Module: Toy::LLM::Primitives::GQA

Defined in:
lib/toy/llm/primitives/gqa.rb,
lib/toy/llm/primitives/gqa_cuda.rb,
lib/toy/llm/primitives/gqa_metal.rb

Constant Summary collapse

NAME =
:gqa

Class Method Summary collapse

Class Method Details

.attention(sess, t_k, t_q, t_vt, attn_mask, scale, batch) ⇒ Object

Pure attention math: scores -> scaled+masked softmax -> weighted V. No weights, no KV cache, no ivars. All inputs are tensor handles + scalars supplied by the L2 block.

t_k        : selected KV-head key tensor    ne=[d_head, T*B]
t_q        : per-Q-head rotated query       ne=[d_head, T*B]
t_vt       : selected KV-head V transpose   ne=[T*B, d_head]
attn_mask  : block-causal mask tensor (used only when batch > 1)
scale      : 1.0 / sqrt(d_head)
batch      : @seq_b (selects soft_max_ext vs diag_mask_inf path)

Returns the per-head output tensor ne=[d_head, T*B].



48
49
50
51
52
53
54
55
56
57
58
# File 'lib/toy/llm/primitives/gqa.rb', line 48

def self.attention(sess, t_k, t_q, t_vt, attn_mask, scale, batch)
  t_scores = TinyNN.tnn_matmul(sess, t_k, t_q)
  if batch > 1
    t_attn = TinyNN.tnn_soft_max_ext(sess, t_scores, attn_mask, scale, 0.0)
  else
    t_scaled = TinyNN.tnn_scale(sess, t_scores, scale)
    t_masked = TinyNN.tnn_diag_mask_inf(sess, t_scaled, 0)
    t_attn   = TinyNN.tnn_softmax(sess, t_masked)
  end
  TinyNN.tnn_matmul(sess, t_vt, t_attn)
end