Class: NextTokenModel

Inherits:
CausalModel show all
Defined in:
lib/scout/model/python/huggingface/causal/next_token.rb

Instance Attribute Summary

Attributes inherited from TorchModel

#criterion, #device, #dtype, #optimizer

Attributes inherited from ScoutModel

#directory, #options, #state

Instance Method Summary collapse

Methods inherited from CausalModel

#chat

Methods inherited from HuggingfaceModel

#fix_options

Methods inherited from TorchModel

criterion, device, dtype, feature_dataset, feature_tsv, #fix_options, freeze, freeze_layer, #freeze_layer, get_layer, #get_layer, get_weights, #get_weights, init_python, load, load_architecture, load_state, model_architecture, optimizer, #reset_state, save, save_architecture, save_state, tensor, text_dataset

Methods inherited from ScoutModel

#add, #add_list, #eval, #eval_list, #execute, #extract_features, #extract_features_list, #init, #load_method, #load_options, #load_ruby_code, #load_state, #post_process, #post_process_list, #restore, #save, #save_method, #save_options, #save_state, #state_file, #train

Constructor Details

#initializeNextTokenModel

Returns a new instance of NextTokenModel.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# File 'lib/scout/model/python/huggingface/causal/next_token.rb', line 4

def initialize(...)
  super(...)

  train do |texts|
    model, tokenizer = @state

    if self.directory
      output_dir = self.directory['output'].find
    else
      output_dir = TmpFile.tmp_file "next_token_model"
    end
    dataset = ScoutPython.call_method(
      "scout_ai.huggingface.data", :list_dataset, tokenizer, texts) 
    ScoutPython.call_method(
      "scout_ai.huggingface.train.next_token", :train_next_token, 
      model:model, tokenizer:tokenizer, dataset:dataset, output_dir:output_dir, **options[:training_args]
    )
  end
end