Class: Ignis::AI::Transformer::Block

Inherits:
NN::Module show all
Defined in:
lib/nnw/ai/transformer/block.rb

Overview

Single Transformer block: Attention + FF with residual connections. Supports both pre-norm (GPT-2, LLaMA) and post-norm (original) variants.

Instance Attribute Summary

Attributes inherited from NN::Module

#training

Instance Method Summary collapse

Methods inherited from NN::Module

#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!

Constructor Details

#initialize(embed_dim, num_heads, ff_dim, dropout: 0.0, pre_norm: true, activation: :gelu, device_id: 0) ⇒ Block

Returns a new instance of Block.

Parameters:

  • embed_dim (Integer)
  • num_heads (Integer)
  • ff_dim (Integer)
  • dropout (Float) (defaults to: 0.0)
  • pre_norm (Boolean) (defaults to: true)

    pre-norm (true) vs post-norm

  • activation (Symbol) (defaults to: :gelu)
  • device_id (Integer) (defaults to: 0)


16
17
18
19
20
21
22
23
24
25
26
27
28
29
# File 'lib/nnw/ai/transformer/block.rb', line 16

def initialize(embed_dim, num_heads, ff_dim, dropout: 0.0,
               pre_norm: true, activation: :gelu, device_id: 0)
  super()
  @pre_norm = pre_norm

  @attention = register_module("attention",
                MultiHeadAttention.new(embed_dim, num_heads, dropout: dropout, device_id: device_id))
  @feed_forward = register_module("feed_forward",
                   FeedForward.new(embed_dim, ff_dim, activation: activation,
                                   dropout: dropout, device_id: device_id))
  @norm1 = register_module("norm1", NN::LayerNorm.new(embed_dim, device_id: device_id))
  @norm2 = register_module("norm2", NN::LayerNorm.new(embed_dim, device_id: device_id))
  @dropout = register_module("dropout", NN::Dropout.new(p: dropout))
end

Instance Method Details

#decode_step(x, cache, layer) ⇒ Object

Incremental forward for one new token (decode path, pre-norm only). Mirrors #forward but routes attention through the KV cache and operates on a single [1, embed] row. Dropout is identity in eval mode, so it is omitted. @param x [Tensor] [1, embed]; @param cache [KVCache];

Parameters:

  • layer (Integer)

    this block’s index. @return [Tensor] [1, embed]



68
69
70
71
72
73
74
75
76
77
78
79
80
# File 'lib/nnw/ai/transformer/block.rb', line 68

def decode_step(x, cache, layer)
  raise "Block#decode_step is only implemented for pre-norm blocks" unless @pre_norm

  residual = x
  h = @norm1.call(x)
  h = @attention.decode_step(h, cache, layer)
  x = residual + h

  residual = x
  h = @norm2.call(x)
  h = @feed_forward.call(h)
  residual + h
end

#forward(x, mask: nil) ⇒ Tensor

Forward pass.

Parameters:

  • x (Tensor)

    input [batch*seq, embed_dim]

  • mask (Tensor, nil) (defaults to: nil)

    attention mask

Returns:

  • (Tensor)


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# File 'lib/nnw/ai/transformer/block.rb', line 35

def forward(x, mask: nil)
  if @pre_norm
    # Pre-norm (GPT-2 style): Norm → Attn → Residual, Norm → FF → Residual
    residual = x
    h = @norm1.call(x)
    h = @attention.call(h, h, h, mask: mask)
    h = @dropout.call(h)
    x = residual + h

    residual = x
    h = @norm2.call(x)
    h = @feed_forward.call(h)
    h = @dropout.call(h)
    residual + h
  else
    # Post-norm (original Transformer): Attn → Residual → Norm
    residual = x
    h = @attention.call(x, x, x, mask: mask)
    h = @dropout.call(h)
    x = @norm1.call(residual + h)

    residual = x
    h = @feed_forward.call(x)
    h = @dropout.call(h)
    @norm2.call(residual + h)
  end
end

#to_sString

Returns:

  • (String)


83
84
85
86
# File 'lib/nnw/ai/transformer/block.rb', line 83

def to_s
  style = @pre_norm ? "pre-norm" : "post-norm"
  "Block(#{style}, attn=#{@attention}, ff=#{@feed_forward})"
end