Class: Trainers::EarlyStoppingCallback

Inherits:
TrainerCallback show all
Defined in:
lib/trainers/callbacks.rb

Overview

Early stopping callback

Instance Method Summary collapse

Methods inherited from TrainerCallback

#on_epoch_begin, #on_epoch_end, #on_log, #on_save, #on_step_begin, #on_step_end, #on_train_begin, #on_train_end

Constructor Details

#initialize(patience: 3, threshold: 0.0, metric_name: "eval_loss") ⇒ EarlyStoppingCallback

Returns a new instance of EarlyStoppingCallback.



78
79
80
81
82
83
84
# File 'lib/trainers/callbacks.rb', line 78

def initialize(patience: 3, threshold: 0.0, metric_name: "eval_loss")
  @patience    = patience
  @threshold   = threshold
  @metric_name = metric_name
  @best_value  = nil
  @wait_count  = 0
end

Instance Method Details

#on_evaluate(args, state, control, metrics: nil, **kwargs) ⇒ Object



86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# File 'lib/trainers/callbacks.rb', line 86

def on_evaluate(args, state, control, metrics: nil, **kwargs)
  return unless metrics

  current = metrics[@metric_name] || metrics[@metric_name.to_sym]
  return unless current

  if @best_value.nil? || improved?(current, @best_value)
    @best_value = current
    @wait_count = 0
  else
    @wait_count += 1
    if @wait_count >= @patience
      puts "Early stopping triggered after #{@wait_count} evaluations without improvement"
      control.should_training_stop = true
    end
  end
end