Class: Toy::SwiGLU
- Inherits:
-
Object
- Object
- Toy::SwiGLU
- Defined in:
- lib/toy.rb
Overview
Toy::SwiGLU — gated FFN used by Llama / SmolLM2 / Qwen2 / Phi.
gate(x) = x · W_gate up(x) = x · W_up
y = (silu(gate(x)) * up(x)) · W_down
Three linear layers, no bias by default (llama convention). Element- wise multiply between silu(gate) and up before the down projection.
Instance Attribute Summary collapse
-
#d_ff ⇒ Object
Returns the value of attribute d_ff.
-
#d_model ⇒ Object
Returns the value of attribute d_model.
-
#w_down ⇒ Object
Returns the value of attribute w_down.
-
#w_gate ⇒ Object
Returns the value of attribute w_gate.
-
#w_up ⇒ Object
Returns the value of attribute w_up.
Instance Method Summary collapse
- #algorithm ⇒ Object
- #algorithm_card ⇒ Object
-
#forward(x) ⇒ Object
x: [T, D] → [T, D].
-
#initialize(d_model, d_ff) ⇒ SwiGLU
constructor
A new instance of SwiGLU.
- #param_count ⇒ Object
- #summary ⇒ Object
Constructor Details
Instance Attribute Details
#d_ff ⇒ Object
Returns the value of attribute d_ff.
411 412 413 |
# File 'lib/toy.rb', line 411 def d_ff @d_ff end |
#d_model ⇒ Object
Returns the value of attribute d_model.
411 412 413 |
# File 'lib/toy.rb', line 411 def d_model @d_model end |
#w_down ⇒ Object
Returns the value of attribute w_down.
411 412 413 |
# File 'lib/toy.rb', line 411 def w_down @w_down end |
#w_gate ⇒ Object
Returns the value of attribute w_gate.
411 412 413 |
# File 'lib/toy.rb', line 411 def w_gate @w_gate end |
#w_up ⇒ Object
Returns the value of attribute w_up.
411 412 413 |
# File 'lib/toy.rb', line 411 def w_up @w_up end |
Instance Method Details
#algorithm ⇒ Object
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 |
# File 'lib/toy.rb', line 438 def algorithm c = Toy::Card.new("SwiGLU.forward(x)", "Llama-family MLP") c.add_input("x", "R^{T×D}", "") c.add_output("y", "R^{T×D}", "") c.add_hyper("D", @d_model.to_s) c.add_hyper("D_f", @d_ff.to_s) c.add_param("W_gate, W_up", "R^{D×D_f}", "") c.add_param("W_down", "R^{D_f×D}", "no biases — Llama convention") c.step_bind("g", "x · W_gate", "g ∈ R^{T×D_f}") c.step_bind("u", "x · W_up", "u ∈ R^{T×D_f}") c.step_bind("h", "silu(g) ⊙ u", "h ∈ R^{T×D_f}") c.step_bind("y", "h · W_down", "y ∈ R^{T×D}") c.step_return("y") c end |
#algorithm_card ⇒ Object
454 |
# File 'lib/toy.rb', line 454 def algorithm_card; algorithm.render_pseudocode; end |
#forward(x) ⇒ Object
x: [T, D] → [T, D]
422 423 424 425 426 427 428 |
# File 'lib/toy.rb', line 422 def forward(x) gate = x.matmul(@w_gate) # [T, Df] up = x.matmul(@w_up) # [T, Df] Toy.silu!(gate) # [T, Df] Toy.hadamard!(gate, up) # [T, Df] (gate := gate * up) gate.matmul(@w_down) # [T, D] end |
#param_count ⇒ Object
433 434 435 436 |
# File 'lib/toy.rb', line 433 def param_count # 3 × (d_model × d_ff) — no biases (llama convention). 3 * @d_model * @d_ff end |
#summary ⇒ Object
430 431 432 |
# File 'lib/toy.rb', line 430 def summary "SwiGLU(d=" + @d_model.to_s + ", d_ff=" + @d_ff.to_s + ")" end |