Class: Trainers::LoraLinear

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/trainers/lora/lora_linear.rb

Overview

Wraps a frozen Torch::NN::Linear with low-rank A/B adapter matrices.

Forward: y = Wx + b + (x @ A^T @ B^T) * scaling

Only lora_A and lora_B are trainable. The original weight and bias are frozen, keeping ~99% of parameters fixed during fine-tuning.

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(original_linear, r:, lora_alpha:, lora_dropout: 0.0) ⇒ LoraLinear

Returns a new instance of LoraLinear.



13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# File 'lib/trainers/lora/lora_linear.rb', line 13

def initialize(original_linear, r:, lora_alpha:, lora_dropout: 0.0)
  super()

  @in_features  = original_linear.instance_variable_get(:@in_features) ||
                  original_linear.weight.size(1)
  @out_features = original_linear.instance_variable_get(:@out_features) ||
                  original_linear.weight.size(0)
  @r            = r
  @scaling      = lora_alpha.to_f / r

  # Freeze original parameters
  @weight = original_linear.weight
  @weight.requires_grad = false

  @bias = original_linear.bias
  @bias.requires_grad = false if @bias

  # LoRA low-rank matrices (these are the only trainable parameters)
  # A: (r, in_features)  — initialized with Kaiming uniform
  # B: (out_features, r) — initialized to zero so LoRA starts as identity
  @lora_A = Torch::NN::Parameter.new(Torch.empty(@r, @in_features))
  Torch::NN::Init.kaiming_uniform!(@lora_A, a: Math.sqrt(5))

  @lora_B = Torch::NN::Parameter.new(Torch.zeros(@out_features, @r))

  @lora_dropout = lora_dropout > 0 ? Torch::NN::Dropout.new(p: lora_dropout) : nil
end

Instance Attribute Details

#in_featuresObject (readonly)

Returns the value of attribute in_features.



11
12
13
# File 'lib/trainers/lora/lora_linear.rb', line 11

def in_features
  @in_features
end

#out_featuresObject (readonly)

Returns the value of attribute out_features.



11
12
13
# File 'lib/trainers/lora/lora_linear.rb', line 11

def out_features
  @out_features
end

#rObject (readonly)

Returns the value of attribute r.



11
12
13
# File 'lib/trainers/lora/lora_linear.rb', line 11

def r
  @r
end

#scalingObject (readonly)

Returns the value of attribute scaling.



11
12
13
# File 'lib/trainers/lora/lora_linear.rb', line 11

def scaling
  @scaling
end

Instance Method Details

#extra_reprObject



73
74
75
76
# File 'lib/trainers/lora/lora_linear.rb', line 73

def extra_repr
  "in_features=#{@in_features}, out_features=#{@out_features}, " \
  "r=#{@r}, scaling=#{@scaling}"
end

#forward(x) ⇒ Object



41
42
43
44
45
46
47
48
49
50
# File 'lib/trainers/lora/lora_linear.rb', line 41

def forward(x)
  # Original linear
  base_output = Torch::NN::F.linear(x, @weight, @bias)

  # LoRA path
  lora_input = @lora_dropout ? @lora_dropout.call(x) : x
  lora_output = lora_input.matmul(@lora_A.t).matmul(@lora_B.t) * @scaling

  base_output + lora_output
end

#load_lora_weights(lora_a_tensor, lora_b_tensor) ⇒ Object

Load LoRA weights from saved state



66
67
68
69
70
71
# File 'lib/trainers/lora/lora_linear.rb', line 66

def load_lora_weights(lora_a_tensor, lora_b_tensor)
  Torch.no_grad do
    @lora_A.copy!(lora_a_tensor)
    @lora_B.copy!(lora_b_tensor)
  end
end

#lora_state_dictObject

Extract LoRA adapter state for saving



61
62
63
# File 'lib/trainers/lora/lora_linear.rb', line 61

def lora_state_dict
  { "lora_A" => @lora_A.data, "lora_B" => @lora_B.data }
end

#merge!Object

Merge LoRA weights into the base weight matrix (for inference)



53
54
55
56
57
58
# File 'lib/trainers/lora/lora_linear.rb', line 53

def merge!
  Torch.no_grad do
    @weight.add!(@lora_B.matmul(@lora_A) * @scaling)
  end
  self
end