Class: Toy::GQAttention
- Inherits:
-
Object
- Object
- Toy::GQAttention
- Defined in:
- lib/toy.rb
Overview
Toy::GQAttention — grouped-query causal self-attention.
n_heads query heads share n_kv key/value heads (n_heads / n_kv queries per KV head). Used by SmolLM2 (9/3), TinyLlama (32/4), Qwen2.5-0.5B (14/2). When n_heads == n_kv this degenerates to standard MHA.
RoPE is applied to Q and K before the dot product. The two-arg forward ‘(x, pos_start)` is needed because RoPE depends on absolute position.
Q/K/V biases are pre-allocated to zeros but only applied when ‘has_qkv_bias` is true. SmolLM2 / Llama / TinyLlama: false (Llama convention). Qwen2 / Qwen2.5: true (Q/K/V have learned biases; O does not). The loader flips the flag when it finds attn_q.bias / attn_k.bias / attn_v.bias in the GGUF — pre-allocation keeps the ivar types stable for Spinel (no reassign-after-construct).
Instance Attribute Summary collapse
-
#b_k ⇒ Object
Returns the value of attribute b_k.
-
#b_q ⇒ Object
Returns the value of attribute b_q.
-
#b_v ⇒ Object
Returns the value of attribute b_v.
-
#d_head ⇒ Object
Returns the value of attribute d_head.
-
#d_model ⇒ Object
Returns the value of attribute d_model.
-
#group_size ⇒ Object
Returns the value of attribute group_size.
-
#has_qkv_bias ⇒ Object
Returns the value of attribute has_qkv_bias.
-
#inv_sqrt ⇒ Object
Returns the value of attribute inv_sqrt.
-
#n_heads ⇒ Object
Returns the value of attribute n_heads.
-
#n_kv ⇒ Object
Returns the value of attribute n_kv.
-
#rope ⇒ Object
Returns the value of attribute rope.
-
#w_k ⇒ Object
Returns the value of attribute w_k.
-
#w_o ⇒ Object
Returns the value of attribute w_o.
-
#w_q ⇒ Object
Returns the value of attribute w_q.
-
#w_v ⇒ Object
Returns the value of attribute w_v.
Instance Method Summary collapse
- #algorithm ⇒ Object
- #algorithm_card ⇒ Object
-
#attend(x, hq, k_h, v_h, pos_start) ⇒ Object
One query-head attention.
-
#enable_qkv_bias! ⇒ Object
Called by the GGUF loader when attn_q.bias / attn_k.bias / attn_v.bias are present.
-
#forward(x, pos_start) ⇒ Object
x: [T, D] → [T, D].
-
#initialize(d_model, n_heads, n_kv, rope_obj) ⇒ GQAttention
constructor
A new instance of GQAttention.
- #param_count ⇒ Object
- #summary ⇒ Object
Constructor Details
#initialize(d_model, n_heads, n_kv, rope_obj) ⇒ GQAttention
Returns a new instance of GQAttention.
574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 |
# File 'lib/toy.rb', line 574 def initialize(d_model, n_heads, n_kv, rope_obj) @d_model = d_model @n_heads = n_heads @n_kv = n_kv @d_head = d_model / n_heads @group_size = n_heads / n_kv @inv_sqrt = 1.0 / Math.sqrt(@d_head) @rope = rope_obj @w_q = [Mat.new(d_model, @d_head)] @b_q = [Array.new(@d_head, 0.0)] # per Q head (zeros until enabled) hq = 1 while hq < n_heads @w_q.push(Mat.new(d_model, @d_head)) @b_q.push(Array.new(@d_head, 0.0)) hq += 1 end @w_k = [Mat.new(d_model, @d_head)] @w_v = [Mat.new(d_model, @d_head)] @b_k = [Array.new(@d_head, 0.0)] # per KV head @b_v = [Array.new(@d_head, 0.0)] hkv = 1 while hkv < n_kv @w_k.push(Mat.new(d_model, @d_head)) @w_v.push(Mat.new(d_model, @d_head)) @b_k.push(Array.new(@d_head, 0.0)) @b_v.push(Array.new(@d_head, 0.0)) hkv += 1 end @w_o = Mat.new(d_model, d_model) @has_qkv_bias = false # flipped by the loader for Qwen2.x end |
Instance Attribute Details
#b_k ⇒ Object
Returns the value of attribute b_k.
569 570 571 |
# File 'lib/toy.rb', line 569 def b_k @b_k end |
#b_q ⇒ Object
Returns the value of attribute b_q.
569 570 571 |
# File 'lib/toy.rb', line 569 def b_q @b_q end |
#b_v ⇒ Object
Returns the value of attribute b_v.
569 570 571 |
# File 'lib/toy.rb', line 569 def b_v @b_v end |
#d_head ⇒ Object
Returns the value of attribute d_head.
569 570 571 |
# File 'lib/toy.rb', line 569 def d_head @d_head end |
#d_model ⇒ Object
Returns the value of attribute d_model.
569 570 571 |
# File 'lib/toy.rb', line 569 def d_model @d_model end |
#group_size ⇒ Object
Returns the value of attribute group_size.
569 570 571 |
# File 'lib/toy.rb', line 569 def group_size @group_size end |
#has_qkv_bias ⇒ Object
Returns the value of attribute has_qkv_bias.
569 570 571 |
# File 'lib/toy.rb', line 569 def has_qkv_bias @has_qkv_bias end |
#inv_sqrt ⇒ Object
Returns the value of attribute inv_sqrt.
569 570 571 |
# File 'lib/toy.rb', line 569 def inv_sqrt @inv_sqrt end |
#n_heads ⇒ Object
Returns the value of attribute n_heads.
569 570 571 |
# File 'lib/toy.rb', line 569 def n_heads @n_heads end |
#n_kv ⇒ Object
Returns the value of attribute n_kv.
569 570 571 |
# File 'lib/toy.rb', line 569 def n_kv @n_kv end |
#rope ⇒ Object
Returns the value of attribute rope.
569 570 571 |
# File 'lib/toy.rb', line 569 def rope @rope end |
#w_k ⇒ Object
Returns the value of attribute w_k.
569 570 571 |
# File 'lib/toy.rb', line 569 def w_k @w_k end |
#w_o ⇒ Object
Returns the value of attribute w_o.
569 570 571 |
# File 'lib/toy.rb', line 569 def w_o @w_o end |
#w_q ⇒ Object
Returns the value of attribute w_q.
569 570 571 |
# File 'lib/toy.rb', line 569 def w_q @w_q end |
#w_v ⇒ Object
Returns the value of attribute w_v.
569 570 571 |
# File 'lib/toy.rb', line 569 def w_v @w_v end |
Instance Method Details
#algorithm ⇒ Object
690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 |
# File 'lib/toy.rb', line 690 def algorithm c = Toy::Card.new("GQAttention.forward(x, p_start)", "grouped-query + RoPE") c.add_input("x", "R^{T×D}", "") c.add_input("p_start", "ℕ", "") c.add_output("y", "R^{T×D}", "") c.add_hyper("D", @d_model.to_s) c.add_hyper("H", @n_heads.to_s) c.add_hyper("H_kv", @n_kv.to_s) c.add_hyper("g", @group_size.to_s) c.add_hyper("D_h", @d_head.to_s) c.add_param("W_Q^h", "R^{D×D_h}", "per query head (h=1..H)") c.add_param("W_K^j, W_V^j", "R^{D×D_h}", "per KV head (j=1..H_kv); shared across g Q heads") if @has_qkv_bias c.add_param("b_Q^h", "R^{D_h}", "per query head") c.add_param("b_K^j, b_V^j", "R^{D_h}", "per KV head — Qwen2.x convention") c.add_param("W_O", "R^{D×D}", "no output bias") else c.add_param("W_O", "R^{D×D}", "no biases — Llama convention") end c.step_loop("j ← 1, …, H_kv", "KV computed once per group") if @has_qkv_bias c.step_bind("k^j", "RoPE(x · W_K^j + b_K^j, p_start)", "k^j ∈ R^{T×D_h}") c.step_update("v^j", "x · W_V^j + b_V^j", "v^j ∈ R^{T×D_h}", "V not rotated") else c.step_bind("k^j", "RoPE(x · W_K^j, p_start)", "k^j ∈ R^{T×D_h}") c.step_update("v^j", "x · W_V^j", "v^j ∈ R^{T×D_h}", "V not rotated") end c.step_loop_close c.step_loop("h ← 1, …, H", "per query head") c.step_bind("j", "⌊(h−1) / g⌋ + 1", "") if @has_qkv_bias c.step_bind("q^h", "RoPE(x · W_Q^h + b_Q^h, p_start)", "q^h ∈ R^{T×D_h}") else c.step_bind("q^h", "RoPE(x · W_Q^h, p_start)", "q^h ∈ R^{T×D_h}") end c.step_bind("S^h", "CausalMask(q^h · (k^j)^⊤ / √D_h)", "S^h ∈ R^{T×T}") c.step_bind("o^h", "softmax_rows(S^h) · v^j", "o^h ∈ R^{T×D_h}") c.step_loop_close c.step_bind("y", "concat(o^1, …, o^H) · W_O", "y ∈ R^{T×D}") c.step_return("y") c end |
#algorithm_card ⇒ Object
734 |
# File 'lib/toy.rb', line 734 def algorithm_card; algorithm.render_pseudocode; end |
#attend(x, hq, k_h, v_h, pos_start) ⇒ Object
One query-head attention. x: [T, D] → [T, Dh].
660 661 662 663 664 665 666 667 668 669 670 671 |
# File 'lib/toy.rb', line 660 def attend(x, hq, k_h, v_h, pos_start) q_h = x.matmul(@w_q[hq]) # [T, Dh] if @has_qkv_bias Toy.add_bias!(q_h, @b_q[hq]) end @rope.rotate!(q_h, pos_start) scores = q_h.matmul_t(k_h) # [T, T] scores.scale!(@inv_sqrt) Toy.causal_mask!(scores) Toy.softmax_rows!(scores) scores.matmul(v_h) # [T, Dh] end |
#enable_qkv_bias! ⇒ Object
Called by the GGUF loader when attn_q.bias / attn_k.bias / attn_v.bias are present. Biases are already allocated; this just flips the flag.
610 611 612 |
# File 'lib/toy.rb', line 610 def enable_qkv_bias! @has_qkv_bias = true end |
#forward(x, pos_start) ⇒ Object
x: [T, D] → [T, D]. pos_start: absolute position of row 0 of x.
615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 |
# File 'lib/toy.rb', line 615 def forward(x, pos_start) # 1) project + rotate K, V once per KV head (n_kv times). k0 = x.matmul(@w_k[0]) # [T, Dh] if @has_qkv_bias Toy.add_bias!(k0, @b_k[0]) end @rope.rotate!(k0, pos_start) v0 = x.matmul(@w_v[0]) # [T, Dh] (V is not rotated) if @has_qkv_bias Toy.add_bias!(v0, @b_v[0]) end ks = [k0] vs = [v0] hkv = 1 while hkv < @n_kv k_h = x.matmul(@w_k[hkv]) if @has_qkv_bias Toy.add_bias!(k_h, @b_k[hkv]) end @rope.rotate!(k_h, pos_start) v_h = x.matmul(@w_v[hkv]) if @has_qkv_bias Toy.add_bias!(v_h, @b_v[hkv]) end ks.push(k_h) vs.push(v_h) hkv += 1 end # 2) per query head: project Q, rotate, attend with the # corresponding (shared) K, V. head0 = attend(x, 0, ks[0], vs[0], pos_start) heads = [head0] hq = 1 while hq < @n_heads grp = hq / @group_size heads.push(attend(x, hq, ks[grp], vs[grp], pos_start)) hq += 1 end concat = Toy.hstack_heads(heads, @n_heads, @d_head, @d_model) # [T, D] concat.matmul(@w_o) # [T, D] end |
#param_count ⇒ Object
678 679 680 681 682 683 684 685 686 687 688 |
# File 'lib/toy.rb', line 678 def param_count # Q: n_heads × (d_model × d_head). K/V: n_kv × (d_model × d_head). O: d_model² # Plus Q/K/V biases when enabled (Qwen2.x). n = @n_heads * @d_model * @d_head + 2 * @n_kv * @d_model * @d_head + @d_model * @d_model if @has_qkv_bias n = n + @n_heads * @d_head + 2 * @n_kv * @d_head end n end |
#summary ⇒ Object
673 674 675 676 677 |
# File 'lib/toy.rb', line 673 def summary "GQAttention(d=" + @d_model.to_s + ", n_q=" + @n_heads.to_s + ", n_kv=" + @n_kv.to_s + ", d_head=" + @d_head.to_s + ", group=" + @group_size.to_s + ")" end |