def test_STFTDecoder_backward(n_fft, win_length, hop_length, window, center, normalized, onesided): decoder = STFTDecoder( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) real = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True) imag = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True) x = ComplexTensor(real, imag) x_lens = torch.tensor([300 * hop_length, 295 * hop_length], dtype=torch.long) y, ilens = decoder(x, x_lens) y.sum().backward()
def test_STFTDecoder_invalid_type(n_fft, win_length, hop_length, window, center, normalized, onesided): decoder = STFTDecoder( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) with pytest.raises(TypeError): real = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True) x_lens = torch.tensor([300 * hop_length, 295 * hop_length], dtype=torch.long) y, ilens = decoder(real, x_lens)
def test_STFTDecoder_invalid_type( n_fft, win_length, hop_length, window, center, normalized, onesided ): if not is_torch_1_2_plus: pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task") decoder = STFTDecoder( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) with pytest.raises(TypeError): real = torch.rand( 2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True ) x_lens = torch.tensor([300 * hop_length, 295 * hop_length], dtype=torch.long) y, ilens = decoder(real, x_lens)
def test_STFTDecoder_backward( n_fft, win_length, hop_length, window, center, normalized, onesided ): if not is_torch_1_2_plus: pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task") decoder = STFTDecoder( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) real = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True) imag = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True) x = ComplexTensor(real, imag) x_lens = torch.tensor([300 * hop_length, 295 * hop_length], dtype=torch.long) y, ilens = decoder(x, x_lens) y.sum().backward()
from espnet2.enh.encoder.conv_encoder import ConvEncoder from espnet2.enh.encoder.stft_encoder import STFTEncoder from espnet2.enh.espnet_enh_s2t_model import ESPnetEnhS2TModel from espnet2.enh.espnet_model import ESPnetEnhancementModel from espnet2.enh.loss.criterions.time_domain import SISNRLoss from espnet2.enh.loss.wrappers.fixed_order import FixedOrderSolver from espnet2.enh.separator.rnn_separator import RNNSeparator from espnet2.layers.label_aggregation import LabelAggregate enh_stft_encoder = STFTEncoder( n_fft=32, hop_length=16, ) enh_stft_decoder = STFTDecoder( n_fft=32, hop_length=16, ) enh_rnn_separator = RNNSeparator( input_dim=17, layer=1, unit=10, num_spk=1, ) si_snr_loss = SISNRLoss() fix_order_solver = FixedOrderSolver(criterion=si_snr_loss) default_frontend = DefaultFrontend( fs=300,
def test_forward_with_beamformer_net( training, mask_type, loss_type, num_spk, use_noise_mask, stft_consistency, use_builtin_complex, ): # Skip some testing cases if not loss_type.startswith("mask") and mask_type != "IBM": # `mask_type` has no effect when `loss_type` is not "mask..." return if not is_torch_1_9_plus and use_builtin_complex: # builtin complex support is only available in PyTorch 1.8+ return ch = 2 inputs = random_speech[..., :ch].float() ilens = torch.LongTensor([16, 12]) speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)] noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float) dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float) encoder = STFTEncoder( n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex ) decoder = STFTDecoder(n_fft=8, hop_length=2) if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"): # skip this condition return beamformer = NeuralBeamformer( input_dim=5, loss_type=loss_type, num_spk=num_spk, use_wpe=True, wlayers=2, wunits=2, wprojs=2, use_dnn_mask_for_wpe=True, multi_source_wpe=True, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, ref_channel=0, use_noise_mask=use_noise_mask, beamformer_type="mvdr_souden", ) enh_model = ESPnetEnhancementModel( encoder=encoder, decoder=decoder, separator=beamformer, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)}, "noise_ref1": noise_ref1, "dereverb_ref1": dereverb_ref1, } loss, stats, weight = enh_model(**kwargs)
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0") stft_encoder = STFTEncoder( n_fft=28, hop_length=16, ) stft_encoder_bultin_complex = STFTEncoder( n_fft=28, hop_length=16, use_builtin_complex=True, ) stft_decoder = STFTDecoder( n_fft=28, hop_length=16, ) conv_encoder = ConvEncoder( channel=15, kernel_size=32, stride=16, ) conv_decoder = ConvDecoder( channel=15, kernel_size=32, stride=16, ) rnn_separator = RNNSeparator(
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0") stft_encoder = STFTEncoder( n_fft=16, hop_length=8, ) stft_encoder_bultin_complex = STFTEncoder( n_fft=16, hop_length=8, use_builtin_complex=True, ) stft_decoder = STFTDecoder( n_fft=16, hop_length=8, ) conv_encoder = ConvEncoder( channel=9, kernel_size=20, stride=10, ) conv_decoder = ConvDecoder( channel=9, kernel_size=20, stride=10, ) rnn_separator = RNNSeparator(