Class: Toy::LLM::Blocks::GDNBlock

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/llm/blocks/gdn_block.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initializeGDNBlock

Returns a new instance of GDNBlock.



44
45
46
47
48
49
50
51
52
53
54
55
# File 'lib/toy/llm/blocks/gdn_block.rb', line 44

def initialize
  @gdn_d_model = 0; @gdn_s_v = 0; @gdn_n_heads = 0
  @t_rn_gamma = TinyNN.tnn_null_ptr
  @t_w_q = TinyNN.tnn_null_ptr; @t_w_k = TinyNN.tnn_null_ptr; @t_w_v = TinyNN.tnn_null_ptr
  @t_w_z = TinyNN.tnn_null_ptr; @t_w_a = TinyNN.tnn_null_ptr; @t_w_b = TinyNN.tnn_null_ptr
  @t_a_log = TinyNN.tnn_null_ptr; @t_dt_bias = TinyNN.tnn_null_ptr
  @t_go_gamma = TinyNN.tnn_null_ptr; @t_w_o = TinyNN.tnn_null_ptr
  @t_state0 = TinyNN.tnn_null_ptr
  @ft_weights = [TinyNN.tnn_null_ptr]; @ft_weights.pop
  @ft_m       = [TinyNN.tnn_null_ptr]; @ft_m.pop
  @ft_v       = [TinyNN.tnn_null_ptr]; @ft_v.pop
end

Instance Attribute Details

#ft_mObject

Returns the value of attribute ft_m.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def ft_m
  @ft_m
end

#ft_vObject

Returns the value of attribute ft_v.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def ft_v
  @ft_v
end

#ft_weightsObject

Returns the value of attribute ft_weights.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def ft_weights
  @ft_weights
end

#gdn_d_modelObject

Returns the value of attribute gdn_d_model.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def gdn_d_model
  @gdn_d_model
end

#gdn_n_headsObject

Returns the value of attribute gdn_n_heads.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def gdn_n_heads
  @gdn_n_heads
end

#gdn_s_vObject

Returns the value of attribute gdn_s_v.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def gdn_s_v
  @gdn_s_v
end

#t_a_logObject

Returns the value of attribute t_a_log.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_a_log
  @t_a_log
end

#t_dt_biasObject

Returns the value of attribute t_dt_bias.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_dt_bias
  @t_dt_bias
end

#t_go_gammaObject

Returns the value of attribute t_go_gamma.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_go_gamma
  @t_go_gamma
end

#t_rn_gammaObject

Returns the value of attribute t_rn_gamma.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_rn_gamma
  @t_rn_gamma
end

#t_state0Object

Returns the value of attribute t_state0.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_state0
  @t_state0
end

#t_w_aObject

Returns the value of attribute t_w_a.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_w_a
  @t_w_a
end

#t_w_bObject

Returns the value of attribute t_w_b.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_w_b
  @t_w_b
end

#t_w_kObject

Returns the value of attribute t_w_k.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_w_k
  @t_w_k
end

#t_w_oObject

Returns the value of attribute t_w_o.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_w_o
  @t_w_o
end

#t_w_qObject

Returns the value of attribute t_w_q.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_w_q
  @t_w_q
end

#t_w_vObject

Returns the value of attribute t_w_v.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_w_v
  @t_w_v
end

#t_w_zObject

Returns the value of attribute t_w_z.



31
32
33
# File 'lib/toy/llm/blocks/gdn_block.rb', line 31

def t_w_z
  @t_w_z
end

Instance Method Details

#alloc_trainable_f32_weights!(sess, d_model, s_v, n_heads) ⇒ Object

Allocate the block’s trainable persistent F32 weights + their Adam moments (parallel ft_weights/ft_m/ft_v arrays, populated in lockstep so the engine / a train loop can opt_step generically). d_model is the residual width; n_heads × s_v = the GDN inner width. state0 is a zeroed [s_v, s_v*n_heads] constant carry (one [s_v,s_v] block per head), NOT a param. Each weight’s m/v match its shape (opt_step_adamw asserts same-shape).



63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# File 'lib/toy/llm/blocks/gdn_block.rb', line 63

def alloc_trainable_f32_weights!(sess, d_model, s_v, n_heads)
  @gdn_d_model = d_model; @gdn_s_v = s_v; @gdn_n_heads = n_heads
  inner = s_v * n_heads
  # W : [d_model, out]  (matmul(W, h) contracts ne0=d_model -> [out, T]).
  # input_2d_f32_persistent(rows, cols) -> ne0=cols, ne1=rows, so pass
  # (out, d_model) to get ne0=d_model, ne1=out.
  @t_rn_gamma = reg1(sess, d_model)
  @t_w_q = reg2(sess, inner,   d_model)
  @t_w_k = reg2(sess, inner,   d_model)
  @t_w_v = reg2(sess, inner,   d_model)
  @t_w_z = reg2(sess, inner,   d_model)
  @t_w_a = reg2(sess, n_heads, d_model)
  @t_w_b = reg2(sess, n_heads, d_model)
  @t_a_log    = reg4(sess, 1, n_heads, 1, 1)
  @t_dt_bias  = reg4(sess, 1, n_heads, 1, 1)
  @t_go_gamma = reg1(sess, inner)
  @t_w_o = reg2(sess, d_model, inner)
  # Constant zero initial state (NOT registered as a trainable param).
  @t_state0 = TinyNN.tnn_input_2d_f32_persistent(sess, s_v, s_v * n_heads)
end

#build_forward(sess, t_x, seq_t, eps) ⇒ Object

