コード例 #1
0
def test_single_channel_model(
    encoder, decoder, separator, stft_consistency, loss_type, mask_type, training
):
    if loss_type == "ci_sdr":
        inputs = torch.randn(2, 300)
        ilens = torch.LongTensor([300, 200])
        speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    else:
        # ci_sdr will fail if length is too short
        inputs = torch.randn(2, 100)
        ilens = torch.LongTensor([100, 80])
        speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()]

    if loss_type not in ("snr", "si_snr", "ci_sdr") and isinstance(
        encoder, ConvEncoder
    ):
        with pytest.raises(TypeError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return
    if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"):
        with pytest.raises(ValueError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return

    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
コード例 #2
0
ファイル: test_espnet_model.py プロジェクト: jnishi/espnet-1
def test_single_channel_model(encoder, decoder, separator, stft_consistency,
                              loss_type, mask_type, training):
    if not is_torch_1_2_plus:
        pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task")

    inputs = torch.randn(2, 100)
    ilens = torch.LongTensor([100, 80])
    speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()]

    if loss_type != "si_snr" and isinstance(encoder, ConvEncoder):
        with pytest.raises(TypeError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return
    if stft_consistency and loss_type in ["mask_mse", "si_snr"]:
        with pytest.raises(ValueError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return

    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
コード例 #3
0
def test_ineube(n_mics, training, loss_wrappers, output_from):
    if not is_torch_1_9_plus:
        return
    inputs = torch.randn(1, 300, n_mics)
    ilens = torch.LongTensor([300])
    speech_refs = [torch.randn(1, 300).float(), torch.randn(1, 300).float()]
    from espnet2.enh.decoder.null_decoder import NullDecoder
    from espnet2.enh.encoder.null_encoder import NullEncoder

    encoder = NullEncoder()
    decoder = NullDecoder()
    separator = iNeuBe(
        2, mic_channels=n_mics, output_from=output_from, tcn_blocks=1, tcn_repeats=1
    )
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
コード例 #4
0
def test_dptnet(training, loss_wrappers):
    encoder = ConvEncoder(channel=16, kernel_size=36, stride=18)
    decoder = ConvDecoder(channel=16, kernel_size=36, stride=18)

    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=dptnet_separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
コード例 #5
0
    def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:
        assert check_argument_types()

        encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf)
        separator = separator_choices.get_class(args.separator)(
            encoder.output_dim, **args.separator_conf)
        decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf)

        loss_wrappers = []
        for ctr in args.criterions:
            criterion = criterion_choices.get_class(ctr["name"])(**ctr["conf"])
            loss_wrapper = loss_wrapper_choices.get_class(ctr["wrapper"])(
                criterion=criterion, **ctr["wrapper_conf"])
            loss_wrappers.append(loss_wrapper)

        # 1. Build model
        model = ESPnetEnhancementModel(encoder=encoder,
                                       separator=separator,
                                       decoder=decoder,
                                       loss_wrappers=loss_wrappers,
                                       **args.model_conf)

        # FIXME(kamo): Should be done in model?
        # 2. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
コード例 #6
0
def test_criterion_behavior_noise(encoder, decoder, separator):
    if not isinstance(encoder, STFTEncoder) and isinstance(
        separator, (DCCRNSeparator, DC_CRNSeparator)
    ):
        # skip because DCCRNSeparator and DC_CRNSeparator only work
        # for complex spectrum features
        return
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    noise_ref = torch.randn(2, 300)
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=[PITSolver(criterion=SISNRLoss(is_noise_loss=True))],
    )

    enh_model.train()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
        "noise_ref1": noise_ref,
    }
    loss, stats, weight = enh_model(**kwargs)
