Module: Ignis::JIT::Kernels::Attention

Defined in:
lib/nvruby/jit/kernels/attention.rb

Overview

Attention and softmax CUDA kernels for transformer models. Includes numerically stable softmax, top-k/top-p filtering.

Class Method Summary collapse

Class Method Details

.attention_scoreIgnis::JIT::Kernel

Scaled dot-product attention score: score = Q @ K^T / sqrt(d_k) With optional causal mask (upper triangular set to -inf)

Returns:



263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# File 'lib/nvruby/jit/kernels/attention.rb', line 263

def attention_score
  source = <<~CUDA
    extern "C" __global__
    void attention_score(const float* __restrict__ scores,
                         float* __restrict__ masked_scores,
                         const float scale,
                         const int seq_len,
                         const int use_causal_mask) {
      int idx = blockIdx.x * blockDim.x + threadIdx.x;
      int total = seq_len * seq_len;
      if (idx < total) {
        int row = idx / seq_len;
        int col = idx % seq_len;

        float val = scores[idx] * scale;

        // Causal mask: zero out future positions
        if (use_causal_mask && col > row) {
          val = -1e9f;
        }

        masked_scores[idx] = val;
      }
    }
  CUDA
  compile_cached(source, "attention_score")
end

.flash_attention_backwardIgnis::JIT::Kernel

Flash Attention 2 backward. Recomputes attention weights on-the-fly during backward (memory efficient).

Returns:



388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
# File 'lib/nvruby/jit/kernels/attention.rb', line 388

def flash_attention_backward
  source = <<~CUDA
    #define TILE_SIZE 64
    #define HEAD_DIM_MAX 128

    extern "C" __global__
    void flash_attention_backward(
        const float* __restrict__ Q,
        const float* __restrict__ K,
        const float* __restrict__ V,
        const float* __restrict__ O,
        const float* __restrict__ dO,
        float* __restrict__ dQ,
        float* __restrict__ dK,
        float* __restrict__ dV,
        const int seq_len,
        const int head_dim,
        const float scale,
        const int use_causal_mask) {

      int q_idx = blockIdx.x * blockDim.x + threadIdx.x;
      if (q_idx >= seq_len) return;

      // Load Q row and dO row
      float q_row[HEAD_DIM_MAX], do_row[HEAD_DIM_MAX], o_row[HEAD_DIM_MAX];
      for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
        q_row[d] = Q[q_idx * head_dim + d];
        do_row[d] = dO[q_idx * head_dim + d];
        o_row[d] = O[q_idx * head_dim + d];
      }

      // Compute D_i = sum(dO_i * O_i) for this row
      float D_i = 0.0f;
      for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
        D_i += do_row[d] * o_row[d];
      }

      // Recompute attention: need softmax weights
      // First pass: compute row_max and row_sum
      float row_max = -1e20f;
      for (int k_idx = 0; k_idx < seq_len; k_idx++) {
        if (use_causal_mask && k_idx > q_idx) continue;
        float score = 0.0f;
        for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
          score += q_row[d] * K[k_idx * head_dim + d];
        }
        score *= scale;
        row_max = fmaxf(row_max, score);
      }

      float row_sum = 0.0f;
      for (int k_idx = 0; k_idx < seq_len; k_idx++) {
        if (use_causal_mask && k_idx > q_idx) continue;
        float score = 0.0f;
        for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
          score += q_row[d] * K[k_idx * head_dim + d];
        }
        score *= scale;
        row_sum += expf(score - row_max);
      }

      // Second pass: compute gradients
      float dq_acc[HEAD_DIM_MAX];
      for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) dq_acc[d] = 0.0f;

      for (int k_idx = 0; k_idx < seq_len; k_idx++) {
        if (use_causal_mask && k_idx > q_idx) continue;

        float score = 0.0f;
        for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
          score += q_row[d] * K[k_idx * head_dim + d];
        }
        score *= scale;
        float p_ij = expf(score - row_max) / row_sum;

        // dV += p_ij * dO
        // dP = dO @ V^T
        float dP = 0.0f;
        for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
          atomicAdd(&dV[k_idx * head_dim + d], p_ij * do_row[d]);
          dP += do_row[d] * V[k_idx * head_dim + d];
        }

        // dS = p_ij * (dP - D_i) * scale
        float dS = p_ij * (dP - D_i) * scale;

        // dQ += dS * K[k_idx]
        // dK[k_idx] += dS * Q[q_idx]
        for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
          dq_acc[d] += dS * K[k_idx * head_dim + d];
          atomicAdd(&dK[k_idx * head_dim + d], dS * q_row[d]);
        }
      }

      // Write dQ
      for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
        dQ[q_idx * head_dim + d] = dq_acc[d];
      }
    }
  CUDA
  compile_cached(source, "flash_attention_backward")
