Class: Toy::CausalSelfAttention
- Inherits:
-
Object
- Object
- Toy::CausalSelfAttention
- Defined in:
- lib/toy.rb
Overview
Toy::CausalSelfAttention
Per-head storage (Array<Mat> for Q/K/V) matches lib/transformer.rb's
Block layout, so the GGUF loader's split-heads helper works.
attn(x) =
q_h = (x · W_q^h + b_q^h) for h in 0..H
k_h = (x · W_k^h + b_k^h)
v_h = (x · W_v^h + b_v^h)
s_h = softmax(causal(q_h · k_h^T / sqrt(Dh))) · v_h
y = hstack(s_0..s_H) · W_o + b_o
Instance Attribute Summary collapse
-
#b_k ⇒ Object
Returns the value of attribute b_k.
-
#b_o ⇒ Object
Returns the value of attribute b_o.
-
#b_q ⇒ Object
Returns the value of attribute b_q.
-
#b_v ⇒ Object
Returns the value of attribute b_v.
-
#d_head ⇒ Object
Returns the value of attribute d_head.
-
#d_model ⇒ Object
Returns the value of attribute d_model.
-
#inv_sqrt ⇒ Object
Returns the value of attribute inv_sqrt.
-
#n_heads ⇒ Object
Returns the value of attribute n_heads.
-
#w_k ⇒ Object
Returns the value of attribute w_k.
-
#w_o ⇒ Object
Returns the value of attribute w_o.
-
#w_q ⇒ Object
Returns the value of attribute w_q.
-
#w_v ⇒ Object
Returns the value of attribute w_v.
Instance Method Summary collapse
-
#algorithm ⇒ Object
Algorithm card.
- #algorithm_card ⇒ Object
-
#forward(x) ⇒ Object
x: [T, D] → [T, D].
-
#head(x, h) ⇒ Object
One attention head.
-
#initialize(d_model, n_heads) ⇒ CausalSelfAttention
constructor
A new instance of CausalSelfAttention.
- #param_count ⇒ Object
- #summary ⇒ Object
Constructor Details
#initialize(d_model, n_heads) ⇒ CausalSelfAttention
Returns a new instance of CausalSelfAttention.
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
# File 'lib/toy.rb', line 195 def initialize(d_model, n_heads) @d_model = d_model @n_heads = n_heads @d_head = d_model / n_heads @inv_sqrt = 1.0 / Math.sqrt(@d_head) # Per-head: literal-seed + push so Spinel types as PtrArray of Mat / # PtrArray of FloatArray. @w_q = [Mat.new(d_model, @d_head)] @w_k = [Mat.new(d_model, @d_head)] @w_v = [Mat.new(d_model, @d_head)] @b_q = [Array.new(@d_head, 0.0)] @b_k = [Array.new(@d_head, 0.0)] @b_v = [Array.new(@d_head, 0.0)] h = 1 while h < n_heads @w_q.push(Mat.new(d_model, @d_head)) @w_k.push(Mat.new(d_model, @d_head)) @w_v.push(Mat.new(d_model, @d_head)) @b_q.push(Array.new(@d_head, 0.0)) @b_k.push(Array.new(@d_head, 0.0)) @b_v.push(Array.new(@d_head, 0.0)) h += 1 end @w_o = Mat.new(d_model, d_model) @b_o = Array.new(d_model, 0.0) end |
Instance Attribute Details
#b_k ⇒ Object
Returns the value of attribute b_k.
192 193 194 |
# File 'lib/toy.rb', line 192 def b_k @b_k end |
#b_o ⇒ Object
Returns the value of attribute b_o.
192 193 194 |
# File 'lib/toy.rb', line 192 def b_o @b_o end |
#b_q ⇒ Object
Returns the value of attribute b_q.
192 193 194 |
# File 'lib/toy.rb', line 192 def b_q @b_q end |
#b_v ⇒ Object
Returns the value of attribute b_v.
192 193 194 |
# File 'lib/toy.rb', line 192 def b_v @b_v end |
#d_head ⇒ Object
Returns the value of attribute d_head.
192 193 194 |
# File 'lib/toy.rb', line 192 def d_head @d_head end |
#d_model ⇒ Object
Returns the value of attribute d_model.
192 193 194 |
# File 'lib/toy.rb', line 192 def d_model @d_model end |
#inv_sqrt ⇒ Object
Returns the value of attribute inv_sqrt.
192 193 194 |
# File 'lib/toy.rb', line 192 def inv_sqrt @inv_sqrt end |
#n_heads ⇒ Object
Returns the value of attribute n_heads.
192 193 194 |
# File 'lib/toy.rb', line 192 def n_heads @n_heads end |
#w_k ⇒ Object
Returns the value of attribute w_k.
192 193 194 |
# File 'lib/toy.rb', line 192 def w_k @w_k end |
#w_o ⇒ Object
Returns the value of attribute w_o.
192 193 194 |
# File 'lib/toy.rb', line 192 def w_o @w_o end |
#w_q ⇒ Object
Returns the value of attribute w_q.
192 193 194 |
# File 'lib/toy.rb', line 192 def w_q @w_q end |
#w_v ⇒ Object
Returns the value of attribute w_v.
192 193 194 |
# File 'lib/toy.rb', line 192 def w_v @w_v end |
Instance Method Details
#algorithm ⇒ Object
Algorithm card. Shapes: x ∈ R^T×D; D_h = D/H.
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
# File 'lib/toy.rb', line 265 def algorithm c = Toy::Card.new("CausalSelfAttention.forward(x)", "") c.add_input("x", "R^{T×D}", "") c.add_output("y", "R^{T×D}", "") c.add_hyper("D", @d_model.to_s) c.add_hyper("H", @n_heads.to_s) c.add_hyper("D_h", @d_head.to_s) c.add_param("W_Q^h, W_K^h, W_V^h", "R^{D×D_h}", "") c.add_param("b_Q^h, b_K^h, b_V^h", "R^{D_h}", "") c.add_param("W_O", "R^{D×D}", "") c.add_param("b_O", "R^{D}", "") c.step_loop("h ← 1, …, H", "per head") c.step_bind("q^h", "x · W_Q^h + b_Q^h", "q^h ∈ R^{T×D_h}") c.step_bind("k^h", "x · W_K^h + b_K^h", "k^h ∈ R^{T×D_h}") c.step_bind("v^h", "x · W_V^h + b_V^h", "v^h ∈ R^{T×D_h}") c.step_bind("S^h", "q^h · (k^h)^⊤ / √D_h", "S^h ∈ R^{T×T}") c.step_update("S^h", "CausalMask(S^h)", "", "j>i ↦ −∞") c.step_bind("A^h", "softmax_rows(S^h)", "A^h ∈ R^{T×T}") c.step_bind("o^h", "A^h · v^h", "o^h ∈ R^{T×D_h}") c.step_loop_close c.step_bind("y", "concat(o^1, …, o^H) · W_O + b_O", "y ∈ R^{T×D}") c.step_return("y") c end |
#algorithm_card ⇒ Object
290 |
# File 'lib/toy.rb', line 290 def algorithm_card; algorithm.render_pseudocode; end |
#forward(x) ⇒ Object
x: [T, D] → [T, D]
224 225 226 227 228 229 230 231 232 233 234 235 236 |
# File 'lib/toy.rb', line 224 def forward(x) head0 = head(x, 0) # [T, Dh] per_head = [head0] h = 1 while h < @n_heads per_head.push(head(x, h)) # [T, Dh] each h += 1 end concat = Toy.hstack_heads(per_head, @n_heads, @d_head, @d_model) # [T, D] out = concat.matmul(@w_o) # [T, D] Toy.add_bias!(out, @b_o) out end |
#head(x, h) ⇒ Object
One attention head. x: [T, D] → [T, Dh].
239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
# File 'lib/toy.rb', line 239 def head(x, h) q = x.matmul(@w_q[h]) Toy.add_bias!(q, @b_q[h]) # [T, Dh] k = x.matmul(@w_k[h]) Toy.add_bias!(k, @b_k[h]) # [T, Dh] v = x.matmul(@w_v[h]) Toy.add_bias!(v, @b_v[h]) # [T, Dh] scores = q.matmul_t(k) # [T, T] scores.scale!(@inv_sqrt) Toy.causal_mask!(scores) Toy.softmax_rows!(scores) scores.matmul(v) # [T, Dh] end |
#param_count ⇒ Object
258 259 260 261 262 |
# File 'lib/toy.rb', line 258 def param_count # n_heads × (d_model * d_head + d_head) for Q/K/V, plus d_model² + d_model for W_o + b_o per_head = (@d_model * @d_head + @d_head) * 3 per_head * @n_heads + @d_model * @d_model + @d_model end |
#summary ⇒ Object
254 255 256 257 |
# File 'lib/toy.rb', line 254 def summary "CausalSelfAttention(d_model=" + @d_model.to_s + ", heads=" + @n_heads.to_s + ", d_head=" + @d_head.to_s + ")" end |