コード例 #7
0
def test_single_channel_model(encoder, decoder, separator, training, loss_wrappers):
    # DCCRN separator dose not support ConvEncoder and ConvDecoder
    if isinstance(encoder, ConvEncoder) and isinstance(separator, DCCRNSeparator):
        return
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
コード例 #8
0
def test_criterion_behavior(training):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=stft_encoder,
        separator=rnn_separator,
        decoder=stft_decoder,
        mask_module=None,
        loss_wrappers=[PITSolver(criterion=SISNRLoss(only_for_test=True))],
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }

    if training:
        with pytest.raises(AttributeError):
            loss, stats, weight = enh_model(**kwargs)
    else:
        loss, stats, weight = enh_model(**kwargs)
コード例 #9
0
ファイル: test_espnet_model.py プロジェクト: siddalmia/espnet
def test_single_channel_model(encoder, decoder, separator, training,
                              loss_wrappers):
    if not isinstance(encoder, STFTEncoder) and isinstance(
            separator, (DCCRNSeparator, DC_CRNSeparator)):
        # skip because DCCRNSeparator and DC_CRNSeparator only work
        # for complex spectrum features
        return
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
コード例 #10
0
def test_enh_asr_model(
    enh_encoder,
    enh_decoder,
    enh_separator,
    training,
    loss_wrappers,
    frontend,
    s2t_encoder,
    s2t_decoder,
    s2t_ctc,
):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_ref = torch.randn(2, 300).float()
    text = torch.LongTensor([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]])
    text_lengths = torch.LongTensor([5, 5])
    enh_model = ESPnetEnhancementModel(
        encoder=enh_encoder,
        separator=enh_separator,
        decoder=enh_decoder,
        mask_module=None,
        loss_wrappers=loss_wrappers,
    )
    s2t_model = ESPnetASRModel(
        vocab_size=len(token_list),
        token_list=token_list,
        frontend=frontend,
        encoder=s2t_encoder,
        decoder=s2t_decoder,
        ctc=s2t_ctc,
        specaug=None,
        normalize=None,
        preencoder=None,
        postencoder=None,
        joint_network=None,
    )
    enh_s2t_model = ESPnetEnhS2TModel(
        enh_model=enh_model,
        s2t_model=s2t_model,
    )

    if training:
        enh_s2t_model.train()
    else:
        enh_s2t_model.eval()

    kwargs = {
        "speech": inputs,
        "speech_lengths": ilens,
        "speech_ref1": speech_ref,
        "text": text,
        "text_lengths": text_lengths,
    }
    loss, stats, weight = enh_s2t_model(**kwargs)
コード例 #11
0
    def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:
        assert check_argument_types()

        enh_model = enh_choices.get_class(args.enh)(**args.enh_conf)

        # 1. Build model
        model = ESPnetEnhancementModel(enh_model=enh_model)

        # FIXME(kamo): Should be done in model?
        # 2. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
コード例 #12
0
def test_enh_diar_model(
    enh_encoder,
    enh_decoder,
    enh_separator,
    mask_module,
    training,
    loss_wrappers,
    diar_frontend,
    diar_encoder,
    diar_decoder,
    label_aggregator,
):
    inputs = torch.randn(2, 300)
    speech_ref = torch.randn(2, 300).float()
    text = torch.randint(high=2, size=(2, 300, 2))
    enh_model = ESPnetEnhancementModel(
        encoder=enh_encoder,
        separator=enh_separator,
        decoder=enh_decoder,
        mask_module=mask_module,
        loss_wrappers=loss_wrappers,
    )
    diar_model = ESPnetDiarizationModel(
        label_aggregator=label_aggregator,
        frontend=diar_frontend,
        encoder=diar_encoder,
        decoder=diar_decoder,
        specaug=None,
        normalize=None,
        attractor=None,
    )
    enh_s2t_model = ESPnetEnhS2TModel(
        enh_model=enh_model,
        s2t_model=diar_model,
    )

    if training:
        enh_s2t_model.train()
    else:
        enh_s2t_model.eval()

    kwargs = {
        "speech": inputs,
        "speech_ref1": speech_ref,
        "speech_ref2": speech_ref,
        "text": text,
    }
    loss, stats, weight = enh_s2t_model(**kwargs)
