Class: Toy::LLM::Engine::GPT2SeqEngineCuda

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/llm/engine/gpt2_seq_engine_cuda.rb

Constant Summary collapse

LN_EPS =
1.0e-5

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initializeGPT2SeqEngineCuda

Returns a new instance of GPT2SeqEngineCuda.



52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 52

def initialize
  @sess = TinyNNCuda.tnn_null_ptr
  @g_vocab = 0; @g_d_model = 0; @g_n_heads = 0; @g_d_head = 0
  @g_d_ff = 0; @g_n_layers = 0; @g_context = 0
  @g_t_tok = TinyNNCuda.tnn_null_ptr; @g_t_pos = TinyNNCuda.tnn_null_ptr
  @g_t_labels = TinyNNCuda.tnn_null_ptr; @g_t_hp = TinyNNCuda.tnn_null_ptr
  @g_t_loss = TinyNNCuda.tnn_null_ptr; @g_t_logits = TinyNNCuda.tnn_null_ptr
  @g_wte = TinyNNCuda.tnn_null_ptr; @g_wpe = TinyNNCuda.tnn_null_ptr
  @g_lnf_g = TinyNNCuda.tnn_null_ptr; @g_lnf_b = TinyNNCuda.tnn_null_ptr
  @g_ln1_g = [TinyNNCuda.tnn_null_ptr]; @g_ln1_g.pop
  @g_ln1_b = [TinyNNCuda.tnn_null_ptr]; @g_ln1_b.pop
  @g_ln2_g = [TinyNNCuda.tnn_null_ptr]; @g_ln2_g.pop
  @g_ln2_b = [TinyNNCuda.tnn_null_ptr]; @g_ln2_b.pop
  @g_w_q = [TinyNNCuda.tnn_null_ptr]; @g_w_q.pop
  @g_b_q = [TinyNNCuda.tnn_null_ptr]; @g_b_q.pop
  @g_w_k = [TinyNNCuda.tnn_null_ptr]; @g_w_k.pop
  @g_b_k = [TinyNNCuda.tnn_null_ptr]; @g_b_k.pop
  @g_w_v = [TinyNNCuda.tnn_null_ptr]; @g_w_v.pop
  @g_b_v = [TinyNNCuda.tnn_null_ptr]; @g_b_v.pop
  @g_w_o = [TinyNNCuda.tnn_null_ptr]; @g_w_o.pop
  @g_b_o = [TinyNNCuda.tnn_null_ptr]; @g_b_o.pop
  @g_fc_w = [TinyNNCuda.tnn_null_ptr]; @g_fc_w.pop
  @g_fc_b = [TinyNNCuda.tnn_null_ptr]; @g_fc_b.pop
  @g_pr_w = [TinyNNCuda.tnn_null_ptr]; @g_pr_w.pop
  @g_pr_b = [TinyNNCuda.tnn_null_ptr]; @g_pr_b.pop
  @g_weights = [TinyNNCuda.tnn_null_ptr]; @g_weights.pop
  @g_opt_m   = [TinyNNCuda.tnn_null_ptr]; @g_opt_m.pop
  @g_opt_v   = [TinyNNCuda.tnn_null_ptr]; @g_opt_v.pop
  @g_rng = 0
end

Instance Attribute Details

#g_b_kObject

Returns the value of attribute g_b_k.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_b_k
  @g_b_k
end

#g_b_oObject

Returns the value of attribute g_b_o.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_b_o
  @g_b_o
end

#g_b_qObject

Returns the value of attribute g_b_q.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_b_q
  @g_b_q
end

#g_b_vObject

Returns the value of attribute g_b_v.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_b_v
  @g_b_v
end

#g_cb_rcObject

Returns the value of attribute g_cb_rc.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_cb_rc
  @g_cb_rc
end

#g_contextObject

Returns the value of attribute g_context.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_context
  @g_context
end

#g_d_ffObject

Returns the value of attribute g_d_ff.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_d_ff
  @g_d_ff
end

#g_d_headObject

Returns the value of attribute g_d_head.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_d_head
  @g_d_head
