def test_tril(self): local_ones = ht.ones((5, )) # 1D case, no offset, data is not split, module-level call result = ht.tril(local_ones) comparison = torch.ones((5, 5), device=self.device.torch_device).tril() self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (5, 5)) self.assertEqual(result.lshape, (5, 5)) self.assertEqual(result.split, None) self.assertTrue((result._DNDarray__array == comparison).all()) # 1D case, positive offset, data is not split, module-level call result = ht.tril(local_ones, k=2) comparison = torch.ones( (5, 5), device=self.device.torch_device).tril(diagonal=2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (5, 5)) self.assertEqual(result.lshape, (5, 5)) self.assertEqual(result.split, None) self.assertTrue((result._DNDarray__array == comparison).all()) # 1D case, negative offset, data is not split, module-level call result = ht.tril(local_ones, k=-2) comparison = torch.ones( (5, 5), device=self.device.torch_device).tril(diagonal=-2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (5, 5)) self.assertEqual(result.lshape, (5, 5)) self.assertEqual(result.split, None) self.assertTrue((result._DNDarray__array == comparison).all()) local_ones = ht.ones((4, 5)) # 2D case, no offset, data is not split, method result = local_ones.tril() comparison = torch.ones((4, 5), device=self.device.torch_device).tril() self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (4, 5)) self.assertEqual(result.lshape, (4, 5)) self.assertEqual(result.split, None) self.assertTrue((result._DNDarray__array == comparison).all()) # 2D case, positive offset, data is not split, method result = local_ones.tril(k=2) comparison = torch.ones( (4, 5), device=self.device.torch_device).tril(diagonal=2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (4, 5)) self.assertEqual(result.lshape, (4, 5)) self.assertEqual(result.split, None) self.assertTrue((result._DNDarray__array == comparison).all()) # 2D case, negative offset, data is not split, method result = local_ones.tril(k=-2) comparison = torch.ones( (4, 5), device=self.device.torch_device).tril(diagonal=-2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (4, 5)) self.assertEqual(result.lshape, (4, 5)) self.assertEqual(result.split, None) self.assertTrue((result._DNDarray__array == comparison).all()) local_ones = ht.ones((3, 4, 5, 6)) # 2D+ case, no offset, data is not split, module-level call result = local_ones.tril() comparison = torch.ones((5, 6), device=self.device.torch_device).tril() self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (3, 4, 5, 6)) self.assertEqual(result.lshape, (3, 4, 5, 6)) self.assertEqual(result.split, None) for i in range(3): for j in range(4): self.assertTrue( (result._DNDarray__array[i, j] == comparison).all()) # 2D+ case, positive offset, data is not split, module-level call result = local_ones.tril(k=2) comparison = torch.ones( (5, 6), device=self.device.torch_device).tril(diagonal=2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (3, 4, 5, 6)) self.assertEqual(result.lshape, (3, 4, 5, 6)) self.assertEqual(result.split, None) for i in range(3): for j in range(4): self.assertTrue( (result._DNDarray__array[i, j] == comparison).all()) # # 2D+ case, negative offset, data is not split, module-level call result = local_ones.tril(k=-2) comparison = torch.ones( (5, 6), device=self.device.torch_device).tril(diagonal=-2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (3, 4, 5, 6)) self.assertEqual(result.lshape, (3, 4, 5, 6)) self.assertEqual(result.split, None) for i in range(3): for j in range(4): self.assertTrue( (result._DNDarray__array[i, j] == comparison).all()) distributed_ones = ht.ones((5, ), split=0) # 1D case, no offset, data is split, method result = distributed_ones.tril() self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (5, 5)) self.assertEqual(result.split, 1) self.assertTrue(result.lshape[0] == 5 or result.lshape[0] == 0) self.assertLessEqual(result.lshape[1], 5) self.assertTrue(result.sum(), 15) if result.comm.rank == 0: self.assertTrue(result._DNDarray__array[-1, 0] == 1) if result.comm.rank == result.shape[0] - 1: self.assertTrue(result._DNDarray__array[0, -1] == 0) # 1D case, positive offset, data is split, method result = distributed_ones.tril(k=2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (5, 5)) self.assertEqual(result.split, 1) self.assertEqual(result.lshape[0], 5) self.assertLessEqual(result.lshape[1], 5) self.assertEqual(result.sum(), 22) if result.comm.rank == 0: self.assertTrue(result._DNDarray__array[-1, 0] == 1) if result.comm.rank == result.shape[0] - 1: self.assertTrue(result._DNDarray__array[0, -1] == 0) # 1D case, negative offset, data is split, method result = distributed_ones.tril(k=-2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (5, 5)) self.assertEqual(result.split, 1) self.assertEqual(result.lshape[0], 5) self.assertLessEqual(result.lshape[1], 5) self.assertEqual(result.sum(), 6) if result.comm.rank == 0: self.assertTrue(result._DNDarray__array[-1, 0] == 1) if result.comm.rank == result.shape[0] - 1: self.assertTrue(result._DNDarray__array[0, -1] == 0) distributed_ones = ht.ones((4, 5), split=0) # 2D case, no offset, data is horizontally split, method result = distributed_ones.tril() self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (4, 5)) self.assertEqual(result.split, 0) self.assertLessEqual(result.lshape[0], 4) self.assertEqual(result.lshape[1], 5) self.assertEqual(result.sum(), 10) if result.comm.rank == 0: self.assertTrue(result._DNDarray__array[0, -1] == 0) if result.comm.rank == result.shape[0] - 1: self.assertTrue(result._DNDarray__array[-1, 0] == 1) # 2D case, positive offset, data is horizontally split, method result = distributed_ones.tril(k=2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (4, 5)) self.assertEqual(result.split, 0) self.assertLessEqual(result.lshape[0], 4) self.assertEqual(result.lshape[1], 5) self.assertEqual(result.sum(), 17) if result.comm.rank == 0: self.assertTrue(result._DNDarray__array[0, -1] == 0) if result.comm.rank == result.shape[0] - 1: self.assertTrue(result._DNDarray__array[-1, 0] == 1) # 2D case, negative offset, data is horizontally split, method result = distributed_ones.tril(k=-2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (4, 5)) self.assertEqual(result.split, 0) self.assertLessEqual(result.lshape[0], 4) self.assertEqual(result.lshape[1], 5) self.assertEqual(result.sum(), 3) if result.comm.rank == 0: self.assertTrue(result._DNDarray__array[0, -1] == 0) if result.comm.rank == result.shape[0] - 1: self.assertTrue(result._DNDarray__array[-1, 0] == 1) distributed_ones = ht.ones((4, 5), split=1) # 2D case, no offset, data is vertically split, method result = distributed_ones.tril() self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (4, 5)) self.assertEqual(result.split, 1) self.assertEqual(result.lshape[0], 4) self.assertLessEqual(result.lshape[1], 5) self.assertEqual(result.sum(), 10) if result.comm.rank == 0: self.assertTrue(result._DNDarray__array[-1, 0] == 1) if result.comm.rank == result.shape[0] - 1: self.assertTrue(result._DNDarray__array[0, -1] == 0) # 2D case, positive offset, data is horizontally split, method result = distributed_ones.tril(k=2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (4, 5)) self.assertEqual(result.split, 1) self.assertEqual(result.lshape[0], 4) self.assertLessEqual(result.lshape[1], 5) self.assertEqual(result.sum(), 17) if result.comm.rank == 0: self.assertTrue(result._DNDarray__array[-1, 0] == 1) if result.comm.rank == result.shape[0] - 1: self.assertTrue(result._DNDarray__array[0, -1] == 0) # 2D case, negative offset, data is horizontally split, method result = distributed_ones.tril(k=-2) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.shape, (4, 5)) self.assertEqual(result.split, 1) self.assertEqual(result.lshape[0], 4) self.assertLessEqual(result.lshape[1], 5) self.assertEqual(result.sum(), 3) if result.comm.rank == 0: self.assertTrue(result._DNDarray__array[-1, 0] == 1) if result.comm.rank == result.shape[0] - 1: self.assertTrue(result._DNDarray__array[0, -1] == 0) with self.assertRaises(TypeError): ht.tril("asdf") with self.assertRaises(TypeError): ht.tril(distributed_ones, m=["sdf", "sf"])
def test_argmax(self): torch.manual_seed(1) data = ht.random.randn(3, 4, 5) # 3D local tensor, major axis result = ht.argmax(data, axis=0) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.dtype, ht.int64) self.assertEqual(result._DNDarray__array.dtype, torch.int64) self.assertEqual(result.shape, (4, 5)) self.assertEqual(result.lshape, (4, 5)) self.assertEqual(result.split, None) self.assertTrue( (result._DNDarray__array == data._DNDarray__array.argmax(0)).all()) # 3D local tensor, minor axis result = ht.argmax(data, axis=-1, keepdim=True) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.dtype, ht.int64) self.assertEqual(result._DNDarray__array.dtype, torch.int64) self.assertEqual(result.shape, (3, 4, 1)) self.assertEqual(result.lshape, (3, 4, 1)) self.assertEqual(result.split, None) self.assertTrue( (result._DNDarray__array == data._DNDarray__array.argmax( -1, keepdim=True)).all()) # 1D split tensor, no axis data = ht.arange(-10, 10, split=0) result = ht.argmax(data) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.dtype, ht.int64) self.assertEqual(result._DNDarray__array.dtype, torch.int64) self.assertEqual(result.shape, (1, )) self.assertEqual(result.lshape, (1, )) self.assertEqual(result.split, None) self.assertTrue((result._DNDarray__array == torch.tensor( [19], device=self.device.torch_device))) # 2D split tensor, along the axis data = ht.array(ht.random.randn(4, 5), is_split=0) result = ht.argmax(data, axis=1) expected = torch.argmax(data._DNDarray__array, dim=1) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.dtype, ht.int64) self.assertEqual(result._DNDarray__array.dtype, torch.int64) self.assertEqual(result.shape, (ht.MPI_WORLD.size * 4, )) self.assertEqual(result.lshape, (4, )) self.assertEqual(result.split, 0) self.assertTrue((result._DNDarray__array == expected).all()) # 2D split tensor, across the axis size = ht.MPI_WORLD.size * 2 data = ht.tril(ht.ones((size, size), split=0), k=-1) result = ht.argmax(data, axis=0) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(result.dtype, ht.int64) self.assertEqual(result._DNDarray__array.dtype, torch.int64) self.assertEqual(result.shape, (size, )) self.assertEqual(result.lshape, (size, )) self.assertEqual(result.split, None) # skip test on gpu; argmax works different if not (torch.cuda.is_available() and result.device == ht.gpu): self.assertTrue((result._DNDarray__array != 0).all()) # 2D split tensor, across the axis, output tensor size = ht.MPI_WORLD.size * 2 data = ht.tril(ht.ones((size, size), split=0), k=-1) output = ht.empty((size, )) result = ht.argmax(data, axis=0, out=output) self.assertIsInstance(result, ht.DNDarray) self.assertEqual(output.dtype, ht.int64) self.assertEqual(output._DNDarray__array.dtype, torch.int64) self.assertEqual(output.shape, (size, )) self.assertEqual(output.lshape, (size, )) self.assertEqual(output.split, None) # skip test on gpu; argmax works different if not (torch.cuda.is_available() and output.device == ht.gpu): self.assertTrue((output._DNDarray__array != 0).all()) # check exceptions with self.assertRaises(TypeError): data.argmax(axis=(0, 1)) with self.assertRaises(TypeError): data.argmax(axis=1.1) with self.assertRaises(TypeError): data.argmax(axis="y") with self.assertRaises(ValueError): ht.argmax(data, axis=-4)