Class: GRX::Loss::BCELoss

Inherits:
Object
  • Object
show all
Defined in:
lib/grx/loss.rb

Overview

BCELoss — Binary Cross-Entropy L = -mean(t*log(p) + (1-t)*log(1-p)) pred debe estar en (0,1) — aplica sigmoid antes si usas logits.

Constant Summary collapse

EPS =
1e-7

Instance Method Summary collapse

Instance Method Details

#call(pred, target) ⇒ Object

Raises:



35
36
37
38
39
40
41
42
43
44
45
# File 'lib/grx/loss.rb', line 35

def call(pred, target)
  raise ShapeError, "Shapes incompatibles" if pred.shape != target.shape
  p_data = pred.to_a.map { |v| v < EPS ? EPS : (v > 1-EPS ? 1-EPS : v) }
  t_data = target.to_a
  total  = p_data.size.to_f
  loss   = p_data.each_with_index.sum do |p, i|
    t = t_data[i]
    -(t * Math.log(p) + (1 - t) * Math.log(1 - p))
  end
  loss / total
end