Class: GRX::Optim::Adam

Inherits:
Object
  • Object
show all
Defined in:
lib/grx/optim.rb

Overview

Adam — Adaptive Moment Estimation (Kingma & Ba, 2015) El optimizador estándar para redes neuronales profundas.

Instance Method Summary collapse

Constructor Details

#initialize(params, lr: 0.001, beta1: 0.9, beta2: 0.999, epsilon: 1e-8, weight_decay: 0.0) ⇒ Adam

Returns a new instance of Adam.



57
58
59
60
61
62
63
64
65
66
67
68
69
70
# File 'lib/grx/optim.rb', line 57

def initialize(params, lr: 0.001, beta1: 0.9, beta2: 0.999,
               epsilon: 1e-8, weight_decay: 0.0)
  @params       = params
  @lr           = lr
  @beta1        = beta1
  @beta2        = beta2
  @epsilon      = epsilon
  @weight_decay = weight_decay
  @t            = 0  # paso actual

  # Momentos de primer y segundo orden (inicializados en cero)
  @m = params.map { |p| Tensor.zeros_like(p) }
  @v = params.map { |p| Tensor.zeros_like(p) }
end

Instance Method Details

#stepObject



72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# File 'lib/grx/optim.rb', line 72

def step
  @t += 1
  beta1t = @beta1 ** @t  # beta1^t para corrección de bias
  beta2t = @beta2 ** @t

  @params.each_with_index do |param, i|
    next unless param.grad

    grad = param.grad

    if @weight_decay > 0
      grad = grad + param.scale(@weight_decay)
    end

    if CAPI::LOADED
      CAPI.grx_adam_step(
        param.storage.ptr,
        @m[i].storage.ptr,
        @v[i].storage.ptr,
        grad.storage.ptr,
        @lr, @beta1, @beta2, @epsilon,
        beta1t, beta2t,
        param.numel
      )
    else
      # Fallback Ruby puro
      p_data = param.to_a
      m_data = @m[i].to_a
      v_data = @v[i].to_a
      g_data = grad.to_a
      p_data.each_with_index do |_, j|
        m_data[j] = @beta1 * m_data[j] + (1 - @beta1) * g_data[j]
        v_data[j] = @beta2 * v_data[j] + (1 - @beta2) * g_data[j] ** 2
        mh = m_data[j] / (1 - beta1t)
        vh = v_data[j] / (1 - beta2t)
        p_data[j] -= @lr * mh / (Math.sqrt(vh) + @epsilon)
      end
      param.storage.instance_variable_set(:@data, p_data)
      @m[i].storage.instance_variable_set(:@data, m_data)
      @v[i].storage.instance_variable_set(:@data, v_data)
    end
  end
end

#zero_gradObject



116
117
118
# File 'lib/grx/optim.rb', line 116

def zero_grad
  @params.each(&:zero_grad!)
end