Class: Toy::SmolLM2Block

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/models/toy_smollm2.rb

Overview

Llama-style block: pre-norm + residual on each sublayer.

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(cfg, rope_obj) ⇒ SmolLM2Block

Returns a new instance of SmolLM2Block.



217
218
219
220
221
222
223
224
# File 'lib/toy/models/toy_smollm2.rb', line 217

def initialize(cfg, rope_obj)
  @rn1     = Toy::RMSNorm.new(cfg.d_model)
  @rn1.eps = cfg.rms_eps
  @rn2     = Toy::RMSNorm.new(cfg.d_model)
  @rn2.eps = cfg.rms_eps
  @attn    = Toy::GQAttention.new(cfg.d_model, cfg.n_heads, cfg.n_kv, rope_obj)
  @ffn     = Toy::SwiGLU.new(cfg.d_model, cfg.d_ff)
end

Instance Attribute Details

#attnObject

Returns the value of attribute attn.



215
216
217
# File 'lib/toy/models/toy_smollm2.rb', line 215

def attn
  @attn
end

#ffnObject

Returns the value of attribute ffn.



215
216
217
# File 'lib/toy/models/toy_smollm2.rb', line 215

def ffn
  @ffn
end

#rn1Object

Returns the value of attribute rn1.



215
216
217
# File 'lib/toy/models/toy_smollm2.rb', line 215

def rn1
  @rn1
end

#rn2Object

Returns the value of attribute rn2.



215
216
217
# File 'lib/toy/models/toy_smollm2.rb', line 215

def rn2
  @rn2
end

Instance Method Details

#algorithmObject



238
239
240
241
242
243
244
245
246
247
248
249
# File 'lib/toy/models/toy_smollm2.rb', line 238

def algorithm
  c = Toy::Card.new("SmolLM2Block.forward(x, p_start)", "")
  c.add_input("x",       "R^{T×D}", "")
  c.add_input("p_start", "",       "")
  c.add_output("x",      "R^{T×D}", "")
  c.step_update("x", "x + GQAttn(RMSNorm(x; γ_1, ε), p_start)",
                "", "residual; RoPE inside attn")
  c.step_update("x", "x + SwiGLU(RMSNorm(x; γ_2, ε))",
                "", "residual")
  c.step_return("x")
  c
end

#algorithm_cardObject



251
# File 'lib/toy/models/toy_smollm2.rb', line 251

def algorithm_card; algorithm.render_pseudocode; end

#forward(x, pos_start) ⇒ Object

x: [T, D] → [T, D]. pos_start: absolute position of row 0 of x.



227
228
229
230
231
# File 'lib/toy/models/toy_smollm2.rb', line 227

def forward(x, pos_start)
  x.add!(@attn.forward(@rn1.forward(x), pos_start))   # residual after attention
  x.add!(@ffn.forward(@rn2.forward(x)))               # residual after FFN
  x
end

#param_countObject



233
234
235
236
# File 'lib/toy/models/toy_smollm2.rb', line 233

def param_count
  @rn1.param_count + @rn2.param_count +
    @attn.param_count + @ffn.param_count
end