Ejemplo n.º 1
0
def test_criterion_behavior_noise(encoder, decoder, separator):
    if not isinstance(encoder, STFTEncoder) and isinstance(
        separator, (DCCRNSeparator, DC_CRNSeparator)
    ):
        # skip because DCCRNSeparator and DC_CRNSeparator only work
        # for complex spectrum features
        return
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    noise_ref = torch.randn(2, 300)
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=[PITSolver(criterion=SISNRLoss(is_noise_loss=True))],
    )

    enh_model.train()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
        "noise_ref1": noise_ref,
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 2
0
def test_single_channel_model(
    encoder, decoder, separator, stft_consistency, loss_type, mask_type, training
):
    if loss_type == "ci_sdr":
        inputs = torch.randn(2, 300)
        ilens = torch.LongTensor([300, 200])
        speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    else:
        # ci_sdr will fail if length is too short
        inputs = torch.randn(2, 100)
        ilens = torch.LongTensor([100, 80])
        speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()]

    if loss_type not in ("snr", "si_snr", "ci_sdr") and isinstance(
        encoder, ConvEncoder
    ):
        with pytest.raises(TypeError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return
    if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"):
        with pytest.raises(ValueError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return

    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 3
0
def test_single_channel_model(encoder, decoder, separator, stft_consistency,
                              loss_type, mask_type, training):
    if not is_torch_1_2_plus:
        pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task")

    inputs = torch.randn(2, 100)
    ilens = torch.LongTensor([100, 80])
    speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()]

    if loss_type != "si_snr" and isinstance(encoder, ConvEncoder):
        with pytest.raises(TypeError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return
    if stft_consistency and loss_type in ["mask_mse", "si_snr"]:
        with pytest.raises(ValueError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return

    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 4
0
def test_criterion_behavior_dereverb(loss_type, num_spk):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float() for _ in range(num_spk)]
    dereverb_ref = [torch.randn(2, 300).float() for _ in range(num_spk)]
    beamformer = NeuralBeamformer(
        input_dim=17,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=False,
    )
    if loss_type == "mask_mse":
        loss_wrapper = PITSolver(
            criterion=FrequencyDomainMSE(
                compute_on_mask=True, mask_type="PSM", is_dereverb_loss=True
            )
        )
    else:
        loss_wrapper = PITSolver(criterion=SISNRLoss(is_dereverb_loss=True))
    enh_model = ESPnetEnhancementModel(
        encoder=stft_encoder,
        separator=beamformer,
        decoder=stft_decoder,
        mask_module=None,
        loss_wrappers=[loss_wrapper],
    )

    enh_model.train()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "dereverb_ref1": dereverb_ref[0],
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 5
0
    def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:
        assert check_argument_types()

        encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf)
        separator = separator_choices.get_class(args.separator)(
            encoder.output_dim, **args.separator_conf)
        decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf)

        loss_wrappers = []
        for ctr in args.criterions:
            criterion = criterion_choices.get_class(ctr["name"])(**ctr["conf"])
            loss_wrapper = loss_wrapper_choices.get_class(ctr["wrapper"])(
                criterion=criterion, **ctr["wrapper_conf"])
            loss_wrappers.append(loss_wrapper)

        # 1. Build model
        model = ESPnetEnhancementModel(encoder=encoder,
                                       separator=separator,
                                       decoder=decoder,
                                       loss_wrappers=loss_wrappers,
                                       **args.model_conf)

        # FIXME(kamo): Should be done in model?
        # 2. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
Ejemplo n.º 6
0
def test_enh_asr_model(
    enh_encoder,
    enh_decoder,
    enh_separator,
    training,
    loss_wrappers,
    frontend,
    s2t_encoder,
    s2t_decoder,
    s2t_ctc,
):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_ref = torch.randn(2, 300).float()
    text = torch.LongTensor([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]])
    text_lengths = torch.LongTensor([5, 5])
    enh_model = ESPnetEnhancementModel(
        encoder=enh_encoder,
        separator=enh_separator,
        decoder=enh_decoder,
        mask_module=None,
        loss_wrappers=loss_wrappers,
    )
    s2t_model = ESPnetASRModel(
        vocab_size=len(token_list),
        token_list=token_list,
        frontend=frontend,
        encoder=s2t_encoder,
        decoder=s2t_decoder,
        ctc=s2t_ctc,
        specaug=None,
        normalize=None,
        preencoder=None,
        postencoder=None,
        joint_network=None,
    )
    enh_s2t_model = ESPnetEnhS2TModel(
        enh_model=enh_model,
        s2t_model=s2t_model,
    )

    if training:
        enh_s2t_model.train()
    else:
        enh_s2t_model.eval()

    kwargs = {
        "speech": inputs,
        "speech_lengths": ilens,
        "speech_ref1": speech_ref,
        "text": text,
        "text_lengths": text_lengths,
    }
    loss, stats, weight = enh_s2t_model(**kwargs)
