Class: Toy::LLM::TrainingBatch
- Inherits:
-
Object
- Object
- Toy::LLM::TrainingBatch
- Defined in:
- lib/toy/llm/training_batch.rb
Instance Attribute Summary collapse
-
#batch ⇒ Object
Returns the value of attribute batch.
-
#context ⇒ Object
Returns the value of attribute context.
-
#hp ⇒ Object
Returns the value of attribute hp.
-
#labels ⇒ Object
Returns the value of attribute labels.
-
#positions ⇒ Object
Returns the value of attribute positions.
-
#seq_ids ⇒ Object
Returns the value of attribute seq_ids.
-
#vocab ⇒ Object
Returns the value of attribute vocab.
Instance Method Summary collapse
-
#fill!(new_seq_ids) ⇒ Object
Validate the sequence and rebuild the labels Mat.
-
#fill_fixed_target!(new_seq_ids, target_id) ⇒ Object
FIXED-TARGET objective (toy#73 item 3): validate the sequence the same way fill! does, but build labels where EVERY position targets ‘target_id` (the lora-smoke objective — push the whole prompt toward one token).
-
#initialize(vocab, context, batch_size) ⇒ TrainingBatch
constructor
A new instance of TrainingBatch.
Constructor Details
#initialize(vocab, context, batch_size) ⇒ TrainingBatch
Returns a new instance of TrainingBatch.
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# File 'lib/toy/llm/training_batch.rb', line 48 def initialize(vocab, context, batch_size) if vocab <= 0 raise "TrainingBatch: vocab must be positive, got " + vocab.to_s end if context <= 0 raise "TrainingBatch: context must be positive, got " + context.to_s end if batch_size != 1 raise "TrainingBatch: batch_size " + batch_size.to_s + " unsupported — batched training deferred (labels are " + "context x vocab, batch is not multiplied into the row count)" end @vocab = vocab @context = context @batch = batch_size # The validated positions vector: 0..context-1, the only shape # every current training graph accepts (RoPE reads positions[k]). @positions = [0]; @positions.pop p = 0 while p < context @positions.push(p) p = p + 1 end # Typed-empty until fill! (type-pin Array[Int] — literal-seed + # pop, the codebase's standard pattern). @seq_ids = [0]; @seq_ids.pop # Zero labels Mat at the final shape until fill! rebuilds it; hp # is caller-owned (Mat(1,7), see Toy::AdamW#hp) — zero Mat(1,7) # placeholder so the member type is concrete from construction. @labels = Mat.new(context, vocab) @hp = Mat.new(1, 7) end |
Instance Attribute Details
#batch ⇒ Object
Returns the value of attribute batch.
45 46 47 |
# File 'lib/toy/llm/training_batch.rb', line 45 def batch @batch end |
#context ⇒ Object
Returns the value of attribute context.
45 46 47 |
# File 'lib/toy/llm/training_batch.rb', line 45 def context @context end |
#hp ⇒ Object
Returns the value of attribute hp.
45 46 47 |
# File 'lib/toy/llm/training_batch.rb', line 45 def hp @hp end |
#labels ⇒ Object
Returns the value of attribute labels.
45 46 47 |
# File 'lib/toy/llm/training_batch.rb', line 45 def labels @labels end |
#positions ⇒ Object
Returns the value of attribute positions.
45 46 47 |
# File 'lib/toy/llm/training_batch.rb', line 45 def positions @positions end |
#seq_ids ⇒ Object
Returns the value of attribute seq_ids.
45 46 47 |
# File 'lib/toy/llm/training_batch.rb', line 45 def seq_ids @seq_ids end |
#vocab ⇒ Object
Returns the value of attribute vocab.
45 46 47 |
# File 'lib/toy/llm/training_batch.rb', line 45 def vocab @vocab end |
Instance Method Details
#fill!(new_seq_ids) ⇒ Object
Validate the sequence and rebuild the labels Mat. Fails LOUD on a length mismatch or an out-of-vocab id (the unguarded one-hot scatter would silently write outside the row otherwise). Labels are rebuilt via Toy::Labels.next_token — byte-identical to the hand-built shift-by-one one-hot it replaces. Returns nil.
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
# File 'lib/toy/llm/training_batch.rb', line 89 def fill!(new_seq_ids) if new_seq_ids.length != @context raise "TrainingBatch#fill!: seq_ids length " + new_seq_ids.length.to_s + " != context " + @context.to_s end k = 0 while k < @context id = new_seq_ids[k] if id < 0 || id >= @vocab raise "TrainingBatch#fill!: seq_ids[" + k.to_s + "] = " + id.to_s + " out of vocab 0..." + @vocab.to_s end k = k + 1 end @seq_ids = new_seq_ids @labels = Toy::Labels.next_token(new_seq_ids, @vocab, @context, @batch) nil end |
#fill_fixed_target!(new_seq_ids, target_id) ⇒ Object
FIXED-TARGET objective (toy#73 item 3): validate the sequence the same way fill! does, but build labels where EVERY position targets ‘target_id` (the lora-smoke objective — push the whole prompt toward one token). Same loud guarantees: length mismatch, out-of-vocab id, and out-of-vocab target all raise (via Toy::Labels.fixed_target for the target). Returns nil.
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# File 'lib/toy/llm/training_batch.rb', line 114 def fill_fixed_target!(new_seq_ids, target_id) if new_seq_ids.length != @context raise "TrainingBatch#fill_fixed_target!: seq_ids length " + new_seq_ids.length.to_s + " != context " + @context.to_s end k = 0 while k < @context id = new_seq_ids[k] if id < 0 || id >= @vocab raise "TrainingBatch#fill_fixed_target!: seq_ids[" + k.to_s + "] = " + id.to_s + " out of vocab 0..." + @vocab.to_s end k = k + 1 end @seq_ids = new_seq_ids @labels = Toy::Labels.fixed_target(@vocab, @context, target_id) nil end |