end

.flash_attention_forwardIgnis::JIT::Kernel

Flash Attention 2 forward (Dao et al. 2023). Tiled Q/K/V processing — avoids materializing full N×N attention matrix. O(N) memory vs O(N²) for standard attention. Uses online softmax (streaming max + sum) for numerical stability.

Returns:



296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# File 'lib/nvruby/jit/kernels/attention.rb', line 296

def flash_attention_forward
  source = <<~CUDA
    #define TILE_SIZE 64
    #define HEAD_DIM_MAX 128

    extern "C" __global__
    void flash_attention_forward(
        const float* __restrict__ Q,
        const float* __restrict__ K,
        const float* __restrict__ V,
        float* __restrict__ O,
        const int seq_len,
        const int head_dim,
        const float scale,
        const int use_causal_mask) {

      // Each block handles one query tile
      int q_tile_idx = blockIdx.x;
      int q_start = q_tile_idx * TILE_SIZE;
      int tid = threadIdx.x;

      if (q_start + tid >= seq_len) return;

      // Per-thread accumulators for online softmax
      float row_max = -1e20f;
      float row_sum = 0.0f;
      float acc[HEAD_DIM_MAX];
      for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
        acc[d] = 0.0f;
      }

      int q_idx = q_start + tid;

      // Load Q row into registers
      float q_row[HEAD_DIM_MAX];
      for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
        q_row[d] = Q[q_idx * head_dim + d];
      }

      // Iterate over K/V tiles
      int num_kv_tiles = (seq_len + TILE_SIZE - 1) / TILE_SIZE;
      for (int kv_tile = 0; kv_tile < num_kv_tiles; kv_tile++) {
        int kv_start = kv_tile * TILE_SIZE;

        // For each key in this tile, compute attention score
        for (int kj = 0; kj < TILE_SIZE; kj++) {
          int k_idx = kv_start + kj;
          if (k_idx >= seq_len) break;

          // Causal: skip future positions
          if (use_causal_mask && k_idx > q_idx) continue;

          // Dot product Q[q_idx] · K[k_idx]
          float score = 0.0f;
          for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
            score += q_row[d] * K[k_idx * head_dim + d];
          }
          score *= scale;

          // Online softmax update (Milakov & Gimelshein)
          float new_max = fmaxf(row_max, score);
          float exp_diff = expf(row_max - new_max);
          float exp_score = expf(score - new_max);

          // Rescale running accumulator
          float new_sum = row_sum * exp_diff + exp_score;

          // Update output accumulator: rescale old + add new V contribution
          for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
            acc[d] = acc[d] * exp_diff + exp_score * V[k_idx * head_dim + d];
          }

          row_max = new_max;
          row_sum = new_sum;
        }
      }

      // Final normalization: divide accumulated values by total softmax sum
      if (row_sum > 0.0f) {
        float inv_sum = 1.0f / row_sum;
        for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
          O[q_idx * head_dim + d] = acc[d] * inv_sum;
        }
      }
    }
  CUDA
  compile_cached(source, "flash_attention_forward")
end

.rope_applyIgnis::JIT::Kernel

Rotary Position Embedding (RoPE), HF/Llama/Qwen “rotate_half” convention. Input x is [seq, n_heads*head_dim] (heads contiguous). For each head, the first/second halves form rotation pairs: with half = head_dim/2,

inv_freq(i) = base^(-2i/head_dim),  angle = (row + pos_offset) * inv_freq(i)
d <  half:  out[d] = x[d]*cos - x[d+half]*sin
d >= half:  out[d] = x[d]*cos + x[d-half]*sin   (i = d-half)

The rotation is orthogonal, so the BACKWARD is this same kernel with the sin sign flipped (R^T = R(-θ)); callers pass sin_sign = +1 fwd, -1 bwd. pos_offset lets decode rotate a single new token at its absolute position.

