def test_struct_should_allow_to_create_nested_zeros_tensors(): t = TensorStruct.zeros( { 'a': 5, 'b': (10, ), 'c': (3, 14), 'd': { 'e': 2, 'f': (3, 1, 4), 'g': { 'h': { 'i': (8, 2) } } } }, prefix_shape=(1, )) td = t.data() assert td['a'].shape == (1, 5) assert td['b'].shape == (1, 10) assert td['c'].shape == (1, 3, 14) assert td['d']['e'].shape == (1, 2) assert td['d']['f'].shape == (1, 3, 1, 4) assert td['d']['g']['h']['i'].shape == (1, 8, 2)
def test_tensorstruct_should_be_serializable_and_deserializable(): x = TensorStruct.zeros({ 'a': (10, 2), 'b': (10, 3) }) x_ = pickle.loads(pickle.dumps(x))
def test_struct_should_allow_indexing_with_list_of_indices(): t = TensorStruct.zeros({'a': (10, 2), 'b': (10, 3)}) indices = [1, 3, 7] t_ = t[indices] assert t_.common_size(0) == 3
def test_struct_should_allow_to_create_single_zeros_tensor(): t = TensorStruct.zeros((2, 3), (4, 5), dtype=torch.float64, device='cpu') assert t.shape == (4, 5, 2, 3) assert t.dtype == torch.float64 assert t.device.type == 'cpu'