Class: Ignis::AI::Transformer::FeedForward
- Inherits:
-
NN::Module
- Object
- NN::Module
- Ignis::AI::Transformer::FeedForward
- Defined in:
- lib/nnw/ai/transformer/feed_forward.rb
Overview
Feed-forward network: Linear → Activation → Dropout → Linear Used in each Transformer block after attention.
Instance Attribute Summary
Attributes inherited from NN::Module
Instance Method Summary collapse
-
#forward(x) ⇒ Tensor
Forward pass.
-
#initialize(embed_dim, ff_dim, activation: :gelu, dropout: 0.0, device_id: 0) ⇒ FeedForward
constructor
A new instance of FeedForward.
- #to_s ⇒ String
Methods inherited from NN::Module
#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!
Constructor Details
#initialize(embed_dim, ff_dim, activation: :gelu, dropout: 0.0, device_id: 0) ⇒ FeedForward
Returns a new instance of FeedForward.
14 15 16 17 18 19 20 21 |
# File 'lib/nnw/ai/transformer/feed_forward.rb', line 14 def initialize(, ff_dim, activation: :gelu, dropout: 0.0, device_id: 0) super() @activation = activation @fc1 = register_module("fc1", NN::Linear.new(, ff_dim, device_id: device_id)) @fc2 = register_module("fc2", NN::Linear.new(ff_dim, , device_id: device_id)) @dropout = register_module("dropout", NN::Dropout.new(p: dropout)) end |
Instance Method Details
#forward(x) ⇒ Tensor
Forward pass.
26 27 28 29 30 31 |
# File 'lib/nnw/ai/transformer/feed_forward.rb', line 26 def forward(x) h = @fc1.call(x) h = apply_activation(h) h = @dropout.call(h) @fc2.call(h) end |
#to_s ⇒ String
34 35 36 |
# File 'lib/nnw/ai/transformer/feed_forward.rb', line 34 def to_s "FeedForward(fc1=#{@fc1}, activation=#{@activation}, fc2=#{@fc2})" end |