def test_where(self): # cases to test # no x and y a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=None) cond = a > 3 wh = ht.where(cond) self.assertEqual(wh.gshape, (6, 2)) self.assertEqual(wh.dtype, ht.int64) self.assertEqual(wh.split, None) # split a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) cond = a > 3 wh = ht.where(cond) self.assertEqual(wh.gshape, (6, 2)) self.assertEqual(wh.dtype, ht.int64) self.assertEqual(wh.split, 0) # not split cond a = ht.array([[0.0, 1.0, 2.0], [0.0, 2.0, 4.0], [0.0, 3.0, 6.0]], split=None) res = ht.array([[0.0, 1.0, 2.0], [0.0, 2.0, -1.0], [0.0, 3.0, -1.0]], split=None) wh = ht.where(a < 4.0, a, -1) self.assertTrue( ht.equal(a[ht.nonzero(a < 4)], ht.array([0.0, 1.0, 2.0, 0.0, 2.0, 0.0, 3.0]))) self.assertTrue(ht.equal(wh, res)) self.assertEqual(wh.gshape, (3, 3)) self.assertEqual(wh.dtype, ht.float64) # split cond a = ht.array([[0.0, 1.0, 2.0], [0.0, 2.0, 4.0], [0.0, 3.0, 6.0]], split=0) res = ht.array([[0.0, 1.0, 2.0], [0.0, 2.0, -1.0], [0.0, 3.0, -1.0]], split=0) wh = ht.where(a < 4.0, a, -1) self.assertTrue(ht.all(wh[ht.nonzero(a >= 4)] == -1)) self.assertTrue(ht.equal(wh, res)) self.assertEqual(wh.gshape, (3, 3)) self.assertEqual(wh.dtype, ht.float64) self.assertEqual(wh.split, 0) a = ht.array([[0.0, 1.0, 2.0], [0.0, 2.0, 4.0], [0.0, 3.0, 6.0]], split=1) res = ht.array([[0.0, 1.0, 2.0], [0.0, 2.0, -1.0], [0.0, 3.0, -1.0]], split=1) wh = ht.where(a < 4.0, a, -1.0) self.assertTrue(ht.equal(wh, res)) self.assertEqual(wh.gshape, (3, 3)) self.assertEqual(wh.dtype, ht.float) self.assertEqual(wh.split, 1) with self.assertRaises(TypeError): ht.where(cond, a) with self.assertRaises(NotImplementedError): ht.where(cond, ht.ones((3, 3), split=0), ht.ones((3, 3), split=1))
def test_nonzero(self): # cases to test: # not split a = ht.array([[1, 2, 3], [4, 5, 2], [7, 8, 9]], split=None) cond = a > 3 nz = ht.nonzero(cond) self.assertEqual(nz.gshape, (5, 2)) self.assertEqual(nz.dtype, ht.int64) self.assertEqual(nz.split, None) # split a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) cond = a > 3 nz = cond.nonzero() self.assertEqual(nz.gshape, (6, 2)) self.assertEqual(nz.dtype, ht.int64) self.assertEqual(nz.split, 0) a[nz] = 10.0 self.assertEqual(ht.all(a[nz] == 10), 1)
def test_all(self): array_len = 9 # check all over all float elements of 1d tensor locally ones_noaxis = ht.ones(array_len, device=ht_device) x = (ones_noaxis == 1).all() self.assertIsInstance(x, ht.DNDarray) self.assertEqual(x.shape, (1, )) self.assertEqual(x.lshape, (1, )) self.assertEqual(x.dtype, ht.bool) self.assertEqual(x._DNDarray__array.dtype, torch.bool) self.assertEqual(x.split, None) self.assertEqual(x._DNDarray__array, 1) out_noaxis = ht.zeros((1, ), device=ht_device) ht.all(ones_noaxis, out=out_noaxis) self.assertEqual(out_noaxis._DNDarray__array, 1) # check all over all float elements of split 1d tensor ones_noaxis_split = ht.ones(array_len, split=0, device=ht_device) floats_is_one = ones_noaxis_split.all() self.assertIsInstance(floats_is_one, ht.DNDarray) self.assertEqual(floats_is_one.shape, (1, )) self.assertEqual(floats_is_one.lshape, (1, )) self.assertEqual(floats_is_one.dtype, ht.bool) self.assertEqual(floats_is_one._DNDarray__array.dtype, torch.bool) self.assertEqual(floats_is_one.split, None) self.assertEqual(floats_is_one._DNDarray__array, 1) out_noaxis = ht.zeros((1, ), device=ht_device) ht.all(ones_noaxis_split, out=out_noaxis) self.assertEqual(out_noaxis._DNDarray__array, 1) # check all over all integer elements of 1d tensor locally ones_noaxis_int = ht.ones(array_len, device=ht_device).astype(ht.int) int_is_one = ones_noaxis_int.all() self.assertIsInstance(int_is_one, ht.DNDarray) self.assertEqual(int_is_one.shape, (1, )) self.assertEqual(int_is_one.lshape, (1, )) self.assertEqual(int_is_one.dtype, ht.bool) self.assertEqual(int_is_one._DNDarray__array.dtype, torch.bool) self.assertEqual(int_is_one.split, None) self.assertEqual(int_is_one._DNDarray__array, 1) out_noaxis = ht.zeros((1, ), device=ht_device) ht.all(ones_noaxis_int, out=out_noaxis) self.assertEqual(out_noaxis._DNDarray__array, 1) # check all over all integer elements of split 1d tensor ones_noaxis_split_int = ht.ones(array_len, split=0, device=ht_device).astype(ht.int) split_int_is_one = ones_noaxis_split_int.all() self.assertIsInstance(split_int_is_one, ht.DNDarray) self.assertEqual(split_int_is_one.shape, (1, )) self.assertEqual(split_int_is_one.lshape, (1, )) self.assertEqual(split_int_is_one.dtype, ht.bool) self.assertEqual(split_int_is_one._DNDarray__array.dtype, torch.bool) self.assertEqual(split_int_is_one.split, None) self.assertEqual(split_int_is_one._DNDarray__array, 1) out_noaxis = ht.zeros((1, ), device=ht_device) ht.all(ones_noaxis_split_int, out=out_noaxis) self.assertEqual(out_noaxis._DNDarray__array, 1) # check all over all float elements of 3d tensor locally ones_noaxis_volume = ht.ones((3, 3, 3), device=ht_device) volume_is_one = ones_noaxis_volume.all() self.assertIsInstance(volume_is_one, ht.DNDarray) self.assertEqual(volume_is_one.shape, (1, )) self.assertEqual(volume_is_one.lshape, (1, )) self.assertEqual(volume_is_one.dtype, ht.bool) self.assertEqual(volume_is_one._DNDarray__array.dtype, torch.bool) self.assertEqual(volume_is_one.split, None) self.assertEqual(volume_is_one._DNDarray__array, 1) out_noaxis = ht.zeros((1, ), device=ht_device) ht.all(ones_noaxis_volume, out=out_noaxis) self.assertEqual(out_noaxis._DNDarray__array, 1) # check sequence is not all one sequence = ht.arange(array_len, device=ht_device) sequence_is_one = sequence.all() self.assertIsInstance(sequence_is_one, ht.DNDarray) self.assertEqual(sequence_is_one.shape, (1, )) self.assertEqual(sequence_is_one.lshape, (1, )) self.assertEqual(sequence_is_one.dtype, ht.bool) self.assertEqual(sequence_is_one._DNDarray__array.dtype, torch.bool) self.assertEqual(sequence_is_one.split, None) self.assertEqual(sequence_is_one._DNDarray__array, 0) out_noaxis = ht.zeros((1, ), device=ht_device) ht.all(sequence, out=out_noaxis) self.assertEqual(out_noaxis._DNDarray__array, 0) # check all over all float elements of split 3d tensor ones_noaxis_split_axis = ht.ones((3, 3, 3), split=0, device=ht_device) float_volume_is_one = ones_noaxis_split_axis.all(axis=0) self.assertIsInstance(float_volume_is_one, ht.DNDarray) self.assertEqual(float_volume_is_one.shape, (3, 3)) self.assertEqual(float_volume_is_one.all(axis=1).dtype, ht.bool) self.assertEqual(float_volume_is_one._DNDarray__array.dtype, torch.bool) self.assertEqual(float_volume_is_one.split, None) out_noaxis = ht.zeros((3, 3), device=ht_device) ht.all(ones_noaxis_split_axis, axis=0, out=out_noaxis) # check all over all float elements of split 3d tensor with tuple axis ones_noaxis_split_axis = ht.ones((3, 3, 3), split=0, device=ht_device) float_volume_is_one = ones_noaxis_split_axis.all(axis=(0, 1)) self.assertIsInstance(float_volume_is_one, ht.DNDarray) self.assertEqual(float_volume_is_one.shape, (3, )) self.assertEqual(float_volume_is_one.all(axis=0).dtype, ht.bool) self.assertEqual(float_volume_is_one._DNDarray__array.dtype, torch.bool) self.assertEqual(float_volume_is_one.split, None) # check all over all float elements of split 5d tensor with negative axis ones_noaxis_split_axis_neg = ht.zeros((1, 2, 3, 4, 5), split=1, device=ht_device) float_5d_is_one = ones_noaxis_split_axis_neg.all(axis=-2) self.assertIsInstance(float_5d_is_one, ht.DNDarray) self.assertEqual(float_5d_is_one.shape, (1, 2, 3, 5)) self.assertEqual(float_5d_is_one.dtype, ht.bool) self.assertEqual(float_5d_is_one._DNDarray__array.dtype, torch.bool) self.assertEqual(float_5d_is_one.split, 1) out_noaxis = ht.zeros((1, 2, 3, 5), device=ht_device) ht.all(ones_noaxis_split_axis_neg, axis=-2, out=out_noaxis) # exceptions with self.assertRaises(ValueError): ht.ones(array_len, device=ht_device).all(axis=1) with self.assertRaises(ValueError): ht.ones(array_len, device=ht_device).all(axis=-2) with self.assertRaises(ValueError): ht.ones((4, 4), device=ht_device).all(axis=0, out=out_noaxis) with self.assertRaises(TypeError): ht.ones(array_len, device=ht_device).all(axis="bad_axis_type")
def test_matmul(self): with self.assertRaises(ValueError): ht.matmul(ht.ones((25, 25)), ht.ones((42, 42))) # cases to test: n, m = 21, 31 j, k = m, 45 a_torch = torch.ones((n, m), device=self.device.torch_device) a_torch[0] = torch.arange(1, m + 1, device=self.device.torch_device) a_torch[:, -1] = torch.arange(1, n + 1, device=self.device.torch_device) b_torch = torch.ones((j, k), device=self.device.torch_device) b_torch[0] = torch.arange(1, k + 1, device=self.device.torch_device) b_torch[:, 0] = torch.arange(1, j + 1, device=self.device.torch_device) # splits None None a = ht.ones((n, m), split=None) b = ht.ones((j, k), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, None) self.assertEqual(a.split, None) self.assertEqual(b.split, None) # splits None None a = ht.ones((n, m), split=None) b = ht.ones((j, k), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b, allow_resplit=True) self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, None) self.assertEqual(a.split, 0) self.assertEqual(b.split, None) if a.comm.size > 1: # splits 00 a = ht.ones((n, m), split=0, dtype=ht.float64) b = ht.ones((j, k), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = a @ b ret_comp00 = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp00)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float64) self.assertEqual(ret00.split, 0) # splits 00 (numpy) a = ht.array(np.ones((n, m)), split=0) b = ht.array(np.ones((j, k)), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = a @ b ret_comp00 = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp00)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float64) self.assertEqual(ret00.split, 0) # splits 01 a = ht.ones((n, m), split=0) b = ht.ones((j, k), split=1, dtype=ht.float64) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp01 = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp01)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float64) self.assertEqual(ret00.split, 0) # splits 10 a = ht.ones((n, m), split=1) b = ht.ones((j, k), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp10 = ht.array(a_torch @ b_torch, split=1) self.assertTrue(ht.equal(ret00, ret_comp10)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 1) # splits 11 a = ht.ones((n, m), split=1) b = ht.ones((j, k), split=1) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp11 = ht.array(a_torch @ b_torch, split=1) self.assertTrue(ht.equal(ret00, ret_comp11)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 1) # splits 11 (torch) a = ht.array(torch.ones((n, m), device=self.device.torch_device), split=1) b = ht.array(torch.ones((j, k), device=self.device.torch_device), split=1) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp11 = ht.array(a_torch @ b_torch, split=1) self.assertTrue(ht.equal(ret00, ret_comp11)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 1) # splits 0 None a = ht.ones((n, m), split=0) b = ht.ones((j, k), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp0 = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp0)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 1 None a = ht.ones((n, m), split=1) b = ht.ones((j, k), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp1 = ht.array(a_torch @ b_torch, split=1) self.assertTrue(ht.equal(ret00, ret_comp1)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 1) # splits None 0 a = ht.ones((n, m), split=None) b = ht.ones((j, k), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits None 1 a = ht.ones((n, m), split=None) b = ht.ones((j, k), split=1) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=1) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 1) # vector matrix mult: # a -> vector a_torch = torch.ones((m), device=self.device.torch_device) b_torch = torch.ones((j, k), device=self.device.torch_device) b_torch[0] = torch.arange(1, k + 1, device=self.device.torch_device) b_torch[:, 0] = torch.arange(1, j + 1, device=self.device.torch_device) # splits None None a = ht.ones((m), split=None) b = ht.ones((j, k), split=None) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, None) # splits None 0 a = ht.ones((m), split=None) b = ht.ones((j, k), split=0) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits None 1 a = ht.ones((m), split=None) b = ht.ones((j, k), split=1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 0 None a = ht.ones((m), split=None) b = ht.ones((j, k), split=0) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 0 0 a = ht.ones((m), split=0) b = ht.ones((j, k), split=0) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 0 1 a = ht.ones((m), split=0) b = ht.ones((j, k), split=1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # b -> vector a_torch = torch.ones((n, m), device=self.device.torch_device) a_torch[0] = torch.arange(1, m + 1, device=self.device.torch_device) a_torch[:, -1] = torch.arange(1, n + 1, device=self.device.torch_device) b_torch = torch.ones((j), device=self.device.torch_device) # splits None None a = ht.ones((n, m), split=None) b = ht.ones((j), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, None) # splits 0 None a = ht.ones((n, m), split=0) b = ht.ones((j), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array((a_torch @ b_torch), split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 1 None a = ht.ones((n, m), split=1) b = ht.ones((j), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array((a_torch @ b_torch), split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits None 0 a = ht.ones((n, m), split=None) b = ht.ones((j), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array((a_torch @ b_torch), split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 0 0 a = ht.ones((n, m), split=0) b = ht.ones((j), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array((a_torch @ b_torch), split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 1 0 a = ht.ones((n, m), split=1) b = ht.ones((j), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array((a_torch @ b_torch), split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) with self.assertRaises(NotImplementedError): a = ht.zeros((3, 3, 3), split=2) b = a.copy() a @ b
def __partial_fit(self, X, y, classes=None, _refit=False, sample_weight=None): """ Actual implementation of Gaussian NB fitting. Adapted to HeAT from scikit-learn. Parameters ---------- X : ht.tensor of shape (n_samples, n_features) Training set, where n_samples is the number of samples and n_features is the number of features. y : ht.tensor of shape (n_samples,) Labels for training set. classes : ht.tensor of shape (n_classes,), optional (default=None) List of all the classes that can possibly appear in the y vector. Must be provided at the first call to partial_fit, can be omitted in subsequent calls. _refit : bool, optional (default=False) If true, act as though this were the first time __partial_fit is called (ie, throw away any past fitting and start over). sample_weight : ht.tensor of shape (n_samples,), optional (default=None) Weights applied to individual samples (1. for unweighted). Returns ------- self : object """ # TODO: sanitize X and y shape: sanitation/validation module, cf. #468 n_samples = X.shape[0] if X.numdims != 2: raise ValueError("expected X to be a 2-D tensor, is {}-D".format( X.numdims)) if y.shape[0] != n_samples: raise ValueError( "y.shape[0] must match number of samples {}, is {}".format( n_samples, y.shape[0])) # TODO: sanitize sample_weight: sanitation/validation module, cf. #468 if sample_weight is not None: if sample_weight.numdims != 1: raise ValueError("Sample weights must be 1D tensor") if sample_weight.shape != (n_samples, ): raise ValueError( "sample_weight.shape == {}, expected {}!".format( sample_weight.shape, (n_samples, ))) # If the ratio of data variance between dimensions is too small, it # will cause numerical errors. To address this, we artificially # boost the variance by epsilon, a small fraction of the standard # deviation of the largest dimension. self.epsilon_ = self.var_smoothing * ht.var(X, axis=0).max() if _refit: self.classes_ = None if self.__check_partial_fit_first_call(classes): # This is the first call to partial_fit: # initialize various cumulative counters n_features = X.shape[1] n_classes = len(self.classes_) self.theta_ = ht.zeros((n_classes, n_features), dtype=X.dtype, device=X.device) self.sigma_ = ht.zeros((n_classes, n_features), dtype=X.dtype, device=X.device) self.class_count_ = ht.zeros((n_classes, ), dtype=ht.float64, device=X.device) # Initialise the class prior # Take into account the priors if self.priors is not None: if not isinstance(self.priors, ht.DNDarray): priors = ht.array(self.priors, dtype=X.dtype, split=None, device=X.device) else: priors = self.priors # Check that the provide prior match the number of classes if len(priors) != n_classes: raise ValueError("Number of priors must match number of" " classes.") # Check that the sum is 1 if not ht.isclose(priors.sum(), ht.array(1.0, dtype=priors.dtype)): raise ValueError("The sum of the priors should be 1.") # Check that the prior are non-negative if (priors < 0).any(): raise ValueError("Priors must be non-negative.") self.class_prior_ = priors else: # Initialize the priors to zeros for each class self.class_prior_ = ht.zeros(len(self.classes_), dtype=ht.float64, split=None, device=X.device) else: if X.shape[1] != self.theta_.shape[1]: raise ValueError( "Number of features {} does not match previous data {}.". format(X.shape[1], self.theta_.shape[1])) # Put epsilon back in each time self.sigma_[:, :] -= self.epsilon_ classes = self.classes_ unique_y = ht.unique(y, sorted=True) if unique_y.split is not None: unique_y = ht.resplit(unique_y, axis=None) unique_y_in_classes = ht.eq(unique_y, classes) if not ht.all(unique_y_in_classes): raise ValueError("The target label(s) {} in y do not exist in the " "initial classes {}".format( unique_y[~unique_y_in_classes], classes)) for y_i in unique_y: # assuming classes.split is None if y_i in classes: i = ht.where(classes == y_i).item() else: classes_ext = torch.cat((classes._DNDarray__array, y_i._DNDarray__array.unsqueeze(0))) i = torch.argsort(classes_ext)[-1].item() where_y_i = ht.where(y == y_i)._DNDarray__array.tolist() X_i = X[where_y_i, :] if sample_weight is not None: sw_i = sample_weight[where_y_i] if 0 not in sw_i.shape: N_i = sw_i.sum() else: N_i = 0.0 sw_i = None else: sw_i = None N_i = X_i.shape[0] new_theta, new_sigma = self.__update_mean_variance( self.class_count_[i], self.theta_[i, :], self.sigma_[i, :], X_i, sw_i) self.theta_[i, :] = new_theta self.sigma_[i, :] = new_sigma self.class_count_[i] += N_i self.sigma_[:, :] += self.epsilon_ # Update if only no priors is provided if self.priors is None: # Empirical prior, with sample_weight taken into account self.class_prior_ = self.class_count_ / self.class_count_.sum() return self
def __partial_fit( self, x: DNDarray, y: DNDarray, classes: Optional[DNDarray] = None, _refit: bool = False, sample_weight: Optional[DNDarray] = None, ): """ Actual implementation of Gaussian NB fitting. Adapted to HeAT from scikit-learn. Parameters ---------- x : DNDarray Training set, where n_samples is the number of samples and n_features is the number of features. Shape = (n_samples, n_features) y : DNDarray Labels for training set. Shape = (n_samples,) classes : DNDarray, optional List of all the classes that can possibly appear in the y vector. Must be provided at the first call to :func:`partial_fit`, can be omitted in subsequent calls. Shape = (n_classes,) _refit : bool, optional If ``True``, act as though this were the first time :func:`__partial_fit` is called (ie, throw away any past fitting and start over). sample_weight : DNDarray, optional Weights applied to individual samples (1. for unweighted). Shape = (n_samples,) """ # TODO: sanitize x and y shape: sanitation/validation module, cf. #468 n_samples = x.shape[0] if x.ndim != 2: raise ValueError("expected x to be a 2-D tensor, is {}-D".format( x.ndim)) if y.shape[0] != n_samples: raise ValueError( "y.shape[0] must match number of samples {}, is {}".format( n_samples, y.shape[0])) # TODO: sanitize sample_weight: sanitation/validation module, cf. #468 if sample_weight is not None: if sample_weight.ndim != 1: raise ValueError("Sample weights must be 1D tensor") if sample_weight.shape != (n_samples, ): raise ValueError( "sample_weight.shape == {}, expected {}!".format( sample_weight.shape, (n_samples, ))) # If the ratio of data variance between dimensions is too small, it # will cause numerical errors. To address this, we artificially # boost the variance by epsilon, a small fraction of the standard # deviation of the largest dimension. self.epsilon_ = self.var_smoothing * ht.var(x, axis=0).max() if _refit: self.classes_ = None if self.__check_partial_fit_first_call(classes): # This is the first call to partial_fit: # initialize various cumulative counters n_features = x.shape[1] n_classes = len(self.classes_) self.theta_ = ht.zeros((n_classes, n_features), dtype=x.dtype, device=x.device) self.sigma_ = ht.zeros((n_classes, n_features), dtype=x.dtype, device=x.device) self.class_count_ = ht.zeros((x.comm.size, n_classes), dtype=ht.float64, device=x.device, split=0) # Initialise the class prior # Take into account the priors if self.priors is not None: if not isinstance(self.priors, ht.DNDarray): priors = ht.array(self.priors, dtype=x.dtype, split=None, device=x.device) else: priors = self.priors # Check that the provide prior match the number of classes if len(priors) != n_classes: raise ValueError("Number of priors must match number of" " classes.") # Check that the sum is 1 if not ht.isclose(priors.sum(), ht.array(1.0, dtype=priors.dtype)): raise ValueError("The sum of the priors should be 1.") # Check that the prior are non-negative if (priors < 0).any(): raise ValueError("Priors must be non-negative.") self.class_prior_ = priors else: # Initialize the priors to zeros for each class self.class_prior_ = ht.zeros(len(self.classes_), dtype=ht.float64, split=None, device=x.device) else: if x.shape[1] != self.theta_.shape[1]: raise ValueError( "Number of features {} does not match previous data {}.". format(x.shape[1], self.theta_.shape[1])) # Put epsilon back in each time self.sigma_[:, :] -= self.epsilon_ classes = self.classes_ unique_y = ht.unique(y, sorted=True).resplit_(None) unique_y_in_classes = ht.eq(unique_y, classes) if not ht.all(unique_y_in_classes): raise ValueError("The target label(s) {} in y do not exist in the " "initial classes {}".format( unique_y[~unique_y_in_classes], classes)) # from now on: extract torch tensors for local operations # DNDarrays for distributed operations only for y_i in unique_y.larray: # assuming classes.split is None if y_i in classes.larray: i = torch.where(classes.larray == y_i)[0].item() else: classes_ext = torch.cat( (classes.larray, y_i.larray.unsqueeze(0))) i = torch.argsort(classes_ext)[-1].item() where_y_i = torch.where(y.larray == y_i)[0] X_i = x[where_y_i, :] if sample_weight is not None: sw_i = sample_weight[where_y_i] if 0 not in sw_i.shape: N_i = sw_i.sum().item() else: N_i = 0.0 sw_i = None else: sw_i = None N_i = X_i.shape[0] new_theta, new_sigma = self.__update_mean_variance( self.class_count_.larray[:, i].item(), self.theta_[i, :], self.sigma_[i, :], X_i, sw_i, ) self.theta_[i, :] = new_theta self.sigma_[i, :] = new_sigma self.class_count_.larray[:, i] += N_i self.sigma_[:, :] += self.epsilon_ # Update only if no priors are provided if self.priors is None: # distributed class_count_: sum along distribution axis self.class_count_ = self.class_count_.sum(axis=0, keepdim=True) # Empirical prior, with sample_weight taken into account self.class_prior_ = (self.class_count_ / self.class_count_.sum()).squeeze(0) return self