def test_criterion_behavior_noise(encoder, decoder, separator): if not isinstance(encoder, STFTEncoder) and isinstance( separator, (DCCRNSeparator, DC_CRNSeparator) ): # skip because DCCRNSeparator and DC_CRNSeparator only work # for complex spectrum features return inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] noise_ref = torch.randn(2, 300) enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, mask_module=None, loss_wrappers=[PITSolver(criterion=SISNRLoss(is_noise_loss=True))], ) enh_model.train() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, "noise_ref1": noise_ref, } loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model( encoder, decoder, separator, stft_consistency, loss_type, mask_type, training ): if loss_type == "ci_sdr": inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] else: # ci_sdr will fail if length is too short inputs = torch.randn(2, 100) ilens = torch.LongTensor([100, 80]) speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()] if loss_type not in ("snr", "si_snr", "ci_sdr") and isinstance( encoder, ConvEncoder ): with pytest.raises(TypeError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"): with pytest.raises(ValueError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model(encoder, decoder, separator, stft_consistency, loss_type, mask_type, training): if not is_torch_1_2_plus: pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task") inputs = torch.randn(2, 100) ilens = torch.LongTensor([100, 80]) speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()] if loss_type != "si_snr" and isinstance(encoder, ConvEncoder): with pytest.raises(TypeError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return if stft_consistency and loss_type in ["mask_mse", "si_snr"]: with pytest.raises(ValueError): enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) return enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_criterion_behavior_dereverb(loss_type, num_spk): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float() for _ in range(num_spk)] dereverb_ref = [torch.randn(2, 300).float() for _ in range(num_spk)] beamformer = NeuralBeamformer( input_dim=17, loss_type=loss_type, num_spk=num_spk, use_wpe=True, wlayers=2, wunits=2, wprojs=2, use_dnn_mask_for_wpe=True, multi_source_wpe=True, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, ref_channel=0, use_noise_mask=False, ) if loss_type == "mask_mse": loss_wrapper = PITSolver( criterion=FrequencyDomainMSE( compute_on_mask=True, mask_type="PSM", is_dereverb_loss=True ) ) else: loss_wrapper = PITSolver(criterion=SISNRLoss(is_dereverb_loss=True)) enh_model = ESPnetEnhancementModel( encoder=stft_encoder, separator=beamformer, decoder=stft_decoder, mask_module=None, loss_wrappers=[loss_wrapper], ) enh_model.train() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)}, "dereverb_ref1": dereverb_ref[0], } loss, stats, weight = enh_model(**kwargs)
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel: assert check_argument_types() encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf) separator = separator_choices.get_class(args.separator)( encoder.output_dim, **args.separator_conf) decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf) loss_wrappers = [] for ctr in args.criterions: criterion = criterion_choices.get_class(ctr["name"])(**ctr["conf"]) loss_wrapper = loss_wrapper_choices.get_class(ctr["wrapper"])( criterion=criterion, **ctr["wrapper_conf"]) loss_wrappers.append(loss_wrapper) # 1. Build model model = ESPnetEnhancementModel(encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, **args.model_conf) # FIXME(kamo): Should be done in model? # 2. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def test_enh_asr_model( enh_encoder, enh_decoder, enh_separator, training, loss_wrappers, frontend, s2t_encoder, s2t_decoder, s2t_ctc, ): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_ref = torch.randn(2, 300).float() text = torch.LongTensor([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]]) text_lengths = torch.LongTensor([5, 5]) enh_model = ESPnetEnhancementModel( encoder=enh_encoder, separator=enh_separator, decoder=enh_decoder, mask_module=None, loss_wrappers=loss_wrappers, ) s2t_model = ESPnetASRModel( vocab_size=len(token_list), token_list=token_list, frontend=frontend, encoder=s2t_encoder, decoder=s2t_decoder, ctc=s2t_ctc, specaug=None, normalize=None, preencoder=None, postencoder=None, joint_network=None, ) enh_s2t_model = ESPnetEnhS2TModel( enh_model=enh_model, s2t_model=s2t_model, ) if training: enh_s2t_model.train() else: enh_s2t_model.eval() kwargs = { "speech": inputs, "speech_lengths": ilens, "speech_ref1": speech_ref, "text": text, "text_lengths": text_lengths, } loss, stats, weight = enh_s2t_model(**kwargs)
def test_dptnet(training, loss_wrappers): encoder = ConvEncoder(channel=16, kernel_size=36, stride=18) decoder = ConvDecoder(channel=16, kernel_size=36, stride=18) inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=dptnet_separator, decoder=decoder, mask_module=None, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model(encoder, decoder, separator, training, loss_wrappers): # DCCRN separator dose not support ConvEncoder and ConvDecoder if isinstance(encoder, ConvEncoder) and isinstance(separator, DCCRNSeparator): return inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_ineube(n_mics, training, loss_wrappers, output_from): if not is_torch_1_9_plus: return inputs = torch.randn(1, 300, n_mics) ilens = torch.LongTensor([300]) speech_refs = [torch.randn(1, 300).float(), torch.randn(1, 300).float()] from espnet2.enh.decoder.null_decoder import NullDecoder from espnet2.enh.encoder.null_encoder import NullEncoder encoder = NullEncoder() decoder = NullDecoder() separator = iNeuBe( 2, mic_channels=n_mics, output_from=output_from, tcn_blocks=1, tcn_repeats=1 ) enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, mask_module=None, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def test_criterion_behavior(training): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=stft_encoder, separator=rnn_separator, decoder=stft_decoder, mask_module=None, loss_wrappers=[PITSolver(criterion=SISNRLoss(only_for_test=True))], ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } if training: with pytest.raises(AttributeError): loss, stats, weight = enh_model(**kwargs) else: loss, stats, weight = enh_model(**kwargs)
def test_single_channel_model(encoder, decoder, separator, training, loss_wrappers): if not isinstance(encoder, STFTEncoder) and isinstance( separator, (DCCRNSeparator, DC_CRNSeparator)): # skip because DCCRNSeparator and DC_CRNSeparator only work # for complex spectrum features return inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel: assert check_argument_types() enh_model = enh_choices.get_class(args.enh)(**args.enh_conf) # 1. Build model model = ESPnetEnhancementModel(enh_model=enh_model) # FIXME(kamo): Should be done in model? # 2. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def test_enh_diar_model( enh_encoder, enh_decoder, enh_separator, mask_module, training, loss_wrappers, diar_frontend, diar_encoder, diar_decoder, label_aggregator, ): inputs = torch.randn(2, 300) speech_ref = torch.randn(2, 300).float() text = torch.randint(high=2, size=(2, 300, 2)) enh_model = ESPnetEnhancementModel( encoder=enh_encoder, separator=enh_separator, decoder=enh_decoder, mask_module=mask_module, loss_wrappers=loss_wrappers, ) diar_model = ESPnetDiarizationModel( label_aggregator=label_aggregator, frontend=diar_frontend, encoder=diar_encoder, decoder=diar_decoder, specaug=None, normalize=None, attractor=None, ) enh_s2t_model = ESPnetEnhS2TModel( enh_model=enh_model, s2t_model=diar_model, ) if training: enh_s2t_model.train() else: enh_s2t_model.eval() kwargs = { "speech": inputs, "speech_ref1": speech_ref, "speech_ref2": speech_ref, "text": text, } loss, stats, weight = enh_s2t_model(**kwargs)
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel: assert check_argument_types() encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf) separator = separator_choices.get_class(args.separator)( encoder.output_dim, **args.separator_conf) decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf) if args.separator.endswith("nomask"): mask_module = mask_module_choices.get_class(args.mask_module)( input_dim=encoder.output_dim, **args.mask_module_conf, ) else: mask_module = None loss_wrappers = [] if getattr(args, "criterions", None) is not None: # This check is for the compatibility when load models # that packed by older version for ctr in args.criterions: criterion_conf = ctr.get("conf", {}) criterion = criterion_choices.get_class( ctr["name"])(**criterion_conf) loss_wrapper = loss_wrapper_choices.get_class(ctr["wrapper"])( criterion=criterion, **ctr["wrapper_conf"]) loss_wrappers.append(loss_wrapper) # 1. Build model model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, mask_module=mask_module, **args.model_conf, ) # FIXME(kamo): Should be done in model? # 2. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel: assert check_argument_types() encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf) separator = separator_choices.get_class(args.separator)( encoder.output_dim, **args.separator_conf) decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf) # 1. Build model model = ESPnetEnhancementModel(encoder=encoder, separator=separator, decoder=decoder, **args.model_conf) # FIXME(kamo): Should be done in model? # 2. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def test_svoice_model(encoder, decoder, separator, training, loss_wrappers): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, loss_wrappers=loss_wrappers, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } loss, stats, weight = enh_model(**kwargs)
def scoring( output_dir: str, dtype: str, log_level: Union[int, str], key_file: str, ref_scp: List[str], inf_scp: List[str], ref_channel: int, metrics: List[str], frame_size: int = 512, frame_hop: int = 256, ): assert check_argument_types() for metric in metrics: assert metric in ( "STOI", "ESTOI", "SNR", "SI_SNR", "SDR", "SAR", "SIR", "framewise-SNR", ), metric logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) assert len(ref_scp) == len(inf_scp), ref_scp num_spk = len(ref_scp) keys = [ line.rstrip().split(maxsplit=1)[0] for line in open(key_file, encoding="utf-8") ] ref_readers = [ SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp ] inf_readers = [ SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp ] # get sample rate fs, _ = ref_readers[0][keys[0]] # check keys for inf_reader, ref_reader in zip(inf_readers, ref_readers): assert inf_reader.keys() == ref_reader.keys() stft = STFTEncoder(n_fft=frame_size, hop_length=frame_hop) do_bss_eval = "SDR" in metrics or "SAR" in metrics or "SIR" in metrics with DatadirWriter(output_dir) as writer: for key in keys: ref_audios = [ref_reader[key][1] for ref_reader in ref_readers] inf_audios = [inf_reader[key][1] for inf_reader in inf_readers] ref = np.array(ref_audios) inf = np.array(inf_audios) if ref.ndim > inf.ndim: # multi-channel reference and single-channel output ref = ref[..., ref_channel] assert ref.shape == inf.shape, (ref.shape, inf.shape) elif ref.ndim < inf.ndim: # single-channel reference and multi-channel output raise ValueError("Reference must be multi-channel when the " "network output is multi-channel.") elif ref.ndim == inf.ndim == 3: # multi-channel reference and output ref = ref[..., ref_channel] inf = inf[..., ref_channel] if do_bss_eval or num_spk > 1: sdr, sir, sar, perm = bss_eval_sources( ref, inf, compute_permutation=True) else: perm = [0] ilens = torch.LongTensor([ref.shape[1]]) # (num_spk, T, F) ref_spec, flens = stft(torch.from_numpy(ref), ilens) inf_spec, _ = stft(torch.from_numpy(inf), ilens) for i in range(num_spk): p = int(perm[i]) for metric in metrics: name = f"{metric}_spk{i + 1}" if metric == "STOI": writer[name][key] = str( stoi(ref[i], inf[p], fs_sig=fs, extended=False)) elif metric == "ESTOI": writer[name][key] = str( stoi(ref[i], inf[p], fs_sig=fs, extended=True)) elif metric == "SNR": si_snr_score = -float( ESPnetEnhancementModel.snr_loss( torch.from_numpy(ref[i][None, ...]), torch.from_numpy(inf[p][None, ...]), )) writer[name][key] = str(si_snr_score) elif metric == "SI_SNR": si_snr_score = -float( ESPnetEnhancementModel.si_snr_loss( torch.from_numpy(ref[i][None, ...]), torch.from_numpy(inf[p][None, ...]), )) writer[name][key] = str(si_snr_score) elif metric == "SDR": writer[name][key] = str(sdr[i]) elif metric == "SAR": writer[name][key] = str(sar[i]) elif metric == "SIR": writer[name][key] = str(sir[i]) elif metric == "framewise-SNR": framewise_snr = -ESPnetEnhancementModel.snr_loss( ref_spec[i].abs(), inf_spec[i].abs()) writer[name][key] = " ".join( map(str, framewise_snr.tolist())) else: raise ValueError("Unsupported metric: %s" % metric) # save permutation assigned script file writer[f"wav_spk{i + 1}"][key] = inf_readers[ perm[i]].data[key]
def test_forward_with_beamformer_net( training, mask_type, loss_type, num_spk, use_noise_mask, stft_consistency, use_builtin_complex, ): # Skip some testing cases if not loss_type.startswith("mask") and mask_type != "IBM": # `mask_type` has no effect when `loss_type` is not "mask..." return if not is_torch_1_9_plus and use_builtin_complex: # builtin complex support is only available in PyTorch 1.8+ return ch = 2 inputs = random_speech[..., :ch].float() ilens = torch.LongTensor([16, 12]) speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)] noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float) dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float) encoder = STFTEncoder( n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex ) decoder = STFTDecoder(n_fft=8, hop_length=2) if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"): # skip this condition return beamformer = NeuralBeamformer( input_dim=5, loss_type=loss_type, num_spk=num_spk, use_wpe=True, wlayers=2, wunits=2, wprojs=2, use_dnn_mask_for_wpe=True, multi_source_wpe=True, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, ref_channel=0, use_noise_mask=use_noise_mask, beamformer_type="mvdr_souden", ) enh_model = ESPnetEnhancementModel( encoder=encoder, decoder=decoder, separator=beamformer, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)}, "noise_ref1": noise_ref1, "dereverb_ref1": dereverb_ref1, } loss, stats, weight = enh_model(**kwargs)
def scoring( output_dir: str, dtype: str, log_level: Union[int, str], key_file: str, ref_scp: List[str], inf_scp: List[str], ref_channel: int, ): assert check_argument_types() logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) assert len(ref_scp) == len(inf_scp), ref_scp num_spk = len(ref_scp) keys = [ line.rstrip().split(maxsplit=1)[0] for line in open(key_file, encoding="utf-8") ] ref_readers = [ SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp ] inf_readers = [ SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp ] # get sample rate sample_rate, _ = ref_readers[0][keys[0]] # check keys for inf_reader, ref_reader in zip(inf_readers, ref_readers): assert inf_reader.keys() == ref_reader.keys() with DatadirWriter(output_dir) as writer: for key in keys: ref_audios = [ref_reader[key][1] for ref_reader in ref_readers] inf_audios = [inf_reader[key][1] for inf_reader in inf_readers] ref = np.array(ref_audios) inf = np.array(inf_audios) if ref.ndim > inf.ndim: # multi-channel reference and single-channel output ref = ref[..., ref_channel] assert ref.shape == inf.shape, (ref.shape, inf.shape) elif ref.ndim < inf.ndim: # single-channel reference and multi-channel output raise ValueError("Reference must be multi-channel when the \ network output is multi-channel.") elif ref.ndim == inf.ndim == 3: # multi-channel reference and output ref = ref[..., ref_channel] inf = inf[..., ref_channel] sdr, sir, sar, perm = bss_eval_sources(ref, inf, compute_permutation=True) for i in range(num_spk): stoi_score = stoi(ref[i], inf[int(perm[i])], fs_sig=sample_rate) si_snr_score = -float( ESPnetEnhancementModel.si_snr_loss( torch.from_numpy(ref[i][None, ...]), torch.from_numpy(inf[int(perm[i])][None, ...]), )) writer[f"STOI_spk{i + 1}"][key] = str(stoi_score) writer[f"SI_SNR_spk{i + 1}"][key] = str(si_snr_score) writer[f"SDR_spk{i + 1}"][key] = str(sdr[i]) writer[f"SAR_spk{i + 1}"][key] = str(sar[i]) writer[f"SIR_spk{i + 1}"][key] = str(sir[i]) # save permutation assigned script file writer[f"wav_spk{i + 1}"][key] = inf_readers[perm[i]].data[key]