Class: Ignis::AI::Optim::LRScheduler::CosineWithWarmup

Inherits:
Object
  • Object
show all
Defined in:
lib/nnw/ai/optim/lr_scheduler.rb

Overview

Cosine with linear warmup: common combo for Transformer training.

Instance Method Summary collapse

Constructor Details

#initialize(optimizer, warmup_steps:, total_steps:, min_lr: 0.0) ⇒ CosineWithWarmup

Returns a new instance of CosineWithWarmup.

Parameters:

  • optimizer (Base)
  • warmup_steps (Integer)
  • total_steps (Integer)
  • min_lr (Float) (defaults to: 0.0)


80
81
82
83
84
85
86
# File 'lib/nnw/ai/optim/lr_scheduler.rb', line 80

def initialize(optimizer, warmup_steps:, total_steps:, min_lr: 0.0)
  @optimizer = optimizer
  @warmup_steps = warmup_steps
  @total_steps = total_steps
  @min_lr = min_lr.to_f
  @base_lr = optimizer.lr
end

Instance Method Details

#current_lrFloat

Returns:

  • (Float)


111
112
113
# File 'lib/nnw/ai/optim/lr_scheduler.rb', line 111

def current_lr
  @optimizer.lr
end

#stepFloat

Returns:

  • (Float)


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# File 'lib/nnw/ai/optim/lr_scheduler.rb', line 89

def step
  current = @optimizer.step_count

  if current < @warmup_steps
    # Linear warmup phase
    factor = current.to_f / @warmup_steps
    new_lr = @base_lr * factor
  elsif current >= @total_steps
    new_lr = @min_lr
  else
    # Cosine decay phase
    decay_steps = @total_steps - @warmup_steps
    decay_current = current - @warmup_steps
    progress = decay_current.to_f / decay_steps
    new_lr = @min_lr + 0.5 * (@base_lr - @min_lr) * (1.0 + Math.cos(Math::PI * progress))
  end

  @optimizer.lr = new_lr
  new_lr
end