Module: Toy::LLM::Primitives::GQA
- Defined in:
- lib/toy/llm/primitives/gqa.rb,
lib/toy/llm/primitives/gqa_cuda.rb,
lib/toy/llm/primitives/gqa_metal.rb
Constant Summary collapse
- NAME =
:gqa
Class Method Summary collapse
-
.attention(sess, t_k, t_q, t_vt, attn_mask, scale, batch) ⇒ Object
Pure attention math: scores -> scaled+masked softmax -> weighted V.
Class Method Details
.attention(sess, t_k, t_q, t_vt, attn_mask, scale, batch) ⇒ Object
Pure attention math: scores -> scaled+masked softmax -> weighted V. No weights, no KV cache, no ivars. All inputs are tensor handles + scalars supplied by the L2 block.
t_k : selected KV-head key tensor ne=[d_head, T*B]
t_q : per-Q-head rotated query ne=[d_head, T*B]
t_vt : selected KV-head V transpose ne=[T*B, d_head]
attn_mask : block-causal mask tensor (used only when batch > 1)
scale : 1.0 / sqrt(d_head)
batch : @seq_b (selects soft_max_ext vs diag_mask_inf path)
Returns the per-head output tensor ne=[d_head, T*B].
48 49 50 51 52 53 54 55 56 57 58 |
# File 'lib/toy/llm/primitives/gqa.rb', line 48 def self.attention(sess, t_k, t_q, t_vt, attn_mask, scale, batch) t_scores = TinyNN.tnn_matmul(sess, t_k, t_q) if batch > 1 t_attn = TinyNN.tnn_soft_max_ext(sess, t_scores, attn_mask, scale, 0.0) else t_scaled = TinyNN.tnn_scale(sess, t_scores, scale) t_masked = TinyNN.tnn_diag_mask_inf(sess, t_scaled, 0) t_attn = TinyNN.tnn_softmax(sess, t_masked) end TinyNN.tnn_matmul(sess, t_vt, t_attn) end |