示例#1
0
def test_result_gather_concatenate():
    """ Test that tensors get concatenated when they have varying size in first dimension. """
    outputs = [
        {"foo": torch.zeros(4, 5)},
        {"foo": torch.zeros(8, 5)},
        {"foo": torch.zeros(3, 5)},
    ]
    result = Result.gather(outputs)
    assert isinstance(result["foo"], torch.Tensor)
    assert list(result["foo"].shape) == [15, 5]
示例#2
0
def test_result_gather_scalar():
    """ Test that 0-dim tensors get gathered and stacked correctly. """
    outputs = [
        {"foo": torch.tensor(1)},
        {"foo": torch.tensor(2)},
        {"foo": torch.tensor(3)},
    ]
    result = Result.gather(outputs)
    assert isinstance(result["foo"], torch.Tensor)
    assert list(result["foo"].shape) == [3]
示例#3
0
def test_result_gather_stack():
    """ Test that tensors get concatenated when they all have the same shape. """
    outputs = [
        {"foo": torch.zeros(4, 5)},
        {"foo": torch.zeros(4, 5)},
        {"foo": torch.zeros(4, 5)},
    ]
    result = Result.gather(outputs)
    assert isinstance(result["foo"], torch.Tensor)
    assert list(result["foo"].shape) == [12, 5]
示例#4
0
def test_result_gather_mixed_types():
    """ Test that a collection of mixed types gets gathered into a list. """
    outputs = [
        {"foo": 1.2},
        {"foo": ["bar", None]},
        {"foo": torch.tensor(1)},
    ]
    result = Result.gather(outputs)
    expected = [1.2, ["bar", None], torch.tensor(1)]
    assert isinstance(result["foo"], list)
    assert result["foo"] == expected
示例#5
0
def test_result_gather_different_shapes():
    """ Test that tensors of varying shape get gathered into a list. """
    outputs = [
        {"foo": torch.tensor(1)},
        {"foo": torch.zeros(2, 3)},
        {"foo": torch.zeros(1, 2, 3)},
    ]
    result = Result.gather(outputs)
    expected = [torch.tensor(1), torch.zeros(2, 3), torch.zeros(1, 2, 3)]
    assert isinstance(result["foo"], list)
    assert all(torch.eq(r, e).all() for r, e in zip(result["foo"], expected))