Class: GTCRN
- Inherits:
-
Object
- Object
- GTCRN
- Defined in:
- lib/gtcrn.rb
Constant Summary collapse
- MODEL_PATH =
File.join(__dir__, "../vendor/gtcrn/stream/onnx_models/gtcrn_simple.onnx").freeze
- ISTFT_OPTS =
{ n_fft: 512, hop_length: 256, win_length: 512, window: Torch.hann_window(512).pow(0.5) }
- STFT_OPTS =
ISTFT_OPTS.merge( pad_mode: "reflect", return_complex: true )
Instance Method Summary collapse
- #enhance_speech(path, dest = nil) ⇒ Object
- #enhance_speech_waveform(waveform) ⇒ Object
-
#initialize ⇒ GTCRN
constructor
A new instance of GTCRN.
Constructor Details
#initialize ⇒ GTCRN
Returns a new instance of GTCRN.
20 21 22 |
# File 'lib/gtcrn.rb', line 20 def initialize @session = OnnxRuntime::InferenceSession.new(MODEL_PATH) end |
Instance Method Details
#enhance_speech(path, dest = nil) ⇒ Object
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
# File 'lib/gtcrn.rb', line 24 def enhance_speech(path, dest=nil) path = Pathname(path) if dest dest = Pathname(dest) if dest.directory? dest = dest/path.basename.sub_ext(".enhanced" + path.extname) end else dest = path.sub_ext(".enhanced" + path.extname) end waveform, sample_rate = TorchAudio.load(path.to_path) raise "Sampling rate must be 16000 Hz, but given: #{sample_rate} Hz" unless sample_rate == 16000 enhanced = enhance_speech_waveform(waveform) TorchAudio.save(dest.to_path, enhanced.squeeze, sample_rate) dest end |
#enhance_speech_waveform(waveform) ⇒ Object
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# File 'lib/gtcrn.rb', line 43 def enhance_speech_waveform(waveform) conv_cache, tra_cache, inter_cache = 1.upto(3).collect {|i| Numo::SFloat.zeros(*@session.inputs[i][:shape]) } inputs = Torch.view_as_real( Torch.stft(waveform[0], **STFT_OPTS)[nil] ).numo outputs = [] inputs.shape[-2].times do |i| enh, conv_cache, tra_cache, inter_cache = @session.run( @session.outputs.collect {|output| output[:name]}, { mix: OnnxRuntime::OrtValue.from_numo(inputs[0.., 0.., i..i, 0..]), conv_cache: OnnxRuntime::OrtValue.from_numo(conv_cache), tra_cache: OnnxRuntime::OrtValue.from_numo(tra_cache), inter_cache: OnnxRuntime::OrtValue.from_numo(inter_cache) }, output_type: :numo ) outputs << enh end concated = Numo::NArray.concatenate(outputs, axis: 2) real = concated[0.., 0.., 0.., 0] imag = concated[0.., 0.., 0.., 1] enhanced = Torch.istft( Torch.from_numo(real) + 1i * Torch.from_numo(imag), **ISTFT_OPTS ) enhanced end |