예제 #1
0
def test_dpcl_separator_forward_backward_complex(input_dim, rnn_type, layer,
                                                 unit, dropout, num_spk, emb_D,
                                                 nonlinear):
    model = DPCLSeparator(
        input_dim=input_dim,
        rnn_type=rnn_type,
        layer=layer,
        unit=unit,
        dropout=dropout,
        num_spk=num_spk,
        emb_D=emb_D,
        nonlinear=nonlinear,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert "tf_embedding" in others

    others["tf_embedding"].abs().mean().backward()
예제 #2
0
def test_dpcl_separator_invalid_type():
    with pytest.raises(ValueError):
        DPCLSeparator(
            input_dim=10,
            rnn_type="rnn",
            layer=2,
            unit=10,
            dropout=0.1,
            num_spk=2,
            emb_D=40,
            nonlinear="fff",
        )
예제 #3
0
def test_dpcl_separator_output():

    x = torch.rand(2, 10, 10)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 4):
        model = DPCLSeparator(
            input_dim=10,
            rnn_type="rnn",
            layer=2,
            unit=10,
            dropout=0.1,
            num_spk=num_spk,
            emb_D=40,
            nonlinear="relu",
        )
        model.eval()
        specs, _, others = model(x, x_lens)
        assert isinstance(specs, list)
        assert isinstance(others, dict)
        assert len(specs) == num_spk, len(specs)
        for n in range(num_spk):
            assert "tf_embedding" in others