Class: GTCRN
Constant Summary collapse
- MODEL_PATH =
File.join(__dir__, "../vendor/gtcrn/stream/onnx_models/gtcrn_simple.onnx").freeze
- ISTFT_OPTS =
{ n_fft: 512, hop_length: 256, win_length: 512, window: Torch.hann_window(512).pow(0.5) }
- STFT_OPTS =
ISTFT_OPTS.merge( pad_mode: "reflect", return_complex: true )
Instance Method Summary collapse
- #enhance_speech(path, dest = nil) ⇒ Object
- #enhance_speech_waveform(waveform) ⇒ Object
- #enhance_speech_waveform_channel(channel) ⇒ Object
-
#initialize ⇒ GTCRN
constructor
A new instance of GTCRN.
Constructor Details
#initialize ⇒ GTCRN
Returns a new instance of GTCRN.
22 23 24 25 26 |
# File 'lib/gtcrn.rb', line 22 def initialize @session = OnnxRuntime::InferenceSession.new(MODEL_PATH) @cache_shapes = @session.inputs[1..].collect {|input| input[:shape]} @output_names = @session.outputs.collect {|output| output[:name]} end |
Instance Method Details
#enhance_speech(path, dest = nil) ⇒ Object
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# File 'lib/gtcrn.rb', line 28 def enhance_speech(path, dest=nil) path = Pathname(path) if dest dest = Pathname(dest) if dest.directory? dest = dest/path.basename.sub_ext(".enhanced" + path.extname) end else dest = path.sub_ext(".enhanced" + path.extname) end waveform, sample_rate = TorchAudio.load(path.to_path) raise "Sampling rate must be 16000 Hz, but given: #{sample_rate} Hz" unless sample_rate == 16000 enhanced = enhance_speech_waveform(waveform) TorchAudio.save(dest.to_path, enhanced, sample_rate) dest end |
#enhance_speech_waveform(waveform) ⇒ Object
47 48 49 50 51 52 53 54 55 |
# File 'lib/gtcrn.rb', line 47 def enhance_speech_waveform(waveform) ndim = waveform.ndim unless ndim == 1 or ndim == 2 raise ArgumentError, "wrong dimension of argment (given #{ndim}, expected 1D or 2D)" end waveform = [waveform] if ndim == 1 channels = waveform.collect {|channel| enhance_speech_waveform_channel(channel)} ndim == 1 ? channels[0] : Torch.stack(channels) end |
#enhance_speech_waveform_channel(channel) ⇒ Object
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# File 'lib/gtcrn.rb', line 57 def enhance_speech_waveform_channel(channel) conv_cache, tra_cache, inter_cache = @cache_shapes.collect {|shape| OrtValue(Torch.zeros(*shape, dtype: :float32)) } inputs = Torch.view_as_real( Torch.stft(channel, **STFT_OPTS)[nil] ) outputs = [] inputs.shape[-2].times do |i| input = inputs[0.., 0.., i..i, 0..] enh, conv_cache, tra_cache, inter_cache = @session.run( @output_names, {mix: OrtValue(input), conv_cache:, tra_cache:, inter_cache:}, output_type: :ort_value ) outputs << TorchTensor(enh) end concated = Torch.cat(outputs, dim: 2) Torch .istft(Torch.view_as_complex(concated), **ISTFT_OPTS) .squeeze(0) end |