Ejemplo n.º 7
0
def test_dptnet(training, loss_wrappers):
    encoder = ConvEncoder(channel=16, kernel_size=36, stride=18)
    decoder = ConvDecoder(channel=16, kernel_size=36, stride=18)

    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=dptnet_separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 8
0
def test_single_channel_model(encoder, decoder, separator, training, loss_wrappers):
    # DCCRN separator dose not support ConvEncoder and ConvDecoder
    if isinstance(encoder, ConvEncoder) and isinstance(separator, DCCRNSeparator):
        return
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 9
0
def test_ineube(n_mics, training, loss_wrappers, output_from):
    if not is_torch_1_9_plus:
        return
    inputs = torch.randn(1, 300, n_mics)
    ilens = torch.LongTensor([300])
    speech_refs = [torch.randn(1, 300).float(), torch.randn(1, 300).float()]
    from espnet2.enh.decoder.null_decoder import NullDecoder
    from espnet2.enh.encoder.null_encoder import NullEncoder

    encoder = NullEncoder()
    decoder = NullDecoder()
    separator = iNeuBe(
        2, mic_channels=n_mics, output_from=output_from, tcn_blocks=1, tcn_repeats=1
    )
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 10
0
def test_criterion_behavior(training):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=stft_encoder,
        separator=rnn_separator,
        decoder=stft_decoder,
        mask_module=None,
        loss_wrappers=[PITSolver(criterion=SISNRLoss(only_for_test=True))],
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }

    if training:
        with pytest.raises(AttributeError):
            loss, stats, weight = enh_model(**kwargs)
    else:
        loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 11
0
def test_single_channel_model(encoder, decoder, separator, training,
                              loss_wrappers):
    if not isinstance(encoder, STFTEncoder) and isinstance(
            separator, (DCCRNSeparator, DC_CRNSeparator)):
        # skip because DCCRNSeparator and DC_CRNSeparator only work
        # for complex spectrum features
        return
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 12
0
    def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:
        assert check_argument_types()

        enh_model = enh_choices.get_class(args.enh)(**args.enh_conf)

        # 1. Build model
        model = ESPnetEnhancementModel(enh_model=enh_model)

        # FIXME(kamo): Should be done in model?
        # 2. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
Ejemplo n.º 13
0
def test_enh_diar_model(
    enh_encoder,
    enh_decoder,
    enh_separator,
    mask_module,
    training,
    loss_wrappers,
    diar_frontend,
    diar_encoder,
    diar_decoder,
    label_aggregator,
):
    inputs = torch.randn(2, 300)
    speech_ref = torch.randn(2, 300).float()
    text = torch.randint(high=2, size=(2, 300, 2))
    enh_model = ESPnetEnhancementModel(
        encoder=enh_encoder,
        separator=enh_separator,
        decoder=enh_decoder,
        mask_module=mask_module,
        loss_wrappers=loss_wrappers,
    )
    diar_model = ESPnetDiarizationModel(
        label_aggregator=label_aggregator,
        frontend=diar_frontend,
        encoder=diar_encoder,
        decoder=diar_decoder,
        specaug=None,
        normalize=None,
        attractor=None,
    )
    enh_s2t_model = ESPnetEnhS2TModel(
        enh_model=enh_model,
        s2t_model=diar_model,
    )

    if training:
        enh_s2t_model.train()
    else:
        enh_s2t_model.eval()

    kwargs = {
        "speech": inputs,
        "speech_ref1": speech_ref,
        "speech_ref2": speech_ref,
        "text": text,
    }
    loss, stats, weight = enh_s2t_model(**kwargs)
Ejemplo n.º 14
0
Archivo: enh.py Proyecto: akreal/espnet
    def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:
        assert check_argument_types()

        encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf)
        separator = separator_choices.get_class(args.separator)(
            encoder.output_dim, **args.separator_conf)
        decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf)
        if args.separator.endswith("nomask"):
            mask_module = mask_module_choices.get_class(args.mask_module)(
                input_dim=encoder.output_dim,
                **args.mask_module_conf,
            )
        else:
            mask_module = None

        loss_wrappers = []

        if getattr(args, "criterions", None) is not None:
            # This check is for the compatibility when load models
            # that packed by older version
            for ctr in args.criterions:
                criterion_conf = ctr.get("conf", {})
                criterion = criterion_choices.get_class(
                    ctr["name"])(**criterion_conf)
                loss_wrapper = loss_wrapper_choices.get_class(ctr["wrapper"])(
                    criterion=criterion, **ctr["wrapper_conf"])
                loss_wrappers.append(loss_wrapper)

        # 1. Build model
        model = ESPnetEnhancementModel(
            encoder=encoder,
            separator=separator,
            decoder=decoder,
            loss_wrappers=loss_wrappers,
            mask_module=mask_module,
            **args.model_conf,
        )

        # FIXME(kamo): Should be done in model?
        # 2. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
