Class: LRSchedule

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/train/training.rb

Overview


LRSchedule  linear warmup  cosine decay

at(step) returns the LR for the given (0-indexed) optimizer step. During warmup the LR ramps linearly from 0 to max_lr; after warmup a half-period cosine decays from max_lr to min_lr over the remaining (total_steps − warmup_steps) steps.

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(warmup_steps, total_steps, max_lr, min_lr) ⇒ LRSchedule

Returns a new instance of LRSchedule.



65
66
67
68
69
70
# File 'lib/toy/train/training.rb', line 65

def initialize(warmup_steps, total_steps, max_lr, min_lr)
  @warmup_steps = warmup_steps
  @total_steps  = total_steps
  @max_lr       = max_lr
  @min_lr       = min_lr
end

Instance Attribute Details

#max_lrObject

Returns the value of attribute max_lr.



63
64
65
# File 'lib/toy/train/training.rb', line 63

def max_lr
  @max_lr
end

#min_lrObject

Returns the value of attribute min_lr.



63
64
65
# File 'lib/toy/train/training.rb', line 63

def min_lr
  @min_lr
end

#total_stepsObject

Returns the value of attribute total_steps.



63
64
65
# File 'lib/toy/train/training.rb', line 63

def total_steps
  @total_steps
end

#warmup_stepsObject

Returns the value of attribute warmup_steps.



63
64
65
# File 'lib/toy/train/training.rb', line 63

def warmup_steps
  @warmup_steps
end

Instance Method Details

#at(step) ⇒ Object



72
73
74
75
76
77
78
79
80
81
82
# File 'lib/toy/train/training.rb', line 72

def at(step)
  if step < @warmup_steps
    return @max_lr * (step.to_f + 1.0) / @warmup_steps.to_f
  end
  if step >= @total_steps
    return @min_lr
  end
  progress = (step - @warmup_steps).to_f / (@total_steps - @warmup_steps).to_f
  cos_v    = 0.5 * (1.0 + Math.cos(Math::PI * progress))
  @min_lr + (@max_lr - @min_lr) * cos_v
end