summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--net/core/skmsg.c31
1 files changed, 15 insertions, 16 deletions
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index 3e78f2a80747..881a5b290946 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -684,20 +684,8 @@ EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
struct sk_buff *skb)
{
- int ret;
-
- /* strparser clones the skb before handing it to a upper layer,
- * meaning we have the same data, but sk is NULL. We do want an
- * sk pointer though when we run the BPF program. So we set it
- * here and then NULL it to ensure we don't trigger a BUG_ON()
- * in skb/sk operations later if kfree_skb is called with a
- * valid skb->sk pointer and no destructor assigned.
- */
- skb->sk = psock->sk;
bpf_compute_data_end_sk_skb(skb);
- ret = bpf_prog_run_pin_on_cpu(prog, skb);
- skb->sk = NULL;
- return ret;
+ return bpf_prog_run_pin_on_cpu(prog, skb);
}
static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
@@ -736,10 +724,11 @@ static void sk_psock_skb_redirect(struct sk_buff *skb)
schedule_work(&psock_other->work);
}
-static void sk_psock_tls_verdict_apply(struct sk_buff *skb, int verdict)
+static void sk_psock_tls_verdict_apply(struct sk_buff *skb, struct sock *sk, int verdict)
{
switch (verdict) {
case __SK_REDIRECT:
+ skb_set_owner_r(skb, sk);
sk_psock_skb_redirect(skb);
break;
case __SK_PASS:
@@ -757,11 +746,17 @@ int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb)
rcu_read_lock();
prog = READ_ONCE(psock->progs.skb_verdict);
if (likely(prog)) {
+ /* We skip full set_owner_r here because if we do a SK_PASS
+ * or SK_DROP we can skip skb memory accounting and use the
+ * TLS context.
+ */
+ skb->sk = psock->sk;
tcp_skb_bpf_redirect_clear(skb);
ret = sk_psock_bpf_run(psock, prog, skb);
ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
+ skb->sk = NULL;
}
- sk_psock_tls_verdict_apply(skb, ret);
+ sk_psock_tls_verdict_apply(skb, psock->sk, ret);
rcu_read_unlock();
return ret;
}
@@ -823,6 +818,7 @@ static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
kfree_skb(skb);
goto out;
}
+ skb_set_owner_r(skb, sk);
prog = READ_ONCE(psock->progs.skb_verdict);
if (likely(prog)) {
tcp_skb_bpf_redirect_clear(skb);
@@ -847,8 +843,11 @@ static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
rcu_read_lock();
prog = READ_ONCE(psock->progs.skb_parser);
- if (likely(prog))
+ if (likely(prog)) {
+ skb->sk = psock->sk;
ret = sk_psock_bpf_run(psock, prog, skb);
+ skb->sk = NULL;
+ }
rcu_read_unlock();
return ret;
}