Module: Ignis::JIT::Kernels::Attention
- Defined in:
- lib/nvruby/jit/kernels/attention.rb
Overview
Attention and softmax CUDA kernels for transformer models. Includes numerically stable softmax, top-k/top-p filtering.
Class Method Summary collapse
-
.attention_score ⇒ Ignis::JIT::Kernel
Scaled dot-product attention score: score = Q @ K^T / sqrt(d_k) With optional causal mask (upper triangular set to -inf).
-
.flash_attention_backward ⇒ Ignis::JIT::Kernel
Flash Attention 2 backward.
-
.flash_attention_forward ⇒ Ignis::JIT::Kernel
Flash Attention 2 forward (Dao et al. 2023).
-
.rope_apply ⇒ Ignis::JIT::Kernel
Rotary Position Embedding (RoPE), HF/Llama/Qwen “rotate_half” convention.
-
.softmax_backward ⇒ Ignis::JIT::Kernel
Softmax backward: Jacobian-vector product grad_input = softmax * (grad_output - sum(grad_output * softmax)).
-
.softmax_forward ⇒ Ignis::JIT::Kernel
Numerically stable softmax forward along last dimension.
-
.topk_mask ⇒ Ignis::JIT::Kernel
Top-k mask: zero out all logits except the top-k highest values.
-
.topp_mask ⇒ Ignis::JIT::Kernel
Top-p (nucleus) mask: keep smallest set of tokens with cumulative prob >= p Assumes logits have already been softmaxed into probabilities.
Class Method Details
.attention_score ⇒ Ignis::JIT::Kernel
Scaled dot-product attention score: score = Q @ K^T / sqrt(d_k) With optional causal mask (upper triangular set to -inf)
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
# File 'lib/nvruby/jit/kernels/attention.rb', line 263 def attention_score source = <<~CUDA extern "C" __global__ void attention_score(const float* __restrict__ scores, float* __restrict__ masked_scores, const float scale, const int seq_len, const int use_causal_mask) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int total = seq_len * seq_len; if (idx < total) { int row = idx / seq_len; int col = idx % seq_len; float val = scores[idx] * scale; // Causal mask: zero out future positions if (use_causal_mask && col > row) { val = -1e9f; } masked_scores[idx] = val; } } CUDA compile_cached(source, "attention_score") end |
.flash_attention_backward ⇒ Ignis::JIT::Kernel
Flash Attention 2 backward. Recomputes attention weights on-the-fly during backward (memory efficient).
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 |
# File 'lib/nvruby/jit/kernels/attention.rb', line 388 def flash_attention_backward source = <<~CUDA #define TILE_SIZE 64 #define HEAD_DIM_MAX 128 extern "C" __global__ void flash_attention_backward( const float* __restrict__ Q, const float* __restrict__ K, const float* __restrict__ V, const float* __restrict__ O, const float* __restrict__ dO, float* __restrict__ dQ, float* __restrict__ dK, float* __restrict__ dV, const int seq_len, const int head_dim, const float scale, const int use_causal_mask) { int q_idx = blockIdx.x * blockDim.x + threadIdx.x; if (q_idx >= seq_len) return; // Load Q row and dO row float q_row[HEAD_DIM_MAX], do_row[HEAD_DIM_MAX], o_row[HEAD_DIM_MAX]; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { q_row[d] = Q[q_idx * head_dim + d]; do_row[d] = dO[q_idx * head_dim + d]; o_row[d] = O[q_idx * head_dim + d]; } // Compute D_i = sum(dO_i * O_i) for this row float D_i = 0.0f; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { D_i += do_row[d] * o_row[d]; } // Recompute attention: need softmax weights // First pass: compute row_max and row_sum float row_max = -1e20f; for (int k_idx = 0; k_idx < seq_len; k_idx++) { if (use_causal_mask && k_idx > q_idx) continue; float score = 0.0f; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { score += q_row[d] * K[k_idx * head_dim + d]; } score *= scale; row_max = fmaxf(row_max, score); } float row_sum = 0.0f; for (int k_idx = 0; k_idx < seq_len; k_idx++) { if (use_causal_mask && k_idx > q_idx) continue; float score = 0.0f; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { score += q_row[d] * K[k_idx * head_dim + d]; } score *= scale; row_sum += expf(score - row_max); } // Second pass: compute gradients float dq_acc[HEAD_DIM_MAX]; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) dq_acc[d] = 0.0f; for (int k_idx = 0; k_idx < seq_len; k_idx++) { if (use_causal_mask && k_idx > q_idx) continue; float score = 0.0f; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { score += q_row[d] * K[k_idx * head_dim + d]; } score *= scale; float p_ij = expf(score - row_max) / row_sum; // dV += p_ij * dO // dP = dO @ V^T float dP = 0.0f; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { atomicAdd(&dV[k_idx * head_dim + d], p_ij * do_row[d]); dP += do_row[d] * V[k_idx * head_dim + d]; } // dS = p_ij * (dP - D_i) * scale float dS = p_ij * (dP - D_i) * scale; // dQ += dS * K[k_idx] // dK[k_idx] += dS * Q[q_idx] for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { dq_acc[d] += dS * K[k_idx * head_dim + d]; atomicAdd(&dK[k_idx * head_dim + d], dS * q_row[d]); } } // Write dQ for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { dQ[q_idx * head_dim + d] = dq_acc[d]; } } CUDA compile_cached(source, "flash_attention_backward") end |
.flash_attention_forward ⇒ Ignis::JIT::Kernel
Flash Attention 2 forward (Dao et al. 2023). Tiled Q/K/V processing — avoids materializing full N×N attention matrix. O(N) memory vs O(N²) for standard attention. Uses online softmax (streaming max + sum) for numerical stability.
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 |
# File 'lib/nvruby/jit/kernels/attention.rb', line 296 def flash_attention_forward source = <<~CUDA #define TILE_SIZE 64 #define HEAD_DIM_MAX 128 extern "C" __global__ void flash_attention_forward( const float* __restrict__ Q, const float* __restrict__ K, const float* __restrict__ V, float* __restrict__ O, const int seq_len, const int head_dim, const float scale, const int use_causal_mask) { // Each block handles one query tile int q_tile_idx = blockIdx.x; int q_start = q_tile_idx * TILE_SIZE; int tid = threadIdx.x; if (q_start + tid >= seq_len) return; // Per-thread accumulators for online softmax float row_max = -1e20f; float row_sum = 0.0f; float acc[HEAD_DIM_MAX]; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { acc[d] = 0.0f; } int q_idx = q_start + tid; // Load Q row into registers float q_row[HEAD_DIM_MAX]; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { q_row[d] = Q[q_idx * head_dim + d]; } // Iterate over K/V tiles int num_kv_tiles = (seq_len + TILE_SIZE - 1) / TILE_SIZE; for (int kv_tile = 0; kv_tile < num_kv_tiles; kv_tile++) { int kv_start = kv_tile * TILE_SIZE; // For each key in this tile, compute attention score for (int kj = 0; kj < TILE_SIZE; kj++) { int k_idx = kv_start + kj; if (k_idx >= seq_len) break; // Causal: skip future positions if (use_causal_mask && k_idx > q_idx) continue; // Dot product Q[q_idx] · K[k_idx] float score = 0.0f; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { score += q_row[d] * K[k_idx * head_dim + d]; } score *= scale; // Online softmax update (Milakov & Gimelshein) float new_max = fmaxf(row_max, score); float exp_diff = expf(row_max - new_max); float exp_score = expf(score - new_max); // Rescale running accumulator float new_sum = row_sum * exp_diff + exp_score; // Update output accumulator: rescale old + add new V contribution for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { acc[d] = acc[d] * exp_diff + exp_score * V[k_idx * head_dim + d]; } row_max = new_max; row_sum = new_sum; } } // Final normalization: divide accumulated values by total softmax sum if (row_sum > 0.0f) { float inv_sum = 1.0f / row_sum; for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) { O[q_idx * head_dim + d] = acc[d] * inv_sum; } } } CUDA compile_cached(source, "flash_attention_forward") end |
.rope_apply ⇒ Ignis::JIT::Kernel
Rotary Position Embedding (RoPE), HF/Llama/Qwen “rotate_half” convention. Input x is [seq, n_heads*head_dim] (heads contiguous). For each head, the first/second halves form rotation pairs: with half = head_dim/2,
inv_freq(i) = base^(-2i/head_dim), angle = (row + pos_offset) * inv_freq(i)
d < half: out[d] = x[d]*cos - x[d+half]*sin
d >= half: out[d] = x[d]*cos + x[d-half]*sin (i = d-half)
The rotation is orthogonal, so the BACKWARD is this same kernel with the sin sign flipped (R^T = R(-θ)); callers pass sin_sign = +1 fwd, -1 bwd. pos_offset lets decode rotate a single new token at its absolute position.
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# File 'lib/nvruby/jit/kernels/attention.rb', line 61 def rope_apply source = <<~CUDA extern "C" __global__ void rope_apply(const float* __restrict__ x, float* __restrict__ out, const int seq, const int n_heads, const int head_dim, const int pos_offset, const float* __restrict__ inv_freq, // [head_dim/2] precomputed freqs const float sin_sign) { int total = seq * n_heads * head_dim; int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total) return; int d = idx % head_dim; int row = (idx / head_dim) / n_heads; // sequence position within this call int half = head_dim / 2; int pos = row + pos_offset; // Precomputed inv_freq lets the caller apply RoPE scaling (NTK/llama3/ // YaRN) by remapping frequencies on the host; standard RoPE just passes // base^(-2i/head_dim). int freq_idx = (d < half) ? d : (d - half); float angle = (float)pos * inv_freq[freq_idx]; float c = cosf(angle); float s = sinf(angle) * sin_sign; float xd = x[idx]; if (d < half) { out[idx] = xd * c - x[idx + half] * s; } else { out[idx] = xd * c + x[idx - half] * s; } } CUDA compile_cached(source, "rope_apply") end |
.softmax_backward ⇒ Ignis::JIT::Kernel
Softmax backward: Jacobian-vector product grad_input = softmax * (grad_output - sum(grad_output * softmax))
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# File 'lib/nvruby/jit/kernels/attention.rb', line 101 def softmax_backward source = <<~CUDA extern "C" __global__ void softmax_backward(const float* __restrict__ grad_output, const float* __restrict__ softmax_output, float* __restrict__ grad_input, const int outer_size, const int dim_size) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row < outer_size) { const float* go = grad_output + row * dim_size; const float* so = softmax_output + row * dim_size; float* gi = grad_input + row * dim_size; // dot(grad_output, softmax_output) float dot = 0.0f; for (int j = 0; j < dim_size; j++) { dot += go[j] * so[j]; } // grad_input = softmax * (grad_output - dot) for (int j = 0; j < dim_size; j++) { gi[j] = so[j] * (go[j] - dot); } } } CUDA compile_cached(source, "softmax_backward") end |
.softmax_forward ⇒ Ignis::JIT::Kernel
Numerically stable softmax forward along last dimension. Uses online max + sum trick for stability. Input shape: [outer_size, dim_size], softmax along dim_size
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# File 'lib/nvruby/jit/kernels/attention.rb', line 14 def softmax_forward source = <<~CUDA extern "C" __global__ void softmax_forward(const float* __restrict__ input, float* __restrict__ output, const int outer_size, const int dim_size) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row < outer_size) { const float* in_row = input + row * dim_size; float* out_row = output + row * dim_size; // Find max for numerical stability float max_val = in_row[0]; for (int j = 1; j < dim_size; j++) { max_val = fmaxf(max_val, in_row[j]); } // Compute exp(x - max) and sum float sum = 0.0f; for (int j = 0; j < dim_size; j++) { float e = expf(in_row[j] - max_val); out_row[j] = e; sum += e; } // Normalize float inv_sum = 1.0f / sum; for (int j = 0; j < dim_size; j++) { out_row[j] *= inv_sum; } } } CUDA compile_cached(source, "softmax_forward") end |
.topk_mask ⇒ Ignis::JIT::Kernel
Top-k mask: zero out all logits except the top-k highest values. Used for top-k sampling in text generation.
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# File 'lib/nvruby/jit/kernels/attention.rb', line 134 def topk_mask source = <<~CUDA extern "C" __global__ void topk_mask(float* __restrict__ logits, const int vocab_size, const int k) { // Single-row operation (batch dim handled by caller) int row = blockIdx.x; float* row_logits = logits + row * vocab_size; // Find k-th largest value using partial selection // Simple approach: sort-like pass to find threshold float threshold = -1e20f; for (int i = 0; i < k; i++) { float max_val = -1e20f; for (int j = 0; j < vocab_size; j++) { if (row_logits[j] > max_val && (i == 0 || row_logits[j] < threshold || (row_logits[j] == threshold))) { // On first pass, find the max // On subsequent passes, find next highest if (i == 0 || row_logits[j] <= threshold) { // Need a smarter approach for GPU } } } } // Simpler: find the k-th largest via partial sort // For GPU efficiency, we use a different strategy: // 1. Copy values, sort descending, get threshold at index k-1 // 2. Mask below threshold // This kernel uses a simple nth-element approach extern __shared__ float shared_vals[]; if (threadIdx.x == 0) { // Copy to shared memory for (int j = 0; j < vocab_size && j < 65536; j++) { shared_vals[j] = row_logits[j]; } // Find k-th largest via partial sort (insertion sort on top-k) float kth = -1e20f; float top_vals[256]; // Max k of 256 int actual_k = k < 256 ? k : 256; for (int i = 0; i < actual_k; i++) top_vals[i] = -1e20f; for (int j = 0; j < vocab_size; j++) { float v = row_logits[j]; if (v > top_vals[actual_k - 1]) { top_vals[actual_k - 1] = v; // Insertion sort step for (int m = actual_k - 1; m > 0 && top_vals[m] > top_vals[m-1]; m--) { float tmp = top_vals[m]; top_vals[m] = top_vals[m-1]; top_vals[m-1] = tmp; } } } kth = top_vals[actual_k - 1]; // Mask logits below threshold for (int j = 0; j < vocab_size; j++) { if (row_logits[j] < kth) { row_logits[j] = -1e20f; } } } } CUDA compile_cached(source, "topk_mask") end |
.topp_mask ⇒ Ignis::JIT::Kernel
Top-p (nucleus) mask: keep smallest set of tokens with cumulative prob >= p Assumes logits have already been softmaxed into probabilities.
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
# File 'lib/nvruby/jit/kernels/attention.rb', line 206 def topp_mask source = <<~CUDA extern "C" __global__ void topp_mask(float* __restrict__ probs, const int vocab_size, const float p) { int row = blockIdx.x; float* row_probs = probs + row * vocab_size; if (threadIdx.x == 0) { // Simple CPU-style approach for correctness // Find cumulative threshold // 1. Sort indices by probability descending // 2. Compute cumulative sum // 3. Zero everything after cumsum > p // Using insertion sort on indices (vocab_size typically 50257) // For production, use radix sort kernel float cumsum = 0.0f; float threshold = 0.0f; // Find cumulative prob threshold // Simple O(n*k) approach: repeatedly find max and accumulate bool* mask = (bool*)malloc(vocab_size * sizeof(bool)); if (mask) { for (int j = 0; j < vocab_size; j++) mask[j] = false; while (cumsum < p) { float max_val = -1.0f; int max_idx = -1; for (int j = 0; j < vocab_size; j++) { if (!mask[j] && row_probs[j] > max_val) { max_val = row_probs[j]; max_idx = j; } } if (max_idx < 0) break; mask[max_idx] = true; cumsum += max_val; threshold = max_val; } // Zero out non-selected for (int j = 0; j < vocab_size; j++) { if (!mask[j]) row_probs[j] = 0.0f; } free(mask); } } } CUDA compile_cached(source, "topp_mask") end |