Class: Trainers::DataCollatorWithPadding
- Inherits:
-
Object
- Object
- Trainers::DataCollatorWithPadding
- Defined in:
- lib/trainers/data/data_collator.rb
Instance Attribute Summary collapse
-
#max_length ⇒ Object
readonly
Returns the value of attribute max_length.
-
#pad_to_multiple_of ⇒ Object
readonly
Returns the value of attribute pad_to_multiple_of.
-
#padding ⇒ Object
readonly
Returns the value of attribute padding.
-
#tokenizer ⇒ Object
readonly
Returns the value of attribute tokenizer.
Instance Method Summary collapse
- #call(features) ⇒ Object
-
#initialize(tokenizer:, padding: true, max_length: nil, pad_to_multiple_of: nil) ⇒ DataCollatorWithPadding
constructor
A new instance of DataCollatorWithPadding.
Constructor Details
#initialize(tokenizer:, padding: true, max_length: nil, pad_to_multiple_of: nil) ⇒ DataCollatorWithPadding
Returns a new instance of DataCollatorWithPadding.
7 8 9 10 11 12 |
# File 'lib/trainers/data/data_collator.rb', line 7 def initialize(tokenizer:, padding: true, max_length: nil, pad_to_multiple_of: nil) @tokenizer = tokenizer @padding = padding @max_length = max_length @pad_to_multiple_of = pad_to_multiple_of end |
Instance Attribute Details
#max_length ⇒ Object (readonly)
Returns the value of attribute max_length.
5 6 7 |
# File 'lib/trainers/data/data_collator.rb', line 5 def max_length @max_length end |
#pad_to_multiple_of ⇒ Object (readonly)
Returns the value of attribute pad_to_multiple_of.
5 6 7 |
# File 'lib/trainers/data/data_collator.rb', line 5 def pad_to_multiple_of @pad_to_multiple_of end |
#padding ⇒ Object (readonly)
Returns the value of attribute padding.
5 6 7 |
# File 'lib/trainers/data/data_collator.rb', line 5 def padding @padding end |
#tokenizer ⇒ Object (readonly)
Returns the value of attribute tokenizer.
5 6 7 |
# File 'lib/trainers/data/data_collator.rb', line 5 def tokenizer @tokenizer end |
Instance Method Details
#call(features) ⇒ Object
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
# File 'lib/trainers/data/data_collator.rb', line 14 def call(features) return {} if features.empty? keys = features.first.keys batch = {} keys.each do |key| values = features.map { |f| f[key] } if values.first.is_a?(Array) batch[key] = pad_and_stack(key, values) elsif values.first.is_a?(Torch::Tensor) if values.first.dim == 0 batch[key] = Torch.stack(values) else batch[key] = pad_and_stack_tensors(key, values) end elsif values.first.is_a?(Integer) batch[key] = Torch.tensor(values, dtype: :int64) elsif values.first.is_a?(Float) batch[key] = Torch.tensor(values, dtype: :float32) else batch[key] = values end end batch end |