Class: Toy::RoPE

Inherits:
Object
  • Object
show all
Defined in:
lib/toy.rb

Overview

Toy::RoPE — rotary position embedding. Llama / SmolLM2 / Qwen2 form (“rotate_half”): split head_dim into two halves and rotate them against each other by an angle that scales with position.

Precomputes cos/sin tables at construction. Rotates in-place.

for freq k in [0, Dh/2):
  theta_k = theta_base^(-2k/Dh)
  angle   = pos * theta_k
  x[k],       x[k+Dh/2]  →
    x[k]      * cos(angle) - x[k+Dh/2] * sin(angle),
    x[k+Dh/2] * cos(angle) + x[k]      * sin(angle)

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(d_head, max_seq, theta_base) ⇒ RoPE

theta_base typical values: 10000 (Llama-1/2/TinyLlama), 100000 (SmolLM2), 1000000 (Qwen2 long-context). Renamed from ‘base` to dodge Spinel’s local-name collapse with int offsets named ‘base`.



477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
# File 'lib/toy.rb', line 477

def initialize(d_head, max_seq, theta_base)
  @d_head  = d_head
  @max_seq = max_seq
  half     = d_head / 2
  n        = max_seq * half
  @cos_tbl = Array.new(n, 1.0)
  @sin_tbl = Array.new(n, 0.0)
  log_b    = Math.log(theta_base.to_f)
  inv_dh   = 1.0 / d_head.to_f
  p = 0
  while p < max_seq
    k = 0
    while k < half
      theta = Math.exp(-2.0 * k.to_f * inv_dh * log_b)
      angle = p.to_f * theta
      @cos_tbl[p * half + k] = Math.cos(angle)
      @sin_tbl[p * half + k] = Math.sin(angle)
      k += 1
    end
    p += 1
  end
end

Instance Attribute Details

#cos_tblObject

Returns the value of attribute cos_tbl.



472
473
474
# File 'lib/toy.rb', line 472

def cos_tbl
  @cos_tbl
end

#d_headObject

Returns the value of attribute d_head.



472
473
474
# File 'lib/toy.rb', line 472

def d_head
  @d_head
end

#max_seqObject

Returns the value of attribute max_seq.



472
473
474
# File 'lib/toy.rb', line 472

def max_seq
  @max_seq
end

#sin_tblObject

Returns the value of attribute sin_tbl.



472
473
474
# File 'lib/toy.rb', line 472

def sin_tbl
  @sin_tbl
end

Instance Method Details

#algorithmObject



529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
# File 'lib/toy.rb', line 529

def algorithm
  c = Toy::Card.new("RoPE.rotate!(x, p_start)", "rotate_half / NEOX form")
  c.add_input("x",       "R^{T×D_h}", "one head's Q or K")
  c.add_input("p_start", "",         "absolute position of row 0")
  c.add_hyper("D_h",    @d_head.to_s)
  c.add_hyper("θ_base", "(cos/sin tables precomputed)")
  c.step_loop("t ← 0, …, T-1, k ← 0, …, D_h/2 − 1", "")
  c.step_bind("p", "p_start + t", "")
  c.step_bind("c, s",
              "cos(p · θ_base^{−2k/D_h}), sin(p · θ_base^{−2k/D_h})", "")
  c.step_update("(x[t,k], x[t,k+D_h/2])",
                "(x[t,k]·c − x[t,k+D_h/2]·s, x[t,k+D_h/2]·c + x[t,k]·s)",
                "", "")
  c.step_loop_close
  c
end

#algorithm_cardObject



546
# File 'lib/toy.rb', line 546

def algorithm_card; algorithm.render_pseudocode; end

#param_countObject

cos/sin tables are precomputed, not learned



527
# File 'lib/toy.rb', line 527

def param_count; 0; end

#rotate!(x, pos_start) ⇒ Object

x: [T, Dh] rotated in-place. Row t corresponds to absolute position (pos_start + t).



502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
# File 'lib/toy.rb', line 502

def rotate!(x, pos_start)
  t    = x.nrows
  dh   = @d_head
  half = dh / 2
  i = 0
  while i < t
    p   = pos_start + i
    row = i * dh
    ck = 0
    while ck < half
      co = @cos_tbl[p * half + ck]
      si = @sin_tbl[p * half + ck]
      xa = x.flat[row + ck]
      xb = x.flat[row + half + ck]
      x.flat[row + ck]        = xa * co - xb * si
      x.flat[row + half + ck] = xb * co + xa * si
      ck += 1
    end
    i += 1
  end
end

#summaryObject



524
525
526
# File 'lib/toy.rb', line 524

def summary
  "RoPE(d_head=" + @d_head.to_s + ", max_seq=" + @max_seq.to_s + ")"
end