Ejemplo n.º 15
0
    def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:
        assert check_argument_types()

        encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf)
        separator = separator_choices.get_class(args.separator)(
            encoder.output_dim, **args.separator_conf)
        decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf)

        # 1. Build model
        model = ESPnetEnhancementModel(encoder=encoder,
                                       separator=separator,
                                       decoder=decoder,
                                       **args.model_conf)

        # FIXME(kamo): Should be done in model?
        # 2. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
Ejemplo n.º 16
0
def test_svoice_model(encoder, decoder, separator, training, loss_wrappers):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 17
0
def scoring(
    output_dir: str,
    dtype: str,
    log_level: Union[int, str],
    key_file: str,
    ref_scp: List[str],
    inf_scp: List[str],
    ref_channel: int,
    metrics: List[str],
    frame_size: int = 512,
    frame_hop: int = 256,
):
    assert check_argument_types()
    for metric in metrics:
        assert metric in (
            "STOI",
            "ESTOI",
            "SNR",
            "SI_SNR",
            "SDR",
            "SAR",
            "SIR",
            "framewise-SNR",
        ), metric

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    assert len(ref_scp) == len(inf_scp), ref_scp
    num_spk = len(ref_scp)

    keys = [
        line.rstrip().split(maxsplit=1)[0]
        for line in open(key_file, encoding="utf-8")
    ]

    ref_readers = [
        SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp
    ]
    inf_readers = [
        SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp
    ]

    # get sample rate
    fs, _ = ref_readers[0][keys[0]]

    # check keys
    for inf_reader, ref_reader in zip(inf_readers, ref_readers):
        assert inf_reader.keys() == ref_reader.keys()

    stft = STFTEncoder(n_fft=frame_size, hop_length=frame_hop)

    do_bss_eval = "SDR" in metrics or "SAR" in metrics or "SIR" in metrics
    with DatadirWriter(output_dir) as writer:
        for key in keys:
            ref_audios = [ref_reader[key][1] for ref_reader in ref_readers]
            inf_audios = [inf_reader[key][1] for inf_reader in inf_readers]
            ref = np.array(ref_audios)
            inf = np.array(inf_audios)
            if ref.ndim > inf.ndim:
                # multi-channel reference and single-channel output
                ref = ref[..., ref_channel]
                assert ref.shape == inf.shape, (ref.shape, inf.shape)
            elif ref.ndim < inf.ndim:
                # single-channel reference and multi-channel output
                raise ValueError("Reference must be multi-channel when the "
                                 "network output is multi-channel.")
            elif ref.ndim == inf.ndim == 3:
                # multi-channel reference and output
                ref = ref[..., ref_channel]
                inf = inf[..., ref_channel]

            if do_bss_eval or num_spk > 1:
                sdr, sir, sar, perm = bss_eval_sources(
                    ref, inf, compute_permutation=True)
            else:
                perm = [0]

            ilens = torch.LongTensor([ref.shape[1]])
            # (num_spk, T, F)
            ref_spec, flens = stft(torch.from_numpy(ref), ilens)
            inf_spec, _ = stft(torch.from_numpy(inf), ilens)

            for i in range(num_spk):
                p = int(perm[i])
                for metric in metrics:
                    name = f"{metric}_spk{i + 1}"
                    if metric == "STOI":
                        writer[name][key] = str(
                            stoi(ref[i], inf[p], fs_sig=fs, extended=False))
                    elif metric == "ESTOI":
                        writer[name][key] = str(
                            stoi(ref[i], inf[p], fs_sig=fs, extended=True))
                    elif metric == "SNR":
                        si_snr_score = -float(
                            ESPnetEnhancementModel.snr_loss(
                                torch.from_numpy(ref[i][None, ...]),
                                torch.from_numpy(inf[p][None, ...]),
                            ))
                        writer[name][key] = str(si_snr_score)
                    elif metric == "SI_SNR":
                        si_snr_score = -float(
                            ESPnetEnhancementModel.si_snr_loss(
                                torch.from_numpy(ref[i][None, ...]),
                                torch.from_numpy(inf[p][None, ...]),
                            ))
                        writer[name][key] = str(si_snr_score)
                    elif metric == "SDR":
                        writer[name][key] = str(sdr[i])
                    elif metric == "SAR":
                        writer[name][key] = str(sar[i])
                    elif metric == "SIR":
                        writer[name][key] = str(sir[i])
                    elif metric == "framewise-SNR":
                        framewise_snr = -ESPnetEnhancementModel.snr_loss(
                            ref_spec[i].abs(), inf_spec[i].abs())
                        writer[name][key] = " ".join(
                            map(str, framewise_snr.tolist()))
                    else:
                        raise ValueError("Unsupported metric: %s" % metric)
                    # save permutation assigned script file
                    writer[f"wav_spk{i + 1}"][key] = inf_readers[
                        perm[i]].data[key]
