コード例 #1
0
ファイル: test_torch.py プロジェクト: iamhatesz/torchstruct
def test_struct_should_forward_torch_calls_on_each_element_in_structure():
    t = TensorStruct({
        'points': torch.zeros((10, 2)),
        'values': torch.zeros((10, 1))
    })
    t_expanded = t.unsqueeze(dim=1)
    assert t_expanded['points'].shape == (10, 1, 2)
    assert t_expanded['values'].shape == (10, 1, 1)
コード例 #2
0
ファイル: test_torch.py プロジェクト: iamhatesz/torchstruct
def test_struct_should_forward_torch_calls_if_single_element_in_structure():
    t = TensorStruct(torch.zeros((10, 5)))
    t_expanded = t.unsqueeze(dim=1)
    assert t_expanded.shape == (10, 1, 5)
コード例 #3
0
ファイル: test_torch.py プロジェクト: iamhatesz/torchstruct
def test_struct_should_return_tensor_if_keep_struct_is_set_to_false():
    t = TensorStruct(torch.zeros((10, 5)))
    t_expanded = t.unsqueeze(dim=1, keep_struct=False)
    assert isinstance(t_expanded, torch.Tensor)
    assert t_expanded.shape == (10, 1, 5)