Exemplo n.º 1
0
def test_skim_separator_forward_backward_real(
    input_dim,
    layer,
    causal,
    unit,
    dropout,
    num_spk,
    nonlinear,
    mem_type,
    segment_size,
    seg_overlap,
):
    model = SkiMSeparator(
        input_dim=input_dim,
        causal=causal,
        num_spk=num_spk,
        nonlinear=nonlinear,
        layer=layer,
        unit=unit,
        segment_size=segment_size,
        dropout=dropout,
        mem_type=mem_type,
        seg_overlap=seg_overlap,
    )
    model.train()

    x = torch.rand(2, 10, input_dim)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert isinstance(masked[0], Tensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
Exemplo n.º 2
0
def test_skim_separator_invalid_type():
    with pytest.raises(ValueError):
        SkiMSeparator(
            input_dim=10,
            layer=2,
            unit=10,
            dropout=0.1,
            num_spk=2,
            nonlinear="fff",
            mem_type="aaa",
            segment_size=2,
        )
Exemplo n.º 3
0
def test_skim_separator_output():

    x = torch.rand(2, 10, 10)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 3):
        model = SkiMSeparator(
            input_dim=10,
            layer=2,
            unit=10,
            dropout=0.1,
            num_spk=2,
            nonlinear="relu",
            segment_size=2,
        )
        model.eval()
        specs, _, others = model(x, x_lens)
        assert isinstance(specs, list)
        assert isinstance(others, dict)
        assert x.shape == specs[0].shape
        for n in range(num_spk):
            assert "mask_spk{}".format(n + 1) in others
            assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape