Class: Toy::AdamW

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/llm/adamw.rb

Constant Summary collapse

MODE_UNSET =

EXPLICIT MODE (toy#64 item 4) — names which slot-5/6 convention this hp object feeds (the loud finding above). 0 = unset: hp() FAILS LOUD until the caller declares a mode, because feeding the wrong convention into a training graph silently changes numerics (train.rb’s gate breaks with lora-style hp). Use the factories:

Toy::AdamW.for_from_scratch  — slots 5/6 = constant betas
                               (from-scratch / warm-start / vit
                               graph family; bias_correct=false)
Toy::AdamW.for_lora          — slots 5/6 = 1/(1-beta^t) per-step
                               bias correction, beta2=0.999
                               (lora-family graphs — ALSO the
                               gpt2 + full-finetune graphs,
                               which read the lora convention;
                               bias_correct=true)

or set ‘mode` explicitly after .new. The mode names the HP-SLOT CONVENTION (graph family), not the recipe name.

0
MODE_FROM_SCRATCH =
1
MODE_LORA =
2

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initializeAdamW

FROM-SCRATCH defaults — the most common / blessed case. All set in the body (NO kwargs / default args — Spinel landmine #4). Callers that differ (lora) set the differing fields AFTER .new via the accessors (or use the .for_from_scratch / .for_lora factories, which also set the REQUIRED mode). warm / vit set adamw.lr per step. mode starts UNSET: hp() raises until it is declared.



72
73
74
75
76
77
78
79
80
# File 'lib/toy/llm/adamw.rb', line 72

def initialize
  @lr           = 0.001
  @beta1        = 0.9
  @beta2        = 0.95
  @eps          = 1.0e-8
  @weight_decay = 0.0
  @bias_correct = false
  @mode         = MODE_UNSET
end

Instance Attribute Details

#beta1Object

Returns the value of attribute beta1.



63
64
65
# File 'lib/toy/llm/adamw.rb', line 63

def beta1
  @beta1
end

#beta2Object

Returns the value of attribute beta2.



63
64
65
# File 'lib/toy/llm/adamw.rb', line 63

def beta2
  @beta2
end

#bias_correctObject

Returns the value of attribute bias_correct.



63
64
65
# File 'lib/toy/llm/adamw.rb', line 63

def bias_correct
  @bias_correct
end

#epsObject

Returns the value of attribute eps.



63
64
65
# File 'lib/toy/llm/adamw.rb', line 63

def eps
  @eps
end

#lrObject

Returns the value of attribute lr.



63
64
65
# File 'lib/toy/llm/adamw.rb', line 63

def lr
  @lr
end

#modeObject

Returns the value of attribute mode.



63
64
65
# File 'lib/toy/llm/adamw.rb', line 63

def mode
  @mode
end

#weight_decayObject

Returns the value of attribute weight_decay.



63
64
65
# File 'lib/toy/llm/adamw.rb', line 63

def weight_decay
  @weight_decay
end

Class Method Details

.for_from_scratchObject

The from-scratch/warm/vit-family factory: defaults + mode set. Slots 5/6 = constant betas (bias_correct=false).



84
85
86
87
88
# File 'lib/toy/llm/adamw.rb', line 84

def self.for_from_scratch
  a = Toy::AdamW.new
  a.mode = MODE_FROM_SCRATCH
  a
end

.for_loraObject

The lora-family factory (ALSO the gpt2 + full-finetune graphs): beta2=0.999 + per-step bias correction (slots 5/6 = 1/(1-beta^t)).



92
93
94
95
96
97
98
# File 'lib/toy/llm/adamw.rb', line 92

def self.for_lora
  a = Toy::AdamW.new
  a.beta2        = 0.999
  a.bias_correct = true
  a.mode = MODE_LORA
  a
end

Instance Method Details

#hp(step) ⇒ Object

Build the Mat(1,7) the recipe hands to step!. Byte-identical to the hand-built m_hp the runners used to fill inline.

‘step` is ignored when bias_correct == false (from-scratch / warm / vit pass slots 5/6 = constant betas, so hp(step) is step-agnostic in that mode). When bias_correct == true (lora) it is the CALLER’s 1-indexed step (>= 1); the ‘** step.to_f` reproduces train_lora.rb:174-175 / smoke_recipe_lora.rb:88-89 VERBATIM.

FAILS LOUD (toy#64 item 4) when mode is unset, or when mode and bias_correct disagree — both are silent-numerics traps (the slot-5/6 dual meaning above).



112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# File 'lib/toy/llm/adamw.rb', line 112

def hp(step)
  if @mode == MODE_UNSET
    raise "Toy::AdamW#hp: mode unset — declare which slot-5/6 " +
          "convention this hp feeds (use Toy::AdamW.for_from_scratch " +
          "or Toy::AdamW.for_lora, or set adamw.mode)"
  end
  if @mode == MODE_FROM_SCRATCH && @bias_correct
    raise "Toy::AdamW#hp: mode=from_scratch but bias_correct=true — " +
          "the from-scratch/warm/vit graphs read slots 5/6 as " +
          "CONSTANT BETAS; lora-style hp breaks their byte gates"
  end
  if @mode == MODE_LORA && !@bias_correct
    raise "Toy::AdamW#hp: mode=lora but bias_correct=false — the " +
          "lora-family graphs read slots 5/6 as per-step " +
          "1/(1-beta^t) bias-correction denominators"
  end
  m = Mat.new(1, 7)            # Mat.new zero-fills @flat (transformer.rb:74)
  m.flat[0] = @lr
  m.flat[1] = @beta1
  m.flat[2] = @beta2
  m.flat[3] = @eps
  m.flat[4] = @weight_decay
  # SLOT 5/6 DUAL MEANING (see the loud finding at the top of this
  # file): bias_correct selects which of the two FFI conventions we
  # feed. Do NOT unify — the from-scratch/warm/vit graphs read these
  # as constant betas; the lora graph reads them as 1/(1-beta^t).
  if @bias_correct
    m.flat[5] = 1.0 / (1.0 - (@beta1 ** step.to_f))
    m.flat[6] = 1.0 / (1.0 - (@beta2 ** step.to_f))
  else
    m.flat[5] = @beta1
    m.flat[6] = @beta2
  end
  m
end

#hp_for_step(step) ⇒ Object

MODE-AWARE per-step hp builder (toy#73 item 5): takes the CALLER’s 0-INDEXED loop step (the convention every recipe loop already uses for its ‘step == 0` is_first branch) and applies the mode’s step convention internally:

MODE_FROM_SCRATCH — step-agnostic (slots 5/6 = constant betas);
                    hp_for_step(k) == hp(k).
MODE_LORA         — the lora-family graphs want a 1-INDEXED t in
                    1/(1-beta^t); hp_for_step(k) == hp(k + 1).

This removes the lora paths’ off-by-one ceremony (a 1-indexed loop carried solely to feed hp(step)). Byte-identical to the hp calls it replaces. The #64 mode guard STAYS: both arms delegate to hp, so mode-unset / mode-vs-bias_correct mismatch still fail loud.



162
163
164
165
166
167
# File 'lib/toy/llm/adamw.rb', line 162

def hp_for_step(step)
  if @mode == MODE_LORA
    return hp(step + 1)
  end
  hp(step)
end