summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/net/tls.h3
-rw-r--r--net/tls/tls_sw.c266
2 files changed, 198 insertions, 71 deletions
diff --git a/include/net/tls.h b/include/net/tls.h
index 2a6ac8d642af..90bf52db573e 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -145,12 +145,13 @@ struct tls_sw_context_tx {
struct tls_sw_context_rx {
struct crypto_aead *aead_recv;
struct crypto_wait async_wait;
-
struct strparser strp;
+ struct sk_buff_head rx_list; /* list of decrypted 'data' records */
void (*saved_data_ready)(struct sock *sk);
struct sk_buff *recv_pkt;
u8 control;
+ int async_capable;
bool decrypted;
atomic_t decrypt_pending;
bool async_notify;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index b8e50e22b777..86b9527c4826 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -124,6 +124,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
{
struct aead_request *aead_req = (struct aead_request *)req;
struct scatterlist *sgout = aead_req->dst;
+ struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx;
struct tls_context *tls_ctx;
struct scatterlist *sg;
@@ -134,12 +135,16 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
skb = (struct sk_buff *)req->data;
tls_ctx = tls_get_ctx(skb->sk);
ctx = tls_sw_ctx_rx(tls_ctx);
- pending = atomic_dec_return(&ctx->decrypt_pending);
/* Propagate if there was an err */
if (err) {
ctx->async_wait.err = err;
tls_err_abort(skb->sk, err);
+ } else {
+ struct strp_msg *rxm = strp_msg(skb);
+
+ rxm->offset += tls_ctx->rx.prepend_size;
+ rxm->full_len -= tls_ctx->rx.overhead_size;
}
/* After using skb->sk to propagate sk through crypto async callback
@@ -147,18 +152,21 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
*/
skb->sk = NULL;
- /* Release the skb, pages and memory allocated for crypto req */
- kfree_skb(skb);
- /* Skip the first S/G entry as it points to AAD */
- for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
- if (!sg)
- break;
- put_page(sg_page(sg));
+ /* Free the destination pages if skb was not decrypted inplace */
+ if (sgout != sgin) {
+ /* Skip the first S/G entry as it points to AAD */
+ for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
+ if (!sg)
+ break;
+ put_page(sg_page(sg));
+ }
}
kfree(aead_req);
+ pending = atomic_dec_return(&ctx->decrypt_pending);
+
if (!pending && READ_ONCE(ctx->async_notify))
complete(&ctx->async_wait.completion);
}
@@ -1271,7 +1279,7 @@ out:
static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
struct iov_iter *out_iov,
struct scatterlist *out_sg,
- int *chunk, bool *zc)
+ int *chunk, bool *zc, bool async)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1371,13 +1379,13 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
fallback_to_reg_recv:
sgout = sgin;
pages = 0;
- *chunk = 0;
+ *chunk = data_len;
*zc = false;
}
/* Prepare and submit AEAD request */
err = tls_do_decryption(sk, skb, sgin, sgout, iv,
- data_len, aead_req, *zc);
+ data_len, aead_req, async);
if (err == -EINPROGRESS)
return err;
@@ -1390,7 +1398,8 @@ fallback_to_reg_recv:
}
static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
- struct iov_iter *dest, int *chunk, bool *zc)
+ struct iov_iter *dest, int *chunk, bool *zc,
+ bool async)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1403,7 +1412,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
return err;
#endif
if (!ctx->decrypted) {
- err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
+ err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
if (err < 0) {
if (err == -EINPROGRESS)
tls_advance_record_sn(sk, &tls_ctx->rx);
@@ -1429,7 +1438,7 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb,
bool zc = true;
int chunk;
- return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
+ return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
}
static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -1456,6 +1465,77 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
return true;
}
+/* This function traverses the rx_list in tls receive context to copies the
+ * decrypted data records into the buffer provided by caller zero copy is not
+ * true. Further, the records are removed from the rx_list if it is not a peek
+ * case and the record has been consumed completely.
+ */
+static int process_rx_list(struct tls_sw_context_rx *ctx,
+ struct msghdr *msg,
+ size_t skip,
+ size_t len,
+ bool zc,
+ bool is_peek)
+{
+ struct sk_buff *skb = skb_peek(&ctx->rx_list);
+ ssize_t copied = 0;
+
+ while (skip && skb) {
+ struct strp_msg *rxm = strp_msg(skb);
+
+ if (skip < rxm->full_len)
+ break;
+
+ skip = skip - rxm->full_len;
+ skb = skb_peek_next(skb, &ctx->rx_list);
+ }
+
+ while (len && skb) {
+ struct sk_buff *next_skb;
+ struct strp_msg *rxm = strp_msg(skb);
+ int chunk = min_t(unsigned int, rxm->full_len - skip, len);
+
+ if (!zc || (rxm->full_len - skip) > len) {
+ int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
+ msg, chunk);
+ if (err < 0)
+ return err;
+ }
+
+ len = len - chunk;
+ copied = copied + chunk;
+
+ /* Consume the data from record if it is non-peek case*/
+ if (!is_peek) {
+ rxm->offset = rxm->offset + chunk;
+ rxm->full_len = rxm->full_len - chunk;
+
+ /* Return if there is unconsumed data in the record */
+ if (rxm->full_len - skip)
+ break;
+ }
+
+ /* The remaining skip-bytes must lie in 1st record in rx_list.
+ * So from the 2nd record, 'skip' should be 0.
+ */
+ skip = 0;
+
+ if (msg)
+ msg->msg_flags |= MSG_EOR;
+
+ next_skb = skb_peek_next(skb, &ctx->rx_list);
+
+ if (!is_peek) {
+ skb_unlink(skb, &ctx->rx_list);
+ kfree_skb(skb);
+ }
+
+ skb = next_skb;
+ }
+
+ return copied;
+}
+
int tls_sw_recvmsg(struct sock *sk,
struct msghdr *msg,
size_t len,
@@ -1466,7 +1546,8 @@ int tls_sw_recvmsg(struct sock *sk,
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_psock *psock;
- unsigned char control;
+ unsigned char control = 0;
+ ssize_t decrypted = 0;
struct strp_msg *rxm;
struct sk_buff *skb;
ssize_t copied = 0;
@@ -1474,6 +1555,7 @@ int tls_sw_recvmsg(struct sock *sk,
int target, err = 0;
long timeo;
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
+ bool is_peek = flags & MSG_PEEK;
int num_async = 0;
flags |= nonblock;
@@ -1484,11 +1566,28 @@ int tls_sw_recvmsg(struct sock *sk,
psock = sk_psock_get(sk);
lock_sock(sk);
- target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
- timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+ /* Process pending decrypted records. It must be non-zero-copy */
+ err = process_rx_list(ctx, msg, 0, len, false, is_peek);
+ if (err < 0) {
+ tls_err_abort(sk, err);
+ goto end;
+ } else {
+ copied = err;
+ }
+
+ len = len - copied;
+ if (len) {
+ target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
+ timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+ } else {
+ goto recv_end;
+ }
+
do {
- bool zc = false;
+ bool retain_skb = false;
bool async = false;
+ bool zc = false;
+ int to_decrypt;
int chunk = 0;
skb = tls_wait_data(sk, psock, flags, timeo, &err);
@@ -1498,7 +1597,7 @@ int tls_sw_recvmsg(struct sock *sk,
msg, len, flags);
if (ret > 0) {
- copied += ret;
+ decrypted += ret;
len -= ret;
continue;
}
@@ -1525,70 +1624,70 @@ int tls_sw_recvmsg(struct sock *sk,
goto recv_end;
}
- if (!ctx->decrypted) {
- int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
+ to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size;
- if (!is_kvec && to_copy <= len &&
- likely(!(flags & MSG_PEEK)))
- zc = true;
+ if (to_decrypt <= len && !is_kvec && !is_peek)
+ zc = true;
- err = decrypt_skb_update(sk, skb, &msg->msg_iter,
- &chunk, &zc);
- if (err < 0 && err != -EINPROGRESS) {
- tls_err_abort(sk, EBADMSG);
- goto recv_end;
- }
-
- if (err == -EINPROGRESS) {
- async = true;
- num_async++;
- goto pick_next_record;
- }
-
- ctx->decrypted = true;
+ err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+ &chunk, &zc, ctx->async_capable);
+ if (err < 0 && err != -EINPROGRESS) {
+ tls_err_abort(sk, EBADMSG);
+ goto recv_end;
}
- if (!zc) {
- chunk = min_t(unsigned int, rxm->full_len, len);
+ if (err == -EINPROGRESS) {
+ async = true;
+ num_async++;
+ goto pick_next_record;
+ } else {
+ if (!zc) {
+ if (rxm->full_len > len) {
+ retain_skb = true;
+ chunk = len;
+ } else {
+ chunk = rxm->full_len;
+ }
+
+ err = skb_copy_datagram_msg(skb, rxm->offset,
+ msg, chunk);
+ if (err < 0)
+ goto recv_end;
- err = skb_copy_datagram_msg(skb, rxm->offset, msg,
- chunk);
- if (err < 0)
- goto recv_end;
+ if (!is_peek) {
+ rxm->offset = rxm->offset + chunk;
+ rxm->full_len = rxm->full_len - chunk;
+ }
+ }
}
pick_next_record:
- copied += chunk;
+ if (chunk > len)
+ chunk = len;
+
+ decrypted += chunk;
len -= chunk;
- if (likely(!(flags & MSG_PEEK))) {
- u8 control = ctx->control;
-
- /* For async, drop current skb reference */
- if (async)
- skb = NULL;
-
- if (tls_sw_advance_skb(sk, skb, chunk)) {
- /* Return full control message to
- * userspace before trying to parse
- * another message type
- */
- msg->msg_flags |= MSG_EOR;
- if (control != TLS_RECORD_TYPE_DATA)
- goto recv_end;
- } else {
- break;
- }
- } else {
- /* MSG_PEEK right now cannot look beyond current skb
- * from strparser, meaning we cannot advance skb here
- * and thus unpause strparser since we'd loose original
- * one.
+
+ /* For async or peek case, queue the current skb */
+ if (async || is_peek || retain_skb) {
+ skb_queue_tail(&ctx->rx_list, skb);
+ skb = NULL;
+ }
+
+ if (tls_sw_advance_skb(sk, skb, chunk)) {
+ /* Return full control message to
+ * userspace before trying to parse
+ * another message type
*/
+ msg->msg_flags |= MSG_EOR;
+ if (ctx->control != TLS_RECORD_TYPE_DATA)
+ goto recv_end;
+ } else {
break;
}
/* If we have a new message from strparser, continue now. */
- if (copied >= target && !ctx->recv_pkt)
+ if (decrypted >= target && !ctx->recv_pkt)
break;
} while (len);
@@ -1602,13 +1701,33 @@ recv_end:
/* one of async decrypt failed */
tls_err_abort(sk, err);
copied = 0;
+ decrypted = 0;
+ goto end;
}
} else {
reinit_completion(&ctx->async_wait.completion);
}
WRITE_ONCE(ctx->async_notify, false);
+
+ /* Drain records from the rx_list & copy if required */
+ if (is_peek || is_kvec)
+ err = process_rx_list(ctx, msg, copied,
+ decrypted, false, is_peek);
+ else
+ err = process_rx_list(ctx, msg, 0,
+ decrypted, true, is_peek);
+ if (err < 0) {
+ tls_err_abort(sk, err);
+ copied = 0;
+ goto end;
+ }
+
+ WARN_ON(decrypted != err);
}
+ copied += decrypted;
+
+end:
release_sock(sk);
if (psock)
sk_psock_put(sk, psock);
@@ -1645,7 +1764,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
}
if (!ctx->decrypted) {
- err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
+ err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
if (err < 0) {
tls_err_abort(sk, EBADMSG);
@@ -1832,6 +1951,7 @@ void tls_sw_release_resources_rx(struct sock *sk)
if (ctx->aead_recv) {
kfree_skb(ctx->recv_pkt);
ctx->recv_pkt = NULL;
+ skb_queue_purge(&ctx->rx_list);
crypto_free_aead(ctx->aead_recv);
strp_stop(&ctx->strp);
write_lock_bh(&sk->sk_callback_lock);
@@ -1881,6 +2001,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
struct crypto_aead **aead;
struct strp_callbacks cb;
u16 nonce_size, tag_size, iv_size, rec_seq_size;
+ struct crypto_tfm *tfm;
char *iv, *rec_seq;
int rc = 0;
@@ -1927,6 +2048,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
crypto_init_wait(&sw_ctx_rx->async_wait);
crypto_info = &ctx->crypto_recv.info;
cctx = &ctx->rx;
+ skb_queue_head_init(&sw_ctx_rx->rx_list);
aead = &sw_ctx_rx->aead_recv;
}
@@ -1994,6 +2116,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
goto free_aead;
if (sw_ctx_rx) {
+ tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
+ sw_ctx_rx->async_capable =
+ tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
+
/* Set up strparser */
memset(&cb, 0, sizeof(cb));
cb.rcv_msg = tls_queue;