Class: Ignis::AI::Optim::AdamW

Inherits:
Base
  • Object
show all
Defined in:
lib/nnw/ai/optim/adamw.rb

Overview

AdamW optimizer (Loshchilov & Hutter, 2019). Key difference from Adam: weight decay is decoupled from gradient update. This means WD is applied directly to parameters, not through the gradient.

Instance Attribute Summary

Attributes inherited from Base

#params, #step_count

Instance Method Summary collapse

Methods inherited from Base

#clip_grad_norm!, #lr, #lr=, #zero_grad!

Constructor Details

#initialize(params, lr: 1e-3, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 0.01) ⇒ AdamW

Returns a new instance of AdamW.

Parameters:

  • params (Array<Tensor>)
  • lr (Float) (defaults to: 1e-3)
  • beta1 (Float) (defaults to: 0.9)
  • beta2 (Float) (defaults to: 0.999)
  • eps (Float) (defaults to: 1e-8)
  • weight_decay (Float) (defaults to: 0.01)

    decoupled weight decay



18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'lib/nnw/ai/optim/adamw.rb', line 18

def initialize(params, lr: 1e-3, beta1: 0.9, beta2: 0.999,
               eps: 1e-8, weight_decay: 0.01)
  super(params, lr: lr)
  @beta1 = beta1.to_f
  @beta2 = beta2.to_f
  @eps = eps.to_f
  @weight_decay = weight_decay.to_f

  @m_states = {}
  @v_states = {}
  params.each do |p|
    m = Ignis::Shared::NvArray.new(shape: p.shape, dtype: p.dtype, device_id: p.device_id)
    m.from_host(Array.new(p.numel, 0.0))
    @m_states[p.object_id] = m

    v = Ignis::Shared::NvArray.new(shape: p.shape, dtype: p.dtype, device_id: p.device_id)
    v.from_host(Array.new(p.numel, 0.0))
    @v_states[p.object_id] = v
  end
end

Instance Method Details

#stepvoid

This method returns an undefined value.



40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# File 'lib/nnw/ai/optim/adamw.rb', line 40

def step
  @step_count += 1
  kernel = Ignis::JIT::Kernels::Optimizer.adamw_step

  bias_correction1 = 1.0 - @beta1**@step_count
  bias_correction2 = 1.0 - @beta2**@step_count

  @params.each do |p|
    next unless p.grad
    n = p.numel
    m = @m_states[p.object_id]
    v = @v_states[p.object_id]

    kernel.launch(grid: [(n + 255) / 256], block: [256],
                  args: [p.data, p.grad, m, v,
                         @lr, @beta1, @beta2, @eps, @weight_decay,
                         bias_correction1.to_f, bias_correction2.to_f, n])
  end
  Ignis.synchronize
end