Class: Ignis::AI::Transformer::FeedForward

Inherits:
NN::Module
  • Object
show all
Defined in:
lib/nnw/ai/transformer/feed_forward.rb

Overview

Feed-forward network: Linear → Activation → Dropout → Linear Used in each Transformer block after attention.

Instance Attribute Summary

Attributes inherited from NN::Module

#training

Instance Method Summary collapse

Methods inherited from NN::Module

#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!

Constructor Details

#initialize(embed_dim, ff_dim, activation: :gelu, dropout: 0.0, device_id: 0) ⇒ FeedForward

Returns a new instance of FeedForward.

Parameters:

  • embed_dim (Integer)

    model dimension

  • ff_dim (Integer)

    feed-forward hidden dimension (typically 4x embed_dim)

  • activation (Symbol) (defaults to: :gelu)

    :gelu, :relu, or :silu

  • dropout (Float) (defaults to: 0.0)

    dropout rate

  • device_id (Integer) (defaults to: 0)


14
15
16
17
18
19
20
21
# File 'lib/nnw/ai/transformer/feed_forward.rb', line 14

def initialize(embed_dim, ff_dim, activation: :gelu, dropout: 0.0, device_id: 0)
  super()
  @activation = activation

  @fc1 = register_module("fc1", NN::Linear.new(embed_dim, ff_dim, device_id: device_id))
  @fc2 = register_module("fc2", NN::Linear.new(ff_dim, embed_dim, device_id: device_id))
  @dropout = register_module("dropout", NN::Dropout.new(p: dropout))
end

Instance Method Details

#forward(x) ⇒ Tensor

Forward pass.

Parameters:

  • x (Tensor)

Returns:

  • (Tensor)


26
27
28
29
30
31
# File 'lib/nnw/ai/transformer/feed_forward.rb', line 26

def forward(x)
  h = @fc1.call(x)
  h = apply_activation(h)
  h = @dropout.call(h)
  @fc2.call(h)
end

#to_sString

Returns:

  • (String)


34
35
36
# File 'lib/nnw/ai/transformer/feed_forward.rb', line 34

def to_s
  "FeedForward(fc1=#{@fc1}, activation=#{@activation}, fc2=#{@fc2})"
end