Class: Toy::GQAttention

Inherits:
Object
  • Object
show all
Defined in:
lib/toy.rb

Overview

Toy::GQAttention — grouped-query causal self-attention.

n_heads query heads share n_kv key/value heads (n_heads / n_kv queries per KV head). Used by SmolLM2 (9/3), TinyLlama (32/4), Qwen2.5-0.5B (14/2). When n_heads == n_kv this degenerates to standard MHA.

RoPE is applied to Q and K before the dot product. The two-arg forward ‘(x, pos_start)` is needed because RoPE depends on absolute position.

Q/K/V biases are pre-allocated to zeros but only applied when ‘has_qkv_bias` is true. SmolLM2 / Llama / TinyLlama: false (Llama convention). Qwen2 / Qwen2.5: true (Q/K/V have learned biases; O does not). The loader flips the flag when it finds attn_q.bias / attn_k.bias / attn_v.bias in the GGUF — pre-allocation keeps the ivar types stable for Spinel (no reassign-after-construct).

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(d_model, n_heads, n_kv, rope_obj) ⇒ GQAttention

Returns a new instance of GQAttention.



574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
# File 'lib/toy.rb', line 574

def initialize(d_model, n_heads, n_kv, rope_obj)
  @d_model    = d_model
  @n_heads    = n_heads
  @n_kv       = n_kv
  @d_head     = d_model / n_heads
  @group_size = n_heads / n_kv
  @inv_sqrt   = 1.0 / Math.sqrt(@d_head)
  @rope       = rope_obj

  @w_q = [Mat.new(d_model, @d_head)]
  @b_q = [Array.new(@d_head, 0.0)]   # per Q head (zeros until enabled)
  hq = 1
  while hq < n_heads
    @w_q.push(Mat.new(d_model, @d_head))
    @b_q.push(Array.new(@d_head, 0.0))
    hq += 1
  end
  @w_k = [Mat.new(d_model, @d_head)]
  @w_v = [Mat.new(d_model, @d_head)]
  @b_k = [Array.new(@d_head, 0.0)]   # per KV head
  @b_v = [Array.new(@d_head, 0.0)]
  hkv = 1
  while hkv < n_kv
    @w_k.push(Mat.new(d_model, @d_head))
    @w_v.push(Mat.new(d_model, @d_head))
    @b_k.push(Array.new(@d_head, 0.0))
    @b_v.push(Array.new(@d_head, 0.0))
    hkv += 1
  end
  @w_o = Mat.new(d_model, d_model)
  @has_qkv_bias = false               # flipped by the loader for Qwen2.x
end

Instance Attribute Details

#b_kObject

Returns the value of attribute b_k.



569
570
571
# File 'lib/toy.rb', line 569

def b_k
  @b_k
end

#b_qObject

Returns the value of attribute b_q.



569
570
571
# File 'lib/toy.rb', line 569

def b_q
  @b_q
end

#b_vObject

Returns the value of attribute b_v.



569
570
571
# File 'lib/toy.rb', line 569

def b_v
  @b_v
end

#d_headObject

Returns the value of attribute d_head.



569
570
571
# File 'lib/toy.rb', line 569

def d_head
  @d_head
end

#d_modelObject

Returns the value of attribute d_model.



569
570
571
# File 'lib/toy.rb', line 569

def d_model
  @d_model
end

#group_sizeObject

Returns the value of attribute group_size.



569
570
571
# File 'lib/toy.rb', line 569

def group_size
  @group_size
end

#has_qkv_biasObject

Returns the value of attribute has_qkv_bias.



569
570
571
# File 'lib/toy.rb', line 569

def has_qkv_bias
  @has_qkv_bias
end

#inv_sqrtObject

Returns the value of attribute inv_sqrt.



569
570
571
# File 'lib/toy.rb', line 569

def inv_sqrt
  @inv_sqrt
end

#n_headsObject

