Esempio n. 1
0
def test_dptnet(training, loss_wrappers):
    encoder = ConvEncoder(channel=16, kernel_size=36, stride=18)
    decoder = ConvDecoder(channel=16, kernel_size=36, stride=18)

    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=dptnet_separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Esempio n. 2
0
def test_single_channel_model(encoder, decoder, separator, training, loss_wrappers):
    # DCCRN separator dose not support ConvEncoder and ConvDecoder
    if isinstance(encoder, ConvEncoder) and isinstance(separator, DCCRNSeparator):
        return
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Esempio n. 3
0
def test_criterion_behavior(training):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=stft_encoder,
        separator=rnn_separator,
        decoder=stft_decoder,
        mask_module=None,
        loss_wrappers=[PITSolver(criterion=SISNRLoss(only_for_test=True))],
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }

    if training:
        with pytest.raises(AttributeError):
            loss, stats, weight = enh_model(**kwargs)
    else:
        loss, stats, weight = enh_model(**kwargs)
Esempio n. 4
0
def test_single_channel_model(encoder, decoder, separator, training,
                              loss_wrappers):
    if not isinstance(encoder, STFTEncoder) and isinstance(
            separator, (DCCRNSeparator, DC_CRNSeparator)):
        # skip because DCCRNSeparator and DC_CRNSeparator only work
        # for complex spectrum features
        return
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Esempio n. 5
0
def test_ineube(n_mics, training, loss_wrappers, output_from):
    if not is_torch_1_9_plus:
        return
    inputs = torch.randn(1, 300, n_mics)
    ilens = torch.LongTensor([300])
    speech_refs = [torch.randn(1, 300).float(), torch.randn(1, 300).float()]
    from espnet2.enh.decoder.null_decoder import NullDecoder
    from espnet2.enh.encoder.null_encoder import NullEncoder

    encoder = NullEncoder()
    decoder = NullDecoder()
    separator = iNeuBe(
        2, mic_channels=n_mics, output_from=output_from, tcn_blocks=1, tcn_repeats=1
    )
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Esempio n. 6
0
def test_single_channel_model(
    encoder, decoder, separator, stft_consistency, loss_type, mask_type, training
):
    if loss_type == "ci_sdr":
        inputs = torch.randn(2, 300)
        ilens = torch.LongTensor([300, 200])
        speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    else:
        # ci_sdr will fail if length is too short
        inputs = torch.randn(2, 100)
        ilens = torch.LongTensor([100, 80])
        speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()]

    if loss_type not in ("snr", "si_snr", "ci_sdr") and isinstance(
        encoder, ConvEncoder
    ):
        with pytest.raises(TypeError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return
    if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"):
        with pytest.raises(ValueError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return

    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Esempio n. 7
0
def test_single_channel_model(encoder, decoder, separator, stft_consistency,
                              loss_type, mask_type, training):
    if not is_torch_1_2_plus:
        pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task")

    inputs = torch.randn(2, 100)
    ilens = torch.LongTensor([100, 80])
    speech_refs = [torch.randn(2, 100).float(), torch.randn(2, 100).float()]

    if loss_type != "si_snr" and isinstance(encoder, ConvEncoder):
        with pytest.raises(TypeError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return
    if stft_consistency and loss_type in ["mask_mse", "si_snr"]:
        with pytest.raises(ValueError):
            enh_model = ESPnetEnhancementModel(
                encoder=encoder,
                separator=separator,
                decoder=decoder,
                stft_consistency=stft_consistency,
                loss_type=loss_type,
                mask_type=mask_type,
            )
        return

    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Esempio n. 8
0
def test_svoice_model(encoder, decoder, separator, training, loss_wrappers):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Esempio n. 9
0
def test_forward_with_beamformer_net(
    training,
    mask_type,
    loss_type,
    num_spk,
    use_noise_mask,
    stft_consistency,
    use_builtin_complex,
):
    # Skip some testing cases
    if not loss_type.startswith("mask") and mask_type != "IBM":
        # `mask_type` has no effect when `loss_type` is not "mask..."
        return
    if not is_torch_1_9_plus and use_builtin_complex:
        # builtin complex support is only available in PyTorch 1.8+
        return

    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])
    speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)]
    noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    encoder = STFTEncoder(
        n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex
    )
    decoder = STFTDecoder(n_fft=8, hop_length=2)

    if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"):
        # skip this condition
        return

    beamformer = NeuralBeamformer(
        input_dim=5,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=use_noise_mask,
        beamformer_type="mvdr_souden",
    )
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        decoder=decoder,
        separator=beamformer,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )
    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "noise_ref1": noise_ref1,
        "dereverb_ref1": dereverb_ref1,
    }
    loss, stats, weight = enh_model(**kwargs)