Class: Adam

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/train/training.rb

Overview


Adam — optimizer wrapper that owns the m/v state and steps the model

Constructed from the model so the AdamState shapes itself to the model’s parameter inventory. step(grads, lr) does one Adam update; reset zeroes the moments + bias-correction products (used right after the warm-up call so the warm-up step doesn’t leak into real training).

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(model, beta1, beta2, eps) ⇒ Adam

Returns a new instance of Adam.



140
141
142
143
144
145
146
147
148
# File 'lib/toy/train/training.rb', line 140

def initialize(model, beta1, beta2, eps)
  @model = model
  @state = AdamState.new(model.vocab_size, model.d_model, model.d_ff,
                         model.n_heads, model.d_head, model.n_layers,
                         model.context_length)
  @beta1 = beta1
  @beta2 = beta2
  @eps   = eps
end

Instance Attribute Details

#beta1Object

Returns the value of attribute beta1.



138
139
140
# File 'lib/toy/train/training.rb', line 138

def beta1
  @beta1
end

#beta2Object

Returns the value of attribute beta2.



138
139
140
# File 'lib/toy/train/training.rb', line 138

def beta2
  @beta2
end

#epsObject

Returns the value of attribute eps.



138
139
140
# File 'lib/toy/train/training.rb', line 138

def eps
  @eps
end

#modelObject

Returns the value of attribute model.



138
139
140
# File 'lib/toy/train/training.rb', line 138

def model
  @model
end

#stateObject

Returns the value of attribute state.



138
139
140
# File 'lib/toy/train/training.rb', line 138

def state
  @state
end

Instance Method Details

#resetObject



154
155
156
157
158
159
# File 'lib/toy/train/training.rb', line 154

def reset
  @state.bc1 = 1.0
  @state.bc2 = 1.0
  @state.m.fill_zero
  @state.v.fill_zero
end

#step(grads, lr) ⇒ Object



150
151
152
# File 'lib/toy/train/training.rb', line 150

def step(grads, lr)
  @model.apply_gradients_adam(grads, @state, lr, @beta1, @beta2, @eps)
end