Module: Ignis::AI::GPT2Loader

Defined in:
lib/nnw/ai/gpt2_loader.rb

Overview

Loads HuggingFace GPT-2 safetensors checkpoints into an Ignis Transformer::Model.

GPT-2 is NOT a drop-in for nn.Linear: its attention/MLP projections are ‘Conv1D` layers whose weights are stored as [in, out] (transposed vs. nn.Linear’s [out, in]), the QKV projection is a single fused [embed, 3*embed] tensor, and the LM head is tied to the token embedding. This loader applies those transforms while copying weights in.

Examples:

model = Ignis::AI::Transformer::Model.gpt2_small
Ignis::AI::GPT2Loader.load(model, "gpt2")
model.eval!

Class Method Summary collapse

Class Method Details

.load(model, dir) ⇒ model

Load weights from <dir>/model.safetensors into model (in place).

Parameters:

Returns:

  • (model)

Raises:

  • (ArgumentError)


26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# File 'lib/nnw/ai/gpt2_loader.rb', line 26

def load(model, dir)
  path = File.join(dir, "model.safetensors")
  raise ArgumentError, "not found: #{path}" unless File.exist?(path)

  header, data_offset = read_header(path)
  params = model.named_parameters
  embed = model.embed_dim
  n_layer = model.num_layers

  File.open(path, "rb") do |f|
    get = lambda do |name|
      meta = header[name] or raise "missing safetensors tensor: #{name}"
      b, e = meta["data_offsets"]
      f.seek(data_offset + b)
      raw = f.read(e - b)
      { values: raw.unpack("e*"), shape: meta["shape"] } # F32 little-endian
    end

    set = lambda do |pname, values|
      p = params[pname] or raise "missing model param: #{pname}"
      unless p.numel == values.length
        raise "size mismatch for #{pname}: param has #{p.numel}, data has #{values.length}"
      end
      p.data.from_host(values)
    end

    # Embeddings (direct)
    set.call("token_embedding.weight", get.call("wte.weight")[:values])
    set.call("position_embedding.weight", get.call("wpe.weight")[:values])

    n_layer.times do |i|
      set.call("blocks.#{i}.norm1.weight", get.call("h.#{i}.ln_1.weight")[:values])
      set.call("blocks.#{i}.norm1.bias",   get.call("h.#{i}.ln_1.bias")[:values])
      set.call("blocks.#{i}.norm2.weight", get.call("h.#{i}.ln_2.weight")[:values])
      set.call("blocks.#{i}.norm2.bias",   get.call("h.#{i}.ln_2.bias")[:values])

      # Fused QKV: c_attn.weight is Conv1D [embed, 3*embed]. Split columns into
      # q|k|v and transpose each into nn.Linear layout [out, in] = [embed, embed].
      cattn_w = get.call("h.#{i}.attn.c_attn.weight")[:values]
      cattn_b = get.call("h.#{i}.attn.c_attn.bias")[:values]
      set.call("blocks.#{i}.attention.q_proj.weight", transpose_slice_cols(cattn_w, embed, 3 * embed, 0 * embed, embed))
      set.call("blocks.#{i}.attention.k_proj.weight", transpose_slice_cols(cattn_w, embed, 3 * embed, 1 * embed, embed))
      set.call("blocks.#{i}.attention.v_proj.weight", transpose_slice_cols(cattn_w, embed, 3 * embed, 2 * embed, embed))
      set.call("blocks.#{i}.attention.q_proj.bias", cattn_b[(0 * embed)...(1 * embed)])
      set.call("blocks.#{i}.attention.k_proj.bias", cattn_b[(1 * embed)...(2 * embed)])
      set.call("blocks.#{i}.attention.v_proj.bias", cattn_b[(2 * embed)...(3 * embed)])

      # Attention output projection: Conv1D [embed, embed] -> transpose.
      aproj = get.call("h.#{i}.attn.c_proj.weight")[:values]
      set.call("blocks.#{i}.attention.out_proj.weight", transpose(aproj, embed, embed))
      set.call("blocks.#{i}.attention.out_proj.bias", get.call("h.#{i}.attn.c_proj.bias")[:values])

      # MLP fc1: c_fc Conv1D [embed, ff] -> transpose -> [ff, embed].
      fc = get.call("h.#{i}.mlp.c_fc.weight")
      ff = fc[:shape][1]
      set.call("blocks.#{i}.feed_forward.fc1.weight", transpose(fc[:values], embed, ff))
      set.call("blocks.#{i}.feed_forward.fc1.bias", get.call("h.#{i}.mlp.c_fc.bias")[:values])

      # MLP fc2: c_proj Conv1D [ff, embed] -> transpose -> [embed, ff].
      fp = get.call("h.#{i}.mlp.c_proj.weight")[:values]
      set.call("blocks.#{i}.feed_forward.fc2.weight", transpose(fp, ff, embed))
      set.call("blocks.#{i}.feed_forward.fc2.bias", get.call("h.#{i}.mlp.c_proj.bias")[:values])
    end

    set.call("norm.weight", get.call("ln_f.weight")[:values])
    set.call("norm.bias",   get.call("ln_f.bias")[:values])

    # LM head is tied to the token embedding (logits = x @ wte^T).
    set.call("head.weight", get.call("wte.weight")[:values])
  end

  model
end

.read_header(path) ⇒ Object

Read the safetensors JSON header. @return [[Hash, Integer]] header, data_offset



101
102
103
104
105
106
107
108
# File 'lib/nnw/ai/gpt2_loader.rb', line 101

def read_header(path)
  File.open(path, "rb") do |f|
    n = f.read(8).unpack1("Q<")
    h = JSON.parse(f.read(n))
    h.delete("__metadata__")
    [h, 8 + n]
  end
end

.transpose(arr, rows, cols) ⇒ Object

Transpose a row-major [rows, cols] flat array → [cols, rows].



111
112
113
114
115
116
117
118
119
120
121
122
123
124
# File 'lib/nnw/ai/gpt2_loader.rb', line 111

def transpose(arr, rows, cols)
  out = Array.new(rows * cols)
  r = 0
  while r < rows
    base = r * cols
    c = 0
    while c < cols
      out[c * rows + r] = arr[base + c]
      c += 1
    end
    r += 1
  end
  out
end

.transpose_slice_cols(arr, rows, total_cols, col_off, len) ⇒ Object

Slice columns [col_off, col_off+len) from a [rows, total_cols] array, then transpose → [len, rows]. (Used to split the fused QKV Conv1D weight.)



128
129
130
131
132
133
134
135
136
137
138
139
140
141
# File 'lib/nnw/ai/gpt2_loader.rb', line 128

def transpose_slice_cols(arr, rows, total_cols, col_off, len)
  out = Array.new(rows * len)
  r = 0
  while r < rows
    base = r * total_cols + col_off
    c = 0
    while c < len
      out[c * rows + r] = arr[base + c]
      c += 1
    end
    r += 1
  end
  out
end