summaryrefslogtreecommitdiff
path: root/net/mctp/af_mctp.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/mctp/af_mctp.c')
-rw-r--r--net/mctp/af_mctp.c315
1 files changed, 288 insertions, 27 deletions
diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c
index 85cc1a28cbe9..f0702d920d8d 100644
--- a/net/mctp/af_mctp.c
+++ b/net/mctp/af_mctp.c
@@ -6,6 +6,7 @@
* Copyright (c) 2021 Google
*/
+#include <linux/compat.h>
#include <linux/if_arp.h>
#include <linux/net.h>
#include <linux/mctp.h>
@@ -16,8 +17,13 @@
#include <net/mctpdevice.h>
#include <net/sock.h>
+#define CREATE_TRACE_POINTS
+#include <trace/events/mctp.h>
+
/* socket implementation */
+static void mctp_sk_expire_keys(struct timer_list *timer);
+
static int mctp_release(struct socket *sock)
{
struct sock *sk = sock->sk;
@@ -36,6 +42,13 @@ static bool mctp_sockaddr_is_ok(const struct sockaddr_mctp *addr)
return !addr->__smctp_pad0 && !addr->__smctp_pad1;
}
+static bool mctp_sockaddr_ext_is_ok(const struct sockaddr_mctp_ext *addr)
+{
+ return !addr->__smctp_pad0[0] &&
+ !addr->__smctp_pad0[1] &&
+ !addr->__smctp_pad0[2];
+}
+
static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
{
struct sock *sk = sock->sk;
@@ -83,18 +96,26 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
int rc, addrlen = msg->msg_namelen;
struct sock *sk = sock->sk;
+ struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct mctp_skb_cb *cb;
struct mctp_route *rt;
struct sk_buff *skb;
if (addr) {
+ const u8 tagbits = MCTP_TAG_MASK | MCTP_TAG_OWNER |
+ MCTP_TAG_PREALLOC;
+
if (addrlen < sizeof(struct sockaddr_mctp))
return -EINVAL;
if (addr->smctp_family != AF_MCTP)
return -EINVAL;
if (!mctp_sockaddr_is_ok(addr))
return -EINVAL;
- if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER))
+ if (addr->smctp_tag & ~tagbits)
+ return -EINVAL;
+ /* can't preallocate a non-owned tag */
+ if (addr->smctp_tag & MCTP_TAG_PREALLOC &&
+ !(addr->smctp_tag & MCTP_TAG_OWNER))
return -EINVAL;
} else {
@@ -108,11 +129,6 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
if (addr->smctp_network == MCTP_NET_ANY)
addr->smctp_network = mctp_default_net(sock_net(sk));
- rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
- addr->smctp_addr.s_addr);
- if (!rt)
- return -EHOSTUNREACH;
-
skb = sock_alloc_send_skb(sk, hlen + 1 + len,
msg->msg_flags & MSG_DONTWAIT, &rc);
if (!skb)
@@ -124,19 +140,46 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
*(u8 *)skb_put(skb, 1) = addr->smctp_type;
rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
- if (rc < 0) {
- kfree_skb(skb);
- return rc;
- }
+ if (rc < 0)
+ goto err_free;
/* set up cb */
cb = __mctp_cb(skb);
cb->net = addr->smctp_network;
+ /* direct addressing */
+ if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
+ DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
+ extaddr, msg->msg_name);
+
+ if (!mctp_sockaddr_ext_is_ok(extaddr) ||
+ extaddr->smctp_halen > sizeof(cb->haddr)) {
+ rc = -EINVAL;
+ goto err_free;
+ }
+
+ cb->ifindex = extaddr->smctp_ifindex;
+ cb->halen = extaddr->smctp_halen;
+ memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen);
+
+ rt = NULL;
+ } else {
+ rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
+ addr->smctp_addr.s_addr);
+ if (!rt) {
+ rc = -EHOSTUNREACH;
+ goto err_free;
+ }
+ }
+
rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
addr->smctp_tag);
return rc ? : len;
+
+err_free:
+ kfree_skb(skb);
+ return rc;
}
static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
@@ -144,6 +187,7 @@ static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
{
DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
struct sock *sk = sock->sk;
+ struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct sk_buff *skb;
size_t msglen;
u8 type;
@@ -191,6 +235,17 @@ static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
(MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
addr->__smctp_pad1 = 0;
msg->msg_namelen = sizeof(*addr);
+
+ if (msk->addr_ext) {
+ DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae,
+ msg->msg_name);
+ msg->msg_namelen = sizeof(*ae);
+ ae->smctp_ifindex = cb->ifindex;
+ ae->smctp_halen = cb->halen;
+ memset(ae->__smctp_pad0, 0x0, sizeof(ae->__smctp_pad0));
+ memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr));
+ memcpy(ae->smctp_haddr, cb->haddr, cb->halen);
+ }
}
rc = len;
@@ -203,18 +258,186 @@ out_free:
return rc;
}
+/* We're done with the key; invalidate, stop reassembly, and remove from lists.
+ */
+static void __mctp_key_remove(struct mctp_sk_key *key, struct net *net,
+ unsigned long flags, unsigned long reason)
+__releases(&key->lock)
+__must_hold(&net->mctp.keys_lock)
+{
+ struct sk_buff *skb;
+
+ trace_mctp_key_release(key, reason);
+ skb = key->reasm_head;
+ key->reasm_head = NULL;
+ key->reasm_dead = true;
+ key->valid = false;
+ mctp_dev_release_key(key->dev, key);
+ spin_unlock_irqrestore(&key->lock, flags);
+
+ hlist_del(&key->hlist);
+ hlist_del(&key->sklist);
+
+ /* unref for the lists */
+ mctp_key_unref(key);
+
+ kfree_skb(skb);
+}
+
static int mctp_setsockopt(struct socket *sock, int level, int optname,
sockptr_t optval, unsigned int optlen)
{
- return -EINVAL;
+ struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
+ int val;
+
+ if (level != SOL_MCTP)
+ return -EINVAL;
+
+ if (optname == MCTP_OPT_ADDR_EXT) {
+ if (optlen != sizeof(int))
+ return -EINVAL;
+ if (copy_from_sockptr(&val, optval, sizeof(int)))
+ return -EFAULT;
+ msk->addr_ext = val;
+ return 0;
+ }
+
+ return -ENOPROTOOPT;
}
static int mctp_getsockopt(struct socket *sock, int level, int optname,
char __user *optval, int __user *optlen)
{
+ struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
+ int len, val;
+
+ if (level != SOL_MCTP)
+ return -EINVAL;
+
+ if (get_user(len, optlen))
+ return -EFAULT;
+
+ if (optname == MCTP_OPT_ADDR_EXT) {
+ if (len != sizeof(int))
+ return -EINVAL;
+ val = !!msk->addr_ext;
+ if (copy_to_user(optval, &val, len))
+ return -EFAULT;
+ return 0;
+ }
+
return -EINVAL;
}
+static int mctp_ioctl_alloctag(struct mctp_sock *msk, unsigned long arg)
+{
+ struct net *net = sock_net(&msk->sk);
+ struct mctp_sk_key *key = NULL;
+ struct mctp_ioc_tag_ctl ctl;
+ unsigned long flags;
+ u8 tag;
+
+ if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl)))
+ return -EFAULT;
+
+ if (ctl.tag)
+ return -EINVAL;
+
+ if (ctl.flags)
+ return -EINVAL;
+
+ key = mctp_alloc_local_tag(msk, ctl.peer_addr, MCTP_ADDR_ANY,
+ true, &tag);
+ if (IS_ERR(key))
+ return PTR_ERR(key);
+
+ ctl.tag = tag | MCTP_TAG_OWNER | MCTP_TAG_PREALLOC;
+ if (copy_to_user((void __user *)arg, &ctl, sizeof(ctl))) {
+ spin_lock_irqsave(&key->lock, flags);
+ __mctp_key_remove(key, net, flags, MCTP_TRACE_KEY_DROPPED);
+ mctp_key_unref(key);
+ return -EFAULT;
+ }
+
+ mctp_key_unref(key);
+ return 0;
+}
+
+static int mctp_ioctl_droptag(struct mctp_sock *msk, unsigned long arg)
+{
+ struct net *net = sock_net(&msk->sk);
+ struct mctp_ioc_tag_ctl ctl;
+ unsigned long flags, fl2;
+ struct mctp_sk_key *key;
+ struct hlist_node *tmp;
+ int rc;
+ u8 tag;
+
+ if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl)))
+ return -EFAULT;
+
+ if (ctl.flags)
+ return -EINVAL;
+
+ /* Must be a local tag, TO set, preallocated */
+ if ((ctl.tag & ~MCTP_TAG_MASK) != (MCTP_TAG_OWNER | MCTP_TAG_PREALLOC))
+ return -EINVAL;
+
+ tag = ctl.tag & MCTP_TAG_MASK;
+ rc = -EINVAL;
+
+ spin_lock_irqsave(&net->mctp.keys_lock, flags);
+ hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
+ /* we do an irqsave here, even though we know the irq state,
+ * so we have the flags to pass to __mctp_key_remove
+ */
+ spin_lock_irqsave(&key->lock, fl2);
+ if (key->manual_alloc &&
+ ctl.peer_addr == key->peer_addr &&
+ tag == key->tag) {
+ __mctp_key_remove(key, net, fl2,
+ MCTP_TRACE_KEY_DROPPED);
+ rc = 0;
+ } else {
+ spin_unlock_irqrestore(&key->lock, fl2);
+ }
+ }
+ spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+
+ return rc;
+}
+
+static int mctp_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
+{
+ struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
+
+ switch (cmd) {
+ case SIOCMCTPALLOCTAG:
+ return mctp_ioctl_alloctag(msk, arg);
+ case SIOCMCTPDROPTAG:
+ return mctp_ioctl_droptag(msk, arg);
+ }
+
+ return -EINVAL;
+}
+
+#ifdef CONFIG_COMPAT
+static int mctp_compat_ioctl(struct socket *sock, unsigned int cmd,
+ unsigned long arg)
+{
+ void __user *argp = compat_ptr(arg);
+
+ switch (cmd) {
+ /* These have compatible ptr layouts */
+ case SIOCMCTPALLOCTAG:
+ case SIOCMCTPDROPTAG:
+ return mctp_ioctl(sock, cmd, (unsigned long)argp);
+ }
+
+ return -ENOIOCTLCMD;
+}
+#endif
+
static const struct proto_ops mctp_dgram_ops = {
.family = PF_MCTP,
.release = mctp_release,
@@ -224,7 +447,7 @@ static const struct proto_ops mctp_dgram_ops = {
.accept = sock_no_accept,
.getname = sock_no_getname,
.poll = datagram_poll,
- .ioctl = sock_no_ioctl,
+ .ioctl = mctp_ioctl,
.gettstamp = sock_gettstamp,
.listen = sock_no_listen,
.shutdown = sock_no_shutdown,
@@ -234,18 +457,67 @@ static const struct proto_ops mctp_dgram_ops = {
.recvmsg = mctp_recvmsg,
.mmap = sock_no_mmap,
.sendpage = sock_no_sendpage,
+#ifdef CONFIG_COMPAT
+ .compat_ioctl = mctp_compat_ioctl,
+#endif
};
+static void mctp_sk_expire_keys(struct timer_list *timer)
+{
+ struct mctp_sock *msk = container_of(timer, struct mctp_sock,
+ key_expiry);
+ struct net *net = sock_net(&msk->sk);
+ unsigned long next_expiry, flags, fl2;
+ struct mctp_sk_key *key;
+ struct hlist_node *tmp;
+ bool next_expiry_valid = false;
+
+ spin_lock_irqsave(&net->mctp.keys_lock, flags);
+
+ hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
+ /* don't expire. manual_alloc is immutable, no locking
+ * required.
+ */
+ if (key->manual_alloc)
+ continue;
+
+ spin_lock_irqsave(&key->lock, fl2);
+ if (!time_after_eq(key->expiry, jiffies)) {
+ __mctp_key_remove(key, net, fl2,
+ MCTP_TRACE_KEY_TIMEOUT);
+ continue;
+ }
+
+ if (next_expiry_valid) {
+ if (time_before(key->expiry, next_expiry))
+ next_expiry = key->expiry;
+ } else {
+ next_expiry = key->expiry;
+ next_expiry_valid = true;
+ }
+ spin_unlock_irqrestore(&key->lock, fl2);
+ }
+
+ spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+
+ if (next_expiry_valid)
+ mod_timer(timer, next_expiry);
+}
+
static int mctp_sk_init(struct sock *sk)
{
struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
INIT_HLIST_HEAD(&msk->keys);
+ timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0);
return 0;
}
static void mctp_sk_close(struct sock *sk, long timeout)
{
+ struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
+
+ del_timer_sync(&msk->key_expiry);
sk_common_release(sk);
}
@@ -264,9 +536,9 @@ static void mctp_sk_unhash(struct sock *sk)
{
struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct net *net = sock_net(sk);
+ unsigned long flags, fl2;
struct mctp_sk_key *key;
struct hlist_node *tmp;
- unsigned long flags;
/* remove from any type-based binds */
mutex_lock(&net->mctp.bind_lock);
@@ -276,21 +548,10 @@ static void mctp_sk_unhash(struct sock *sk)
/* remove tag allocations */
spin_lock_irqsave(&net->mctp.keys_lock, flags);
hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
- hlist_del_rcu(&key->sklist);
- hlist_del_rcu(&key->hlist);
-
- spin_lock(&key->reasm_lock);
- if (key->reasm_head)
- kfree_skb(key->reasm_head);
- key->reasm_head = NULL;
- key->reasm_dead = true;
- spin_unlock(&key->reasm_lock);
-
- kfree_rcu(key, rcu);
+ spin_lock_irqsave(&key->lock, fl2);
+ __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_CLOSED);
}
spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
-
- synchronize_rcu();
}
static struct proto mctp_proto = {
@@ -398,7 +659,7 @@ static __exit void mctp_exit(void)
sock_unregister(PF_MCTP);
}
-module_init(mctp_init);
+subsys_initcall(mctp_init);
module_exit(mctp_exit);
MODULE_DESCRIPTION("MCTP core");