Class: Ignis::AI::DataLoader
- Inherits:
-
Object
- Object
- Ignis::AI::DataLoader
- Defined in:
- lib/nnw/ai/trainer.rb
Overview
DataLoader — batching, shuffling, and GPU prefetch for training data.
Instance Method Summary collapse
-
#initialize(data, batch_size:, seq_len:, device_id: 0, shuffle: true) ⇒ DataLoader
constructor
A new instance of DataLoader.
-
#next_batch ⇒ Hash{Symbol => Tensor}
Get next training batch.
-
#num_batches ⇒ Integer
Number of batches per epoch.
Constructor Details
#initialize(data, batch_size:, seq_len:, device_id: 0, shuffle: true) ⇒ DataLoader
Returns a new instance of DataLoader.
181 182 183 184 185 186 187 188 189 190 191 |
# File 'lib/nnw/ai/trainer.rb', line 181 def initialize(data, batch_size:, seq_len:, device_id: 0, shuffle: true) @data = data.flatten @batch_size = batch_size @seq_len = seq_len @device_id = device_id @shuffle = shuffle @position = 0 # Shuffle on init reshuffle! if @shuffle end |
Instance Method Details
#next_batch ⇒ Hash{Symbol => Tensor}
Get next training batch.
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
# File 'lib/nnw/ai/trainer.rb', line 195 def next_batch total_tokens = @batch_size * @seq_len if @position + total_tokens + 1 > @data.length @position = 0 reshuffle! if @shuffle end input_ids = [] targets = [] @batch_size.times do |b| start = @position + b * @seq_len input_ids.concat(@data[start, @seq_len]) targets.concat(@data[start + 1, @seq_len]) end @position += total_tokens input_nv = Ignis::Shared::NvArray.new(shape: [@batch_size, @seq_len], dtype: :int32, device_id: @device_id) input_nv.from_host(input_ids) target_nv = Ignis::Shared::NvArray.new(shape: [@batch_size * @seq_len], dtype: :int32, device_id: @device_id) target_nv.from_host(targets) { input_ids: Tensor.new(data: input_nv, requires_grad: false), targets: Tensor.new(data: target_nv, requires_grad: false) } end |
#num_batches ⇒ Integer
Number of batches per epoch.
230 231 232 |
# File 'lib/nnw/ai/trainer.rb', line 230 def num_batches (@data.length - 1) / (@batch_size * @seq_len) end |