Module: Ignis::JIT::Kernels::Optimizer
- Defined in:
- lib/nvruby/jit/kernels/optimizer.rb
Overview
Optimizer CUDA kernels. Each optimizer step is a single fused kernel per parameter (update moments + param in one pass, avoiding multiple kernel launches).
Class Method Summary collapse
-
.adam_step ⇒ Ignis::JIT::Kernel
Fused Adam step: update m, v, and param in one kernel launch.
-
.adamw_step ⇒ Ignis::JIT::Kernel
Fused AdamW step: Adam with decoupled weight decay Weight decay applied directly to param, not through gradient.
-
.grad_clip_scale ⇒ Ignis::JIT::Kernel
Phase 2: scale gradients by clip factor clip_factor = max_norm / (total_norm + eps).
-
.grad_squared_sum ⇒ Ignis::JIT::Kernel
Gradient clipping by global norm Phase 1: compute per-parameter squared sum.
-
.sgd_step ⇒ Ignis::JIT::Kernel
SGD step with momentum and weight decay.
Class Method Details
.adam_step ⇒ Ignis::JIT::Kernel
Fused Adam step: update m, v, and param in one kernel launch
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# File 'lib/nvruby/jit/kernels/optimizer.rb', line 13 def adam_step source = <<~CUDA extern "C" __global__ void adam_step(float* __restrict__ param, const float* __restrict__ grad, float* __restrict__ m, float* __restrict__ v, const float lr, const float beta1, const float beta2, const float eps, const float weight_decay, const float bias_correction1, const float bias_correction2, const int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { float g = grad[idx]; // L2 regularization (Adam-style, not decoupled) if (weight_decay > 0.0f) { g += weight_decay * param[idx]; } // Update biased first moment estimate float m_new = beta1 * m[idx] + (1.0f - beta1) * g; m[idx] = m_new; // Update biased second moment estimate float v_new = beta2 * v[idx] + (1.0f - beta2) * g * g; v[idx] = v_new; // Bias correction float m_hat = m_new / bias_correction1; float v_hat = v_new / bias_correction2; // Update parameter param[idx] -= lr * m_hat / (sqrtf(v_hat) + eps); } } CUDA compile_cached(source, "adam_step") end |
.adamw_step ⇒ Ignis::JIT::Kernel
Fused AdamW step: Adam with decoupled weight decay Weight decay applied directly to param, not through gradient
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# File 'lib/nvruby/jit/kernels/optimizer.rb', line 60 def adamw_step source = <<~CUDA extern "C" __global__ void adamw_step(float* __restrict__ param, const float* __restrict__ grad, float* __restrict__ m, float* __restrict__ v, const float lr, const float beta1, const float beta2, const float eps, const float weight_decay, const float bias_correction1, const float bias_correction2, const int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { float g = grad[idx]; // Update biased first moment float m_new = beta1 * m[idx] + (1.0f - beta1) * g; m[idx] = m_new; // Update biased second moment float v_new = beta2 * v[idx] + (1.0f - beta2) * g * g; v[idx] = v_new; // Bias correction float m_hat = m_new / bias_correction1; float v_hat = v_new / bias_correction2; // Decoupled weight decay + Adam update param[idx] -= lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param[idx]); } } CUDA compile_cached(source, "adamw_step") end |
.grad_clip_scale ⇒ Ignis::JIT::Kernel
Phase 2: scale gradients by clip factor clip_factor = max_norm / (total_norm + eps)
169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# File 'lib/nvruby/jit/kernels/optimizer.rb', line 169 def grad_clip_scale source = <<~CUDA extern "C" __global__ void grad_clip_scale(float* __restrict__ grad, const float clip_factor, const int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { grad[idx] *= clip_factor; } } CUDA compile_cached(source, "grad_clip_scale") end |
.grad_squared_sum ⇒ Ignis::JIT::Kernel
Gradient clipping by global norm Phase 1: compute per-parameter squared sum
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# File 'lib/nvruby/jit/kernels/optimizer.rb', line 137 def grad_squared_sum source = <<~CUDA extern "C" __global__ void grad_squared_sum(const float* __restrict__ grad, float* __restrict__ partial_sum, const int n) { extern __shared__ float sdata[]; int tid = threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x; sdata[tid] = (idx < n) ? grad[idx] * grad[idx] : 0.0f; __syncthreads(); // Parallel reduction in shared memory for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { sdata[tid] += sdata[tid + s]; } __syncthreads(); } if (tid == 0) { atomicAdd(partial_sum, sdata[0]); } } CUDA compile_cached(source, "grad_squared_sum") end |
.sgd_step ⇒ Ignis::JIT::Kernel
SGD step with momentum and weight decay
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# File 'lib/nvruby/jit/kernels/optimizer.rb', line 101 def sgd_step source = <<~CUDA extern "C" __global__ void sgd_step(float* __restrict__ param, const float* __restrict__ grad, float* __restrict__ velocity, const float lr, const float momentum, const float weight_decay, const int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { float g = grad[idx]; if (weight_decay > 0.0f) { g += weight_decay * param[idx]; } float v; if (momentum > 0.0f) { v = momentum * velocity[idx] + g; velocity[idx] = v; } else { v = g; } param[idx] -= lr * v; } } CUDA compile_cached(source, "sgd_step") end |