summaryrefslogtreecommitdiff
path: root/net/tls/tls_main.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r--net/tls/tls_main.c136
1 files changed, 94 insertions, 42 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 43252a801c3f..ac88877dcade 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -39,6 +39,7 @@
#include <linux/netdevice.h>
#include <linux/sched/signal.h>
#include <linux/inetdevice.h>
+#include <linux/inet_diag.h>
#include <net/tls.h>
@@ -251,14 +252,26 @@ static void tls_write_space(struct sock *sk)
ctx->sk_write_space(sk);
}
-void tls_ctx_free(struct tls_context *ctx)
+/**
+ * tls_ctx_free() - free TLS ULP context
+ * @sk: socket to with @ctx is attached
+ * @ctx: TLS context structure
+ *
+ * Free TLS context. If @sk is %NULL caller guarantees that the socket
+ * to which @ctx was attached has no outstanding references.
+ */
+void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
{
if (!ctx)
return;
memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
- kfree(ctx);
+
+ if (sk)
+ kfree_rcu(ctx, rcu);
+ else
+ kfree(ctx);
}
static void tls_sk_proto_cleanup(struct sock *sk,
@@ -273,19 +286,14 @@ static void tls_sk_proto_cleanup(struct sock *sk,
kfree(ctx->tx.rec_seq);
kfree(ctx->tx.iv);
tls_sw_release_resources_tx(sk);
-#ifdef CONFIG_TLS_DEVICE
} else if (ctx->tx_conf == TLS_HW) {
tls_device_free_resources_tx(sk);
-#endif
}
if (ctx->rx_conf == TLS_SW)
tls_sw_release_resources_rx(sk);
-
-#ifdef CONFIG_TLS_DEVICE
- if (ctx->rx_conf == TLS_HW)
+ else if (ctx->rx_conf == TLS_HW)
tls_device_offload_cleanup_rx(sk);
-#endif
}
static void tls_sk_proto_close(struct sock *sk, long timeout)
@@ -306,7 +314,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
write_lock_bh(&sk->sk_callback_lock);
if (free_ctx)
- icsk->icsk_ulp_data = NULL;
+ rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
sk->sk_prot = ctx->sk_proto;
if (sk->sk_write_space == tls_write_space)
sk->sk_write_space = ctx->sk_write_space;
@@ -318,10 +326,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
tls_sw_strparser_done(ctx);
if (ctx->rx_conf == TLS_SW)
tls_sw_free_ctx_rx(ctx);
- ctx->sk_proto_close(sk, timeout);
+ ctx->sk_proto->close(sk, timeout);
if (free_ctx)
- tls_ctx_free(ctx);
+ tls_ctx_free(sk, ctx);
}
static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
@@ -438,7 +446,8 @@ static int tls_getsockopt(struct sock *sk, int level, int optname,
struct tls_context *ctx = tls_get_ctx(sk);
if (level != SOL_TLS)
- return ctx->getsockopt(sk, level, optname, optval, optlen);
+ return ctx->sk_proto->getsockopt(sk, level,
+ optname, optval, optlen);
return do_tls_getsockopt(sk, optname, optval, optlen);
}
@@ -523,26 +532,18 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
}
if (tx) {
-#ifdef CONFIG_TLS_DEVICE
rc = tls_set_device_offload(sk, ctx);
conf = TLS_HW;
if (rc) {
-#else
- {
-#endif
rc = tls_set_sw_offload(sk, ctx, 1);
if (rc)
goto err_crypto_info;
conf = TLS_SW;
}
} else {
-#ifdef CONFIG_TLS_DEVICE
rc = tls_set_device_offload_rx(sk, ctx);
conf = TLS_HW;
if (rc) {
-#else
- {
-#endif
rc = tls_set_sw_offload(sk, ctx, 0);
if (rc)
goto err_crypto_info;
@@ -596,7 +597,8 @@ static int tls_setsockopt(struct sock *sk, int level, int optname,
struct tls_context *ctx = tls_get_ctx(sk);
if (level != SOL_TLS)
- return ctx->setsockopt(sk, level, optname, optval, optlen);
+ return ctx->sk_proto->setsockopt(sk, level, optname, optval,
+ optlen);
return do_tls_setsockopt(sk, optname, optval, optlen);
}
@@ -610,11 +612,8 @@ static struct tls_context *create_ctx(struct sock *sk)
if (!ctx)
return NULL;
- icsk->icsk_ulp_data = ctx;
- ctx->setsockopt = sk->sk_prot->setsockopt;
- ctx->getsockopt = sk->sk_prot->getsockopt;
- ctx->sk_proto_close = sk->sk_prot->close;
- ctx->unhash = sk->sk_prot->unhash;
+ rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
+ ctx->sk_proto = sk->sk_prot;
return ctx;
}
@@ -651,8 +650,8 @@ static void tls_hw_sk_destruct(struct sock *sk)
ctx->sk_destruct(sk);
/* Free ctx */
- tls_ctx_free(ctx);
- icsk->icsk_ulp_data = NULL;
+ rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
+ tls_ctx_free(sk, ctx);
}
static int tls_hw_prot(struct sock *sk)
@@ -670,9 +669,6 @@ static int tls_hw_prot(struct sock *sk)
spin_unlock_bh(&device_spinlock);
tls_build_proto(sk);
- ctx->hash = sk->sk_prot->hash;
- ctx->unhash = sk->sk_prot->unhash;
- ctx->sk_proto_close = sk->sk_prot->close;
ctx->sk_destruct = sk->sk_destruct;
sk->sk_destruct = tls_hw_sk_destruct;
ctx->rx_conf = TLS_HW_RECORD;
@@ -704,7 +700,7 @@ static void tls_hw_unhash(struct sock *sk)
}
}
spin_unlock_bh(&device_spinlock);
- ctx->unhash(sk);
+ ctx->sk_proto->unhash(sk);
}
static int tls_hw_hash(struct sock *sk)
@@ -713,7 +709,7 @@ static int tls_hw_hash(struct sock *sk)
struct tls_device *dev;
int err;
- err = ctx->hash(sk);
+ err = ctx->sk_proto->hash(sk);
spin_lock_bh(&device_spinlock);
list_for_each_entry(dev, &device_list, dev_list) {
if (dev->hash) {
@@ -803,7 +799,6 @@ static int tls_init(struct sock *sk)
ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
- ctx->sk_proto = sk->sk_prot;
update_sk_prot(sk, ctx);
out:
write_unlock_bh(&sk->sk_callback_lock);
@@ -815,12 +810,71 @@ static void tls_update(struct sock *sk, struct proto *p)
struct tls_context *ctx;
ctx = tls_get_ctx(sk);
- if (likely(ctx)) {
- ctx->sk_proto_close = p->close;
+ if (likely(ctx))
ctx->sk_proto = p;
- } else {
+ else
sk->sk_prot = p;
- }
+}
+
+static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
+{
+ u16 version, cipher_type;
+ struct tls_context *ctx;
+ struct nlattr *start;
+ int err;
+
+ start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
+ if (!start)
+ return -EMSGSIZE;
+
+ rcu_read_lock();
+ ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
+ if (!ctx) {
+ err = 0;
+ goto nla_failure;
+ }
+ version = ctx->prot_info.version;
+ if (version) {
+ err = nla_put_u16(skb, TLS_INFO_VERSION, version);
+ if (err)
+ goto nla_failure;
+ }
+ cipher_type = ctx->prot_info.cipher_type;
+ if (cipher_type) {
+ err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
+ if (err)
+ goto nla_failure;
+ }
+ err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
+ if (err)
+ goto nla_failure;
+
+ err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
+ if (err)
+ goto nla_failure;
+
+ rcu_read_unlock();
+ nla_nest_end(skb, start);
+ return 0;
+
+nla_failure:
+ rcu_read_unlock();
+ nla_nest_cancel(skb, start);
+ return err;
+}
+
+static size_t tls_get_info_size(const struct sock *sk)
+{
+ size_t size = 0;
+
+ size += nla_total_size(0) + /* INET_ULP_INFO_TLS */
+ nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */
+ nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */
+ nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */
+ nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */
+ 0;
+
+ return size;
}
void tls_register_device(struct tls_device *device)
@@ -844,6 +898,8 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
.owner = THIS_MODULE,
.init = tls_init,
.update = tls_update,
+ .get_info = tls_get_info,
+ .get_info_size = tls_get_info_size,
};
static int __init tls_register(void)
@@ -851,9 +907,7 @@ static int __init tls_register(void)
tls_sw_proto_ops = inet_stream_ops;
tls_sw_proto_ops.splice_read = tls_sw_splice_read;
-#ifdef CONFIG_TLS_DEVICE
tls_device_init();
-#endif
tcp_register_ulp(&tcp_tls_ulp_ops);
return 0;
@@ -862,9 +916,7 @@ static int __init tls_register(void)
static void __exit tls_unregister(void)
{
tcp_unregister_ulp(&tcp_tls_ulp_ops);
-#ifdef CONFIG_TLS_DEVICE
tls_device_cleanup();
-#endif
}
module_init(tls_register);