Ejemplo n.º 18
0
def test_forward_with_beamformer_net(
    training,
    mask_type,
    loss_type,
    num_spk,
    use_noise_mask,
    stft_consistency,
    use_builtin_complex,
):
    # Skip some testing cases
    if not loss_type.startswith("mask") and mask_type != "IBM":
        # `mask_type` has no effect when `loss_type` is not "mask..."
        return
    if not is_torch_1_9_plus and use_builtin_complex:
        # builtin complex support is only available in PyTorch 1.8+
        return

    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])
    speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)]
    noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    encoder = STFTEncoder(
        n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex
    )
    decoder = STFTDecoder(n_fft=8, hop_length=2)

    if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"):
        # skip this condition
        return

    beamformer = NeuralBeamformer(
        input_dim=5,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=use_noise_mask,
        beamformer_type="mvdr_souden",
    )
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        decoder=decoder,
        separator=beamformer,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )
    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "noise_ref1": noise_ref1,
        "dereverb_ref1": dereverb_ref1,
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 19
0
def scoring(
    output_dir: str,
    dtype: str,
    log_level: Union[int, str],
    key_file: str,
    ref_scp: List[str],
    inf_scp: List[str],
    ref_channel: int,
):
    assert check_argument_types()

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    assert len(ref_scp) == len(inf_scp), ref_scp
    num_spk = len(ref_scp)

    keys = [
        line.rstrip().split(maxsplit=1)[0]
        for line in open(key_file, encoding="utf-8")
    ]

    ref_readers = [
        SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp
    ]
    inf_readers = [
        SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp
    ]

    # get sample rate
    sample_rate, _ = ref_readers[0][keys[0]]

    # check keys
    for inf_reader, ref_reader in zip(inf_readers, ref_readers):
        assert inf_reader.keys() == ref_reader.keys()

    with DatadirWriter(output_dir) as writer:
        for key in keys:
            ref_audios = [ref_reader[key][1] for ref_reader in ref_readers]
            inf_audios = [inf_reader[key][1] for inf_reader in inf_readers]
            ref = np.array(ref_audios)
            inf = np.array(inf_audios)
            if ref.ndim > inf.ndim:
                # multi-channel reference and single-channel output
                ref = ref[..., ref_channel]
                assert ref.shape == inf.shape, (ref.shape, inf.shape)
            elif ref.ndim < inf.ndim:
                # single-channel reference and multi-channel output
                raise ValueError("Reference must be multi-channel when the \
                    network output is multi-channel.")
            elif ref.ndim == inf.ndim == 3:
                # multi-channel reference and output
                ref = ref[..., ref_channel]
                inf = inf[..., ref_channel]

            sdr, sir, sar, perm = bss_eval_sources(ref,
                                                   inf,
                                                   compute_permutation=True)

            for i in range(num_spk):
                stoi_score = stoi(ref[i],
                                  inf[int(perm[i])],
                                  fs_sig=sample_rate)
                si_snr_score = -float(
                    ESPnetEnhancementModel.si_snr_loss(
                        torch.from_numpy(ref[i][None, ...]),
                        torch.from_numpy(inf[int(perm[i])][None, ...]),
                    ))
                writer[f"STOI_spk{i + 1}"][key] = str(stoi_score)
                writer[f"SI_SNR_spk{i + 1}"][key] = str(si_snr_score)
                writer[f"SDR_spk{i + 1}"][key] = str(sdr[i])
                writer[f"SAR_spk{i + 1}"][key] = str(sar[i])
                writer[f"SIR_spk{i + 1}"][key] = str(sir[i])
                # save permutation assigned script file
                writer[f"wav_spk{i + 1}"][key] = inf_readers[perm[i]].data[key]