Class: Ignis::MathDx::GemmKernel

Inherits:
Object
  • Object
show all
Defined in:
lib/nvruby/mathdx/gemm_kernel.rb

Overview

Device-side GEMM kernel using cuBLASDx Generates and compiles CUDA C++ code that uses cuBLASDx for thread block GEMM

cuBLASDx enables embedding GEMM operations inside CUDA kernels, allowing fusion with other operations (epilogs) to reduce memory bandwidth.

Examples:

Basic GEMM kernel

kernel = GemmKernel.new(m: 64, n: 64, k: 64, dtype: :float16)
kernel.compile!
kernel.execute(a, b, c)

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(m:, n:, k:, dtype: :float32, epilog: nil, block_size: 128) ⇒ GemmKernel

Initialize GEMM kernel configuration

Parameters:

  • m (Integer)

    Output rows (must be power of 2, <= 128)

  • n (Integer)

    Output columns (must be power of 2, <= 128)

  • k (Integer)

    Inner dimension (must be power of 2)

  • dtype (Symbol) (defaults to: :float32)

    Data type (:float16, :float32, :float64)

  • epilog (Symbol, nil) (defaults to: nil)

    Epilog (:relu, :gelu, :sigmoid, :tanh, nil)

  • block_size (Integer) (defaults to: 128)

    Threads per block



44
45
46
47
48
49
50
51
52
53
54
55
56
57
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 44

def initialize(m:, n:, k:, dtype: :float32, epilog: nil, block_size: 128)
  validate_dimensions!(m, n, k)
  validate_dtype!(dtype)
  validate_epilog!(epilog) if epilog

  @m = m
  @n = n
  @k = k
  @dtype = dtype
  @epilog = epilog
  @block_size = block_size
  @compiled = false
  @kernel = nil
end

Instance Attribute Details

#block_sizeInteger (readonly)

Returns Thread block size.

Returns:

  • (Integer)

    Thread block size



32
33
34
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 32

def block_size
  @block_size
end

#compiledBoolean (readonly)

Returns Whether kernel is compiled.

Returns:

  • (Boolean)

    Whether kernel is compiled



35
36
37
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 35

def compiled
  @compiled
end

#dtypeSymbol (readonly)

Returns Data type.

Returns:

  • (Symbol)

    Data type



26
27
28
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 26

def dtype
  @dtype
end

#epilogSymbol? (readonly)

Returns Epilog operation.

Returns:

  • (Symbol, nil)

    Epilog operation



29
30
31
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 29

def epilog
  @epilog
end

#kInteger (readonly)

Returns Matrix K dimension.

Returns:

  • (Integer)

    Matrix K dimension



23
24
25
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 23

def k
  @k
end

#mInteger (readonly)

Returns Matrix M dimension.

Returns:

  • (Integer)

    Matrix M dimension



17
18
19
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 17

def m
  @m
end

#nInteger (readonly)

Returns Matrix N dimension.

Returns:

  • (Integer)

    Matrix N dimension



20
21
22
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 20

def n
  @n
end

Instance Method Details

#compile!(device_id: 0) ⇒ self

Compile the GEMM kernel

Parameters:

  • device_id (Integer) (defaults to: 0)

    Target GPU device

Returns:

  • (self)


62
63
64
65
66
67
68
69
70
71
72
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 62

def compile!(device_id: 0)
  source = generate_source
  @kernel = Ignis::JIT::Compiler.compile(
    source,
    "cublasdx_gemm",
    device_id: device_id,
    options: nvrtc_options
  )
  @compiled = true
  self
end

#destroy!void

This method returns an undefined value.

Release kernel resources



121
122
123
124
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 121

def destroy!
  @kernel = nil
  @compiled = false
end

#execute(a, b, c, alpha: 1.0, beta: 0.0, stream: nil) ⇒ NvArray

Execute the GEMM kernel: C = alpha * A @ B + beta * C

Parameters:

  • a (NvArray)

    Input matrix A (M x K)

  • b (NvArray)

    Input matrix B (K x N)

  • c (NvArray)

    Output matrix C (M x N)

  • alpha (Float) (defaults to: 1.0)

    Scaling factor for A @ B

  • beta (Float) (defaults to: 0.0)

    Scaling factor for C

  • stream (CUDA::Stream, nil) (defaults to: nil)

    CUDA stream

Returns:

Raises:



82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# File 'lib/nvruby/mathdx/gemm_kernel.rb', line 82

def execute(a, b, c, alpha: 1.0, beta: 0.0, stream: nil)
  raise StateError, "Kernel not compiled. Call compile! first." unless @compiled

  validate_execution_inputs!(a, b, c)

  # Ensure arrays are on device
  a_dev = a.on_device? ? a : a.to_device
  b_dev = b.on_device? ? b : b.to_device
  c_dev = c.on_device? ? c : c.to_device

  # Calculate grid dimensions
  grid_m = (a_dev.shape[0] + @m - 1) / @m
  grid_n = (b_dev.shape[1] + @n - 1) / @n

  # Launch kernel
  @kernel.launch(
    grid: [grid_m, grid_n],
    block: [@block_size],
    args: [
      a_dev.device_ptr,
      b_dev.device_ptr,
      c_dev.device_ptr,
      a_dev.shape[0],         # M
      b_dev.shape[1],         # N
      a_dev.shape[1],         # K
      a_dev.shape[1],         # lda
      b_dev.shape[1],         # ldb
      c_dev.shape[1],         # ldc
      alpha,
      beta
    ],
    stream: stream
  )

  c_dev
end