Class: Ignis::AI::Transformer::MultiHeadAttention
- Inherits:
-
NN::Module
- Object
- NN::Module
- Ignis::AI::Transformer::MultiHeadAttention
- Defined in:
- lib/nnw/ai/transformer/attention.rb
Overview
Multi-Head Attention with optional causal masking. Supports standard scaled dot-product attention.
Instance Attribute Summary
Attributes inherited from NN::Module
Instance Method Summary collapse
-
#decode_step(x, cache, layer) ⇒ Tensor
Incremental attention for one new token via a KV cache (decode path).
-
#forward(query, key, value, mask: nil) ⇒ Tensor
Forward pass: scaled dot-product attention.
-
#initialize(embed_dim, num_heads, dropout: 0.0, bias: true, causal: true, device_id: 0) ⇒ MultiHeadAttention
constructor
A new instance of MultiHeadAttention.
- #to_s ⇒ String
Methods inherited from NN::Module
#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!
Constructor Details
#initialize(embed_dim, num_heads, dropout: 0.0, bias: true, causal: true, device_id: 0) ⇒ MultiHeadAttention
Returns a new instance of MultiHeadAttention.
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# File 'lib/nnw/ai/transformer/attention.rb', line 18 def initialize(, num_heads, dropout: 0.0, bias: true, causal: true, device_id: 0) super() raise ArgumentError, "embed_dim must be divisible by num_heads" unless ( % num_heads).zero? @embed_dim = @num_heads = num_heads @head_dim = / num_heads @scale = 1.0 / Math.sqrt(@head_dim) @causal = causal @q_proj = register_module("q_proj", NN::Linear.new(, , bias: bias, device_id: device_id)) @k_proj = register_module("k_proj", NN::Linear.new(, , bias: bias, device_id: device_id)) @v_proj = register_module("v_proj", NN::Linear.new(, , bias: bias, device_id: device_id)) @out_proj = register_module("out_proj", NN::Linear.new(, , bias: bias, device_id: device_id)) @attn_dropout = register_module("attn_dropout", NN::Dropout.new(p: dropout)) end |
Instance Method Details
#decode_step(x, cache, layer) ⇒ Tensor
Incremental attention for one new token via a KV cache (decode path). Projects only this token’s q/k/v, appends k/v to the cache, then attends the single query over all cached keys/values (the new token is the last position, so it sees everything — no causal mask). Must run under no_grad.
67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# File 'lib/nnw/ai/transformer/attention.rb', line 67 def decode_step(x, cache, layer) raise "decode_step requires a causal attention module" unless @causal q = @q_proj.call(x) k = @k_proj.call(x) v = @v_proj.call(x) cache.append(layer, k.data, v.data) kview = Tensor.new(data: cache.k_view(layer), requires_grad: false) vview = Tensor.new(data: cache.v_view(layer), requires_grad: false) ctx = q.decode_sdpa(kview, vview, num_heads: @num_heads) @out_proj.call(ctx) end |
#forward(query, key, value, mask: nil) ⇒ Tensor
Forward pass: scaled dot-product attention.
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# File 'lib/nnw/ai/transformer/attention.rb', line 41 def forward(query, key, value, mask: nil) # Project Q, K, V → [seq, embed_dim] each (batch = 1) q = @q_proj.call(query) k = @k_proj.call(key) v = @v_proj.call(value) # Real multi-head scaled dot-product attention with causal masking. # Each head attends over its own [seq, head_dim] slice; the per-head # Flash-Attention-2 kernel applies 1/sqrt(head_dim) scaling, the causal # mask, and a numerically-stable online softmax. (A nil mask still means # causal for a GPT-2-style decoder; pass causal: false at construction # for bidirectional attention.) context = q.sdpa(k, v, num_heads: @num_heads, causal: @causal) # Output projection @out_proj.call(context) end |
#to_s ⇒ String
83 84 85 |
# File 'lib/nnw/ai/transformer/attention.rb', line 83 def to_s "MultiHeadAttention(embed=#{@embed_dim}, heads=#{@num_heads}, head_dim=#{@head_dim})" end |