Class: Trainers::LoraLinear
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Trainers::LoraLinear
- Defined in:
- lib/trainers/lora/lora_linear.rb
Overview
Wraps a frozen Torch::NN::Linear with low-rank A/B adapter matrices.
Forward: y = Wx + b + (x @ A^T @ B^T) * scaling
Only lora_A and lora_B are trainable. The original weight and bias are frozen, keeping ~99% of parameters fixed during fine-tuning.
Instance Attribute Summary collapse
-
#in_features ⇒ Object
readonly
Returns the value of attribute in_features.
-
#out_features ⇒ Object
readonly
Returns the value of attribute out_features.
-
#r ⇒ Object
readonly
Returns the value of attribute r.
-
#scaling ⇒ Object
readonly
Returns the value of attribute scaling.
Instance Method Summary collapse
- #extra_repr ⇒ Object
- #forward(x) ⇒ Object
-
#initialize(original_linear, r:, lora_alpha:, lora_dropout: 0.0) ⇒ LoraLinear
constructor
A new instance of LoraLinear.
-
#load_lora_weights(lora_a_tensor, lora_b_tensor) ⇒ Object
Load LoRA weights from saved state.
-
#lora_state_dict ⇒ Object
Extract LoRA adapter state for saving.
-
#merge! ⇒ Object
Merge LoRA weights into the base weight matrix (for inference).
Constructor Details
#initialize(original_linear, r:, lora_alpha:, lora_dropout: 0.0) ⇒ LoraLinear
Returns a new instance of LoraLinear.
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# File 'lib/trainers/lora/lora_linear.rb', line 13 def initialize(original_linear, r:, lora_alpha:, lora_dropout: 0.0) super() @in_features = original_linear.instance_variable_get(:@in_features) || original_linear.weight.size(1) @out_features = original_linear.instance_variable_get(:@out_features) || original_linear.weight.size(0) @r = r @scaling = lora_alpha.to_f / r # Freeze original parameters @weight = original_linear.weight @weight.requires_grad = false @bias = original_linear.bias @bias.requires_grad = false if @bias # LoRA low-rank matrices (these are the only trainable parameters) # A: (r, in_features) — initialized with Kaiming uniform # B: (out_features, r) — initialized to zero so LoRA starts as identity @lora_A = Torch::NN::Parameter.new(Torch.empty(@r, @in_features)) Torch::NN::Init.kaiming_uniform!(@lora_A, a: Math.sqrt(5)) @lora_B = Torch::NN::Parameter.new(Torch.zeros(@out_features, @r)) @lora_dropout = lora_dropout > 0 ? Torch::NN::Dropout.new(p: lora_dropout) : nil end |
Instance Attribute Details
#in_features ⇒ Object (readonly)
Returns the value of attribute in_features.
11 12 13 |
# File 'lib/trainers/lora/lora_linear.rb', line 11 def in_features @in_features end |
#out_features ⇒ Object (readonly)
Returns the value of attribute out_features.
11 12 13 |
# File 'lib/trainers/lora/lora_linear.rb', line 11 def out_features @out_features end |
#r ⇒ Object (readonly)
Returns the value of attribute r.
11 12 13 |
# File 'lib/trainers/lora/lora_linear.rb', line 11 def r @r end |
#scaling ⇒ Object (readonly)
Returns the value of attribute scaling.
11 12 13 |
# File 'lib/trainers/lora/lora_linear.rb', line 11 def scaling @scaling end |
Instance Method Details
#extra_repr ⇒ Object
73 74 75 76 |
# File 'lib/trainers/lora/lora_linear.rb', line 73 def extra_repr "in_features=#{@in_features}, out_features=#{@out_features}, " \ "r=#{@r}, scaling=#{@scaling}" end |
#forward(x) ⇒ Object
41 42 43 44 45 46 47 48 49 50 |
# File 'lib/trainers/lora/lora_linear.rb', line 41 def forward(x) # Original linear base_output = Torch::NN::F.linear(x, @weight, @bias) # LoRA path lora_input = @lora_dropout ? @lora_dropout.call(x) : x lora_output = lora_input.matmul(@lora_A.t).matmul(@lora_B.t) * @scaling base_output + lora_output end |
#load_lora_weights(lora_a_tensor, lora_b_tensor) ⇒ Object
Load LoRA weights from saved state
66 67 68 69 70 71 |
# File 'lib/trainers/lora/lora_linear.rb', line 66 def load_lora_weights(lora_a_tensor, lora_b_tensor) Torch.no_grad do @lora_A.copy!(lora_a_tensor) @lora_B.copy!(lora_b_tensor) end end |
#lora_state_dict ⇒ Object
Extract LoRA adapter state for saving
61 62 63 |
# File 'lib/trainers/lora/lora_linear.rb', line 61 def lora_state_dict { "lora_A" => @lora_A.data, "lora_B" => @lora_B.data } end |
#merge! ⇒ Object
Merge LoRA weights into the base weight matrix (for inference)
53 54 55 56 57 58 |
# File 'lib/trainers/lora/lora_linear.rb', line 53 def merge! Torch.no_grad do @weight.add!(@lora_B.matmul(@lora_A) * @scaling) end self end |