Class: Toy::GPT2Block

Inherits:
Object
  • Object
show all
Defined in:
lib/toy/models/toy_gpt2.rb

Overview

One transformer block: pre-norm, residual after attention, pre-norm, residual after FFN.

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(cfg) ⇒ GPT2Block

Returns a new instance of GPT2Block.



39
40
41
42
43
44
# File 'lib/toy/models/toy_gpt2.rb', line 39

def initialize(cfg)
  @ln1  = Toy::LayerNorm.new(cfg.d_model)
  @ln2  = Toy::LayerNorm.new(cfg.d_model)
  @attn = Toy::CausalSelfAttention.new(cfg.d_model, cfg.n_heads)
  @ffn  = Toy::FFN.new(cfg.d_model, cfg.d_ff, :gelu_new)
end

Instance Attribute Details

#attnObject

Returns the value of attribute attn.



37
38
39
# File 'lib/toy/models/toy_gpt2.rb', line 37

def attn
  @attn
end

#ffnObject

Returns the value of attribute ffn.



37
38
39
# File 'lib/toy/models/toy_gpt2.rb', line 37

def ffn
  @ffn
end

#ln1Object

Returns the value of attribute ln1.



37
38
39
# File 'lib/toy/models/toy_gpt2.rb', line 37

def ln1
  @ln1
end

#ln2Object

Returns the value of attribute ln2.



37
38
39
# File 'lib/toy/models/toy_gpt2.rb', line 37

def ln2
  @ln2
end

Instance Method Details

#algorithmObject



58
59
60
61
62
63
64
65
66
# File 'lib/toy/models/toy_gpt2.rb', line 58

def algorithm
  c = Toy::Card.new("GPT2Block.forward(x)", "")
  c.add_input("x",  "R^{T×D}", "")
  c.add_output("x", "R^{T×D}", "")
  c.step_update("x", "x + Attn(LN(x; γ_1, β_1, ε))", "", "residual after attention")
  c.step_update("x", "x + FFN (LN(x; γ_2, β_2, ε))", "", "residual after FFN")
  c.step_return("x")
  c
end

#algorithm_cardObject



68
# File 'lib/toy/models/toy_gpt2.rb', line 68

def algorithm_card; algorithm.render_pseudocode; end

#forward(x) ⇒ Object

x: [T, D] → [T, D]



47
48
49
50
51
# File 'lib/toy/models/toy_gpt2.rb', line 47

def forward(x)
  x.add!(@attn.forward(@ln1.forward(x)))    # residual after attention
  x.add!(@ffn.forward(@ln2.forward(x)))       # residual after FFN
  x
end

#param_countObject



53
54
55
56
# File 'lib/toy/models/toy_gpt2.rb', line 53

def param_count
  @ln1.param_count + @ln2.param_count +
    @attn.param_count + @ffn.param_count
end