From df1c631648c55bfb247339279f9bc573c7f283f4 Mon Sep 17 00:00:00 2001 From: David Ahern Date: Fri, 31 Mar 2017 07:14:02 -0700 Subject: net: mpls: Limit memory allocation for mpls_route Limit memory allocation size for mpls_route to 4096. Signed-off-by: David Ahern Signed-off-by: David S. Miller --- net/mpls/af_mpls.c | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) (limited to 'net/mpls') diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c index 1863b94133e4..f84c52b6eafc 100644 --- a/net/mpls/af_mpls.c +++ b/net/mpls/af_mpls.c @@ -26,6 +26,9 @@ #define MAX_NEW_LABELS 2 +/* max memory we will use for mpls_route */ +#define MAX_MPLS_ROUTE_MEM 4096 + /* Maximum number of labels to look ahead at when selecting a path of * a multipath route */ @@ -477,14 +480,20 @@ static struct mpls_route *mpls_rt_alloc(u8 num_nh, u8 max_alen, u8 max_labels) { u8 nh_size = MPLS_NH_SIZE(max_labels, max_alen); struct mpls_route *rt; + size_t size; - rt = kzalloc(sizeof(*rt) + num_nh * nh_size, GFP_KERNEL); - if (rt) { - rt->rt_nhn = num_nh; - rt->rt_nhn_alive = num_nh; - rt->rt_nh_size = nh_size; - rt->rt_via_offset = MPLS_NH_VIA_OFF(max_labels); - } + size = sizeof(*rt) + num_nh * nh_size; + if (size > MAX_MPLS_ROUTE_MEM) + return ERR_PTR(-EINVAL); + + rt = kzalloc(size, GFP_KERNEL); + if (!rt) + return ERR_PTR(-ENOMEM); + + rt->rt_nhn = num_nh; + rt->rt_nhn_alive = num_nh; + rt->rt_nh_size = nh_size; + rt->rt_via_offset = MPLS_NH_VIA_OFF(max_labels); return rt; } @@ -898,8 +907,10 @@ static int mpls_route_add(struct mpls_route_config *cfg) err = -ENOMEM; rt = mpls_rt_alloc(nhs, max_via_alen, MAX_NEW_LABELS); - if (!rt) + if (IS_ERR(rt)) { + err = PTR_ERR(rt); goto errout; + } rt->rt_protocol = cfg->rc_protocol; rt->rt_payload_type = cfg->rc_payload_type; @@ -1970,7 +1981,7 @@ static int resize_platform_label_table(struct net *net, size_t limit) if (limit > MPLS_LABEL_IPV4NULL) { struct net_device *lo = net->loopback_dev; rt0 = mpls_rt_alloc(1, lo->addr_len, MAX_NEW_LABELS); - if (!rt0) + if (IS_ERR(rt0)) goto nort0; RCU_INIT_POINTER(rt0->rt_nh->nh_dev, lo); rt0->rt_protocol = RTPROT_KERNEL; @@ -1984,7 +1995,7 @@ static int resize_platform_label_table(struct net *net, size_t limit) if (limit > MPLS_LABEL_IPV6NULL) { struct net_device *lo = net->loopback_dev; rt2 = mpls_rt_alloc(1, lo->addr_len, MAX_NEW_LABELS); - if (!rt2) + if (IS_ERR(rt2)) goto nort2; RCU_INIT_POINTER(rt2->rt_nh->nh_dev, lo); rt2->rt_protocol = RTPROT_KERNEL; -- cgit v1.2.3