Module: GPT2KVMetal

Defined in:
lib/toy/llm/engine/gpt2_kv_engine_metal.rb

Class Method Summary collapse

Class Method Details

.decode_step(kv_cache, token_id, pos) ⇒ Object

Decode one new token at position ‘pos`. Writes K, V[:, pos] as a side effect, returns the (vocab,) logits Mat for the new position. The caller can argmax (greedy) or sample.



337
338
339
340
341
342
343
344
345
346
# File 'lib/toy/llm/engine/gpt2_kv_engine_metal.rb', line 337

def self.decode_step(kv_cache, token_id, pos)
  TinyNNMetal.tnn_reset_for_rebuild(kv_cache.sess)
  step = kv_cache.build_decode_step(pos)
  TinyNNMetal.tnn_realize(kv_cache.sess, step.kv_step_logits)
  TinyNNMetal.upload_int_array(kv_cache.sess, step.t_token_id, [token_id])
  TinyNNMetal.tnn_compute(kv_cache.sess)
  # Logits ne=[vocab, 1]. Download as (1, vocab) row-major — same
  # layout as a single-row Mat with vocab columns.
  TinyNNMetal.download_row_major(kv_cache.sess, step.kv_step_logits, 1, kv_cache.vocab_size)
end

.upload_from(kv_cache, model) ⇒ Object

Upload all GPT-2 weights (+ zero-init the K/V buffers) into a realized GPT2KVFFICacheMetal. Counterpart to GPT2FFI.upload_from for the KV cache variant. Note: pos_embed is uploaded in FULL (all context_length rows), not sliced.



281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
# File 'lib/toy/llm/engine/gpt2_kv_engine_metal.rb', line 281

def self.upload_from(kv_cache, model)
  sess    = kv_cache.sess
  n       = kv_cache.n_layers
  n_heads = kv_cache.n_heads
  d_model = kv_cache.d_model
  d_head  = kv_cache.d_head
  max_T   = kv_cache.max_T

  TinyNNMetal.upload_row_major(sess, kv_cache.t_token_embed, model.token_embed)
  TinyNNMetal.upload_row_major(sess, kv_cache.t_pos_embed,   model.pos_embed)
  TinyNNMetal.tnn_upload_from_float_array(sess, kv_cache.t_ln_f_gamma, model.ln_f_gamma, d_model)
  TinyNNMetal.tnn_upload_from_float_array(sess, kv_cache.t_ln_f_beta,  model.ln_f_beta,  d_model)

  # Zero buffers for K/V (ggml_backend_alloc_ctx_tensors typically
  # zeros, but be explicit so reuse across multiple decode runs has
  # a clean starting state).
  kv_zero_k = Mat.new(max_T,  d_head)
  kv_zero_v = Mat.new(d_head, max_T)

  li = 0
  while li < n
    blk_n = model.gpt2_blocks[li]
    blk_f = kv_cache.kv_blocks_ffi[li]

    TinyNNMetal.tnn_upload_from_float_array(sess, blk_f.t_ln1_gamma, blk_n.ln1_gamma, d_model)
    TinyNNMetal.tnn_upload_from_float_array(sess, blk_f.t_ln1_beta,  blk_n.ln1_beta,  d_model)
    TinyNNMetal.tnn_upload_from_float_array(sess, blk_f.t_ln2_gamma, blk_n.ln2_gamma, d_model)
    TinyNNMetal.tnn_upload_from_float_array(sess, blk_f.t_ln2_beta,  blk_n.ln2_beta,  d_model)

    h = 0
    while h < n_heads
      TinyNNMetal.stage_transposed_and_upload(sess, blk_f.t_w_q[h], blk_n.w_q[h])
      TinyNNMetal.stage_transposed_and_upload(sess, blk_f.t_w_k[h], blk_n.w_k[h])
      TinyNNMetal.stage_transposed_and_upload(sess, blk_f.t_w_v[h], blk_n.w_v[h])
      TinyNNMetal.tnn_upload_from_float_array(sess, blk_f.t_b_q[h], blk_n.b_q[h], d_head)
      TinyNNMetal.tnn_upload_from_float_array(sess, blk_f.t_b_k[h], blk_n.b_k[h], d_head)
      TinyNNMetal.tnn_upload_from_float_array(sess, blk_f.t_b_v[h], blk_n.b_v[h], d_head)
      TinyNNMetal.upload_row_major(sess, blk_f.t_K[h], kv_zero_k)
      TinyNNMetal.upload_row_major(sess, blk_f.t_V[h], kv_zero_v)
      h = h + 1
    end

    TinyNNMetal.stage_transposed_and_upload(sess, blk_f.t_w_o,   blk_n.w_o)
    TinyNNMetal.stage_transposed_and_upload(sess, blk_f.t_w_ff1, blk_n.w_ff1)
    TinyNNMetal.stage_transposed_and_upload(sess, blk_f.t_w_ff2, blk_n.w_ff2)
    TinyNNMetal.tnn_upload_from_float_array(sess, blk_f.t_b_o,   blk_n.b_o,   d_model)
    TinyNNMetal.tnn_upload_from_float_array(sess, blk_f.t_b_ff1, blk_n.b_ff1, kv_cache.d_ff)
    TinyNNMetal.tnn_upload_from_float_array(sess, blk_f.t_b_ff2, blk_n.b_ff2, d_model)

    li = li + 1
  end
end