Class: Ignis::AI::NN::Embedding
- Defined in:
- lib/nnw/ai/nn/embedding.rb
Overview
Embedding layer: maps integer indices to dense vectors. Forward uses gather_rows JIT kernel. Backward uses scatter_add with atomicAdd.
Instance Attribute Summary collapse
-
#weight ⇒ Tensor
readonly
Weight matrix [num_embeddings, embedding_dim].
Attributes inherited from Module
Instance Method Summary collapse
-
#forward(indices) ⇒ Tensor
Forward pass: gather rows from weight table.
-
#initialize(num_embeddings, embedding_dim, device_id: 0) ⇒ Embedding
constructor
A new instance of Embedding.
- #to_s ⇒ String
Methods inherited from Module
#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!
Constructor Details
#initialize(num_embeddings, embedding_dim, device_id: 0) ⇒ Embedding
Returns a new instance of Embedding.
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# File 'lib/nnw/ai/nn/embedding.rb', line 18 def initialize(, , device_id: 0) super() @num_embeddings = @embedding_dim = # Initialize uniform[-scale, scale] (scale = 1/sqrt(embedding_dim)) via a # device kernel. The old host Array.new(num_embeddings*embedding_dim) was a # 262M-element / ~10GB array for a 128k-vocab model — infeasible. The kaiming # uniform kernel with bound=scale produces the same [-scale, scale] range. scale = 1.0 / Math.sqrt() weight_nv = Ignis::Shared::NvArray.new(shape: [, ], dtype: :float32, device_id: device_id) weight_nv.to_device n = * init_kernel = Ignis::JIT::Kernels::Elementwise.kaiming_uniform_init init_kernel.launch(grid: [(n + 255) / 256], block: [256], args: [weight_nv, scale.to_f, Ignis::JIT::Kernel::U64.new(::Random.new.rand(2**64)), n]) @weight = register_parameter("weight", Tensor.new(data: weight_nv, requires_grad: true)) end |
Instance Attribute Details
#weight ⇒ Tensor (readonly)
Returns weight matrix [num_embeddings, embedding_dim].
13 14 15 |
# File 'lib/nnw/ai/nn/embedding.rb', line 13 def weight @weight end |
Instance Method Details
#forward(indices) ⇒ Tensor
Forward pass: gather rows from weight table.
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# File 'lib/nnw/ai/nn/embedding.rb', line 43 def forward(indices) num_indices = indices.numel output_shape = indices.shape + [@embedding_dim] output_nv = Ignis::Shared::NvArray.new(shape: output_shape, dtype: :float32, device_id: @weight.device_id) output_nv.from_host(Array.new(num_indices * @embedding_dim, 0.0)) kernel = Ignis::JIT::Kernels::Elementwise.gather_rows total = num_indices * @embedding_dim kernel.launch(grid: [(total + 255) / 256], block: [256], args: [@weight.data, indices.data, output_nv, num_indices, @embedding_dim]) result = Tensor.new(data: output_nv, requires_grad: @weight.requires_grad, is_leaf: false) if @weight.requires_grad saved_indices = indices.data saved_weight = @weight Tape.record(result, inputs: [@weight]) do |grad| # scatter_add: accumulate gradients for each embedding index grad_weight = Ignis::Shared::NvArray.new( shape: [@num_embeddings, @embedding_dim], dtype: :float32, device_id: @weight.device_id) grad_weight.from_host(Array.new(@num_embeddings * @embedding_dim, 0.0)) scatter_k = Ignis::JIT::Kernels::Elementwise.scatter_add scatter_k.launch(grid: [(total + 255) / 256], block: [256], args: [grad, saved_indices, grad_weight, num_indices, @embedding_dim]) [grad_weight] end end result end |
#to_s ⇒ String
80 81 82 |
# File 'lib/nnw/ai/nn/embedding.rb', line 80 def to_s "Embedding(num=#{@num_embeddings}, dim=#{@embedding_dim})" end |