Class: Torch::NN::MultiheadAttention

Inherits:
Module
  • Object
show all
Defined in:
lib/torch/nn/multihead_attention.rb

Instance Attribute Summary

Attributes inherited from Module

#training

Instance Method Summary collapse

Methods inherited from Module

#_apply, #add_module, #apply, #buffers, #call, #children, #cpu, #cuda, #deep_dup, #double, #eval, #float, #half, #inspect, #load_state_dict, #method_missing, #modules, #named_buffers, #named_children, #named_modules, #named_parameters, #parameters, #register_buffer, #register_parameter, #requires_grad!, #respond_to?, #share_memory, #state_dict, #to, #train, #type, #zero_grad

Methods included from Utils

#_activation_fn, #_clones, #_ntuple, #_pair, #_quadrupal, #_single, #_triple

Constructor Details

#initialize(embed_dim, num_heads, dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false, kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil) ⇒ MultiheadAttention

Returns a new instance of MultiheadAttention.

Raises:

  • (ArgumentError)


4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# File 'lib/torch/nn/multihead_attention.rb', line 4

def initialize(
  embed_dim,
  num_heads,
  dropout: 0.0,
  bias: true,
  add_bias_kv: false,
  add_zero_attn: false,
  kdim: nil,
  vdim: nil,
  batch_first: false,
  device: nil,
  dtype: nil
)
  super()

  @embed_dim = embed_dim
  @kdim = kdim || @embed_dim
  @vdim = vdim || @embed_dim

  @qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim

  @num_heads = num_heads
  @dropout = dropout
  @batch_first = batch_first

  @head_dim = @embed_dim.div @num_heads

  raise ArgumentError, "embed_dim must be divisible by num_heads" unless @head_dim * @num_heads == @embed_dim

  if @qkv_same_embed_dim
    @in_proj_weight = Parameter.new(Torch.empty([3 * @embed_dim, @embed_dim]))
    %w(q k v).each { |x| register_parameter("#{x}_proj_weight", nil) }
  else
    @q_proj_weight = Parameter.new(Torch.empty([@embed_dim, @embed_dim]))
    @k_proj_weight = Parameter.new(Torch.empty([@embed_dim, @kdim]))
    @v_proj_weight = Parameter.new(Torch.empty([@embed_dim, @vdim]))

    register_parameter('in_proj_weight', nil)
  end

  if bias
    @in_proj_bias = Parameter.new(Torch.empty(3 * @embed_dim))
  else
    register_parameter('in_proj_bias', nil)
  end

  @out_proj = Linear.new(@embed_dim, @embed_dim, bias: bias)

  if add_bias_kv
    @bias_k = Parameter.new(Torch.empty([1, 1, @embed_dim]))
    @bias_v = Parameter.new(Torch.empty([1, 1, @embed_dim]))
  else
    @bias_k = @bias_v = nil
  end

  @add_zero_attn = add_zero_attn

  reset_parameters
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method in the class Torch::NN::Module

Instance Method Details

#batch_first?Boolean

Returns:

  • (Boolean)


64
65
66
# File 'lib/torch/nn/multihead_attention.rb', line 64

def batch_first?
  !!@batch_first
end

#forward(query, key, value, key_padding_mask: nil, need_weights: true, attn_mask: nil) ⇒ Object



86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# File 'lib/torch/nn/multihead_attention.rb', line 86

def forward(
  query,
  key,
  value,
  key_padding_mask: nil,
  need_weights: true,
  attn_mask: nil
)
  if batch_first?
    query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
  end

  attn_output, attn_output_weights =
    if @qkv_same_embed_dim
      F.multi_head_attention_forward(
        query, key, value,
        @embed_dim, @num_heads,
        @in_proj_weight, @in_proj_bias,
        @bias_k, @bias_v, @add_zero_attn,
        @dropout, @out_proj.weight, @out_proj.bias,
        training: @training,
        key_padding_mask: key_padding_mask,
        need_weights: need_weights,
        attn_mask: attn_mask
      )
    else
      F.multi_head_attention_forward(
        query, key, value,
        @embed_dim, @num_heads,
        @in_proj_weight, @in_proj_bias,
        @bias_k, @bias_v, @add_zero_attn,
        @dropout, @out_proj.weight, @out_proj.bias,
        training: @training,
        key_padding_mask: key_padding_mask,
        need_weights: need_weights,
        attn_mask: attn_mask,
        use_separate_proj_weight: true,
        q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight, v_proj_weight: @v_proj_weight
      )
    end

  attn_output = attn_output.transpose(1, 0) if batch_first?

  [attn_output, attn_output_weights]
end

#reset_parametersObject



68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# File 'lib/torch/nn/multihead_attention.rb', line 68

def reset_parameters
  if @qkv_same_embed_dim
    Init.xavier_uniform!(@in_proj_weight)
  else
    Init.xavier_uniform!(@q_proj_weight)
    Init.xavier_uniform!(@k_proj_weight)
    Init.xavier_uniform!(@v_proj_weight)
  end

  if @in_proj_bias
    Init.constant!(@in_proj_bias, 0.0)
    Init.constant!(@out_proj.bias, 0.0)
  end

  Init.xavier_uniform!(@bias_k) if @bias_k
  Init.xavier_uniform!(@bias_v) if @bias_v
end