Class: GTCRN

Inherits:
Object
  • Object
show all
Includes:
NDAV::Converter
Defined in:
lib/gtcrn.rb

Constant Summary collapse

MODEL_PATH =
File.join(__dir__, "../vendor/gtcrn/stream/onnx_models/gtcrn_simple.onnx").freeze
ISTFT_OPTS =
{
  n_fft: 512,
  hop_length: 256,
  win_length: 512,
  window: Torch.hann_window(512).pow(0.5)
}
STFT_OPTS =
ISTFT_OPTS.merge(
  pad_mode: "reflect",
  return_complex: true
)

Instance Method Summary collapse

Constructor Details

#initializeGTCRN

Returns a new instance of GTCRN.



22
23
24
25
26
# File 'lib/gtcrn.rb', line 22

def initialize
  @session = OnnxRuntime::InferenceSession.new(MODEL_PATH)
  @cache_shapes = @session.inputs[1..].collect {|input| input[:shape]}
  @output_names = @session.outputs.collect {|output| output[:name]}
end

Instance Method Details

#enhance_speech(path, dest = nil) ⇒ Object



28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# File 'lib/gtcrn.rb', line 28

def enhance_speech(path, dest=nil)
  path = Pathname(path)
  if dest
    dest = Pathname(dest)
    if dest.directory?
      dest = dest/path.basename.sub_ext(".enhanced" + path.extname)
    end
  else
    dest = path.sub_ext(".enhanced" + path.extname)
  end

  waveform, sample_rate = TorchAudio.load(path.to_path)
  raise "Sampling rate must be 16000 Hz, but given: #{sample_rate} Hz" unless sample_rate == 16000
  enhanced = enhance_speech_waveform(waveform)
  TorchAudio.save(dest.to_path, enhanced, sample_rate)

  dest
end

#enhance_speech_waveform(waveform) ⇒ Object



47
48
49
50
51
52
53
54
55
# File 'lib/gtcrn.rb', line 47

def enhance_speech_waveform(waveform)
  ndim = waveform.ndim
  unless ndim == 1 or ndim == 2
    raise ArgumentError, "wrong dimension of argment (given #{ndim}, expected 1D or 2D)"
  end
  waveform = [waveform] if ndim == 1
  channels = waveform.collect {|channel| enhance_speech_waveform_channel(channel)}
  ndim == 1 ? channels[0] : Torch.stack(channels)
end

#enhance_speech_waveform_channel(channel) ⇒ Object



57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# File 'lib/gtcrn.rb', line 57

def enhance_speech_waveform_channel(channel)
  conv_cache, tra_cache, inter_cache = @cache_shapes.collect {|shape|
    OrtValue(Torch.zeros(*shape, dtype: :float32))
  }
  inputs = Torch.view_as_real(
    Torch.stft(channel, **STFT_OPTS)[nil]
  )
  outputs = []
  inputs.shape[-2].times do |i|
    input = inputs[0.., 0.., i..i, 0..]
    enh, conv_cache, tra_cache, inter_cache = @session.run(
      @output_names,
      {mix: OrtValue(input), conv_cache:, tra_cache:, inter_cache:},
      output_type: :ort_value
    )
    outputs << TorchTensor(enh)
  end
  concated = Torch.cat(outputs, dim: 2)
  Torch
    .istft(Torch.view_as_complex(concated), **ISTFT_OPTS)
    .squeeze(0)
end