Example #1
0
def test_struct_should_forward_torch_calls_on_each_element_in_structure():
    t = TensorStruct({
        'points': torch.zeros((10, 2)),
        'values': torch.zeros((10, 1))
    })
    t_expanded = t.unsqueeze(dim=1)
    assert t_expanded['points'].shape == (10, 1, 2)
    assert t_expanded['values'].shape == (10, 1, 1)
Example #2
0
def test_struct_should_forward_torch_calls_if_single_element_in_structure():
    t = TensorStruct(torch.zeros((10, 5)))
    t_expanded = t.unsqueeze(dim=1)
    assert t_expanded.shape == (10, 1, 5)
Example #3
0
def test_struct_should_return_tensor_if_keep_struct_is_set_to_false():
    t = TensorStruct(torch.zeros((10, 5)))
    t_expanded = t.unsqueeze(dim=1, keep_struct=False)
    assert isinstance(t_expanded, torch.Tensor)
    assert t_expanded.shape == (10, 1, 5)