summaryrefslogtreecommitdiff
path: root/net/ipv4/inet_connection_sock.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/inet_connection_sock.c')
-rw-r--r--net/ipv4/inet_connection_sock.c297
1 files changed, 221 insertions, 76 deletions
diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c
index eb31c7158b39..ebca860e113f 100644
--- a/net/ipv4/inet_connection_sock.c
+++ b/net/ipv4/inet_connection_sock.c
@@ -130,14 +130,75 @@ void inet_get_local_port_range(struct net *net, int *low, int *high)
}
EXPORT_SYMBOL(inet_get_local_port_range);
+static bool inet_use_bhash2_on_bind(const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6) {
+ int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr);
+
+ return addr_type != IPV6_ADDR_ANY &&
+ addr_type != IPV6_ADDR_MAPPED;
+ }
+#endif
+ return sk->sk_rcv_saddr != htonl(INADDR_ANY);
+}
+
+static bool inet_bind_conflict(const struct sock *sk, struct sock *sk2,
+ kuid_t sk_uid, bool relax,
+ bool reuseport_cb_ok, bool reuseport_ok)
+{
+ int bound_dev_if2;
+
+ if (sk == sk2)
+ return false;
+
+ bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if);
+
+ if (!sk->sk_bound_dev_if || !bound_dev_if2 ||
+ sk->sk_bound_dev_if == bound_dev_if2) {
+ if (sk->sk_reuse && sk2->sk_reuse &&
+ sk2->sk_state != TCP_LISTEN) {
+ if (!relax || (!reuseport_ok && sk->sk_reuseport &&
+ sk2->sk_reuseport && reuseport_cb_ok &&
+ (sk2->sk_state == TCP_TIME_WAIT ||
+ uid_eq(sk_uid, sock_i_uid(sk2)))))
+ return true;
+ } else if (!reuseport_ok || !sk->sk_reuseport ||
+ !sk2->sk_reuseport || !reuseport_cb_ok ||
+ (sk2->sk_state != TCP_TIME_WAIT &&
+ !uid_eq(sk_uid, sock_i_uid(sk2)))) {
+ return true;
+ }
+ }
+ return false;
+}
+
+static bool inet_bhash2_conflict(const struct sock *sk,
+ const struct inet_bind2_bucket *tb2,
+ kuid_t sk_uid,
+ bool relax, bool reuseport_cb_ok,
+ bool reuseport_ok)
+{
+ struct sock *sk2;
+
+ sk_for_each_bound_bhash2(sk2, &tb2->owners) {
+ if (sk->sk_family == AF_INET && ipv6_only_sock(sk2))
+ continue;
+
+ if (inet_bind_conflict(sk, sk2, sk_uid, relax,
+ reuseport_cb_ok, reuseport_ok))
+ return true;
+ }
+ return false;
+}
+
+/* This should be called only when the tb and tb2 hashbuckets' locks are held */
static int inet_csk_bind_conflict(const struct sock *sk,
const struct inet_bind_bucket *tb,
+ const struct inet_bind2_bucket *tb2, /* may be null */
bool relax, bool reuseport_ok)
{
- struct sock *sk2;
bool reuseport_cb_ok;
- bool reuse = sk->sk_reuse;
- bool reuseport = !!sk->sk_reuseport;
struct sock_reuseport *reuseport_cb;
kuid_t uid = sock_i_uid((struct sock *)sk);
@@ -150,58 +211,88 @@ static int inet_csk_bind_conflict(const struct sock *sk,
/*
* Unlike other sk lookup places we do not check
* for sk_net here, since _all_ the socks listed
- * in tb->owners list belong to the same net - the
- * one this bucket belongs to.
+ * in tb->owners and tb2->owners list belong
+ * to the same net - the one this bucket belongs to.
*/
- sk_for_each_bound(sk2, &tb->owners) {
- int bound_dev_if2;
+ if (!inet_use_bhash2_on_bind(sk)) {
+ struct sock *sk2;
- if (sk == sk2)
- continue;
- bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if);
- if ((!sk->sk_bound_dev_if ||
- !bound_dev_if2 ||
- sk->sk_bound_dev_if == bound_dev_if2)) {
- if (reuse && sk2->sk_reuse &&
- sk2->sk_state != TCP_LISTEN) {
- if ((!relax ||
- (!reuseport_ok &&
- reuseport && sk2->sk_reuseport &&
- reuseport_cb_ok &&
- (sk2->sk_state == TCP_TIME_WAIT ||
- uid_eq(uid, sock_i_uid(sk2))))) &&
- inet_rcv_saddr_equal(sk, sk2, true))
- break;
- } else if (!reuseport_ok ||
- !reuseport || !sk2->sk_reuseport ||
- !reuseport_cb_ok ||
- (sk2->sk_state != TCP_TIME_WAIT &&
- !uid_eq(uid, sock_i_uid(sk2)))) {
- if (inet_rcv_saddr_equal(sk, sk2, true))
- break;
- }
- }
+ sk_for_each_bound(sk2, &tb->owners)
+ if (inet_bind_conflict(sk, sk2, uid, relax,
+ reuseport_cb_ok, reuseport_ok) &&
+ inet_rcv_saddr_equal(sk, sk2, true))
+ return true;
+
+ return false;
+ }
+
+ /* Conflicts with an existing IPV6_ADDR_ANY (if ipv6) or INADDR_ANY (if
+ * ipv4) should have been checked already. We need to do these two
+ * checks separately because their spinlocks have to be acquired/released
+ * independently of each other, to prevent possible deadlocks
+ */
+ return tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok,
+ reuseport_ok);
+}
+
+/* Determine if there is a bind conflict with an existing IPV6_ADDR_ANY (if ipv6) or
+ * INADDR_ANY (if ipv4) socket.
+ *
+ * Caller must hold bhash hashbucket lock with local bh disabled, to protect
+ * against concurrent binds on the port for addr any
+ */
+static bool inet_bhash2_addr_any_conflict(const struct sock *sk, int port, int l3mdev,
+ bool relax, bool reuseport_ok)
+{
+ kuid_t uid = sock_i_uid((struct sock *)sk);
+ const struct net *net = sock_net(sk);
+ struct sock_reuseport *reuseport_cb;
+ struct inet_bind_hashbucket *head2;
+ struct inet_bind2_bucket *tb2;
+ bool reuseport_cb_ok;
+
+ rcu_read_lock();
+ reuseport_cb = rcu_dereference(sk->sk_reuseport_cb);
+ /* paired with WRITE_ONCE() in __reuseport_(add|detach)_closed_sock */
+ reuseport_cb_ok = !reuseport_cb || READ_ONCE(reuseport_cb->num_closed_socks);
+ rcu_read_unlock();
+
+ head2 = inet_bhash2_addr_any_hashbucket(sk, net, port);
+
+ spin_lock(&head2->lock);
+
+ inet_bind_bucket_for_each(tb2, &head2->chain)
+ if (inet_bind2_bucket_match_addr_any(tb2, net, port, l3mdev, sk))
+ break;
+
+ if (tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok,
+ reuseport_ok)) {
+ spin_unlock(&head2->lock);
+ return true;
}
- return sk2 != NULL;
+
+ spin_unlock(&head2->lock);
+ return false;
}
/*
* Find an open port number for the socket. Returns with the
- * inet_bind_hashbucket lock held.
+ * inet_bind_hashbucket locks held if successful.
*/
static struct inet_bind_hashbucket *
-inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret, int *port_ret)
+inet_csk_find_open_port(const struct sock *sk, struct inet_bind_bucket **tb_ret,
+ struct inet_bind2_bucket **tb2_ret,
+ struct inet_bind_hashbucket **head2_ret, int *port_ret)
{
- struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
- int port = 0;
- struct inet_bind_hashbucket *head;
+ struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk);
+ int i, low, high, attempt_half, port, l3mdev;
+ struct inet_bind_hashbucket *head, *head2;
struct net *net = sock_net(sk);
- bool relax = false;
- int i, low, high, attempt_half;
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
u32 remaining, offset;
- int l3mdev;
+ bool relax = false;
l3mdev = inet_sk_bound_l3mdev(sk);
ports_exhausted:
@@ -239,11 +330,20 @@ other_parity_scan:
head = &hinfo->bhash[inet_bhashfn(net, port,
hinfo->bhash_size)];
spin_lock_bh(&head->lock);
+ if (inet_use_bhash2_on_bind(sk)) {
+ if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, relax, false))
+ goto next_port;
+ }
+
+ head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+ spin_lock(&head2->lock);
+ tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
inet_bind_bucket_for_each(tb, &head->chain)
- if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
- tb->port == port) {
- if (!inet_csk_bind_conflict(sk, tb, relax, false))
+ if (inet_bind_bucket_match(tb, net, port, l3mdev)) {
+ if (!inet_csk_bind_conflict(sk, tb, tb2,
+ relax, false))
goto success;
+ spin_unlock(&head2->lock);
goto next_port;
}
tb = NULL;
@@ -272,6 +372,8 @@ next_port:
success:
*port_ret = port;
*tb_ret = tb;
+ *tb2_ret = tb2;
+ *head2_ret = head2;
return head;
}
@@ -365,56 +467,97 @@ void inet_csk_update_fastreuse(struct inet_bind_bucket *tb,
*/
int inet_csk_get_port(struct sock *sk, unsigned short snum)
{
+ struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk);
bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN;
- struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
- int ret = 1, port = snum;
- struct inet_bind_hashbucket *head;
- struct net *net = sock_net(sk);
+ bool found_port = false, check_bind_conflict = true;
+ bool bhash_created = false, bhash2_created = false;
+ struct inet_bind_hashbucket *head, *head2;
+ struct inet_bind2_bucket *tb2 = NULL;
struct inet_bind_bucket *tb = NULL;
- int l3mdev;
+ bool head2_lock_acquired = false;
+ int ret = 1, port = snum, l3mdev;
+ struct net *net = sock_net(sk);
l3mdev = inet_sk_bound_l3mdev(sk);
if (!port) {
- head = inet_csk_find_open_port(sk, &tb, &port);
+ head = inet_csk_find_open_port(sk, &tb, &tb2, &head2, &port);
if (!head)
return ret;
+
+ head2_lock_acquired = true;
+
+ if (tb && tb2)
+ goto success;
+ found_port = true;
+ } else {
+ head = &hinfo->bhash[inet_bhashfn(net, port,
+ hinfo->bhash_size)];
+ spin_lock_bh(&head->lock);
+ inet_bind_bucket_for_each(tb, &head->chain)
+ if (inet_bind_bucket_match(tb, net, port, l3mdev))
+ break;
+ }
+
+ if (!tb) {
+ tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, net,
+ head, port, l3mdev);
if (!tb)
- goto tb_not_found;
- goto success;
+ goto fail_unlock;
+ bhash_created = true;
}
- head = &hinfo->bhash[inet_bhashfn(net, port,
- hinfo->bhash_size)];
- spin_lock_bh(&head->lock);
- inet_bind_bucket_for_each(tb, &head->chain)
- if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
- tb->port == port)
- goto tb_found;
-tb_not_found:
- tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep,
- net, head, port, l3mdev);
- if (!tb)
- goto fail_unlock;
-tb_found:
- if (!hlist_empty(&tb->owners)) {
- if (sk->sk_reuse == SK_FORCE_REUSE)
- goto success;
- if ((tb->fastreuse > 0 && reuse) ||
- sk_reuseport_match(tb, sk))
- goto success;
- if (inet_csk_bind_conflict(sk, tb, true, true))
+ if (!found_port) {
+ if (!hlist_empty(&tb->owners)) {
+ if (sk->sk_reuse == SK_FORCE_REUSE ||
+ (tb->fastreuse > 0 && reuse) ||
+ sk_reuseport_match(tb, sk))
+ check_bind_conflict = false;
+ }
+
+ if (check_bind_conflict && inet_use_bhash2_on_bind(sk)) {
+ if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, true, true))
+ goto fail_unlock;
+ }
+
+ head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+ spin_lock(&head2->lock);
+ head2_lock_acquired = true;
+ tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
+ }
+
+ if (!tb2) {
+ tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
+ net, head2, port, l3mdev, sk);
+ if (!tb2)
goto fail_unlock;
+ bhash2_created = true;
}
+
+ if (!found_port && check_bind_conflict) {
+ if (inet_csk_bind_conflict(sk, tb, tb2, true, true))
+ goto fail_unlock;
+ }
+
success:
inet_csk_update_fastreuse(tb, sk);
if (!inet_csk(sk)->icsk_bind_hash)
- inet_bind_hash(sk, tb, port);
+ inet_bind_hash(sk, tb, tb2, port);
WARN_ON(inet_csk(sk)->icsk_bind_hash != tb);
+ WARN_ON(inet_csk(sk)->icsk_bind2_hash != tb2);
ret = 0;
fail_unlock:
+ if (ret) {
+ if (bhash_created)
+ inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
+ if (bhash2_created)
+ inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep,
+ tb2);
+ }
+ if (head2_lock_acquired)
+ spin_unlock(&head2->lock);
spin_unlock_bh(&head->lock);
return ret;
}
@@ -763,14 +906,15 @@ static void reqsk_migrate_reset(struct request_sock *req)
/* return true if req was found in the ehash table */
static bool reqsk_queue_unlink(struct request_sock *req)
{
- struct inet_hashinfo *hashinfo = req_to_sk(req)->sk_prot->h.hashinfo;
+ struct sock *sk = req_to_sk(req);
bool found = false;
- if (sk_hashed(req_to_sk(req))) {
+ if (sk_hashed(sk)) {
+ struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk);
spinlock_t *lock = inet_ehash_lockp(hashinfo, req->rsk_hash);
spin_lock(lock);
- found = __sk_nulls_del_node_init_rcu(req_to_sk(req));
+ found = __sk_nulls_del_node_init_rcu(sk);
spin_unlock(lock);
}
if (timer_pending(&req->rsk_timer) && del_timer_sync(&req->rsk_timer))
@@ -962,6 +1106,7 @@ struct sock *inet_csk_clone_lock(const struct sock *sk,
inet_sk_set_state(newsk, TCP_SYN_RECV);
newicsk->icsk_bind_hash = NULL;
+ newicsk->icsk_bind2_hash = NULL;
inet_sk(newsk)->inet_dport = inet_rsk(req)->ir_rmt_port;
inet_sk(newsk)->inet_num = inet_rsk(req)->ir_num;