summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--drivers/vhost/vhost.c54
-rw-r--r--drivers/vhost/vhost.h2
-rw-r--r--include/linux/sched/vhost_task.h3
-rw-r--r--kernel/vhost_task.c53
4 files changed, 88 insertions, 24 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index c6448ff37768..b60955682474 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -273,7 +273,7 @@ static void __vhost_worker_flush(struct vhost_worker *worker)
{
struct vhost_flush_struct flush;
- if (!worker->attachment_cnt)
+ if (!worker->attachment_cnt || worker->killed)
return;
init_completion(&flush.wait_event);
@@ -388,7 +388,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
__vhost_vq_meta_reset(vq);
}
-static bool vhost_worker(void *data)
+static bool vhost_run_work_list(void *data)
{
struct vhost_worker *worker = data;
struct vhost_work *work, *work_next;
@@ -413,6 +413,40 @@ static bool vhost_worker(void *data)
return !!node;
}
+static void vhost_worker_killed(void *data)
+{
+ struct vhost_worker *worker = data;
+ struct vhost_dev *dev = worker->dev;
+ struct vhost_virtqueue *vq;
+ int i, attach_cnt = 0;
+
+ mutex_lock(&worker->mutex);
+ worker->killed = true;
+
+ for (i = 0; i < dev->nvqs; i++) {
+ vq = dev->vqs[i];
+
+ mutex_lock(&vq->mutex);
+ if (worker ==
+ rcu_dereference_check(vq->worker,
+ lockdep_is_held(&vq->mutex))) {
+ rcu_assign_pointer(vq->worker, NULL);
+ attach_cnt++;
+ }
+ mutex_unlock(&vq->mutex);
+ }
+
+ worker->attachment_cnt -= attach_cnt;
+ if (attach_cnt)
+ synchronize_rcu();
+ /*
+ * Finish vhost_worker_flush calls and any other works that snuck in
+ * before the synchronize_rcu.
+ */
+ vhost_run_work_list(worker);
+ mutex_unlock(&worker->mutex);
+}
+
static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
{
kfree(vq->indirect);
@@ -627,9 +661,11 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
if (!worker)
return NULL;
+ worker->dev = dev;
snprintf(name, sizeof(name), "vhost-%d", current->pid);
- vtsk = vhost_task_create(vhost_worker, worker, name);
+ vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
+ worker, name);
if (!vtsk)
goto free_worker;
@@ -661,6 +697,11 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
struct vhost_worker *old_worker;
mutex_lock(&worker->mutex);
+ if (worker->killed) {
+ mutex_unlock(&worker->mutex);
+ return;
+ }
+
mutex_lock(&vq->mutex);
old_worker = rcu_dereference_check(vq->worker,
@@ -681,6 +722,11 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
* device wide flushes which doesn't use RCU for execution.
*/
mutex_lock(&old_worker->mutex);
+ if (old_worker->killed) {
+ mutex_unlock(&old_worker->mutex);
+ return;
+ }
+
/*
* We don't want to call synchronize_rcu for every vq during setup
* because it will slow down VM startup. If we haven't done
@@ -758,7 +804,7 @@ static int vhost_free_worker(struct vhost_dev *dev,
return -ENODEV;
mutex_lock(&worker->mutex);
- if (worker->attachment_cnt) {
+ if (worker->attachment_cnt || worker->killed) {
mutex_unlock(&worker->mutex);
return -EBUSY;
}
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 91ade037f08e..bb75a292d50c 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -28,12 +28,14 @@ struct vhost_work {
struct vhost_worker {
struct vhost_task *vtsk;
+ struct vhost_dev *dev;
/* Used to serialize device wide flushing with worker swapping. */
struct mutex mutex;
struct llist_head work_list;
u64 kcov_handle;
u32 id;
int attachment_cnt;
+ bool killed;
};
/* Poll a file (eventfd or socket) */
diff --git a/include/linux/sched/vhost_task.h b/include/linux/sched/vhost_task.h
index bc60243d43b3..25446c5d3508 100644
--- a/include/linux/sched/vhost_task.h
+++ b/include/linux/sched/vhost_task.h
@@ -4,7 +4,8 @@
struct vhost_task;
-struct vhost_task *vhost_task_create(bool (*fn)(void *), void *arg,
+struct vhost_task *vhost_task_create(bool (*fn)(void *),
+ void (*handle_kill)(void *), void *arg,
const char *name);
void vhost_task_start(struct vhost_task *vtsk);
void vhost_task_stop(struct vhost_task *vtsk);
diff --git a/kernel/vhost_task.c b/kernel/vhost_task.c
index da35e5b7f047..8800f5acc007 100644
--- a/kernel/vhost_task.c
+++ b/kernel/vhost_task.c
@@ -10,38 +10,32 @@
enum vhost_task_flags {
VHOST_TASK_FLAGS_STOP,
+ VHOST_TASK_FLAGS_KILLED,
};
struct vhost_task {
bool (*fn)(void *data);
+ void (*handle_sigkill)(void *data);
void *data;
struct completion exited;
unsigned long flags;
struct task_struct *task;
+ /* serialize SIGKILL and vhost_task_stop calls */
+ struct mutex exit_mutex;
};
static int vhost_task_fn(void *data)
{
struct vhost_task *vtsk = data;
- bool dead = false;
for (;;) {
bool did_work;
- if (!dead && signal_pending(current)) {
+ if (signal_pending(current)) {
struct ksignal ksig;
- /*
- * Calling get_signal will block in SIGSTOP,
- * or clear fatal_signal_pending, but remember
- * what was set.
- *
- * This thread won't actually exit until all
- * of the file descriptors are closed, and
- * the release function is called.
- */
- dead = get_signal(&ksig);
- if (dead)
- clear_thread_flag(TIF_SIGPENDING);
+
+ if (get_signal(&ksig))
+ break;
}
/* mb paired w/ vhost_task_stop */
@@ -57,7 +51,19 @@ static int vhost_task_fn(void *data)
schedule();
}
+ mutex_lock(&vtsk->exit_mutex);
+ /*
+ * If a vhost_task_stop and SIGKILL race, we can ignore the SIGKILL.
+ * When the vhost layer has called vhost_task_stop it's already stopped
+ * new work and flushed.
+ */
+ if (!test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) {
+ set_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags);
+ vtsk->handle_sigkill(vtsk->data);
+ }
+ mutex_unlock(&vtsk->exit_mutex);
complete(&vtsk->exited);
+
do_exit(0);
}
@@ -78,12 +84,17 @@ EXPORT_SYMBOL_GPL(vhost_task_wake);
* @vtsk: vhost_task to stop
*
* vhost_task_fn ensures the worker thread exits after
- * VHOST_TASK_FLAGS_SOP becomes true.
+ * VHOST_TASK_FLAGS_STOP becomes true.
*/
void vhost_task_stop(struct vhost_task *vtsk)
{
- set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags);
- vhost_task_wake(vtsk);
+ mutex_lock(&vtsk->exit_mutex);
+ if (!test_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags)) {
+ set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags);
+ vhost_task_wake(vtsk);
+ }
+ mutex_unlock(&vtsk->exit_mutex);
+
/*
* Make sure vhost_task_fn is no longer accessing the vhost_task before
* freeing it below.
@@ -96,14 +107,16 @@ EXPORT_SYMBOL_GPL(vhost_task_stop);
/**
* vhost_task_create - create a copy of a task to be used by the kernel
* @fn: vhost worker function
- * @arg: data to be passed to fn
+ * @handle_sigkill: vhost function to handle when we are killed
+ * @arg: data to be passed to fn and handled_kill
* @name: the thread's name
*
* This returns a specialized task for use by the vhost layer or NULL on
* failure. The returned task is inactive, and the caller must fire it up
* through vhost_task_start().
*/
-struct vhost_task *vhost_task_create(bool (*fn)(void *), void *arg,
+struct vhost_task *vhost_task_create(bool (*fn)(void *),
+ void (*handle_sigkill)(void *), void *arg,
const char *name)
{
struct kernel_clone_args args = {
@@ -122,8 +135,10 @@ struct vhost_task *vhost_task_create(bool (*fn)(void *), void *arg,
if (!vtsk)
return NULL;
init_completion(&vtsk->exited);
+ mutex_init(&vtsk->exit_mutex);
vtsk->data = arg;
vtsk->fn = fn;
+ vtsk->handle_sigkill = handle_sigkill;
args.fn_arg = vtsk;