// SPDX-License-Identifier: GPL-2.0 #include "comm.h" #include #include #include #include #include #include "rwsem.h" DECLARE_RC_STRUCT(comm_str) { refcount_t refcnt; char str[]; }; static struct comm_strs { struct rw_semaphore lock; struct comm_str **strs; int num_strs; int capacity; } _comm_strs; static void comm_strs__remove_if_last(struct comm_str *cs); static void comm_strs__init(void) { init_rwsem(&_comm_strs.lock); _comm_strs.capacity = 16; _comm_strs.num_strs = 0; _comm_strs.strs = calloc(16, sizeof(*_comm_strs.strs)); } static struct comm_strs *comm_strs__get(void) { static pthread_once_t comm_strs_type_once = PTHREAD_ONCE_INIT; pthread_once(&comm_strs_type_once, comm_strs__init); return &_comm_strs; } static refcount_t *comm_str__refcnt(struct comm_str *cs) { return &RC_CHK_ACCESS(cs)->refcnt; } static const char *comm_str__str(const struct comm_str *cs) { return &RC_CHK_ACCESS(cs)->str[0]; } static struct comm_str *comm_str__get(struct comm_str *cs) { struct comm_str *result; if (RC_CHK_GET(result, cs)) refcount_inc_not_zero(comm_str__refcnt(cs)); return result; } static void comm_str__put(struct comm_str *cs) { if (!cs) return; if (refcount_dec_and_test(comm_str__refcnt(cs))) { RC_CHK_FREE(cs); } else { if (refcount_read(comm_str__refcnt(cs)) == 1) comm_strs__remove_if_last(cs); RC_CHK_PUT(cs); } } static struct comm_str *comm_str__new(const char *str) { struct comm_str *result = NULL; RC_STRUCT(comm_str) *cs; cs = malloc(sizeof(*cs) + strlen(str) + 1); if (ADD_RC_CHK(result, cs)) { refcount_set(comm_str__refcnt(result), 1); strcpy(&cs->str[0], str); } return result; } static int comm_str__cmp(const void *_lhs, const void *_rhs) { const struct comm_str *lhs = *(const struct comm_str * const *)_lhs; const struct comm_str *rhs = *(const struct comm_str * const *)_rhs; return strcmp(comm_str__str(lhs), comm_str__str(rhs)); } static int comm_str__search(const void *_key, const void *_member) { const char *key = _key; const struct comm_str *member = *(const struct comm_str * const *)_member; return strcmp(key, comm_str__str(member)); } static void comm_strs__remove_if_last(struct comm_str *cs) { struct comm_strs *comm_strs = comm_strs__get(); down_write(&comm_strs->lock); /* * Are there only references from the array, if so remove the array * reference under the write lock so that we don't race with findnew. */ if (refcount_read(comm_str__refcnt(cs)) == 1) { struct comm_str **entry; entry = bsearch(comm_str__str(cs), comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *), comm_str__search); comm_str__put(*entry); for (int i = entry - comm_strs->strs; i < comm_strs->num_strs - 1; i++) comm_strs->strs[i] = comm_strs->strs[i + 1]; comm_strs->num_strs--; } up_write(&comm_strs->lock); } static struct comm_str *__comm_strs__find(struct comm_strs *comm_strs, const char *str) { struct comm_str **result; result = bsearch(str, comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *), comm_str__search); if (!result) return NULL; return comm_str__get(*result); } static struct comm_str *comm_strs__findnew(const char *str) { struct comm_strs *comm_strs = comm_strs__get(); struct comm_str *result; if (!comm_strs) return NULL; down_read(&comm_strs->lock); result = __comm_strs__find(comm_strs, str); up_read(&comm_strs->lock); if (result) return result; down_write(&comm_strs->lock); result = __comm_strs__find(comm_strs, str); if (!result) { if (comm_strs->num_strs == comm_strs->capacity) { struct comm_str **tmp; tmp = reallocarray(comm_strs->strs, comm_strs->capacity + 16, sizeof(*comm_strs->strs)); if (!tmp) { up_write(&comm_strs->lock); return NULL; } comm_strs->strs = tmp; comm_strs->capacity += 16; } result = comm_str__new(str); if (result) { comm_strs->strs[comm_strs->num_strs++] = result; qsort(comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *), comm_str__cmp); } } up_write(&comm_strs->lock); return comm_str__get(result); } struct comm *comm__new(const char *str, u64 timestamp, bool exec) { struct comm *comm = zalloc(sizeof(*comm)); if (!comm) return NULL; comm->start = timestamp; comm->exec = exec; comm->comm_str = comm_strs__findnew(str); if (!comm->comm_str) { free(comm); return NULL; } return comm; } int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec) { struct comm_str *new, *old = comm->comm_str; new = comm_strs__findnew(str); if (!new) return -ENOMEM; comm_str__put(old); comm->comm_str = new; comm->start = timestamp; if (exec) comm->exec = true; return 0; } void comm__free(struct comm *comm) { comm_str__put(comm->comm_str); free(comm); } const char *comm__str(const struct comm *comm) { return comm_str__str(comm->comm_str); }