Module: Trainers
- Defined in:
- lib/trainers.rb,
lib/trainers/trainer.rb,
lib/trainers/version.rb,
lib/trainers/callbacks.rb,
lib/trainers/save_utils.rb,
lib/trainers/data/dataset.rb,
lib/trainers/trainer_utils.rb,
lib/trainers/lora/lora_model.rb,
lib/trainers/lora/lora_utils.rb,
lib/trainers/lora/lora_config.rb,
lib/trainers/lora/lora_linear.rb,
lib/trainers/data/data_collator.rb,
lib/trainers/training_arguments.rb,
lib/trainers/optimization/optimizer.rb,
lib/trainers/optimization/scheduler.rb
Defined Under Namespace
Modules: EvalStrategy, LoraUtils, Optimization, SaveStrategy, SaveUtils, SchedulerType Classes: CallbackHandler, DataCollatorWithPadding, Dataset, DefaultDataCollator, EarlyStoppingCallback, EvalPrediction, LoraConfig, LoraLinear, LoraModel, PrinterCallback, Trainer, TrainerCallback, TrainerControl, TrainerState, TrainingArguments
Constant Summary collapse
- VERSION =
"0.1.0"
Class Method Summary collapse
-
.from_pretrained(model_name, task: :sequence_classification, num_labels: 2) ⇒ Object
Convenience method: load model + tokenizer and prepare for training.
Class Method Details
.from_pretrained(model_name, task: :sequence_classification, num_labels: 2) ⇒ Object
Convenience method: load model + tokenizer and prepare for training
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
# File 'lib/trainers.rb', line 24 def self.from_pretrained(model_name, task: :sequence_classification, num_labels: 2) require "transformers-rb" model_class = case task when :sequence_classification Transformers::AutoModelForSequenceClassification when :token_classification Transformers::AutoModelForTokenClassification when :question_answering Transformers::AutoModelForQuestionAnswering else Transformers::AutoModel end model = model_class.from_pretrained(model_name, num_labels: num_labels) tokenizer = Transformers::AutoTokenizer.from_pretrained(model_name) [model, tokenizer] end |