Class: GRX::Optim::SGD

Inherits:
Object
  • Object
show all
Defined in:
lib/grx/optim.rb

Overview

SGD — Stochastic Gradient Descent (con momentum opcional)

Instance Method Summary collapse

Constructor Details

#initialize(params, lr: 0.01, momentum: 0.0, weight_decay: 0.0) ⇒ SGD

Returns a new instance of SGD.



9
10
11
12
13
14
15
16
# File 'lib/grx/optim.rb', line 9

def initialize(params, lr: 0.01, momentum: 0.0, weight_decay: 0.0)
  @params       = params
  @lr           = lr
  @momentum     = momentum
  @weight_decay = weight_decay
  # Buffer de velocidad para momentum
  @velocity = params.map { |p| Tensor.zeros_like(p) }
end

Instance Method Details

#stepObject



18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# File 'lib/grx/optim.rb', line 18

def step
  @params.each_with_index do |param, i|
    next unless param.grad

    grad = param.grad

    # L2 regularización (weight decay)
    if @weight_decay > 0
      grad = grad + param.scale(@weight_decay)
    end

    if @momentum > 0
      # v = momentum*v + grad
      @velocity[i] = @velocity[i].scale(@momentum) + grad
      grad = @velocity[i]
    end

    if CAPI::LOADED
      CAPI.grx_sgd_step(param.storage.ptr, grad.storage.ptr, @lr, param.numel)
    else
      # Fallback Ruby
      param_data = param.to_a
      grad_data  = grad.to_a
      param_data.each_with_index { |v, j| param_data[j] = v - @lr * grad_data[j] }
      param.storage.instance_variable_set(:@data, param_data)
    end
  end
end

#zero_gradObject



47
48
49
# File 'lib/grx/optim.rb', line 47

def zero_grad
  @params.each(&:zero_grad!)
end