Class: Ignis::AI::Transformer::ModernBlock

Inherits:
NN::Module
  • Object
show all
Defined in:
lib/nnw/ai/transformer/modern.rb

Overview

Llama/Qwen-style block: x = attn(rmsnorm(x)); x = swiglu(rmsnorm(x)).

Instance Attribute Summary

Attributes inherited from NN::Module

#training

Instance Method Summary collapse

Methods inherited from NN::Module

#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!

Constructor Details

#initialize(embed_dim, num_heads, num_kv_heads:, ff_dim:, rope_base: 10000.0, rope_scaling: nil, head_dim: nil, eps: 1e-6, device_id: 0) ⇒ ModernBlock

Returns a new instance of ModernBlock.

Parameters:

  • embed_dim (Integer)
  • num_heads (Integer)
  • num_kv_heads (Integer)
  • ff_dim (Integer)
  • rope_base (Float) (defaults to: 10000.0)
  • eps (Float) (defaults to: 1e-6)

    RMSNorm epsilon

  • device_id (Integer) (defaults to: 0)


106
107
108
109
110
111
112
113
114
115
116
# File 'lib/nnw/ai/transformer/modern.rb', line 106

def initialize(embed_dim, num_heads, num_kv_heads:, ff_dim:,
               rope_base: 10000.0, rope_scaling: nil, head_dim: nil, eps: 1e-6, device_id: 0)
  super()
  @attn  = register_module("attn",
            RopeGqaAttention.new(embed_dim, num_heads, num_kv_heads: num_kv_heads,
                                 head_dim: head_dim, rope_base: rope_base,
                                 rope_scaling: rope_scaling, device_id: device_id))
  @mlp   = register_module("mlp", SwiGLU.new(embed_dim, ff_dim, device_id: device_id))
  @norm1 = register_module("norm1", NN::RMSNorm.new(embed_dim, eps: eps, device_id: device_id))
  @norm2 = register_module("norm2", NN::RMSNorm.new(embed_dim, eps: eps, device_id: device_id))
end

Instance Method Details

#forward(x) ⇒ Tensor

Parameters:

  • x (Tensor)
    seq, embed

Returns:

  • (Tensor)


120
121
122
123
# File 'lib/nnw/ai/transformer/modern.rb', line 120

def forward(x)
  x = x + @attn.call(@norm1.call(x))
  x + @mlp.call(@norm2.call(x))
end

#to_sString

Returns:

  • (String)


126
127
128
# File 'lib/nnw/ai/transformer/modern.rb', line 126

def to_s
  "ModernBlock(#{@attn}, #{@mlp})"
end