Exemplo n.º 1
0
def test_neural_beamformer_wpe_output(ch, num_spk, multi_source_wpe,
                                      use_dnn_mask_for_wpe):
    torch.random.manual_seed(0)
    inputs = torch.randn(2, 16, ch) if ch > 1 else torch.randn(2, 16)
    inputs = inputs.float()
    ilens = torch.LongTensor([16, 12])
    stft = STFTEncoder(n_fft=8, hop_length=2)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        use_wpe=True,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        multi_source_wpe=multi_source_wpe,
        wlayers=2,
        wunits=2,
        wprojs=2,
        taps=5,
        delay=3,
        use_beamformer=False,
    )
    model.eval()
    input_spectrum, flens = stft(inputs, ilens)
    specs, _, others = model(input_spectrum, flens)
    assert isinstance(specs, list)
    if not use_dnn_mask_for_wpe or multi_source_wpe:
        assert len(specs) == 1
    else:
        assert len(specs) == num_spk
    assert specs[0].shape == input_spectrum.shape
    assert specs[0].dtype == torch.float
    assert isinstance(others, dict)
    if use_dnn_mask_for_wpe:
        assert "mask_dereverb1" in others, others.keys()
        assert others["mask_dereverb1"].shape == specs[0].shape
Exemplo n.º 2
0
def test_neural_beamformer_bf_output(
    num_spk,
    use_noise_mask,
    beamformer_type,
    diagonal_loading,
    mask_flooring,
    use_torch_solver,
):
    if num_spk == 1 and beamformer_type in (
        "lcmv",
        "lcmp",
        "wlcmp",
        "mvdr_tfs",
        "mvdr_tfs_souden",
    ):
        # only support multiple-source cases
        return

    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])

    torch.random.manual_seed(0)
    stft = STFTEncoder(n_fft=8, hop_length=2)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        use_wpe=False,
        taps=2,
        delay=3,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
        diagonal_loading=diagonal_loading,
        mask_flooring=mask_flooring,
        use_torch_solver=use_torch_solver,
    )
    model.eval()
    input_spectrum, flens = stft(inputs, ilens)
    specs, _, others = model(input_spectrum, flens)
    assert isinstance(others, dict)
    if use_noise_mask:
        assert "mask_noise1" in others
        assert others["mask_noise1"].shape == others["mask_spk1"].shape
    assert isinstance(specs, list)
    assert len(specs) == num_spk
    for n in range(1, num_spk + 1):
        assert "mask_spk{}".format(n) in others, others.keys()
        assert others["mask_spk{}".format(n)].shape[-2] == ch
        assert specs[n - 1].shape == others["mask_spk{}".format(n)][..., 0, :].shape
        assert specs[n - 1].shape == input_spectrum[..., 0, :].shape
        if is_torch_1_9_plus and torch.is_complex(specs[n - 1]):
            assert specs[n - 1].dtype == torch.complex64
        else:
            assert specs[n - 1].dtype == torch.float
Exemplo n.º 3
0
def test_neural_beamformer_bf_output(num_spk, use_noise_mask, beamformer_type):
    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])

    torch.random.manual_seed(0)
    stft = STFTEncoder(n_fft=8, hop_length=2)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        use_wpe=False,
        taps=2,
        delay=3,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
    )
    model.eval()
    input_spectrum, flens = stft(inputs, ilens)
    specs, _, others = model(input_spectrum, flens)
    assert isinstance(others, dict)
    if use_noise_mask:
        assert "mask_noise1" in others
        assert others["mask_noise1"].shape == others["mask_spk1"].shape
    assert isinstance(specs, list)
    assert len(specs) == num_spk
    for n in range(1, num_spk + 1):
        assert "mask_spk{}".format(n) in others, others.keys()
        assert others["mask_spk{}".format(n)].shape[-2] == ch
        assert specs[n - 1].shape == others["mask_spk{}".format(n)][...,
                                                                    0, :].shape
        assert specs[n - 1].shape == input_spectrum[..., 0, :].shape
        assert specs[n - 1].dtype == torch.float
