summaryrefslogtreecommitdiff
path: root/net/tls/tls_sw.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r--net/tls/tls_sw.c471
1 files changed, 293 insertions, 178 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 09370f853031..ed5e6f1df9c7 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -47,9 +47,13 @@
#include "tls.h"
struct tls_decrypt_arg {
+ struct_group(inargs,
bool zc;
bool async;
u8 tail;
+ );
+
+ struct sk_buff *skb;
};
struct tls_decrypt_ctx {
@@ -180,39 +184,22 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx;
struct tls_context *tls_ctx;
- struct tls_prot_info *prot;
struct scatterlist *sg;
- struct sk_buff *skb;
unsigned int pages;
+ struct sock *sk;
- skb = (struct sk_buff *)req->data;
- tls_ctx = tls_get_ctx(skb->sk);
+ sk = (struct sock *)req->data;
+ tls_ctx = tls_get_ctx(sk);
ctx = tls_sw_ctx_rx(tls_ctx);
- prot = &tls_ctx->prot_info;
/* Propagate if there was an err */
if (err) {
if (err == -EBADMSG)
- TLS_INC_STATS(sock_net(skb->sk),
- LINUX_MIB_TLSDECRYPTERROR);
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
ctx->async_wait.err = err;
- tls_err_abort(skb->sk, err);
- } else {
- struct strp_msg *rxm = strp_msg(skb);
-
- /* No TLS 1.3 support with async crypto */
- WARN_ON(prot->tail_size);
-
- rxm->offset += prot->prepend_size;
- rxm->full_len -= prot->overhead_size;
+ tls_err_abort(sk, err);
}
- /* After using skb->sk to propagate sk through crypto async callback
- * we need to NULL it again.
- */
- skb->sk = NULL;
-
-
/* Free the destination pages if skb was not decrypted inplace */
if (sgout != sgin) {
/* Skip the first S/G entry as it points to AAD */
@@ -232,7 +219,6 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
}
static int tls_do_decryption(struct sock *sk,
- struct sk_buff *skb,
struct scatterlist *sgin,
struct scatterlist *sgout,
char *iv_recv,
@@ -252,16 +238,9 @@ static int tls_do_decryption(struct sock *sk,
(u8 *)iv_recv);
if (darg->async) {
- /* Using skb->sk to push sk through to crypto async callback
- * handler. This allows propagating errors up to the socket
- * if needed. It _must_ be cleared in the async handler
- * before consume_skb is called. We _know_ skb->sk is NULL
- * because it is a clone from strparser.
- */
- skb->sk = sk;
aead_request_set_callback(aead_req,
CRYPTO_TFM_REQ_MAY_BACKLOG,
- tls_decrypt_done, skb);
+ tls_decrypt_done, sk);
atomic_inc(&ctx->decrypt_pending);
} else {
aead_request_set_callback(aead_req,
@@ -1404,51 +1383,90 @@ out:
return rc;
}
+static struct sk_buff *
+tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb,
+ unsigned int full_len)
+{
+ struct strp_msg *clr_rxm;
+ struct sk_buff *clr_skb;
+ int err;
+
+ clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER,
+ &err, sk->sk_allocation);
+ if (!clr_skb)
+ return NULL;
+
+ skb_copy_header(clr_skb, skb);
+ clr_skb->len = full_len;
+ clr_skb->data_len = full_len;
+
+ clr_rxm = strp_msg(clr_skb);
+ clr_rxm->offset = 0;
+
+ return clr_skb;
+}
+
+/* Decrypt handlers
+ *
+ * tls_decrypt_sg() and tls_decrypt_device() are decrypt handlers.
+ * They must transform the darg in/out argument are as follows:
+ * | Input | Output
+ * -------------------------------------------------------------------
+ * zc | Zero-copy decrypt allowed | Zero-copy performed
+ * async | Async decrypt allowed | Async crypto used / in progress
+ * skb | * | Output skb
+ */
+
/* This function decrypts the input skb into either out_iov or in out_sg
- * or in skb buffers itself. The input parameter 'zc' indicates if
+ * or in skb buffers itself. The input parameter 'darg->zc' indicates if
* zero-copy mode needs to be tried or not. With zero-copy mode, either
* out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
* NULL, then the decryption happens inside skb buffers itself, i.e.
- * zero-copy gets disabled and 'zc' is updated.
+ * zero-copy gets disabled and 'darg->zc' is updated.
*/
-
-static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
- struct iov_iter *out_iov,
- struct scatterlist *out_sg,
- struct tls_decrypt_arg *darg)
+static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
+ struct scatterlist *out_sg,
+ struct tls_decrypt_arg *darg)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
int n_sgin, n_sgout, aead_size, err, pages = 0;
- struct strp_msg *rxm = strp_msg(skb);
- struct tls_msg *tlm = tls_msg(skb);
+ struct sk_buff *skb = tls_strp_msg(ctx);
+ const struct strp_msg *rxm = strp_msg(skb);
+ const struct tls_msg *tlm = tls_msg(skb);
struct aead_request *aead_req;
- struct sk_buff *unused;
struct scatterlist *sgin = NULL;
struct scatterlist *sgout = NULL;
const int data_len = rxm->full_len - prot->overhead_size;
int tail_pages = !!prot->tail_size;
struct tls_decrypt_ctx *dctx;
+ struct sk_buff *clear_skb;
int iv_offset = 0;
u8 *mem;
+ n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
+ rxm->full_len - prot->prepend_size);
+ if (n_sgin < 1)
+ return n_sgin ?: -EBADMSG;
+
if (darg->zc && (out_iov || out_sg)) {
+ clear_skb = NULL;
+
if (out_iov)
n_sgout = 1 + tail_pages +
iov_iter_npages_cap(out_iov, INT_MAX, data_len);
else
n_sgout = sg_nents(out_sg);
- n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
- rxm->full_len - prot->prepend_size);
} else {
- n_sgout = 0;
darg->zc = false;
- n_sgin = skb_cow_data(skb, 0, &unused);
- }
- if (n_sgin < 1)
- return -EBADMSG;
+ clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len);
+ if (!clear_skb)
+ return -ENOMEM;
+
+ n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags;
+ }
/* Increment to accommodate AAD */
n_sgin = n_sgin + 1;
@@ -1460,8 +1478,10 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
sk->sk_allocation);
- if (!mem)
- return -ENOMEM;
+ if (!mem) {
+ err = -ENOMEM;
+ goto exit_free_skb;
+ }
/* Segment the allocated memory */
aead_req = (struct aead_request *)mem;
@@ -1510,117 +1530,141 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
if (err < 0)
goto exit_free;
- if (n_sgout) {
- if (out_iov) {
- sg_init_table(sgout, n_sgout);
- sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
+ if (clear_skb) {
+ sg_init_table(sgout, n_sgout);
+ sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
- err = tls_setup_from_iter(out_iov, data_len,
- &pages, &sgout[1],
- (n_sgout - 1 - tail_pages));
- if (err < 0)
- goto fallback_to_reg_recv;
+ err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size,
+ data_len + prot->tail_size);
+ if (err < 0)
+ goto exit_free;
+ } else if (out_iov) {
+ sg_init_table(sgout, n_sgout);
+ sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
- if (prot->tail_size) {
- sg_unmark_end(&sgout[pages]);
- sg_set_buf(&sgout[pages + 1], &dctx->tail,
- prot->tail_size);
- sg_mark_end(&sgout[pages + 1]);
- }
- } else if (out_sg) {
- memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
- } else {
- goto fallback_to_reg_recv;
+ err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1],
+ (n_sgout - 1 - tail_pages));
+ if (err < 0)
+ goto exit_free_pages;
+
+ if (prot->tail_size) {
+ sg_unmark_end(&sgout[pages]);
+ sg_set_buf(&sgout[pages + 1], &dctx->tail,
+ prot->tail_size);
+ sg_mark_end(&sgout[pages + 1]);
}
- } else {
-fallback_to_reg_recv:
- sgout = sgin;
- pages = 0;
- darg->zc = false;
+ } else if (out_sg) {
+ memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
}
/* Prepare and submit AEAD request */
- err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
+ err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
data_len + prot->tail_size, aead_req, darg);
- if (darg->async)
- return 0;
+ if (err)
+ goto exit_free_pages;
+
+ darg->skb = clear_skb ?: tls_strp_msg(ctx);
+ clear_skb = NULL;
+
+ if (unlikely(darg->async)) {
+ err = tls_strp_msg_hold(sk, skb, &ctx->async_hold);
+ if (err)
+ __skb_queue_tail(&ctx->async_hold, darg->skb);
+ return err;
+ }
if (prot->tail_size)
darg->tail = dctx->tail;
+exit_free_pages:
/* Release the pages in case iov was mapped to pages */
for (; pages > 0; pages--)
put_page(sg_page(&sgout[pages]));
exit_free:
kfree(mem);
+exit_free_skb:
+ consume_skb(clear_skb);
return err;
}
-static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
- struct iov_iter *dest,
- struct tls_decrypt_arg *darg)
+static int
+tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
+ struct tls_decrypt_arg *darg)
+{
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ int err;
+
+ if (tls_ctx->rx_conf != TLS_HW)
+ return 0;
+
+ err = tls_device_decrypted(sk, tls_ctx);
+ if (err <= 0)
+ return err;
+
+ darg->zc = false;
+ darg->async = false;
+ darg->skb = tls_strp_msg(ctx);
+ ctx->recv_pkt = NULL;
+ return 1;
+}
+
+static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
+ struct tls_decrypt_arg *darg)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
- struct strp_msg *rxm = strp_msg(skb);
- struct tls_msg *tlm = tls_msg(skb);
+ struct strp_msg *rxm;
int pad, err;
- if (tlm->decrypted) {
- darg->zc = false;
- darg->async = false;
- return 0;
- }
-
- if (tls_ctx->rx_conf == TLS_HW) {
- err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
- if (err < 0)
- return err;
- if (err > 0) {
- tlm->decrypted = 1;
- darg->zc = false;
- darg->async = false;
- goto decrypt_done;
- }
- }
+ err = tls_decrypt_device(sk, tls_ctx, darg);
+ if (err < 0)
+ return err;
+ if (err)
+ goto decrypt_done;
- err = decrypt_internal(sk, skb, dest, NULL, darg);
+ err = tls_decrypt_sg(sk, dest, NULL, darg);
if (err < 0) {
if (err == -EBADMSG)
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
return err;
}
if (darg->async)
- goto decrypt_next;
+ goto decrypt_done;
/* If opportunistic TLS 1.3 ZC failed retry without ZC */
if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
darg->tail != TLS_RECORD_TYPE_DATA)) {
darg->zc = false;
- TLS_INC_STATS(sock_net(sk), LINUX_MIN_TLSDECRYPTRETRY);
- return decrypt_skb_update(sk, skb, dest, darg);
+ if (!darg->tail)
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
+ return tls_rx_one_record(sk, dest, darg);
}
decrypt_done:
- pad = tls_padding_length(prot, skb, darg);
- if (pad < 0)
+ if (darg->skb == ctx->recv_pkt)
+ ctx->recv_pkt = NULL;
+
+ pad = tls_padding_length(prot, darg->skb, darg);
+ if (pad < 0) {
+ consume_skb(darg->skb);
return pad;
+ }
+ rxm = strp_msg(darg->skb);
rxm->full_len -= pad;
rxm->offset += prot->prepend_size;
rxm->full_len -= prot->overhead_size;
- tlm->decrypted = 1;
-decrypt_next:
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
return 0;
}
-int decrypt_skb(struct sock *sk, struct sk_buff *skb,
- struct scatterlist *sgout)
+int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
{
struct tls_decrypt_arg darg = { .zc = true, };
- return decrypt_internal(sk, skb, NULL, sgout, &darg);
+ return tls_decrypt_sg(sk, NULL, sgout, &darg);
}
static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
@@ -1646,6 +1690,13 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
return 1;
}
+static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
+{
+ consume_skb(ctx->recv_pkt);
+ ctx->recv_pkt = NULL;
+ __strp_unpause(&ctx->strp);
+}
+
/* This function traverses the rx_list in tls receive context to copies the
* decrypted records into the buffer provided by caller zero copy is not
* true. Further, the records are removed from the rx_list if it is not a peek
@@ -1656,7 +1707,6 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
u8 *control,
size_t skip,
size_t len,
- bool zc,
bool is_peek)
{
struct sk_buff *skb = skb_peek(&ctx->rx_list);
@@ -1690,12 +1740,10 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
if (err <= 0)
goto out;
- if (!zc || (rxm->full_len - skip) > len) {
- err = skb_copy_datagram_msg(skb, rxm->offset + skip,
- msg, chunk);
- if (err < 0)
- goto out;
- }
+ err = skb_copy_datagram_msg(skb, rxm->offset + skip,
+ msg, chunk);
+ if (err < 0)
+ goto out;
len = len - chunk;
copied = copied + chunk;
@@ -1751,6 +1799,60 @@ tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
sk_flush_backlog(sk);
}
+static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
+ bool nonblock)
+{
+ long timeo;
+ int err;
+
+ lock_sock(sk);
+
+ timeo = sock_rcvtimeo(sk, nonblock);
+
+ while (unlikely(ctx->reader_present)) {
+ DEFINE_WAIT_FUNC(wait, woken_wake_function);
+
+ ctx->reader_contended = 1;
+
+ add_wait_queue(&ctx->wq, &wait);
+ sk_wait_event(sk, &timeo,
+ !READ_ONCE(ctx->reader_present), &wait);
+ remove_wait_queue(&ctx->wq, &wait);
+
+ if (timeo <= 0) {
+ err = -EAGAIN;
+ goto err_unlock;
+ }
+ if (signal_pending(current)) {
+ err = sock_intr_errno(timeo);
+ goto err_unlock;
+ }
+ }
+
+ WRITE_ONCE(ctx->reader_present, 1);
+
+ return timeo;
+
+err_unlock:
+ release_sock(sk);
+ return err;
+}
+
+static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx)
+{
+ if (unlikely(ctx->reader_contended)) {
+ if (wq_has_sleeper(&ctx->wq))
+ wake_up(&ctx->wq);
+ else
+ ctx->reader_contended = 0;
+
+ WARN_ON_ONCE(!ctx->reader_present);
+ }
+
+ WRITE_ONCE(ctx->reader_present, 0);
+ release_sock(sk);
+}
+
int tls_sw_recvmsg(struct sock *sk,
struct msghdr *msg,
size_t len,
@@ -1760,9 +1862,9 @@ int tls_sw_recvmsg(struct sock *sk,
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
+ ssize_t decrypted = 0, async_copy_bytes = 0;
struct sk_psock *psock;
unsigned char control = 0;
- ssize_t decrypted = 0;
size_t flushed_at = 0;
struct strp_msg *rxm;
struct tls_msg *tlm;
@@ -1780,7 +1882,9 @@ int tls_sw_recvmsg(struct sock *sk,
return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
psock = sk_psock_get(sk);
- lock_sock(sk);
+ timeo = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT);
+ if (timeo < 0)
+ return timeo;
bpf_strp_enabled = sk_psock_strp_enabled(psock);
/* If crypto failed the connection is broken */
@@ -1789,7 +1893,7 @@ int tls_sw_recvmsg(struct sock *sk,
goto end;
/* Process pending decrypted records. It must be non-zero-copy */
- err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
+ err = process_rx_list(ctx, msg, &control, 0, len, is_peek);
if (err < 0)
goto end;
@@ -1799,13 +1903,12 @@ int tls_sw_recvmsg(struct sock *sk,
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
len = len - copied;
- timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
ctx->zc_capable;
decrypted = 0;
while (len && (decrypted + copied < target || ctx->recv_pkt)) {
- struct tls_decrypt_arg darg = {};
+ struct tls_decrypt_arg darg;
int to_decrypt, chunk;
err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo);
@@ -1813,15 +1916,19 @@ int tls_sw_recvmsg(struct sock *sk,
if (psock) {
chunk = sk_msg_recvmsg(sk, psock, msg, len,
flags);
- if (chunk > 0)
- goto leave_on_list;
+ if (chunk > 0) {
+ decrypted += chunk;
+ len -= chunk;
+ continue;
+ }
}
goto recv_end;
}
- skb = ctx->recv_pkt;
- rxm = strp_msg(skb);
- tlm = tls_msg(skb);
+ memset(&darg.inargs, 0, sizeof(darg.inargs));
+
+ rxm = strp_msg(ctx->recv_pkt);
+ tlm = tls_msg(ctx->recv_pkt);
to_decrypt = rxm->full_len - prot->overhead_size;
@@ -1835,12 +1942,16 @@ int tls_sw_recvmsg(struct sock *sk,
else
darg.async = false;
- err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg);
+ err = tls_rx_one_record(sk, &msg->msg_iter, &darg);
if (err < 0) {
tls_err_abort(sk, -EBADMSG);
goto recv_end;
}
+ skb = darg.skb;
+ rxm = strp_msg(skb);
+ tlm = tls_msg(skb);
+
async |= darg.async;
/* If the type of records being processed is not known yet,
@@ -1851,34 +1962,36 @@ int tls_sw_recvmsg(struct sock *sk,
* For tls1.3, we disable async.
*/
err = tls_record_content_type(msg, tlm, &control);
- if (err <= 0)
+ if (err <= 0) {
+ tls_rx_rec_done(ctx);
+put_on_rx_list_err:
+ __skb_queue_tail(&ctx->rx_list, skb);
goto recv_end;
+ }
/* periodically flush backlog, and feed strparser */
tls_read_flush_backlog(sk, prot, len, to_decrypt,
decrypted + copied, &flushed_at);
- ctx->recv_pkt = NULL;
- __strp_unpause(&ctx->strp);
- __skb_queue_tail(&ctx->rx_list, skb);
-
- if (async) {
- /* TLS 1.2-only, to_decrypt must be text length */
- chunk = min_t(int, to_decrypt, len);
-leave_on_list:
- decrypted += chunk;
- len -= chunk;
- continue;
- }
/* TLS 1.3 may have updated the length by more than overhead */
chunk = rxm->full_len;
+ tls_rx_rec_done(ctx);
if (!darg.zc) {
bool partially_consumed = chunk > len;
+ if (async) {
+ /* TLS 1.2-only, to_decrypt must be text len */
+ chunk = min_t(int, to_decrypt, len);
+ async_copy_bytes += chunk;
+put_on_rx_list:
+ decrypted += chunk;
+ len -= chunk;
+ __skb_queue_tail(&ctx->rx_list, skb);
+ continue;
+ }
+
if (bpf_strp_enabled) {
- /* BPF may try to queue the skb */
- __skb_unlink(skb, &ctx->rx_list);
err = sk_psock_tls_strp_read(psock, skb);
if (err != __SK_PASS) {
rxm->offset = rxm->offset + rxm->full_len;
@@ -1887,7 +2000,6 @@ leave_on_list:
consume_skb(skb);
continue;
}
- __skb_queue_tail(&ctx->rx_list, skb);
}
if (partially_consumed)
@@ -1896,22 +2008,21 @@ leave_on_list:
err = skb_copy_datagram_msg(skb, rxm->offset,
msg, chunk);
if (err < 0)
- goto recv_end;
+ goto put_on_rx_list_err;
if (is_peek)
- goto leave_on_list;
+ goto put_on_rx_list;
if (partially_consumed) {
rxm->offset += chunk;
rxm->full_len -= chunk;
- goto leave_on_list;
+ goto put_on_rx_list;
}
}
decrypted += chunk;
len -= chunk;
- __skb_unlink(skb, &ctx->rx_list);
consume_skb(skb);
/* Return full control message to userspace before trying
@@ -1931,30 +2042,32 @@ recv_end:
reinit_completion(&ctx->async_wait.completion);
pending = atomic_read(&ctx->decrypt_pending);
spin_unlock_bh(&ctx->decrypt_compl_lock);
- if (pending) {
+ ret = 0;
+ if (pending)
ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
- if (ret) {
- if (err >= 0 || err == -EINPROGRESS)
- err = ret;
- decrypted = 0;
- goto end;
- }
+ __skb_queue_purge(&ctx->async_hold);
+
+ if (ret) {
+ if (err >= 0 || err == -EINPROGRESS)
+ err = ret;
+ decrypted = 0;
+ goto end;
}
/* Drain records from the rx_list & copy if required */
if (is_peek || is_kvec)
err = process_rx_list(ctx, msg, &control, copied,
- decrypted, false, is_peek);
+ decrypted, is_peek);
else
err = process_rx_list(ctx, msg, &control, 0,
- decrypted, true, is_peek);
+ async_copy_bytes, is_peek);
decrypted = max(err, 0);
}
copied += decrypted;
end:
- release_sock(sk);
+ tls_rx_reader_unlock(sk, ctx);
if (psock)
sk_psock_put(sk, psock);
return copied ? : err;
@@ -1971,33 +2084,34 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct tls_msg *tlm;
struct sk_buff *skb;
ssize_t copied = 0;
- bool from_queue;
int err = 0;
long timeo;
int chunk;
- lock_sock(sk);
+ timeo = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK);
+ if (timeo < 0)
+ return timeo;
- timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
-
- from_queue = !skb_queue_empty(&ctx->rx_list);
- if (from_queue) {
+ if (!skb_queue_empty(&ctx->rx_list)) {
skb = __skb_dequeue(&ctx->rx_list);
} else {
- struct tls_decrypt_arg darg = {};
+ struct tls_decrypt_arg darg;
err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
timeo);
if (err <= 0)
goto splice_read_end;
- skb = ctx->recv_pkt;
+ memset(&darg.inargs, 0, sizeof(darg.inargs));
- err = decrypt_skb_update(sk, skb, NULL, &darg);
+ err = tls_rx_one_record(sk, NULL, &darg);
if (err < 0) {
tls_err_abort(sk, -EBADMSG);
goto splice_read_end;
}
+
+ tls_rx_rec_done(ctx);
+ skb = darg.skb;
}
rxm = strp_msg(skb);
@@ -2006,29 +2120,29 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
/* splice does not support reading control messages */
if (tlm->control != TLS_RECORD_TYPE_DATA) {
err = -EINVAL;
- goto splice_read_end;
+ goto splice_requeue;
}
chunk = min_t(unsigned int, rxm->full_len, len);
copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
if (copied < 0)
- goto splice_read_end;
+ goto splice_requeue;
- if (!from_queue) {
- ctx->recv_pkt = NULL;
- __strp_unpause(&ctx->strp);
- }
if (chunk < rxm->full_len) {
- __skb_queue_head(&ctx->rx_list, skb);
rxm->offset += len;
rxm->full_len -= len;
- } else {
- consume_skb(skb);
+ goto splice_requeue;
}
+ consume_skb(skb);
+
splice_read_end:
- release_sock(sk);
+ tls_rx_reader_unlock(sk, ctx);
return copied ? : err;
+
+splice_requeue:
+ __skb_queue_head(&ctx->rx_list, skb);
+ goto splice_read_end;
}
bool tls_sw_sock_is_readable(struct sock *sk)
@@ -2074,7 +2188,6 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
if (ret < 0)
goto read_failure;
- tlm->decrypted = 0;
tlm->control = header[0];
data_len = ((header[4] & 0xFF) | (header[3] << 8));
@@ -2369,9 +2482,11 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
} else {
crypto_init_wait(&sw_ctx_rx->async_wait);
spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
+ init_waitqueue_head(&sw_ctx_rx->wq);
crypto_info = &ctx->crypto_recv.info;
cctx = &ctx->rx;
skb_queue_head_init(&sw_ctx_rx->rx_list);
+ skb_queue_head_init(&sw_ctx_rx->async_hold);
aead = &sw_ctx_rx->aead_recv;
}