def test_dptnet(training, loss_wrappers): encoder = ConvEncoder(channel=16, kernel_size=36, stride=18) decoder = ConvDecoder(channel=16, kernel_size=36, stride=18) inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=dptnet_separator, decoder=decoder, mask_module=None, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model(encoder, decoder, separator, training, loss_wrappers): # DCCRN separator dose not support ConvEncoder and ConvDecoder if isinstance(encoder, ConvEncoder) and isinstance(separator, DCCRNSeparator): return inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_criterion_behavior(training): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=stft_encoder, separator=rnn_separator, decoder=stft_decoder, mask_module=None, loss_wrappers=[PITSolver(criterion=SISNRLoss(only_for_test=True))], ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } if training: with pytest.raises(AttributeError): loss, stats, weight = enh_model(**kwargs) else: loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model(encoder, decoder, separator, training, loss_wrappers): if not isinstance(encoder, STFTEncoder) and isinstance( separator, (DCCRNSeparator, DC_CRNSeparator)): # skip because DCCRNSeparator and DC_CRNSeparator only work # for complex spectrum features return inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_ineube(n_mics, training, loss_wrappers, output_from): if not is_torch_1_9_plus: return inputs = torch.randn(1, 300, n_mics) ilens = torch.LongTensor([300]) speech_refs = [torch.randn(1, 300).float(), torch.randn(1, 300).float()] from espnet2.enh.decoder.null_decoder import NullDecoder from espnet2.enh.encoder.null_encoder import NullEncoder encoder = NullEncoder() decoder = NullDecoder() separator = iNeuBe( 2, mic_channels=n_mics, output_from=output_from, tcn_blocks=1, tcn_repeats=1 ) enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, mask_module=None, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model( encoder, decoder, separator, stft_consistency, loss_type, mask_type, training ): if loss_type == "ci_sdr": inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] else: # ci_sdr will fail if length is too short inputs = torch.randn(2, 100) ilens = torch.LongTensor([100, 80]) speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()] if loss_type not in ("snr", "si_snr", "ci_sdr") and isinstance( encoder, ConvEncoder ): with pytest.raises(TypeError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"): with pytest.raises(ValueError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model(encoder, decoder, separator, stft_consistency, loss_type, mask_type, training): if not is_torch_1_2_plus: pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task") inputs = torch.randn(2, 100) ilens = torch.LongTensor([100, 80]) speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()] if loss_type != "si_snr" and isinstance(encoder, ConvEncoder): with pytest.raises(TypeError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return if stft_consistency and loss_type in ["mask_mse", "si_snr"]: with pytest.raises(ValueError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_svoice_model(encoder, decoder, separator, training, loss_wrappers): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_forward_with_beamformer_net( training, mask_type, loss_type, num_spk, use_noise_mask, stft_consistency, use_builtin_complex, ): # Skip some testing cases if not loss_type.startswith("mask") and mask_type != "IBM": # `mask_type` has no effect when `loss_type` is not "mask..." return if not is_torch_1_9_plus and use_builtin_complex: # builtin complex support is only available in PyTorch 1.8+ return ch = 2 inputs = random_speech[..., :ch].float() ilens = torch.LongTensor([16, 12]) speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)] noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float) dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float) encoder = STFTEncoder( n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex ) decoder = STFTDecoder(n_fft=8, hop_length=2) if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"): # skip this condition return beamformer = NeuralBeamformer( input_dim=5, loss_type=loss_type, num_spk=num_spk, use_wpe=True, wlayers=2, wunits=2, wprojs=2, use_dnn_mask_for_wpe=True, multi_source_wpe=True, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, ref_channel=0, use_noise_mask=use_noise_mask, beamformer_type="mvdr_souden", ) enh_model = ESPnetEnhancementModel( encoder=encoder, decoder=decoder, separator=beamformer, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)}, "noise_ref1": noise_ref1, "dereverb_ref1": dereverb_ref1, } loss, stats, weight = enh_model(**kwargs)