diff options
Diffstat (limited to 'include/net/sock.h')
-rw-r--r-- | include/net/sock.h | 48 |
1 files changed, 34 insertions, 14 deletions
diff --git a/include/net/sock.h b/include/net/sock.h index 2eb916d1ff64..b770261fbdaf 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1053,6 +1053,12 @@ static inline void sk_wmem_queued_add(struct sock *sk, int val) WRITE_ONCE(sk->sk_wmem_queued, sk->sk_wmem_queued + val); } +static inline void sk_forward_alloc_add(struct sock *sk, int val) +{ + /* Paired with lockless reads of sk->sk_forward_alloc */ + WRITE_ONCE(sk->sk_forward_alloc, sk->sk_forward_alloc + val); +} + void sk_stream_write_space(struct sock *sk); /* OOB backlog add */ @@ -1323,6 +1329,7 @@ struct proto { /* * Pressure flag: try to collapse. * Technical note: it is used by multiple contexts non atomically. + * Make sure to use READ_ONCE()/WRITE_ONCE() for all reads/writes. * All the __sk_mem_schedule() is of this nature: accounting * is strict, actions are advisory and have some latency. */ @@ -1339,6 +1346,7 @@ struct proto { struct kmem_cache *slab; unsigned int obj_size; + unsigned int ipv6_pinfo_offset; slab_flags_t slab_flags; unsigned int useroffset; /* Usercopy region offset */ unsigned int usersize; /* Usercopy region size */ @@ -1375,7 +1383,7 @@ static inline int sk_forward_alloc_get(const struct sock *sk) if (sk->sk_prot->forward_alloc_get) return sk->sk_prot->forward_alloc_get(sk); #endif - return sk->sk_forward_alloc; + return READ_ONCE(sk->sk_forward_alloc); } static inline bool __sk_stream_memory_free(const struct sock *sk, int wake) @@ -1420,6 +1428,12 @@ static inline bool sk_has_memory_pressure(const struct sock *sk) return sk->sk_prot->memory_pressure != NULL; } +static inline bool sk_under_global_memory_pressure(const struct sock *sk) +{ + return sk->sk_prot->memory_pressure && + !!READ_ONCE(*sk->sk_prot->memory_pressure); +} + static inline bool sk_under_memory_pressure(const struct sock *sk) { if (!sk->sk_prot->memory_pressure) @@ -1429,7 +1443,7 @@ static inline bool sk_under_memory_pressure(const struct sock *sk) mem_cgroup_under_socket_pressure(sk->sk_memcg)) return true; - return !!*sk->sk_prot->memory_pressure; + return !!READ_ONCE(*sk->sk_prot->memory_pressure); } static inline long @@ -1506,7 +1520,7 @@ proto_memory_pressure(struct proto *prot) { if (!prot->memory_pressure) return false; - return !!*prot->memory_pressure; + return !!READ_ONCE(*prot->memory_pressure); } @@ -1665,14 +1679,14 @@ static inline void sk_mem_charge(struct sock *sk, int size) { if (!sk_has_account(sk)) return; - sk->sk_forward_alloc -= size; + sk_forward_alloc_add(sk, -size); } static inline void sk_mem_uncharge(struct sock *sk, int size) { if (!sk_has_account(sk)) return; - sk->sk_forward_alloc += size; + sk_forward_alloc_add(sk, size); sk_mem_reclaim(sk); } @@ -1892,7 +1906,9 @@ struct sockcm_cookie { static inline void sockcm_init(struct sockcm_cookie *sockc, const struct sock *sk) { - *sockc = (struct sockcm_cookie) { .tsflags = sk->sk_tsflags }; + *sockc = (struct sockcm_cookie) { + .tsflags = READ_ONCE(sk->sk_tsflags) + }; } int __sock_cmsg_send(struct sock *sk, struct cmsghdr *cmsg, @@ -2687,9 +2703,9 @@ void __sock_recv_wifi_status(struct msghdr *msg, struct sock *sk, static inline void sock_recv_timestamp(struct msghdr *msg, struct sock *sk, struct sk_buff *skb) { - ktime_t kt = skb->tstamp; struct skb_shared_hwtstamps *hwtstamps = skb_hwtstamps(skb); - + u32 tsflags = READ_ONCE(sk->sk_tsflags); + ktime_t kt = skb->tstamp; /* * generate control messages if * - receive time stamping in software requested @@ -2697,10 +2713,10 @@ sock_recv_timestamp(struct msghdr *msg, struct sock *sk, struct sk_buff *skb) * - hardware time stamps available and wanted */ if (sock_flag(sk, SOCK_RCVTSTAMP) || - (sk->sk_tsflags & SOF_TIMESTAMPING_RX_SOFTWARE) || - (kt && sk->sk_tsflags & SOF_TIMESTAMPING_SOFTWARE) || + (tsflags & SOF_TIMESTAMPING_RX_SOFTWARE) || + (kt && tsflags & SOF_TIMESTAMPING_SOFTWARE) || (hwtstamps->hwtstamp && - (sk->sk_tsflags & SOF_TIMESTAMPING_RAW_HARDWARE))) + (tsflags & SOF_TIMESTAMPING_RAW_HARDWARE))) __sock_recv_timestamp(msg, sk, skb); else sock_write_timestamp(sk, kt); @@ -2722,7 +2738,8 @@ static inline void sock_recv_cmsgs(struct msghdr *msg, struct sock *sk, #define TSFLAGS_ANY (SOF_TIMESTAMPING_SOFTWARE | \ SOF_TIMESTAMPING_RAW_HARDWARE) - if (sk->sk_flags & FLAGS_RECV_CMSGS || sk->sk_tsflags & TSFLAGS_ANY) + if (sk->sk_flags & FLAGS_RECV_CMSGS || + READ_ONCE(sk->sk_tsflags) & TSFLAGS_ANY) __sock_recv_cmsgs(msg, sk, skb); else if (unlikely(sock_flag(sk, SOCK_TIMESTAMP))) sock_write_timestamp(sk, skb->tstamp); @@ -2814,20 +2831,23 @@ sk_is_refcounted(struct sock *sk) * skb_steal_sock - steal a socket from an sk_buff * @skb: sk_buff to steal the socket from * @refcounted: is set to true if the socket is reference-counted + * @prefetched: is set to true if the socket was assigned from bpf */ static inline struct sock * -skb_steal_sock(struct sk_buff *skb, bool *refcounted) +skb_steal_sock(struct sk_buff *skb, bool *refcounted, bool *prefetched) { if (skb->sk) { struct sock *sk = skb->sk; *refcounted = true; - if (skb_sk_is_prefetched(skb)) + *prefetched = skb_sk_is_prefetched(skb); + if (*prefetched) *refcounted = sk_is_refcounted(sk); skb->destructor = NULL; skb->sk = NULL; return sk; } + *prefetched = false; *refcounted = false; return NULL; } |