Class: Ignis::AI::Transformer::RopeGqaAttention

Inherits:
NN::Module
  • Object
show all
Defined in:
lib/nnw/ai/transformer/modern.rb

Overview

Attention with rotary embeddings (RoPE) and grouped-query attention (GQA). No bias on projections (Llama/Qwen convention).

Instance Attribute Summary

Attributes inherited from NN::Module

#training

Instance Method Summary collapse

Methods inherited from NN::Module

#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!

Constructor Details

#initialize(embed_dim, num_heads, num_kv_heads:, head_dim: nil, rope_base: 10000.0, rope_scaling: nil, bias: false, device_id: 0) ⇒ RopeGqaAttention

Returns a new instance of RopeGqaAttention.

Parameters:

  • embed_dim (Integer)
  • num_heads (Integer)

    query heads

  • num_kv_heads (Integer)

    key/value heads (== num_heads ⇒ plain MHA)

  • head_dim (Integer, nil) (defaults to: nil)

    per-head dim (default embed_dim/num_heads)

  • rope_base (Float) (defaults to: 10000.0)

    RoPE theta

  • bias (Boolean) (defaults to: false)
  • device_id (Integer) (defaults to: 0)

Raises:

  • (ArgumentError)


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# File 'lib/nnw/ai/transformer/modern.rb', line 56

def initialize(embed_dim, num_heads, num_kv_heads:, head_dim: nil,
               rope_base: 10000.0, rope_scaling: nil, bias: false, device_id: 0)
  super()
  raise ArgumentError, "num_heads must be a multiple of num_kv_heads" unless (num_heads % num_kv_heads).zero?
  @embed_dim = embed_dim
  @num_heads = num_heads
  @num_kv_heads = num_kv_heads
  @head_dim = head_dim || (embed_dim / num_heads)
  # Fail early (at construction) rather than silently miscompute later:
  # RoPE needs an even head_dim; the flash kernels cap head_dim at 128.
  raise ArgumentError, "head_dim must be even for RoPE (got #{@head_dim})" unless @head_dim.even?
  raise ArgumentError, "head_dim #{@head_dim} exceeds flash-attention HEAD_DIM_MAX (128)" if @head_dim > 128
  @rope_base = rope_base
  # Precompute the (optionally scaled) inv_freq table once; reused every layer/step.
  @inv_freq = Transformer.compute_inv_freq(@head_dim, rope_base, rope_scaling)
  q_out  = num_heads * @head_dim
  kv_out = num_kv_heads * @head_dim
  @q_proj = register_module("q_proj", NN::Linear.new(embed_dim, q_out, bias: bias, device_id: device_id))
  @k_proj = register_module("k_proj", NN::Linear.new(embed_dim, kv_out, bias: bias, device_id: device_id))
  @v_proj = register_module("v_proj", NN::Linear.new(embed_dim, kv_out, bias: bias, device_id: device_id))
  @o_proj = register_module("o_proj", NN::Linear.new(q_out, embed_dim, bias: bias, device_id: device_id))
end

Instance Method Details

#forward(x, pos_offset: 0) ⇒ Tensor

Returns [seq, embed_dim].

Parameters:

  • x (Tensor)
    seq, embed_dim
  • pos_offset (Integer) (defaults to: 0)

    absolute position of row 0 (for KV-cache decode)

Returns:

  • (Tensor)
    seq, embed_dim


82
83
84
85
86
87
88
89
# File 'lib/nnw/ai/transformer/modern.rb', line 82

def forward(x, pos_offset: 0)
  # RoPE is applied to Q and K (not V); q has num_heads, k has num_kv_heads.
  q = @q_proj.call(x).rope(num_heads: @num_heads, pos_offset: pos_offset, inv_freq: @inv_freq)
  k = @k_proj.call(x).rope(num_heads: @num_kv_heads, pos_offset: pos_offset, inv_freq: @inv_freq)
  v = @v_proj.call(x)
  ctx = q.sdpa(k, v, num_heads: @num_heads, num_kv_heads: @num_kv_heads, causal: true)
  @o_proj.call(ctx)
end

#to_sString

Returns:

  • (String)


92
93
94
# File 'lib/nnw/ai/transformer/modern.rb', line 92

def to_s
  "RopeGqaAttention(embed=#{@embed_dim}, q_heads=#{@num_heads}, kv_heads=#{@num_kv_heads}, head_dim=#{@head_dim})"
end