From 30bab7cdb56da4819ff081ad658646f2df16c098 Mon Sep 17 00:00:00 2001 From: Jiri Pirko Date: Mon, 25 Jul 2022 10:29:14 +0200 Subject: net: devlink: make sure that devlink_try_get() works with valid pointer during xarray iteration Remove dependency on devlink_mutex during devlinks xarray iteration. The reason is that devlink_register/unregister() functions taking devlink_mutex would deadlock during devlink reload operation of devlink instance which registers/unregisters nested devlink instances. The devlinks xarray consistency is ensured internally by xarray. There is a reference taken when working with devlink using devlink_try_get(). But there is no guarantee that devlink pointer picked during xarray iteration is not freed before devlink_try_get() is called. Make sure that devlink_try_get() works with valid pointer. Achieve it by: 1) Splitting devlink_put() so the completion is sent only after grace period. Completion unblocks the devlink_unregister() routine, which is followed-up by devlink_free() 2) During devlinks xa_array iteration, get devlink pointer from xa_array holding RCU read lock and taking reference using devlink_try_get() before unlock. Signed-off-by: Jiri Pirko Reviewed-by: Jakub Kicinski Signed-off-by: Jakub Kicinski --- net/core/devlink.c | 171 +++++++++++++++++++++++++---------------------------- 1 file changed, 80 insertions(+), 91 deletions(-) diff --git a/net/core/devlink.c b/net/core/devlink.c index 98d79feeb3dc..c7abd928f389 100644 --- a/net/core/devlink.c +++ b/net/core/devlink.c @@ -70,6 +70,7 @@ struct devlink { u8 reload_failed:1; refcount_t refcount; struct completion comp; + struct rcu_head rcu; char priv[] __aligned(NETDEV_ALIGN); }; @@ -221,8 +222,6 @@ static DEFINE_XARRAY_FLAGS(devlinks, XA_FLAGS_ALLOC); /* devlink_mutex * * An overall lock guarding every operation coming from userspace. - * It also guards devlink devices list and it is taken when - * driver registers/unregisters it. */ static DEFINE_MUTEX(devlink_mutex); @@ -232,10 +231,21 @@ struct net *devlink_net(const struct devlink *devlink) } EXPORT_SYMBOL_GPL(devlink_net); +static void __devlink_put_rcu(struct rcu_head *head) +{ + struct devlink *devlink = container_of(head, struct devlink, rcu); + + complete(&devlink->comp); +} + void devlink_put(struct devlink *devlink) { if (refcount_dec_and_test(&devlink->refcount)) - complete(&devlink->comp); + /* Make sure unregister operation that may await the completion + * is unblocked only after all users are after the end of + * RCU grace period. + */ + call_rcu(&devlink->rcu, __devlink_put_rcu); } struct devlink *__must_check devlink_try_get(struct devlink *devlink) @@ -278,12 +288,55 @@ void devl_unlock(struct devlink *devlink) } EXPORT_SYMBOL_GPL(devl_unlock); +static struct devlink * +devlinks_xa_find_get(unsigned long *indexp, xa_mark_t filter, + void * (*xa_find_fn)(struct xarray *, unsigned long *, + unsigned long, xa_mark_t)) +{ + struct devlink *devlink; + + rcu_read_lock(); +retry: + devlink = xa_find_fn(&devlinks, indexp, ULONG_MAX, DEVLINK_REGISTERED); + if (!devlink) + goto unlock; + /* For a possible retry, the xa_find_after() should be always used */ + xa_find_fn = xa_find_after; + if (!devlink_try_get(devlink)) + goto retry; +unlock: + rcu_read_unlock(); + return devlink; +} + +static struct devlink *devlinks_xa_find_get_first(unsigned long *indexp, + xa_mark_t filter) +{ + return devlinks_xa_find_get(indexp, filter, xa_find); +} + +static struct devlink *devlinks_xa_find_get_next(unsigned long *indexp, + xa_mark_t filter) +{ + return devlinks_xa_find_get(indexp, filter, xa_find_after); +} + +/* Iterate over devlink pointers which were possible to get reference to. + * devlink_put() needs to be called for each iterated devlink pointer + * in loop body in order to release the reference. + */ +#define devlinks_xa_for_each_get(index, devlink, filter) \ + for (index = 0, devlink = devlinks_xa_find_get_first(&index, filter); \ + devlink; devlink = devlinks_xa_find_get_next(&index, filter)) + +#define devlinks_xa_for_each_registered_get(index, devlink) \ + devlinks_xa_for_each_get(index, devlink, DEVLINK_REGISTERED) + static struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs) { struct devlink *devlink; unsigned long index; - bool found = false; char *busname; char *devname; @@ -293,21 +346,15 @@ static struct devlink *devlink_get_from_attrs(struct net *net, busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]); devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]); - lockdep_assert_held(&devlink_mutex); - - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { + devlinks_xa_for_each_registered_get(index, devlink) { if (strcmp(devlink->dev->bus->name, busname) == 0 && strcmp(dev_name(devlink->dev), devname) == 0 && - net_eq(devlink_net(devlink), net)) { - found = true; - break; - } + net_eq(devlink_net(devlink), net)) + return devlink; + devlink_put(devlink); } - if (!found || !devlink_try_get(devlink)) - devlink = ERR_PTR(-ENODEV); - - return devlink; + return ERR_PTR(-ENODEV); } static struct devlink_port *devlink_port_get_by_index(struct devlink *devlink, @@ -1329,10 +1376,7 @@ static int devlink_nl_cmd_rate_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -1432,10 +1476,7 @@ static int devlink_nl_cmd_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) { devlink_put(devlink); continue; @@ -1495,10 +1536,7 @@ static int devlink_nl_cmd_port_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -2177,10 +2215,7 @@ static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -2449,10 +2484,7 @@ static int devlink_nl_cmd_sb_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -2601,10 +2633,7 @@ static int devlink_nl_cmd_sb_pool_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) || !devlink->ops->sb_pool_get) goto retry; @@ -2822,10 +2851,7 @@ static int devlink_nl_cmd_sb_port_pool_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) || !devlink->ops->sb_port_pool_get) goto retry; @@ -3071,10 +3097,7 @@ devlink_nl_cmd_sb_tc_pool_bind_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) || !devlink->ops->sb_tc_pool_bind_get) goto retry; @@ -5158,10 +5181,7 @@ static int devlink_nl_cmd_param_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -5393,10 +5413,7 @@ static int devlink_nl_cmd_port_param_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -5977,10 +5994,7 @@ static int devlink_nl_cmd_region_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -6511,10 +6525,7 @@ static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -7691,10 +7702,7 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry_rep; @@ -7721,10 +7729,7 @@ retry_rep: devlink_put(devlink); } - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry_port; @@ -8291,10 +8296,7 @@ static int devlink_nl_cmd_trap_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -8518,10 +8520,7 @@ static int devlink_nl_cmd_trap_group_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -8832,10 +8831,7 @@ static int devlink_nl_cmd_trap_policer_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) goto retry; @@ -9589,10 +9585,8 @@ void devlink_register(struct devlink *devlink) ASSERT_DEVLINK_NOT_REGISTERED(devlink); /* Make sure that we are in .probe() routine */ - mutex_lock(&devlink_mutex); xa_set_mark(&devlinks, devlink->index, DEVLINK_REGISTERED); devlink_notify_register(devlink); - mutex_unlock(&devlink_mutex); } EXPORT_SYMBOL_GPL(devlink_register); @@ -9609,10 +9603,8 @@ void devlink_unregister(struct devlink *devlink) devlink_put(devlink); wait_for_completion(&devlink->comp); - mutex_lock(&devlink_mutex); devlink_notify_unregister(devlink); xa_clear_mark(&devlinks, devlink->index, DEVLINK_REGISTERED); - mutex_unlock(&devlink_mutex); } EXPORT_SYMBOL_GPL(devlink_unregister); @@ -12281,10 +12273,7 @@ static void __net_exit devlink_pernet_pre_exit(struct net *net) * all devlink instances from this namespace into init_net. */ mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - + devlinks_xa_for_each_registered_get(index, devlink) { if (!net_eq(devlink_net(devlink), net)) goto retry; -- cgit v1.2.3