Class: Toy::GPT2

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/models/toy_gpt2.rb

Overview

GPT-2: decoder-only transformer LM. ‘stack` (not `blocks`) is kept as the field name for readability — “the stack of N transformer blocks” — independent of the older Spinel field-name-collapse constraint that originally forced it.

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(cfg) ⇒ GPT2

Returns a new instance of GPT2.



78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# File 'lib/toy/models/toy_gpt2.rb', line 78

def initialize(cfg)
  @cfg         = cfg
  @token_embed = Toy::Embedding.new(cfg.vocab, cfg.d_model)
  @pos_embed   = Toy::Embedding.new(cfg.ctx,   cfg.d_model)
  @final_norm  = Toy::LayerNorm.new(cfg.d_model)

  # Block stack: literal-seed + push so Spinel infers
  # PtrArray<Toy::GPT2Block>.
  @stack = [Toy::GPT2Block.new(cfg)]
  li = 1
  while li < cfg.n_layers
    @stack.push(Toy::GPT2Block.new(cfg))
    li += 1
  end
end

Instance Attribute Details

#cfgObject

Returns the value of attribute cfg.



76
77
78
# File 'lib/toy/models/toy_gpt2.rb', line 76

def cfg
  @cfg
end

#final_normObject

Returns the value of attribute final_norm.



76
77
78
# File 'lib/toy/models/toy_gpt2.rb', line 76

def final_norm
  @final_norm
end

#pos_embedObject

Returns the value of attribute pos_embed.



76
77
78
# File 'lib/toy/models/toy_gpt2.rb', line 76

def pos_embed
  @pos_embed
end

#stackObject

Returns the value of attribute stack.



76
77
78
# File 'lib/toy/models/toy_gpt2.rb', line 76

def stack
  @stack
end

#token_embedObject

Returns the value of attribute token_embed.



76
77
78
# File 'lib/toy/models/toy_gpt2.rb', line 76

def token_embed
  @token_embed
end

Instance Method Details

#algorithmObject

Phuong–Hutter style algorithm card for the whole model. See arXiv:2207.09238 for the formalism. Mamba (arXiv:2312.00752) and FlashAttention (arXiv:2205.14135) Algorithm 1 are the modern exemplars for shape-annotated pseudocode.

‘algorithm` returns the structured form (Toy::Card); `algorithm_card` renders it to the human-readable Phuong–Hutter text. The structured form is what prep/card_to_code.rb consumes for round-trip parsing.



133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# File 'lib/toy/models/toy_gpt2.rb', line 133

def algorithm
  c = Toy::Card.new("Toy::GPT2.forward(x, p_start)", "HF GPT-2 family")
  c.add_input("x",       "{1..V}^T", "token IDs")
  c.add_input("p_start", "",        "absolute position of x[0]")
  c.add_output("P",      "R^{T×V}",  "logits")
  c.add_hyper("V",   @cfg.vocab.to_s)
  c.add_hyper("D",   @cfg.d_model.to_s)
  c.add_hyper("H",   @cfg.n_heads.to_s)
  c.add_hyper("D_f", @cfg.d_ff.to_s)
  c.add_hyper("N",   @cfg.n_layers.to_s)
  c.add_hyper("ctx", @cfg.ctx.to_s)
  c.add_param("W_e",         "R^{V×D}",   "token embeddings")
  c.add_param("W_p",         "R^{ctx×D}", "learned absolute positions")
  c.add_param("θ_block_ℓ",   "(ℓ=1..N)",  "per-block; see GPT2Block")
  c.add_param("γ_f, β_f",    "R^D",       "final LayerNorm")
  c.add_param_extra("(total " + Toy.fmt_count(param_count) +
                    ", embeddings tied: logits = e · W_e^⊤)")
  c.step_bind("e", "W_e[x] + W_p[p_start : p_start+T]", "e ∈ R^{T×D}")
  c.step_loop("ℓ ← 1, …, N", "")
  c.step_update("e", "e + Attn(LN(e; γ_ℓ^1, β_ℓ^1, ε); θ_ℓ^attn)",
                "e ∈ R^{T×D}", "")
  c.step_update("e", "e + FFN (LN(e; γ_ℓ^2, β_ℓ^2, ε); θ_ℓ^ffn )",
                "e ∈ R^{T×D}", "")
  c.step_loop_close
  c.step_update("e", "LN(e; γ_f, β_f, ε)", "e ∈ R^{T×D}", "")
  c.step_bind("P", "e · W_e^⊤",            "P ∈ R^{T×V}")
  c.step_return("P")
  c
end

#algorithm_cardObject



163
# File 'lib/toy/models/toy_gpt2.rb', line 163

def algorithm_card; algorithm.render_pseudocode; end

#algorithm_card_fullObject

Recursive card — model + block + sub-ops inlined.



166
167
168
169
170
171
172
173
174
175
# File 'lib/toy/models/toy_gpt2.rb', line 166

def algorithm_card_full
  blk = @stack[0]
  s = algorithm_card + "\n\n"
  s = s + "─── sub-algorithms ─────────────────────────────────────────────────────\n\n"
  s = s + blk.algorithm_card    + "\n\n"
  s = s + blk.ln1.algorithm_card  + "\n\n"
  s = s + blk.attn.algorithm_card + "\n\n"
  s = s + blk.ffn.algorithm_card
  s
end

#forward(ids, start_pos) ⇒ Object

ids: Array<Int> (length T), start_pos: Int → logits [T, V]



95
96
97
98
99
100
101
102
103
104
105
106
107
# File 'lib/toy/models/toy_gpt2.rb', line 95

def forward(ids, start_pos)
  x = @token_embed.lookup(ids)                              # [T, D]
  x.add!(@pos_embed.slice(start_pos, ids.length))           # [T, D]

  li = 0
  while li < @cfg.n_layers
    x = @stack[li].forward(x)                                   # [T, D]
    li += 1
  end

  x_final = @final_norm.forward(x)                              # [T, D]
  x_final.matmul_t(@token_embed.weight)                      # [T, V]
end

#param_countObject

Total trainable parameter count. Tied embeddings counted once.



110
111
112
113
114
115
116
117
118
119
# File 'lib/toy/models/toy_gpt2.rb', line 110

def param_count
  total = @token_embed.param_count + @pos_embed.param_count +
          @final_norm.param_count
  li = 0
  while li < @cfg.n_layers
    total = total + @stack[li].param_count
    li += 1
  end
  total
end