Class: GTCRN

Inherits:
Object
  • Object
show all
Defined in:
lib/gtcrn.rb

Constant Summary collapse

MODEL_PATH =
File.join(__dir__, "../vendor/gtcrn/stream/onnx_models/gtcrn_simple.onnx").freeze
ISTFT_OPTS =
{
  n_fft: 512,
  hop_length: 256,
  win_length: 512,
  window: Torch.hann_window(512).pow(0.5)
}
STFT_OPTS =
ISTFT_OPTS.merge(
  pad_mode: "reflect",
  return_complex: true
)

Instance Method Summary collapse

Constructor Details

#initializeGTCRN

Returns a new instance of GTCRN.



20
21
22
23
# File 'lib/gtcrn.rb', line 20

def initialize
  @session = OnnxRuntime::InferenceSession.new(MODEL_PATH)
  @output_names = @session.outputs.collect {|output| output[:name]}
end

Instance Method Details

#enhance_speech(path, dest = nil) ⇒ Object



25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# File 'lib/gtcrn.rb', line 25

def enhance_speech(path, dest=nil)
  path = Pathname(path)
  if dest
    dest = Pathname(dest)
    if dest.directory?
      dest = dest/path.basename.sub_ext(".enhanced" + path.extname)
    end
  else
    dest = path.sub_ext(".enhanced" + path.extname)
  end

  waveform, sample_rate = TorchAudio.load(path.to_path)
  raise "Sampling rate must be 16000 Hz, but given: #{sample_rate} Hz" unless sample_rate == 16000
  enhanced = enhance_speech_waveform(waveform)
  TorchAudio.save(dest.to_path, enhanced, sample_rate)

  dest
end

#enhance_speech_waveform(waveform) ⇒ Object



44
45
46
47
48
49
50
51
52
# File 'lib/gtcrn.rb', line 44

def enhance_speech_waveform(waveform)
  ndim = waveform.ndim
  unless ndim == 1 or ndim == 2
    raise ArgumentError, "wrong dimension of argment (given #{ndim}, expected 1D or 2D"
  end
  waveform = [waveform] if ndim == 1
  channels = waveform.collect {|channel| enhance_speech_waveform_channel(channel)}
  ndim == 1 ? channels[0] : Torch.stack(channels)
end

#enhance_speech_waveform_channel(channel) ⇒ Object



54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# File 'lib/gtcrn.rb', line 54

def enhance_speech_waveform_channel(channel)
  conv_cache, tra_cache, inter_cache = 1.upto(3).collect {|i|
    OnnxRuntime::OrtValue.from_numo(
      Numo::SFloat.zeros(*@session.inputs[i][:shape])
    )
  }
  inputs = Torch.view_as_real(
    Torch.stft(channel, **STFT_OPTS)[nil]
  ).numo
  outputs = []
  inputs.shape[-2].times do |i|
    enh, conv_cache, tra_cache, inter_cache = @session.run(
      @output_names,
      {
        mix: OnnxRuntime::OrtValue.from_numo(inputs[0.., 0.., i..i, 0..]),
        conv_cache:, tra_cache:, inter_cache:,
      },
      output_type: :ort_value
    )
    outputs << enh.numo
  end
  concated = Numo::NArray.concatenate(outputs, axis: 2)
  real = concated[0.., 0.., 0.., 0]
  imag = concated[0.., 0.., 0.., 1]
  enhanced = Torch.istft(
    Torch.complex(Torch.from_numo(real), Torch.from_numo(imag)),
    **ISTFT_OPTS
  )
  enhanced.squeeze(0)
end