Module: Trainers::LoraUtils
- Defined in:
- lib/trainers/lora/lora_utils.rb
Class Method Summary collapse
-
.find_target_modules(model, target_modules) ⇒ Object
Find all Linear modules in a model matching the target pattern.
- .format_number(n) ⇒ Object
-
.print_trainable_parameters(model) ⇒ Object
Count total and trainable parameters.
-
.replace_module(model, target_name, new_module) ⇒ Object
Replace a child module on a parent by setting the instance variable.
Class Method Details
.find_target_modules(model, target_modules) ⇒ Object
Find all Linear modules in a model matching the target pattern
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
# File 'lib/trainers/lora/lora_utils.rb', line 6 def self.find_target_modules(model, target_modules) targets = {} model.named_modules.each do |name, mod| next unless mod.is_a?(Torch::NN::Linear) if target_modules == :all_linear targets[name] = mod elsif target_modules.is_a?(Array) if target_modules.any? { |pattern| name.include?(pattern) } targets[name] = mod end end end targets end |
.format_number(n) ⇒ Object
69 70 71 |
# File 'lib/trainers/lora/lora_utils.rb', line 69 def self.format_number(n) n.to_s.reverse.gsub(/(\d{3})(?=\d)/, '\\1,').reverse end |
.print_trainable_parameters(model) ⇒ Object
Count total and trainable parameters
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# File 'lib/trainers/lora/lora_utils.rb', line 52 def self.print_trainable_parameters(model) total = 0 trainable = 0 model.parameters.each do |p| total += p.numel trainable += p.numel if p.requires_grad end pct = total > 0 ? (trainable.to_f / total * 100) : 0.0 puts "trainable params: #{format_number(trainable)} || " \ "all params: #{format_number(total)} || " \ "trainable%: #{format('%.4f', pct)}%" { total: total, trainable: trainable, percentage: pct } end |
.replace_module(model, target_name, new_module) ⇒ Object
Replace a child module on a parent by setting the instance variable. torch-rb’s named_children discovers modules via instance variable scan, so setting the ivar is the correct replacement mechanism.
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# File 'lib/trainers/lora/lora_utils.rb', line 27 def self.replace_module(model, target_name, new_module) parts = target_name.split(".") parent = model # Navigate to the parent module parts[0...-1].each do |part| if part =~ /\A\d+\z/ # Numeric index — likely a ModuleList element parent = parent.instance_variable_get(:@modules)[part.to_i] else parent = parent.instance_variable_get(:"@#{part}") end raise "Could not find module part '#{part}' in #{target_name}" unless parent end child_name = parts.last if child_name =~ /\A\d+\z/ parent.instance_variable_get(:@modules)[child_name.to_i] = new_module else parent.instance_variable_set(:"@#{child_name}", new_module) end end |