Class: GRX::Loss::CrossEntropyLoss
- Inherits:
-
Object
- Object
- GRX::Loss::CrossEntropyLoss
- Defined in:
- lib/grx/loss.rb
Overview
CrossEntropyLoss — Softmax + NLL L = -mean(sum(target * log(softmax(logits))))
Constant Summary collapse
- EPS =
1e-7
Instance Method Summary collapse
Instance Method Details
#call(logits, target) ⇒ Object
55 56 57 58 59 60 61 |
# File 'lib/grx/loss.rb', line 55 def call(logits, target) raise ShapeError, "Shapes incompatibles" if logits.shape != target.shape probs = logits.softmax.to_a.map { |v| v < EPS ? EPS : v } t_data = target.to_a loss = probs.each_with_index.sum { |p, i| -t_data[i] * Math.log(p) } loss / probs.size.to_f end |