Class: Ignis::AI::NN::RMSNorm

Inherits:
Module
  • Object
show all
Defined in:
lib/nnw/ai/nn/rms_norm.rb

Overview

Root-mean-square layer normalization: y = gamma * x / sqrt(mean(x^2) + eps). Used by Llama, Qwen, Mistral, SmolLM, Phi. Unlike LayerNorm there is no mean-subtraction and no bias — only a learned per-feature scale (gamma).

Instance Attribute Summary collapse

Attributes inherited from Module

#training

Instance Method Summary collapse

Methods inherited from Module

#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!

Constructor Details

#initialize(normalized_shape, eps: 1e-6, device_id: 0) ⇒ RMSNorm

Returns a new instance of RMSNorm.

Parameters:

  • normalized_shape (Integer)

    size of the last dimension

  • eps (Float) (defaults to: 1e-6)

    epsilon for numerical stability (Llama/Qwen use 1e-6/1e-5)

  • device_id (Integer) (defaults to: 0)


18
19
20
21
22
23
24
25
26
27
28
# File 'lib/nnw/ai/nn/rms_norm.rb', line 18

def initialize(normalized_shape, eps: 1e-6, device_id: 0)
  super()
  @normalized_shape = normalized_shape
  @eps = eps

  weight_nv = Ignis::Shared::NvArray.new(shape: [normalized_shape],
                                        dtype: :float32, device_id: device_id)
  weight_nv.from_host(Array.new(normalized_shape, 1.0))
  @weight = register_parameter("weight",
             Tensor.new(data: weight_nv, requires_grad: true))
end

Instance Attribute Details

#weightTensor (readonly)

Returns gamma (scale), initialized to ones.

Returns:

  • (Tensor)

    gamma (scale), initialized to ones



13
14
15
# File 'lib/nnw/ai/nn/rms_norm.rb', line 13

def weight
  @weight
end

Instance Method Details

#forward(x) ⇒ Tensor

Parameters:

  • x (Tensor)

    input [*, normalized_shape]

Returns:

  • (Tensor)


32
33
34
# File 'lib/nnw/ai/nn/rms_norm.rb', line 32

def forward(x)
  x.rms_norm(@weight, eps: @eps)
end

#to_sString

Returns:

  • (String)


37
38
39
# File 'lib/nnw/ai/nn/rms_norm.rb', line 37

def to_s
  "RMSNorm(#{@normalized_shape}, eps=#{@eps})"
end