Class: Toy::Trainer

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/train/toy_trainer.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(model) ⇒ Trainer

Defaults match the train_tinystories.rb constants — sensible starting points for a small transformer LM.



45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# File 'lib/toy/train/toy_trainer.rb', line 45

def initialize(model)
  @model       = model
  @beta1       = 0.9
  @beta2       = 0.999
  @eps         = 0.00000001
  @lr_max      = 0.001
  @lr_min      = 0.00001
  @warmup      = 200
  @total_steps = 1000
  @step_idx    = 0

  @grads     = Gradients.new(model.vocab_size, model.d_model, model.d_ff,
                              model.n_heads, model.d_head, model.n_layers,
                              model.context_length)
  @optimizer = Adam.new(model, @beta1, @beta2, @eps)
  @schedule  = LRSchedule.new(@warmup, @total_steps, @lr_max, @lr_min)
end

Instance Attribute Details

#beta1Object

Returns the value of attribute beta1.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def beta1
  @beta1
end

#beta2Object

Returns the value of attribute beta2.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def beta2
  @beta2
end

#epsObject

Returns the value of attribute eps.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def eps
  @eps
end

#gradsObject

Returns the value of attribute grads.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def grads
  @grads
end

#lr_maxObject

Returns the value of attribute lr_max.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def lr_max
  @lr_max
end

#lr_minObject

Returns the value of attribute lr_min.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def lr_min
  @lr_min
end

#modelObject

Returns the value of attribute model.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def model
  @model
end

#optimizerObject

Returns the value of attribute optimizer.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def optimizer
  @optimizer
end

#scheduleObject

Returns the value of attribute schedule.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def schedule
  @schedule
end

#step_idxObject

Returns the value of attribute step_idx.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def step_idx
  @step_idx
end

#total_stepsObject

Returns the value of attribute total_steps.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def total_steps
  @total_steps
end

#warmupObject

Returns the value of attribute warmup.



39
40
41
# File 'lib/toy/train/toy_trainer.rb', line 39

def warmup
  @warmup
end

Instance Method Details

#lrObject

Convenience: current learning rate.



82
83
84
# File 'lib/toy/train/toy_trainer.rb', line 82

def lr
  @schedule.at(@step_idx)
end

#reset_optimizer!Object

Reset optimizer state (e.g. after a warm-up step that you don’t want to count). Step counter stays where it is — change @step_idx by hand if you want that too.



77
78
79
# File 'lib/toy/train/toy_trainer.rb', line 77

def reset_optimizer!
  @optimizer.reset
end

#step!(seq) ⇒ Object

One optimizer step on a single sequence. Returns the loss. The four-line body is the whole point: this is what training is.



65
66
67
68
69
70
71
72
# File 'lib/toy/train/toy_trainer.rb', line 65

def step!(seq)
  @grads.fill_zero
  @model.forward(seq)
  @model.backward(seq, @grads)
  @optimizer.step(@grads, @schedule.at(@step_idx))
  @step_idx += 1
  @grads.loss
end