def test_shape_empty(self): """Test shape of empty tensor""" t1 = Tensor(rank_ids=["M", "K"]) self.assertEqual(t1.getRankIds(), ["M", "K"]) self.assertEqual(t1.getShape(), [0, 0]) t2 = Tensor(rank_ids=["M", "K"], shape=[10, 20]) self.assertEqual(t2.getRankIds(), ["M", "K"]) self.assertEqual(t2.getShape(), [10, 20])
def test_shape_0D(self): """Test shpe of 0-D tensor""" t = Tensor(rank_ids=[]) p = t.getRoot() p += 1 self.assertEqual(t.getRankIds(), []) self.assertEqual(t.getShape(), [])
def test_constructor_shape(self): """Test construction of shape of tensor""" ranks = ["M", "K"] shape = [4, 8] t = Tensor(rank_ids=ranks, shape=shape) self.assertEqual(t.getRankIds(), ranks) self.assertEqual(t.getRoot().getRankIds(), ranks) self.assertEqual(t.getShape(), shape) self.assertEqual(t.getRoot().getShape(), shape)