Class: Trainers::DataCollatorWithPadding

Inherits:
Object
  • Object
show all
Defined in:
lib/trainers/data/data_collator.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(tokenizer:, padding: true, max_length: nil, pad_to_multiple_of: nil) ⇒ DataCollatorWithPadding

Returns a new instance of DataCollatorWithPadding.



7
8
9
10
11
12
# File 'lib/trainers/data/data_collator.rb', line 7

def initialize(tokenizer:, padding: true, max_length: nil, pad_to_multiple_of: nil)
  @tokenizer          = tokenizer
  @padding             = padding
  @max_length          = max_length
  @pad_to_multiple_of  = pad_to_multiple_of
end

Instance Attribute Details

#max_lengthObject (readonly)

Returns the value of attribute max_length.



5
6
7
# File 'lib/trainers/data/data_collator.rb', line 5

def max_length
  @max_length
end

#pad_to_multiple_ofObject (readonly)

Returns the value of attribute pad_to_multiple_of.



5
6
7
# File 'lib/trainers/data/data_collator.rb', line 5

def pad_to_multiple_of
  @pad_to_multiple_of
end

#paddingObject (readonly)

Returns the value of attribute padding.



5
6
7
# File 'lib/trainers/data/data_collator.rb', line 5

def padding
  @padding
end

#tokenizerObject (readonly)

Returns the value of attribute tokenizer.



5
6
7
# File 'lib/trainers/data/data_collator.rb', line 5

def tokenizer
  @tokenizer
end

Instance Method Details

#call(features) ⇒ Object



14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# File 'lib/trainers/data/data_collator.rb', line 14

def call(features)
  return {} if features.empty?

  keys = features.first.keys
  batch = {}

  keys.each do |key|
    values = features.map { |f| f[key] }

    if values.first.is_a?(Array)
      batch[key] = pad_and_stack(key, values)
    elsif values.first.is_a?(Torch::Tensor)
      if values.first.dim == 0
        batch[key] = Torch.stack(values)
      else
        batch[key] = pad_and_stack_tensors(key, values)
      end
    elsif values.first.is_a?(Integer)
      batch[key] = Torch.tensor(values, dtype: :int64)
    elsif values.first.is_a?(Float)
      batch[key] = Torch.tensor(values, dtype: :float32)
    else
      batch[key] = values
    end
  end

  batch
end