Returns:



61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# File 'lib/nvruby/jit/kernels/attention.rb', line 61

def rope_apply
  source = <<~CUDA
    extern "C" __global__
    void rope_apply(const float* __restrict__ x,
                    float* __restrict__ out,
                    const int seq, const int n_heads, const int head_dim,
                    const int pos_offset,
                    const float* __restrict__ inv_freq,  // [head_dim/2] precomputed freqs
                    const float sin_sign) {
      int total = seq * n_heads * head_dim;
      int idx = blockIdx.x * blockDim.x + threadIdx.x;
      if (idx >= total) return;

      int d   = idx % head_dim;
      int row = (idx / head_dim) / n_heads;   // sequence position within this call
      int half = head_dim / 2;
      int pos = row + pos_offset;

      // Precomputed inv_freq lets the caller apply RoPE scaling (NTK/llama3/
      // YaRN) by remapping frequencies on the host; standard RoPE just passes
      // base^(-2i/head_dim).
      int freq_idx = (d < half) ? d : (d - half);
      float angle = (float)pos * inv_freq[freq_idx];
      float c = cosf(angle);
      float s = sinf(angle) * sin_sign;

      float xd = x[idx];
      if (d < half) {
        out[idx] = xd * c - x[idx + half] * s;
      } else {
        out[idx] = xd * c + x[idx - half] * s;
      }
    }
  CUDA
  compile_cached(source, "rope_apply")
end

.softmax_backwardIgnis::JIT::Kernel

Softmax backward: Jacobian-vector product grad_input = softmax * (grad_output - sum(grad_output * softmax))

Returns:



101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# File 'lib/nvruby/jit/kernels/attention.rb', line 101

def softmax_backward
  source = <<~CUDA
    extern "C" __global__
    void softmax_backward(const float* __restrict__ grad_output,
                          const float* __restrict__ softmax_output,
                          float* __restrict__ grad_input,
                          const int outer_size,
                          const int dim_size) {
      int row = blockIdx.x * blockDim.x + threadIdx.x;
      if (row < outer_size) {
        const float* go = grad_output + row * dim_size;
        const float* so = softmax_output + row * dim_size;
        float* gi = grad_input + row * dim_size;

        // dot(grad_output, softmax_output)
        float dot = 0.0f;
        for (int j = 0; j < dim_size; j++) {
          dot += go[j] * so[j];
        }

        // grad_input = softmax * (grad_output - dot)
        for (int j = 0; j < dim_size; j++) {
          gi[j] = so[j] * (go[j] - dot);
        }
      }
    }
  CUDA
  compile_cached(source, "softmax_backward")
end

.softmax_forwardIgnis::JIT::Kernel

Numerically stable softmax forward along last dimension. Uses online max + sum trick for stability. Input shape: [outer_size, dim_size], softmax along dim_size

Returns:



14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# File 'lib/nvruby/jit/kernels/attention.rb', line 14

def softmax_forward
  source = <<~CUDA
    extern "C" __global__
    void softmax_forward(const float* __restrict__ input,
                         float* __restrict__ output,
                         const int outer_size,
                         const int dim_size) {
      int row = blockIdx.x * blockDim.x + threadIdx.x;
      if (row < outer_size) {
        const float* in_row = input + row * dim_size;
        float* out_row = output + row * dim_size;

        // Find max for numerical stability
        float max_val = in_row[0];
        for (int j = 1; j < dim_size; j++) {
          max_val = fmaxf(max_val, in_row[j]);
        }

        // Compute exp(x - max) and sum
        float sum = 0.0f;
        for (int j = 0; j < dim_size; j++) {
          float e = expf(in_row[j] - max_val);
          out_row[j] = e;
          sum += e;
        }

        // Normalize
        float inv_sum = 1.0f / sum;
        for (int j = 0; j < dim_size; j++) {
          out_row[j] *= inv_sum;
        }
      }
    }
  CUDA
  compile_cached(source, "softmax_forward")
end

.topk_maskIgnis::JIT::Kernel

Top-k mask: zero out all logits except the top-k highest values. Used for top-k sampling in text generation.

Returns:



134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# File 'lib/nvruby/jit/kernels/attention.rb', line 134

