Class: Ignis::AI::NN::Embedding

Inherits:
Module
  • Object
show all
Defined in:
lib/nnw/ai/nn/embedding.rb

Overview

Embedding layer: maps integer indices to dense vectors. Forward uses gather_rows JIT kernel. Backward uses scatter_add with atomicAdd.

Instance Attribute Summary collapse

Attributes inherited from Module

#training

Instance Method Summary collapse

Methods inherited from Module

#call, #eval!, #load_state_dict, #named_parameters, #num_parameters, #parameters, #state_dict, #to, #train!, #zero_grad!

Constructor Details

#initialize(num_embeddings, embedding_dim, device_id: 0) ⇒ Embedding

Returns a new instance of Embedding.

Parameters:

  • num_embeddings (Integer)

    vocabulary size

  • embedding_dim (Integer)

    dimension of embedding vectors

  • device_id (Integer) (defaults to: 0)


18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# File 'lib/nnw/ai/nn/embedding.rb', line 18

def initialize(num_embeddings, embedding_dim, device_id: 0)
  super()
  @num_embeddings = num_embeddings
  @embedding_dim = embedding_dim

  # Initialize uniform[-scale, scale] (scale = 1/sqrt(embedding_dim)) via a
  # device kernel. The old host Array.new(num_embeddings*embedding_dim) was a
  # 262M-element / ~10GB array for a 128k-vocab model — infeasible. The kaiming
  # uniform kernel with bound=scale produces the same [-scale, scale] range.
  scale = 1.0 / Math.sqrt(embedding_dim)
  weight_nv = Ignis::Shared::NvArray.new(shape: [num_embeddings, embedding_dim],
                                        dtype: :float32, device_id: device_id)
  weight_nv.to_device
  n = num_embeddings * embedding_dim
  init_kernel = Ignis::JIT::Kernels::Elementwise.kaiming_uniform_init
  init_kernel.launch(grid: [(n + 255) / 256], block: [256],
                     args: [weight_nv, scale.to_f, Ignis::JIT::Kernel::U64.new(::Random.new.rand(2**64)), n])

  @weight = register_parameter("weight",
             Tensor.new(data: weight_nv, requires_grad: true))
end

Instance Attribute Details

#weightTensor (readonly)

Returns weight matrix [num_embeddings, embedding_dim].

Returns:

  • (Tensor)

    weight matrix [num_embeddings, embedding_dim]



13
14
15
# File 'lib/nnw/ai/nn/embedding.rb', line 13

def weight
  @weight
end

Instance Method Details

#forward(indices) ⇒ Tensor

Forward pass: gather rows from weight table.

Parameters:

  • indices (Tensor)

    integer indices [batch_size, seq_len] (int32 on GPU)

Returns:

  • (Tensor)

    embeddings [batch_size, seq_len, embedding_dim]



43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# File 'lib/nnw/ai/nn/embedding.rb', line 43

def forward(indices)
  num_indices = indices.numel
  output_shape = indices.shape + [@embedding_dim]
  output_nv = Ignis::Shared::NvArray.new(shape: output_shape, dtype: :float32,
                                        device_id: @weight.device_id)
  output_nv.from_host(Array.new(num_indices * @embedding_dim, 0.0))

  kernel = Ignis::JIT::Kernels::Elementwise.gather_rows
  total = num_indices * @embedding_dim
  kernel.launch(grid: [(total + 255) / 256], block: [256],
                args: [@weight.data, indices.data, output_nv, num_indices, @embedding_dim])

  result = Tensor.new(data: output_nv,
                      requires_grad: @weight.requires_grad,
                      is_leaf: false)

  if @weight.requires_grad
    saved_indices = indices.data
    saved_weight = @weight
    Tape.record(result, inputs: [@weight]) do |grad|
      # scatter_add: accumulate gradients for each embedding index
      grad_weight = Ignis::Shared::NvArray.new(
        shape: [@num_embeddings, @embedding_dim],
        dtype: :float32, device_id: @weight.device_id)
      grad_weight.from_host(Array.new(@num_embeddings * @embedding_dim, 0.0))

      scatter_k = Ignis::JIT::Kernels::Elementwise.scatter_add
      scatter_k.launch(grid: [(total + 255) / 256], block: [256],
                       args: [grad, saved_indices, grad_weight, num_indices, @embedding_dim])
      [grad_weight]
    end
  end

  result
end

#to_sString

Returns:

  • (String)


80
81
82
# File 'lib/nnw/ai/nn/embedding.rb', line 80

def to_s
  "Embedding(num=#{@num_embeddings}, dim=#{@embedding_dim})"
end