Returns the value of attribute n_heads.



569
570
571
# File 'lib/toy.rb', line 569

def n_heads
  @n_heads
end

#n_kvObject

Returns the value of attribute n_kv.



569
570
571
# File 'lib/toy.rb', line 569

def n_kv
  @n_kv
end

#ropeObject

Returns the value of attribute rope.



569
570
571
# File 'lib/toy.rb', line 569

def rope
  @rope
end

#w_kObject

Returns the value of attribute w_k.



569
570
571
# File 'lib/toy.rb', line 569

def w_k
  @w_k
end

#w_oObject

Returns the value of attribute w_o.



569
570
571
# File 'lib/toy.rb', line 569

def w_o
  @w_o
end

#w_qObject

Returns the value of attribute w_q.



569
570
571
# File 'lib/toy.rb', line 569

def w_q
  @w_q
end

#w_vObject

Returns the value of attribute w_v.



569
570
571
# File 'lib/toy.rb', line 569

def w_v
  @w_v
end

Instance Method Details

#algorithmObject



690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
# File 'lib/toy.rb', line 690

def algorithm
  c = Toy::Card.new("GQAttention.forward(x, p_start)", "grouped-query + RoPE")
  c.add_input("x",       "R^{T×D}", "")
  c.add_input("p_start", "",       "")
  c.add_output("y",      "R^{T×D}", "")
  c.add_hyper("D",       @d_model.to_s)
  c.add_hyper("H",       @n_heads.to_s)
  c.add_hyper("H_kv",    @n_kv.to_s)
  c.add_hyper("g",       @group_size.to_s)
  c.add_hyper("D_h",     @d_head.to_s)
  c.add_param("W_Q^h",  "R^{D×D_h}", "per query head (h=1..H)")
  c.add_param("W_K^j, W_V^j", "R^{D×D_h}",
              "per KV head (j=1..H_kv); shared across g Q heads")
  if @has_qkv_bias
    c.add_param("b_Q^h",        "R^{D_h}", "per query head")
    c.add_param("b_K^j, b_V^j", "R^{D_h}", "per KV head — Qwen2.x convention")
    c.add_param("W_O", "R^{D×D}", "no output bias")
  else
    c.add_param("W_O", "R^{D×D}", "no biases — Llama convention")
  end
  c.step_loop("j ← 1, …, H_kv", "KV computed once per group")
  if @has_qkv_bias
    c.step_bind("k^j", "RoPE(x · W_K^j + b_K^j, p_start)", "k^j ∈ R^{T×D_h}")
    c.step_update("v^j", "x · W_V^j + b_V^j",              "v^j ∈ R^{T×D_h}", "V not rotated")
  else
    c.step_bind("k^j", "RoPE(x · W_K^j, p_start)",         "k^j ∈ R^{T×D_h}")
    c.step_update("v^j", "x · W_V^j",                      "v^j ∈ R^{T×D_h}", "V not rotated")
  end
  c.step_loop_close
  c.step_loop("h ← 1, …, H", "per query head")
  c.step_bind("j",   "⌊(h−1) / g⌋ + 1", "")
  if @has_qkv_bias
    c.step_bind("q^h", "RoPE(x · W_Q^h + b_Q^h, p_start)", "q^h ∈ R^{T×D_h}")
  else
    c.step_bind("q^h", "RoPE(x · W_Q^h, p_start)",         "q^h ∈ R^{T×D_h}")
  end
  c.step_bind("S^h", "CausalMask(q^h · (k^j)^⊤ / √D_h)", "S^h ∈ R^{T×T}")
  c.step_bind("o^h", "softmax_rows(S^h) · v^j",          "o^h ∈ R^{T×D_h}")
  c.step_loop_close
  c.step_bind("y", "concat(o^1, …, o^H) · W_O", "y ∈ R^{T×D}")
  c.step_return("y")
  c
end

