Class: GRX::NN::BatchNorm1d

Inherits:
Module
  • Object
show all
Defined in:
lib/grx/nn.rb

Overview

BatchNorm1d — Normalización por batch Estabiliza el entrenamiento de redes profundas.

Instance Method Summary collapse

Methods inherited from Module

#call, #parameters, #zero_grad

Constructor Details

#initialize(num_features, epsilon: 1e-5, momentum: 0.1) ⇒ BatchNorm1d

Returns a new instance of BatchNorm1d.



197
198
199
200
201
202
203
204
205
206
207
208
209
210
# File 'lib/grx/nn.rb', line 197

def initialize(num_features, epsilon: 1e-5, momentum: 0.1)
  @num_features = num_features
  @epsilon      = epsilon
  @momentum     = momentum
  @training     = true

  # Parámetros entrenables
  @gamma = Tensor.ones([num_features],  requires_grad: true)
  @beta  = Tensor.zeros([num_features], requires_grad: true)

  # Estadísticas corrientes (no entrenables, para inferencia)
  @running_mean = Tensor.zeros([num_features])
  @running_var  = Tensor.ones([num_features])
end

Instance Method Details

#eval!Object



213
# File 'lib/grx/nn.rb', line 213

def eval!;  @training = false; self; end

#forward(x) ⇒ Object



215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# File 'lib/grx/nn.rb', line 215

def forward(x)
  # x: [batch, num_features]
  batch_size = x.shape[0]

  if @training
    # Calculamos media y varianza del batch
    batch_data = x.to_a
    means = Array.new(@num_features) do |j|
      batch_data.each_slice(@num_features).map { |row| row[j] }.sum / batch_size
    end
    vars = Array.new(@num_features) do |j|
      col = batch_data.each_slice(@num_features).map { |row| row[j] }
      col.sum { |v| (v - means[j]) ** 2 } / batch_size
    end

    # Actualizamos estadísticas corrientes
    means.each_with_index do |m, j|
      rm = @running_mean.to_a; rm[j] = (1 - @momentum) * rm[j] + @momentum * m
      @running_mean = Tensor.create(rm, [@num_features])
    end
    vars.each_with_index do |v, j|
      rv = @running_var.to_a; rv[j] = (1 - @momentum) * rv[j] + @momentum * v
      @running_var = Tensor.create(rv, [@num_features])
    end

    mean_t = Tensor.create(means, [@num_features])
    var_t  = Tensor.create(vars,  [@num_features])
  else
    mean_t = @running_mean
    var_t  = @running_var
  end

  # Normalizamos: x_hat = (x - mean) / sqrt(var + eps)
  # Luego escalamos: y = gamma * x_hat + beta
  norm_data = x.to_a.each_slice(@num_features).flat_map do |row|
    row.each_with_index.map do |v, j|
      x_hat = (v - mean_t.to_a[j]) / Math.sqrt(var_t.to_a[j] + @epsilon)
      @gamma.to_a[j] * x_hat + @beta.to_a[j]
    end
  end

  Tensor.create(norm_data, x.shape)
end

#to_sObject



259
# File 'lib/grx/nn.rb', line 259

def to_s = "BatchNorm1d(#{@num_features})"

#train!Object



212
# File 'lib/grx/nn.rb', line 212

def train!; @training = true;  self; end