Class: GRX::NN::Module

Inherits:
Object
  • Object
show all
Defined in:
lib/grx/nn.rb

Overview

Module — clase base para todas las capas

Instance Method Summary collapse

Instance Method Details

#call(*args) ⇒ Object

Subclases implementan forward



33
34
35
# File 'lib/grx/nn.rb', line 33

def call(*args)
  forward(*args)
end

#parametersObject

Retorna todos los parámetros entrenables (para pasarlos al optimizador)



10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# File 'lib/grx/nn.rb', line 10

def parameters
  instance_variables.flat_map do |var|
    val = instance_variable_get(var)
    case val
    when Tensor    then val.requires_grad ? [val] : []
    when Module    then val.parameters
    when Array     then val.flat_map { |v|
      case v
      when Tensor then v.requires_grad ? [v] : []
      when Module then v.parameters
      else []
      end
    }
    else []
    end
  end
end

#zero_gradObject



28
29
30
# File 'lib/grx/nn.rb', line 28

def zero_grad
  parameters.each(&:zero_grad!)
end