Class: Ignis::AI::Transformer::MultiHeadAttention

Inherits:
NN::Module
  • Object
show all
Defined in:
lib/nnw/ai/transformer/attention.rb

Overview

Multi-Head Attention with optional causal masking. Supports standard scaled dot-product attention.

Examples:

attn = MultiHeadAttention.new(768, 12)
out = attn.call(x, x, x, mask: causal_mask)

Instance Attribute Summary

Attributes inherited from NN::Module

#training

Instance Method Summary collapse

Methods inherited from NN::Module

#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!

Constructor Details

#initialize(embed_dim, num_heads, dropout: 0.0, bias: true, causal: true, device_id: 0) ⇒ MultiHeadAttention

Returns a new instance of MultiHeadAttention.

Parameters:

  • embed_dim (Integer)

    model dimension

  • num_heads (Integer)

    number of attention heads

  • dropout (Float) (defaults to: 0.0)

    attention dropout rate

  • bias (Boolean) (defaults to: true)

    whether projections have bias

  • device_id (Integer) (defaults to: 0)

Raises:

  • (ArgumentError)


18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# File 'lib/nnw/ai/transformer/attention.rb', line 18

def initialize(embed_dim, num_heads, dropout: 0.0, bias: true, causal: true, device_id: 0)
  super()
  raise ArgumentError, "embed_dim must be divisible by num_heads" unless (embed_dim % num_heads).zero?

  @embed_dim = embed_dim
  @num_heads = num_heads
  @head_dim = embed_dim / num_heads
  @scale = 1.0 / Math.sqrt(@head_dim)
  @causal = causal

  @q_proj = register_module("q_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
  @k_proj = register_module("k_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
  @v_proj = register_module("v_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
  @out_proj = register_module("out_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
  @attn_dropout = register_module("attn_dropout", NN::Dropout.new(p: dropout))
end

Instance Method Details

#decode_step(x, cache, layer) ⇒ Tensor

Incremental attention for one new token via a KV cache (decode path). Projects only this token’s q/k/v, appends k/v to the cache, then attends the single query over all cached keys/values (the new token is the last position, so it sees everything — no causal mask). Must run under no_grad.

Parameters:

  • x (Tensor)
    1, embed

    hidden state of the new token (post-norm)

  • cache (KVCache)
  • layer (Integer)

    this block’s layer index

Returns:

  • (Tensor)
    1, embed


67
68
69
70
71
72
73
74
75
76
77
78
79
80
# File 'lib/nnw/ai/transformer/attention.rb', line 67

def decode_step(x, cache, layer)
  raise "decode_step requires a causal attention module" unless @causal

  q = @q_proj.call(x)
  k = @k_proj.call(x)
  v = @v_proj.call(x)

  cache.append(layer, k.data, v.data)
  kview = Tensor.new(data: cache.k_view(layer), requires_grad: false)
  vview = Tensor.new(data: cache.v_view(layer), requires_grad: false)

  ctx = q.decode_sdpa(kview, vview, num_heads: @num_heads)
  @out_proj.call(ctx)
end

#forward(query, key, value, mask: nil) ⇒ Tensor

Forward pass: scaled dot-product attention.

Parameters:

  • query (Tensor)
    batch, seq_q, embed_dim
  • key (Tensor)
    batch, seq_k, embed_dim
  • value (Tensor)
    batch, seq_k, embed_dim
  • mask (Tensor, nil) (defaults to: nil)

    attention mask

Returns:

  • (Tensor)
    batch, seq_q, embed_dim


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# File 'lib/nnw/ai/transformer/attention.rb', line 41

def forward(query, key, value, mask: nil)
  # Project Q, K, V → [seq, embed_dim] each (batch = 1)
  q = @q_proj.call(query)
  k = @k_proj.call(key)
  v = @v_proj.call(value)

  # Real multi-head scaled dot-product attention with causal masking.
  # Each head attends over its own [seq, head_dim] slice; the per-head
  # Flash-Attention-2 kernel applies 1/sqrt(head_dim) scaling, the causal
  # mask, and a numerically-stable online softmax. (A nil mask still means
  # causal for a GPT-2-style decoder; pass causal: false at construction
  # for bidirectional attention.)
  context = q.sdpa(k, v, num_heads: @num_heads, causal: @causal)

  # Output projection
  @out_proj.call(context)
end

#to_sString

Returns:

  • (String)


83
84
85
# File 'lib/nnw/ai/transformer/attention.rb', line 83

def to_s
  "MultiHeadAttention(embed=#{@embed_dim}, heads=#{@num_heads}, head_dim=#{@head_dim})"
end