Class: SmolLM2KVFFICacheMetal

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/llm/engine/llama_kv_engine_metal.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initializeSmolLM2KVFFICacheMetal

Returns a new instance of SmolLM2KVFFICacheMetal.



196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 196

def initialize
  @realized   = false
  @max_T      = 0
  @d_model    = 0
  @d_ff       = 0
  @n_heads    = 0
  @n_kv       = 0
  @d_head     = 0
  @group_size = 0
  @n_layers   = 0
  @vocab_size = 0
  @rope_base  = 10000.0
  @rope_scaling        = Toy::RopeScaling.none
  @t_rope_freq_factors = TinyNNMetal.tnn_null_ptr
  @rms_eps    = 1.0e-5
  @sess               = TinyNNMetal.tnn_null_ptr
  @t_token_embed      = TinyNNMetal.tnn_null_ptr
  @t_final_norm_gamma = TinyNNMetal.tnn_null_ptr
  @t_output           = TinyNNMetal.tnn_null_ptr
  @has_untied_output  = false
  @has_qkv_bias       = false
  @has_qk_norm        = false
  @qk_norm_kind       = 0
  @swa_window         = 0
  @has_post_norms     = false
  @embed_scale        = 1.0
  @attn_softcap       = 0.0
  @final_softcap      = 0.0
  @swa_alternates     = false
  @kv_blocks_ffi      = [SmolLM2KVBlockFFIMetal.new]
  @weight_type   = 0                # GGML_TYPE_F32; legacy default
  @kv_type_k     = 0                # GGML_TYPE_F32; opt in via enable_kv_q8!
  @kv_type_v     = 0                # GGML_TYPE_F32; opt in via enable_kv_q8!
  @use_flash_attn = false            # opt in via enable_flash_attn!
  @is_moe         = false
  @n_experts      = 0
  @n_experts_used = 0
  @gguf_handle_keepalive = TinyNNMetal.tnn_null_ptr  # set by realize_for_mmap
  @lora_q_enabled = false
  @lora_q_rank    = 0
  @lora_q_adamw_enabled = false
end

Instance Attribute Details

#attn_softcapObject

Returns the value of attribute attn_softcap.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def attn_softcap
  @attn_softcap
end

#d_ffObject

Returns the value of attribute d_ff.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def d_ff
  @d_ff
end

#d_headObject

Returns the value of attribute d_head.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def d_head
  @d_head
end

#d_modelObject

Returns the value of attribute d_model.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def d_model
  @d_model
end

#embed_scaleObject

Returns the value of attribute embed_scale.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def embed_scale
  @embed_scale
end

#final_softcapObject

Returns the value of attribute final_softcap.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def final_softcap
  @final_softcap
end

#gguf_handle_keepaliveObject

Returns the value of attribute gguf_handle_keepalive.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def gguf_handle_keepalive
  @gguf_handle_keepalive
end

#group_sizeObject

Returns the value of attribute group_size.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def group_size
  @group_size
end

#has_post_normsObject

Returns the value of attribute has_post_norms.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def has_post_norms
  @has_post_norms
end

#has_qk_normObject

Returns the value of attribute has_qk_norm.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def has_qk_norm
  @has_qk_norm
end

#has_qkv_biasObject

Returns the value of attribute has_qkv_bias.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def has_qkv_bias
  @has_qkv_bias
end

#has_untied_outputObject

Returns the value of attribute has_untied_output.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def has_untied_output
  @has_untied_output
end

#is_moeObject

Returns the value of attribute is_moe.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def is_moe
  @is_moe
end

#kv_blocks_ffiObject

Returns the value of attribute kv_blocks_ffi.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def kv_blocks_ffi
  @kv_blocks_ffi
end

#kv_type_kObject

Returns the value of attribute kv_type_k.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def kv_type_k
  @kv_type_k
end

#kv_type_vObject

Returns the value of attribute kv_type_v.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def kv_type_v
  @kv_type_v
end

#lora_q_adamw_enabledObject

Returns the value of attribute lora_q_adamw_enabled.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def lora_q_adamw_enabled
  @lora_q_adamw_enabled
end

#lora_q_enabledObject

Returns the value of attribute lora_q_enabled.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def lora_q_enabled
  @lora_q_enabled
end

#lora_q_rankObject

Returns the value of attribute lora_q_rank.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def lora_q_rank
  @lora_q_rank
end

#max_TObject

Returns the value of attribute max_T.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def max_T
  @max_T
end

#n_expertsObject

Returns the value of attribute n_experts.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def n_experts
  @n_experts
end

#n_experts_usedObject

Returns the value of attribute n_experts_used.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def n_experts_used
  @n_experts_used
end

#n_headsObject

Returns the value of attribute n_heads.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def n_heads
  @n_heads
end

#n_kvObject

Returns the value of attribute n_kv.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def n_kv
  @n_kv
end

#n_layersObject

Returns the value of attribute n_layers.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def n_layers
  @n_layers
end

#qk_norm_kindObject

Returns the value of attribute qk_norm_kind.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def qk_norm_kind
  @qk_norm_kind
end

#realizedObject

Returns the value of attribute realized.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def realized
  @realized
end

#rms_epsObject

Returns the value of attribute rms_eps.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def rms_eps
  @rms_eps
end

#rope_baseObject

Returns the value of attribute rope_base.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def rope_base
  @rope_base
end

#rope_scalingObject

Returns the value of attribute rope_scaling.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def rope_scaling
  @rope_scaling
end

#sessObject

Returns the value of attribute sess.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def sess
  @sess
end

#swa_alternatesObject

Returns the value of attribute swa_alternates.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def swa_alternates
  @swa_alternates
end

#swa_windowObject

Returns the value of attribute swa_window.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def swa_window
  @swa_window
end

#t_final_norm_gammaObject

Returns the value of attribute t_final_norm_gamma.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def t_final_norm_gamma
  @t_final_norm_gamma
end

#t_outputObject

Returns the value of attribute t_output.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def t_output
  @t_output
end

#t_rope_freq_factorsObject

Returns the value of attribute t_rope_freq_factors.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def t_rope_freq_factors
  @t_rope_freq_factors
end

#t_token_embedObject

Returns the value of attribute t_token_embed.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def t_token_embed
  @t_token_embed
end

#use_flash_attnObject

Returns the value of attribute use_flash_attn.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def use_flash_attn
  @use_flash_attn
end

#vocab_sizeObject

Returns the value of attribute vocab_size.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def vocab_size
  @vocab_size
end

#weight_typeObject

Returns the value of attribute weight_type.



116
117
118
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 116

def weight_type
  @weight_type
end

Instance Method Details

#alloc_2d_w(rows, cols) ⇒ Object

Allocate one persistent 2D linear weight tensor at the configured type. Used by realize_for; keeps the Q8/F32 branch in one place. Non-2D-linear tensors (norms, biases, K/V cache, t_output) stay F32 even in Q8 mode — quantizing them costs accuracy with no compute saving.



322
323
324
325
326
327
328
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 322

def alloc_2d_w(rows, cols)
  if @weight_type == 0
    TinyNNMetal.tnn_input_2d_f32_persistent(@sess, rows, cols)
  else
    TinyNNMetal.tnn_input_2d_persistent_typed(@sess, rows, cols, @weight_type)
  end
end

#build_attention_qhead_step(t_h, blk, hq, t_pos, pos, scale, bytes_d_head, bytes_d_head_k, bytes_d_head_v, bytes_max_T, tag, tap_this_head, layer_idx) ⇒ Object

