summaryrefslogtreecommitdiff
path: root/net/netfilter/nf_tables_api.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/netfilter/nf_tables_api.c')
-rw-r--r--net/netfilter/nf_tables_api.c89
1 files changed, 57 insertions, 32 deletions
diff --git a/net/netfilter/nf_tables_api.c b/net/netfilter/nf_tables_api.c
index e34d05cc5754..2b5f97e1d40b 100644
--- a/net/netfilter/nf_tables_api.c
+++ b/net/netfilter/nf_tables_api.c
@@ -4115,6 +4115,7 @@ static int nf_tables_newset(struct net *net, struct sock *nlsk,
struct nft_table *table;
struct nft_set *set;
struct nft_ctx ctx;
+ size_t alloc_size;
char *name;
u64 size;
u64 timeout;
@@ -4263,8 +4264,10 @@ static int nf_tables_newset(struct net *net, struct sock *nlsk,
size = 0;
if (ops->privsize != NULL)
size = ops->privsize(nla, &desc);
-
- set = kvzalloc(sizeof(*set) + size + udlen, GFP_KERNEL);
+ alloc_size = sizeof(*set) + size + udlen;
+ if (alloc_size < size)
+ return -ENOMEM;
+ set = kvzalloc(alloc_size, GFP_KERNEL);
if (!set)
return -ENOMEM;
@@ -4277,15 +4280,7 @@ static int nf_tables_newset(struct net *net, struct sock *nlsk,
err = nf_tables_set_alloc_name(&ctx, set, name);
kfree(name);
if (err < 0)
- goto err_set_alloc_name;
-
- if (nla[NFTA_SET_EXPR]) {
- expr = nft_set_elem_expr_alloc(&ctx, set, nla[NFTA_SET_EXPR]);
- if (IS_ERR(expr)) {
- err = PTR_ERR(expr);
- goto err_set_alloc_name;
- }
- }
+ goto err_set_name;
udata = NULL;
if (udlen) {
@@ -4296,21 +4291,19 @@ static int nf_tables_newset(struct net *net, struct sock *nlsk,
INIT_LIST_HEAD(&set->bindings);
set->table = table;
write_pnet(&set->net, net);
- set->ops = ops;
+ set->ops = ops;
set->ktype = ktype;
- set->klen = desc.klen;
+ set->klen = desc.klen;
set->dtype = dtype;
set->objtype = objtype;
- set->dlen = desc.dlen;
- set->expr = expr;
+ set->dlen = desc.dlen;
set->flags = flags;
- set->size = desc.size;
+ set->size = desc.size;
set->policy = policy;
- set->udlen = udlen;
- set->udata = udata;
+ set->udlen = udlen;
+ set->udata = udata;
set->timeout = timeout;
set->gc_int = gc_int;
- set->handle = nf_tables_alloc_handle(table);
set->field_count = desc.field_count;
for (i = 0; i < desc.field_count; i++)
@@ -4320,20 +4313,32 @@ static int nf_tables_newset(struct net *net, struct sock *nlsk,
if (err < 0)
goto err_set_init;
+ if (nla[NFTA_SET_EXPR]) {
+ expr = nft_set_elem_expr_alloc(&ctx, set, nla[NFTA_SET_EXPR]);
+ if (IS_ERR(expr)) {
+ err = PTR_ERR(expr);
+ goto err_set_expr_alloc;
+ }
+
+ set->expr = expr;
+ }
+
+ set->handle = nf_tables_alloc_handle(table);
+
err = nft_trans_set_add(&ctx, NFT_MSG_NEWSET, set);
if (err < 0)
- goto err_set_trans;
+ goto err_set_expr_alloc;
list_add_tail_rcu(&set->list, &table->sets);
table->use++;
return 0;
-err_set_trans:
+err_set_expr_alloc:
+ if (set->expr)
+ nft_expr_destroy(&ctx, set->expr);
+
ops->destroy(set);
err_set_init:
- if (expr)
- nft_expr_destroy(&ctx, expr);
-err_set_alloc_name:
kfree(set->name);
err_set_name:
kvfree(set);
@@ -5145,6 +5150,24 @@ static void nf_tables_set_elem_destroy(const struct nft_ctx *ctx,
kfree(elem);
}
+static int nft_set_elem_expr_setup(struct nft_ctx *ctx,
+ const struct nft_set_ext *ext,
+ struct nft_expr *expr)
+{
+ struct nft_expr *elem_expr = nft_set_ext_expr(ext);
+ int err;
+
+ if (expr == NULL)
+ return 0;
+
+ err = nft_expr_clone(elem_expr, expr);
+ if (err < 0)
+ return -ENOMEM;
+
+ nft_expr_destroy(ctx, expr);
+ return 0;
+}
+
static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
const struct nlattr *attr, u32 nlmsg_flags)
{
@@ -5347,15 +5370,17 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
*nft_set_ext_obj(ext) = obj;
obj->use++;
}
- if (expr) {
- memcpy(nft_set_ext_expr(ext), expr, expr->ops->size);
- kfree(expr);
- expr = NULL;
- }
+
+ err = nft_set_elem_expr_setup(ctx, ext, expr);
+ if (err < 0)
+ goto err_elem_expr;
+ expr = NULL;
trans = nft_trans_elem_alloc(ctx, NFT_MSG_NEWSETELEM, set);
- if (trans == NULL)
- goto err_trans;
+ if (trans == NULL) {
+ err = -ENOMEM;
+ goto err_elem_expr;
+ }
ext->genmask = nft_genmask_cur(ctx->net) | NFT_SET_ELEM_BUSY_MASK;
err = set->ops->insert(ctx->net, set, &elem, &ext2);
@@ -5399,7 +5424,7 @@ err_set_full:
set->ops->remove(ctx->net, set, &elem);
err_element_clash:
kfree(trans);
-err_trans:
+err_elem_expr:
if (obj)
obj->use--;