Class: GRX::Loss::HuberLoss

Inherits:
Object
  • Object
show all
Defined in:
lib/grx/loss.rb

Overview

HuberLoss — Smooth L1

Instance Method Summary collapse

Constructor Details

#initialize(delta: 1.0) ⇒ HuberLoss

Returns a new instance of HuberLoss.



68
69
70
# File 'lib/grx/loss.rb', line 68

def initialize(delta: 1.0)
  @delta = delta
end

Instance Method Details

#call(pred, target) ⇒ Object

Raises:



72
73
74
75
76
77
78
# File 'lib/grx/loss.rb', line 72

def call(pred, target)
  raise ShapeError, "Shapes incompatibles" if pred.shape != target.shape
  d = @delta
  diffs = (pred - target).abs.to_a
  loss  = diffs.sum { |v| v <= d ? 0.5 * v * v : d * (v - 0.5 * d) }
  loss / diffs.size.to_f
end