def test_reduce_0D(self): N = 10 ** 5 dist = Distribution(self.context, (N,)) new_dist = dist.reduce(axes=[0]) self.assertEqual(new_dist.dist, ()) self.assertSequenceEqual(new_dist.shape, ()) self.assertEqual(new_dist.grid_shape, ()) self.assertEqual(set(new_dist.targets), set(dist.targets[:1]))
def test_reduce_0D(self): N = 10**5 dist = Distribution(self.context, (N, )) new_dist = dist.reduce(axes=[0]) self.assertEqual(new_dist.dist, ()) self.assertSequenceEqual(new_dist.shape, ()) self.assertEqual(new_dist.grid_shape, ()) self.assertEqual(set(new_dist.targets), set(dist.targets[:1]))
def test_reduce(self): nr, nc, nd = 10 ** 5, 10 ** 6, 10 ** 4 dist = Distribution(self.context, (nr, nc, nd), ("b", "c", "n"), grid_shape=(2, 2, 1)) new_dist0 = dist.reduce(axes=[0]) self.assertEqual(new_dist0.dist, ("c", "n")) self.assertSequenceEqual(new_dist0.shape, (nc, nd)) self.assertEqual(new_dist0.grid_shape, dist.grid_shape[1:]) self.assertLess(set(new_dist0.targets), set(dist.targets)) new_dist1 = dist.reduce(axes=[1]) self.assertEqual(new_dist1.dist, ("b", "n")) self.assertSequenceEqual(new_dist1.shape, (nr, nd)) self.assertEqual(new_dist1.grid_shape, dist.grid_shape[:1] + dist.grid_shape[2:]) self.assertLess(set(new_dist1.targets), set(dist.targets)) new_dist2 = dist.reduce(axes=[2]) self.assertEqual(new_dist2.dist, ("b", "c")) self.assertSequenceEqual(new_dist2.shape, (nr, nc)) self.assertEqual(new_dist2.grid_shape, dist.grid_shape[:-1]) self.assertEqual(set(new_dist2.targets), set(dist.targets))
def test_reduce(self): nr, nc, nd = 10**5, 10**6, 10**4 dist = Distribution(self.context, (nr, nc, nd), ('b', 'c', 'n'), grid_shape=(2, 2, 1)) new_dist0 = dist.reduce(axes=[0]) self.assertEqual(new_dist0.dist, ('c', 'n')) self.assertSequenceEqual(new_dist0.shape, (nc, nd)) self.assertEqual(new_dist0.grid_shape, dist.grid_shape[1:]) self.assertLess(set(new_dist0.targets), set(dist.targets)) new_dist1 = dist.reduce(axes=[1]) self.assertEqual(new_dist1.dist, ('b', 'n')) self.assertSequenceEqual(new_dist1.shape, (nr, nd)) self.assertEqual(new_dist1.grid_shape, dist.grid_shape[:1] + dist.grid_shape[2:]) self.assertLess(set(new_dist1.targets), set(dist.targets)) new_dist2 = dist.reduce(axes=[2]) self.assertEqual(new_dist2.dist, ('b', 'c')) self.assertSequenceEqual(new_dist2.shape, (nr, nc)) self.assertEqual(new_dist2.grid_shape, dist.grid_shape[:-1]) self.assertEqual(set(new_dist2.targets), set(dist.targets))