summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid S. Miller <davem@davemloft.net>2021-06-05 00:25:14 +0300
committerDavid S. Miller <davem@davemloft.net>2021-06-05 00:25:14 +0300
commit6fd815bb1ecc5d3cd99a31e0393fba0be517ed04 (patch)
tree6aba382d4a78664064216c7055da777d772568a5
parent579028dec182c026b9a85725682f1dfbdc825eaa (diff)
parentbf7b042dc62a31f66d3a41dd4dfc7806f267b307 (diff)
downloadlinux-6fd815bb1ecc5d3cd99a31e0393fba0be517ed04.tar.xz
Merge branch 'wireguard-fixes'
Jason A. Donenfeld says: ==================== wireguard fixes for 5.13-rc5 Here are bug fixes to WireGuard for 5.13-rc5: 1-2,6) These are small, trivial tweaks to our test harness. 3) Linus thinks -O3 is still dangerous to enable. The code gen wasn't so much different with -O2 either. 4) We were accidentally calling synchronize_rcu instead of synchronize_net while holding the rtnl_lock, resulting in some rather large stalls that hit production machines. 5) Peer allocation was wasting literally hundreds of megabytes on real world deployments, due to oddly sized large objects not fitting nicely into a kmalloc slab. 7-9) We move from an insanely expensive O(n) algorithm to a fast O(1) algorithm, and cleanup a massive memory leak in the process, in which allowed ips churn would leave danging nodes hanging around without cleanup until the interface was removed. The O(1) algorithm eliminates packet stalls and high latency issues, in addition to bringing operations that took as much as 10 minutes down to less than a second. ==================== Signed-off-by: David S. Miller <davem@davemloft.net>
-rw-r--r--drivers/net/wireguard/Makefile3
-rw-r--r--drivers/net/wireguard/allowedips.c189
-rw-r--r--drivers/net/wireguard/allowedips.h14
-rw-r--r--drivers/net/wireguard/main.c17
-rw-r--r--drivers/net/wireguard/peer.c27
-rw-r--r--drivers/net/wireguard/peer.h3
-rw-r--r--drivers/net/wireguard/selftest/allowedips.c165
-rw-r--r--drivers/net/wireguard/socket.c2
-rwxr-xr-xtools/testing/selftests/wireguard/netns.sh1
-rw-r--r--tools/testing/selftests/wireguard/qemu/kernel.config1
10 files changed, 227 insertions, 195 deletions
diff --git a/drivers/net/wireguard/Makefile b/drivers/net/wireguard/Makefile
index fc52b2cb500b..dbe1f8514efc 100644
--- a/drivers/net/wireguard/Makefile
+++ b/drivers/net/wireguard/Makefile
@@ -1,5 +1,4 @@
-ccflags-y := -O3
-ccflags-y += -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
+ccflags-y := -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG
wireguard-y := main.o
wireguard-y += noise.o
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c
index 3725e9cd85f4..b7197e80f226 100644
--- a/drivers/net/wireguard/allowedips.c
+++ b/drivers/net/wireguard/allowedips.c
@@ -6,6 +6,8 @@
#include "allowedips.h"
#include "peer.h"
+static struct kmem_cache *node_cache;
+
static void swap_endian(u8 *dst, const u8 *src, u8 bits)
{
if (bits == 32) {
@@ -28,8 +30,11 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
node->bitlen = bits;
memcpy(node->bits, src, bits / 8U);
}
-#define CHOOSE_NODE(parent, key) \
- parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
+
+static inline u8 choose(struct allowedips_node *node, const u8 *key)
+{
+ return (key[node->bit_at_a] >> node->bit_at_b) & 1;
+}
static void push_rcu(struct allowedips_node **stack,
struct allowedips_node __rcu *p, unsigned int *len)
@@ -40,6 +45,11 @@ static void push_rcu(struct allowedips_node **stack,
}
}
+static void node_free_rcu(struct rcu_head *rcu)
+{
+ kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, rcu));
+}
+
static void root_free_rcu(struct rcu_head *rcu)
{
struct allowedips_node *node, *stack[128] = {
@@ -49,7 +59,7 @@ static void root_free_rcu(struct rcu_head *rcu)
while (len > 0 && (node = stack[--len])) {
push_rcu(stack, node->bit[0], &len);
push_rcu(stack, node->bit[1], &len);
- kfree(node);
+ kmem_cache_free(node_cache, node);
}
}
@@ -66,60 +76,6 @@ static void root_remove_peer_lists(struct allowedips_node *root)
}
}
-static void walk_remove_by_peer(struct allowedips_node __rcu **top,
- struct wg_peer *peer, struct mutex *lock)
-{
-#define REF(p) rcu_access_pointer(p)
-#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
-#define PUSH(p) ({ \
- WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \
- stack[len++] = p; \
- })
-
- struct allowedips_node __rcu **stack[128], **nptr;
- struct allowedips_node *node, *prev;
- unsigned int len;
-
- if (unlikely(!peer || !REF(*top)))
- return;
-
- for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
- nptr = stack[len - 1];
- node = DEREF(nptr);
- if (!node) {
- --len;
- continue;
- }
- if (!prev || REF(prev->bit[0]) == node ||
- REF(prev->bit[1]) == node) {
- if (REF(node->bit[0]))
- PUSH(&node->bit[0]);
- else if (REF(node->bit[1]))
- PUSH(&node->bit[1]);
- } else if (REF(node->bit[0]) == prev) {
- if (REF(node->bit[1]))
- PUSH(&node->bit[1]);
- } else {
- if (rcu_dereference_protected(node->peer,
- lockdep_is_held(lock)) == peer) {
- RCU_INIT_POINTER(node->peer, NULL);
- list_del_init(&node->peer_list);
- if (!node->bit[0] || !node->bit[1]) {
- rcu_assign_pointer(*nptr, DEREF(
- &node->bit[!REF(node->bit[0])]));
- kfree_rcu(node, rcu);
- node = DEREF(nptr);
- }
- }
- --len;
- }
- }
-
-#undef REF
-#undef DEREF
-#undef PUSH
-}
-
static unsigned int fls128(u64 a, u64 b)
{
return a ? fls64(a) + 64U : fls64(b);
@@ -159,7 +115,7 @@ static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits,
found = node;
if (node->cidr == bits)
break;
- node = rcu_dereference_bh(CHOOSE_NODE(node, key));
+ node = rcu_dereference_bh(node->bit[choose(node, key)]);
}
return found;
}
@@ -191,8 +147,7 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
u8 cidr, u8 bits, struct allowedips_node **rnode,
struct mutex *lock)
{
- struct allowedips_node *node = rcu_dereference_protected(trie,
- lockdep_is_held(lock));
+ struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock));
struct allowedips_node *parent = NULL;
bool exact = false;
@@ -202,13 +157,24 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
exact = true;
break;
}
- node = rcu_dereference_protected(CHOOSE_NODE(parent, key),
- lockdep_is_held(lock));
+ node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock));
}
*rnode = parent;
return exact;
}
+static inline void connect_node(struct allowedips_node **parent, u8 bit, struct allowedips_node *node)
+{
+ node->parent_bit_packed = (unsigned long)parent | bit;
+ rcu_assign_pointer(*parent, node);
+}
+
+static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node)
+{
+ u8 bit = choose(parent, node->bits);
+ connect_node(&parent->bit[bit], bit, node);
+}
+
static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
u8 cidr, struct wg_peer *peer, struct mutex *lock)
{
@@ -218,13 +184,13 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
return -EINVAL;
if (!rcu_access_pointer(*trie)) {
- node = kzalloc(sizeof(*node), GFP_KERNEL);
+ node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
if (unlikely(!node))
return -ENOMEM;
RCU_INIT_POINTER(node->peer, peer);
list_add_tail(&node->peer_list, &peer->allowedips_list);
copy_and_assign_cidr(node, key, cidr, bits);
- rcu_assign_pointer(*trie, node);
+ connect_node(trie, 2, node);
return 0;
}
if (node_placement(*trie, key, cidr, bits, &node, lock)) {
@@ -233,7 +199,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
return 0;
}
- newnode = kzalloc(sizeof(*newnode), GFP_KERNEL);
+ newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL);
if (unlikely(!newnode))
return -ENOMEM;
RCU_INIT_POINTER(newnode->peer, peer);
@@ -243,10 +209,10 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
if (!node) {
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
} else {
- down = rcu_dereference_protected(CHOOSE_NODE(node, key),
- lockdep_is_held(lock));
+ const u8 bit = choose(node, key);
+ down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock));
if (!down) {
- rcu_assign_pointer(CHOOSE_NODE(node, key), newnode);
+ connect_node(&node->bit[bit], bit, newnode);
return 0;
}
}
@@ -254,30 +220,29 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
parent = node;
if (newnode->cidr == cidr) {
- rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down);
+ choose_and_connect_node(newnode, down);
if (!parent)
- rcu_assign_pointer(*trie, newnode);
+ connect_node(trie, 2, newnode);
else
- rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits),
- newnode);
- } else {
- node = kzalloc(sizeof(*node), GFP_KERNEL);
- if (unlikely(!node)) {
- list_del(&newnode->peer_list);
- kfree(newnode);
- return -ENOMEM;
- }
- INIT_LIST_HEAD(&node->peer_list);
- copy_and_assign_cidr(node, newnode->bits, cidr, bits);
+ choose_and_connect_node(parent, newnode);
+ return 0;
+ }
- rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
- rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
- if (!parent)
- rcu_assign_pointer(*trie, node);
- else
- rcu_assign_pointer(CHOOSE_NODE(parent, node->bits),
- node);
+ node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
+ if (unlikely(!node)) {
+ list_del(&newnode->peer_list);
+ kmem_cache_free(node_cache, newnode);
+ return -ENOMEM;
}
+ INIT_LIST_HEAD(&node->peer_list);
+ copy_and_assign_cidr(node, newnode->bits, cidr, bits);
+
+ choose_and_connect_node(node, down);
+ choose_and_connect_node(node, newnode);
+ if (!parent)
+ connect_node(trie, 2, node);
+ else
+ choose_and_connect_node(parent, node);
return 0;
}
@@ -335,9 +300,41 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock)
{
+ struct allowedips_node *node, *child, **parent_bit, *parent, *tmp;
+ bool free_parent;
+
+ if (list_empty(&peer->allowedips_list))
+ return;
++table->seq;
- walk_remove_by_peer(&table->root4, peer, lock);
- walk_remove_by_peer(&table->root6, peer, lock);
+ list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
+ list_del_init(&node->peer_list);
+ RCU_INIT_POINTER(node->peer, NULL);
+ if (node->bit[0] && node->bit[1])
+ continue;
+ child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
+ lockdep_is_held(lock));
+ if (child)
+ child->parent_bit_packed = node->parent_bit_packed;
+ parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL);
+ *parent_bit = child;
+ parent = (void *)parent_bit -
+ offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]);
+ free_parent = !rcu_access_pointer(node->bit[0]) &&
+ !rcu_access_pointer(node->bit[1]) &&
+ (node->parent_bit_packed & 3) <= 1 &&
+ !rcu_access_pointer(parent->peer);
+ if (free_parent)
+ child = rcu_dereference_protected(
+ parent->bit[!(node->parent_bit_packed & 1)],
+ lockdep_is_held(lock));
+ call_rcu(&node->rcu, node_free_rcu);
+ if (!free_parent)
+ continue;
+ if (child)
+ child->parent_bit_packed = parent->parent_bit_packed;
+ *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child;
+ call_rcu(&parent->rcu, node_free_rcu);
+ }
}
int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
@@ -374,4 +371,16 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
return NULL;
}
+int __init wg_allowedips_slab_init(void)
+{
+ node_cache = KMEM_CACHE(allowedips_node, 0);
+ return node_cache ? 0 : -ENOMEM;
+}
+
+void wg_allowedips_slab_uninit(void)
+{
+ rcu_barrier();
+ kmem_cache_destroy(node_cache);
+}
+
#include "selftest/allowedips.c"
diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h
index e5c83cafcef4..2346c797eb4d 100644
--- a/drivers/net/wireguard/allowedips.h
+++ b/drivers/net/wireguard/allowedips.h
@@ -15,14 +15,11 @@ struct wg_peer;
struct allowedips_node {
struct wg_peer __rcu *peer;
struct allowedips_node __rcu *bit[2];
- /* While it may seem scandalous that we waste space for v4,
- * we're alloc'ing to the nearest power of 2 anyway, so this
- * doesn't actually make a difference.
- */
- u8 bits[16] __aligned(__alignof(u64));
u8 cidr, bit_at_a, bit_at_b, bitlen;
+ u8 bits[16] __aligned(__alignof(u64));
- /* Keep rarely used list at bottom to be beyond cache line. */
+ /* Keep rarely used members at bottom to be beyond cache line. */
+ unsigned long parent_bit_packed;
union {
struct list_head peer_list;
struct rcu_head rcu;
@@ -33,7 +30,7 @@ struct allowedips {
struct allowedips_node __rcu *root4;
struct allowedips_node __rcu *root6;
u64 seq;
-};
+} __aligned(4); /* We pack the lower 2 bits of &root, but m68k only gives 16-bit alignment. */
void wg_allowedips_init(struct allowedips *table);
void wg_allowedips_free(struct allowedips *table, struct mutex *mutex);
@@ -56,4 +53,7 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
bool wg_allowedips_selftest(void);
#endif
+int wg_allowedips_slab_init(void);
+void wg_allowedips_slab_uninit(void);
+
#endif /* _WG_ALLOWEDIPS_H */
diff --git a/drivers/net/wireguard/main.c b/drivers/net/wireguard/main.c
index 7a7d5f1a80fc..75dbe77b0b4b 100644
--- a/drivers/net/wireguard/main.c
+++ b/drivers/net/wireguard/main.c
@@ -21,13 +21,22 @@ static int __init mod_init(void)
{
int ret;
+ ret = wg_allowedips_slab_init();
+ if (ret < 0)
+ goto err_allowedips;
+
#ifdef DEBUG
+ ret = -ENOTRECOVERABLE;
if (!wg_allowedips_selftest() || !wg_packet_counter_selftest() ||
!wg_ratelimiter_selftest())
- return -ENOTRECOVERABLE;
+ goto err_peer;
#endif
wg_noise_init();
+ ret = wg_peer_init();
+ if (ret < 0)
+ goto err_peer;
+
ret = wg_device_init();
if (ret < 0)
goto err_device;
@@ -44,6 +53,10 @@ static int __init mod_init(void)
err_netlink:
wg_device_uninit();
err_device:
+ wg_peer_uninit();
+err_peer:
+ wg_allowedips_slab_uninit();
+err_allowedips:
return ret;
}
@@ -51,6 +64,8 @@ static void __exit mod_exit(void)
{
wg_genetlink_uninit();
wg_device_uninit();
+ wg_peer_uninit();
+ wg_allowedips_slab_uninit();
}
module_init(mod_init);
diff --git a/drivers/net/wireguard/peer.c b/drivers/net/wireguard/peer.c
index cd5cb0292cb6..1acd00ab2fbc 100644
--- a/drivers/net/wireguard/peer.c
+++ b/drivers/net/wireguard/peer.c
@@ -15,6 +15,7 @@
#include <linux/rcupdate.h>
#include <linux/list.h>
+static struct kmem_cache *peer_cache;
static atomic64_t peer_counter = ATOMIC64_INIT(0);
struct wg_peer *wg_peer_create(struct wg_device *wg,
@@ -29,10 +30,10 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
if (wg->num_peers >= MAX_PEERS_PER_DEVICE)
return ERR_PTR(ret);
- peer = kzalloc(sizeof(*peer), GFP_KERNEL);
+ peer = kmem_cache_zalloc(peer_cache, GFP_KERNEL);
if (unlikely(!peer))
return ERR_PTR(ret);
- if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL))
+ if (unlikely(dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)))
goto err;
peer->device = wg;
@@ -64,7 +65,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
return peer;
err:
- kfree(peer);
+ kmem_cache_free(peer_cache, peer);
return ERR_PTR(ret);
}
@@ -88,7 +89,7 @@ static void peer_make_dead(struct wg_peer *peer)
/* Mark as dead, so that we don't allow jumping contexts after. */
WRITE_ONCE(peer->is_dead, true);
- /* The caller must now synchronize_rcu() for this to take effect. */
+ /* The caller must now synchronize_net() for this to take effect. */
}
static void peer_remove_after_dead(struct wg_peer *peer)
@@ -160,7 +161,7 @@ void wg_peer_remove(struct wg_peer *peer)
lockdep_assert_held(&peer->device->device_update_lock);
peer_make_dead(peer);
- synchronize_rcu();
+ synchronize_net();
peer_remove_after_dead(peer);
}
@@ -178,7 +179,7 @@ void wg_peer_remove_all(struct wg_device *wg)
peer_make_dead(peer);
list_add_tail(&peer->peer_list, &dead_peers);
}
- synchronize_rcu();
+ synchronize_net();
list_for_each_entry_safe(peer, temp, &dead_peers, peer_list)
peer_remove_after_dead(peer);
}
@@ -193,7 +194,8 @@ static void rcu_release(struct rcu_head *rcu)
/* The final zeroing takes care of clearing any remaining handshake key
* material and other potentially sensitive information.
*/
- kfree_sensitive(peer);
+ memzero_explicit(peer, sizeof(*peer));
+ kmem_cache_free(peer_cache, peer);
}
static void kref_release(struct kref *refcount)
@@ -225,3 +227,14 @@ void wg_peer_put(struct wg_peer *peer)
return;
kref_put(&peer->refcount, kref_release);
}
+
+int __init wg_peer_init(void)
+{
+ peer_cache = KMEM_CACHE(wg_peer, 0);
+ return peer_cache ? 0 : -ENOMEM;
+}
+
+void wg_peer_uninit(void)
+{
+ kmem_cache_destroy(peer_cache);
+}
diff --git a/drivers/net/wireguard/peer.h b/drivers/net/wireguard/peer.h
index 8d53b687a1d1..76e4d3128ad4 100644
--- a/drivers/net/wireguard/peer.h
+++ b/drivers/net/wireguard/peer.h
@@ -80,4 +80,7 @@ void wg_peer_put(struct wg_peer *peer);
void wg_peer_remove(struct wg_peer *peer);
void wg_peer_remove_all(struct wg_device *wg);
+int wg_peer_init(void);
+void wg_peer_uninit(void);
+
#endif /* _WG_PEER_H */
diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c
index 846db14cb046..e173204ae7d7 100644
--- a/drivers/net/wireguard/selftest/allowedips.c
+++ b/drivers/net/wireguard/selftest/allowedips.c
@@ -19,32 +19,22 @@
#include <linux/siphash.h>
-static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits,
- u8 cidr)
-{
- swap_endian(dst, src, bits);
- memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8);
- if (cidr)
- dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8);
-}
-
static __init void print_node(struct allowedips_node *node, u8 bits)
{
char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
- char *fmt_declaration = KERN_DEBUG
- "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
+ char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
+ u8 ip1[16], ip2[16], cidr1, cidr2;
char *style = "dotted";
- u8 ip1[16], ip2[16];
u32 color = 0;
+ if (node == NULL)
+ return;
if (bits == 32) {
fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
- fmt_declaration = KERN_DEBUG
- "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
+ fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
} else if (bits == 128) {
fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
- fmt_declaration = KERN_DEBUG
- "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
+ fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
}
if (node->peer) {
hsiphash_key_t key = { { 0 } };
@@ -55,24 +45,20 @@ static __init void print_node(struct allowedips_node *node, u8 bits)
hsiphash_1u32(0xabad1dea, &key) % 200;
style = "bold";
}
- swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr);
- printk(fmt_declaration, ip1, node->cidr, style, color);
+ wg_allowedips_read_node(node, ip1, &cidr1);
+ printk(fmt_declaration, ip1, cidr1, style, color);
if (node->bit[0]) {
- swap_endian_and_apply_cidr(ip2,
- rcu_dereference_raw(node->bit[0])->bits, bits,
- node->cidr);
- printk(fmt_connection, ip1, node->cidr, ip2,
- rcu_dereference_raw(node->bit[0])->cidr);
- print_node(rcu_dereference_raw(node->bit[0]), bits);
+ wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, &cidr2);
+ printk(fmt_connection, ip1, cidr1, ip2, cidr2);
}
if (node->bit[1]) {
- swap_endian_and_apply_cidr(ip2,
- rcu_dereference_raw(node->bit[1])->bits,
- bits, node->cidr);
- printk(fmt_connection, ip1, node->cidr, ip2,
- rcu_dereference_raw(node->bit[1])->cidr);
- print_node(rcu_dereference_raw(node->bit[1]), bits);
+ wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, &cidr2);
+ printk(fmt_connection, ip1, cidr1, ip2, cidr2);
}
+ if (node->bit[0])
+ print_node(rcu_dereference_raw(node->bit[0]), bits);
+ if (node->bit[1])
+ print_node(rcu_dereference_raw(node->bit[1]), bits);
}
static __init void print_tree(struct allowedips_node __rcu *top, u8 bits)
@@ -121,8 +107,8 @@ static __init inline union nf_inet_addr horrible_cidr_to_mask(u8 cidr)
{
union nf_inet_addr mask;
- memset(&mask, 0x00, 128 / 8);
- memset(&mask, 0xff, cidr / 8);
+ memset(&mask, 0, sizeof(mask));
+ memset(&mask.all, 0xff, cidr / 8);
if (cidr % 32)
mask.all[cidr / 32] = (__force u32)htonl(
(0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
@@ -149,42 +135,36 @@ horrible_mask_self(struct horrible_allowedips_node *node)
}
static __init inline bool
-horrible_match_v4(const struct horrible_allowedips_node *node,
- struct in_addr *ip)
+horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip)
{
return (ip->s_addr & node->mask.ip) == node->ip.ip;
}
static __init inline bool
-horrible_match_v6(const struct horrible_allowedips_node *node,
- struct in6_addr *ip)
+horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip)
{
- return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) ==
- node->ip.ip6[0] &&
- (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) ==
- node->ip.ip6[1] &&
- (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) ==
- node->ip.ip6[2] &&
+ return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] &&
+ (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] &&
+ (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] &&
(ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
}
static __init void
-horrible_insert_ordered(struct horrible_allowedips *table,
- struct horrible_allowedips_node *node)
+horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node)
{
struct horrible_allowedips_node *other = NULL, *where = NULL;
u8 my_cidr = horrible_mask_to_cidr(node->mask);
hlist_for_each_entry(other, &table->head, table) {
- if (!memcmp(&other->mask, &node->mask,
- sizeof(union nf_inet_addr)) &&
- !memcmp(&other->ip, &node->ip,
- sizeof(union nf_inet_addr)) &&
- other->ip_version == node->ip_version) {
+ if (other->ip_version == node->ip_version &&
+ !memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) &&
+ !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) {
other->value = node->value;
kfree(node);
return;
}
+ }
+ hlist_for_each_entry(other, &table->head, table) {
where = other;
if (horrible_mask_to_cidr(other->mask) <= my_cidr)
break;
@@ -201,8 +181,7 @@ static __init int
horrible_allowedips_insert_v4(struct horrible_allowedips *table,
struct in_addr *ip, u8 cidr, void *value)
{
- struct horrible_allowedips_node *node = kzalloc(sizeof(*node),
- GFP_KERNEL);
+ struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
if (unlikely(!node))
return -ENOMEM;
@@ -219,8 +198,7 @@ static __init int
horrible_allowedips_insert_v6(struct horrible_allowedips *table,
struct in6_addr *ip, u8 cidr, void *value)
{
- struct horrible_allowedips_node *node = kzalloc(sizeof(*node),
- GFP_KERNEL);
+ struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
if (unlikely(!node))
return -ENOMEM;
@@ -234,39 +212,43 @@ horrible_allowedips_insert_v6(struct horrible_allowedips *table,
}
static __init void *
-horrible_allowedips_lookup_v4(struct horrible_allowedips *table,
- struct in_addr *ip)
+horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip)
{
struct horrible_allowedips_node *node;
- void *ret = NULL;
hlist_for_each_entry(node, &table->head, table) {
- if (node->ip_version != 4)
- continue;
- if (horrible_match_v4(node, ip)) {
- ret = node->value;
- break;
- }
+ if (node->ip_version == 4 && horrible_match_v4(node, ip))
+ return node->value;
}
- return ret;
+ return NULL;
}
static __init void *
-horrible_allowedips_lookup_v6(struct horrible_allowedips *table,
- struct in6_addr *ip)
+horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip)
{
struct horrible_allowedips_node *node;
- void *ret = NULL;
hlist_for_each_entry(node, &table->head, table) {
- if (node->ip_version != 6)
+ if (node->ip_version == 6 && horrible_match_v6(node, ip))
+ return node->value;
+ }
+ return NULL;
+}
+
+
+static __init void
+horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void *value)
+{
+ struct horrible_allowedips_node *node;
+ struct hlist_node *h;
+
+ hlist_for_each_entry_safe(node, h, &table->head, table) {
+ if (node->value != value)
continue;
- if (horrible_match_v6(node, ip)) {
- ret = node->value;
- break;
- }
+ hlist_del(&node->table);
+ kfree(node);
}
- return ret;
+
}
static __init bool randomized_test(void)
@@ -296,6 +278,7 @@ static __init bool randomized_test(void)
goto free;
}
kref_init(&peers[i]->refcount);
+ INIT_LIST_HEAD(&peers[i]->allowedips_list);
}
mutex_lock(&mutex);
@@ -333,7 +316,7 @@ static __init bool randomized_test(void)
if (wg_allowedips_insert_v4(&t,
(struct in_addr *)mutated,
cidr, peer, &mutex) < 0) {
- pr_err("allowedips random malloc: FAIL\n");
+ pr_err("allowedips random self-test malloc: FAIL\n");
goto free_locked;
}
if (horrible_allowedips_insert_v4(&h,
@@ -396,23 +379,33 @@ static __init bool randomized_test(void)
print_tree(t.root6, 128);
}
- for (i = 0; i < NUM_QUERIES; ++i) {
- prandom_bytes(ip, 4);
- if (lookup(t.root4, 32, ip) !=
- horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
- pr_err("allowedips random self-test: FAIL\n");
- goto free;
+ for (j = 0;; ++j) {
+ for (i = 0; i < NUM_QUERIES; ++i) {
+ prandom_bytes(ip, 4);
+ if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
+ horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip);
+ pr_err("allowedips random v4 self-test: FAIL\n");
+ goto free;
+ }
+ prandom_bytes(ip, 16);
+ if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
+ pr_err("allowedips random v6 self-test: FAIL\n");
+ goto free;
+ }
}
+ if (j >= NUM_PEERS)
+ break;
+ mutex_lock(&mutex);
+ wg_allowedips_remove_by_peer(&t, peers[j], &mutex);
+ mutex_unlock(&mutex);
+ horrible_allowedips_remove_by_value(&h, peers[j]);
}
- for (i = 0; i < NUM_QUERIES; ++i) {
- prandom_bytes(ip, 16);
- if (lookup(t.root6, 128, ip) !=
- horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
- pr_err("allowedips random self-test: FAIL\n");
- goto free;
- }
+ if (t.root4 || t.root6) {
+ pr_err("allowedips random self-test removal: FAIL\n");
+ goto free;
}
+
ret = true;
free:
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c
index d9ad850daa79..8c496b747108 100644
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -430,7 +430,7 @@ void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
if (new4)
wg->incoming_port = ntohs(inet_sk(new4)->inet_sport);
mutex_unlock(&wg->socket_update_lock);
- synchronize_rcu();
+ synchronize_net();
sock_free(old4);
sock_free(old6);
}
diff --git a/tools/testing/selftests/wireguard/netns.sh b/tools/testing/selftests/wireguard/netns.sh
index 7ed7cd95e58f..ebc4ee0fe179 100755
--- a/tools/testing/selftests/wireguard/netns.sh
+++ b/tools/testing/selftests/wireguard/netns.sh
@@ -363,6 +363,7 @@ ip1 -6 rule add table main suppress_prefixlength 0
ip1 -4 route add default dev wg0 table 51820
ip1 -4 rule add not fwmark 51820 table 51820
ip1 -4 rule add table main suppress_prefixlength 0
+n1 bash -c 'printf 0 > /proc/sys/net/ipv4/conf/vethc/rp_filter'
# Flood the pings instead of sending just one, to trigger routing table reference counting bugs.
n1 ping -W 1 -c 100 -f 192.168.99.7
n1 ping -W 1 -c 100 -f abab::1111
diff --git a/tools/testing/selftests/wireguard/qemu/kernel.config b/tools/testing/selftests/wireguard/qemu/kernel.config
index 4eecb432a66c..74db83a0aedd 100644
--- a/tools/testing/selftests/wireguard/qemu/kernel.config
+++ b/tools/testing/selftests/wireguard/qemu/kernel.config
@@ -19,7 +19,6 @@ CONFIG_NETFILTER_XTABLES=y
CONFIG_NETFILTER_XT_NAT=y
CONFIG_NETFILTER_XT_MATCH_LENGTH=y
CONFIG_NETFILTER_XT_MARK=y
-CONFIG_NF_CONNTRACK_IPV4=y
CONFIG_NF_NAT_IPV4=y
CONFIG_IP_NF_IPTABLES=y
CONFIG_IP_NF_FILTER=y