#algorithm_cardObject



734
# File 'lib/toy.rb', line 734

def algorithm_card; algorithm.render_pseudocode; end

#attend(x, hq, k_h, v_h, pos_start) ⇒ Object

One query-head attention. x: [T, D] → [T, Dh].



660
661
662
663
664
665
666
667
668
669
670
671
# File 'lib/toy.rb', line 660

def attend(x, hq, k_h, v_h, pos_start)
  q_h = x.matmul(@w_q[hq])                 # [T, Dh]
  if @has_qkv_bias
    Toy.add_bias!(q_h, @b_q[hq])
  end
  @rope.rotate!(q_h, pos_start)
  scores = q_h.matmul_t(k_h)               # [T, T]
  scores.scale!(@inv_sqrt)
  Toy.causal_mask!(scores)
  Toy.softmax_rows!(scores)
  scores.matmul(v_h)                       # [T, Dh]
end

#enable_qkv_bias!Object

Called by the GGUF loader when attn_q.bias / attn_k.bias / attn_v.bias are present. Biases are already allocated; this just flips the flag.



610
611
612
# File 'lib/toy.rb', line 610

def enable_qkv_bias!
  @has_qkv_bias = true
end

#forward(x, pos_start) ⇒ Object

x: [T, D] → [T, D]. pos_start: absolute position of row 0 of x.



615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
# File 'lib/toy.rb', line 615

def forward(x, pos_start)
  # 1) project + rotate K, V once per KV head (n_kv times).
  k0 = x.matmul(@w_k[0])                 # [T, Dh]
  if @has_qkv_bias
    Toy.add_bias!(k0, @b_k[0])
  end
  @rope.rotate!(k0, pos_start)
  v0 = x.matmul(@w_v[0])                 # [T, Dh]  (V is not rotated)
  if @has_qkv_bias
    Toy.add_bias!(v0, @b_v[0])
  end
  ks = [k0]
  vs = [v0]
  hkv = 1
  while hkv < @n_kv
    k_h = x.matmul(@w_k[hkv])
    if @has_qkv_bias
      Toy.add_bias!(k_h, @b_k[hkv])
    end
    @rope.rotate!(k_h, pos_start)
    v_h = x.matmul(@w_v[hkv])
    if @has_qkv_bias
      Toy.add_bias!(v_h, @b_v[hkv])
    end
    ks.push(k_h)
    vs.push(v_h)
    hkv += 1
  end

  # 2) per query head: project Q, rotate, attend with the
  # corresponding (shared) K, V.
  head0  = attend(x, 0, ks[0], vs[0], pos_start)
  heads  = [head0]
  hq = 1
  while hq < @n_heads
    grp = hq / @group_size
    heads.push(attend(x, hq, ks[grp], vs[grp], pos_start))
    hq += 1
  end

  concat = Toy.hstack_heads(heads, @n_heads, @d_head, @d_model)   # [T, D]
  concat.matmul(@w_o)                                              # [T, D]
end

#param_countObject



678
679
680
681
682
683
684
685
686
687
688
# File 'lib/toy.rb', line 678

def param_count
  # Q: n_heads × (d_model × d_head). K/V: n_kv × (d_model × d_head). O: d_model²
  # Plus Q/K/V biases when enabled (Qwen2.x).
  n = @n_heads * @d_model * @d_head +
      2 * @n_kv * @d_model * @d_head +
      @d_model * @d_model
  if @has_qkv_bias
    n = n + @n_heads * @d_head + 2 * @n_kv * @d_head
  end
  n
end

#summaryObject



673
674
675
676
677
# File 'lib/toy.rb', line 673

def summary
  "GQAttention(d=" + @d_model.to_s +
    ", n_q=" + @n_heads.to_s + ", n_kv=" + @n_kv.to_s +
    ", d_head=" + @d_head.to_s + ", group=" + @group_size.to_s + ")"
end