One query head. Uses the (already-written) K and V of the corresponding KV head — index = hq / group_size. ‘tag` is the “L<i>.” layer prefix; `tap_this_head` is true only for head 0 so we don’t multiply taps by n_heads in trace mode.



1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 1298

def build_attention_qhead_step(t_h, blk, hq, t_pos, pos, scale,
                                bytes_d_head, bytes_d_head_k, bytes_d_head_v,
                                bytes_max_T, tag, tap_this_head,
                                layer_idx)
  hkv = hq / @group_size

  # I-Gemma (#113): per-layer SWA toggle. Gemma 2 alternates layers
  # between full attention and sliding-window. When @swa_alternates
  # is true, only EVEN layers see the SWA window; odd layers get
  # effectively full attention (window = 0 ⇒ hist_count = pos+1).
  # Non-Gemma archs: @swa_alternates is false; all layers apply
  # @swa_window uniformly (or 0 for no-SWA models).
  swa_for_this_layer = @swa_window
  if @swa_alternates && layer_idx.odd?
    swa_for_this_layer = 0
  end

  t_q_raw = TinyNNMetal.tnn_matmul(@sess, blk.t_w_q[hq], t_h)   # ne=[d_head, 1]
  # F1.2: optional LoRA on Q. Standard placement is BEFORE the bias
  # add (HF LoRA practice — the bias stays a property of the base
  # projection, LoRA only adjusts the linear part). Math:
  #   q_lora = w_lora_b[hq] @ (w_lora_a[hq] @ t_h)
  #   q_raw  := q_raw + q_lora
  # With B init to zero, q_lora == 0 and q_raw is unchanged.
  if @lora_q_enabled
    t_lora_a_h    = TinyNNMetal.tnn_matmul(@sess, blk.t_w_lora_a_q[hq], t_h)      # ne=[r, 1]
    t_lora_b_a_h  = TinyNNMetal.tnn_matmul(@sess, blk.t_w_lora_b_q[hq], t_lora_a_h)# ne=[d_head, 1]
    t_q_raw       = TinyNNMetal.tnn_add(@sess, t_q_raw, t_lora_b_a_h)
  end
  if @has_qkv_bias
    t_q_pre = TinyNNMetal.tnn_add(@sess, t_q_raw, blk.t_b_q[hq])
  else
    t_q_pre = t_q_raw
  end
  if tap_this_head
    t_q_pre = trace_tap(tag + "q_pre", t_q_pre)
  end
  if @has_qk_norm
    if @qk_norm_kind == 2
      # OLMoE / Granite per-head gamma slice (see build_block_step's
      # K-norm comment). The gamma tensor is [d_model]; head hq's
      # slice lives at byte offset hq*d_head*4.
      q_gamma_view = TinyNNMetal.tnn_view_1d(@sess, blk.t_q_norm_gamma,
                                          @d_head, hq * @d_head * 4)
      t_q_pre = TinyNNMetal.tnn_rms_norm(@sess, t_q_pre, q_gamma_view, @rms_eps)
    else
      t_q_pre = TinyNNMetal.tnn_rms_norm(@sess, t_q_pre, blk.t_q_norm_gamma, @rms_eps)
    end
  end
  t_q     = TinyNNMetal.tnn_rope_ext(@sess, t_q_pre, t_pos, @d_head,
                                @rope_base, @rope_scaling.freq_scale,
                                @rope_scaling.ext_factor,
                                @rope_scaling.attn_factor,
                                @rope_scaling.beta_fast,
                                @rope_scaling.beta_slow,
                                @t_rope_freq_factors)
  if tap_this_head
    t_q = trace_tap(tag + "q_rot", t_q)
  end

  # M3 + I-Gemma: sliding-window attention. When swa_for_this_layer
  # > 0, restrict the K/V view to the last `min(pos+1, swa_window)`
  # positions. swa_for_this_layer differs from @swa_window only
  # when @swa_alternates is set (Gemma 2's even/odd layer pattern).
  if swa_for_this_layer > 0 && (pos + 1) > swa_for_this_layer
    hist_start = pos + 1 - swa_for_this_layer
    hist_count = swa_for_this_layer
  else
    hist_start = 0
    hist_count = pos + 1
  end
  # P5.1+P5.2: K and V views share the same byte-stride math.
  # ggml_mul_mat dequantizes Q8 source on the fly when reads happen.
  t_K_hist = TinyNNMetal.tnn_view_2d(@sess, blk.t_K[hkv],
                                  @d_head, hist_count, bytes_d_head_k,
                                  hist_start * bytes_d_head_k)
  # P5.2: V is now ne=[d_head, max_T] (positions on ne1, mirror of K).
  # The history view at [d_head, hist_count] is what flash_attn_ext
  # expects natively — no transpose-cont in the flash path now.
  t_V_hist = TinyNNMetal.tnn_view_2d(@sess, blk.t_V[hkv],
                                  @d_head, hist_count, bytes_d_head_v,
                                  hist_start * bytes_d_head_v)

  if @use_flash_attn
    # P4.1+P5.2: fused softmax(Q·Kᵀ·scale + mask)·V via
    # ggml_flash_attn_ext. Reshape Q/K/V to the 3D shapes
    # flash_attn_ext expects (ne[3] defaults to 1 so we don't need
    # a fourth dim). V's layout is already correct post-P5.2 — no
    # transpose needed.
    t_q_3d   = TinyNNMetal.tnn_reshape_3d(@sess, t_q,      @d_head, 1, 1)
    t_K_3d   = TinyNNMetal.tnn_reshape_3d(@sess, t_K_hist, @d_head, hist_count, 1)
    t_V_3d   = TinyNNMetal.tnn_reshape_3d(@sess, t_V_hist, @d_head, hist_count, 1)
    # I-Gemma (#113): pass logit soft-cap to flash_attn_ext. The
    # kernel applies tanh(x/softcap)*softcap to attention logits
    # internally. 0.0 disables (every non-Gemma model).
    t_out_4d = TinyNNMetal.tnn_flash_attn_ext(@sess, t_q_3d, t_K_3d, t_V_3d, nil,
                                          scale, 0.0, @attn_softcap)
    # Output ne=[d_head, n_head=1, T_q=1, batch=1]; collapse to 2D.
    t_head = TinyNNMetal.tnn_reshape_2d(@sess, t_out_4d, @d_head, 1)
    if tap_this_head
      t_head = trace_tap(tag + "head0_flash", t_head)
    end
    return t_head
  end

  t_scores = TinyNNMetal.tnn_matmul(@sess, t_K_hist, t_q)
  if tap_this_head
    t_scores = trace_tap(tag + "scores", t_scores)
  end
  t_scaled = TinyNNMetal.tnn_scale(@sess, t_scores, scale)
  # I-Gemma (#113): logit soft-cap in the non-flash path.
  #   y = softcap * tanh(x / softcap)
  # Composed via two scales + tanh. No-op when @attn_softcap == 0.
  if @attn_softcap > 0.0
    t_scaled = TinyNNMetal.tnn_scale(@sess, t_scaled, 1.0 / @attn_softcap)
    t_scaled = TinyNNMetal.tnn_tanh(@sess, t_scaled)
    t_scaled = TinyNNMetal.tnn_scale(@sess, t_scaled, @attn_softcap)
  end
  t_attn   = TinyNNMetal.tnn_softmax(@sess, t_scaled)
  if tap_this_head
    t_attn = trace_tap(tag + "softmax", t_attn)
  end
  # P5.2: V is now [d_head, hist_count]; ggml_mul_mat needs the
  # matching k axis (hist_count) on both inputs, so transpose V_hist
  # (free view; tnn_transpose materializes via ggml_cont — one copy
  # of d_head × hist_count × 4 bytes per Q-head per layer). Cheap
  # at decode (typical hist_count ~ a few hundred) and uniform with
  # how flash takes V — both paths see the same V layout now.
  t_V_T  = TinyNNMetal.tnn_transpose(@sess, t_V_hist)
  t_head = TinyNNMetal.tnn_matmul(@sess, t_V_T, t_attn)
  if tap_this_head
    t_head = trace_tap(tag + "head0", t_head)
  end
  t_head
end

#build_block_step(t_x, blk, t_pos, pos, scale, eps, bytes_d_head, bytes_d_head_k, bytes_d_head_v, bytes_max_T, layer_idx) ⇒ Object



1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 1105

def build_block_step(t_x, blk, t_pos, pos, scale, eps,
                      bytes_d_head, bytes_d_head_k, bytes_d_head_v,
                      bytes_max_T, layer_idx)
  # Layer-tag prefix for tap names (e.g. "L00."). String concat of an
  # int needs explicit .to_s; ljust pads so all names align in output.
  tag = "L" + layer_idx.to_s + "."

  t_h = TinyNNMetal.tnn_rms_norm(@sess, t_x, blk.t_rn1_gamma, eps)
  t_h = trace_tap(tag + "rn1", t_h)

  # --- compute K, V for each KV head (n_kv times), rope K, cpy into buffers ---
  hkv = 0
  while hkv < @n_kv
    t_k_raw = TinyNNMetal.tnn_matmul(@sess, blk.t_w_k[hkv], t_h)         # ne=[d_head, 1]
    if @has_qkv_bias
      t_k_pre = TinyNNMetal.tnn_add(@sess, t_k_raw, blk.t_b_k[hkv])
    else
      t_k_pre = t_k_raw
    end
    # Tap K (head 0 only) post-bias, pre-RoPE.
    if hkv == 0
      t_k_pre = trace_tap(tag + "k_pre", t_k_pre)
    end
    # M1 + #110: QK-norm. Two flavors:
    #   kind=1 (Qwen3): blk.t_k_norm_gamma is [d_head], shared
    #     across all KV heads; pass directly.
    #   kind=2 (OLMoE / Granite, per-head approximation):
    #     blk.t_k_norm_gamma is [n_kv * d_head] = [d_model_kv];
    #     view the per-head [d_head] slice at byte offset
    #     hkv*d_head*4. This computes per-head variance (not the
    #     true full-Q-vector variance) but applies the correct
    #     per-element gamma scaling. Cheap and close-enough for
    #     models where per-head magnitudes are similar (which they
    #     typically are for projections of a single input).
    if @has_qk_norm
      if @qk_norm_kind == 2
        k_gamma_view = TinyNNMetal.tnn_view_1d(@sess, blk.t_k_norm_gamma,
                                            @d_head, hkv * @d_head * 4)
        t_k_pre = TinyNNMetal.tnn_rms_norm(@sess, t_k_pre, k_gamma_view, @rms_eps)
      else
        t_k_pre = TinyNNMetal.tnn_rms_norm(@sess, t_k_pre, blk.t_k_norm_gamma, @rms_eps)
      end
    end
    t_k_rot = TinyNNMetal.tnn_rope_ext(@sess, t_k_pre, t_pos, @d_head,
                                  @rope_base, @rope_scaling.freq_scale,
                                  @rope_scaling.ext_factor,
                                  @rope_scaling.attn_factor,
                                  @rope_scaling.beta_fast,
                                  @rope_scaling.beta_slow,
                                  @t_rope_freq_factors)
    if hkv == 0
      t_k_rot = trace_tap(tag + "k_rot", t_k_rot)
    end
    # V matmul: weight in A position so ggml's matmul kernel can
    # dispatch to Q8 (and other quantized) kernels. Result is
    # [d_head, 1] instead of the legacy [1, d_head]; a contiguous
    # view_2d before the cpy reinterprets it as a [1, d_head] row
    # without moving bytes.
    t_v_raw = TinyNNMetal.tnn_matmul(@sess, blk.t_w_v[hkv], t_h)         # ne=[d_head, 1]
    if @has_qkv_bias
      t_v_new = TinyNNMetal.tnn_add(@sess, t_v_raw, blk.t_b_v[hkv])      # bias is 1-D [d_head]
    else
      t_v_new = t_v_raw
    end
    if hkv == 0
      t_v_new = trace_tap(tag + "v_new", t_v_new)
    end

    # P5.1+P5.2: K and V both use the same per-position write pattern.
    # bytes_d_head_{k,v} reflect each cache's dtype (F32 → d_head*4,
    # Q8_0 → type-aware row size from tnn_row_size). cpy quantizes
    # f32 source → Q8 destination automatically when types differ.
    t_K_slot = TinyNNMetal.tnn_view_2d(@sess, blk.t_K[hkv],
                                    @d_head, 1, bytes_d_head_k, pos * bytes_d_head_k)
    t_cpy_k = TinyNNMetal.tnn_cpy(@sess, t_k_rot, t_K_slot)
    t_V_slot = TinyNNMetal.tnn_view_2d(@sess, blk.t_V[hkv],
                                    @d_head, 1, bytes_d_head_v, pos * bytes_d_head_v)
    t_cpy_v = TinyNNMetal.tnn_cpy(@sess, t_v_new, t_V_slot)
    TinyNNMetal.tnn_add_to_graph(@sess, t_cpy_k)
    TinyNNMetal.tnn_add_to_graph(@sess, t_cpy_v)
    hkv = hkv + 1
  end

  # --- per-Q-head attention ---
  t_head_out0 = build_attention_qhead_step(t_h, blk, 0, t_pos, pos,
                                            scale, bytes_d_head, bytes_d_head_k,
                                            bytes_d_head_v, bytes_max_T, tag, true,
                                            layer_idx)
  t_head_outs = [t_head_out0]
  hq = 1
  while hq < @n_heads
    t_head_outs.push(build_attention_qhead_step(t_h, blk, hq, t_pos, pos,
                                                  scale, bytes_d_head, bytes_d_head_k,
                                                  bytes_d_head_v, bytes_max_T, tag, false,
                                                  layer_idx))
    hq = hq + 1
  end

  t_concat = t_head_outs[0]
  hq = 1
  while hq < @n_heads
    t_concat = TinyNNMetal.tnn_concat(@sess, t_concat, t_head_outs[hq], 0)
    hq = hq + 1
  end
  t_concat = trace_tap(tag + "concat", t_concat)

  t_out_proj = TinyNNMetal.tnn_matmul(@sess, blk.t_w_o, t_concat)
  t_out_proj = trace_tap(tag + "attn_out", t_out_proj)
  # I-Gemma (#113): post-attention RMSNorm applied to the attention
  # output BEFORE the residual add. Gemma 2's sandwich structure:
  #   pre_norm(x) → attention → post_norm → residual + …
  # No-op when has_post_norms is false (every non-Gemma arch).
  if @has_post_norms
    t_out_proj = TinyNNMetal.tnn_rms_norm(@sess, t_out_proj, blk.t_post_attn_norm_gamma, eps)
    t_out_proj = trace_tap(tag + "post_attn_norm", t_out_proj)
  end
  t_x_attn   = TinyNNMetal.tnn_add(@sess, t_x, t_out_proj)
  t_x_attn   = trace_tap(tag + "post_attn", t_x_attn)

  # --- FFN ---
  t_h2     = TinyNNMetal.tnn_rms_norm(@sess, t_x_attn, blk.t_rn2_gamma, eps)
  t_h2     = trace_tap(tag + "rn2", t_h2)

  if @is_moe
    t_dn = build_moe_ffn(blk, t_h2, tag)
  else
    # --- SwiGLU FFN (dense) ---
    t_gate   = TinyNNMetal.tnn_matmul(@sess, blk.t_w_gate, t_h2)        # ne=[d_ff, 1]
    t_gate   = trace_tap(tag + "gate", t_gate)
    t_up     = TinyNNMetal.tnn_matmul(@sess, blk.t_w_up,   t_h2)        # ne=[d_ff, 1]
    t_up     = trace_tap(tag + "up", t_up)
    t_silug  = TinyNNMetal.tnn_silu(@sess, t_gate)
    t_silug  = trace_tap(tag + "silu_gate", t_silug)
    t_gated  = TinyNNMetal.tnn_mul(@sess, t_silug, t_up)
    t_gated  = trace_tap(tag + "gated", t_gated)
    t_dn     = TinyNNMetal.tnn_matmul(@sess, blk.t_w_down, t_gated)     # ne=[d_model, 1]
    t_dn     = trace_tap(tag + "dn", t_dn)
  end

  # I-Gemma (#113): post-FFN RMSNorm on the FFN output before the
  # residual add. Same pattern as the post-attn norm above.
  if @has_post_norms
    t_dn = TinyNNMetal.tnn_rms_norm(@sess, t_dn, blk.t_post_ffn_norm_gamma, eps)
    t_dn = trace_tap(tag + "post_ffn_norm", t_dn)
  end
  t_post_ffn = TinyNNMetal.tnn_add(@sess, t_x_attn, t_dn)
  trace_tap(tag + "post_ffn", t_post_ffn)
end

#build_decode_step(pos) ⇒ Object

Build the compute graph for one decode position.



1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 1047

def build_decode_step(pos)
  eps     = @rms_eps
  scale   = 1.0 / Math.sqrt(@d_head.to_f)
  d_model = @d_model
  d_head  = @d_head
  max_T   = @max_T
  bytes_d_head = d_head * 4
  bytes_max_T  = max_T * 4
  # P5.1+P5.2: row size for K and V. F32 → d_head*4; Q8_0 →
  # ggml_row_size(Q8_0, d_head) (block 32 × 34 bytes; 68 at d_head=64).
  # V is in the same layout as K post-P5.2 so the math is symmetric.
  bytes_d_head_k = @kv_type_k == 8 ? TinyNNMetal.tnn_row_size(8, d_head) : bytes_d_head
  bytes_d_head_v = @kv_type_v == 8 ? TinyNNMetal.tnn_row_size(8, d_head) : bytes_d_head

  # Inputs: token id + RoPE position. Both length 1.
  t_token_id  = TinyNNMetal.tnn_input_1d_i32(@sess, 1)
  t_pos       = TinyNNMetal.tnn_input_1d_i32_ctx(@sess, 1)

  t_x = TinyNNMetal.tnn_get_rows(@sess, @t_token_embed, t_token_id)   # ne=[d_model, 1]
  # I-Gemma (#113): Gemma 2 scales token embeddings by sqrt(d_model)
  # post-lookup. Non-Gemma archs use @embed_scale = 1.0 (no-op
  # branch). The scalar is computed at flag-detection time so we
  # don't pay a Math.sqrt landmine in the hot path.
  if @embed_scale != 1.0
    t_x = TinyNNMetal.tnn_scale(@sess, t_x, @embed_scale)
  end
  t_x = trace_tap("embed", t_x)

  li = 0
  while li < @n_layers
    t_x = build_block_step(t_x, @kv_blocks_ffi[li], t_pos, pos,
                            scale, eps, bytes_d_head, bytes_d_head_k,
                            bytes_d_head_v, bytes_max_T, li)
    li = li + 1
  end

  t_x_final = TinyNNMetal.tnn_rms_norm(@sess, t_x, @t_final_norm_gamma, eps)
  t_x_final = trace_tap("final_norm", t_x_final)
  # Logits: untied path matmuls against t_output (lm_head); tied
  # path against t_token_embed. Both tensors are [vocab, d_model],
  # so the matmul shape is identical either way.
  if @has_untied_output
    t_kv_logits = TinyNNMetal.tnn_matmul(@sess, @t_output, t_x_final)
  else
    t_kv_logits = TinyNNMetal.tnn_matmul(@sess, @t_token_embed, t_x_final)
  end
  # I-Gemma (#113): final logit soft-cap. Gemma 2 applies
  # tanh(logits / final_softcap) * final_softcap to the output
  # logits before argmax / sampling. No-op for other models.
  if @final_softcap > 0.0
    t_kv_logits = TinyNNMetal.tnn_scale(@sess, t_kv_logits, 1.0 / @final_softcap)
    t_kv_logits = TinyNNMetal.tnn_tanh(@sess, t_kv_logits)
    t_kv_logits = TinyNNMetal.tnn_scale(@sess, t_kv_logits, @final_softcap)
  end
  TinyNNMetal.tnn_set_output(t_kv_logits)
  SmolLM2KVStepResultMetal.new(t_token_id, t_pos, t_kv_logits)
end

#build_moe_ffn(blk, t_h2, tag) ⇒ Object

M2.3: Mixtral / Qwen-MoE routed FFN. Ports the validated graph from tinynn/ab_smoke_moe_ffn into the production decode path. Shapes:

t_h2          [d_model, 1]                    input (post-norm)
router_logits [n_experts, 1]                  matmul(w_router, h2)
probs         [n_experts, 1]                  softmax(logits)
top_idx       [n_experts_used, 1]             top_k(probs)
weights       [1, n_experts_used, 1]          get_rows(reshape_3d(probs,1,n_exp,1), top_idx)
e_gate / e_up [d_ff,    n_experts_used, 1]    mul_mat_id(...exps, h2, top_idx)
e_down        [d_model, n_experts_used, 1]    after weight × sum

The (mul/transpose/sum_rows/reshape) sum-across-K is the same trick the smoke uses; ggml has no axis-1 reduce primitive.



1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 1266

def build_moe_ffn(blk, t_h2, tag)
  t_logits     = TinyNNMetal.tnn_matmul(@sess, blk.t_w_router, t_h2)        # ne=[n_exp, 1]
  t_logits     = trace_tap(tag + "moe_logits", t_logits)
  t_probs      = TinyNNMetal.tnn_softmax(@sess, t_logits)                   # ne=[n_exp, 1]
  t_top_idx    = TinyNNMetal.tnn_top_k(@sess, t_probs, @n_experts_used)     # ne=[K, 1]
  t_probs_3d   = TinyNNMetal.tnn_reshape_3d(@sess, t_probs, 1, @n_experts, 1)
  t_w_route    = TinyNNMetal.tnn_get_rows(@sess, t_probs_3d, t_top_idx)     # ne=[1, K, 1]

  t_e_gate     = TinyNNMetal.tnn_mul_mat_id(@sess, blk.t_w_gate_exps, t_h2, t_top_idx)
  t_e_up       = TinyNNMetal.tnn_mul_mat_id(@sess, blk.t_w_up_exps,   t_h2, t_top_idx)
  t_e_silu     = TinyNNMetal.tnn_silu(@sess, t_e_gate)
  t_e_gated    = TinyNNMetal.tnn_mul(@sess, t_e_silu, t_e_up)               # ne=[d_ff, K, 1]
  t_e_down     = TinyNNMetal.tnn_mul_mat_id(@sess, blk.t_w_down_exps, t_e_gated, t_top_idx)
  t_e_down     = trace_tap(tag + "moe_e_down", t_e_down)               # ne=[d_model, K, 1]

  # Broadcast weights over d_model: [d_model, K, 1] × [1, K, 1] → [d_model, K, 1].
  t_weighted   = TinyNNMetal.tnn_mul(@sess, t_e_down, t_w_route)

  # Sum across K (axis 1). Reshape to 2D (T=1 collapses), transpose
  # [d_model, K] → [K, d_model], sum_rows along ne0=K → [1, d_model],
  # reshape back to [d_model, 1].
  t_weighted_2d = TinyNNMetal.tnn_reshape_2d(@sess, t_weighted, @d_model, @n_experts_used)
  t_weighted_T  = TinyNNMetal.tnn_transpose(@sess, t_weighted_2d)
  t_summed_T    = TinyNNMetal.tnn_sum_rows(@sess, t_weighted_T)             # ne=[1, d_model]
  t_dn          = TinyNNMetal.tnn_reshape_2d(@sess, t_summed_T, @d_model, 1)
  trace_tap(tag + "moe_out", t_dn)
end

#dump_traceObject



911
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 911

def dump_trace; end

#enable_flash_attn!Object

P4.1: opt into ggml_flash_attn_ext for inference. Per-Q-head it replaces the (scale → softmax → matmul) triplet with one fused call. The V cache stays in its current [max_T, d_head] layout —we transpose-materialize it per step (cheap; one ggml_cont). A future cleanup (P5.2) flips V’s layout to remove the transpose and unlock V Q8.

Backward is unsupported in vendored ggml (flash_attn_back aborts), so this path is INFERENCE only. Call BEFORE realize_for_mmap.



275
276
277
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 275

def enable_flash_attn!
  @use_flash_attn = true
end

#enable_kv_q8!Object

P5.1: opt into Q8_0 storage for the K cache. Must be called BEFORE realize_for_mmap. V stays F32 in this phase — its layout (positions along ne0) makes per-position Q8 writes non-block- aligned. K’s layout (positions along ne1, d_head along ne0) writes whole d_head-vectors at a time, which for d_head=64 spans exactly 2 Q8_0 blocks of 32 elements each → aligned. The write path uses ggml_cpy which quantizes on f32→Q8 destination; the read path (attention matmul) dequantizes block-by-block inside ggml’s kernel. Cuts K-cache memory & bandwidth ~4×. P5.1+P5.2: opt into Q8_0 for the K and V caches. Halves K and V memory + bandwidth (3.75× smaller at d_head=64).

Auto-enables flash attention. Reason: the non-flash V matmul requires a transpose-cont of V_hist, which is structurally impossible for Q8_0 (transposing flips the d_head and hist_count axes; hist_count generally isn’t a multiple of 32, so the contiguous Q8 destination can’t be allocated). flash_attn_ext consumes V in its natural [d_head, hist_count] orientation, which dodges the transpose entirely — so Q8 V works there.

Inference-only. flash_attn’s backward aborts in vendored ggml.



260
261
262
263
264
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 260

def enable_kv_q8!
  @kv_type_k      = 8   # GGML_TYPE_Q8_0
  @kv_type_v      = 8
  @use_flash_attn = true
end

#enable_lora_q!(r) ⇒ Object

F1.2: enable per-Q-head LoRA on this session’s forward graph. Call BEFORE realize_for_mmap. Adapter A is (r, d_model), adapter B is (d_head, r); both trainable F32 tensors in ctx_w (not mmap’d, so writes survive). Standard LoRA init: A = small Gaussian, B = 0, which makes the adapter a no-op at step 0 (forward output == baseline). Use upload_lora_zero!(seed) to set up that init.



296
297
298
299
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 296

def enable_lora_q!(r)
  @lora_q_enabled = true
  @lora_q_rank    = r
end

#enable_lora_q_adamw!Object

F1.2 step 6b: allocate persistent AdamW moments (m, v) alongside each LoRA-A/B pair, in ctx_w. Requires enable_lora_q!(…) to have been called first (so the rank is known). Call BEFORE realize_for_mmap. Without this, multi-position SFT loses Adam state at every graph rebuild and diverges to NaN.



306
307
308
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 306

def enable_lora_q_adamw!
  @lora_q_adamw_enabled = true
end

#enable_moe!(n_experts, n_experts_used) ⇒ Object

M2.3: opt into the MoE FFN graph. Must be called BEFORE realize_for_mmap. n_experts is the total count in the GGUF; n_experts_used is the top-K routed per token. Mixtral-8x7B: enable_moe!(8, 2). Qwen3-30B- A3B: enable_moe!(128, 8) (with optional shared expert — not yet supported in this path).



284
285
286
287
288
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 284

def enable_moe!(n_experts, n_experts_used)
  @is_moe         = true
  @n_experts      = n_experts
  @n_experts_used = n_experts_used
end

#enable_trace!Object



909
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 909

def enable_trace!; end

#head_nbytes(ggml_type, d_head, d_model) ⇒ Object

Per-head byte stride for slicing a full [n_heads*d_head, d_model] tensor into n_heads contiguous Dh×D blocks. A per-head slice is d_head rows of d_model elements, so the stride is d_head row-sizes.

tnn_row_size delegates to ggml_row_size, which is correct for EVERY type — F32, Q8_0, and the K-quants (Q4_K/Q5_K/Q6_K). The previous hand-coded F32/Q8_0-only branches returned 0 for any other type, which silently made the per-head offset ‘off_base + hq*0 == off_base` — i.e. every attention head read head 0’s weight slice. That collapsed multi-head attention on K-quant MoE models (forced down the realize_for_mmap path), compounding across layers into degenerate output. This was misdiagnosed as a ggml mul_mat_id K-quant bug (ggml#1506); it was ours. Block alignment holds because each row is a whole number of quant blocks (requires d_model % block == 0, which the per-head tnn_input_2d_persistent_mmap also enforces via ne0).



896
897
898
899
900
901
902
903
904
905
906
907
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 896

def head_nbytes(ggml_type, d_head, d_model)
  rs = TinyNNMetal.tnn_row_size(ggml_type, d_model)
  if rs <= 0
    # Fail loud per the never-mask rule: a 0 stride would collapse all
    # heads. tnn_row_size only returns 0 on a bad type/shape.
    puts "FATAL: head_nbytes got row_size<=0 for ggml_type=" +
         ggml_type.to_s + " d_model=" + d_model.to_s +
         " — per-head attention stride would collapse. Aborting."
    exit 1
  end
  d_head * rs
end

#load_weights(path) ⇒ Object

Ruby-OO entry point for “load weights into this realized cache.” Auto-detects layout: GGUFs with the ‘toy.ggml_native` metadata key take the memcpy path (no transpose); legacy GGUFs take the transposing path. Callers stay layout-agnostic.



917
918
919
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 917

def load_weights(path)
  GGUFLoad.load_kv_cache_auto(self, path)
end

#read_persistent_mat(t, rows, cols) ⇒ Object

Pull any persistent FFI tensor back to a Ruby Mat (chunked download, works for weight-sized tensors). Required by the design rule that the direct-loader path must keep Mat-roundtrip open — see docs/loader-api.md.

‘t` is any tensor handle exposed on this cache or its blocks (e.g. `kv.t_token_embed`, `kv.kv_blocks_ffi.t_w_o`). `rows` and `cols` are the logical shape; we trust the caller.



929
930
931
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 929

def read_persistent_mat(t, rows, cols)
  TinyNNMetal.download_to_mat(@sess, t, rows, cols)
end

#realize_and_load_auto(gguf_path, max_T, cfg, flags) ⇒ Object



859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 859

def realize_and_load_auto(gguf_path, max_T, cfg, flags)
  gguf = TinyNNMetal.tnn_gguf_load(gguf_path)
  is_native = TinyNNMetal.tnn_gguf_get_bool(gguf, "toy.ggml_native") == 1
  if is_native
    wtype = GGUFLoad.detect_weight_type(gguf_path)
    set_weight_type(wtype)
    realize_for_mmap(gguf, cfg, max_T, flags.untied, flags.qkv_bias)
    puts "  BYO-pointer mmap (weight_type=" + wtype.to_s + ")"
    gguf
  else
    TinyNNMetal.tnn_gguf_free(gguf)
    realize_for(max_T, cfg.d_model, cfg.d_ff,
                cfg.n_heads, cfg.n_kv,
                cfg.n_layers, cfg.vocab,
                cfg.rope_base, cfg.rms_eps,
                flags.untied, flags.qkv_bias)
    load_weights(gguf_path)
    puts "  legacy copy load"
    TinyNNMetal.tnn_null_ptr
  end
end

#realize_for(max_T, d_model, d_ff, n_heads, n_kv, n_layers, vocab_size, rope_base, rms_eps, untied, qkv_bias) ⇒ Object

Declare every persistent tensor (weights + K/V buffers) and finalize. ‘untied` is true for TinyLlama-shape models that have a separate `output.weight` (lm_head); false for SmolLM2 / Qwen2.5 with tied embeddings. When false we skip the (vocab × d_model) t_output allocation entirely. `qkv_bias` is true for Qwen2.x; when false the b_q/b_k/b_v tensors aren’t allocated and Q/K/V matmuls land without an add.



940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 940

def realize_for(max_T, d_model, d_ff, n_heads, n_kv, n_layers,
                vocab_size, rope_base, rms_eps, untied, qkv_bias)
  @max_T      = max_T
  @d_model    = d_model
  @d_ff       = d_ff
  @n_heads    = n_heads
  @n_kv       = n_kv
  @d_head     = d_model / n_heads
  @group_size = n_heads / n_kv
  @n_layers   = n_layers
  @vocab_size = vocab_size
  @rope_base  = rope_base
  @rms_eps    = rms_eps

  @sess               = TinyNNMetal.tnn_session_new(2)
  @t_token_embed      = TinyNNMetal.tnn_input_2d_f32_persistent(@sess, vocab_size, d_model)
  @t_final_norm_gamma = TinyNNMetal.tnn_input_1d_f32_persistent(@sess, d_model)
  @has_untied_output  = untied
  @has_qkv_bias       = qkv_bias
  if untied
    @t_output = alloc_2d_w(vocab_size, d_model)
  end

  @kv_blocks_ffi = [SmolLM2KVBlockFFIMetal.new]
  li = 1
  while li < n_layers
    @kv_blocks_ffi.push(SmolLM2KVBlockFFIMetal.new)
    li = li + 1
  end

  li = 0
  while li < n_layers
    blk = @kv_blocks_ffi[li]
    blk.t_rn1_gamma = TinyNNMetal.tnn_input_1d_f32_persistent(@sess, d_model)
    blk.t_rn2_gamma = TinyNNMetal.tnn_input_1d_f32_persistent(@sess, d_model)

    # Q: n_heads per-head matrices of (d_head, d_model). Quantizable.
    blk.t_w_q = [alloc_2d_w(d_head, d_model)]
    if qkv_bias
      blk.t_b_q = [TinyNNMetal.tnn_input_1d_f32_persistent(@sess, d_head)]
    end
    hq = 1
    while hq < n_heads
      blk.t_w_q.push(alloc_2d_w(d_head, d_model))
      if qkv_bias
        blk.t_b_q.push(TinyNNMetal.tnn_input_1d_f32_persistent(@sess, d_head))
      end
      hq = hq + 1
    end

    # K, V (and the persistent K/V buffers): n_kv per-head. Linear
    # weights quantizable; K/V cache buffers follow @kv_type_*
    # (P5.1 K, P5.2 V); biases stay F32.
    blk.t_w_k = [alloc_2d_w(d_head, d_model)]
    blk.t_w_v = [alloc_2d_w(d_head, d_model)]
    # P5.1: Q8 K alloc when enabled (see realize_for_mmap parallel path).
    if @kv_type_k == 8
      blk.t_K = [TinyNNMetal.tnn_input_2d_persistent_typed(@sess, max_T, d_head, 8)]
    else
      blk.t_K = [TinyNNMetal.tnn_input_2d_f32_persistent(@sess, max_T, d_head)]
    end
    # P5.2: V now mirrors K's layout (ne=[d_head, max_T]).
    if @kv_type_v == 8
      blk.t_V = [TinyNNMetal.tnn_input_2d_persistent_typed(@sess, max_T, d_head, 8)]
    else
      blk.t_V = [TinyNNMetal.tnn_input_2d_f32_persistent(@sess, max_T, d_head)]
    end
    if qkv_bias
      # K bias: 1-D (broadcasts over [d_head, 1] k matmul result).
      # V bias: 1-D too (the V matmul is now ordered weight-first, so
      # its result is [d_head, 1] like K — matches a 1-D bias).
      blk.t_b_k = [TinyNNMetal.tnn_input_1d_f32_persistent(@sess, d_head)]
      blk.t_b_v = [TinyNNMetal.tnn_input_1d_f32_persistent(@sess, d_head)]
    end
    hkv = 1
    while hkv < n_kv
      blk.t_w_k.push(alloc_2d_w(d_head, d_model))
      blk.t_w_v.push(alloc_2d_w(d_head, d_model))
      if @kv_type_k == 8
        blk.t_K.push(TinyNNMetal.tnn_input_2d_persistent_typed(@sess, max_T, d_head, 8))
      else
        blk.t_K.push(TinyNNMetal.tnn_input_2d_f32_persistent(@sess, max_T, d_head))
      end
      if @kv_type_v == 8
        blk.t_V.push(TinyNNMetal.tnn_input_2d_persistent_typed(@sess, max_T, d_head, 8))
      else
        blk.t_V.push(TinyNNMetal.tnn_input_2d_f32_persistent(@sess, max_T, d_head))
      end
      if qkv_bias
        blk.t_b_k.push(TinyNNMetal.tnn_input_1d_f32_persistent(@sess, d_head))
        blk.t_b_v.push(TinyNNMetal.tnn_input_1d_f32_persistent(@sess, d_head))
      end
      hkv = hkv + 1
    end

    blk.t_w_o    = alloc_2d_w(d_model, @n_heads * @d_head)
    blk.t_w_gate = alloc_2d_w(d_ff,    d_model)
    blk.t_w_up   = alloc_2d_w(d_ff,    d_model)
    blk.t_w_down = alloc_2d_w(d_model, d_ff)
    li = li + 1
  end

  TinyNNMetal.tnn_finalize_weights(@sess)
  @realized = true
end

#realize_for_mmap(gguf_handle, cfg, max_T, untied, qkv_bias, qk_norm) ⇒ Object

Phase 2 BYO-pointer realization. Like realize_for but every GGUF-resident tensor (token_embed, norms, biases, all 2D linears, untied output) is allocated to POINT AT the file’s mmap’d pages rather than copied into a backend buffer. Only K/V cache and the compute scratch live in backend-allocated memory. The kv_cache holds the GGUF handle so the mmap stays alive for its lifetime.

Caller flow:

gguf  = TinyNNMetal.tnn_gguf_load(path)        # mmap'd, no_alloc
flags = GGUFLoad.detect_smollm2_flags(path)
wtype = GGUFLoad.detect_weight_type(path)
kv = SmolLM2KVFFICacheMetal.new
kv.realize_for_mmap(gguf, cfg, MAX_T, flags.untied, flags.qkv_bias)
# weights are already in place; no load_weights call needed.


344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 344

def realize_for_mmap(gguf_handle, cfg, max_T, untied, qkv_bias, qk_norm)
  @max_T      = max_T
  @d_model    = cfg.d_model
  @d_ff       = cfg.d_ff
  @n_heads    = cfg.n_heads
  @n_kv       = cfg.n_kv
  @d_head     = cfg.head_dim
  @group_size = cfg.n_heads / cfg.n_kv
  @n_layers   = cfg.n_layers
  @vocab_size = cfg.vocab
  @rope_base    = cfg.rope_base
  @rope_scaling = cfg.rope_scaling
  @rms_eps      = cfg.rms_eps

  @gguf_handle_keepalive = gguf_handle   # prevent GC; mmap must outlive @sess
  @sess              = TinyNNMetal.tnn_session_new(2)
  @has_untied_output = untied
  @has_qkv_bias      = qkv_bias
  @has_qk_norm       = qk_norm
  # #110: if caller didn't pre-set qk_norm_kind via the
  # attr_accessor, default to 1 (per-head shared) for backward
  # compat with the Qwen3 detection that established the qk_norm
  # path. Models that want full-Q (OLMoE / Granite) must set
  # kv.qk_norm_kind = 2 BEFORE calling realize_for_mmap.
  if @has_qk_norm && @qk_norm_kind == 0
    @qk_norm_kind = 1
  end

  # llama3 / LongRoPE: allocate the (d_head/2)-elem freq_factors
  # tensor in ctx_w before finalize_weights. We compute and upload
  # the values after finalize (see below). For all other rope_scaling
  # kinds the FFI call still needs a pointer — pass tnn_null_ptr.
  if @rope_scaling.kind == :llama3
    @t_rope_freq_factors = TinyNNMetal.tnn_rope_freq_factors_alloc(@sess, @d_head)
  else
    @t_rope_freq_factors = TinyNNMetal.tnn_null_ptr
  end

  # Wire the GGUF's mmap region into the session as the source of
  # weight bytes. Subsequent tnn_input_*_persistent_mmap calls
  # allocate tensors with .data inside this region — no copy.
  map_base = TinyNNMetal.tnn_gguf_mmap_base(gguf_handle)
  map_size = TinyNNMetal.tnn_gguf_mmap_size(gguf_handle)
  TinyNNMetal.tnn_session_attach_weight_mmap(@sess, map_base, map_size)

  # toy#gguf-checkpoint-reload (#153) — from-scratch checkpoints
  # written by ToyGGUFWriter store one tensor per head
  # (blk.N.attn_q.head_H.weight) rather than the fused llama.cpp
  # shape. Detect via the head_0 sentinel; the per-Q-head/K/V
  # loaders below branch on it.
  @per_head_attn = TinyNNMetal.tnn_gguf_find_index(gguf_handle, "blk.0.attn_q.head_0.weight") >= 0
  if @per_head_attn
    puts "  per-head tensors detected (toy from-scratch checkpoint)"
  end

  # Globals — embeddings + final norm + optional untied output.
  eidx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, "token_embd.weight")
  eoff = TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, eidx)
  etyp = TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, eidx)
  @t_token_embed = TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                     @vocab_size, @d_model, etyp, eoff)

  fnidx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, "output_norm.weight")
  fnoff = TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, fnidx)
  @t_final_norm_gamma = TinyNNMetal.tnn_input_1d_persistent_mmap(@sess,
                          @d_model, 0, fnoff)   # 0 = GGML_TYPE_F32

  if untied
    oidx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, "output.weight")
    ooff = TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, oidx)
    otyp = TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, oidx)
    @t_output = TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                  @vocab_size, @d_model, otyp, ooff)
  end

  @kv_blocks_ffi = [SmolLM2KVBlockFFIMetal.new]
  li = 1
  while li < @n_layers
    @kv_blocks_ffi.push(SmolLM2KVBlockFFIMetal.new)
    li = li + 1
  end

  li = 0
  while li < @n_layers
    blk = @kv_blocks_ffi[li]
    prefix = "blk." + li.to_s

    # Norms — 1D F32 mmap'd directly.
    rn1_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_norm.weight")
    rn2_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".ffn_norm.weight")
    blk.t_rn1_gamma = TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, @d_model, 0,
                        TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, rn1_idx))
    blk.t_rn2_gamma = TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, @d_model, 0,
                        TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, rn2_idx))

    # I-Gemma (#113): post-attention and post-FFN RMSNorm gammas
    # (Gemma 2 sandwiches each sublayer between pre+post norms).
    # Tensor names: blk.X.post_attention_norm.weight, blk.X.post_ffw_norm.weight.
    if @has_post_norms
      pa_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".post_attention_norm.weight")
      pf_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".post_ffw_norm.weight")
      blk.t_post_attn_norm_gamma = TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, @d_model, 0,
                                     TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, pa_idx))
      blk.t_post_ffn_norm_gamma  = TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, @d_model, 0,
                                     TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, pf_idx))
    end

    # M1 + #110: QK-norm gammas. Two flavors detected via shape:
    #   kind=1: Qwen3 — gamma shape [d_head], shared across heads.
    #   kind=2: OLMoE / Granite — gamma shape [d_model], applied to
    #          the full Q before head split. Allocate the full
    #          [d_model] tensor; the graph builder either does a
    #          full-Q rms_norm OR views per-head d_head slices.
    gamma_nelems = (@qk_norm_kind == 2) ? @d_model : @d_head
    if @has_qk_norm
      qn_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_q_norm.weight")
      kn_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_k_norm.weight")
      blk.t_q_norm_gamma = TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, gamma_nelems, 0,
                             TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, qn_idx))
      # K norm follows the same flavor as Q.
      k_gamma_nelems = (@qk_norm_kind == 2) ? (@n_kv * @d_head) : @d_head
      blk.t_k_norm_gamma = TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, k_gamma_nelems, 0,
                             TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, kn_idx))
    end

    # Q per-head — two layouts:
    # 1) Fused (llama.cpp): single attn_q.weight tensor; each head
    #    is a contiguous slice at offset q_base + h * head_nbytes.
    # 2) Per-head (toy from-scratch ckpt, #153): each head has its
    #    own attn_q.head_H.weight tensor with its own file offset.
    if @per_head_attn
      q0_idx  = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_q.head_0.weight")
      q0_type = TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, q0_idx)
      blk.t_w_q = [TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                     @d_head, @d_model, q0_type,
                     TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, q0_idx))]
      hq = 1
      while hq < @n_heads
        qh_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_q.head_" + hq.to_s + ".weight")
        blk.t_w_q.push(TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                         @d_head, @d_model, q0_type,
                         TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, qh_idx)))
        hq = hq + 1
      end
    else
      q_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_q.weight")
      q_off_base = TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, q_idx)
      q_type     = TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, q_idx)
      q_stride   = head_nbytes(q_type, @d_head, @d_model)
      blk.t_w_q = [TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                     @d_head, @d_model, q_type, q_off_base)]
      hq = 1
      while hq < @n_heads
        blk.t_w_q.push(TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                         @d_head, @d_model, q_type,
                         q_off_base + hq * q_stride))
        hq = hq + 1
      end
    end

    # F1.2: per-Q-head LoRA adapter slots. F32-only, allocated in
    # ctx_w (trainable, not mmap'd). A: (r, d_model). B: (d_head, r).
    # Standard init (A small Gaussian + B zero) makes the adapter
    # equal to zero at step 0 → forward output matches the base
    # model exactly. Caller seeds via upload_lora_q_init!(seed).
    if @lora_q_enabled
      blk.t_w_lora_a_q = [TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                            @lora_q_rank, @d_model)]
      blk.t_w_lora_b_q = [TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                            @d_head, @lora_q_rank)]
      hq = 1
      while hq < @n_heads
        blk.t_w_lora_a_q.push(TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                                @lora_q_rank, @d_model))
        blk.t_w_lora_b_q.push(TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                                @d_head, @lora_q_rank))
        hq = hq + 1
      end

      # F1.2 step 6b: persistent AdamW moments paired with the LoRA
      # adapter tensors above. Same shapes. Live in ctx_w so they
      # survive tnn_reset_for_rebuild across multi-position SFT.
      if @lora_q_adamw_enabled
        blk.t_w_lora_a_q_m = [TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                                @lora_q_rank, @d_model)]
        blk.t_w_lora_a_q_v = [TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                                @lora_q_rank, @d_model)]
        blk.t_w_lora_b_q_m = [TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                                @d_head, @lora_q_rank)]
        blk.t_w_lora_b_q_v = [TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                                @d_head, @lora_q_rank)]
        hqm = 1
        while hqm < @n_heads
          blk.t_w_lora_a_q_m.push(TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                                    @lora_q_rank, @d_model))
          blk.t_w_lora_a_q_v.push(TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                                    @lora_q_rank, @d_model))
          blk.t_w_lora_b_q_m.push(TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                                    @d_head, @lora_q_rank))
          blk.t_w_lora_b_q_v.push(TinyNNMetal.tnn_input_2d_f32_persistent(@sess,
                                    @d_head, @lora_q_rank))
          hqm = hqm + 1
        end
      end
    end

    # K, V per-kv-head — same dual-layout split (#153).
    if @per_head_attn
      k0_idx  = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_k.head_0.weight")
      v0_idx  = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_v.head_0.weight")
      k_type  = TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, k0_idx)
      v_type  = TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, v0_idx)
      blk.t_w_k = [TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                     @d_head, @d_model, k_type,
                     TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, k0_idx))]
      blk.t_w_v = [TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                     @d_head, @d_model, v_type,
                     TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, v0_idx))]
      k_stride = 0  # unused in per-head branch but referenced later
      v_stride = 0
    else
      k_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_k.weight")
      v_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_v.weight")
      k_off_base = TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, k_idx)
      v_off_base = TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, v_idx)
      k_type     = TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, k_idx)
      v_type     = TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, v_idx)
      k_stride   = head_nbytes(k_type, @d_head, @d_model)
      v_stride   = head_nbytes(v_type, @d_head, @d_model)
      blk.t_w_k = [TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                     @d_head, @d_model, k_type, k_off_base)]
      blk.t_w_v = [TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                     @d_head, @d_model, v_type, v_off_base)]
    end
    # P5.1+P5.2: K and V allocs both follow @kv_type_*. Layout is
    # `ne=[d_head, max_T]` for both — positions on ne1, d_head on
    # ne0. Per-position writes span a contiguous d_head-vector
    # which is Q8-block-aligned at d_head=64 (=2 blocks of 32).
    # See the struct comment on :kv_type_k / :kv_type_v.
    if @kv_type_k == 8
      blk.t_K = [TinyNNMetal.tnn_input_2d_persistent_typed(@sess, max_T, @d_head, 8)]
    else
      blk.t_K = [TinyNNMetal.tnn_input_2d_f32_persistent(@sess, max_T, @d_head)]
    end
    if @kv_type_v == 8
      blk.t_V = [TinyNNMetal.tnn_input_2d_persistent_typed(@sess, max_T, @d_head, 8)]
    else
      blk.t_V = [TinyNNMetal.tnn_input_2d_f32_persistent(@sess, max_T, @d_head)]
    end
    hkv = 1
    while hkv < @n_kv
      if @per_head_attn
        kh_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_k.head_" + hkv.to_s + ".weight")
        vh_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_v.head_" + hkv.to_s + ".weight")
        blk.t_w_k.push(TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                         @d_head, @d_model, k_type,
                         TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, kh_idx)))
        blk.t_w_v.push(TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                         @d_head, @d_model, v_type,
                         TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, vh_idx)))
      else
        blk.t_w_k.push(TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                         @d_head, @d_model, k_type,
                         k_off_base + hkv * k_stride))
        blk.t_w_v.push(TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                         @d_head, @d_model, v_type,
                         v_off_base + hkv * v_stride))
      end
      if @kv_type_k == 8
        blk.t_K.push(TinyNNMetal.tnn_input_2d_persistent_typed(@sess, max_T, @d_head, 8))
      else
        blk.t_K.push(TinyNNMetal.tnn_input_2d_f32_persistent(@sess, max_T, @d_head))
      end
      if @kv_type_v == 8
        blk.t_V.push(TinyNNMetal.tnn_input_2d_persistent_typed(@sess, max_T, @d_head, 8))
      else
        blk.t_V.push(TinyNNMetal.tnn_input_2d_f32_persistent(@sess, max_T, @d_head))
      end
      hkv = hkv + 1
    end

    # Q/K/V biases — 1D F32 per head, contiguous in the file.
    if qkv_bias
      qb_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_q.bias")
      kb_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_k.bias")
      vb_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_v.bias")
      qb_off = TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, qb_idx)
      kb_off = TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, kb_idx)
      vb_off = TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, vb_idx)
      bias_stride = @d_head * 4  # f32

      blk.t_b_q = [TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, @d_head, 0, qb_off)]
      hq = 1
      while hq < @n_heads
        blk.t_b_q.push(TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, @d_head, 0,
                         qb_off + hq * bias_stride))
        hq = hq + 1
      end

      blk.t_b_k = [TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, @d_head, 0, kb_off)]
      blk.t_b_v = [TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, @d_head, 0, vb_off)]
      hkv = 1
      while hkv < @n_kv
        blk.t_b_k.push(TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, @d_head, 0,
                         kb_off + hkv * bias_stride))
        blk.t_b_v.push(TinyNNMetal.tnn_input_1d_persistent_mmap(@sess, @d_head, 0,
                         vb_off + hkv * bias_stride))
        hkv = hkv + 1
      end
    end

    # O / FFN — full 2D weights, no per-head slicing.
    o_idx    = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".attn_output.weight")
    # M1.1: o_proj maps [n_heads * d_head] → [d_model]. For models
    # where d_head = d_model / n_heads (SmolLM2 / Llama / Qwen2.5)
    # these are equal; for Qwen3 with explicit head_dim=128 they
    # differ (n_heads * d_head = 2048, d_model = 1024).
    blk.t_w_o    = TinyNNMetal.tnn_input_2d_persistent_mmap(@sess, @d_model, @n_heads * @d_head,
                     TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, o_idx),
                     TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, o_idx))

    if @is_moe
      # M2.3: MoE FFN. Per-expert weight matrices are stacked along
      # ne2 in the GGUF (llama.cpp convention):
      #   ffn_gate_inp.weight : ne=[d_model, n_experts]
      #   ffn_gate_exps.weight: ne=[d_model, d_ff,    n_experts]
      #   ffn_up_exps.weight  : ne=[d_model, d_ff,    n_experts]
      #   ffn_down_exps.weight: ne=[d_ff,    d_model, n_experts]
      # All mmap'd in place — Mixtral-8x7B Q4_K_M (26 GB) loads without
      # any RAM copy.
      router_idx    = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".ffn_gate_inp.weight")
      gate_exps_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".ffn_gate_exps.weight")
      up_exps_idx   = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".ffn_up_exps.weight")
      down_exps_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".ffn_down_exps.weight")
      # #112 (RESOLVED): K-quant MoE experts work. The old warning here
      # blamed ggml's mul_mat_id kernel for the OLMoE-Q4_K_M corruption,
      # but the op was always correct for K-quants (verified by op-level
      # and real-bytes reproducers in tinynn/ggml1506_*). The actual bug
      # was head_nbytes() returning 0 for K-quant ATTENTION weights,
      # collapsing every head onto head 0 — fixed there. K-quant expert
      # stacks (gate/up/down, including OLMoE's mixed q4_K+q6_K down_exps)
      # load and run coherently. See docs/notes/mul_mat_id_quants.md.
      blk.t_w_router    = TinyNNMetal.tnn_input_2d_persistent_mmap(@sess,
                            @n_experts, @d_model,
                            TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, router_idx),
                            TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, router_idx))
      blk.t_w_gate_exps = TinyNNMetal.tnn_input_3d_persistent_mmap(@sess,
                            @d_model, @d_ff, @n_experts,
                            TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, gate_exps_idx),
                            TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, gate_exps_idx))
      blk.t_w_up_exps   = TinyNNMetal.tnn_input_3d_persistent_mmap(@sess,
                            @d_model, @d_ff, @n_experts,
                            TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, up_exps_idx),
                            TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, up_exps_idx))
      blk.t_w_down_exps = TinyNNMetal.tnn_input_3d_persistent_mmap(@sess,
                            @d_ff, @d_model, @n_experts,
                            TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, down_exps_idx),
                            TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, down_exps_idx))
    else
      gate_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".ffn_gate.weight")
      up_idx   = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".ffn_up.weight")
      down_idx = TinyNNMetal.tnn_gguf_find_index(gguf_handle, prefix + ".ffn_down.weight")
      blk.t_w_gate = TinyNNMetal.tnn_input_2d_persistent_mmap(@sess, @d_ff, @d_model,
                       TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, gate_idx),
                       TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, gate_idx))
      blk.t_w_up   = TinyNNMetal.tnn_input_2d_persistent_mmap(@sess, @d_ff, @d_model,
                       TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, up_idx),
                       TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, up_idx))
      blk.t_w_down = TinyNNMetal.tnn_input_2d_persistent_mmap(@sess, @d_model, @d_ff,
                       TinyNNMetal.tnn_gguf_tensor_type(gguf_handle, down_idx),
                       TinyNNMetal.tnn_gguf_tensor_file_offset(gguf_handle, down_idx))
    end

    li = li + 1
  end

  # F1.2: mark LoRA tensors as trainable BEFORE finalize_weights.
  # set_param flips a flag on the tensor; the build_backward pass
  # later walks PARAM-flagged nodes to emit grad nodes. Doing it
  # here (rather than in the smoke) keeps the cache class as the
  # single source of truth for what's trainable in a session.
  if @lora_q_enabled
    li2 = 0
    while li2 < @n_layers
      blk2 = @kv_blocks_ffi[li2]
      hq = 0
      while hq < @n_heads
        TinyNNMetal.tnn_set_param(blk2.t_w_lora_a_q[hq])
        TinyNNMetal.tnn_set_param(blk2.t_w_lora_b_q[hq])
        hq = hq + 1
      end
      li2 = li2 + 1
    end
  end

  # Finalize the regular persistent context (K/V cache buffers).
  # Mmap'd tensors don't need finalization — they were allocated
  # against weights_buf_mmap inline.
  TinyNNMetal.tnn_finalize_weights(@sess)

  # Upload llama3-style RoPE freq_factors once the backend buffer
  # for @t_rope_freq_factors exists (post-finalize). The values are
  # a per-model constant — never re-uploaded across rebuild cycles.
  if @rope_scaling.kind == :llama3
    ff = Toy::RopeScaling.compute_llama3_freq_factors(
      @d_head, @rope_base,
      @rope_scaling.orig_max_pos, @rope_scaling.factor,
      @rope_scaling.low_freq_factor, @rope_scaling.high_freq_factor)
    TinyNNMetal.tnn_upload_from_float_array(@sess, @t_rope_freq_factors,
                                       ff, ff.length)
  end

  # F1.2 step 6b: zero-init persistent Adam moments. AdamW's update
  # rule assumes m = v = 0 at step 0 (otherwise the first step picks
  # up garbage from the buffer). The bias-correction term beta1h/beta2h
  # then ramps in as the moments accumulate.
  if @lora_q_adamw_enabled
    za = Mat.new(@lora_q_rank, @d_model)
    zb = Mat.new(@d_head,      @lora_q_rank)
    i = 0
    while i < @lora_q_rank * @d_model; za.flat[i] = 0.0; i = i + 1; end
    j = 0
    while j < @d_head * @lora_q_rank; zb.flat[j] = 0.0; j = j + 1; end
    li_z = 0
    while li_z < @n_layers
      blk_z = @kv_blocks_ffi[li_z]
      hqz = 0
      while hqz < @n_heads
        TinyNNMetal.upload_row_major(@sess, blk_z.t_w_lora_a_q_m[hqz], za)
        TinyNNMetal.upload_row_major(@sess, blk_z.t_w_lora_a_q_v[hqz], za)
        TinyNNMetal.upload_row_major(@sess, blk_z.t_w_lora_b_q_m[hqz], zb)
        TinyNNMetal.upload_row_major(@sess, blk_z.t_w_lora_b_q_v[hqz], zb)
        hqz = hqz + 1
      end
      li_z = li_z + 1
    end
  end

  # Zero-init K/V cache buffers (same as realize_for + legacy load).
  # P5.1: skip K zero-init when K is Q8_0. upload_row_major writes
  # F32 row-major bytes which would corrupt a Q8 tensor's quantization
  # blocks. The K cache is read only at positions [0, pos+1], and
  # every position is written before it's read, so unset trailing
  # positions are never observed — zero-init is paranoia and safe
  # to skip for Q8. P5.2 flipped V to mirror K's layout, so V's
  # zero-init Mat now has the same shape as K's, and the same Q8
  # skip rule applies.
  kv_zero = Mat.new(max_T, @d_head)
  li = 0
  while li < @n_layers
    blk_f = @kv_blocks_ffi[li]
    hkv = 0
    while hkv < @n_kv
      if @kv_type_k != 8
        TinyNNMetal.upload_row_major(@sess, blk_f.t_K[hkv], kv_zero)
      end
      if @kv_type_v != 8
        TinyNNMetal.upload_row_major(@sess, blk_f.t_V[hkv], kv_zero)
      end
      hkv = hkv + 1
    end
    li = li + 1
  end

  @realized = true
end

#set_weight_type(t) ⇒ Object

Phase 3 opt-in: set the ggml type used for 2D linear weights when realize_for runs. 0 = F32, 8 = Q8_0. Call BEFORE realize_for —the persistent tensors are allocated there.



313
314
315
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 313

def set_weight_type(t)
  @weight_type = t
end

#trace_tap(_name, t) ⇒ Object



910
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 910

def trace_tap(_name, t); t; end

#upload_lora_q_init!(seed, init_scale) ⇒ Object

Auto-dispatch: open the GGUF, peek at its ‘toy.ggml_native` flag, and route to either the BYO-pointer mmap path (Phase 2) or the legacy realize_for + load_weights copy path. Returns the GGUF handle (or null for the legacy path); the kv_cache holds it via Caller must have `require_relative “toy/models/toy_smollm2_loader”` at the top-level driver — this file deliberately does NOT require it (require-order with GGUFLoad’s methods that touch ‘weight_type` was triggering a Spinel GC crash in decode_step). F1.2: standard LoRA init for the Q adapters. A = small Gaussian (scale = init_scale, default 0.01); B = zero. With B=0 the LoRA contribution is exactly zero, so forward output matches the base model bit-for-bit at step 0. Call AFTER realize_for_mmap.



825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
# File 'lib/toy/llm/engine/llama_kv_engine_metal.rb', line 825

def upload_lora_q_init!(seed, init_scale)
  if !@lora_q_enabled; return; end
  s = seed
  m_a = Mat.new(@lora_q_rank, @d_model)
  m_b = Mat.new(@d_head, @lora_q_rank)
  z_b = m_b
  i_b = 0
  while i_b < @d_head * @lora_q_rank
    z_b.flat[i_b] = 0.0
    i_b = i_b + 1
  end
  li = 0
  while li < @n_layers
    blk = @kv_blocks_ffi[li]
    hq = 0
    while hq < @n_heads
      # Per-(layer, head) Gaussian for A via Box-Muller on an LCG.
      ii = 0
      while ii < @lora_q_rank * @d_model
        s = (s * 1103515245 + 12345) & 0x7FFFFFFF
        u1 = (s.to_f + 1.0) / 2147483648.0
        s = (s * 1103515245 + 12345) & 0x7FFFFFFF
        u2 = (s.to_f + 1.0) / 2147483648.0
        m_a.flat[ii] = init_scale * Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math::PI * u2)
        ii = ii + 1
      end
      TinyNNMetal.upload_row_major(@sess, blk.t_w_lora_a_q[hq], m_a)
      TinyNNMetal.upload_row_major(@sess, blk.t_w_lora_b_q[hq], z_b)
      hq = hq + 1
    end
    li = li + 1
  end
end