Class: GRX::Optim::Adam
- Inherits:
-
Object
- Object
- GRX::Optim::Adam
- Defined in:
- lib/grx/optim.rb
Overview
Adam — Adaptive Moment Estimation (Kingma & Ba, 2015) El optimizador estándar para redes neuronales profundas.
Instance Method Summary collapse
-
#initialize(params, lr: 0.001, beta1: 0.9, beta2: 0.999, epsilon: 1e-8, weight_decay: 0.0) ⇒ Adam
constructor
A new instance of Adam.
- #step ⇒ Object
- #zero_grad ⇒ Object
Constructor Details
#initialize(params, lr: 0.001, beta1: 0.9, beta2: 0.999, epsilon: 1e-8, weight_decay: 0.0) ⇒ Adam
Returns a new instance of Adam.
57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# File 'lib/grx/optim.rb', line 57 def initialize(params, lr: 0.001, beta1: 0.9, beta2: 0.999, epsilon: 1e-8, weight_decay: 0.0) @params = params @lr = lr @beta1 = beta1 @beta2 = beta2 @epsilon = epsilon @weight_decay = weight_decay @t = 0 # paso actual # Momentos de primer y segundo orden (inicializados en cero) @m = params.map { |p| Tensor.zeros_like(p) } @v = params.map { |p| Tensor.zeros_like(p) } end |
Instance Method Details
#step ⇒ Object
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# File 'lib/grx/optim.rb', line 72 def step @t += 1 beta1t = @beta1 ** @t # beta1^t para corrección de bias beta2t = @beta2 ** @t @params.each_with_index do |param, i| next unless param.grad grad = param.grad if @weight_decay > 0 grad = grad + param.scale(@weight_decay) end if CAPI::LOADED CAPI.grx_adam_step( param.storage.ptr, @m[i].storage.ptr, @v[i].storage.ptr, grad.storage.ptr, @lr, @beta1, @beta2, @epsilon, beta1t, beta2t, param.numel ) else # Fallback Ruby puro p_data = param.to_a m_data = @m[i].to_a v_data = @v[i].to_a g_data = grad.to_a p_data.each_with_index do |_, j| m_data[j] = @beta1 * m_data[j] + (1 - @beta1) * g_data[j] v_data[j] = @beta2 * v_data[j] + (1 - @beta2) * g_data[j] ** 2 mh = m_data[j] / (1 - beta1t) vh = v_data[j] / (1 - beta2t) p_data[j] -= @lr * mh / (Math.sqrt(vh) + @epsilon) end param.storage.instance_variable_set(:@data, p_data) @m[i].storage.instance_variable_set(:@data, m_data) @v[i].storage.instance_variable_set(:@data, v_data) end end end |
#zero_grad ⇒ Object
116 117 118 |
# File 'lib/grx/optim.rb', line 116 def zero_grad @params.each(&:zero_grad!) end |