Module: Trainers

Defined in:
lib/trainers.rb,
lib/trainers/trainer.rb,
lib/trainers/version.rb,
lib/trainers/callbacks.rb,
lib/trainers/save_utils.rb,
lib/trainers/data/dataset.rb,
lib/trainers/trainer_utils.rb,
lib/trainers/lora/lora_model.rb,
lib/trainers/lora/lora_utils.rb,
lib/trainers/lora/lora_config.rb,
lib/trainers/lora/lora_linear.rb,
lib/trainers/data/data_collator.rb,
lib/trainers/training_arguments.rb,
lib/trainers/optimization/optimizer.rb,
lib/trainers/optimization/scheduler.rb

Defined Under Namespace

Modules: EvalStrategy, LoraUtils, Optimization, SaveStrategy, SaveUtils, SchedulerType Classes: CallbackHandler, DataCollatorWithPadding, Dataset, DefaultDataCollator, EarlyStoppingCallback, EvalPrediction, LoraConfig, LoraLinear, LoraModel, PrinterCallback, Trainer, TrainerCallback, TrainerControl, TrainerState, TrainingArguments

Constant Summary collapse

VERSION =
"0.1.0"

Class Method Summary collapse

Class Method Details

.from_pretrained(model_name, task: :sequence_classification, num_labels: 2) ⇒ Object

Convenience method: load model + tokenizer and prepare for training



24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# File 'lib/trainers.rb', line 24

def self.from_pretrained(model_name, task: :sequence_classification, num_labels: 2)
  require "transformers-rb"

  model_class = case task
                when :sequence_classification
                  Transformers::AutoModelForSequenceClassification
                when :token_classification
                  Transformers::AutoModelForTokenClassification
                when :question_answering
                  Transformers::AutoModelForQuestionAnswering
                else
                  Transformers::AutoModel
                end

  model     = model_class.from_pretrained(model_name, num_labels: num_labels)
  tokenizer = Transformers::AutoTokenizer.from_pretrained(model_name)

  [model, tokenizer]
end