Пример #1
0
def test_beamformer_net_bf_output(num_spk):
    ch = 3
    inputs = torch.randn(2, 16, ch)
    inputs = inputs.float()
    ilens = torch.LongTensor([16, 12])
    model = BeamformerNet(
        n_fft=8,
        hop_length=2,
        num_spk=num_spk,
        use_wpe=False,
        use_beamformer=True,
        use_noise_mask=True,
    )
    specs, _, masks = model(inputs, ilens)
    assert isinstance(masks, dict)
    assert "noise1" in masks
    assert masks["noise1"].shape == masks["spk1"].shape
    if num_spk > 1:
        assert isinstance(specs, list)
        assert len(specs) == num_spk
        for n in range(1, num_spk + 1):
            assert "spk{}".format(n) in masks
            assert masks["spk{}".format(n)].shape[-2] == ch
            assert specs[n -
                         1].shape[:-1] == masks["spk{}".format(n)][...,
                                                                   0, :].shape
            assert specs[n - 1].shape[-1] == 2  # real and imag
            assert specs[n - 1].dtype == torch.float
    else:
        assert isinstance(specs, torch.Tensor)
        assert "spk1" in masks
        assert masks["spk1"].shape[-2] == ch
        assert specs.shape[:-1] == masks["spk1"][..., 0, :].shape
        assert specs.shape[-1] == 2  # real and imag
        assert specs.dtype == torch.float
Пример #2
0
def test_beamformer_net_forward_backward(
    n_fft,
    win_length,
    hop_length,
    num_spk,
    normalize_input,
    use_wpe,
    wnet_type,
    wlayers,
    wunits,
    wprojs,
    wdropout_rate,
    taps,
    delay,
    use_dnn_mask_for_wpe,
    use_beamformer,
    bnet_type,
    blayers,
    bunits,
    bprojs,
    badim,
    ref_channel,
    use_noise_mask,
    beamformer_type,
    bdropout_rate,
):
    model = BeamformerNet(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        num_spk=num_spk,
        normalize_input=normalize_input,
        use_wpe=use_wpe,
        wnet_type=wnet_type,
        wlayers=wlayers,
        wunits=wunits,
        wprojs=wprojs,
        wdropout_rate=wdropout_rate,
        taps=taps,
        delay=delay,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        use_beamformer=use_beamformer,
        bnet_type=bnet_type,
        blayers=blayers,
        bunits=bunits,
        bprojs=bprojs,
        badim=badim,
        ref_channel=ref_channel,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
        bdropout_rate=bdropout_rate,
    )

    est_speech, *_ = model(torch.randn(2, 16, 2, requires_grad=True),
                           ilens=torch.LongTensor([16, 12]))
    loss = sum([abs(est).mean() for est in est_speech])
    loss.backward()
Пример #3
0
def test_beamformer_net_wpe_output(ch, num_spk, use_dnn_mask_for_wpe):
    torch.random.manual_seed(0)
    inputs = torch.randn(2, 16, ch) if ch > 1 else torch.randn(2, 16)
    inputs = inputs.float()
    ilens = torch.LongTensor([16, 12])
    model = BeamformerNet(
        n_fft=8,
        hop_length=2,
        num_spk=num_spk,
        use_wpe=True,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        taps=5,
        delay=3,
        use_beamformer=False,
    )
    model.eval()
    spec, _, masks = model(inputs, ilens)
    assert spec.shape[0] == 2  # batch size
    assert spec.shape[-1] == 2  # real and imag
    assert spec.dtype == torch.float
    assert isinstance(masks, dict)
    if use_dnn_mask_for_wpe:
        assert "dereverb" in masks
        assert masks["dereverb"].shape == spec.shape[:-1]
Пример #4
0
def test_beamformer_net_invalid_bf_type():
    with pytest.raises(ValueError):
        BeamformerNet(use_beamformer=True, beamformer_type="fff")
Пример #5
0
def test_beamformer_net_consistency(
    n_fft,
    win_length,
    hop_length,
    num_spk,
    normalize_input,
    use_wpe,
    wnet_type,
    wlayers,
    wunits,
    wprojs,
    wdropout_rate,
    taps,
    delay,
    use_dnn_mask_for_wpe,
    use_beamformer,
    bnet_type,
    blayers,
    bunits,
    bprojs,
    badim,
    ref_channel,
    use_noise_mask,
    beamformer_type,
    bdropout_rate,
):
    model = BeamformerNet(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        num_spk=num_spk,
        normalize_input=normalize_input,
        use_wpe=use_wpe,
        wnet_type=wnet_type,
        wlayers=wlayers,
        wunits=wunits,
        wprojs=wprojs,
        wdropout_rate=wdropout_rate,
        taps=taps,
        delay=delay,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        use_beamformer=use_beamformer,
        bnet_type=bnet_type,
        blayers=blayers,
        bunits=bunits,
        bprojs=bprojs,
        badim=badim,
        ref_channel=ref_channel,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
        bdropout_rate=bdropout_rate,
    )
    model.eval()

    random_input_numpy = np.random.randn(2, 16, 2)  # np.float64
    random_input_torch = torch.from_numpy(random_input_numpy).float()
    random_input_numpy = torch.from_numpy(random_input_numpy.astype(
        "float32"))  # np.float64-->np.float32-->torch.float32

    # ensures reproducibility in the matrix inverse computation
    torch.random.manual_seed(0)
    est_speech_numpy, *_ = model(random_input_numpy,
                                 ilens=torch.LongTensor([16, 12]))

    torch.random.manual_seed(0)
    est_speech_torch, *_ = model(random_input_torch,
                                 ilens=torch.LongTensor([16, 12]))
    assert torch.allclose(est_speech_torch[0], est_speech_numpy[0])
    assert torch.allclose(est_speech_torch[-1], est_speech_numpy[-1])
    for est in est_speech_torch:
        assert est.dtype == torch.float
Пример #6
0
def test_beamformer_net_forward_backward(
    n_fft,
    win_length,
    hop_length,
    num_spk,
    normalize_input,
    mask_type,
    loss_type,
    use_wpe,
    wnet_type,
    wlayers,
    wunits,
    wprojs,
    dropout_rate,
    taps,
    delay,
    use_dnn_mask_for_wpe,
    use_beamformer,
    bnet_type,
    blayers,
    bunits,
    bprojs,
    badim,
    ref_channel,
    use_noise_mask,
    beamformer_type,
):
    model = BeamformerNet(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        num_spk=num_spk,
        normalize_input=normalize_input,
        mask_type=mask_type,
        loss_type=loss_type,
        use_wpe=use_wpe,
        wnet_type=wnet_type,
        wlayers=wlayers,
        wunits=wunits,
        wprojs=wprojs,
        wdropout_rate=dropout_rate,
        taps=taps,
        delay=delay,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        use_beamformer=use_beamformer,
        bnet_type=bnet_type,
        blayers=blayers,
        bunits=bunits,
        bprojs=bprojs,
        badim=badim,
        ref_channel=ref_channel,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
        bdropout_rate=dropout_rate,
    )

    model.train()
    est_speech, flens, masks = model(torch.randn(2, 16, 2, requires_grad=True),
                                     ilens=torch.LongTensor([16, 12]))
    if loss_type.startswith("mask"):
        assert est_speech is None
        loss = sum([abs(m).mean() for m in masks.values()])
    else:
        loss = sum([abs(est).mean() for est in est_speech])
    loss.backward()
Пример #7
0
def test_beamformer_net_invalid_loss_type():
    with pytest.raises(ValueError):
        BeamformerNet(loss_type="fff")