Class: Block

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/models/transformer.rb

Overview

Block: a transformer block's parameters.

norm1_gamma, norm2_gamma : Array of Float (length d_model)
w_q, w_k, w_v            : Array of Mat (one per head, each d_model × d_head)
w_o, w_ff1, w_ff2        : Mat

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(d_model, d_head, d_ff, n_heads) ⇒ Block

Zero-initializes everything. Call .fill_random_all(scale) for params, leave as-is for gradients / Adam moments.



294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# File 'lib/toy/models/transformer.rb', line 294

def initialize(d_model, d_head, d_ff, n_heads)
  @norm1_gamma = Array.new(d_model, 1.0)
  @norm2_gamma = Array.new(d_model, 1.0)

  # Per-head matrices: literal-seed pattern so Spinel types as PtrArray of Mat.
  @w_q = [Mat.new(d_model, d_head)]
  @w_k = [Mat.new(d_model, d_head)]
  @w_v = [Mat.new(d_model, d_head)]
  h = 1
  while h < n_heads
    @w_q.push(Mat.new(d_model, d_head))
    @w_k.push(Mat.new(d_model, d_head))
    @w_v.push(Mat.new(d_model, d_head))
    h += 1
  end

  @w_o   = Mat.new(d_model, d_model)
  @w_ff1 = Mat.new(d_model, d_ff)
  @w_ff2 = Mat.new(d_ff,    d_model)
end

Instance Attribute Details

#norm1_gammaObject

Returns the value of attribute norm1_gamma.



289
290
291
# File 'lib/toy/models/transformer.rb', line 289

def norm1_gamma
  @norm1_gamma
end

#norm2_gammaObject

Returns the value of attribute norm2_gamma.



289
290
291
# File 'lib/toy/models/transformer.rb', line 289

def norm2_gamma
  @norm2_gamma
end

#w_ff1Object

Returns the value of attribute w_ff1.



289
290
291
# File 'lib/toy/models/transformer.rb', line 289

def w_ff1
  @w_ff1
end

#w_ff2Object

Returns the value of attribute w_ff2.



289
290
291
# File 'lib/toy/models/transformer.rb', line 289

def w_ff2
  @w_ff2
end

#w_kObject

Returns the value of attribute w_k.



289
290
291
# File 'lib/toy/models/transformer.rb', line 289

def w_k
  @w_k
end

#w_oObject

Returns the value of attribute w_o.



289
290
291
# File 'lib/toy/models/transformer.rb', line 289

def w_o
  @w_o
end

#w_qObject

Returns the value of attribute w_q.



289
290
291
# File 'lib/toy/models/transformer.rb', line 289

def w_q
  @w_q
end

#w_vObject

Returns the value of attribute w_v.



289
290
291
# File 'lib/toy/models/transformer.rb', line 289

def w_v
  @w_v
end

Instance Method Details

#fill_random_all(scale) ⇒ Object



315
316
317
318
319
320
321
322
323
324
325
326
# File 'lib/toy/models/transformer.rb', line 315

def fill_random_all(scale)
  h = 0
  while h < @w_q.length
    @w_q[h].fill_random(scale)
    @w_k[h].fill_random(scale)
    @w_v[h].fill_random(scale)
    h += 1
  end
  @w_o.fill_random(scale)
  @w_ff1.fill_random(scale)
  @w_ff2.fill_random(scale)
end

#fill_zeroObject



328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
# File 'lib/toy/models/transformer.rb', line 328

def fill_zero
  n = @norm1_gamma.length
  i = 0
  while i < n
    @norm1_gamma[i] = 0.0
    @norm2_gamma[i] = 0.0
    i += 1
  end
  h = 0
  while h < @w_q.length
    @w_q[h].fill_zero
    @w_k[h].fill_zero
    @w_v[h].fill_zero
    h += 1
  end
  @w_o.fill_zero
  @w_ff1.fill_zero
  @w_ff2.fill_zero
end