Class: LRSchedule
- Inherits:
-
Object
- Object
- LRSchedule
- Defined in:
- lib/toy/train/training.rb
Overview
LRSchedule — linear warmup → cosine decay
at(step) returns the LR for the given (0-indexed) optimizer step. During warmup the LR ramps linearly from 0 to max_lr; after warmup a half-period cosine decays from max_lr to min_lr over the remaining (total_steps − warmup_steps) steps.
Instance Attribute Summary collapse
-
#max_lr ⇒ Object
Returns the value of attribute max_lr.
-
#min_lr ⇒ Object
Returns the value of attribute min_lr.
-
#total_steps ⇒ Object
Returns the value of attribute total_steps.
-
#warmup_steps ⇒ Object
Returns the value of attribute warmup_steps.
Instance Method Summary collapse
- #at(step) ⇒ Object
-
#initialize(warmup_steps, total_steps, max_lr, min_lr) ⇒ LRSchedule
constructor
A new instance of LRSchedule.
Constructor Details
#initialize(warmup_steps, total_steps, max_lr, min_lr) ⇒ LRSchedule
Returns a new instance of LRSchedule.
65 66 67 68 69 70 |
# File 'lib/toy/train/training.rb', line 65 def initialize(warmup_steps, total_steps, max_lr, min_lr) @warmup_steps = warmup_steps @total_steps = total_steps @max_lr = max_lr @min_lr = min_lr end |
Instance Attribute Details
#max_lr ⇒ Object
Returns the value of attribute max_lr.
63 64 65 |
# File 'lib/toy/train/training.rb', line 63 def max_lr @max_lr end |
#min_lr ⇒ Object
Returns the value of attribute min_lr.
63 64 65 |
# File 'lib/toy/train/training.rb', line 63 def min_lr @min_lr end |
#total_steps ⇒ Object
Returns the value of attribute total_steps.
63 64 65 |
# File 'lib/toy/train/training.rb', line 63 def total_steps @total_steps end |
#warmup_steps ⇒ Object
Returns the value of attribute warmup_steps.
63 64 65 |
# File 'lib/toy/train/training.rb', line 63 def warmup_steps @warmup_steps end |
Instance Method Details
#at(step) ⇒ Object
72 73 74 75 76 77 78 79 80 81 82 |
# File 'lib/toy/train/training.rb', line 72 def at(step) if step < @warmup_steps return @max_lr * (step.to_f + 1.0) / @warmup_steps.to_f end if step >= @total_steps return @min_lr end progress = (step - @warmup_steps).to_f / (@total_steps - @warmup_steps).to_f cos_v = 0.5 * (1.0 + Math.cos(Math::PI * progress)) @min_lr + (@max_lr - @min_lr) * cos_v end |