Class: Ignis::AI::Optim::Adam

Inherits:
Base
  • Object
show all
Defined in:
lib/nnw/ai/optim/adam.rb

Overview

Adam optimizer (Kingma & Ba, 2014). Fused single-kernel per parameter: updates m, v, and param in one launch.

Instance Attribute Summary

Attributes inherited from Base

#params, #step_count

Instance Method Summary collapse

Methods inherited from Base

#clip_grad_norm!, #lr, #lr=, #zero_grad!

Constructor Details

#initialize(params, lr: 1e-3, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 0.0) ⇒ Adam

Returns a new instance of Adam.

Parameters:

  • params (Array<Tensor>)
  • lr (Float) (defaults to: 1e-3)

    learning rate

  • beta1 (Float) (defaults to: 0.9)

    exponential decay rate for first moment

  • beta2 (Float) (defaults to: 0.999)

    exponential decay rate for second moment

  • eps (Float) (defaults to: 1e-8)

    numerical stability

  • weight_decay (Float) (defaults to: 0.0)

    L2 regularization



17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'lib/nnw/ai/optim/adam.rb', line 17

def initialize(params, lr: 1e-3, beta1: 0.9, beta2: 0.999,
               eps: 1e-8, weight_decay: 0.0)
  super(params, lr: lr)
  @beta1 = beta1.to_f
  @beta2 = beta2.to_f
  @eps = eps.to_f
  @weight_decay = weight_decay.to_f

  # Initialize first and second moment estimates to zero
  @m_states = {}
  @v_states = {}
  params.each do |p|
    m = Ignis::Shared::NvArray.new(shape: p.shape, dtype: p.dtype, device_id: p.device_id)
    m.from_host(Array.new(p.numel, 0.0))
    @m_states[p.object_id] = m

    v = Ignis::Shared::NvArray.new(shape: p.shape, dtype: p.dtype, device_id: p.device_id)
    v.from_host(Array.new(p.numel, 0.0))
    @v_states[p.object_id] = v
  end
end

Instance Method Details

#stepvoid

This method returns an undefined value.



40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# File 'lib/nnw/ai/optim/adam.rb', line 40

def step
  @step_count += 1
  kernel = Ignis::JIT::Kernels::Optimizer.adam_step

  bias_correction1 = 1.0 - @beta1**@step_count
  bias_correction2 = 1.0 - @beta2**@step_count

  @params.each do |p|
    next unless p.grad
    n = p.numel
    m = @m_states[p.object_id]
    v = @v_states[p.object_id]

    kernel.launch(grid: [(n + 255) / 256], block: [256],
                  args: [p.data, p.grad, m, v,
                         @lr, @beta1, @beta2, @eps, @weight_decay,
                         bias_correction1.to_f, bias_correction2.to_f, n])
  end
  Ignis.synchronize
end