Class: Trainers::EarlyStoppingCallback
- Inherits:
-
TrainerCallback
- Object
- TrainerCallback
- Trainers::EarlyStoppingCallback
- Defined in:
- lib/trainers/callbacks.rb
Overview
Early stopping callback
Instance Method Summary collapse
-
#initialize(patience: 3, threshold: 0.0, metric_name: "eval_loss") ⇒ EarlyStoppingCallback
constructor
A new instance of EarlyStoppingCallback.
- #on_evaluate(args, state, control, metrics: nil, **kwargs) ⇒ Object
Methods inherited from TrainerCallback
#on_epoch_begin, #on_epoch_end, #on_log, #on_save, #on_step_begin, #on_step_end, #on_train_begin, #on_train_end
Constructor Details
#initialize(patience: 3, threshold: 0.0, metric_name: "eval_loss") ⇒ EarlyStoppingCallback
Returns a new instance of EarlyStoppingCallback.
78 79 80 81 82 83 84 |
# File 'lib/trainers/callbacks.rb', line 78 def initialize(patience: 3, threshold: 0.0, metric_name: "eval_loss") @patience = patience @threshold = threshold @metric_name = metric_name @best_value = nil @wait_count = 0 end |
Instance Method Details
#on_evaluate(args, state, control, metrics: nil, **kwargs) ⇒ Object
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# File 'lib/trainers/callbacks.rb', line 86 def on_evaluate(args, state, control, metrics: nil, **kwargs) return unless metrics current = metrics[@metric_name] || metrics[@metric_name.to_sym] return unless current if @best_value.nil? || improved?(current, @best_value) @best_value = current @wait_count = 0 else @wait_count += 1 if @wait_count >= @patience puts "Early stopping triggered after #{@wait_count} evaluations without improvement" control.should_training_stop = true end end end |