Esempio n. 1
0
def test_beamformer_net_bf_output(num_spk):
    ch = 2
    inputs = torch.randn(2, 16, ch)
    inputs = inputs.float()
    ilens = torch.LongTensor([16, 12])
    model = BeamformerNet(
        n_fft=8,
        hop_length=2,
        num_spk=num_spk,
        use_wpe=False,
        use_beamformer=True,
        use_noise_mask=True,
    )
    model.eval()
    specs, _, masks = model(inputs, ilens)
    assert isinstance(masks, dict)
    assert "noise1" in masks
    assert masks["noise1"].shape == masks["spk1"].shape
    if num_spk > 1:
        assert isinstance(specs, list)
        assert len(specs) == num_spk
        for n in range(1, num_spk + 1):
            assert "spk{}".format(n) in masks
            assert masks["spk{}".format(n)].shape[-2] == ch
            assert specs[n -
                         1].shape[:-1] == masks["spk{}".format(n)][...,
                                                                   0, :].shape
            assert specs[n - 1].shape[-1] == 2  # real and imag
            assert specs[n - 1].dtype == torch.float
    else:
        assert isinstance(specs, torch.Tensor)
        assert "spk1" in masks
        assert masks["spk1"].shape[-2] == ch
        assert specs.shape[:-1] == masks["spk1"][..., 0, :].shape
        assert specs.shape[-1] == 2  # real and imag
        assert specs.dtype == torch.float
Esempio n. 2
0
def test_beamformer_net_wpe_output(ch, num_spk, use_dnn_mask_for_wpe):
    torch.random.manual_seed(0)
    inputs = torch.randn(2, 16, ch) if ch > 1 else torch.randn(2, 16)
    inputs = inputs.float()
    ilens = torch.LongTensor([16, 12])
    model = BeamformerNet(
        n_fft=8,
        hop_length=2,
        num_spk=num_spk,
        use_wpe=True,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        taps=5,
        delay=3,
        use_beamformer=False,
    )
    model.eval()
    spec, _, masks = model(inputs, ilens)
    assert spec.shape[0] == 2  # batch size
    assert spec.shape[-1] == 2  # real and imag
    assert spec.dtype == torch.float
    assert isinstance(masks, dict)
    if use_dnn_mask_for_wpe:
        assert "dereverb" in masks
        assert masks["dereverb"].shape == spec.shape[:-1]
Esempio n. 3
0
def test_beamformer_net_consistency(
    n_fft,
    win_length,
    hop_length,
    num_spk,
    normalize_input,
    use_wpe,
    wnet_type,
    wlayers,
    wunits,
    wprojs,
    wdropout_rate,
    taps,
    delay,
    use_dnn_mask_for_wpe,
    use_beamformer,
    bnet_type,
    blayers,
    bunits,
    bprojs,
    badim,
    ref_channel,
    use_noise_mask,
    beamformer_type,
    bdropout_rate,
):
    model = BeamformerNet(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        num_spk=num_spk,
        normalize_input=normalize_input,
        use_wpe=use_wpe,
        wnet_type=wnet_type,
        wlayers=wlayers,
        wunits=wunits,
        wprojs=wprojs,
        wdropout_rate=wdropout_rate,
        taps=taps,
        delay=delay,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        use_beamformer=use_beamformer,
        bnet_type=bnet_type,
        blayers=blayers,
        bunits=bunits,
        bprojs=bprojs,
        badim=badim,
        ref_channel=ref_channel,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
        bdropout_rate=bdropout_rate,
    )
    model.eval()

    random_input_numpy = np.random.randn(2, 16, 2)  # np.float64
    random_input_torch = torch.from_numpy(random_input_numpy).float()
    random_input_numpy = torch.from_numpy(random_input_numpy.astype(
        "float32"))  # np.float64-->np.float32-->torch.float32

    # ensures reproducibility in the matrix inverse computation
    torch.random.manual_seed(0)
    est_speech_numpy, *_ = model(random_input_numpy,
                                 ilens=torch.LongTensor([16, 12]))

    torch.random.manual_seed(0)
    est_speech_torch, *_ = model(random_input_torch,
                                 ilens=torch.LongTensor([16, 12]))
    assert torch.allclose(est_speech_torch[0], est_speech_numpy[0])
    assert torch.allclose(est_speech_torch[-1], est_speech_numpy[-1])
    for est in est_speech_torch:
        assert est.dtype == torch.float