Class: GRX::Loss::CrossEntropyLoss

Inherits:
Object
  • Object
show all
Defined in:
lib/grx/loss.rb

Overview

CrossEntropyLoss — Softmax + NLL L = -mean(sum(target * log(softmax(logits))))

Constant Summary collapse

EPS =
1e-7

Instance Method Summary collapse

Instance Method Details

#call(logits, target) ⇒ Object

Raises:



55
56
57
58
59
60
61
# File 'lib/grx/loss.rb', line 55

def call(logits, target)
  raise ShapeError, "Shapes incompatibles" if logits.shape != target.shape
  probs  = logits.softmax.to_a.map { |v| v < EPS ? EPS : v }
  t_data = target.to_a
  loss   = probs.each_with_index.sum { |p, i| -t_data[i] * Math.log(p) }
  loss / probs.size.to_f
end