diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_device_fallback.c | 6 | ||||
-rw-r--r-- | net/tls/tls_main.c | 75 | ||||
-rw-r--r-- | net/tls/tls_proc.c | 1 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 84 |
4 files changed, 146 insertions, 20 deletions
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index e40bedd112b6..3bae29ae57ca 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -232,7 +232,7 @@ static int fill_sg_in(struct scatterlist *sg_in, s32 *sync_size, int *resync_sgs) { - int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb); + int tcp_payload_offset = skb_tcp_all_headers(skb); int payload_len = skb->len - tcp_payload_offset; u32 tcp_seq = ntohl(tcp_hdr(skb)->seq); struct tls_record_info *record; @@ -310,8 +310,8 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, struct sk_buff *skb, s32 sync_size, u64 rcd_sn) { - int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb); struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); + int tcp_payload_offset = skb_tcp_all_headers(skb); int payload_len = skb->len - tcp_payload_offset; void *buf, *iv, *aad, *dummy_buf; struct aead_request *aead_req; @@ -372,7 +372,7 @@ free_nskb: static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb) { - int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb); + int tcp_payload_offset = skb_tcp_all_headers(skb); struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); int payload_len = skb->len - tcp_payload_offset; diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 2ffede463e4a..1b3efc96db0b 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -533,6 +533,37 @@ static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, return 0; } +static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, + int __user *optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + unsigned int value; + int err, len; + + if (ctx->prot_info.version != TLS_1_3_VERSION) + return -EINVAL; + + if (get_user(len, optlen)) + return -EFAULT; + if (len < sizeof(value)) + return -EINVAL; + + lock_sock(sk); + err = -EINVAL; + if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) + value = ctx->rx_no_pad; + release_sock(sk); + if (err) + return err; + + if (put_user(sizeof(value), optlen)) + return -EFAULT; + if (copy_to_user(optval, &value, sizeof(value))) + return -EFAULT; + + return 0; +} + static int do_tls_getsockopt(struct sock *sk, int optname, char __user *optval, int __user *optlen) { @@ -547,6 +578,9 @@ static int do_tls_getsockopt(struct sock *sk, int optname, case TLS_TX_ZEROCOPY_RO: rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); break; + case TLS_RX_EXPECT_NO_PAD: + rc = do_tls_getsockopt_no_pad(sk, optval, optlen); + break; default: rc = -ENOPROTOOPT; break; @@ -718,6 +752,38 @@ static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, return 0; } +static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, + unsigned int optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + u32 val; + int rc; + + if (ctx->prot_info.version != TLS_1_3_VERSION || + sockptr_is_null(optval) || optlen < sizeof(val)) + return -EINVAL; + + rc = copy_from_sockptr(&val, optval, sizeof(val)); + if (rc) + return -EFAULT; + if (val > 1) + return -EINVAL; + rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); + if (rc < 1) + return rc == 0 ? -EINVAL : rc; + + lock_sock(sk); + rc = -EINVAL; + if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { + ctx->rx_no_pad = val; + tls_update_rx_zc_capable(ctx); + rc = 0; + } + release_sock(sk); + + return rc; +} + static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, unsigned int optlen) { @@ -736,6 +802,9 @@ static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); release_sock(sk); break; + case TLS_RX_EXPECT_NO_PAD: + rc = do_tls_setsockopt_no_pad(sk, optval, optlen); + break; default: rc = -ENOPROTOOPT; break; @@ -976,6 +1045,11 @@ static int tls_get_info(const struct sock *sk, struct sk_buff *skb) if (err) goto nla_failure; } + if (ctx->rx_no_pad) { + err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); + if (err) + goto nla_failure; + } rcu_read_unlock(); nla_nest_end(skb, start); @@ -997,6 +1071,7 @@ static size_t tls_get_info_size(const struct sock *sk) nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ + nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ 0; return size; diff --git a/net/tls/tls_proc.c b/net/tls/tls_proc.c index feeceb0e4cb4..0c200000cc45 100644 --- a/net/tls/tls_proc.c +++ b/net/tls/tls_proc.c @@ -18,6 +18,7 @@ static const struct snmp_mib tls_mib_list[] = { SNMP_MIB_ITEM("TlsRxDevice", LINUX_MIB_TLSRXDEVICE), SNMP_MIB_ITEM("TlsDecryptError", LINUX_MIB_TLSDECRYPTERROR), SNMP_MIB_ITEM("TlsRxDeviceResync", LINUX_MIB_TLSRXDEVICERESYNC), + SNMP_MIB_ITEM("TlsDecryptRetry", LINUX_MIN_TLSDECRYPTRETRY), SNMP_MIB_SENTINEL }; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index e30649f6dde5..f1777d67527f 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -47,6 +47,7 @@ struct tls_decrypt_arg { bool zc; bool async; + u8 tail; }; noinline void tls_err_abort(struct sock *sk, int err) @@ -133,7 +134,8 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len) return __skb_nsg(skb, offset, len, 0); } -static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb) +static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb, + struct tls_decrypt_arg *darg) { struct strp_msg *rxm = strp_msg(skb); struct tls_msg *tlm = tls_msg(skb); @@ -142,7 +144,7 @@ static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb) /* Determine zero-padding length */ if (prot->version == TLS_1_3_VERSION) { int offset = rxm->full_len - TLS_TAG_SIZE - 1; - char content_type = 0; + char content_type = darg->zc ? darg->tail : 0; int err; while (content_type == 0) { @@ -1415,18 +1417,18 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, struct strp_msg *rxm = strp_msg(skb); struct tls_msg *tlm = tls_msg(skb); int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; + u8 *aad, *iv, *tail, *mem = NULL; struct aead_request *aead_req; struct sk_buff *unused; - u8 *aad, *iv, *mem = NULL; struct scatterlist *sgin = NULL; struct scatterlist *sgout = NULL; - const int data_len = rxm->full_len - prot->overhead_size + - prot->tail_size; + const int data_len = rxm->full_len - prot->overhead_size; + int tail_pages = !!prot->tail_size; int iv_offset = 0; if (darg->zc && (out_iov || out_sg)) { if (out_iov) - n_sgout = 1 + + n_sgout = 1 + tail_pages + iov_iter_npages_cap(out_iov, INT_MAX, data_len); else n_sgout = sg_nents(out_sg); @@ -1450,9 +1452,10 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, mem_size = aead_size + (nsg * sizeof(struct scatterlist)); mem_size = mem_size + prot->aad_size; mem_size = mem_size + MAX_IV_SIZE; + mem_size = mem_size + prot->tail_size; /* Allocate a single block of memory which contains - * aead_req || sgin[] || sgout[] || aad || iv. + * aead_req || sgin[] || sgout[] || aad || iv || tail. * This order achieves correct alignment for aead_req, sgin, sgout. */ mem = kmalloc(mem_size, sk->sk_allocation); @@ -1465,6 +1468,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, sgout = sgin + n_sgin; aad = (u8 *)(sgout + n_sgout); iv = aad + prot->aad_size; + tail = iv + MAX_IV_SIZE; /* For CCM based ciphers, first byte of nonce+iv is a constant */ switch (prot->cipher_type) { @@ -1518,9 +1522,16 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1], - (n_sgout - 1)); + (n_sgout - 1 - tail_pages)); if (err < 0) goto fallback_to_reg_recv; + + if (prot->tail_size) { + sg_unmark_end(&sgout[pages]); + sg_set_buf(&sgout[pages + 1], tail, + prot->tail_size); + sg_mark_end(&sgout[pages + 1]); + } } else if (out_sg) { memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); } else { @@ -1535,10 +1546,13 @@ fallback_to_reg_recv: /* Prepare and submit AEAD request */ err = tls_do_decryption(sk, skb, sgin, sgout, iv, - data_len, aead_req, darg); + data_len + prot->tail_size, aead_req, darg); if (darg->async) return 0; + if (prot->tail_size) + darg->tail = *tail; + /* Release the pages in case iov was mapped to pages */ for (; pages > 0; pages--) put_page(sg_page(&sgout[pages])); @@ -1583,9 +1597,16 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, } if (darg->async) goto decrypt_next; + /* If opportunistic TLS 1.3 ZC failed retry without ZC */ + if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION && + darg->tail != TLS_RECORD_TYPE_DATA)) { + darg->zc = false; + TLS_INC_STATS(sock_net(sk), LINUX_MIN_TLSDECRYPTRETRY); + return decrypt_skb_update(sk, skb, dest, darg); + } decrypt_done: - pad = padding_length(prot, skb); + pad = tls_padding_length(prot, skb, darg); if (pad < 0) return pad; @@ -1717,6 +1738,24 @@ out: return copied ? : err; } +static void +tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, + size_t len_left, size_t decrypted, ssize_t done, + size_t *flushed_at) +{ + size_t max_rec; + + if (len_left <= decrypted) + return; + + max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE; + if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec) + return; + + *flushed_at = done; + sk_flush_backlog(sk); +} + int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, @@ -1729,6 +1768,7 @@ int tls_sw_recvmsg(struct sock *sk, struct sk_psock *psock; unsigned char control = 0; ssize_t decrypted = 0; + size_t flushed_at = 0; struct strp_msg *rxm; struct tls_msg *tlm; struct sk_buff *skb; @@ -1767,7 +1807,7 @@ int tls_sw_recvmsg(struct sock *sk, timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek && - prot->version != TLS_1_3_VERSION; + ctx->zc_capable; decrypted = 0; while (len && (decrypted + copied < target || ctx->recv_pkt)) { struct tls_decrypt_arg darg = {}; @@ -1818,6 +1858,10 @@ int tls_sw_recvmsg(struct sock *sk, if (err <= 0) goto recv_end; + /* periodically flush backlog, and feed strparser */ + tls_read_flush_backlog(sk, prot, len, to_decrypt, + decrypted + copied, &flushed_at); + ctx->recv_pkt = NULL; __strp_unpause(&ctx->strp); __skb_queue_tail(&ctx->rx_list, skb); @@ -2249,6 +2293,14 @@ void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx) strp_check_rcv(&rx_ctx->strp); } +void tls_update_rx_zc_capable(struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); + + rx_ctx->zc_capable = tls_ctx->rx_no_pad || + tls_ctx->prot_info.version != TLS_1_3_VERSION; +} + int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) { struct tls_context *tls_ctx = tls_get_ctx(sk); @@ -2484,12 +2536,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) if (sw_ctx_rx) { tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv); - if (crypto_info->version == TLS_1_3_VERSION) - sw_ctx_rx->async_capable = 0; - else - sw_ctx_rx->async_capable = - !!(tfm->__crt_alg->cra_flags & - CRYPTO_ALG_ASYNC); + tls_update_rx_zc_capable(ctx); + sw_ctx_rx->async_capable = + crypto_info->version != TLS_1_3_VERSION && + !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC); /* Set up strparser */ memset(&cb, 0, sizeof(cb)); |