Example #1
0
def test_struct_should_forward_torch_calls_on_each_element_in_structure():
    t = TensorStruct({
        'points': torch.zeros((10, 2)),
        'values': torch.zeros((10, 1))
    })
    t_expanded = t.unsqueeze(dim=1)
    assert t_expanded['points'].shape == (10, 1, 2)
    assert t_expanded['values'].shape == (10, 1, 1)
Example #2
0
def test_struct_common_size_should_return_size_of_first_tensor_in_dict():
    t = TensorStruct({
        'a': torch.ones((10, 2)),
        'b': {
            'c': torch.ones((5, 2))
        }
    })
    assert t.common_size(0) in [10, 5]
Example #3
0
def test_struct_should_update_values_from_struct_at_given_indices():
    t = TensorStruct.ones({'a': {'b': (10, 1), 'c': (10, 2)}})
    new_data = TensorStruct({
        'b': torch.zeros((5, 1)),
        'c': torch.zeros((5, 2))
    })
    t['a'][2:7] = new_data
    assert t['a']['b'][2, 0] == 0
    assert t['a']['c'][2, 0] == 0
Example #4
0
def test_struct_should_update_dict_if_valid_struct_given():
    t = TensorStruct.ones({'a': {'b': (10, 1), 'c': (10, 2)}})
    new_data = TensorStruct({
        'b': torch.zeros_like(t['a']['b']),
        'c': torch.zeros_like(t['a']['c'])
    })
    t['a'] = new_data
    assert t['a']['b'][0, 0] == 0
    assert t['a']['c'][0, 0] == 0
Example #5
0
def test_struct_tensors_should_return_list_of_tensors_in_struct():
    t = TensorStruct({
        'a': torch.ones(5),
        'b': {
            'c': {
                'd': torch.ones(5) * 2
            }
        }
    })
    ts = t.tensors()
    assert len(ts) == 2
    assert any([torch.all(torch.ones(5).eq(t_)) for t_ in ts])
    assert any([torch.all(torch.ones(5).eq(t_)) * 2 for t_ in ts])
Example #6
0
def test_struct_should_raise_if_getting_property_on_structure_with_multiple_tensors(
):
    t = TensorStruct({
        'points': torch.zeros((10, 2)),
        'values': torch.zeros((10, 1))
    })
    with pytest.raises(ValueError):
        _ = t.shape
Example #7
0
def test_struct_should_return_element_when_indexing_with_string():
    t = TensorStruct.ones({'a': (10, 2), 'b': {'c': (5, 4)}})
    t_a = t['a']
    assert isinstance(t_a, torch.Tensor)
    assert t_a.shape == (10, 2)

    t_b = t['b']
    assert isinstance(t_b, TensorStruct)
    assert t_b['c'].shape == (5, 4)
Example #8
0
def test_struct_should_allow_to_create_nested_zeros_tensors():
    t = TensorStruct.zeros(
        {
            'a': 5,
            'b': (10, ),
            'c': (3, 14),
            'd': {
                'e': 2,
                'f': (3, 1, 4),
                'g': {
                    'h': {
                        'i': (8, 2)
                    }
                }
            }
        },
        prefix_shape=(1, ))
    td = t.data()
    assert td['a'].shape == (1, 5)
    assert td['b'].shape == (1, 10)
    assert td['c'].shape == (1, 3, 14)
    assert td['d']['e'].shape == (1, 2)
    assert td['d']['f'].shape == (1, 3, 1, 4)
    assert td['d']['g']['h']['i'].shape == (1, 8, 2)
Example #9
0
def test_struct_should_raise_when_given_invalid_key():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    with pytest.raises(KeyError):
        _ = t['c']
Example #10
0
def test_struct_should_raise_if_updating_invalid_key():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    with pytest.raises(KeyError):
        t['c'] = 5
Example #11
0
def test_cat_should_cat_nested_tensors():
    ts = [TensorStruct.ones({'a': (10, 2), 'b': (10, 3)}) for _ in range(5)]
    ts_ = cat(ts, dim=0)
    assert ts_['a'].shape == (50, 2)
    assert ts_['b'].shape == (50, 3)
Example #12
0
def test_struct_should_raise_if_updating_with_invalid_struct():
    t = TensorStruct.ones({'a': {'b': (10, 1), 'c': (10, 2)}})
    new_data = TensorStruct({'b': torch.zeros((10, 1))})

    with pytest.raises(ValueError):
        t['a'] = new_data
Example #13
0
def test_struct_should_raise_if_using_unsupported_assignment():
    t = TensorStruct.ones({'a': {'b': (10, 1), 'c': (10, 2)}})
    with pytest.raises(ValueError):
        t['a'] = 7
Example #14
0
def test_struct_should_raise_if_constructed_from_invalid_data():
    with pytest.raises(AssertionError):
        _ = TensorStruct({'a': (1, 2)})
Example #15
0
def test_struct_should_update_tensor_if_tensor_given():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    t['a'] = torch.zeros_like(t['a'])
    assert t['a'][0, 0] == 0
Example #16
0
def test_struct_should_allow_indexing_with_list_of_indices():
    t = TensorStruct.zeros({'a': (10, 2), 'b': (10, 3)})
    indices = [1, 3, 7]
    t_ = t[indices]
    assert t_.common_size(0) == 3
Example #17
0
def test_struct_should_return_single_elements_when_indexing_with_int():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    t_1 = t[1]
    assert isinstance(t_1, TensorStruct)
    assert t_1['a'].shape == (2, )
    assert t_1['b'].shape == (1, )
Example #18
0
def test_struct_should_return_tensor_if_keep_struct_is_set_to_false():
    t = TensorStruct(torch.zeros((10, 5)))
    t_expanded = t.unsqueeze(dim=1, keep_struct=False)
    assert isinstance(t_expanded, torch.Tensor)
    assert t_expanded.shape == (10, 1, 5)
Example #19
0
def test_struct_should_return_narrowed_elements_when_indexing_with_slice():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    t_05 = t[:5]
    assert isinstance(t_05, TensorStruct)
    assert t_05['a'].shape == (5, 2)
    assert t_05['b'].shape == (5, 1)
Example #20
0
def test_struct_should_return_tensor_property_if_single_element_in_structure():
    t = TensorStruct(torch.zeros((10, 5)))
    assert t.shape == (10, 5)
Example #21
0
def test_struct_should_return_narrowed_elements_when_indexing_with_tensor():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    t_29 = t[torch.arange(2, 9, dtype=torch.long)]
    assert isinstance(t_29, TensorStruct)
    assert t_29['a'].shape == (7, 2)
    assert t_29['b'].shape == (7, 1)
Example #22
0
def test_struct_should_forward_torch_calls_if_single_element_in_structure():
    t = TensorStruct(torch.zeros((10, 5)))
    t_expanded = t.unsqueeze(dim=1)
    assert t_expanded.shape == (10, 1, 5)
Example #23
0
def test_struct_should_raise_when_given_index_other_than_string_int_slice_tensor(
):
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    with pytest.raises(ValueError):
        _ = t[None]
Example #24
0
def test_tensorstruct_should_be_serializable_and_deserializable():
    x = TensorStruct.zeros({
        'a': (10, 2),
        'b': (10, 3)
    })
    x_ = pickle.loads(pickle.dumps(x))
Example #25
0
def test_struct_should_allow_to_create_single_zeros_tensor():
    t = TensorStruct.zeros((2, 3), (4, 5), dtype=torch.float64, device='cpu')
    assert t.shape == (4, 5, 2, 3)
    assert t.dtype == torch.float64
    assert t.device.type == 'cpu'