def test_result_gather_concatenate(): """ Test that tensors get concatenated when they have varying size in first dimension. """ outputs = [ {"foo": torch.zeros(4, 5)}, {"foo": torch.zeros(8, 5)}, {"foo": torch.zeros(3, 5)}, ] result = Result.gather(outputs) assert isinstance(result["foo"], torch.Tensor) assert list(result["foo"].shape) == [15, 5]
def test_result_gather_scalar(): """ Test that 0-dim tensors get gathered and stacked correctly. """ outputs = [ {"foo": torch.tensor(1)}, {"foo": torch.tensor(2)}, {"foo": torch.tensor(3)}, ] result = Result.gather(outputs) assert isinstance(result["foo"], torch.Tensor) assert list(result["foo"].shape) == [3]
def test_result_gather_stack(): """ Test that tensors get concatenated when they all have the same shape. """ outputs = [ {"foo": torch.zeros(4, 5)}, {"foo": torch.zeros(4, 5)}, {"foo": torch.zeros(4, 5)}, ] result = Result.gather(outputs) assert isinstance(result["foo"], torch.Tensor) assert list(result["foo"].shape) == [12, 5]
def test_result_gather_mixed_types(): """ Test that a collection of mixed types gets gathered into a list. """ outputs = [ {"foo": 1.2}, {"foo": ["bar", None]}, {"foo": torch.tensor(1)}, ] result = Result.gather(outputs) expected = [1.2, ["bar", None], torch.tensor(1)] assert isinstance(result["foo"], list) assert result["foo"] == expected
def test_result_gather_different_shapes(): """ Test that tensors of varying shape get gathered into a list. """ outputs = [ {"foo": torch.tensor(1)}, {"foo": torch.zeros(2, 3)}, {"foo": torch.zeros(1, 2, 3)}, ] result = Result.gather(outputs) expected = [torch.tensor(1), torch.zeros(2, 3), torch.zeros(1, 2, 3)] assert isinstance(result["foo"], list) assert all(torch.eq(r, e).all() for r, e in zip(result["foo"], expected))