summaryrefslogtreecommitdiff
path: root/drivers/misc
diff options
context:
space:
mode:
authorOfir Bitton <obitton@habana.ai>2020-07-30 14:56:38 +0300
committerGreg Kroah-Hartman <gregkh@linuxfoundation.org>2020-09-09 20:14:08 +0300
commitfb8b4592f9585a9f50683ad3423c74a21c9f0991 (patch)
tree7697bcdde084e8a00f279e7c5cd1c5202edb3b55 /drivers/misc
parentb18c6071868cef30e178bdae8f4479e162ec0747 (diff)
downloadlinux-fb8b4592f9585a9f50683ad3423c74a21c9f0991.tar.xz
habanalabs: validate packet id during CB parse
[ Upstream commit bc75be24fa88ef10eecaff2b2a9ada8189e5ab5d ] During command buffer parsing, driver extracts packet id from user buffer. Driver must validate this packet id, since it is being used in order to extract information from internal structures. Signed-off-by: Ofir Bitton <obitton@habana.ai> Reviewed-by: Oded Gabbay <oded.gabbay@gmail.com> Signed-off-by: Oded Gabbay <oded.gabbay@gmail.com> Signed-off-by: Sasha Levin <sashal@kernel.org>
Diffstat (limited to 'drivers/misc')
-rw-r--r--drivers/misc/habanalabs/gaudi/gaudi.c35
-rw-r--r--drivers/misc/habanalabs/goya/goya.c31
2 files changed, 66 insertions, 0 deletions
diff --git a/drivers/misc/habanalabs/gaudi/gaudi.c b/drivers/misc/habanalabs/gaudi/gaudi.c
index 637a9d608707..0261f60df563 100644
--- a/drivers/misc/habanalabs/gaudi/gaudi.c
+++ b/drivers/misc/habanalabs/gaudi/gaudi.c
@@ -154,6 +154,29 @@ static const u16 gaudi_packet_sizes[MAX_PACKET_ID] = {
[PACKET_LOAD_AND_EXE] = sizeof(struct packet_load_and_exe)
};
+static inline bool validate_packet_id(enum packet_id id)
+{
+ switch (id) {
+ case PACKET_WREG_32:
+ case PACKET_WREG_BULK:
+ case PACKET_MSG_LONG:
+ case PACKET_MSG_SHORT:
+ case PACKET_CP_DMA:
+ case PACKET_REPEAT:
+ case PACKET_MSG_PROT:
+ case PACKET_FENCE:
+ case PACKET_LIN_DMA:
+ case PACKET_NOP:
+ case PACKET_STOP:
+ case PACKET_ARB_POINT:
+ case PACKET_WAIT:
+ case PACKET_LOAD_AND_EXE:
+ return true;
+ default:
+ return false;
+ }
+}
+
static const char * const
gaudi_tpc_interrupts_cause[GAUDI_NUM_OF_TPC_INTR_CAUSE] = {
"tpc_address_exceed_slm",
@@ -3859,6 +3882,12 @@ static int gaudi_validate_cb(struct hl_device *hdev,
PACKET_HEADER_PACKET_ID_MASK) >>
PACKET_HEADER_PACKET_ID_SHIFT);
+ if (!validate_packet_id(pkt_id)) {
+ dev_err(hdev->dev, "Invalid packet id %u\n", pkt_id);
+ rc = -EINVAL;
+ break;
+ }
+
pkt_size = gaudi_packet_sizes[pkt_id];
cb_parsed_length += pkt_size;
if (cb_parsed_length > parser->user_cb_size) {
@@ -4082,6 +4111,12 @@ static int gaudi_patch_cb(struct hl_device *hdev,
PACKET_HEADER_PACKET_ID_MASK) >>
PACKET_HEADER_PACKET_ID_SHIFT);
+ if (!validate_packet_id(pkt_id)) {
+ dev_err(hdev->dev, "Invalid packet id %u\n", pkt_id);
+ rc = -EINVAL;
+ break;
+ }
+
pkt_size = gaudi_packet_sizes[pkt_id];
cb_parsed_length += pkt_size;
if (cb_parsed_length > parser->user_cb_size) {
diff --git a/drivers/misc/habanalabs/goya/goya.c b/drivers/misc/habanalabs/goya/goya.c
index 88460b2138d8..c179085ced7b 100644
--- a/drivers/misc/habanalabs/goya/goya.c
+++ b/drivers/misc/habanalabs/goya/goya.c
@@ -139,6 +139,25 @@ static u16 goya_packet_sizes[MAX_PACKET_ID] = {
[PACKET_STOP] = sizeof(struct packet_stop)
};
+static inline bool validate_packet_id(enum packet_id id)
+{
+ switch (id) {
+ case PACKET_WREG_32:
+ case PACKET_WREG_BULK:
+ case PACKET_MSG_LONG:
+ case PACKET_MSG_SHORT:
+ case PACKET_CP_DMA:
+ case PACKET_MSG_PROT:
+ case PACKET_FENCE:
+ case PACKET_LIN_DMA:
+ case PACKET_NOP:
+ case PACKET_STOP:
+ return true;
+ default:
+ return false;
+ }
+}
+
static u64 goya_mmu_regs[GOYA_MMU_REGS_NUM] = {
mmDMA_QM_0_GLBL_NON_SECURE_PROPS,
mmDMA_QM_1_GLBL_NON_SECURE_PROPS,
@@ -3381,6 +3400,12 @@ static int goya_validate_cb(struct hl_device *hdev,
PACKET_HEADER_PACKET_ID_MASK) >>
PACKET_HEADER_PACKET_ID_SHIFT);
+ if (!validate_packet_id(pkt_id)) {
+ dev_err(hdev->dev, "Invalid packet id %u\n", pkt_id);
+ rc = -EINVAL;
+ break;
+ }
+
pkt_size = goya_packet_sizes[pkt_id];
cb_parsed_length += pkt_size;
if (cb_parsed_length > parser->user_cb_size) {
@@ -3616,6 +3641,12 @@ static int goya_patch_cb(struct hl_device *hdev,
PACKET_HEADER_PACKET_ID_MASK) >>
PACKET_HEADER_PACKET_ID_SHIFT);
+ if (!validate_packet_id(pkt_id)) {
+ dev_err(hdev->dev, "Invalid packet id %u\n", pkt_id);
+ rc = -EINVAL;
+ break;
+ }
+
pkt_size = goya_packet_sizes[pkt_id];
cb_parsed_length += pkt_size;
if (cb_parsed_length > parser->user_cb_size) {