Class: Ignis::AI::Transformer::RopeGqaAttention
- Inherits:
-
NN::Module
- Object
- NN::Module
- Ignis::AI::Transformer::RopeGqaAttention
- Defined in:
- lib/nnw/ai/transformer/modern.rb
Overview
Attention with rotary embeddings (RoPE) and grouped-query attention (GQA). No bias on projections (Llama/Qwen convention).
Instance Attribute Summary
Attributes inherited from NN::Module
Instance Method Summary collapse
-
#forward(x, pos_offset: 0) ⇒ Tensor
[seq, embed_dim].
-
#initialize(embed_dim, num_heads, num_kv_heads:, head_dim: nil, rope_base: 10000.0, rope_scaling: nil, bias: false, device_id: 0) ⇒ RopeGqaAttention
constructor
A new instance of RopeGqaAttention.
- #to_s ⇒ String
Methods inherited from NN::Module
#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!
Constructor Details
#initialize(embed_dim, num_heads, num_kv_heads:, head_dim: nil, rope_base: 10000.0, rope_scaling: nil, bias: false, device_id: 0) ⇒ RopeGqaAttention
Returns a new instance of RopeGqaAttention.
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# File 'lib/nnw/ai/transformer/modern.rb', line 56 def initialize(, num_heads, num_kv_heads:, head_dim: nil, rope_base: 10000.0, rope_scaling: nil, bias: false, device_id: 0) super() raise ArgumentError, "num_heads must be a multiple of num_kv_heads" unless (num_heads % num_kv_heads).zero? @embed_dim = @num_heads = num_heads @num_kv_heads = num_kv_heads @head_dim = head_dim || ( / num_heads) # Fail early (at construction) rather than silently miscompute later: # RoPE needs an even head_dim; the flash kernels cap head_dim at 128. raise ArgumentError, "head_dim must be even for RoPE (got #{@head_dim})" unless @head_dim.even? raise ArgumentError, "head_dim #{@head_dim} exceeds flash-attention HEAD_DIM_MAX (128)" if @head_dim > 128 @rope_base = rope_base # Precompute the (optionally scaled) inv_freq table once; reused every layer/step. @inv_freq = Transformer.compute_inv_freq(@head_dim, rope_base, rope_scaling) q_out = num_heads * @head_dim kv_out = num_kv_heads * @head_dim @q_proj = register_module("q_proj", NN::Linear.new(, q_out, bias: bias, device_id: device_id)) @k_proj = register_module("k_proj", NN::Linear.new(, kv_out, bias: bias, device_id: device_id)) @v_proj = register_module("v_proj", NN::Linear.new(, kv_out, bias: bias, device_id: device_id)) @o_proj = register_module("o_proj", NN::Linear.new(q_out, , bias: bias, device_id: device_id)) end |
Instance Method Details
#forward(x, pos_offset: 0) ⇒ Tensor
Returns [seq, embed_dim].
82 83 84 85 86 87 88 89 |
# File 'lib/nnw/ai/transformer/modern.rb', line 82 def forward(x, pos_offset: 0) # RoPE is applied to Q and K (not V); q has num_heads, k has num_kv_heads. q = @q_proj.call(x).rope(num_heads: @num_heads, pos_offset: pos_offset, inv_freq: @inv_freq) k = @k_proj.call(x).rope(num_heads: @num_kv_heads, pos_offset: pos_offset, inv_freq: @inv_freq) v = @v_proj.call(x) ctx = q.sdpa(k, v, num_heads: @num_heads, num_kv_heads: @num_kv_heads, causal: true) @o_proj.call(ctx) end |
#to_s ⇒ String
92 93 94 |
# File 'lib/nnw/ai/transformer/modern.rb', line 92 def to_s "RopeGqaAttention(embed=#{@embed_dim}, q_heads=#{@num_heads}, kv_heads=#{@num_kv_heads}, head_dim=#{@head_dim})" end |