Class: Ignis::MathDx::GemmKernel
- Inherits:
-
Object
- Object
- Ignis::MathDx::GemmKernel
- Defined in:
- lib/nvruby/mathdx/gemm_kernel.rb
Overview
Device-side GEMM kernel using cuBLASDx Generates and compiles CUDA C++ code that uses cuBLASDx for thread block GEMM
cuBLASDx enables embedding GEMM operations inside CUDA kernels, allowing fusion with other operations (epilogs) to reduce memory bandwidth.
Instance Attribute Summary collapse
-
#block_size ⇒ Integer
readonly
Thread block size.
-
#compiled ⇒ Boolean
readonly
Whether kernel is compiled.
-
#dtype ⇒ Symbol
readonly
Data type.
-
#epilog ⇒ Symbol?
readonly
Epilog operation.
-
#k ⇒ Integer
readonly
Matrix K dimension.
-
#m ⇒ Integer
readonly
Matrix M dimension.
-
#n ⇒ Integer
readonly
Matrix N dimension.
Instance Method Summary collapse
-
#compile!(device_id: 0) ⇒ self
Compile the GEMM kernel.
-
#destroy! ⇒ void
Release kernel resources.
-
#execute(a, b, c, alpha: 1.0, beta: 0.0, stream: nil) ⇒ NvArray
Execute the GEMM kernel: C = alpha * A @ B + beta * C.
-
#initialize(m:, n:, k:, dtype: :float32, epilog: nil, block_size: 128) ⇒ GemmKernel
constructor
Initialize GEMM kernel configuration.
Constructor Details
#initialize(m:, n:, k:, dtype: :float32, epilog: nil, block_size: 128) ⇒ GemmKernel
Initialize GEMM kernel configuration
44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 44 def initialize(m:, n:, k:, dtype: :float32, epilog: nil, block_size: 128) validate_dimensions!(m, n, k) validate_dtype!(dtype) validate_epilog!(epilog) if epilog @m = m @n = n @k = k @dtype = dtype @epilog = epilog @block_size = block_size @compiled = false @kernel = nil end |
Instance Attribute Details
#block_size ⇒ Integer (readonly)
Returns Thread block size.
32 33 34 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 32 def block_size @block_size end |
#compiled ⇒ Boolean (readonly)
Returns Whether kernel is compiled.
35 36 37 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 35 def compiled @compiled end |
#dtype ⇒ Symbol (readonly)
Returns Data type.
26 27 28 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 26 def dtype @dtype end |
#epilog ⇒ Symbol? (readonly)
Returns Epilog operation.
29 30 31 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 29 def epilog @epilog end |
#k ⇒ Integer (readonly)
Returns Matrix K dimension.
23 24 25 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 23 def k @k end |
#m ⇒ Integer (readonly)
Returns Matrix M dimension.
17 18 19 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 17 def m @m end |
#n ⇒ Integer (readonly)
Returns Matrix N dimension.
20 21 22 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 20 def n @n end |
Instance Method Details
#compile!(device_id: 0) ⇒ self
Compile the GEMM kernel
62 63 64 65 66 67 68 69 70 71 72 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 62 def compile!(device_id: 0) source = generate_source @kernel = Ignis::JIT::Compiler.compile( source, "cublasdx_gemm", device_id: device_id, options: ) @compiled = true self end |
#destroy! ⇒ void
This method returns an undefined value.
Release kernel resources
121 122 123 124 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 121 def destroy! @kernel = nil @compiled = false end |
#execute(a, b, c, alpha: 1.0, beta: 0.0, stream: nil) ⇒ NvArray
Execute the GEMM kernel: C = alpha * A @ B + beta * C
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 82 def execute(a, b, c, alpha: 1.0, beta: 0.0, stream: nil) raise StateError, "Kernel not compiled. Call compile! first." unless @compiled validate_execution_inputs!(a, b, c) # Ensure arrays are on device a_dev = a.on_device? ? a : a.to_device b_dev = b.on_device? ? b : b.to_device c_dev = c.on_device? ? c : c.to_device # Calculate grid dimensions grid_m = (a_dev.shape[0] + @m - 1) / @m grid_n = (b_dev.shape[1] + @n - 1) / @n # Launch kernel @kernel.launch( grid: [grid_m, grid_n], block: [@block_size], args: [ a_dev.device_ptr, b_dev.device_ptr, c_dev.device_ptr, a_dev.shape[0], # M b_dev.shape[1], # N a_dev.shape[1], # K a_dev.shape[1], # lda b_dev.shape[1], # ldb c_dev.shape[1], # ldc alpha, beta ], stream: stream ) c_dev end |