def test_full(self): # simple tensor data = ht.full((10, 2), 4, device=ht_device) self.assertIsInstance(data, ht.DNDarray) self.assertEqual(data.shape, (10, 2)) self.assertEqual(data.lshape, (10, 2)) self.assertEqual(data.dtype, ht.float32) self.assertEqual(data._DNDarray__array.dtype, torch.float32) self.assertEqual(data.split, None) self.assertTrue(ht.allclose(data, ht.float32(4.0, device=ht_device))) # non-standard dtype tensor data = ht.full((10, 2), 4, dtype=ht.int32, device=ht_device) self.assertIsInstance(data, ht.DNDarray) self.assertEqual(data.shape, (10, 2)) self.assertEqual(data.lshape, (10, 2)) self.assertEqual(data.dtype, ht.int32) self.assertEqual(data._DNDarray__array.dtype, torch.int32) self.assertEqual(data.split, None) self.assertTrue(ht.allclose(data, ht.int32(4, device=ht_device))) # split tensor data = ht.full((10, 2), 4, split=0, device=ht_device) self.assertIsInstance(data, ht.DNDarray) self.assertEqual(data.shape, (10, 2)) self.assertLessEqual(data.lshape[0], 10) self.assertEqual(data.lshape[1], 2) self.assertEqual(data.dtype, ht.float32) self.assertEqual(data._DNDarray__array.dtype, torch.float32) self.assertEqual(data.split, 0) self.assertTrue(ht.allclose(data, ht.float32(4.0, device=ht_device))) # exceptions with self.assertRaises(TypeError): ht.full("(2, 3,)", 4, dtype=ht.float64, device=ht_device) with self.assertRaises(ValueError): ht.full((-1, 3), 2, dtype=ht.float64, device=ht_device) with self.assertRaises(TypeError): ht.full((2, 3), dtype=ht.float64, split="axis", device=ht_device)
def test_any(self): # float values, minor axis x = ht.float32([[2.7, 0, 0], [0, 0, 0], [0, 0.3, 0]], device=ht_device) any_tensor = x.any(axis=1) res = ht.uint8([1, 0, 1], device=ht_device) self.assertIsInstance(any_tensor, ht.DNDarray) self.assertEqual(any_tensor.shape, (3, )) self.assertEqual(any_tensor.dtype, ht.bool) self.assertTrue(ht.equal(any_tensor, res)) # integer values, major axis, output tensor any_tensor = ht.zeros((2, ), device=ht_device) x = ht.int32([[0, 0], [0, 0], [0, 1]], device=ht_device) ht.any(x, axis=0, out=any_tensor) res = ht.uint8([0, 1], device=ht_device) self.assertIsInstance(any_tensor, ht.DNDarray) self.assertEqual(any_tensor.shape, (2, )) self.assertEqual(any_tensor.dtype, ht.bool) self.assertTrue(ht.equal(any_tensor, res)) # float values, no axis x = ht.float64([[0, 0, 0], [0, 0, 0]], device=ht_device) res = ht.zeros(1, dtype=ht.uint8, device=ht_device) any_tensor = ht.any(x) self.assertIsInstance(any_tensor, ht.DNDarray) self.assertEqual(any_tensor.shape, (1, )) self.assertEqual(any_tensor.dtype, ht.bool) self.assertTrue(ht.equal(any_tensor, res)) # split tensor, along axis x = ht.arange(10, split=0, device=ht_device) any_tensor = ht.any(x, axis=0) res = ht.uint8([1], device=ht_device) self.assertIsInstance(any_tensor, ht.DNDarray) self.assertEqual(any_tensor.shape, (1, )) self.assertEqual(any_tensor.dtype, ht.bool) self.assertTrue(ht.equal(any_tensor, res))
def test_full(self): # simple tensor data = ht.full(( 10, 2, ), 4) self.assertIsInstance(data, ht.tensor) self.assertEqual(data.shape, ( 10, 2, )) self.assertEqual(data.lshape, ( 10, 2, )) self.assertEqual(data.dtype, ht.float32) self.assertEqual(data._tensor__array.dtype, torch.float32) self.assertEqual(data.split, None) self.assertTrue(ht.allclose(data, ht.float32(4.0))) # non-standard dtype tensor data = ht.full(( 10, 2, ), 4, dtype=ht.int32) self.assertIsInstance(data, ht.tensor) self.assertEqual(data.shape, ( 10, 2, )) self.assertEqual(data.lshape, ( 10, 2, )) self.assertEqual(data.dtype, ht.int32) self.assertEqual(data._tensor__array.dtype, torch.int32) self.assertEqual(data.split, None) self.assertTrue(ht.allclose(data, ht.int32(4))) # split tensor data = ht.full(( 10, 2, ), 4, split=0) self.assertIsInstance(data, ht.tensor) self.assertEqual(data.shape, ( 10, 2, )) self.assertLessEqual(data.lshape[0], 10) self.assertEqual(data.lshape[1], 2) self.assertEqual(data.dtype, ht.float32) self.assertEqual(data._tensor__array.dtype, torch.float32) self.assertEqual(data.split, 0) self.assertTrue(ht.allclose(data, ht.float32(4.0))) # exceptions with self.assertRaises(TypeError): ht.full('(2, 3,)', 4, dtype=ht.float64) with self.assertRaises(ValueError): ht.full(( -1, 3, ), 2, dtype=ht.float64) with self.assertRaises(TypeError): ht.full(( 2, 3, ), dtype=ht.float64, split='axis')