Class: Toy::SwiGLU

Inherits:
Object
  • Object
show all
Defined in:
lib/toy.rb

Overview

Toy::SwiGLU — gated FFN used by Llama / SmolLM2 / Qwen2 / Phi.

gate(x) = x · W_gate       up(x) = x · W_up
y       = (silu(gate(x)) * up(x)) · W_down

Three linear layers, no bias by default (llama convention). Element- wise multiply between silu(gate) and up before the down projection.

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(d_model, d_ff) ⇒ SwiGLU

Returns a new instance of SwiGLU.



413
414
415
416
417
418
419
# File 'lib/toy.rb', line 413

def initialize(d_model, d_ff)
  @d_model = d_model
  @d_ff    = d_ff
  @w_gate  = Mat.new(d_model, d_ff)
  @w_up    = Mat.new(d_model, d_ff)
  @w_down  = Mat.new(d_ff,    d_model)
end

Instance Attribute Details

#d_ffObject

Returns the value of attribute d_ff.



411
412
413
# File 'lib/toy.rb', line 411

def d_ff
  @d_ff
end

#d_modelObject

Returns the value of attribute d_model.



411
412
413
# File 'lib/toy.rb', line 411

def d_model
  @d_model
end

#w_downObject

Returns the value of attribute w_down.



411
412
413
# File 'lib/toy.rb', line 411

def w_down
  @w_down
end

#w_gateObject

Returns the value of attribute w_gate.



411
412
413
# File 'lib/toy.rb', line 411

def w_gate
  @w_gate
end

#w_upObject

Returns the value of attribute w_up.



411
412
413
# File 'lib/toy.rb', line 411

def w_up
  @w_up
end

Instance Method Details

#algorithmObject



438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
# File 'lib/toy.rb', line 438

def algorithm
  c = Toy::Card.new("SwiGLU.forward(x)", "Llama-family MLP")
  c.add_input("x",  "R^{T×D}", "")
  c.add_output("y", "R^{T×D}", "")
  c.add_hyper("D",   @d_model.to_s)
  c.add_hyper("D_f", @d_ff.to_s)
  c.add_param("W_gate, W_up", "R^{D×D_f}", "")
  c.add_param("W_down",       "R^{D_f×D}", "no biases — Llama convention")
  c.step_bind("g", "x · W_gate",    "g ∈ R^{T×D_f}")
  c.step_bind("u", "x · W_up",      "u ∈ R^{T×D_f}")
  c.step_bind("h", "silu(g) ⊙ u",   "h ∈ R^{T×D_f}")
  c.step_bind("y", "h · W_down",    "y ∈ R^{T×D}")
  c.step_return("y")
  c
end

#algorithm_cardObject



454
# File 'lib/toy.rb', line 454

def algorithm_card; algorithm.render_pseudocode; end

#forward(x) ⇒ Object

x: [T, D] → [T, D]



422
423
424
425
426
427
428
# File 'lib/toy.rb', line 422

def forward(x)
  gate = x.matmul(@w_gate)              # [T, Df]
  up   = x.matmul(@w_up)                # [T, Df]
  Toy.silu!(gate)                       # [T, Df]
  Toy.hadamard!(gate, up)               # [T, Df]  (gate := gate * up)
  gate.matmul(@w_down)                  # [T, D]
end

#param_countObject



433
434
435
436
# File 'lib/toy.rb', line 433

def param_count
  # 3 × (d_model × d_ff) — no biases (llama convention).
  3 * @d_model * @d_ff
end

#summaryObject



430
431
432
# File 'lib/toy.rb', line 430

def summary
  "SwiGLU(d=" + @d_model.to_s + ", d_ff=" + @d_ff.to_s + ")"
end