Forward: residual update for x [d_model, T] (B=1). Returns [d_model, T]. Dims (d_model/s_v/n_heads) come from self (set at alloc) so this matches the seam’s per-layer call shape; seq_t/eps arrive from the forward ctx.



129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# File 'lib/toy/llm/blocks/gdn_block.rb', line 129

def build_forward(sess, t_x, seq_t, eps)
  d_model = @gdn_d_model
  s_v     = @gdn_s_v
  n_heads = @gdn_n_heads
  fbytes  = 4
  h = Toy::LLM::Primitives::RMSNorm.build(sess, t_x, @t_rn_gamma, eps)

  q2 = TinyNN.tnn_matmul(sess, @t_w_q, h)   # [S_v*H, T]
  k2 = TinyNN.tnn_matmul(sess, @t_w_k, h)
  v2 = TinyNN.tnn_matmul(sess, @t_w_v, h)
  z2 = TinyNN.tnn_matmul(sess, @t_w_z, h)   # [S_v*H, T] output gate
  a2 = TinyNN.tnn_matmul(sess, @t_w_a, h)   # [H, T] decay stream
  b2 = TinyNN.tnn_matmul(sess, @t_w_b, h)   # [H, T] update stream

  # Reshape projections into the recurrence's packed [S_v, H, T] / [1, H, T].
  q3 = TinyNN.tnn_reshape_3d(sess, q2, s_v, n_heads, seq_t)
  k3 = TinyNN.tnn_reshape_3d(sess, k2, s_v, n_heads, seq_t)
  v3 = TinyNN.tnn_reshape_3d(sess, v2, s_v, n_heads, seq_t)
  a3 = TinyNN.tnn_reshape_3d(sess, a2, 1,   n_heads, seq_t)
  b3 = TinyNN.tnn_reshape_3d(sess, b2, 1,   n_heads, seq_t)

  qn = Toy::LLM::Primitives::GDN.l2_train(sess, q3, eps)
  kn = Toy::LLM::Primitives::GDN.l2_train(sess, k3, eps)
  g  = Toy::LLM::Primitives::GDN.decay_gate(sess, a3, @t_dt_bias, @t_a_log)
  bt = Toy::LLM::Primitives::GDN.update_gate_train(sess, b3)

  # Per-head recurrence; concat head outputs along ne0 -> [S_v*H, T].
  o = TinyNN.tnn_null_ptr
  hh = 0
  while hh < n_heads
    st_h = TinyNN.tnn_view_2d(sess, @t_state0, s_v, s_v,
                              s_v * fbytes, hh * s_v * s_v * fbytes)
    o_h = Toy::LLM::Primitives::GDN.recur_unrolled(sess, qn, kn, v3, g, bt,
                                                   st_h, s_v, n_heads, hh, seq_t)
    if hh == 0
      o = o_h
    else
      o = TinyNN.tnn_concat(sess, o, o_h, 0)
    end
    hh = hh + 1
  end

  gated = Toy::LLM::Primitives::GDN.gated_out(sess, o, z2, @t_go_gamma, eps)
  out   = TinyNN.tnn_matmul(sess, @t_w_o, gated)   # [d_model, T]
  TinyNN.tnn_add(sess, t_x, out)
end

#reg1(sess, n) ⇒ Object

reg1,2,4: alloc a weight of the given rank + matching m/v, push the triple into ft_weights/ft_m/ft_v, return the weight handle.



86
87
88
89
90
91
92
# File 'lib/toy/llm/blocks/gdn_block.rb', line 86

def reg1(sess, n)
  w = TinyNN.tnn_input_1d_f32_persistent(sess, n)
  @ft_weights.push(w)
  @ft_m.push(TinyNN.tnn_input_1d_f32_persistent(sess, n))
  @ft_v.push(TinyNN.tnn_input_1d_f32_persistent(sess, n))
  w
end

#reg2(sess, rows, cols) ⇒ Object



94
95
96
97
98
99
100
# File 'lib/toy/llm/blocks/gdn_block.rb', line 94

def reg2(sess, rows, cols)
  w = TinyNN.tnn_input_2d_f32_persistent(sess, rows, cols)
  @ft_weights.push(w)
  @ft_m.push(TinyNN.tnn_input_2d_f32_persistent(sess, rows, cols))
  @ft_v.push(TinyNN.tnn_input_2d_f32_persistent(sess, rows, cols))
  w
end

#reg4(sess, a, b, c, d) ⇒ Object



102
103
104
105
106
107
108
# File 'lib/toy/llm/blocks/gdn_block.rb', line 102

def reg4(sess, a, b, c, d)
  w = TinyNN.tnn_input_4d_f32_persistent(sess, a, b, c, d)
  @ft_weights.push(w)
  @ft_m.push(TinyNN.tnn_input_4d_f32_persistent(sess, a, b, c, d))
  @ft_v.push(TinyNN.tnn_input_4d_f32_persistent(sess, a, b, c, d))
  w
end

#set_params!Object

Mark every projection weight a trainable param. Call BEFORE finalize_weights (load-bearing order, gpt2_seq_engine.rb:128). a_log + dt_bias ARE trained (per-head decay shape); state0 is NOT (it is not in ft_weights).



113
114
115
116
117
118
119
# File 'lib/toy/llm/blocks/gdn_block.rb', line 113

def set_params!
  wi = 0
  while wi < @ft_weights.length
    TinyNN.tnn_set_param(@ft_weights[wi])
    wi = wi + 1
  end
end

#zero_state!(sess) ⇒ Object

Zero the constant initial state (after finalize_weights).



122
123
124
# File 'lib/toy/llm/blocks/gdn_block.rb', line 122

def zero_state!(sess)
  TinyNN.tnn_zero_tensor(sess, @t_state0)
end