Class: Ignis::AI::Optim::SGD

Inherits:
Base
  • Object
show all
Defined in:
lib/nnw/ai/optim/sgd.rb

Overview

SGD optimizer with optional momentum and weight decay.

Instance Attribute Summary

Attributes inherited from Base

#params, #step_count

Instance Method Summary collapse

Methods inherited from Base

#clip_grad_norm!, #lr, #lr=, #zero_grad!

Constructor Details

#initialize(params, lr:, momentum: 0.0, weight_decay: 0.0) ⇒ SGD

Returns a new instance of SGD.

Parameters:

  • params (Array<Tensor>)
  • lr (Float)

    learning rate

  • momentum (Float) (defaults to: 0.0)

    momentum factor (0.0 = no momentum)

  • weight_decay (Float) (defaults to: 0.0)

    L2 penalty



14
15
16
17
18
19
20
21
22
23
24
25
26
27
# File 'lib/nnw/ai/optim/sgd.rb', line 14

def initialize(params, lr:, momentum: 0.0, weight_decay: 0.0)
  super(params, lr: lr)
  @momentum = momentum.to_f
  @weight_decay = weight_decay.to_f

  @velocities = {}
  if @momentum > 0.0
    params.each do |p|
      v = Ignis::Shared::NvArray.new(shape: p.shape, dtype: p.dtype, device_id: p.device_id)
      v.from_host(Array.new(p.numel, 0.0))
      @velocities[p.object_id] = v
    end
  end
end

Instance Method Details

#stepvoid

This method returns an undefined value.



30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# File 'lib/nnw/ai/optim/sgd.rb', line 30

def step
  kernel = Ignis::JIT::Kernels::Optimizer.sgd_step

  @params.each do |p|
    next unless p.grad
    n = p.numel
    vel = @velocities[p.object_id]
    vel_ptr = vel || p.grad  # If no momentum, velocity unused but still passed

    kernel.launch(grid: [(n + 255) / 256], block: [256],
                  args: [p.data, p.grad, vel_ptr, @lr, @momentum, @weight_decay, n])
  end
  Ignis.synchronize

  @step_count += 1
end