diff options
Diffstat (limited to 'net/ipv4/raw_diag.c')
-rw-r--r-- | net/ipv4/raw_diag.c | 57 |
1 files changed, 30 insertions, 27 deletions
diff --git a/net/ipv4/raw_diag.c b/net/ipv4/raw_diag.c index ccacbde30a2c..999321834b94 100644 --- a/net/ipv4/raw_diag.c +++ b/net/ipv4/raw_diag.c @@ -34,57 +34,57 @@ raw_get_hashinfo(const struct inet_diag_req_v2 *r) * use helper to figure it out. */ -static struct sock *raw_lookup(struct net *net, struct sock *from, - const struct inet_diag_req_v2 *req) +static bool raw_lookup(struct net *net, struct sock *sk, + const struct inet_diag_req_v2 *req) { struct inet_diag_req_raw *r = (void *)req; - struct sock *sk = NULL; if (r->sdiag_family == AF_INET) - sk = __raw_v4_lookup(net, from, r->sdiag_raw_protocol, - r->id.idiag_dst[0], - r->id.idiag_src[0], - r->id.idiag_if, 0); + return raw_v4_match(net, sk, r->sdiag_raw_protocol, + r->id.idiag_dst[0], + r->id.idiag_src[0], + r->id.idiag_if, 0); #if IS_ENABLED(CONFIG_IPV6) else - sk = __raw_v6_lookup(net, from, r->sdiag_raw_protocol, - (const struct in6_addr *)r->id.idiag_src, - (const struct in6_addr *)r->id.idiag_dst, - r->id.idiag_if, 0); + return raw_v6_match(net, sk, r->sdiag_raw_protocol, + (const struct in6_addr *)r->id.idiag_src, + (const struct in6_addr *)r->id.idiag_dst, + r->id.idiag_if, 0); #endif - return sk; + return false; } static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 *r) { struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); - struct sock *sk = NULL, *s; + struct hlist_nulls_head *hlist; + struct hlist_nulls_node *hnode; + struct sock *sk; int slot; if (IS_ERR(hashinfo)) return ERR_CAST(hashinfo); - read_lock(&hashinfo->lock); + rcu_read_lock(); for (slot = 0; slot < RAW_HTABLE_SIZE; slot++) { - sk_for_each(s, &hashinfo->ht[slot]) { - sk = raw_lookup(net, s, r); - if (sk) { + hlist = &hashinfo->ht[slot]; + sk_nulls_for_each(sk, hnode, hlist) { + if (raw_lookup(net, sk, r)) { /* * Grab it and keep until we fill - * diag meaage to be reported, so + * diag message to be reported, so * caller should call sock_put then. - * We can do that because we're keeping - * hashinfo->lock here. */ - sock_hold(sk); - goto out_unlock; + if (refcount_inc_not_zero(&sk->sk_refcnt)) + goto out_unlock; } } } + sk = ERR_PTR(-ENOENT); out_unlock: - read_unlock(&hashinfo->lock); + rcu_read_unlock(); - return sk ? sk : ERR_PTR(-ENOENT); + return sk; } static int raw_diag_dump_one(struct netlink_callback *cb, @@ -142,6 +142,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); struct net *net = sock_net(skb->sk); struct inet_diag_dump_data *cb_data; + struct hlist_nulls_head *hlist; + struct hlist_nulls_node *hnode; int num, s_num, slot, s_slot; struct sock *sk = NULL; struct nlattr *bc; @@ -154,11 +156,12 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, s_slot = cb->args[0]; num = s_num = cb->args[1]; - read_lock(&hashinfo->lock); + rcu_read_lock(); for (slot = s_slot; slot < RAW_HTABLE_SIZE; s_num = 0, slot++) { num = 0; - sk_for_each(sk, &hashinfo->ht[slot]) { + hlist = &hashinfo->ht[slot]; + sk_nulls_for_each(sk, hnode, hlist) { struct inet_sock *inet = inet_sk(sk); if (!net_eq(sock_net(sk), net)) @@ -181,7 +184,7 @@ next: } out_unlock: - read_unlock(&hashinfo->lock); + rcu_read_unlock(); cb->args[0] = slot; cb->args[1] = num; |