コード例 #13
0
def test_criterion_behavior_dereverb(loss_type, num_spk):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float() for _ in range(num_spk)]
    dereverb_ref = [torch.randn(2, 300).float() for _ in range(num_spk)]
    beamformer = NeuralBeamformer(
        input_dim=17,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=False,
    )
    if loss_type == "mask_mse":
        loss_wrapper = PITSolver(
            criterion=FrequencyDomainMSE(
                compute_on_mask=True, mask_type="PSM", is_dereverb_loss=True
            )
        )
    else:
        loss_wrapper = PITSolver(criterion=SISNRLoss(is_dereverb_loss=True))
    enh_model = ESPnetEnhancementModel(
        encoder=stft_encoder,
        separator=beamformer,
        decoder=stft_decoder,
        mask_module=None,
        loss_wrappers=[loss_wrapper],
    )

    enh_model.train()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "dereverb_ref1": dereverb_ref[0],
    }
    loss, stats, weight = enh_model(**kwargs)
コード例 #14
0
ファイル: enh.py プロジェクト: akreal/espnet
    def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:
        assert check_argument_types()

        encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf)
        separator = separator_choices.get_class(args.separator)(
            encoder.output_dim, **args.separator_conf)
        decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf)
        if args.separator.endswith("nomask"):
            mask_module = mask_module_choices.get_class(args.mask_module)(
                input_dim=encoder.output_dim,
                **args.mask_module_conf,
            )
        else:
            mask_module = None

        loss_wrappers = []

        if getattr(args, "criterions", None) is not None:
            # This check is for the compatibility when load models
            # that packed by older version
            for ctr in args.criterions:
                criterion_conf = ctr.get("conf", {})
                criterion = criterion_choices.get_class(
                    ctr["name"])(**criterion_conf)
                loss_wrapper = loss_wrapper_choices.get_class(ctr["wrapper"])(
                    criterion=criterion, **ctr["wrapper_conf"])
                loss_wrappers.append(loss_wrapper)

        # 1. Build model
        model = ESPnetEnhancementModel(
            encoder=encoder,
            separator=separator,
            decoder=decoder,
            loss_wrappers=loss_wrappers,
            mask_module=mask_module,
            **args.model_conf,
        )

        # FIXME(kamo): Should be done in model?
        # 2. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
コード例 #15
0
ファイル: enh.py プロジェクト: zhouyh-jlu/espnet
    def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:
        assert check_argument_types()

        encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf)
        separator = separator_choices.get_class(args.separator)(
            encoder.output_dim, **args.separator_conf)
        decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf)

        # 1. Build model
        model = ESPnetEnhancementModel(encoder=encoder,
                                       separator=separator,
                                       decoder=decoder,
                                       **args.model_conf)

        # FIXME(kamo): Should be done in model?
        # 2. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
コード例 #16
0
def test_svoice_model(encoder, decoder, separator, training, loss_wrappers):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
コード例 #17
0
def test_forward_with_beamformer_net(
    training,
    mask_type,
    loss_type,
    num_spk,
    use_noise_mask,
    stft_consistency,
    use_builtin_complex,
):
    # Skip some testing cases
    if not loss_type.startswith("mask") and mask_type != "IBM":
        # `mask_type` has no effect when `loss_type` is not "mask..."
        return
    if not is_torch_1_9_plus and use_builtin_complex:
        # builtin complex support is only available in PyTorch 1.8+
        return

    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])
    speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)]
    noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    encoder = STFTEncoder(
        n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex
    )
    decoder = STFTDecoder(n_fft=8, hop_length=2)

    if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"):
        # skip this condition
        return

    beamformer = NeuralBeamformer(
        input_dim=5,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=use_noise_mask,
        beamformer_type="mvdr_souden",
    )
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        decoder=decoder,
        separator=beamformer,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )
    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "noise_ref1": noise_ref1,
        "dereverb_ref1": dereverb_ref1,
    }
    loss, stats, weight = enh_model(**kwargs)