Module: Sampler

Defined in:
lib/toy/train/sampler.rb

Class Method Summary collapse

Class Method Details

.argmax(logits) ⇒ Object



268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# File 'lib/toy/train/sampler.rb', line 268

def self.argmax(logits)
  n = logits.ncols
  best_i = 0
  best_v = logits.flat[0]
  j = 1
  while j < n
    v = logits.flat[j]
    if v > best_v
      best_v = v
      best_i = j
    end
    j = j + 1
  end
  best_i
end

.multinomial(logits, ctx) ⇒ Object



284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# File 'lib/toy/train/sampler.rb', line 284

def self.multinomial(logits, ctx)
  n = logits.ncols
  # softmax
  max_v = NEG_INF_SCORE
  j = 0
  while j < n
    v = logits.flat[j]
    if v > max_v
      max_v = v
    end
    j = j + 1
  end
  sum = 0.0
  j = 0
  while j < n
    sum = sum + Math.exp(logits.flat[j] - max_v)
    j = j + 1
  end
  target = ctx.next_unit * sum
  cum = 0.0
  j = 0
  while j < n
    cum = cum + Math.exp(logits.flat[j] - max_v)
    if cum >= target
      return j
    end
    j = j + 1
  end
  n - 1
end

.pick(logits, cfg, ctx) ⇒ Object

Final pick. If cfg.temperature <= 0, return argmax (greedy). Otherwise softmax + multinomial draw using ctx’s RNG.



261
262
263
264
265
266
# File 'lib/toy/train/sampler.rb', line 261

def self.pick(logits, cfg, ctx)
  if cfg.temperature <= 0.0
    return Sampler.argmax(logits)
  end
  Sampler.multinomial(logits, ctx)
end

.repetition_penalty(logits, ctx, p) ⇒ Object

Subtract ‘rep_penalty` from logits of any token already in the generated context. Default 1.0 = disabled. The HF convention DIVIDES positive logits and MULTIPLIES negative ones; we use the simpler subtract-on-positive variant to keep it Spinel-friendly. For most fine-tunes a value of 1.05–1.2 is reasonable.



82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# File 'lib/toy/train/sampler.rb', line 82

def self.repetition_penalty(logits, ctx, p)
  if p <= 1.0
    return logits
  end
  seen = ctx.generated_ids
  i = 0
  while i < seen.length
    tid = seen[i]
    if tid >= 0 && tid < logits.ncols
      v = logits.flat[tid]
      if v > 0.0
        logits.flat[tid] = v / p
      else
        logits.flat[tid] = v * p
      end
    end
    i = i + 1
  end
  logits
end

.temperature(logits, t) ⇒ Object

In-place divide by temperature. T=0 means “do nothing here, let argmax_or_multinomial fall through to argmax.”



63
64
65
66
67
68
69
70
71
72
73
74
75
# File 'lib/toy/train/sampler.rb', line 63

def self.temperature(logits, t)
  if t <= 0.0 || t == 1.0
    return logits
  end
  inv = 1.0 / t
  n = logits.ncols
  j = 0
  while j < n
    logits.flat[j] = logits.flat[j] * inv
    j = j + 1
  end
  logits
end

.top_k(logits, k) ⇒ Object

Keep top-k logits; mask the rest with -INFINITY (= -1e30 here so softmax never sees -Inf). k=0 disables.



105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# File 'lib/toy/train/sampler.rb', line 105

def self.top_k(logits, k)
  if k <= 0 || k >= logits.ncols
    return logits
  end
  n = logits.ncols
  # Find the k-th largest by k passes of argmax. O(k*n); fine for
  # k ≤ ~100 at vocab=150K. For bigger k a real partial-sort would
  # be better; current sizes don't need it.
  kept = [0]
  kept.pop
  snapshot = [0.0]
  snapshot.pop
  j = 0
  while j < n
    snapshot.push(logits.flat[j])
    j = j + 1
  end
  pass = 0
  while pass < k
    best_i = -1
    best_v = NEG_INF_SCORE
    j = 0
    while j < n
      v = snapshot[j]
      if v > best_v
        best_v = v
        best_i = j
      end
      j = j + 1
    end
    if best_i < 0
      # already all masked
      return logits
    end
    kept.push(best_i)
    snapshot[best_i] = NEG_INF_SCORE
    pass = pass + 1
  end
  # Build keep-set as a flag array
  keep = [false]
  keep.pop
  j = 0
  while j < n
    keep.push(false)
    j = j + 1
  end
  j = 0
  while j < kept.length
    keep[kept[j]] = true
    j = j + 1
  end
  j = 0
  while j < n
    if !keep[j]
      logits.flat[j] = NEG_INF_SCORE
    end
    j = j + 1
  end
  logits
end

.top_p(logits, p) ⇒ Object

Top-p / nucleus: softmax → cumulative sort → keep smallest set whose probability mass ≥ p. p>=1 disables.



168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# File 'lib/toy/train/sampler.rb', line 168

def self.top_p(logits, p)
  if p >= 1.0 || p <= 0.0
    return logits
  end
  n = logits.ncols
  # softmax in place into a copy
  probs = [0.0]
  probs.pop
  max_v = NEG_INF_SCORE
  j = 0
  while j < n
    v = logits.flat[j]
    if v > max_v
      max_v = v
    end
    j = j + 1
  end
  sum = 0.0
  j = 0
  while j < n
    e = Math.exp(logits.flat[j] - max_v)
    probs.push(e)
    sum = sum + e
    j = j + 1
  end
  inv_sum = 1.0 / sum
  j = 0
  while j < n
    probs[j] = probs[j] * inv_sum
    j = j + 1
  end
  # Sort indices by descending prob (selection-sort with mark; O(n^2)).
  # Acceptable because vocab ≤ 200K and we typically prune via top_k
  # first; a partial-sort would help if top_p were used solo at full
  # vocab.
  order = [0]
  order.pop
  taken = [false]
  taken.pop
  j = 0
  while j < n
    taken.push(false)
    j = j + 1
  end
  cum = 0.0
  pass = 0
  while pass < n
    best_i = -1
    best_v = -1.0
    j = 0
    while j < n
      if !taken[j] && probs[j] > best_v
        best_v = probs[j]
        best_i = j
      end
      j = j + 1
    end
    if best_i < 0
      break
    end
    taken[best_i] = true
    cum = cum + best_v
    order.push(best_i)
    if cum >= p
      break
    end
    pass = pass + 1
  end
  # Mask anything NOT in `order`.
  keep = [false]
  keep.pop
  j = 0
  while j < n
    keep.push(false)
    j = j + 1
  end
  j = 0
  while j < order.length
    keep[order[j]] = true
    j = j + 1
  end
  j = 0
  while j < n
    if !keep[j]
      logits.flat[j] = NEG_INF_SCORE
    end
    j = j + 1
  end
  logits
end