def test_flatten(self): self.assertEqual( iterables.flatten([[1, 2, 3], [4, 5, 6], [7, 8], [9, 10]]), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ) self.assertEqual( iterables.flatten([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10]]), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], )
def mpiFlatten(allCPUResults): """ Flatten results to the same order they were in before making a list of mpiIter results. See Also -------- mpiIter : used for distributing objects/tasks """ return iterables.flatten(allCPUResults)
def test_typicalBalancing(self): """Test load balancing for typical case (numProcs < numObjs) In this case, the total imbalance should be 1 (except for the perfectly balanced case). """ numObjs, numProcs = 25, 6 allObjs = list(range(numObjs)) objs = self._distributeObjects(allObjs, numProcs) # typical case (more objects than processes) counts = [len(o) for o in objs] imbalance = max(counts) - min(counts) self.assertLessEqual(imbalance, 1) self.assertEqual(iterables.flatten(objs), allObjs)
def test_split(self): data = list(range(50)) chu = iterables.split(data, 10) self.assertEqual(len(chu), 10) unchu = iterables.flatten(chu) self.assertEqual(data, unchu) chu = iterables.split(data, 1) self.assertEqual(len(chu), 1) unchu = iterables.flatten(chu) self.assertEqual(data, unchu) chu = iterables.split(data, 60, padWith=[None]) self.assertEqual(len(chu), 60) unchu = iterables.flatten(chu) self.assertEqual(len(unchu), 60) chu = iterables.split(data, 60, padWith=[None]) self.assertEqual(len(chu), 60) data = [0] chu = iterables.split(data, 1) unchu = iterables.flatten(chu) self.assertEqual(unchu, data)
def test_perfectBalancing(self): """Test load balancing when numProcs divides numObjects In this case, all processes should get the same number of objects. """ numObjs, numProcs = 25, 5 allObjs = list(range(numObjs)) objs = self._distributeObjects(allObjs, numProcs) counts = [len(o) for o in objs] imbalance = max(counts) - min(counts) # ensure we haven't missed any objects self.assertEqual(iterables.flatten(objs), allObjs) # check imbalance self.assertEqual(imbalance, 0)
def test_excessProcesses(self): """Test load balancing when numProcs exceeds numObjects In this case, some processes should receive a single object and the rest should receive no objects """ numObjs, numProcs = 5, 25 allObjs = list(range(numObjs)) objs = self._distributeObjects(allObjs, numProcs) counts = [len(o) for o in objs] imbalance = max(counts) - min(counts) # ensure we haven't missed any objects self.assertEqual(iterables.flatten(objs), allObjs) # check imbalance self.assertLessEqual(imbalance, 1)
def _gatherList(localList): globalList = armi.MPI_COMM.gather(localList, root=0) if armi.MPI_RANK == 0: globalList = iterables.flatten(globalList) return globalList