Class: Trainers::LoraModel

Inherits:
Object
  • Object
show all
Defined in:
lib/trainers/lora/lora_model.rb

Class Method Summary collapse

Class Method Details

.apply(model, config) ⇒ Object

Apply LoRA adapters to a model.

  1. Finds all Linear layers matching config.target_modules

  2. Replaces each with a LoraLinear that freezes the original weight

  3. Freezes all base model parameters

  4. Only LoRA A/B matrices remain trainable

Returns the modified model (in-place).



13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# File 'lib/trainers/lora/lora_model.rb', line 13

def self.apply(model, config)
  config = LoraConfig.new(**config) if config.is_a?(Hash)

  # Freeze all base parameters first
  model.parameters.each { |p| p.requires_grad = false }

  # Find and replace target modules
  targets = LoraUtils.find_target_modules(model, config.target_modules)

  if targets.empty?
    raise "No Linear modules found matching target_modules: #{config.target_modules.inspect}. " \
          "Available modules: #{model.named_modules.map(&:first).join(', ')}"
  end

  targets.each do |name, linear|
    lora_linear = LoraLinear.new(
      linear,
      r:            config.r,
      lora_alpha:   config.lora_alpha,
      lora_dropout: config.lora_dropout
    )

    LoraUtils.replace_module(model, name, lora_linear)
  end

  # Handle bias training based on config
  case config.bias
  when :all
    model.named_parameters.each do |name, param|
      param.requires_grad = true if name.include?("bias")
    end
  when :lora_only
    model.named_modules.each do |_, mod|
      if mod.is_a?(LoraLinear) && mod.instance_variable_get(:@bias)
        mod.instance_variable_get(:@bias).requires_grad = true
      end
    end
  end
  # :none — biases stay frozen (default)

  puts "LoRA applied to #{targets.size} modules: #{targets.keys.join(', ')}"
  LoraUtils.print_trainable_parameters(model)

  model
end

.load_adapters(model, input_dir) ⇒ Object

Load LoRA adapter weights into a model that already has LoRA applied



83
84
85
# File 'lib/trainers/lora/lora_model.rb', line 83

def self.load_adapters(model, input_dir)
  SaveUtils.load_lora_adapters(model, input_dir)
end

.merge(model) ⇒ Object

Merge all LoRA weights back into base weights (for inference)



60
61
62
63
64
65
66
67
68
69
70
# File 'lib/trainers/lora/lora_model.rb', line 60

def self.merge(model)
  count = 0
  model.named_modules.each do |_, mod|
    if mod.is_a?(LoraLinear)
      mod.merge!
      count += 1
    end
  end
  puts "Merged #{count} LoRA adapters into base model"
  model
end

.save_adapters(model, output_dir, config: nil) ⇒ Object

Save only the LoRA adapter weights



73
74
75
76
77
78
79
80
# File 'lib/trainers/lora/lora_model.rb', line 73

def self.save_adapters(model, output_dir, config: nil)
  SaveUtils.save_lora_adapters(model, output_dir)

  if config
    config_path = File.join(output_dir, "lora_config.json")
    File.write(config_path, JSON.pretty_generate(config.to_h))
  end
end