def topk_mask
  source = <<~CUDA
    extern "C" __global__
    void topk_mask(float* __restrict__ logits,
                   const int vocab_size,
                   const int k) {
      // Single-row operation (batch dim handled by caller)
      int row = blockIdx.x;
      float* row_logits = logits + row * vocab_size;

      // Find k-th largest value using partial selection
      // Simple approach: sort-like pass to find threshold
      float threshold = -1e20f;
      for (int i = 0; i < k; i++) {
        float max_val = -1e20f;
        for (int j = 0; j < vocab_size; j++) {
          if (row_logits[j] > max_val && (i == 0 || row_logits[j] < threshold || (row_logits[j] == threshold))) {
            // On first pass, find the max
            // On subsequent passes, find next highest
            if (i == 0 || row_logits[j] <= threshold) {
              // Need a smarter approach for GPU
            }
          }
        }
      }

      // Simpler: find the k-th largest via partial sort
      // For GPU efficiency, we use a different strategy:
      // 1. Copy values, sort descending, get threshold at index k-1
      // 2. Mask below threshold
      // This kernel uses a simple nth-element approach
      extern __shared__ float shared_vals[];
      if (threadIdx.x == 0) {
        // Copy to shared memory
        for (int j = 0; j < vocab_size && j < 65536; j++) {
          shared_vals[j] = row_logits[j];
        }
        // Find k-th largest via partial sort (insertion sort on top-k)
        float kth = -1e20f;
        float top_vals[256]; // Max k of 256
        int actual_k = k < 256 ? k : 256;
        for (int i = 0; i < actual_k; i++) top_vals[i] = -1e20f;

        for (int j = 0; j < vocab_size; j++) {
          float v = row_logits[j];
          if (v > top_vals[actual_k - 1]) {
            top_vals[actual_k - 1] = v;
            // Insertion sort step
            for (int m = actual_k - 1; m > 0 && top_vals[m] > top_vals[m-1]; m--) {
              float tmp = top_vals[m];
              top_vals[m] = top_vals[m-1];
              top_vals[m-1] = tmp;
            }
          }
        }
        kth = top_vals[actual_k - 1];

        // Mask logits below threshold
        for (int j = 0; j < vocab_size; j++) {
          if (row_logits[j] < kth) {
            row_logits[j] = -1e20f;
          }
        }
      }
    }
  CUDA
  compile_cached(source, "topk_mask")
end

.topp_maskIgnis::JIT::Kernel

Top-p (nucleus) mask: keep smallest set of tokens with cumulative prob >= p Assumes logits have already been softmaxed into probabilities.

Returns:



206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# File 'lib/nvruby/jit/kernels/attention.rb', line 206

def topp_mask
  source = <<~CUDA
    extern "C" __global__
    void topp_mask(float* __restrict__ probs,
                   const int vocab_size,
                   const float p) {
      int row = blockIdx.x;
      float* row_probs = probs + row * vocab_size;

      if (threadIdx.x == 0) {
        // Simple CPU-style approach for correctness
        // Find cumulative threshold
        // 1. Sort indices by probability descending
        // 2. Compute cumulative sum
        // 3. Zero everything after cumsum > p

        // Using insertion sort on indices (vocab_size typically 50257)
        // For production, use radix sort kernel
        float cumsum = 0.0f;
        float threshold = 0.0f;

        // Find cumulative prob threshold
        // Simple O(n*k) approach: repeatedly find max and accumulate
        bool* mask = (bool*)malloc(vocab_size * sizeof(bool));
        if (mask) {
          for (int j = 0; j < vocab_size; j++) mask[j] = false;

          while (cumsum < p) {
            float max_val = -1.0f;
            int max_idx = -1;
            for (int j = 0; j < vocab_size; j++) {
              if (!mask[j] && row_probs[j] > max_val) {
                max_val = row_probs[j];
                max_idx = j;
              }
            }
            if (max_idx < 0) break;
            mask[max_idx] = true;
            cumsum += max_val;
            threshold = max_val;
          }

          // Zero out non-selected
          for (int j = 0; j < vocab_size; j++) {
            if (!mask[j]) row_probs[j] = 0.0f;
          }
          free(mask);
        }
      }
    }
  CUDA
  compile_cached(source, "topp_mask")
end