Class: Toy::LLM::TrainingBatch

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/llm/training_batch.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(vocab, context, batch_size) ⇒ TrainingBatch

Returns a new instance of TrainingBatch.



48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# File 'lib/toy/llm/training_batch.rb', line 48

def initialize(vocab, context, batch_size)
  if vocab <= 0
    raise "TrainingBatch: vocab must be positive, got " + vocab.to_s
  end
  if context <= 0
    raise "TrainingBatch: context must be positive, got " + context.to_s
  end
  if batch_size != 1
    raise "TrainingBatch: batch_size " + batch_size.to_s +
          " unsupported — batched training deferred (labels are " +
          "context x vocab, batch is not multiplied into the row count)"
  end
  @vocab   = vocab
  @context = context
  @batch   = batch_size

  # The validated positions vector: 0..context-1, the only shape
  # every current training graph accepts (RoPE reads positions[k]).
  @positions = [0]; @positions.pop
  p = 0
  while p < context
    @positions.push(p)
    p = p + 1
  end

  # Typed-empty until fill! (type-pin Array[Int] — literal-seed +
  # pop, the codebase's standard pattern).
  @seq_ids = [0]; @seq_ids.pop

  # Zero labels Mat at the final shape until fill! rebuilds it; hp
  # is caller-owned (Mat(1,7), see Toy::AdamW#hp) — zero Mat(1,7)
  # placeholder so the member type is concrete from construction.
  @labels = Mat.new(context, vocab)
  @hp     = Mat.new(1, 7)
end

Instance Attribute Details

#batchObject

Returns the value of attribute batch.



45
46
47
# File 'lib/toy/llm/training_batch.rb', line 45

def batch
  @batch
end

#contextObject

Returns the value of attribute context.



45
46
47
# File 'lib/toy/llm/training_batch.rb', line 45

def context
  @context
end

#hpObject

Returns the value of attribute hp.



45
46
47
# File 'lib/toy/llm/training_batch.rb', line 45

def hp
  @hp
end

#labelsObject

Returns the value of attribute labels.



45
46
47
# File 'lib/toy/llm/training_batch.rb', line 45

def labels
  @labels
end

#positionsObject

Returns the value of attribute positions.



45
46
47
# File 'lib/toy/llm/training_batch.rb', line 45

def positions
  @positions
end

#seq_idsObject

Returns the value of attribute seq_ids.



45
46
47
# File 'lib/toy/llm/training_batch.rb', line 45

def seq_ids
  @seq_ids
end

#vocabObject

Returns the value of attribute vocab.



45
46
47
# File 'lib/toy/llm/training_batch.rb', line 45

def vocab
  @vocab
end

Instance Method Details

#fill!(new_seq_ids) ⇒ Object

Validate the sequence and rebuild the labels Mat. Fails LOUD on a length mismatch or an out-of-vocab id (the unguarded one-hot scatter would silently write outside the row otherwise). Labels are rebuilt via Toy::Labels.next_token — byte-identical to the hand-built shift-by-one one-hot it replaces. Returns nil.



89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# File 'lib/toy/llm/training_batch.rb', line 89

def fill!(new_seq_ids)
  if new_seq_ids.length != @context
    raise "TrainingBatch#fill!: seq_ids length " +
          new_seq_ids.length.to_s + " != context " + @context.to_s
  end
  k = 0
  while k < @context
    id = new_seq_ids[k]
    if id < 0 || id >= @vocab
      raise "TrainingBatch#fill!: seq_ids[" + k.to_s + "] = " +
            id.to_s + " out of vocab 0..." + @vocab.to_s
    end
    k = k + 1
  end
  @seq_ids = new_seq_ids
  @labels  = Toy::Labels.next_token(new_seq_ids, @vocab, @context, @batch)
  nil
end

#fill_fixed_target!(new_seq_ids, target_id) ⇒ Object

FIXED-TARGET objective (toy#73 item 3): validate the sequence the same way fill! does, but build labels where EVERY position targets ‘target_id` (the lora-smoke objective — push the whole prompt toward one token). Same loud guarantees: length mismatch, out-of-vocab id, and out-of-vocab target all raise (via Toy::Labels.fixed_target for the target). Returns nil.



114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# File 'lib/toy/llm/training_batch.rb', line 114

def fill_fixed_target!(new_seq_ids, target_id)
  if new_seq_ids.length != @context
    raise "TrainingBatch#fill_fixed_target!: seq_ids length " +
          new_seq_ids.length.to_s + " != context " + @context.to_s
  end
  k = 0
  while k < @context
    id = new_seq_ids[k]
    if id < 0 || id >= @vocab
      raise "TrainingBatch#fill_fixed_target!: seq_ids[" + k.to_s +
            "] = " + id.to_s + " out of vocab 0..." + @vocab.to_s
    end
    k = k + 1
  end
  @seq_ids = new_seq_ids
  @labels  = Toy::Labels.fixed_target(@vocab, @context, target_id)
  nil
end