diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/core/devlink.c | 291 | ||||
-rw-r--r-- | net/ipv4/tcp.c | 42 | ||||
-rw-r--r-- | net/ipv6/ip6mr.c | 4 | ||||
-rw-r--r-- | net/sched/sch_cbq.c | 3 | ||||
-rw-r--r-- | net/smc/af_smc.c | 1 | ||||
-rw-r--r-- | net/smc/smc_diag.c | 1 | ||||
-rw-r--r-- | net/smc/smc_ism.c | 19 | ||||
-rw-r--r-- | net/smc/smc_ism.h | 20 | ||||
-rw-r--r-- | net/smc/smc_tx.c | 10 | ||||
-rw-r--r-- | net/tls/tls.h | 29 | ||||
-rw-r--r-- | net/tls/tls_device.c | 19 | ||||
-rw-r--r-- | net/tls/tls_main.c | 20 | ||||
-rw-r--r-- | net/tls/tls_strp.c | 488 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 228 |
14 files changed, 858 insertions, 317 deletions
diff --git a/net/core/devlink.c b/net/core/devlink.c index 98d79feeb3dc..ca4c9939d569 100644 --- a/net/core/devlink.c +++ b/net/core/devlink.c @@ -70,6 +70,7 @@ struct devlink { u8 reload_failed:1; refcount_t refcount; struct completion comp; + struct rcu_head rcu; char priv[] __aligned(NETDEV_ALIGN); }; @@ -88,6 +89,7 @@ struct devlink_linecard { const char *type; struct devlink_linecard_type *types; unsigned int types_count; + struct devlink *nested_devlink; }; /** @@ -221,8 +223,6 @@ static DEFINE_XARRAY_FLAGS(devlinks, XA_FLAGS_ALLOC); /* devlink_mutex * * An overall lock guarding every operation coming from userspace. - * It also guards devlink devices list and it is taken when - * driver registers/unregisters it. */ static DEFINE_MUTEX(devlink_mutex); @@ -232,10 +232,21 @@ struct net *devlink_net(const struct devlink *devlink) } EXPORT_SYMBOL_GPL(devlink_net); +static void __devlink_put_rcu(struct rcu_head *head) +{ + struct devlink *devlink = container_of(head, struct devlink, rcu); + + complete(&devlink->comp); +} + void devlink_put(struct devlink *devlink) { if (refcount_dec_and_test(&devlink->refcount)) - complete(&devlink->comp); + /* Make sure unregister operation that may await the completion + * is unblocked only after all users are after the end of + * RCU grace period. + */ + call_rcu(&devlink->rcu, __devlink_put_rcu); } struct devlink *__must_check devlink_try_get(struct devlink *devlink) @@ -278,12 +289,62 @@ void devl_unlock(struct devlink *devlink) } EXPORT_SYMBOL_GPL(devl_unlock); +static struct devlink * +devlinks_xa_find_get(struct net *net, unsigned long *indexp, xa_mark_t filter, + void * (*xa_find_fn)(struct xarray *, unsigned long *, + unsigned long, xa_mark_t)) +{ + struct devlink *devlink; + + rcu_read_lock(); +retry: + devlink = xa_find_fn(&devlinks, indexp, ULONG_MAX, DEVLINK_REGISTERED); + if (!devlink) + goto unlock; + /* For a possible retry, the xa_find_after() should be always used */ + xa_find_fn = xa_find_after; + if (!devlink_try_get(devlink)) + goto retry; + if (!net_eq(devlink_net(devlink), net)) { + devlink_put(devlink); + goto retry; + } +unlock: + rcu_read_unlock(); + return devlink; +} + +static struct devlink *devlinks_xa_find_get_first(struct net *net, + unsigned long *indexp, + xa_mark_t filter) +{ + return devlinks_xa_find_get(net, indexp, filter, xa_find); +} + +static struct devlink *devlinks_xa_find_get_next(struct net *net, + unsigned long *indexp, + xa_mark_t filter) +{ + return devlinks_xa_find_get(net, indexp, filter, xa_find_after); +} + +/* Iterate over devlink pointers which were possible to get reference to. + * devlink_put() needs to be called for each iterated devlink pointer + * in loop body in order to release the reference. + */ +#define devlinks_xa_for_each_get(net, index, devlink, filter) \ + for (index = 0, \ + devlink = devlinks_xa_find_get_first(net, &index, filter); \ + devlink; devlink = devlinks_xa_find_get_next(net, &index, filter)) + +#define devlinks_xa_for_each_registered_get(net, index, devlink) \ + devlinks_xa_for_each_get(net, index, devlink, DEVLINK_REGISTERED) + static struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs) { struct devlink *devlink; unsigned long index; - bool found = false; char *busname; char *devname; @@ -293,21 +354,14 @@ static struct devlink *devlink_get_from_attrs(struct net *net, busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]); devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]); - lockdep_assert_held(&devlink_mutex); - - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { + devlinks_xa_for_each_registered_get(net, index, devlink) { if (strcmp(devlink->dev->bus->name, busname) == 0 && - strcmp(dev_name(devlink->dev), devname) == 0 && - net_eq(devlink_net(devlink), net)) { - found = true; - break; - } + strcmp(dev_name(devlink->dev), devname) == 0) + return devlink; + devlink_put(devlink); } - if (!found || !devlink_try_get(devlink)) - devlink = ERR_PTR(-ENODEV); - - return devlink; + return ERR_PTR(-ENODEV); } static struct devlink_port *devlink_port_get_by_index(struct devlink *devlink, @@ -803,6 +857,24 @@ static int devlink_nl_put_handle(struct sk_buff *msg, struct devlink *devlink) return 0; } +static int devlink_nl_put_nested_handle(struct sk_buff *msg, struct devlink *devlink) +{ + struct nlattr *nested_attr; + + nested_attr = nla_nest_start(msg, DEVLINK_ATTR_NESTED_DEVLINK); + if (!nested_attr) + return -EMSGSIZE; + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + + nla_nest_end(msg, nested_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(msg, nested_attr); + return -EMSGSIZE; +} + struct devlink_reload_combination { enum devlink_reload_action action; enum devlink_reload_limit limit; @@ -1329,13 +1401,7 @@ static int devlink_nl_cmd_rate_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); list_for_each_entry(devlink_rate, &devlink->rate_list, list) { enum devlink_command cmd = DEVLINK_CMD_RATE_NEW; @@ -1356,7 +1422,6 @@ static int devlink_nl_cmd_rate_get_dumpit(struct sk_buff *msg, idx++; } devl_unlock(devlink); -retry: devlink_put(devlink); } out: @@ -1432,15 +1497,7 @@ static int devlink_nl_cmd_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) { - devlink_put(devlink); - continue; - } - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { if (idx < start) { idx++; devlink_put(devlink); @@ -1495,13 +1552,7 @@ static int devlink_nl_cmd_port_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); list_for_each_entry(devlink_port, &devlink->port_list, list) { if (idx < start) { @@ -1521,7 +1572,6 @@ static int devlink_nl_cmd_port_get_dumpit(struct sk_buff *msg, idx++; } devl_unlock(devlink); -retry: devlink_put(devlink); } out: @@ -2104,6 +2154,10 @@ static int devlink_nl_linecard_fill(struct sk_buff *msg, nla_nest_end(msg, attr); } + if (linecard->nested_devlink && + devlink_nl_put_nested_handle(msg, linecard->nested_devlink)) + goto nla_put_failure; + genlmsg_end(msg, hdr); return 0; @@ -2177,13 +2231,7 @@ static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { mutex_lock(&devlink->linecards_lock); list_for_each_entry(linecard, &devlink->linecard_list, list) { if (idx < start) { @@ -2206,7 +2254,6 @@ static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg, idx++; } mutex_unlock(&devlink->linecards_lock); -retry: devlink_put(devlink); } out: @@ -2449,13 +2496,7 @@ static int devlink_nl_cmd_sb_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); list_for_each_entry(devlink_sb, &devlink->sb_list, list) { if (idx < start) { @@ -2475,7 +2516,6 @@ static int devlink_nl_cmd_sb_get_dumpit(struct sk_buff *msg, idx++; } devl_unlock(devlink); -retry: devlink_put(devlink); } out: @@ -2601,12 +2641,8 @@ static int devlink_nl_cmd_sb_pool_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) || - !devlink->ops->sb_pool_get) + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + if (!devlink->ops->sb_pool_get) goto retry; devl_lock(devlink); @@ -2822,12 +2858,8 @@ static int devlink_nl_cmd_sb_port_pool_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) || - !devlink->ops->sb_port_pool_get) + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + if (!devlink->ops->sb_port_pool_get) goto retry; devl_lock(devlink); @@ -3071,12 +3103,8 @@ devlink_nl_cmd_sb_tc_pool_bind_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) || - !devlink->ops->sb_tc_pool_bind_get) + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + if (!devlink->ops->sb_tc_pool_bind_get) goto retry; devl_lock(devlink); @@ -5158,13 +5186,7 @@ static int devlink_nl_cmd_param_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); list_for_each_entry(param_item, &devlink->param_list, list) { if (idx < start) { @@ -5186,7 +5208,6 @@ static int devlink_nl_cmd_param_get_dumpit(struct sk_buff *msg, idx++; } devl_unlock(devlink); -retry: devlink_put(devlink); } out: @@ -5393,13 +5414,7 @@ static int devlink_nl_cmd_port_param_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); list_for_each_entry(devlink_port, &devlink->port_list, list) { list_for_each_entry(param_item, @@ -5426,7 +5441,6 @@ static int devlink_nl_cmd_port_param_get_dumpit(struct sk_buff *msg, } } devl_unlock(devlink); -retry: devlink_put(devlink); } out: @@ -5977,16 +5991,9 @@ static int devlink_nl_cmd_region_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { err = devlink_nl_cmd_region_get_devlink_dumpit(msg, cb, devlink, &idx, start); -retry: devlink_put(devlink); if (err) goto out; @@ -6511,13 +6518,7 @@ static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg, int err = 0; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { if (idx < start || !devlink->ops->info_get) goto inc; @@ -6535,7 +6536,6 @@ static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg, } inc: idx++; -retry: devlink_put(devlink); } mutex_unlock(&devlink_mutex); @@ -7691,13 +7691,7 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry_rep; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { mutex_lock(&devlink->reporters_lock); list_for_each_entry(reporter, &devlink->reporter_list, list) { @@ -7717,17 +7711,10 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg, idx++; } mutex_unlock(&devlink->reporters_lock); -retry_rep: devlink_put(devlink); } - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry_port; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); list_for_each_entry(port, &devlink->port_list, list) { mutex_lock(&port->reporters_lock); @@ -7752,7 +7739,6 @@ retry_rep: mutex_unlock(&port->reporters_lock); } devl_unlock(devlink); -retry_port: devlink_put(devlink); } out: @@ -8291,13 +8277,7 @@ static int devlink_nl_cmd_trap_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); list_for_each_entry(trap_item, &devlink->trap_list, list) { if (idx < start) { @@ -8317,7 +8297,6 @@ static int devlink_nl_cmd_trap_get_dumpit(struct sk_buff *msg, idx++; } devl_unlock(devlink); -retry: devlink_put(devlink); } out: @@ -8518,13 +8497,7 @@ static int devlink_nl_cmd_trap_group_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); list_for_each_entry(group_item, &devlink->trap_group_list, list) { @@ -8545,7 +8518,6 @@ static int devlink_nl_cmd_trap_group_get_dumpit(struct sk_buff *msg, idx++; } devl_unlock(devlink); -retry: devlink_put(devlink); } out: @@ -8832,13 +8804,7 @@ static int devlink_nl_cmd_trap_policer_get_dumpit(struct sk_buff *msg, int err; mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) - goto retry; - + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); list_for_each_entry(policer_item, &devlink->trap_policer_list, list) { @@ -8859,7 +8825,6 @@ static int devlink_nl_cmd_trap_policer_get_dumpit(struct sk_buff *msg, idx++; } devl_unlock(devlink); -retry: devlink_put(devlink); } out: @@ -9589,10 +9554,8 @@ void devlink_register(struct devlink *devlink) ASSERT_DEVLINK_NOT_REGISTERED(devlink); /* Make sure that we are in .probe() routine */ - mutex_lock(&devlink_mutex); xa_set_mark(&devlinks, devlink->index, DEVLINK_REGISTERED); devlink_notify_register(devlink); - mutex_unlock(&devlink_mutex); } EXPORT_SYMBOL_GPL(devlink_register); @@ -9609,10 +9572,8 @@ void devlink_unregister(struct devlink *devlink) devlink_put(devlink); wait_for_completion(&devlink->comp); - mutex_lock(&devlink_mutex); devlink_notify_unregister(devlink); xa_clear_mark(&devlinks, devlink->index, DEVLINK_REGISTERED); - mutex_unlock(&devlink_mutex); } EXPORT_SYMBOL_GPL(devlink_unregister); @@ -10316,6 +10277,7 @@ EXPORT_SYMBOL_GPL(devlink_linecard_provision_set); void devlink_linecard_provision_clear(struct devlink_linecard *linecard) { mutex_lock(&linecard->state_lock); + WARN_ON(linecard->nested_devlink); linecard->state = DEVLINK_LINECARD_STATE_UNPROVISIONED; linecard->type = NULL; devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); @@ -10334,6 +10296,7 @@ EXPORT_SYMBOL_GPL(devlink_linecard_provision_clear); void devlink_linecard_provision_fail(struct devlink_linecard *linecard) { mutex_lock(&linecard->state_lock); + WARN_ON(linecard->nested_devlink); linecard->state = DEVLINK_LINECARD_STATE_PROVISIONING_FAILED; devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); mutex_unlock(&linecard->state_lock); @@ -10381,6 +10344,23 @@ void devlink_linecard_deactivate(struct devlink_linecard *linecard) } EXPORT_SYMBOL_GPL(devlink_linecard_deactivate); +/** + * devlink_linecard_nested_dl_set - Attach/detach nested devlink + * instance to linecard. + * + * @linecard: devlink linecard + * @nested_devlink: devlink instance to attach or NULL to detach + */ +void devlink_linecard_nested_dl_set(struct devlink_linecard *linecard, + struct devlink *nested_devlink) +{ + mutex_lock(&linecard->state_lock); + linecard->nested_devlink = nested_devlink; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + mutex_unlock(&linecard->state_lock); +} +EXPORT_SYMBOL_GPL(devlink_linecard_nested_dl_set); + int devl_sb_register(struct devlink *devlink, unsigned int sb_index, u32 size, u16 ingress_pools_count, u16 egress_pools_count, u16 ingress_tc_count, @@ -12281,13 +12261,7 @@ static void __net_exit devlink_pernet_pre_exit(struct net *net) * all devlink instances from this namespace into init_net. */ mutex_lock(&devlink_mutex); - xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) { - if (!devlink_try_get(devlink)) - continue; - - if (!net_eq(devlink_net(devlink), net)) - goto retry; - + devlinks_xa_for_each_registered_get(net, index, devlink) { WARN_ON(!(devlink->features & DEVLINK_F_RELOAD)); err = devlink_reload(devlink, &init_net, DEVLINK_RELOAD_ACTION_DRIVER_REINIT, @@ -12295,7 +12269,6 @@ static void __net_exit devlink_pernet_pre_exit(struct net *net) &actions_performed, NULL); if (err && err != -EOPNOTSUPP) pr_warn("Failed to reload devlink instance into init_net\n"); -retry: devlink_put(devlink); } mutex_unlock(&devlink_mutex); diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index ba2bdc811374..dc7cc3ce6a53 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -1635,7 +1635,7 @@ static void tcp_eat_recv_skb(struct sock *sk, struct sk_buff *skb) __kfree_skb(skb); } -static struct sk_buff *tcp_recv_skb(struct sock *sk, u32 seq, u32 *off) +struct sk_buff *tcp_recv_skb(struct sock *sk, u32 seq, u32 *off) { struct sk_buff *skb; u32 offset; @@ -1658,6 +1658,7 @@ static struct sk_buff *tcp_recv_skb(struct sock *sk, u32 seq, u32 *off) } return NULL; } +EXPORT_SYMBOL(tcp_recv_skb); /* * This routine provides an alternative to tcp_recvmsg() for routines @@ -1788,6 +1789,45 @@ int tcp_read_skb(struct sock *sk, skb_read_actor_t recv_actor) } EXPORT_SYMBOL(tcp_read_skb); +void tcp_read_done(struct sock *sk, size_t len) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 seq = tp->copied_seq; + struct sk_buff *skb; + size_t left; + u32 offset; + + if (sk->sk_state == TCP_LISTEN) + return; + + left = len; + while (left && (skb = tcp_recv_skb(sk, seq, &offset)) != NULL) { + int used; + + used = min_t(size_t, skb->len - offset, left); + seq += used; + left -= used; + + if (skb->len > offset + used) + break; + + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) { + tcp_eat_recv_skb(sk, skb); + ++seq; + break; + } + tcp_eat_recv_skb(sk, skb); + } + WRITE_ONCE(tp->copied_seq, seq); + + tcp_rcv_space_adjust(sk); + + /* Clean up data we have read: This will do ACK frames. */ + if (left != len) + tcp_cleanup_rbuf(sk, len - left); +} +EXPORT_SYMBOL(tcp_read_done); + int tcp_peek_len(struct socket *sock) { return tcp_inq(sock->sk); diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c index d546fc09d803..a9ba41648e36 100644 --- a/net/ipv6/ip6mr.c +++ b/net/ipv6/ip6mr.c @@ -2133,10 +2133,8 @@ static void ip6_mr_forward(struct net *net, struct mr_table *mrt, */ cache_proxy = mr_mfc_find_any_parent(mrt, vif); if (cache_proxy && - cache_proxy->_c.mfc_un.res.ttls[true_vifi] < 255) { - rcu_read_unlock(); + cache_proxy->_c.mfc_un.res.ttls[true_vifi] < 255) goto forward; - } } /* diff --git a/net/sched/sch_cbq.c b/net/sched/sch_cbq.c index 599e26fc2fa8..91a0dc463c48 100644 --- a/net/sched/sch_cbq.c +++ b/net/sched/sch_cbq.c @@ -979,7 +979,7 @@ cbq_reset(struct Qdisc *sch) } -static int cbq_set_lss(struct cbq_class *cl, struct tc_cbq_lssopt *lss) +static void cbq_set_lss(struct cbq_class *cl, struct tc_cbq_lssopt *lss) { if (lss->change & TCF_CBQ_LSS_FLAGS) { cl->share = (lss->flags & TCF_CBQ_LSS_ISOLATED) ? NULL : cl->tparent; @@ -997,7 +997,6 @@ static int cbq_set_lss(struct cbq_class *cl, struct tc_cbq_lssopt *lss) } if (lss->change & TCF_CBQ_LSS_OFFTIME) cl->offtime = lss->offtime; - return 0; } static void cbq_rmprio(struct cbq_sched_data *q, struct cbq_class *cl) diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index 6e70d9c10b78..79c1318af1fe 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -3515,3 +3515,4 @@ MODULE_DESCRIPTION("smc socket address family"); MODULE_LICENSE("GPL"); MODULE_ALIAS_NETPROTO(PF_SMC); MODULE_ALIAS_TCP_ULP("smc"); +MODULE_ALIAS_GENL_FAMILY(SMC_GENL_FAMILY_NAME); diff --git a/net/smc/smc_diag.c b/net/smc/smc_diag.c index 1fca2f90a9c7..80ea7d954ece 100644 --- a/net/smc/smc_diag.c +++ b/net/smc/smc_diag.c @@ -268,3 +268,4 @@ module_init(smc_diag_init); module_exit(smc_diag_exit); MODULE_LICENSE("GPL"); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 43 /* AF_SMC */); +MODULE_ALIAS_GENL_FAMILY(SMCR_GENL_FAMILY_NAME); diff --git a/net/smc/smc_ism.c b/net/smc/smc_ism.c index a2084ecdb97e..911fe08bc54b 100644 --- a/net/smc/smc_ism.c +++ b/net/smc/smc_ism.c @@ -33,17 +33,6 @@ int smc_ism_cantalk(u64 peer_gid, unsigned short vlan_id, struct smcd_dev *smcd) vlan_id); } -int smc_ism_write(struct smcd_dev *smcd, const struct smc_ism_position *pos, - void *data, size_t len) -{ - int rc; - - rc = smcd->ops->move_data(smcd, pos->token, pos->index, pos->signal, - pos->offset, data, len); - - return rc < 0 ? rc : 0; -} - void smc_ism_get_system_eid(u8 **eid) { if (!smc_ism_v2_capable) @@ -440,7 +429,7 @@ int smcd_register_dev(struct smcd_dev *smcd) if (list_empty(&smcd_dev_list.list)) { u8 *system_eid = NULL; - smcd->ops->get_system_eid(smcd, &system_eid); + system_eid = smcd->ops->get_system_eid(); if (system_eid[24] != '0' || system_eid[28] != '0') { smc_ism_v2_capable = true; memcpy(smc_ism_v2_system_eid, system_eid, @@ -519,13 +508,13 @@ void smcd_handle_event(struct smcd_dev *smcd, struct smcd_event *event) EXPORT_SYMBOL_GPL(smcd_handle_event); /* SMCD Device interrupt handler. Called from ISM device interrupt handler. - * Parameters are smcd device pointer and DMB number. Find the connection and - * schedule the tasklet for this connection. + * Parameters are smcd device pointer, DMB number, and the DMBE bitmask. + * Find the connection and schedule the tasklet for this connection. * * Context: * - Function called in IRQ context from ISM device driver IRQ handler. */ -void smcd_handle_irq(struct smcd_dev *smcd, unsigned int dmbno) +void smcd_handle_irq(struct smcd_dev *smcd, unsigned int dmbno, u16 dmbemask) { struct smc_connection *conn = NULL; unsigned long flags; diff --git a/net/smc/smc_ism.h b/net/smc/smc_ism.h index 004b22a13ffa..d6b2db604fe8 100644 --- a/net/smc/smc_ism.h +++ b/net/smc/smc_ism.h @@ -28,13 +28,6 @@ struct smc_ism_vlanid { /* VLAN id set on ISM device */ refcount_t refcnt; /* Reference count */ }; -struct smc_ism_position { /* ISM device position to write to */ - u64 token; /* Token of DMB */ - u32 offset; /* Offset into DMBE */ - u8 index; /* Index of DMBE */ - u8 signal; /* Generate interrupt on owner side */ -}; - struct smcd_dev; int smc_ism_cantalk(u64 peer_gid, unsigned short vlan_id, struct smcd_dev *dev); @@ -45,12 +38,21 @@ int smc_ism_put_vlan(struct smcd_dev *dev, unsigned short vlan_id); int smc_ism_register_dmb(struct smc_link_group *lgr, int buf_size, struct smc_buf_desc *dmb_desc); int smc_ism_unregister_dmb(struct smcd_dev *dev, struct smc_buf_desc *dmb_desc); -int smc_ism_write(struct smcd_dev *dev, const struct smc_ism_position *pos, - void *data, size_t len); int smc_ism_signal_shutdown(struct smc_link_group *lgr); void smc_ism_get_system_eid(u8 **eid); u16 smc_ism_get_chid(struct smcd_dev *dev); bool smc_ism_is_v2_capable(void); void smc_ism_init(void); int smcd_nl_get_device(struct sk_buff *skb, struct netlink_callback *cb); + +static inline int smc_ism_write(struct smcd_dev *smcd, u64 dmb_tok, + unsigned int idx, bool sf, unsigned int offset, + void *data, size_t len) +{ + int rc; + + rc = smcd->ops->move_data(smcd, dmb_tok, idx, sf, offset, data, len); + return rc < 0 ? rc : 0; +} + #endif diff --git a/net/smc/smc_tx.c b/net/smc/smc_tx.c index 4e8377657a62..64dedffe9d26 100644 --- a/net/smc/smc_tx.c +++ b/net/smc/smc_tx.c @@ -320,15 +320,11 @@ int smc_tx_sendpage(struct smc_sock *smc, struct page *page, int offset, int smcd_tx_ism_write(struct smc_connection *conn, void *data, size_t len, u32 offset, int signal) { - struct smc_ism_position pos; int rc; - memset(&pos, 0, sizeof(pos)); - pos.token = conn->peer_token; - pos.index = conn->peer_rmbe_idx; - pos.offset = conn->tx_off + offset; - pos.signal = signal; - rc = smc_ism_write(conn->lgr->smcd, &pos, data, len); + rc = smc_ism_write(conn->lgr->smcd, conn->peer_token, + conn->peer_rmbe_idx, signal, conn->tx_off + offset, + data, len); if (rc) conn->local_tx_ctrl.conn_state_flags.peer_conn_abort = 1; return rc; diff --git a/net/tls/tls.h b/net/tls/tls.h index 3740740504e3..0e840a0c3437 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -1,4 +1,5 @@ /* + * Copyright (c) 2016 Tom Herbert <tom@herbertland.com> * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. * @@ -127,8 +128,24 @@ int tls_sw_fallback_init(struct sock *sk, struct tls_offload_context_tx *offload_ctx, struct tls_crypto_info *crypto_info); -int tls_strp_msg_hold(struct sock *sk, struct sk_buff *skb, - struct sk_buff_head *dst); +int tls_strp_dev_init(void); +void tls_strp_dev_exit(void); + +void tls_strp_done(struct tls_strparser *strp); +void tls_strp_stop(struct tls_strparser *strp); +int tls_strp_init(struct tls_strparser *strp, struct sock *sk); +void tls_strp_data_ready(struct tls_strparser *strp); + +void tls_strp_check_rcv(struct tls_strparser *strp); +void tls_strp_msg_done(struct tls_strparser *strp); + +int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb); +void tls_rx_msg_ready(struct tls_strparser *strp); + +void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh); +int tls_strp_msg_cow(struct tls_sw_context_rx *ctx); +struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx); +int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst); static inline struct tls_msg *tls_msg(struct sk_buff *skb) { @@ -139,7 +156,13 @@ static inline struct tls_msg *tls_msg(struct sk_buff *skb) static inline struct sk_buff *tls_strp_msg(struct tls_sw_context_rx *ctx) { - return ctx->recv_pkt; + DEBUG_NET_WARN_ON_ONCE(!ctx->strp.msg_ready || !ctx->strp.anchor->len); + return ctx->strp.anchor; +} + +static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx) +{ + return ctx->strp.msg_ready; } #ifdef CONFIG_TLS_DEVICE diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index b1fcd61836d1..fc513c1806a0 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -894,27 +894,26 @@ static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx, static int tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx) { - int err = 0, offset, copy, nsg, data_len, pos; - struct sk_buff *skb, *skb_iter, *unused; + int err, offset, copy, data_len, pos; + struct sk_buff *skb, *skb_iter; struct scatterlist sg[1]; struct strp_msg *rxm; char *orig_buf, *buf; - skb = tls_strp_msg(sw_ctx); - rxm = strp_msg(skb); - offset = rxm->offset; - + rxm = strp_msg(tls_strp_msg(sw_ctx)); orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation); if (!orig_buf) return -ENOMEM; buf = orig_buf; - nsg = skb_cow_data(skb, 0, &unused); - if (unlikely(nsg < 0)) { - err = nsg; + err = tls_strp_msg_cow(sw_ctx); + if (unlikely(err)) goto free_buf; - } + + skb = tls_strp_msg(sw_ctx); + rxm = strp_msg(skb); + offset = rxm->offset; sg_init_table(sg, 1); sg_set_buf(&sg[0], buf, diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 9703636cfc60..08ddf9d837ae 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -725,6 +725,10 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, if (tx) { ctx->sk_write_space = sk->sk_write_space; sk->sk_write_space = tls_write_space; + } else { + struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); + + tls_strp_check_rcv(&rx_ctx->strp); } return 0; @@ -1141,20 +1145,28 @@ static int __init tls_register(void) if (err) return err; + err = tls_strp_dev_init(); + if (err) + goto err_pernet; + err = tls_device_init(); - if (err) { - unregister_pernet_subsys(&tls_proc_ops); - return err; - } + if (err) + goto err_strp; tcp_register_ulp(&tcp_tls_ulp_ops); return 0; +err_strp: + tls_strp_dev_exit(); +err_pernet: + unregister_pernet_subsys(&tls_proc_ops); + return err; } static void __exit tls_unregister(void) { tcp_unregister_ulp(&tcp_tls_ulp_ops); + tls_strp_dev_exit(); tls_device_cleanup(); unregister_pernet_subsys(&tls_proc_ops); } diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c index 9ccab79a6e1e..b945288c312e 100644 --- a/net/tls/tls_strp.c +++ b/net/tls/tls_strp.c @@ -1,17 +1,493 @@ // SPDX-License-Identifier: GPL-2.0-only +/* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */ #include <linux/skbuff.h> +#include <linux/workqueue.h> +#include <net/strparser.h> +#include <net/tcp.h> +#include <net/sock.h> +#include <net/tls.h> #include "tls.h" -int tls_strp_msg_hold(struct sock *sk, struct sk_buff *skb, - struct sk_buff_head *dst) +static struct workqueue_struct *tls_strp_wq; + +static void tls_strp_abort_strp(struct tls_strparser *strp, int err) +{ + if (strp->stopped) + return; + + strp->stopped = 1; + + /* Report an error on the lower socket */ + strp->sk->sk_err = -err; + sk_error_report(strp->sk); +} + +static void tls_strp_anchor_free(struct tls_strparser *strp) +{ + struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); + + DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); + shinfo->frag_list = NULL; + consume_skb(strp->anchor); + strp->anchor = NULL; +} + +/* Create a new skb with the contents of input copied to its page frags */ +static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp) { - struct sk_buff *clone; + struct strp_msg *rxm; + struct sk_buff *skb; + int i, err, offset; + + skb = alloc_skb_with_frags(0, strp->anchor->len, TLS_PAGE_ORDER, + &err, strp->sk->sk_allocation); + if (!skb) + return NULL; + + offset = strp->stm.offset; + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + + WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset, + skb_frag_address(frag), + skb_frag_size(frag))); + offset += skb_frag_size(frag); + } + + skb_copy_header(skb, strp->anchor); + rxm = strp_msg(skb); + rxm->offset = 0; + return skb; +} + +/* Steal the input skb, input msg is invalid after calling this function */ +struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx) +{ + struct tls_strparser *strp = &ctx->strp; + +#ifdef CONFIG_TLS_DEVICE + DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted); +#else + /* This function turns an input into an output, + * that can only happen if we have offload. + */ + WARN_ON(1); +#endif + + if (strp->copy_mode) { + struct sk_buff *skb; + + /* Replace anchor with an empty skb, this is a little + * dangerous but __tls_cur_msg() warns on empty skbs + * so hopefully we'll catch abuses. + */ + skb = alloc_skb(0, strp->sk->sk_allocation); + if (!skb) + return NULL; - clone = skb_clone(skb, sk->sk_allocation); - if (!clone) + swap(strp->anchor, skb); + return skb; + } + + return tls_strp_msg_make_copy(strp); +} + +/* Force the input skb to be in copy mode. The data ownership remains + * with the input skb itself (meaning unpause will wipe it) but it can + * be modified. + */ +int tls_strp_msg_cow(struct tls_sw_context_rx *ctx) +{ + struct tls_strparser *strp = &ctx->strp; + struct sk_buff *skb; + + if (strp->copy_mode) + return 0; + + skb = tls_strp_msg_make_copy(strp); + if (!skb) return -ENOMEM; - __skb_queue_tail(dst, clone); + + tls_strp_anchor_free(strp); + strp->anchor = skb; + + tcp_read_done(strp->sk, strp->stm.full_len); + strp->copy_mode = 1; + return 0; } + +/* Make a clone (in the skb sense) of the input msg to keep a reference + * to the underlying data. The reference-holding skbs get placed on + * @dst. + */ +int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst) +{ + struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); + + if (strp->copy_mode) { + struct sk_buff *skb; + + WARN_ON_ONCE(!shinfo->nr_frags); + + /* We can't skb_clone() the anchor, it gets wiped by unpause */ + skb = alloc_skb(0, strp->sk->sk_allocation); + if (!skb) + return -ENOMEM; + + __skb_queue_tail(dst, strp->anchor); + strp->anchor = skb; + } else { + struct sk_buff *iter, *clone; + int chunk, len, offset; + + offset = strp->stm.offset; + len = strp->stm.full_len; + iter = shinfo->frag_list; + + while (len > 0) { + if (iter->len <= offset) { + offset -= iter->len; + goto next; + } + + chunk = iter->len - offset; + offset = 0; + + clone = skb_clone(iter, strp->sk->sk_allocation); + if (!clone) + return -ENOMEM; + __skb_queue_tail(dst, clone); + + len -= chunk; +next: + iter = iter->next; + } + } + + return 0; +} + +static void tls_strp_flush_anchor_copy(struct tls_strparser *strp) +{ + struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); + int i; + + DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); + + for (i = 0; i < shinfo->nr_frags; i++) + __skb_frag_unref(&shinfo->frags[i], false); + shinfo->nr_frags = 0; + strp->copy_mode = 0; +} + +static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, + unsigned int offset, size_t in_len) +{ + struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data; + size_t sz, len, chunk; + struct sk_buff *skb; + skb_frag_t *frag; + + if (strp->msg_ready) + return 0; + + skb = strp->anchor; + frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; + + len = in_len; + /* First make sure we got the header */ + if (!strp->stm.full_len) { + /* Assume one page is more than enough for headers */ + chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag)); + WARN_ON_ONCE(skb_copy_bits(in_skb, offset, + skb_frag_address(frag) + + skb_frag_size(frag), + chunk)); + + sz = tls_rx_msg_size(strp, strp->anchor); + if (sz < 0) { + desc->error = sz; + return 0; + } + + /* We may have over-read, sz == 0 is guaranteed under-read */ + if (sz > 0) + chunk = min_t(size_t, chunk, sz - skb->len); + + skb->len += chunk; + skb->data_len += chunk; + skb_frag_size_add(frag, chunk); + frag++; + len -= chunk; + offset += chunk; + + strp->stm.full_len = sz; + if (!strp->stm.full_len) + goto read_done; + } + + /* Load up more data */ + while (len && strp->stm.full_len > skb->len) { + chunk = min_t(size_t, len, strp->stm.full_len - skb->len); + chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag)); + WARN_ON_ONCE(skb_copy_bits(in_skb, offset, + skb_frag_address(frag) + + skb_frag_size(frag), + chunk)); + + skb->len += chunk; + skb->data_len += chunk; + skb_frag_size_add(frag, chunk); + frag++; + len -= chunk; + offset += chunk; + } + + if (strp->stm.full_len == skb->len) { + desc->count = 0; + + strp->msg_ready = 1; + tls_rx_msg_ready(strp); + } + +read_done: + return in_len - len; +} + +static int tls_strp_read_copyin(struct tls_strparser *strp) +{ + struct socket *sock = strp->sk->sk_socket; + read_descriptor_t desc; + + desc.arg.data = strp; + desc.error = 0; + desc.count = 1; /* give more than one skb per call */ + + /* sk should be locked here, so okay to do read_sock */ + sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin); + + return desc.error; +} + +static int tls_strp_read_short(struct tls_strparser *strp) +{ + struct skb_shared_info *shinfo; + struct page *page; + int need_spc, len; + + /* If the rbuf is small or rcv window has collapsed to 0 we need + * to read the data out. Otherwise the connection will stall. + * Without pressure threshold of INT_MAX will never be ready. + */ + if (likely(!tcp_epollin_ready(strp->sk, INT_MAX))) + return 0; + + shinfo = skb_shinfo(strp->anchor); + shinfo->frag_list = NULL; + + /* If we don't know the length go max plus page for cipher overhead */ + need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; + + for (len = need_spc; len > 0; len -= PAGE_SIZE) { + page = alloc_page(strp->sk->sk_allocation); + if (!page) { + tls_strp_flush_anchor_copy(strp); + return -ENOMEM; + } + + skb_fill_page_desc(strp->anchor, shinfo->nr_frags++, + page, 0, 0); + } + + strp->copy_mode = 1; + strp->stm.offset = 0; + + strp->anchor->len = 0; + strp->anchor->data_len = 0; + strp->anchor->truesize = round_up(need_spc, PAGE_SIZE); + + tls_strp_read_copyin(strp); + + return 0; +} + +static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len) +{ + struct tcp_sock *tp = tcp_sk(strp->sk); + struct sk_buff *first; + u32 offset; + + first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset); + if (WARN_ON_ONCE(!first)) + return; + + /* Bestow the state onto the anchor */ + strp->anchor->len = offset + len; + strp->anchor->data_len = offset + len; + strp->anchor->truesize = offset + len; + + skb_shinfo(strp->anchor)->frag_list = first; + + skb_copy_header(strp->anchor, first); + strp->anchor->destructor = NULL; + + strp->stm.offset = offset; +} + +void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) +{ + struct strp_msg *rxm; + struct tls_msg *tlm; + + DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready); + DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len); + + if (!strp->copy_mode && force_refresh) { + if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len)) + return; + + tls_strp_load_anchor_with_queue(strp, strp->stm.full_len); + } + + rxm = strp_msg(strp->anchor); + rxm->full_len = strp->stm.full_len; + rxm->offset = strp->stm.offset; + tlm = tls_msg(strp->anchor); + tlm->control = strp->mark; +} + +/* Called with lock held on lower socket */ +static int tls_strp_read_sock(struct tls_strparser *strp) +{ + int sz, inq; + + inq = tcp_inq(strp->sk); + if (inq < 1) + return 0; + + if (unlikely(strp->copy_mode)) + return tls_strp_read_copyin(strp); + + if (inq < strp->stm.full_len) + return tls_strp_read_short(strp); + + if (!strp->stm.full_len) { + tls_strp_load_anchor_with_queue(strp, inq); + + sz = tls_rx_msg_size(strp, strp->anchor); + if (sz < 0) { + tls_strp_abort_strp(strp, sz); + return sz; + } + + strp->stm.full_len = sz; + + if (!strp->stm.full_len || inq < strp->stm.full_len) + return tls_strp_read_short(strp); + } + + strp->msg_ready = 1; + tls_rx_msg_ready(strp); + + return 0; +} + +void tls_strp_check_rcv(struct tls_strparser *strp) +{ + if (unlikely(strp->stopped) || strp->msg_ready) + return; + + if (tls_strp_read_sock(strp) == -ENOMEM) + queue_work(tls_strp_wq, &strp->work); +} + +/* Lower sock lock held */ +void tls_strp_data_ready(struct tls_strparser *strp) +{ + /* This check is needed to synchronize with do_tls_strp_work. + * do_tls_strp_work acquires a process lock (lock_sock) whereas + * the lock held here is bh_lock_sock. The two locks can be + * held by different threads at the same time, but bh_lock_sock + * allows a thread in BH context to safely check if the process + * lock is held. In this case, if the lock is held, queue work. + */ + if (sock_owned_by_user_nocheck(strp->sk)) { + queue_work(tls_strp_wq, &strp->work); + return; + } + + tls_strp_check_rcv(strp); +} + +static void tls_strp_work(struct work_struct *w) +{ + struct tls_strparser *strp = + container_of(w, struct tls_strparser, work); + + lock_sock(strp->sk); + tls_strp_check_rcv(strp); + release_sock(strp->sk); +} + +void tls_strp_msg_done(struct tls_strparser *strp) +{ + WARN_ON(!strp->stm.full_len); + + if (likely(!strp->copy_mode)) + tcp_read_done(strp->sk, strp->stm.full_len); + else + tls_strp_flush_anchor_copy(strp); + + strp->msg_ready = 0; + memset(&strp->stm, 0, sizeof(strp->stm)); + + tls_strp_check_rcv(strp); +} + +void tls_strp_stop(struct tls_strparser *strp) +{ + strp->stopped = 1; +} + +int tls_strp_init(struct tls_strparser *strp, struct sock *sk) +{ + memset(strp, 0, sizeof(*strp)); + + strp->sk = sk; + + strp->anchor = alloc_skb(0, GFP_KERNEL); + if (!strp->anchor) + return -ENOMEM; + + INIT_WORK(&strp->work, tls_strp_work); + + return 0; +} + +/* strp must already be stopped so that tls_strp_recv will no longer be called. + * Note that tls_strp_done is not called with the lower socket held. + */ +void tls_strp_done(struct tls_strparser *strp) +{ + WARN_ON(!strp->stopped); + + cancel_work_sync(&strp->work); + tls_strp_anchor_free(strp); +} + +int __init tls_strp_dev_init(void) +{ + tls_strp_wq = create_singlethread_workqueue("kstrp"); + if (unlikely(!tls_strp_wq)) + return -ENOMEM; + + return 0; +} + +void tls_strp_dev_exit(void) +{ + destroy_workqueue(tls_strp_wq); +} diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index ed5e6f1df9c7..0fc24a5ce208 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1283,13 +1283,13 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, static int tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, - long timeo) + bool released, long timeo) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); DEFINE_WAIT_FUNC(wait, woken_wake_function); - while (!ctx->recv_pkt) { + while (!tls_strp_msg_ready(ctx)) { if (!sk_psock_queue_empty(psock)) return 0; @@ -1297,8 +1297,8 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, return sock_error(sk); if (!skb_queue_empty(&sk->sk_receive_queue)) { - __strp_unpause(&ctx->strp); - if (ctx->recv_pkt) + tls_strp_check_rcv(&ctx->strp); + if (tls_strp_msg_ready(ctx)) break; } @@ -1311,10 +1311,12 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, if (nonblock || !timeo) return -EAGAIN; + released = true; add_wait_queue(sk_sleep(sk), &wait); sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); sk_wait_event(sk, &timeo, - ctx->recv_pkt || !sk_psock_queue_empty(psock), + tls_strp_msg_ready(ctx) || + !sk_psock_queue_empty(psock), &wait); sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); remove_wait_queue(sk_sleep(sk), &wait); @@ -1324,6 +1326,8 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, return sock_intr_errno(timeo); } + tls_strp_msg_load(&ctx->strp, released); + return 1; } @@ -1408,13 +1412,15 @@ tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb, /* Decrypt handlers * - * tls_decrypt_sg() and tls_decrypt_device() are decrypt handlers. + * tls_decrypt_sw() and tls_decrypt_device() are decrypt handlers. * They must transform the darg in/out argument are as follows: * | Input | Output * ------------------------------------------------------------------- * zc | Zero-copy decrypt allowed | Zero-copy performed * async | Async decrypt allowed | Async crypto used / in progress * skb | * | Output skb + * + * If ZC decryption was performed darg.skb will point to the input skb. */ /* This function decrypts the input skb into either out_iov or in out_sg @@ -1567,7 +1573,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, clear_skb = NULL; if (unlikely(darg->async)) { - err = tls_strp_msg_hold(sk, skb, &ctx->async_hold); + err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold); if (err) __skb_queue_tail(&ctx->async_hold, darg->skb); return err; @@ -1588,49 +1594,22 @@ exit_free_skb: } static int -tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx, - struct tls_decrypt_arg *darg) -{ - struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - int err; - - if (tls_ctx->rx_conf != TLS_HW) - return 0; - - err = tls_device_decrypted(sk, tls_ctx); - if (err <= 0) - return err; - - darg->zc = false; - darg->async = false; - darg->skb = tls_strp_msg(ctx); - ctx->recv_pkt = NULL; - return 1; -} - -static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest, - struct tls_decrypt_arg *darg) +tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx, + struct msghdr *msg, struct tls_decrypt_arg *darg) { - struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_prot_info *prot = &tls_ctx->prot_info; struct strp_msg *rxm; int pad, err; - err = tls_decrypt_device(sk, tls_ctx, darg); - if (err < 0) - return err; - if (err) - goto decrypt_done; - - err = tls_decrypt_sg(sk, dest, NULL, darg); + err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg); if (err < 0) { if (err == -EBADMSG) TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); return err; } - if (darg->async) - goto decrypt_done; + /* keep going even for ->async, the code below is TLS 1.3 */ + /* If opportunistic TLS 1.3 ZC failed retry without ZC */ if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION && darg->tail != TLS_RECORD_TYPE_DATA)) { @@ -1638,21 +1617,87 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest, if (!darg->tail) TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY); - return tls_rx_one_record(sk, dest, darg); + return tls_decrypt_sw(sk, tls_ctx, msg, darg); } -decrypt_done: - if (darg->skb == ctx->recv_pkt) - ctx->recv_pkt = NULL; - pad = tls_padding_length(prot, darg->skb, darg); if (pad < 0) { - consume_skb(darg->skb); + if (darg->skb != tls_strp_msg(ctx)) + consume_skb(darg->skb); return pad; } rxm = strp_msg(darg->skb); rxm->full_len -= pad; + + return 0; +} + +static int +tls_decrypt_device(struct sock *sk, struct msghdr *msg, + struct tls_context *tls_ctx, struct tls_decrypt_arg *darg) +{ + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct strp_msg *rxm; + int pad, err; + + if (tls_ctx->rx_conf != TLS_HW) + return 0; + + err = tls_device_decrypted(sk, tls_ctx); + if (err <= 0) + return err; + + pad = tls_padding_length(prot, tls_strp_msg(ctx), darg); + if (pad < 0) + return pad; + + darg->async = false; + darg->skb = tls_strp_msg(ctx); + /* ->zc downgrade check, in case TLS 1.3 gets here */ + darg->zc &= !(prot->version == TLS_1_3_VERSION && + tls_msg(darg->skb)->control != TLS_RECORD_TYPE_DATA); + + rxm = strp_msg(darg->skb); + rxm->full_len -= pad; + + if (!darg->zc) { + /* Non-ZC case needs a real skb */ + darg->skb = tls_strp_msg_detach(ctx); + if (!darg->skb) + return -ENOMEM; + } else { + unsigned int off, len; + + /* In ZC case nobody cares about the output skb. + * Just copy the data here. Note the skb is not fully trimmed. + */ + off = rxm->offset + prot->prepend_size; + len = rxm->full_len - prot->overhead_size; + + err = skb_copy_datagram_msg(darg->skb, off, msg, len); + if (err) + return err; + } + return 1; +} + +static int tls_rx_one_record(struct sock *sk, struct msghdr *msg, + struct tls_decrypt_arg *darg) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct strp_msg *rxm; + int err; + + err = tls_decrypt_device(sk, msg, tls_ctx, darg); + if (!err) + err = tls_decrypt_sw(sk, tls_ctx, msg, darg); + if (err < 0) + return err; + + rxm = strp_msg(darg->skb); rxm->offset += prot->prepend_size; rxm->full_len -= prot->overhead_size; tls_advance_record_sn(sk, prot, &tls_ctx->rx); @@ -1692,9 +1737,7 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, static void tls_rx_rec_done(struct tls_sw_context_rx *ctx) { - consume_skb(ctx->recv_pkt); - ctx->recv_pkt = NULL; - __strp_unpause(&ctx->strp); + tls_strp_msg_done(&ctx->strp); } /* This function traverses the rx_list in tls receive context to copies the @@ -1781,7 +1824,7 @@ out: return copied ? : err; } -static void +static bool tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, size_t len_left, size_t decrypted, ssize_t done, size_t *flushed_at) @@ -1789,14 +1832,14 @@ tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, size_t max_rec; if (len_left <= decrypted) - return; + return false; max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE; if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec) - return; + return false; *flushed_at = done; - sk_flush_backlog(sk); + return sk_flush_backlog(sk); } static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, @@ -1868,13 +1911,13 @@ int tls_sw_recvmsg(struct sock *sk, size_t flushed_at = 0; struct strp_msg *rxm; struct tls_msg *tlm; - struct sk_buff *skb; ssize_t copied = 0; bool async = false; int target, err = 0; long timeo; bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool is_peek = flags & MSG_PEEK; + bool released = true; bool bpf_strp_enabled; bool zc_capable; @@ -1907,11 +1950,12 @@ int tls_sw_recvmsg(struct sock *sk, zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek && ctx->zc_capable; decrypted = 0; - while (len && (decrypted + copied < target || ctx->recv_pkt)) { + while (len && (decrypted + copied < target || tls_strp_msg_ready(ctx))) { struct tls_decrypt_arg darg; int to_decrypt, chunk; - err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo); + err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, released, + timeo); if (err <= 0) { if (psock) { chunk = sk_msg_recvmsg(sk, psock, msg, len, @@ -1927,8 +1971,8 @@ int tls_sw_recvmsg(struct sock *sk, memset(&darg.inargs, 0, sizeof(darg.inargs)); - rxm = strp_msg(ctx->recv_pkt); - tlm = tls_msg(ctx->recv_pkt); + rxm = strp_msg(tls_strp_msg(ctx)); + tlm = tls_msg(tls_strp_msg(ctx)); to_decrypt = rxm->full_len - prot->overhead_size; @@ -1942,16 +1986,12 @@ int tls_sw_recvmsg(struct sock *sk, else darg.async = false; - err = tls_rx_one_record(sk, &msg->msg_iter, &darg); + err = tls_rx_one_record(sk, msg, &darg); if (err < 0) { tls_err_abort(sk, -EBADMSG); goto recv_end; } - skb = darg.skb; - rxm = strp_msg(skb); - tlm = tls_msg(skb); - async |= darg.async; /* If the type of records being processed is not known yet, @@ -1961,24 +2001,30 @@ int tls_sw_recvmsg(struct sock *sk, * is known just after record is dequeued from stream parser. * For tls1.3, we disable async. */ - err = tls_record_content_type(msg, tlm, &control); + err = tls_record_content_type(msg, tls_msg(darg.skb), &control); if (err <= 0) { + DEBUG_NET_WARN_ON_ONCE(darg.zc); tls_rx_rec_done(ctx); put_on_rx_list_err: - __skb_queue_tail(&ctx->rx_list, skb); + __skb_queue_tail(&ctx->rx_list, darg.skb); goto recv_end; } /* periodically flush backlog, and feed strparser */ - tls_read_flush_backlog(sk, prot, len, to_decrypt, - decrypted + copied, &flushed_at); + released = tls_read_flush_backlog(sk, prot, len, to_decrypt, + decrypted + copied, + &flushed_at); /* TLS 1.3 may have updated the length by more than overhead */ + rxm = strp_msg(darg.skb); chunk = rxm->full_len; tls_rx_rec_done(ctx); if (!darg.zc) { bool partially_consumed = chunk > len; + struct sk_buff *skb = darg.skb; + + DEBUG_NET_WARN_ON_ONCE(darg.skb == tls_strp_msg(ctx)); if (async) { /* TLS 1.2-only, to_decrypt must be text len */ @@ -1992,6 +2038,7 @@ put_on_rx_list: } if (bpf_strp_enabled) { + released = true; err = sk_psock_tls_strp_read(psock, skb); if (err != __SK_PASS) { rxm->offset = rxm->offset + rxm->full_len; @@ -2018,13 +2065,13 @@ put_on_rx_list: rxm->full_len -= chunk; goto put_on_rx_list; } + + consume_skb(skb); } decrypted += chunk; len -= chunk; - consume_skb(skb); - /* Return full control message to userspace before trying * to parse another message type */ @@ -2098,7 +2145,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, struct tls_decrypt_arg darg; err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK, - timeo); + true, timeo); if (err <= 0) goto splice_read_end; @@ -2158,23 +2205,21 @@ bool tls_sw_sock_is_readable(struct sock *sk) ingress_empty = list_empty(&psock->ingress_msg); rcu_read_unlock(); - return !ingress_empty || ctx->recv_pkt || + return !ingress_empty || tls_strp_msg_ready(ctx) || !skb_queue_empty(&ctx->rx_list); } -static int tls_read_size(struct strparser *strp, struct sk_buff *skb) +int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_prot_info *prot = &tls_ctx->prot_info; char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; - struct strp_msg *rxm = strp_msg(skb); - struct tls_msg *tlm = tls_msg(skb); size_t cipher_overhead; size_t data_len = 0; int ret; /* Verify that we have a full TLS header, or wait for more data */ - if (rxm->offset + prot->prepend_size > skb->len) + if (strp->stm.offset + prot->prepend_size > skb->len) return 0; /* Sanity-check size of on-stack buffer. */ @@ -2184,11 +2229,11 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) } /* Linearize header to local buffer */ - ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size); + ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size); if (ret < 0) goto read_failure; - tlm->control = header[0]; + strp->mark = header[0]; data_len = ((header[4] & 0xFF) | (header[3] << 8)); @@ -2215,7 +2260,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) } tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE, - TCP_SKB_CB(skb)->seq + rxm->offset); + TCP_SKB_CB(skb)->seq + strp->stm.offset); return data_len + TLS_HEADER_SIZE; read_failure: @@ -2224,14 +2269,11 @@ read_failure: return ret; } -static void tls_queue(struct strparser *strp, struct sk_buff *skb) +void tls_rx_msg_ready(struct tls_strparser *strp) { - struct tls_context *tls_ctx = tls_get_ctx(strp->sk); - struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - - ctx->recv_pkt = skb; - strp_pause(strp); + struct tls_sw_context_rx *ctx; + ctx = container_of(strp, struct tls_sw_context_rx, strp); ctx->saved_data_ready(strp->sk); } @@ -2241,7 +2283,7 @@ static void tls_data_ready(struct sock *sk) struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct sk_psock *psock; - strp_data_ready(&ctx->strp); + tls_strp_data_ready(&ctx->strp); psock = sk_psock_get(sk); if (psock) { @@ -2317,13 +2359,11 @@ void tls_sw_release_resources_rx(struct sock *sk) kfree(tls_ctx->rx.iv); if (ctx->aead_recv) { - kfree_skb(ctx->recv_pkt); - ctx->recv_pkt = NULL; __skb_queue_purge(&ctx->rx_list); crypto_free_aead(ctx->aead_recv); - strp_stop(&ctx->strp); + tls_strp_stop(&ctx->strp); /* If tls_sw_strparser_arm() was not called (cleanup paths) - * we still want to strp_stop(), but sk->sk_data_ready was + * we still want to tls_strp_stop(), but sk->sk_data_ready was * never swapped. */ if (ctx->saved_data_ready) { @@ -2338,7 +2378,7 @@ void tls_sw_strparser_done(struct tls_context *tls_ctx) { struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - strp_done(&ctx->strp); + tls_strp_done(&ctx->strp); } void tls_sw_free_ctx_rx(struct tls_context *tls_ctx) @@ -2411,8 +2451,6 @@ void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx) rx_ctx->saved_data_ready = sk->sk_data_ready; sk->sk_data_ready = tls_data_ready; write_unlock_bh(&sk->sk_callback_lock); - - strp_check_rcv(&rx_ctx->strp); } void tls_update_rx_zc_capable(struct tls_context *tls_ctx) @@ -2432,7 +2470,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) struct tls_sw_context_rx *sw_ctx_rx = NULL; struct cipher_context *cctx; struct crypto_aead **aead; - struct strp_callbacks cb; u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size; struct crypto_tfm *tfm; char *iv, *rec_seq, *key, *salt, *cipher_name; @@ -2666,12 +2703,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) crypto_info->version != TLS_1_3_VERSION && !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC); - /* Set up strparser */ - memset(&cb, 0, sizeof(cb)); - cb.rcv_msg = tls_queue; - cb.parse_msg = tls_read_size; - - strp_init(&sw_ctx_rx->strp, sk, &cb); + tls_strp_init(&sw_ctx_rx->strp, sk); } goto out; |