end

#g_d_modelObject

Returns the value of attribute g_d_model.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_d_model
  @g_d_model
end

#g_fc_bObject

Returns the value of attribute g_fc_b.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_fc_b
  @g_fc_b
end

#g_fc_wObject

Returns the value of attribute g_fc_w.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_fc_w
  @g_fc_w
end

#g_ln1_bObject

Returns the value of attribute g_ln1_b.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_ln1_b
  @g_ln1_b
end

#g_ln1_gObject

Returns the value of attribute g_ln1_g.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_ln1_g
  @g_ln1_g
end

#g_ln2_bObject

Returns the value of attribute g_ln2_b.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_ln2_b
  @g_ln2_b
end

#g_ln2_gObject

Returns the value of attribute g_ln2_g.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_ln2_g
  @g_ln2_g
end

#g_lnf_bObject

Returns the value of attribute g_lnf_b.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_lnf_b
  @g_lnf_b
end

#g_lnf_gObject

Returns the value of attribute g_lnf_g.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_lnf_g
  @g_lnf_g
end

#g_n_headsObject

Returns the value of attribute g_n_heads.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_n_heads
  @g_n_heads
end

#g_n_layersObject

Returns the value of attribute g_n_layers.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_n_layers
  @g_n_layers
end

#g_opt_mObject

Returns the value of attribute g_opt_m.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_opt_m
  @g_opt_m
end

#g_opt_vObject

Returns the value of attribute g_opt_v.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_opt_v
  @g_opt_v
end

#g_pr_bObject

Returns the value of attribute g_pr_b.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_pr_b
  @g_pr_b
end

#g_pr_wObject

Returns the value of attribute g_pr_w.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_pr_w
  @g_pr_w
end

#g_rb_rcObject

Returns the value of attribute g_rb_rc.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_rb_rc
  @g_rb_rc
end

#g_rngObject

Returns the value of attribute g_rng.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_rng
  @g_rng
end

#g_t_hpObject

Returns the value of attribute g_t_hp.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_t_hp
  @g_t_hp
end

#g_t_labelsObject

Returns the value of attribute g_t_labels.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_t_labels
  @g_t_labels
end

#g_t_logitsObject

Returns the value of attribute g_t_logits.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_t_logits
  @g_t_logits
end

#g_t_lossObject

Returns the value of attribute g_t_loss.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_t_loss
  @g_t_loss
end

#g_t_posObject

Returns the value of attribute g_t_pos.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_t_pos
  @g_t_pos
end

#g_t_tokObject

Returns the value of attribute g_t_tok.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_t_tok
  @g_t_tok
end

#g_vocabObject

Returns the value of attribute g_vocab.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_vocab
  @g_vocab
end

#g_w_kObject

Returns the value of attribute g_w_k.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_w_k
  @g_w_k
end

#g_w_oObject

Returns the value of attribute g_w_o.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_w_o
  @g_w_o
end

#g_w_qObject

Returns the value of attribute g_w_q.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_w_q
  @g_w_q
end

#g_w_vObject

Returns the value of attribute g_w_v.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_w_v
  @g_w_v
end

#g_weightsObject

Returns the value of attribute g_weights.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_weights
  @g_weights
end

#g_wpeObject

Returns the value of attribute g_wpe.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_wpe
  @g_wpe
end

#g_wteObject

Returns the value of attribute g_wte.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def g_wte
  @g_wte
end

#sessObject

Returns the value of attribute sess.



37
38
39
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 37

def sess
  @sess
end

Instance Method Details

#alloc_w1(inits, n, init_mat) ⇒ Object



122
123
124
125
126
127
128
129
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 122

def alloc_w1(inits, n, init_mat)
  w = TinyNNCuda.tnn_input_1d_f32_persistent(@sess, n)
  @g_weights.push(w)
  @g_opt_m.push(TinyNNCuda.tnn_input_1d_f32_persistent(@sess, n))
  @g_opt_v.push(TinyNNCuda.tnn_input_1d_f32_persistent(@sess, n))
  inits.push(init_mat)
  w
