Ejemplo n.º 1
0
def test_scatter_tensor():
    ab = torch.zeros(2, 1)

    a, b = scatter(ab, chunks=2)

    assert a.tensor.size() == (1, 1)
    assert b.tensor.size() == (1, 1)
Ejemplo n.º 2
0
def test_scatter_multiple_tensors():
    ab = (torch.zeros(2, 1), torch.zeros(4, 2))

    a, b = scatter(*ab, chunks=2)

    assert list(a)[0].size() == (1, 1)
    assert list(b)[0].size() == (1, 1)
    assert list(a)[1].size() == (2, 2)
    assert list(b)[1].size() == (2, 2)
Ejemplo n.º 3
0
def test_scatter_tuple():
    ab = (torch.zeros(2, 1), torch.zeros(4, 2))

    a, b = scatter(ab, chunks=2)

    assert a.tensors[0].size() == (1, 1)
    assert b.tensors[0].size() == (1, 1)
    assert a.tensors[1].size() == (2, 2)
    assert b.tensors[1].size() == (2, 2)