Module: Ignis::LinAlg::OptimizedMatmul

Defined in:
lib/nvruby/linalg/optimized_matmul.rb

Overview

Optimized matrix multiplication using cuBLASLt

Features:

  • Heuristic-based algorithm selection

  • Large workspace for more algorithm choices

  • Descriptor-based API for optimal performance

  • Split-K support for better SM utilization

Class Method Summary collapse

Class Method Details

.call(a, b, c: nil, alpha: 1.0, beta: 0.0, transpose_a: false, transpose_b: false, workspace_size: 256 * 1024 * 1024, stream: nil) ⇒ NvArray

Perform optimized matrix multiplication with auto-tuned algorithm

Parameters:

  • a (NvArray)

    Left matrix

  • b (NvArray)

    Right matrix

  • c (NvArray, nil) (defaults to: nil)

    Output matrix (created if nil)

  • alpha (Float) (defaults to: 1.0)

    Scaling factor for A @ B

  • beta (Float) (defaults to: 0.0)

    Scaling factor for C

  • transpose_a (Boolean) (defaults to: false)

    Transpose A

  • transpose_b (Boolean) (defaults to: false)

    Transpose B

  • workspace_size (Integer) (defaults to: 256 * 1024 * 1024)

    Workspace size in bytes

  • stream (CUDA::Stream, nil) (defaults to: nil)

    CUDA stream

Returns:

  • (NvArray)

    Result matrix



29
30
31
32
33
34
35
36
37
38
39
40
41
# File 'lib/nvruby/linalg/optimized_matmul.rb', line 29

def call(a, b, c: nil, alpha: 1.0, beta: 0.0,
         transpose_a: false, transpose_b: false,
         workspace_size: 256 * 1024 * 1024, stream: nil)
  # NOTE: the cuBLASLt path (execute_cublaslt) builds COLUMN-major layouts
  # for Ignis's ROW-major buffers without setting CUBLASLT_ORDER_ROW, so it
  # computed a transposed/incorrect product (its benchmark used all-zero
  # inputs, so the bug was invisible). Until the cuBLASLt layouts are fixed
  # and verified, delegate to the cuBLAS GEMM path, which IS numerically
  # verified (benchmarks/verify_matmul.rb), so callers get correct results.
  _ = workspace_size
  Matmul.call(a, b, c: c, alpha: alpha, beta: beta,
              transpose_a: transpose_a, transpose_b: transpose_b, stream: stream)
end

.call_with_algorithm(a, b, algo_index:, c: nil, alpha: 1.0, beta: 0.0, transpose_a: false, transpose_b: false, stream: nil) ⇒ NvArray

Perform matrix multiplication with a specific algorithm

Parameters:

  • a (NvArray)

    Left matrix

  • b (NvArray)

    Right matrix

  • algo_index (Integer)

    Algorithm index from heuristic results

  • c (NvArray, nil) (defaults to: nil)

    Output matrix

  • alpha (Float) (defaults to: 1.0)

    Scaling factor

  • beta (Float) (defaults to: 0.0)

    Scaling factor for C

  • transpose_a (Boolean) (defaults to: false)

    Transpose A

  • transpose_b (Boolean) (defaults to: false)

    Transpose B

  • stream (CUDA::Stream, nil) (defaults to: nil)

    CUDA stream

Returns:

  • (NvArray)

    Result matrix

Raises:



55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# File 'lib/nvruby/linalg/optimized_matmul.rb', line 55

def call_with_algorithm(a, b, algo_index:, c: nil, alpha: 1.0, beta: 0.0,
                        transpose_a: false, transpose_b: false, stream: nil)
  validate_inputs!(a, b)

  m, k1, k2, n = compute_dimensions(a, b, transpose_a, transpose_b)
  raise DimensionError, "K dimensions mismatch: #{k1} vs #{k2}" unless k1 == k2

  k = k1
  dtype = a.dtype

  c = prepare_output(c, m, n, dtype, a.device_index) if c.nil?

  CuBLASLtBindings.ensure_loaded!

  execute_cublaslt_with_algo(
    a, b, c,
    m, n, k,
    alpha, beta,
    transpose_a, transpose_b,
    dtype, algo_index, stream
  )

  c
end

.get_algorithms(m, n, k, dtype: :float32, workspace_size: 256 * 1024 * 1024) ⇒ Array<Hash>

Get available algorithms for a specific matmul configuration

Parameters:

  • m (Integer)

    Number of rows in A

  • n (Integer)

    Number of columns in B

  • k (Integer)

    Inner dimension

  • dtype (Symbol) (defaults to: :float32)

    Data type

  • workspace_size (Integer) (defaults to: 256 * 1024 * 1024)

    Max workspace size

Returns:

  • (Array<Hash>)

    Array of algorithm info hashes



88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# File 'lib/nvruby/linalg/optimized_matmul.rb', line 88

def get_algorithms(m, n, k, dtype: :float32, workspace_size: 256 * 1024 * 1024)
  CuBLASLtBindings.ensure_loaded!

  lt_handle = CuBLASLtBindings.get_handle
  cuda_type = CuBLASLtBindings.dtype_to_cuda_type(dtype)
  compute_type = CuBLASLtBindings.compute_type_for_dtype(dtype)
  scale_type = CuBLASLtBindings.scale_type_for_dtype(dtype)

  # Create descriptors with correct scale type for alpha/beta
  matmul_desc = create_matmul_desc(compute_type, scale_type)
  layout_a = create_matrix_layout(cuda_type, m, k, m)
  layout_b = create_matrix_layout(cuda_type, k, n, k)
  layout_c = create_matrix_layout(cuda_type, m, n, m)

  # Create preference with workspace size
  preference = create_preference(workspace_size)

  begin
    # Query heuristics
    max_algos = 32
    results_ptr = FFI::MemoryPointer.new(CuBLASLtBindings::MatmulHeuristicResult, max_algos)
    algo_count_ptr = FFI::MemoryPointer.new(:int)

    status = CuBLASLtBindings.cublasLtMatmulAlgoGetHeuristic(
      lt_handle,
      matmul_desc,
      layout_a, layout_b, layout_c, layout_c,
      preference,
      max_algos,
      results_ptr,
      algo_count_ptr
    )
    CuBLASLtBindings.check_status!(status, "cublasLtMatmulAlgoGetHeuristic")

    algo_count = algo_count_ptr.read_int
    algorithms = []

    algo_count.times do |i|
      result = CuBLASLtBindings::MatmulHeuristicResult.new(
        results_ptr + i * CuBLASLtBindings::MatmulHeuristicResult.size
      )

      algorithms << {
        index: i,
        workspace_size: result[:workspaceSize],
        status: result[:state],
        waves_count: result[:wavesCount],
        algo_data: result[:algo].to_a
      }
    end

    algorithms
  ensure
    cleanup_descriptors(matmul_desc, layout_a, layout_b, layout_c, preference)
  end
end