Module: Ignis::JIT::Kernels::Optimizer

Defined in:
lib/nvruby/jit/kernels/optimizer.rb

Overview

Optimizer CUDA kernels. Each optimizer step is a single fused kernel per parameter (update moments + param in one pass, avoiding multiple kernel launches).

Class Method Summary collapse

Class Method Details

.adam_stepIgnis::JIT::Kernel

Fused Adam step: update m, v, and param in one kernel launch

Returns:



13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# File 'lib/nvruby/jit/kernels/optimizer.rb', line 13

def adam_step
  source = <<~CUDA
    extern "C" __global__
    void adam_step(float* __restrict__ param,
                   const float* __restrict__ grad,
                   float* __restrict__ m,
                   float* __restrict__ v,
                   const float lr,
                   const float beta1,
                   const float beta2,
                   const float eps,
                   const float weight_decay,
                   const float bias_correction1,
                   const float bias_correction2,
                   const int n) {
      int idx = blockIdx.x * blockDim.x + threadIdx.x;
      if (idx < n) {
        float g = grad[idx];

        // L2 regularization (Adam-style, not decoupled)
        if (weight_decay > 0.0f) {
          g += weight_decay * param[idx];
        }

        // Update biased first moment estimate
        float m_new = beta1 * m[idx] + (1.0f - beta1) * g;
        m[idx] = m_new;

        // Update biased second moment estimate
        float v_new = beta2 * v[idx] + (1.0f - beta2) * g * g;
        v[idx] = v_new;

        // Bias correction
        float m_hat = m_new / bias_correction1;
        float v_hat = v_new / bias_correction2;

        // Update parameter
        param[idx] -= lr * m_hat / (sqrtf(v_hat) + eps);
      }
    }
  CUDA
  compile_cached(source, "adam_step")
end

.adamw_stepIgnis::JIT::Kernel

Fused AdamW step: Adam with decoupled weight decay Weight decay applied directly to param, not through gradient

Returns:



60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# File 'lib/nvruby/jit/kernels/optimizer.rb', line 60

def adamw_step
  source = <<~CUDA
    extern "C" __global__
    void adamw_step(float* __restrict__ param,
                    const float* __restrict__ grad,
                    float* __restrict__ m,
                    float* __restrict__ v,
                    const float lr,
                    const float beta1,
                    const float beta2,
                    const float eps,
                    const float weight_decay,
                    const float bias_correction1,
                    const float bias_correction2,
                    const int n) {
      int idx = blockIdx.x * blockDim.x + threadIdx.x;
      if (idx < n) {
        float g = grad[idx];

        // Update biased first moment
        float m_new = beta1 * m[idx] + (1.0f - beta1) * g;
        m[idx] = m_new;

        // Update biased second moment
        float v_new = beta2 * v[idx] + (1.0f - beta2) * g * g;
        v[idx] = v_new;

        // Bias correction
        float m_hat = m_new / bias_correction1;
        float v_hat = v_new / bias_correction2;

        // Decoupled weight decay + Adam update
        param[idx] -= lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param[idx]);
      }
    }
  CUDA
  compile_cached(source, "adamw_step")
end

.grad_clip_scaleIgnis::JIT::Kernel

Phase 2: scale gradients by clip factor clip_factor = max_norm / (total_norm + eps)

Returns:



169
170
171
172
173
174
175
176
177
178
179
180
181
182
# File 'lib/nvruby/jit/kernels/optimizer.rb', line 169

def grad_clip_scale
  source = <<~CUDA
    extern "C" __global__
    void grad_clip_scale(float* __restrict__ grad,
                         const float clip_factor,
                         const int n) {
      int idx = blockIdx.x * blockDim.x + threadIdx.x;
      if (idx < n) {
        grad[idx] *= clip_factor;
      }
    }
  CUDA
  compile_cached(source, "grad_clip_scale")
end

.grad_squared_sumIgnis::JIT::Kernel

Gradient clipping by global norm Phase 1: compute per-parameter squared sum

Returns:



137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# File 'lib/nvruby/jit/kernels/optimizer.rb', line 137

def grad_squared_sum
  source = <<~CUDA
    extern "C" __global__
    void grad_squared_sum(const float* __restrict__ grad,
                          float* __restrict__ partial_sum,
                          const int n) {
      extern __shared__ float sdata[];
      int tid = threadIdx.x;
      int idx = blockIdx.x * blockDim.x + threadIdx.x;

      sdata[tid] = (idx < n) ? grad[idx] * grad[idx] : 0.0f;
      __syncthreads();

      // Parallel reduction in shared memory
      for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
          sdata[tid] += sdata[tid + s];
        }
        __syncthreads();
      }

      if (tid == 0) {
        atomicAdd(partial_sum, sdata[0]);
      }
    }
  CUDA
  compile_cached(source, "grad_squared_sum")
end

.sgd_stepIgnis::JIT::Kernel

SGD step with momentum and weight decay

Returns:



101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# File 'lib/nvruby/jit/kernels/optimizer.rb', line 101

def sgd_step
  source = <<~CUDA
    extern "C" __global__
    void sgd_step(float* __restrict__ param,
                  const float* __restrict__ grad,
                  float* __restrict__ velocity,
                  const float lr,
                  const float momentum,
                  const float weight_decay,
                  const int n) {
      int idx = blockIdx.x * blockDim.x + threadIdx.x;
      if (idx < n) {
        float g = grad[idx];

        if (weight_decay > 0.0f) {
          g += weight_decay * param[idx];
        }

        float v;
        if (momentum > 0.0f) {
          v = momentum * velocity[idx] + g;
          velocity[idx] = v;
        } else {
          v = g;
        }

        param[idx] -= lr * v;
      }
    }
  CUDA
  compile_cached(source, "sgd_step")
end