summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/net/dst.h33
-rw-r--r--include/net/sock.h2
-rw-r--r--net/ipv4/tcp_ipv4.c9
-rw-r--r--net/ipv6/tcp_ipv6.c11
4 files changed, 45 insertions, 10 deletions
diff --git a/include/net/dst.h b/include/net/dst.h
index 30cd2f9cd1dd..d30afbdc1a59 100644
--- a/include/net/dst.h
+++ b/include/net/dst.h
@@ -306,6 +306,39 @@ static inline void skb_dst_force(struct sk_buff *skb)
}
}
+/**
+ * dst_hold_safe - Take a reference on a dst if possible
+ * @dst: pointer to dst entry
+ *
+ * This helper returns false if it could not safely
+ * take a reference on a dst.
+ */
+static inline bool dst_hold_safe(struct dst_entry *dst)
+{
+ if (dst->flags & DST_NOCACHE)
+ return atomic_inc_not_zero(&dst->__refcnt);
+ dst_hold(dst);
+ return true;
+}
+
+/**
+ * skb_dst_force_safe - makes sure skb dst is refcounted
+ * @skb: buffer
+ *
+ * If dst is not yet refcounted and not destroyed, grab a ref on it.
+ */
+static inline void skb_dst_force_safe(struct sk_buff *skb)
+{
+ if (skb_dst_is_noref(skb)) {
+ struct dst_entry *dst = skb_dst(skb);
+
+ if (!dst_hold_safe(dst))
+ dst = NULL;
+
+ skb->_skb_refdst = (unsigned long)dst;
+ }
+}
+
/**
* __skb_tunnel_rx - prepare skb for rx reinsert
diff --git a/include/net/sock.h b/include/net/sock.h
index 41d98f1d0459..6ed6df149bce 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -760,7 +760,7 @@ extern void sk_stream_write_space(struct sock *sk);
static inline void __sk_add_backlog(struct sock *sk, struct sk_buff *skb)
{
/* dont let skb dst not refcounted, we are going to leave rcu lock */
- skb_dst_force(skb);
+ skb_dst_force_safe(skb);
if (!sk->sk_backlog.tail)
sk->sk_backlog.head = skb;
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 624ceca7ffd1..09451a2cbd6a 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -1905,7 +1905,7 @@ bool tcp_prequeue(struct sock *sk, struct sk_buff *skb)
skb_queue_len(&tp->ucopy.prequeue) == 0)
return false;
- skb_dst_force(skb);
+ skb_dst_force_safe(skb);
__skb_queue_tail(&tp->ucopy.prequeue, skb);
tp->ucopy.memory += skb->truesize;
if (tp->ucopy.memory > sk->sk_rcvbuf) {
@@ -2098,9 +2098,10 @@ void inet_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb)
{
struct dst_entry *dst = skb_dst(skb);
- dst_hold(dst);
- sk->sk_rx_dst = dst;
- inet_sk(sk)->rx_dst_ifindex = skb->skb_iif;
+ if (dst_hold_safe(dst)) {
+ sk->sk_rx_dst = dst;
+ inet_sk(sk)->rx_dst_ifindex = skb->skb_iif;
+ }
}
EXPORT_SYMBOL(inet_sk_rx_dst_set);
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index 65c310d6e92a..90004c6e3bff 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -97,11 +97,12 @@ static void inet6_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb)
struct dst_entry *dst = skb_dst(skb);
const struct rt6_info *rt = (const struct rt6_info *)dst;
- dst_hold(dst);
- sk->sk_rx_dst = dst;
- inet_sk(sk)->rx_dst_ifindex = skb->skb_iif;
- if (rt->rt6i_node)
- inet6_sk(sk)->rx_dst_cookie = rt->rt6i_node->fn_sernum;
+ if (dst_hold_safe(dst)) {
+ sk->sk_rx_dst = dst;
+ inet_sk(sk)->rx_dst_ifindex = skb->skb_iif;
+ if (rt->rt6i_node)
+ inet6_sk(sk)->rx_dst_cookie = rt->rt6i_node->fn_sernum;
+ }
}
static void tcp_v6_hash(struct sock *sk)