diff options
Diffstat (limited to 'net/netfilter/nf_tables_api.c')
-rw-r--r-- | net/netfilter/nf_tables_api.c | 89 |
1 files changed, 57 insertions, 32 deletions
diff --git a/net/netfilter/nf_tables_api.c b/net/netfilter/nf_tables_api.c index e34d05cc5754..2b5f97e1d40b 100644 --- a/net/netfilter/nf_tables_api.c +++ b/net/netfilter/nf_tables_api.c @@ -4115,6 +4115,7 @@ static int nf_tables_newset(struct net *net, struct sock *nlsk, struct nft_table *table; struct nft_set *set; struct nft_ctx ctx; + size_t alloc_size; char *name; u64 size; u64 timeout; @@ -4263,8 +4264,10 @@ static int nf_tables_newset(struct net *net, struct sock *nlsk, size = 0; if (ops->privsize != NULL) size = ops->privsize(nla, &desc); - - set = kvzalloc(sizeof(*set) + size + udlen, GFP_KERNEL); + alloc_size = sizeof(*set) + size + udlen; + if (alloc_size < size) + return -ENOMEM; + set = kvzalloc(alloc_size, GFP_KERNEL); if (!set) return -ENOMEM; @@ -4277,15 +4280,7 @@ static int nf_tables_newset(struct net *net, struct sock *nlsk, err = nf_tables_set_alloc_name(&ctx, set, name); kfree(name); if (err < 0) - goto err_set_alloc_name; - - if (nla[NFTA_SET_EXPR]) { - expr = nft_set_elem_expr_alloc(&ctx, set, nla[NFTA_SET_EXPR]); - if (IS_ERR(expr)) { - err = PTR_ERR(expr); - goto err_set_alloc_name; - } - } + goto err_set_name; udata = NULL; if (udlen) { @@ -4296,21 +4291,19 @@ static int nf_tables_newset(struct net *net, struct sock *nlsk, INIT_LIST_HEAD(&set->bindings); set->table = table; write_pnet(&set->net, net); - set->ops = ops; + set->ops = ops; set->ktype = ktype; - set->klen = desc.klen; + set->klen = desc.klen; set->dtype = dtype; set->objtype = objtype; - set->dlen = desc.dlen; - set->expr = expr; + set->dlen = desc.dlen; set->flags = flags; - set->size = desc.size; + set->size = desc.size; set->policy = policy; - set->udlen = udlen; - set->udata = udata; + set->udlen = udlen; + set->udata = udata; set->timeout = timeout; set->gc_int = gc_int; - set->handle = nf_tables_alloc_handle(table); set->field_count = desc.field_count; for (i = 0; i < desc.field_count; i++) @@ -4320,20 +4313,32 @@ static int nf_tables_newset(struct net *net, struct sock *nlsk, if (err < 0) goto err_set_init; + if (nla[NFTA_SET_EXPR]) { + expr = nft_set_elem_expr_alloc(&ctx, set, nla[NFTA_SET_EXPR]); + if (IS_ERR(expr)) { + err = PTR_ERR(expr); + goto err_set_expr_alloc; + } + + set->expr = expr; + } + + set->handle = nf_tables_alloc_handle(table); + err = nft_trans_set_add(&ctx, NFT_MSG_NEWSET, set); if (err < 0) - goto err_set_trans; + goto err_set_expr_alloc; list_add_tail_rcu(&set->list, &table->sets); table->use++; return 0; -err_set_trans: +err_set_expr_alloc: + if (set->expr) + nft_expr_destroy(&ctx, set->expr); + ops->destroy(set); err_set_init: - if (expr) - nft_expr_destroy(&ctx, expr); -err_set_alloc_name: kfree(set->name); err_set_name: kvfree(set); @@ -5145,6 +5150,24 @@ static void nf_tables_set_elem_destroy(const struct nft_ctx *ctx, kfree(elem); } +static int nft_set_elem_expr_setup(struct nft_ctx *ctx, + const struct nft_set_ext *ext, + struct nft_expr *expr) +{ + struct nft_expr *elem_expr = nft_set_ext_expr(ext); + int err; + + if (expr == NULL) + return 0; + + err = nft_expr_clone(elem_expr, expr); + if (err < 0) + return -ENOMEM; + + nft_expr_destroy(ctx, expr); + return 0; +} + static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set, const struct nlattr *attr, u32 nlmsg_flags) { @@ -5347,15 +5370,17 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set, *nft_set_ext_obj(ext) = obj; obj->use++; } - if (expr) { - memcpy(nft_set_ext_expr(ext), expr, expr->ops->size); - kfree(expr); - expr = NULL; - } + + err = nft_set_elem_expr_setup(ctx, ext, expr); + if (err < 0) + goto err_elem_expr; + expr = NULL; trans = nft_trans_elem_alloc(ctx, NFT_MSG_NEWSETELEM, set); - if (trans == NULL) - goto err_trans; + if (trans == NULL) { + err = -ENOMEM; + goto err_elem_expr; + } ext->genmask = nft_genmask_cur(ctx->net) | NFT_SET_ELEM_BUSY_MASK; err = set->ops->insert(ctx->net, set, &elem, &ext2); @@ -5399,7 +5424,7 @@ err_set_full: set->ops->remove(ctx->net, set, &elem); err_element_clash: kfree(trans); -err_trans: +err_elem_expr: if (obj) obj->use--; |