Exemplo n.º 4
0
def test_criterion_behavior_dereverb(loss_type, num_spk):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float() for _ in range(num_spk)]
    dereverb_ref = [torch.randn(2, 300).float() for _ in range(num_spk)]
    beamformer = NeuralBeamformer(
        input_dim=17,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=False,
    )
    if loss_type == "mask_mse":
        loss_wrapper = PITSolver(
            criterion=FrequencyDomainMSE(
                compute_on_mask=True, mask_type="PSM", is_dereverb_loss=True
            )
        )
    else:
        loss_wrapper = PITSolver(criterion=SISNRLoss(is_dereverb_loss=True))
    enh_model = ESPnetEnhancementModel(
        encoder=stft_encoder,
        separator=beamformer,
        decoder=stft_decoder,
        mask_module=None,
        loss_wrappers=[loss_wrapper],
    )

    enh_model.train()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "dereverb_ref1": dereverb_ref[0],
    }
    loss, stats, weight = enh_model(**kwargs)
Exemplo n.º 5
0
def test_neural_beamformer_forward_backward(
    n_fft,
    win_length,
    hop_length,
    num_spk,
    loss_type,
    use_wpe,
    wnet_type,
    wlayers,
    wunits,
    wprojs,
    taps,
    delay,
    use_dnn_mask_for_wpe,
    multi_source_wpe,
    use_beamformer,
    bnet_type,
    blayers,
    bunits,
    bprojs,
    badim,
    ref_channel,
    use_noise_mask,
    bnonlinear,
    beamformer_type,
):
    # Skip some cases
    if num_spk > 1 and use_wpe and use_beamformer:
        if not multi_source_wpe:
            # Single-source WPE is not supported with beamformer in multi-speaker cases
            return
    elif num_spk == 1:
        if multi_source_wpe:
            # When num_spk == 1, `multi_source_wpe` has no effect
            return
        elif beamformer_type in (
            "lcmv",
            "lcmp",
            "wlcmp",
            "mvdr_tfs",
            "mvdr_tfs_souden",
        ):
            # only support multiple-source cases
            return
    if bnonlinear != "sigmoid" and (
        beamformer_type != "mvdr_souden" or multi_source_wpe
    ):
        # only test different nonlinear layers with MVDR_Souden
        return

    # ensures reproducibility and reversibility in the matrix inverse computation
    torch.random.manual_seed(0)
    stft = STFTEncoder(n_fft=n_fft, win_length=win_length, hop_length=hop_length)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        loss_type=loss_type,
        use_wpe=use_wpe,
        wnet_type=wnet_type,
        wlayers=wlayers,
        wunits=wunits,
        wprojs=wprojs,
        taps=taps,
        delay=delay,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        use_beamformer=use_beamformer,
        bnet_type=bnet_type,
        blayers=blayers,
        bunits=bunits,
        bprojs=bprojs,
        badim=badim,
        ref_channel=ref_channel,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
        rtf_iterations=2,
        shared_power=True,
    )

    model.train()
    inputs = random_speech[..., :2].float()
    ilens = torch.LongTensor([16, 12])
    input_spectrum, flens = stft(inputs, ilens)
    est_speech, flens, others = model(input_spectrum, flens)
    if loss_type.startswith("mask"):
        assert est_speech is None
        loss = sum([abs(m).mean() for m in others.values()])
    else:
        loss = sum([abs(est).mean() for est in est_speech])
    loss.backward()
Exemplo n.º 6
0
def test_beamformer_net_invalid_loss_type():
    with pytest.raises(ValueError):
        NeuralBeamformer(10, loss_type="fff")
Exemplo n.º 7
0
def test_beamformer_net_invalid_bf_type():
    with pytest.raises(ValueError):
        NeuralBeamformer(10, use_beamformer=True, beamformer_type="fff")
Exemplo n.º 8
0
def test_forward_with_beamformer_net(
    training,
    mask_type,
    loss_type,
    num_spk,
    use_noise_mask,
    stft_consistency,
    use_builtin_complex,
):
    # Skip some testing cases
    if not loss_type.startswith("mask") and mask_type != "IBM":
        # `mask_type` has no effect when `loss_type` is not "mask..."
        return
    if not is_torch_1_9_plus and use_builtin_complex:
        # builtin complex support is only available in PyTorch 1.8+
        return

    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])
    speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)]
    noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    encoder = STFTEncoder(
        n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex
    )
    decoder = STFTDecoder(n_fft=8, hop_length=2)

    if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"):
        # skip this condition
        return

    beamformer = NeuralBeamformer(
        input_dim=5,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=use_noise_mask,
        beamformer_type="mvdr_souden",
    )
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        decoder=decoder,
        separator=beamformer,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )
    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "noise_ref1": noise_ref1,
        "dereverb_ref1": dereverb_ref1,
    }
    loss, stats, weight = enh_model(**kwargs)