Class: CausalModel

Inherits:
HuggingfaceModel show all
Defined in:
lib/scout/model/python/huggingface/causal.rb

Direct Known Subclasses

NextTokenModel

Instance Attribute Summary

Attributes inherited from TorchModel

#criterion, #device, #dtype, #optimizer

Attributes inherited from ScoutModel

#directory, #options, #state

Instance Method Summary collapse

Methods inherited from HuggingfaceModel

#fix_options

Methods inherited from TorchModel

criterion, device, dtype, feature_dataset, feature_tsv, #fix_options, freeze, freeze_layer, #freeze_layer, get_layer, #get_layer, get_weights, #get_weights, init_python, load, load_architecture, load_state, model_architecture, optimizer, #reset_state, save, save_architecture, save_state, tensor, text_dataset

Methods inherited from ScoutModel

#add, #add_list, #eval, #eval_list, #execute, #extract_features, #extract_features_list, #init, #load_method, #load_options, #load_ruby_code, #load_state, #post_process, #post_process_list, #restore, #save, #save_method, #save_options, #save_state, #state_file, #train

Constructor Details

#initializeCausalModel

Returns a new instance of CausalModel.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# File 'lib/scout/model/python/huggingface/causal.rb', line 4

def initialize(...)
  super("CausalLM", ...)

  self.eval do |messages,list|
    model, tokenizer = @state
    ScoutPython.call_method(
      "scout_ai.huggingface.eval", :eval_causal_lm_chat,
      model, tokenizer, messages,
      options[:chat_template],
      options[:chat_template_kwargs],
      options[:generation_kwargs],
      options[:tool_argument]
    )
  end

  train do |pairs,labels|
    # data: array of [response, reward] or [prompt, response, reward]
    model, tokenizer = @state

    ScoutPython.call_method(
      "scout_ai.huggingface.rlhf", :train_rlhf,
      self.state_file, tokenizer, pairs, labels, options[:rlhf_config]
    )
    load_state
  end
end

Instance Method Details

#chat(messages, tools = nil, runtime_options = {}) ⇒ Object



31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# File 'lib/scout/model/python/huggingface/causal.rb', line 31

def chat(messages, tools = nil, runtime_options = {})
  init unless @state
  model, tokenizer = @state

  runtime_options = IndiferentHash.setup(runtime_options)

  ScoutPython.call_method(
    "scout_ai.huggingface.eval", :eval_causal_lm_response,
    model, tokenizer, messages, tools,
    runtime_options[:chat_template] || options[:chat_template],
    runtime_options[:chat_template_kwargs] || options[:chat_template_kwargs],
    runtime_options[:generation_kwargs] || options[:generation_kwargs],
    runtime_options[:tool_argument] || options[:tool_argument],
    runtime_options[:response_parser] || options[:response_parser]
  )
end