Exemplo n.º 1
0
def test_neural_beamformer_wpe_output(ch, num_spk, multi_source_wpe,
                                      use_dnn_mask_for_wpe):
    torch.random.manual_seed(0)
    inputs = torch.randn(2, 16, ch) if ch > 1 else torch.randn(2, 16)
    inputs = inputs.float()
    ilens = torch.LongTensor([16, 12])
    stft = STFTEncoder(n_fft=8, hop_length=2)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        use_wpe=True,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        multi_source_wpe=multi_source_wpe,
        wlayers=2,
        wunits=2,
        wprojs=2,
        taps=5,
        delay=3,
        use_beamformer=False,
    )
    model.eval()
    input_spectrum, flens = stft(inputs, ilens)
    specs, _, others = model(input_spectrum, flens)
    assert isinstance(specs, list)
    if not use_dnn_mask_for_wpe or multi_source_wpe:
        assert len(specs) == 1
    else:
        assert len(specs) == num_spk
    assert specs[0].shape == input_spectrum.shape
    assert specs[0].dtype == torch.float
    assert isinstance(others, dict)
    if use_dnn_mask_for_wpe:
        assert "mask_dereverb1" in others, others.keys()
        assert others["mask_dereverb1"].shape == specs[0].shape
Exemplo n.º 2
0
def test_neural_beamformer_bf_output(
    num_spk,
    use_noise_mask,
    beamformer_type,
    diagonal_loading,
    mask_flooring,
    use_torch_solver,
):
    if num_spk == 1 and beamformer_type in (
        "lcmv",
        "lcmp",
        "wlcmp",
        "mvdr_tfs",
        "mvdr_tfs_souden",
    ):
        # only support multiple-source cases
        return

    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])

    torch.random.manual_seed(0)
    stft = STFTEncoder(n_fft=8, hop_length=2)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        use_wpe=False,
        taps=2,
        delay=3,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
        diagonal_loading=diagonal_loading,
        mask_flooring=mask_flooring,
        use_torch_solver=use_torch_solver,
    )
    model.eval()
    input_spectrum, flens = stft(inputs, ilens)
    specs, _, others = model(input_spectrum, flens)
    assert isinstance(others, dict)
    if use_noise_mask:
        assert "mask_noise1" in others
        assert others["mask_noise1"].shape == others["mask_spk1"].shape
    assert isinstance(specs, list)
    assert len(specs) == num_spk
    for n in range(1, num_spk + 1):
        assert "mask_spk{}".format(n) in others, others.keys()
        assert others["mask_spk{}".format(n)].shape[-2] == ch
        assert specs[n - 1].shape == others["mask_spk{}".format(n)][..., 0, :].shape
        assert specs[n - 1].shape == input_spectrum[..., 0, :].shape
        if is_torch_1_9_plus and torch.is_complex(specs[n - 1]):
            assert specs[n - 1].dtype == torch.complex64
        else:
            assert specs[n - 1].dtype == torch.float
Exemplo n.º 3
0
def test_neural_beamformer_bf_output(num_spk, use_noise_mask, beamformer_type):
    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])

    torch.random.manual_seed(0)
    stft = STFTEncoder(n_fft=8, hop_length=2)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        use_wpe=False,
        taps=2,
        delay=3,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
    )
    model.eval()
    input_spectrum, flens = stft(inputs, ilens)
    specs, _, others = model(input_spectrum, flens)
    assert isinstance(others, dict)
    if use_noise_mask:
        assert "mask_noise1" in others
        assert others["mask_noise1"].shape == others["mask_spk1"].shape
    assert isinstance(specs, list)
    assert len(specs) == num_spk
    for n in range(1, num_spk + 1):
        assert "mask_spk{}".format(n) in others, others.keys()
        assert others["mask_spk{}".format(n)].shape[-2] == ch
        assert specs[n - 1].shape == others["mask_spk{}".format(n)][...,
                                                                    0, :].shape
        assert specs[n - 1].shape == input_spectrum[..., 0, :].shape
        assert specs[n - 1].dtype == torch.float