summaryrefslogtreecommitdiff
path: root/net/mptcp
diff options
context:
space:
mode:
Diffstat (limited to 'net/mptcp')
-rw-r--r--net/mptcp/options.c3
-rw-r--r--net/mptcp/pm_netlink.c5
-rw-r--r--net/mptcp/protocol.c38
-rw-r--r--net/mptcp/protocol.h2
-rw-r--r--net/mptcp/sockopt.c2
-rw-r--r--net/mptcp/token.c14
-rw-r--r--net/mptcp/token_test.c3
7 files changed, 42 insertions, 25 deletions
diff --git a/net/mptcp/options.c b/net/mptcp/options.c
index 5ded85e2c374..b30cea2fbf3f 100644
--- a/net/mptcp/options.c
+++ b/net/mptcp/options.c
@@ -1594,8 +1594,7 @@ mp_rst:
TCPOLEN_MPTCP_PRIO,
opts->backup, TCPOPT_NOP);
- MPTCP_INC_STATS(sock_net((const struct sock *)tp),
- MPTCP_MIB_MPPRIOTX);
+ MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_MPPRIOTX);
}
mp_capable_done:
diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c
index 2ea7eae43bdb..b5505b8167f9 100644
--- a/net/mptcp/pm_netlink.c
+++ b/net/mptcp/pm_netlink.c
@@ -1143,7 +1143,7 @@ void mptcp_pm_nl_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ss
if (!tcp_rtx_and_write_queues_empty(ssk)) {
subflow->stale = 1;
__mptcp_retransmit_pending_data(sk);
- MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_SUBFLOWSTALE);
+ MPTCP_INC_STATS(net, MPTCP_MIB_SUBFLOWSTALE);
}
unlock_sock_fast(ssk, slow);
@@ -1903,8 +1903,7 @@ static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
}
if (token)
- return mptcp_userspace_pm_set_flags(sock_net(skb->sk),
- token, &addr, &remote, bkup);
+ return mptcp_userspace_pm_set_flags(net, token, &addr, &remote, bkup);
spin_lock_bh(&pernet->lock);
entry = __lookup_addr(pernet, &addr.addr, lookup_by_id);
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index 8cd6cc67c2c5..1ff2efbab29e 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -923,9 +923,8 @@ static void mptcp_check_for_eof(struct mptcp_sock *msk)
static struct sock *mptcp_subflow_recv_lookup(const struct mptcp_sock *msk)
{
struct mptcp_subflow_context *subflow;
- struct sock *sk = (struct sock *)msk;
- sock_owned_by_me(sk);
+ msk_owned_by_me(msk);
mptcp_for_each_subflow(msk, subflow) {
if (READ_ONCE(subflow->data_avail))
@@ -1408,7 +1407,7 @@ static struct sock *mptcp_subflow_get_send(struct mptcp_sock *msk)
u64 linger_time;
long tout = 0;
- sock_owned_by_me(sk);
+ msk_owned_by_me(msk);
if (__mptcp_check_fallback(msk)) {
if (!msk->first)
@@ -1890,7 +1889,7 @@ static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied)
u32 time, advmss = 1;
u64 rtt_us, mstamp;
- sock_owned_by_me(sk);
+ msk_owned_by_me(msk);
if (copied <= 0)
return;
@@ -2217,7 +2216,7 @@ static struct sock *mptcp_subflow_get_retrans(struct mptcp_sock *msk)
struct mptcp_subflow_context *subflow;
int min_stale_count = INT_MAX;
- sock_owned_by_me((const struct sock *)msk);
+ msk_owned_by_me(msk);
if (__mptcp_check_fallback(msk))
return NULL;
@@ -2724,8 +2723,8 @@ static int mptcp_init_sock(struct sock *sk)
mptcp_ca_reset(sk);
sk_sockets_allocated_inc(sk);
- sk->sk_rcvbuf = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[1]);
- sk->sk_sndbuf = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_wmem[1]);
+ sk->sk_rcvbuf = READ_ONCE(net->ipv4.sysctl_tcp_rmem[1]);
+ sk->sk_sndbuf = READ_ONCE(net->ipv4.sysctl_tcp_wmem[1]);
return 0;
}
@@ -2892,6 +2891,12 @@ static __poll_t mptcp_check_readable(struct mptcp_sock *msk)
return EPOLLIN | EPOLLRDNORM;
}
+static void mptcp_listen_inuse_dec(struct sock *sk)
+{
+ if (inet_sk_state_load(sk) == TCP_LISTEN)
+ sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
+}
+
bool __mptcp_close(struct sock *sk, long timeout)
{
struct mptcp_subflow_context *subflow;
@@ -2901,6 +2906,7 @@ bool __mptcp_close(struct sock *sk, long timeout)
sk->sk_shutdown = SHUTDOWN_MASK;
if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) {
+ mptcp_listen_inuse_dec(sk);
inet_sk_state_store(sk, TCP_CLOSE);
goto cleanup;
}
@@ -3001,6 +3007,7 @@ static int mptcp_disconnect(struct sock *sk, int flags)
if (msk->fastopening)
return 0;
+ mptcp_listen_inuse_dec(sk);
inet_sk_state_store(sk, TCP_CLOSE);
mptcp_stop_timer(sk);
@@ -3639,12 +3646,13 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
static int mptcp_listen(struct socket *sock, int backlog)
{
struct mptcp_sock *msk = mptcp_sk(sock->sk);
+ struct sock *sk = sock->sk;
struct socket *ssock;
int err;
pr_debug("msk=%p", msk);
- lock_sock(sock->sk);
+ lock_sock(sk);
ssock = __mptcp_nmpc_socket(msk);
if (!ssock) {
err = -EINVAL;
@@ -3652,18 +3660,20 @@ static int mptcp_listen(struct socket *sock, int backlog)
}
mptcp_token_destroy(msk);
- inet_sk_state_store(sock->sk, TCP_LISTEN);
- sock_set_flag(sock->sk, SOCK_RCU_FREE);
+ inet_sk_state_store(sk, TCP_LISTEN);
+ sock_set_flag(sk, SOCK_RCU_FREE);
err = ssock->ops->listen(ssock, backlog);
- inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
- if (!err)
- mptcp_copy_inaddrs(sock->sk, ssock->sk);
+ inet_sk_state_store(sk, inet_sk_state_load(ssock->sk));
+ if (!err) {
+ sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
+ mptcp_copy_inaddrs(sk, ssock->sk);
+ }
mptcp_event_pm_listener(ssock->sk, MPTCP_EVENT_LISTENER_CREATED);
unlock:
- release_sock(sock->sk);
+ release_sock(sk);
return err;
}
diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h
index 601469249da8..61fd8eabfca2 100644
--- a/net/mptcp/protocol.h
+++ b/net/mptcp/protocol.h
@@ -755,7 +755,7 @@ static inline void mptcp_token_init_request(struct request_sock *req)
int mptcp_token_new_request(struct request_sock *req);
void mptcp_token_destroy_request(struct request_sock *req);
-int mptcp_token_new_connect(struct sock *sk);
+int mptcp_token_new_connect(struct sock *ssk);
void mptcp_token_accept(struct mptcp_subflow_request_sock *r,
struct mptcp_sock *msk);
bool mptcp_token_exists(u32 token);
diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c
index d4b1e6ec1b36..582ed93bcc8a 100644
--- a/net/mptcp/sockopt.c
+++ b/net/mptcp/sockopt.c
@@ -18,7 +18,7 @@
static struct sock *__mptcp_tcp_fallback(struct mptcp_sock *msk)
{
- sock_owned_by_me((const struct sock *)msk);
+ msk_owned_by_me(msk);
if (likely(!__mptcp_check_fallback(msk)))
return NULL;
diff --git a/net/mptcp/token.c b/net/mptcp/token.c
index 65430f314a68..5bb924534387 100644
--- a/net/mptcp/token.c
+++ b/net/mptcp/token.c
@@ -134,7 +134,7 @@ int mptcp_token_new_request(struct request_sock *req)
/**
* mptcp_token_new_connect - create new key/idsn/token for subflow
- * @sk: the socket that will initiate a connection
+ * @ssk: the socket that will initiate a connection
*
* This function is called when a new outgoing mptcp connection is
* initiated.
@@ -148,11 +148,12 @@ int mptcp_token_new_request(struct request_sock *req)
*
* returns 0 on success.
*/
-int mptcp_token_new_connect(struct sock *sk)
+int mptcp_token_new_connect(struct sock *ssk)
{
- struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
+ struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
struct mptcp_sock *msk = mptcp_sk(subflow->conn);
int retries = MPTCP_TOKEN_MAX_RETRIES;
+ struct sock *sk = subflow->conn;
struct token_bucket *bucket;
again:
@@ -169,12 +170,13 @@ again:
}
pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n",
- sk, subflow->local_key, subflow->token, subflow->idsn);
+ ssk, subflow->local_key, subflow->token, subflow->idsn);
WRITE_ONCE(msk->token, subflow->token);
__sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
bucket->chain_len++;
spin_unlock_bh(&bucket->lock);
+ sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
return 0;
}
@@ -190,8 +192,10 @@ void mptcp_token_accept(struct mptcp_subflow_request_sock *req,
struct mptcp_sock *msk)
{
struct mptcp_subflow_request_sock *pos;
+ struct sock *sk = (struct sock *)msk;
struct token_bucket *bucket;
+ sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
bucket = token_bucket(req->token);
spin_lock_bh(&bucket->lock);
@@ -370,12 +374,14 @@ void mptcp_token_destroy_request(struct request_sock *req)
*/
void mptcp_token_destroy(struct mptcp_sock *msk)
{
+ struct sock *sk = (struct sock *)msk;
struct token_bucket *bucket;
struct mptcp_sock *pos;
if (sk_unhashed((struct sock *)msk))
return;
+ sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
bucket = token_bucket(msk->token);
spin_lock_bh(&bucket->lock);
pos = __token_lookup_msk(bucket, msk->token);
diff --git a/net/mptcp/token_test.c b/net/mptcp/token_test.c
index 5d984bec1cd8..0758865ab658 100644
--- a/net/mptcp/token_test.c
+++ b/net/mptcp/token_test.c
@@ -57,6 +57,9 @@ static struct mptcp_sock *build_msk(struct kunit *test)
KUNIT_EXPECT_NOT_ERR_OR_NULL(test, msk);
refcount_set(&((struct sock *)msk)->sk_refcnt, 1);
sock_net_set((struct sock *)msk, &init_net);
+
+ /* be sure the token helpers can dereference sk->sk_prot */
+ ((struct sock *)msk)->sk_prot = &tcp_prot;
return msk;
}