108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
# File 'lib/trainers/data/data_collator.rb', line 108
def call(features)
return {} if features.empty?
batch = {}
features.first.keys.each do |key|
values = features.map { |f| f[key] }
if values.first.is_a?(Torch::Tensor)
batch[key] = Torch.stack(values)
elsif values.first.is_a?(Integer)
batch[key] = Torch.tensor(values, dtype: :int64)
elsif values.first.is_a?(Float)
batch[key] = Torch.tensor(values, dtype: :float32)
elsif values.first.is_a?(Array)
batch[key] = Torch.tensor(values)
else
batch[key] = values
end
end
batch
end
|