def test_single_channel_model( encoder, decoder, separator, stft_consistency, loss_type, mask_type, training ): if loss_type == "ci_sdr": inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] else: # ci_sdr will fail if length is too short inputs = torch.randn(2, 100) ilens = torch.LongTensor([100, 80]) speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()] if loss_type not in ("snr", "si_snr", "ci_sdr") and isinstance( encoder, ConvEncoder ): with pytest.raises(TypeError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"): with pytest.raises(ValueError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model(encoder, decoder, separator, stft_consistency, loss_type, mask_type, training): if not is_torch_1_2_plus: pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task") inputs = torch.randn(2, 100) ilens = torch.LongTensor([100, 80]) speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()] if loss_type != "si_snr" and isinstance(encoder, ConvEncoder): with pytest.raises(TypeError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return if stft_consistency and loss_type in ["mask_mse", "si_snr"]: with pytest.raises(ValueError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_ineube(n_mics, training, loss_wrappers, output_from): if not is_torch_1_9_plus: return inputs = torch.randn(1, 300, n_mics) ilens = torch.LongTensor([300]) speech_refs = [torch.randn(1, 300).float(), torch.randn(1, 300).float()] from espnet2.enh.decoder.null_decoder import NullDecoder from espnet2.enh.encoder.null_encoder import NullEncoder encoder = NullEncoder() decoder = NullDecoder() separator = iNeuBe( 2, mic_channels=n_mics, output_from=output_from, tcn_blocks=1, tcn_repeats=1 ) enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, mask_module=None, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_dptnet(training, loss_wrappers): encoder = ConvEncoder(channel=16, kernel_size=36, stride=18) decoder = ConvDecoder(channel=16, kernel_size=36, stride=18) inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=dptnet_separator, decoder=decoder, mask_module=None, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel: assert check_argument_types() encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf) separator = separator_choices.get_class(args.separator)( encoder.output_dim, **args.separator_conf) decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf) loss_wrappers = [] for ctr in args.criterions: criterion = criterion_choices.get_class(ctr["name"])(**ctr["conf"]) loss_wrapper = loss_wrapper_choices.get_class(ctr["wrapper"])( criterion=criterion, **ctr["wrapper_conf"]) loss_wrappers.append(loss_wrapper) # 1. Build model model = ESPnetEnhancementModel(encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, **args.model_conf) # FIXME(kamo): Should be done in model? # 2. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def test_criterion_behavior_noise(encoder, decoder, separator): if not isinstance(encoder, STFTEncoder) and isinstance( separator, (DCCRNSeparator, DC_CRNSeparator) ): # skip because DCCRNSeparator and DC_CRNSeparator only work # for complex spectrum features return inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] noise_ref = torch.randn(2, 300) enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, mask_module=None, loss_wrappers=[PITSolver(criterion=SISNRLoss(is_noise_loss=True))], ) enh_model.train() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, "noise_ref1": noise_ref, } loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model(encoder, decoder, separator, training, loss_wrappers): # DCCRN separator dose not support ConvEncoder and ConvDecoder if isinstance(encoder, ConvEncoder) and isinstance(separator, DCCRNSeparator): return inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_criterion_behavior(training): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=stft_encoder, separator=rnn_separator, decoder=stft_decoder, mask_module=None, loss_wrappers=[PITSolver(criterion=SISNRLoss(only_for_test=True))], ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } if training: with pytest.raises(AttributeError): loss, stats, weight = enh_model(**kwargs) else: loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model(encoder, decoder, separator, training, loss_wrappers): if not isinstance(encoder, STFTEncoder) and isinstance( separator, (DCCRNSeparator, DC_CRNSeparator)): # skip because DCCRNSeparator and DC_CRNSeparator only work # for complex spectrum features return inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_enh_asr_model( enh_encoder, enh_decoder, enh_separator, training, loss_wrappers, frontend, s2t_encoder, s2t_decoder, s2t_ctc, ): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_ref = torch.randn(2, 300).float() text = torch.LongTensor([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]]) text_lengths = torch.LongTensor([5, 5]) enh_model = ESPnetEnhancementModel( encoder=enh_encoder, separator=enh_separator, decoder=enh_decoder, mask_module=None, loss_wrappers=loss_wrappers, ) s2t_model = ESPnetASRModel( vocab_size=len(token_list), token_list=token_list, frontend=frontend, encoder=s2t_encoder, decoder=s2t_decoder, ctc=s2t_ctc, specaug=None, normalize=None, preencoder=None, postencoder=None, joint_network=None, ) enh_s2t_model = ESPnetEnhS2TModel( enh_model=enh_model, s2t_model=s2t_model, ) if training: enh_s2t_model.train() else: enh_s2t_model.eval() kwargs = { "speech": inputs, "speech_lengths": ilens, "speech_ref1": speech_ref, "text": text, "text_lengths": text_lengths, } loss, stats, weight = enh_s2t_model(**kwargs)
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel: assert check_argument_types() enh_model = enh_choices.get_class(args.enh)(**args.enh_conf) # 1. Build model model = ESPnetEnhancementModel(enh_model=enh_model) # FIXME(kamo): Should be done in model? # 2. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def test_enh_diar_model( enh_encoder, enh_decoder, enh_separator, mask_module, training, loss_wrappers, diar_frontend, diar_encoder, diar_decoder, label_aggregator, ): inputs = torch.randn(2, 300) speech_ref = torch.randn(2, 300).float() text = torch.randint(high=2, size=(2, 300, 2)) enh_model = ESPnetEnhancementModel( encoder=enh_encoder, separator=enh_separator, decoder=enh_decoder, mask_module=mask_module, loss_wrappers=loss_wrappers, ) diar_model = ESPnetDiarizationModel( label_aggregator=label_aggregator, frontend=diar_frontend, encoder=diar_encoder, decoder=diar_decoder, specaug=None, normalize=None, attractor=None, ) enh_s2t_model = ESPnetEnhS2TModel( enh_model=enh_model, s2t_model=diar_model, ) if training: enh_s2t_model.train() else: enh_s2t_model.eval() kwargs = { "speech": inputs, "speech_ref1": speech_ref, "speech_ref2": speech_ref, "text": text, } loss, stats, weight = enh_s2t_model(**kwargs)
def test_criterion_behavior_dereverb(loss_type, num_spk): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float() for _ in range(num_spk)] dereverb_ref = [torch.randn(2, 300).float() for _ in range(num_spk)] beamformer = NeuralBeamformer( input_dim=17, loss_type=loss_type, num_spk=num_spk, use_wpe=True, wlayers=2, wunits=2, wprojs=2, use_dnn_mask_for_wpe=True, multi_source_wpe=True, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, ref_channel=0, use_noise_mask=False, ) if loss_type == "mask_mse": loss_wrapper = PITSolver( criterion=FrequencyDomainMSE( compute_on_mask=True, mask_type="PSM", is_dereverb_loss=True ) ) else: loss_wrapper = PITSolver(criterion=SISNRLoss(is_dereverb_loss=True)) enh_model = ESPnetEnhancementModel( encoder=stft_encoder, separator=beamformer, decoder=stft_decoder, mask_module=None, loss_wrappers=[loss_wrapper], ) enh_model.train() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)}, "dereverb_ref1": dereverb_ref[0], } loss, stats, weight = enh_model(**kwargs)
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel: assert check_argument_types() encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf) separator = separator_choices.get_class(args.separator)( encoder.output_dim, **args.separator_conf) decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf) if args.separator.endswith("nomask"): mask_module = mask_module_choices.get_class(args.mask_module)( input_dim=encoder.output_dim, **args.mask_module_conf, ) else: mask_module = None loss_wrappers = [] if getattr(args, "criterions", None) is not None: # This check is for the compatibility when load models # that packed by older version for ctr in args.criterions: criterion_conf = ctr.get("conf", {}) criterion = criterion_choices.get_class( ctr["name"])(**criterion_conf) loss_wrapper = loss_wrapper_choices.get_class(ctr["wrapper"])( criterion=criterion, **ctr["wrapper_conf"]) loss_wrappers.append(loss_wrapper) # 1. Build model model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, mask_module=mask_module, **args.model_conf, ) # FIXME(kamo): Should be done in model? # 2. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel: assert check_argument_types() encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf) separator = separator_choices.get_class(args.separator)( encoder.output_dim, **args.separator_conf) decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf) # 1. Build model model = ESPnetEnhancementModel(encoder=encoder, separator=separator, decoder=decoder, **args.model_conf) # FIXME(kamo): Should be done in model? # 2. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def test_svoice_model(encoder, decoder, separator, training, loss_wrappers): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_forward_with_beamformer_net( training, mask_type, loss_type, num_spk, use_noise_mask, stft_consistency, use_builtin_complex, ): # Skip some testing cases if not loss_type.startswith("mask") and mask_type != "IBM": # `mask_type` has no effect when `loss_type` is not "mask..." return if not is_torch_1_9_plus and use_builtin_complex: # builtin complex support is only available in PyTorch 1.8+ return ch = 2 inputs = random_speech[..., :ch].float() ilens = torch.LongTensor([16, 12]) speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)] noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float) dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float) encoder = STFTEncoder( n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex ) decoder = STFTDecoder(n_fft=8, hop_length=2) if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"): # skip this condition return beamformer = NeuralBeamformer( input_dim=5, loss_type=loss_type, num_spk=num_spk, use_wpe=True, wlayers=2, wunits=2, wprojs=2, use_dnn_mask_for_wpe=True, multi_source_wpe=True, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, ref_channel=0, use_noise_mask=use_noise_mask, beamformer_type="mvdr_souden", ) enh_model = ESPnetEnhancementModel( encoder=encoder, decoder=decoder, separator=beamformer, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)}, "noise_ref1": noise_ref1, "dereverb_ref1": dereverb_ref1, } loss, stats, weight = enh_model(**kwargs)