end

#alloc_w2(inits, rows, cols, init_mat) ⇒ Object

alloc-only (buffers don’t exist until tnn_finalize_weights); records (weight, m, v) into the optimizer arrays and the init Mat into ‘inits`.



113
114
115
116
117
118
119
120
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 113

def alloc_w2(inits, rows, cols, init_mat)
  w = TinyNNCuda.tnn_input_2d_f32_persistent(@sess, rows, cols)
  @g_weights.push(w)
  @g_opt_m.push(TinyNNCuda.tnn_input_2d_f32_persistent(@sess, rows, cols))
  @g_opt_v.push(TinyNNCuda.tnn_input_2d_f32_persistent(@sess, rows, cols))
  inits.push(init_mat)
  w
end

#build_forward!Object

GPT-2 forward → @g_t_logits (tied unembed).



199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 199

def build_forward!
  @g_t_tok = TinyNNCuda.tnn_input_1d_i32(@sess, @g_context)
  @g_t_pos = TinyNNCuda.tnn_input_1d_i32(@sess, @g_context)

  x = TinyNNCuda.tnn_add(@sess,
        TinyNNCuda.tnn_get_rows(@sess, @g_wte, @g_t_tok),
        TinyNNCuda.tnn_get_rows(@sess, @g_wpe, @g_t_pos))
  TinyNNCuda.tnn_set_output(x)

  att_scale = 1.0 / Math.sqrt(@g_d_head.to_f)
  li = 0
  while li < @g_n_layers
    # attention sub-block (per-head loop + concat)
    h1 = TinyNNCuda.tnn_layer_norm(@sess, x, @g_ln1_g[li], @g_ln1_b[li], LN_EPS)
    head_out = TinyNNCuda.tnn_null_ptr
    hh = 0
    while hh < @g_n_heads
      hi = li * @g_n_heads + hh
      q = TinyNNCuda.tnn_add(@sess, TinyNNCuda.tnn_matmul(@sess, @g_w_q[hi], h1), @g_b_q[hi])
      k = TinyNNCuda.tnn_add(@sess, TinyNNCuda.tnn_matmul(@sess, @g_w_k[hi], h1), @g_b_k[hi])
      v = TinyNNCuda.tnn_matmul(@sess, @g_w_v[hi], h1)   # bias added to output
      scores = TinyNNCuda.tnn_scale(@sess, TinyNNCuda.tnn_matmul(@sess, k, q), att_scale)
      scores = TinyNNCuda.tnn_diag_mask_inf(@sess, scores, 0)
      probs  = TinyNNCuda.tnn_softmax(@sess, scores)
      v_t    = TinyNNCuda.tnn_cont_2d(@sess, TinyNNCuda.tnn_transpose(@sess, v), @g_context, @g_d_head)
      head   = TinyNNCuda.tnn_add(@sess, TinyNNCuda.tnn_matmul(@sess, v_t, probs), @g_b_v[hi])
      if hh == 0
        head_out = head
      else
        head_out = TinyNNCuda.tnn_concat(@sess, head_out, head, 0)
      end
      hh = hh + 1
    end
    ao = TinyNNCuda.tnn_add(@sess, TinyNNCuda.tnn_matmul(@sess, @g_w_o[li], head_out), @g_b_o[li])
    x  = TinyNNCuda.tnn_add(@sess, x, ao)
    TinyNNCuda.tnn_set_output(x)

    # FFN sub-block
    h2  = TinyNNCuda.tnn_layer_norm(@sess, x, @g_ln2_g[li], @g_ln2_b[li], LN_EPS)
    pre = TinyNNCuda.tnn_add(@sess, TinyNNCuda.tnn_matmul(@sess, @g_fc_w[li], h2), @g_fc_b[li])
    act = TinyNNCuda.tnn_gelu(@sess, pre)
    mlp = TinyNNCuda.tnn_add(@sess, TinyNNCuda.tnn_matmul(@sess, @g_pr_w[li], act), @g_pr_b[li])
    x   = TinyNNCuda.tnn_add(@sess, x, mlp)
    TinyNNCuda.tnn_set_output(x)
    li = li + 1
  end

  x_final = TinyNNCuda.tnn_layer_norm(@sess, x, @g_lnf_g, @g_lnf_b, LN_EPS)
  TinyNNCuda.tnn_set_output(x_final)
  @g_t_logits = TinyNNCuda.tnn_matmul(@sess, @g_wte, x_final)   # tied
  TinyNNCuda.tnn_set_output(@g_t_logits)
  nil
end

#build_train_step!Object

CE loss + backward + opt_step_adamw per weight.



254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 254

def build_train_step!
  @g_t_labels = TinyNNCuda.tnn_input_2d_f32(@sess, @g_context, @g_vocab)
  @g_t_hp     = TinyNNCuda.tnn_input_1d_f32(@sess, 7)
  @g_t_loss   = TinyNNCuda.tnn_cross_entropy_loss(@sess, @g_t_logits, @g_t_labels)
  TinyNNCuda.tnn_set_output(@g_t_loss)
  TinyNNCuda.tnn_set_loss(@g_t_loss)

  TinyNNCuda.tnn_build_forward_only(@sess, @g_t_loss)
  TinyNNCuda.tnn_build_backward(@sess)

  gj = 0
  while gj < @g_weights.length
    tw = @g_weights[gj]
    tg = TinyNNCuda.tnn_tensor_grad(@sess, tw)
    to = TinyNNCuda.tnn_opt_step_adamw(@sess, tw, tg, @g_opt_m[gj], @g_opt_v[gj], @g_t_hp)
    TinyNNCuda.tnn_extend_backward_graph(@sess, to)
    gj = gj + 1
  end
  @g_rb_rc = TinyNNCuda.tnn_realize_backward(@sess)
  nil
end

#const_mat(rows, cols, value) ⇒ Object



100
101
102
103
104
105
106
107
108
109
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 100

def const_mat(rows, cols, value)
  m = Mat.new(rows, cols)
  n = rows * cols
  i = 0
  while i < n
    m.flat[i] = value
    i = i + 1
  end
  m
end

#rand_unitObject

seeded LCG → ~[-1,1)



84
85
86
87
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 84

def rand_unit
  @g_rng = ((@g_rng * 1103515245) + 12345) & 0x7fffffff
  ((@g_rng >> 8).to_f / 8388608.0) - 1.0
end

#random_mat(rows, cols, scale) ⇒ Object



89
90
91
92
93
94
95
96
97
98
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 89

def random_mat(rows, cols, scale)
  m = Mat.new(rows, cols)
  n = rows * cols
  i = 0
  while i < n
    m.flat[i] = rand_unit * scale
    i = i + 1
  end
  m
end

#realize!(vocab, d_model, n_heads, d_ff, n_layers, context, seed) ⇒ Object

Build the full random-init training graph. Realize ordering is load-bearing (alloc → set_param → finalize_weights → upload → backward →realize_backward); uploading a persistent weight before finalize aborts (“tensor buffer not set”).



135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 135

def realize!(vocab, d_model, n_heads, d_ff, n_layers, context, seed)
  @g_vocab = vocab; @g_d_model = d_model; @g_n_heads = n_heads
  @g_d_head = d_model / n_heads; @g_d_ff = d_ff
  @g_n_layers = n_layers; @g_context = context
  @g_rng = seed
  @sess = TinyNNCuda.tnn_session_new(1)
  # Per-head decomposition makes node count scale O(n_layers × n_heads);
  # budget like LlamaSeqEngine (the default cap overflows at backward-expand
  # on bigger shapes). Must precede realize (no compute tensors stored yet).
  TinyNNCuda.tnn_session_set_graph_capacity(@sess, n_layers * n_heads * 1000 + 65536)

  inits = [Mat.new(1, 1)]; inits.pop

  @g_wte = alloc_w2(inits, vocab,   d_model, random_mat(vocab,   d_model, 0.02))
  @g_wpe = alloc_w2(inits, context, d_model, random_mat(context, d_model, 0.02))

  li = 0
  while li < n_layers
    @g_ln1_g.push(alloc_w1(inits, d_model, const_mat(1, d_model, 1.0)))
    @g_ln1_b.push(alloc_w1(inits, d_model, const_mat(1, d_model, 0.0)))
    hh = 0
    while hh < n_heads
      @g_w_q.push(alloc_w2(inits, @g_d_head, d_model, random_mat(@g_d_head, d_model, 0.02)))
      @g_b_q.push(alloc_w1(inits, @g_d_head, const_mat(1, @g_d_head, 0.0)))
      @g_w_k.push(alloc_w2(inits, @g_d_head, d_model, random_mat(@g_d_head, d_model, 0.02)))
      @g_b_k.push(alloc_w1(inits, @g_d_head, const_mat(1, @g_d_head, 0.0)))
      @g_w_v.push(alloc_w2(inits, @g_d_head, d_model, random_mat(@g_d_head, d_model, 0.02)))
      @g_b_v.push(alloc_w1(inits, @g_d_head, const_mat(1, @g_d_head, 0.0)))
      hh = hh + 1
    end
    @g_w_o.push(alloc_w2(inits, d_model, d_model, random_mat(d_model, d_model, 0.02)))
    @g_b_o.push(alloc_w1(inits, d_model, const_mat(1, d_model, 0.0)))
    @g_ln2_g.push(alloc_w1(inits, d_model, const_mat(1, d_model, 1.0)))
    @g_ln2_b.push(alloc_w1(inits, d_model, const_mat(1, d_model, 0.0)))
    @g_fc_w.push(alloc_w2(inits, d_ff, d_model, random_mat(d_ff, d_model, 0.02)))
    @g_fc_b.push(alloc_w1(inits, d_ff, const_mat(1, d_ff, 0.0)))
    @g_pr_w.push(alloc_w2(inits, d_model, d_ff, random_mat(d_model, d_ff, 0.02)))
    @g_pr_b.push(alloc_w1(inits, d_model, const_mat(1, d_model, 0.0)))
    li = li + 1
  end
  @g_lnf_g = alloc_w1(inits, d_model, const_mat(1, d_model, 1.0))
  @g_lnf_b = alloc_w1(inits, d_model, const_mat(1, d_model, 0.0))

  gi = 0
  while gi < @g_weights.length
    TinyNNCuda.tnn_set_param(@g_weights[gi])
    gi = gi + 1
  end
  TinyNNCuda.tnn_finalize_weights(@sess)

  gk = 0
  while gk < @g_weights.length
    TinyNNCuda.upload_row_major(@sess, @g_weights[gk], inits[gk])
    TinyNNCuda.tnn_zero_tensor(@sess, @g_opt_m[gk])
    TinyNNCuda.tnn_zero_tensor(@sess, @g_opt_v[gk])
    gk = gk + 1
  end

  build_forward!
  build_train_step!
  nil
end

#step!(seq_ids, positions, m_labels, m_hp, is_first) ⇒ Object

One training step. is_first selects full reset vs grads-only (momenta persist). Returns the loss Float.



278
279
280
281
282
283
284
285
286
287
288
289
290
291
# File 'lib/toy/llm/engine/gpt2_seq_engine_cuda.rb', line 278

def step!(seq_ids, positions, m_labels, m_hp, is_first)
  if is_first
    TinyNNCuda.tnn_graph_reset(@sess)
  else
    TinyNNCuda.tnn_graph_reset_grads_only(@sess)
  end
  TinyNNCuda.upload_int_array(@sess, @g_t_tok, seq_ids)
  TinyNNCuda.upload_int_array(@sess, @g_t_pos, positions)
  TinyNNCuda.upload_row_major(@sess, @g_t_labels, m_labels)
  TinyNNCuda.upload_row_major(@sess, @g_t_hp, m_hp)
  @g_cb_rc = TinyNNCuda.tnn_compute_backward(@sess)
  TinyNNCuda.tnn_download(@sess, @g_t_loss)
  TinyNNCuda.tnn_scratch_get(@sess, 0)
end