Class: GTCRN

Inherits:
Object
  • Object
show all
Defined in:
lib/gtcrn.rb

Constant Summary collapse

MODEL_PATH =
File.join(__dir__, "../vendor/gtcrn/stream/onnx_models/gtcrn_simple.onnx").freeze
ISTFT_OPTS =
{
  n_fft: 512,
  hop_length: 256,
  win_length: 512,
  window: Torch.hann_window(512).pow(0.5)
}
STFT_OPTS =
ISTFT_OPTS.merge(
  pad_mode: "reflect",
  return_complex: true
)

Instance Method Summary collapse

Constructor Details

#initializeGTCRN

Returns a new instance of GTCRN.



20
21
22
# File 'lib/gtcrn.rb', line 20

def initialize
  @session = OnnxRuntime::InferenceSession.new(MODEL_PATH)
end

Instance Method Details

#enhance_speech(path, dest = nil) ⇒ Object



24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# File 'lib/gtcrn.rb', line 24

def enhance_speech(path, dest=nil)
  path = Pathname(path)
  if dest
    dest = Pathname(dest)
    if dest.directory?
      dest = dest/path.basename.sub_ext(".enhanced" + path.extname)
    end
  else
    dest = path.sub_ext(".enhanced" + path.extname)
  end

  waveform, sample_rate = TorchAudio.load(path.to_path)
  raise "Sampling rate must be 16000 Hz, but given: #{sample_rate} Hz" unless sample_rate == 16000
  enhanced = enhance_speech_waveform(waveform)
  TorchAudio.save(dest.to_path, enhanced.squeeze, sample_rate)

  dest
end

#enhance_speech_waveform(waveform) ⇒ Object



43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# File 'lib/gtcrn.rb', line 43

def enhance_speech_waveform(waveform)
  conv_cache, tra_cache, inter_cache = 1.upto(3).collect {|i| Numo::SFloat.zeros(*@session.inputs[i][:shape]) }
  inputs = Torch.view_as_real(
    Torch.stft(waveform[0], **STFT_OPTS)[nil]
  ).numo
  outputs = []
  inputs.shape[-2].times do |i|
    enh, conv_cache, tra_cache, inter_cache = @session.run(
      @session.outputs.collect {|output| output[:name]},
      {
        mix: OnnxRuntime::OrtValue.from_numo(inputs[0.., 0.., i..i, 0..]),
        conv_cache: OnnxRuntime::OrtValue.from_numo(conv_cache),
        tra_cache: OnnxRuntime::OrtValue.from_numo(tra_cache),
        inter_cache: OnnxRuntime::OrtValue.from_numo(inter_cache)
      },
      output_type: :numo
    )
    outputs << enh
  end
  concated = Numo::NArray.concatenate(outputs, axis: 2)
  real = concated[0.., 0.., 0.., 0]
  imag = concated[0.., 0.., 0.., 1]
  enhanced = Torch.istft(
    Torch.from_numo(real) + 1i * Torch.from_numo(imag),
    **ISTFT_OPTS
  )
  enhanced
end