Module: Ignis::AI::GPT2Loader
- Defined in:
- lib/nnw/ai/gpt2_loader.rb
Overview
Loads HuggingFace GPT-2 safetensors checkpoints into an Ignis Transformer::Model.
GPT-2 is NOT a drop-in for nn.Linear: its attention/MLP projections are ‘Conv1D` layers whose weights are stored as [in, out] (transposed vs. nn.Linear’s [out, in]), the QKV projection is a single fused [embed, 3*embed] tensor, and the LM head is tied to the token embedding. This loader applies those transforms while copying weights in.
Class Method Summary collapse
-
.load(model, dir) ⇒ model
Load weights from <dir>/model.safetensors into
model(in place). -
.read_header(path) ⇒ Object
Read the safetensors JSON header.
-
.transpose(arr, rows, cols) ⇒ Object
Transpose a row-major [rows, cols] flat array → [cols, rows].
-
.transpose_slice_cols(arr, rows, total_cols, col_off, len) ⇒ Object
Slice columns [col_off, col_off+len) from a [rows, total_cols] array, then transpose → [len, rows].
Class Method Details
.load(model, dir) ⇒ model
Load weights from <dir>/model.safetensors into model (in place).
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# File 'lib/nnw/ai/gpt2_loader.rb', line 26 def load(model, dir) path = File.join(dir, "model.safetensors") raise ArgumentError, "not found: #{path}" unless File.exist?(path) header, data_offset = read_header(path) params = model.named_parameters = model. n_layer = model.num_layers File.open(path, "rb") do |f| get = lambda do |name| = header[name] or raise "missing safetensors tensor: #{name}" b, e = ["data_offsets"] f.seek(data_offset + b) raw = f.read(e - b) { values: raw.unpack("e*"), shape: ["shape"] } # F32 little-endian end set = lambda do |pname, values| p = params[pname] or raise "missing model param: #{pname}" unless p.numel == values.length raise "size mismatch for #{pname}: param has #{p.numel}, data has #{values.length}" end p.data.from_host(values) end # Embeddings (direct) set.call("token_embedding.weight", get.call("wte.weight")[:values]) set.call("position_embedding.weight", get.call("wpe.weight")[:values]) n_layer.times do |i| set.call("blocks.#{i}.norm1.weight", get.call("h.#{i}.ln_1.weight")[:values]) set.call("blocks.#{i}.norm1.bias", get.call("h.#{i}.ln_1.bias")[:values]) set.call("blocks.#{i}.norm2.weight", get.call("h.#{i}.ln_2.weight")[:values]) set.call("blocks.#{i}.norm2.bias", get.call("h.#{i}.ln_2.bias")[:values]) # Fused QKV: c_attn.weight is Conv1D [embed, 3*embed]. Split columns into # q|k|v and transpose each into nn.Linear layout [out, in] = [embed, embed]. cattn_w = get.call("h.#{i}.attn.c_attn.weight")[:values] cattn_b = get.call("h.#{i}.attn.c_attn.bias")[:values] set.call("blocks.#{i}.attention.q_proj.weight", transpose_slice_cols(cattn_w, , 3 * , 0 * , )) set.call("blocks.#{i}.attention.k_proj.weight", transpose_slice_cols(cattn_w, , 3 * , 1 * , )) set.call("blocks.#{i}.attention.v_proj.weight", transpose_slice_cols(cattn_w, , 3 * , 2 * , )) set.call("blocks.#{i}.attention.q_proj.bias", cattn_b[(0 * )...(1 * )]) set.call("blocks.#{i}.attention.k_proj.bias", cattn_b[(1 * )...(2 * )]) set.call("blocks.#{i}.attention.v_proj.bias", cattn_b[(2 * )...(3 * )]) # Attention output projection: Conv1D [embed, embed] -> transpose. aproj = get.call("h.#{i}.attn.c_proj.weight")[:values] set.call("blocks.#{i}.attention.out_proj.weight", transpose(aproj, , )) set.call("blocks.#{i}.attention.out_proj.bias", get.call("h.#{i}.attn.c_proj.bias")[:values]) # MLP fc1: c_fc Conv1D [embed, ff] -> transpose -> [ff, embed]. fc = get.call("h.#{i}.mlp.c_fc.weight") ff = fc[:shape][1] set.call("blocks.#{i}.feed_forward.fc1.weight", transpose(fc[:values], , ff)) set.call("blocks.#{i}.feed_forward.fc1.bias", get.call("h.#{i}.mlp.c_fc.bias")[:values]) # MLP fc2: c_proj Conv1D [ff, embed] -> transpose -> [embed, ff]. fp = get.call("h.#{i}.mlp.c_proj.weight")[:values] set.call("blocks.#{i}.feed_forward.fc2.weight", transpose(fp, ff, )) set.call("blocks.#{i}.feed_forward.fc2.bias", get.call("h.#{i}.mlp.c_proj.bias")[:values]) end set.call("norm.weight", get.call("ln_f.weight")[:values]) set.call("norm.bias", get.call("ln_f.bias")[:values]) # LM head is tied to the token embedding (logits = x @ wte^T). set.call("head.weight", get.call("wte.weight")[:values]) end model end |
.read_header(path) ⇒ Object
Read the safetensors JSON header. @return [[Hash, Integer]] header, data_offset
101 102 103 104 105 106 107 108 |
# File 'lib/nnw/ai/gpt2_loader.rb', line 101 def read_header(path) File.open(path, "rb") do |f| n = f.read(8).unpack1("Q<") h = JSON.parse(f.read(n)) h.delete("__metadata__") [h, 8 + n] end end |
.transpose(arr, rows, cols) ⇒ Object
Transpose a row-major [rows, cols] flat array → [cols, rows].
111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# File 'lib/nnw/ai/gpt2_loader.rb', line 111 def transpose(arr, rows, cols) out = Array.new(rows * cols) r = 0 while r < rows base = r * cols c = 0 while c < cols out[c * rows + r] = arr[base + c] c += 1 end r += 1 end out end |
.transpose_slice_cols(arr, rows, total_cols, col_off, len) ⇒ Object
Slice columns [col_off, col_off+len) from a [rows, total_cols] array, then transpose → [len, rows]. (Used to split the fused QKV Conv1D weight.)
128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# File 'lib/nnw/ai/gpt2_loader.rb', line 128 def transpose_slice_cols(arr, rows, total_cols, col_off, len) out = Array.new(rows * len) r = 0 while r < rows base = r * total_cols + col_off c = 0 while c < len out[c * rows + r] = arr[base + c] c += 1 end r += 1 end out end |