Ejemplo n.º 1
0
def test_tasnet_forward_backward(
    N,
    L,
    B,
    H,
    P,
    X,
    R,
    num_spk,
    norm_type,
    causal,
    mask_nonlinear,
):
    model = TasNet(
        N=N,
        L=L,
        B=B,
        H=H,
        P=P,
        X=X,
        R=R,
        num_spk=num_spk,
        norm_type=norm_type,
        causal=causal,
        mask_nonlinear=mask_nonlinear,
    )

    est_speech, *_ = model(
        torch.randn(2, 100, requires_grad=True), ilens=torch.LongTensor([100, 80])
    )
    loss = sum([est.mean() for est in est_speech])
    loss.backward()
Ejemplo n.º 2
0
def test_tasnet_cosistency(
    N,
    L,
    B,
    H,
    P,
    X,
    R,
    num_spk,
    norm_type,
    causal,
    mask_nonlinear,
):
    model = TasNet(
        N=N,
        L=L,
        B=B,
        H=H,
        P=P,
        X=X,
        R=R,
        num_spk=num_spk,
        norm_type=norm_type,
        causal=causal,
        mask_nonlinear=mask_nonlinear,
    )
    model.eval()

    random_input_numpy = np.random.randn(2, 100)  # np.float64
    random_input_torch = (
        torch.from_numpy(random_input_numpy - 1.0).float() + 1.0
    )  # torch.float32
    random_input_numpy = torch.from_numpy(
        random_input_numpy.astype("float32")
    )  # np.float64-->np.float32-->torch.float32
    est_speech_numpy, *_ = model(random_input_numpy, ilens=torch.LongTensor([100, 80]))
    est_speech_torch, *_ = model(random_input_torch, ilens=torch.LongTensor([100, 80]))
    assert (est_speech_torch[0] - est_speech_numpy[0]).abs().mean() < 1e-5
    assert (
        np.abs((est_speech_torch[-1] - est_speech_numpy[-1]).detach().numpy()).mean()
        < 1e-5
    )
Ejemplo n.º 3
0
def test_tasnet_output(
    N,
    L,
    B,
    H,
    P,
    X,
    R,
    num_spk,
    norm_type,
    causal,
    mask_nonlinear,
):
    inputs = torch.randn(2, 160)
    ilens = torch.LongTensor([160, 120])
    for num_spk in range(1, 3):
        model = TasNet(
            N=N,
            L=L,
            B=B,
            H=H,
            P=P,
            X=X,
            R=R,
            num_spk=num_spk,
            norm_type=norm_type,
            causal=causal,
            mask_nonlinear=mask_nonlinear,
        )
        specs, _, masks = model(inputs, ilens)
        assert isinstance(specs, list)
        assert isinstance(masks, dict)
        for n in range(num_spk):
            assert "spk{}".format(n + 1) in masks
            assert specs[n].shape == masks["spk{}".format(n + 1)].shape
            assert specs[n].shape == (2, 160)
Ejemplo n.º 4
0
def test_tasnet_invalid_mask_nonlinear():
    with pytest.raises(ValueError):
        TasNet(5, 20, 5, 10, 3, 8, 4, 2, mask_nonlinear="fff")
Ejemplo n.º 5
0
def test_tasnet_invalid_norm_type():
    with pytest.raises(ValueError):
        TasNet(5, 20, 5, 10, 3, 8, 4, 2, norm_type="fff")
Ejemplo n.º 6
0
def test_tasnet_invalid_mask_nonlinear():
    with pytest.raises(ValueError):
        TasNet(2, 40, 2, 2, 3, 3, 2, 2, mask_nonlinear="fff")
Ejemplo n.º 7
0
def test_tasnet_invalid_norm_type():
    with pytest.raises(ValueError):
        TasNet(2, 40, 2, 2, 3, 3, 2, 2, norm_type="fff")