Skip to content

Commit

Permalink
[CUDA EP] Implement Add/Sub/Mul/Div element wise operations for (u)in…
Browse files Browse the repository at this point in the history
…t8 and (u)int16
  • Loading branch information
Zyrin committed Dec 18, 2024
1 parent 9f52833 commit 92d0502
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
32 changes: 32 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1218,29 +1218,45 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, C
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Relu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Relu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Div);
Expand Down Expand Up @@ -2183,29 +2199,45 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Relu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Relu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Div)>,
Expand Down
22 changes: 13 additions & 9 deletions onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,15 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, Bin
BINARY_OP_TYPED(name, ver, double) \
BINARY_OP_TYPED(name, ver, BFloat16)

#define BINARY_OP_UZILHFD(name, ver) \
BINARY_OP_TYPED(name, ver, uint32_t) \
BINARY_OP_TYPED(name, ver, uint64_t) \
BINARY_OP_TYPED(name, ver, int32_t) \
BINARY_OP_TYPED(name, ver, int64_t) \
#define BINARY_OP_BWUZCSILHFD(name, ver) \
BINARY_OP_TYPED(name, ver, uint8_t) \
BINARY_OP_TYPED(name, ver, uint16_t) \
BINARY_OP_TYPED(name, ver, uint32_t) \
BINARY_OP_TYPED(name, ver, uint64_t) \
BINARY_OP_TYPED(name, ver, int8_t) \
BINARY_OP_TYPED(name, ver, int16_t) \
BINARY_OP_TYPED(name, ver, int32_t) \
BINARY_OP_TYPED(name, ver, int64_t) \
BINARY_OP_HFD(name, ver)

#define BINARY_OP_REGISTER_VERSIONED_OIL(name, startver, endver) \
Expand Down Expand Up @@ -279,10 +283,10 @@ BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(Sub, 13, 13)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(Mul, 13, 13)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(Div, 13, 13)

BINARY_OP_UZILHFD(Add, 14)
BINARY_OP_UZILHFD(Sub, 14)
BINARY_OP_UZILHFD(Mul, 14)
BINARY_OP_UZILHFD(Div, 14)
BINARY_OP_BWUZCSILHFD(Add, 14)
BINARY_OP_BWUZCSILHFD(Sub, 14)
BINARY_OP_BWUZCSILHFD(Mul, 14)
BINARY_OP_BWUZCSILHFD(Div, 14)

BINARY_OP_REGISTER_VERSIONED_CLASS_HFD(Pow, Pow_7, 7, 11)
BINARY_LOGICALOP_TYPED(And, 7, bool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ namespace cuda {
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)

#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint8_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint16_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int8_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int16_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)

#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZIL(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \
Expand Down Expand Up @@ -135,11 +149,11 @@ BINARY_OPS()
// D: double
// O: bool

SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Add)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Add)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Add, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Sub)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Mul)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Div)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sub)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Mul)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Div)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Pow_7)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(And, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Or, bool)
Expand Down

0 comments on commit 92d0502

Please sign in to comment.