Class: Ignis::AI::Transformer::SwiGLU

Inherits:
NN::Module
  • Object
show all
Defined in:
lib/nnw/ai/transformer/swiglu.rb

Overview

SwiGLU feed-forward (Llama / Qwen / SmolLM / Mistral):

down( silu(gate(x)) ⊙ up(x) )

Two input projections (gate, up) into the hidden dim and one output (down) back to embed. silu(z) = z·sigmoid(z). Llama-style uses no bias. This is a pure composition of verified ops (Linear, silu, elementwise mul), so autograd produces the backward automatically.

Instance Attribute Summary

Attributes inherited from NN::Module

#training

Instance Method Summary collapse

Methods inherited from NN::Module

#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!

Constructor Details

#initialize(embed_dim, ff_dim, bias: false, device_id: 0) ⇒ SwiGLU

Returns a new instance of SwiGLU.

Parameters:

  • embed_dim (Integer)
  • ff_dim (Integer)

    hidden dim (Llama uses ~8/3·embed, rounded)

  • bias (Boolean) (defaults to: false)

    include projection biases (Llama/Qwen: false)

  • device_id (Integer) (defaults to: 0)


17
18
19
20
21
22
23
24
# File 'lib/nnw/ai/transformer/swiglu.rb', line 17

def initialize(embed_dim, ff_dim, bias: false, device_id: 0)
  super()
  @embed_dim = embed_dim
  @ff_dim = ff_dim
  @gate = register_module("gate", NN::Linear.new(embed_dim, ff_dim, bias: bias, device_id: device_id))
  @up   = register_module("up",   NN::Linear.new(embed_dim, ff_dim, bias: bias, device_id: device_id))
  @down = register_module("down", NN::Linear.new(ff_dim, embed_dim, bias: bias, device_id: device_id))
end

Instance Method Details

#forward(x) ⇒ Tensor

Returns [*, embed_dim].

Parameters:

  • x (Tensor)
    *, embed_dim

Returns:

  • (Tensor)
    *, embed_dim


28
29
30
# File 'lib/nnw/ai/transformer/swiglu.rb', line 28

def forward(x)
  @down.call(@gate.call(x).silu * @up.call(x))
end

#to_sString

Returns:

  • (String)


33
34
35
# File 'lib/nnw/ai/transformer/swiglu.rb', line 33

def to_s
  "SwiGLU(embed=#{@embed_dim}, ff=#{@ff_dim})"
end