Example #1
0
def test_struct_should_allow_to_create_nested_zeros_tensors():
    t = TensorStruct.zeros(
        {
            'a': 5,
            'b': (10, ),
            'c': (3, 14),
            'd': {
                'e': 2,
                'f': (3, 1, 4),
                'g': {
                    'h': {
                        'i': (8, 2)
                    }
                }
            }
        },
        prefix_shape=(1, ))
    td = t.data()
    assert td['a'].shape == (1, 5)
    assert td['b'].shape == (1, 10)
    assert td['c'].shape == (1, 3, 14)
    assert td['d']['e'].shape == (1, 2)
    assert td['d']['f'].shape == (1, 3, 1, 4)
    assert td['d']['g']['h']['i'].shape == (1, 8, 2)
Example #2
0
def test_tensorstruct_should_be_serializable_and_deserializable():
    x = TensorStruct.zeros({
        'a': (10, 2),
        'b': (10, 3)
    })
    x_ = pickle.loads(pickle.dumps(x))
Example #3
0
def test_struct_should_allow_indexing_with_list_of_indices():
    t = TensorStruct.zeros({'a': (10, 2), 'b': (10, 3)})
    indices = [1, 3, 7]
    t_ = t[indices]
    assert t_.common_size(0) == 3
Example #4
0
def test_struct_should_allow_to_create_single_zeros_tensor():
    t = TensorStruct.zeros((2, 3), (4, 5), dtype=torch.float64, device='cpu')
    assert t.shape == (4, 5, 2, 3)
    assert t.dtype == torch.float64
    assert t.device.type == 'cpu'