Class: Trainers::DefaultDataCollator

Inherits:
Object
  • Object
show all
Defined in:
lib/trainers/data/data_collator.rb

Instance Method Summary collapse

Instance Method Details

#call(features) ⇒ Object



108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# File 'lib/trainers/data/data_collator.rb', line 108

def call(features)
  return {} if features.empty?

  batch = {}
  features.first.keys.each do |key|
    values = features.map { |f| f[key] }

    if values.first.is_a?(Torch::Tensor)
      batch[key] = Torch.stack(values)
    elsif values.first.is_a?(Integer)
      batch[key] = Torch.tensor(values, dtype: :int64)
    elsif values.first.is_a?(Float)
      batch[key] = Torch.tensor(values, dtype: :float32)
    elsif values.first.is_a?(Array)
      batch[key] = Torch.tensor(values)
    else
      batch[key] = values
    end
  end

  batch
end