def test_aggregator_order_preserved(self): batch_size = 10 num_processes = 2 num_rows = batch_size * num_processes * 10 self.options.aggregators = [ _CountMultiples(-1), _CountMultiples(0), _CountMultiples(0.5), _CountMultiples(1), ca.CountsAggregator(), _CountMultiples(10) ] self.options.num_processes = num_processes self.options.num_rows = num_rows self.options.batch_size = batch_size self.worker = gw.Worker(self.options, self.dummy_logger, self.dist_holder) x = self.worker.start() goal_results = [ -num_rows, 0, num_rows / 2, num_rows, num_rows, 10 * num_rows ] self.assertListEqual(x, goal_results)
def test_aggregators_not_shared(self): batch_size = 10 num_processes = 2 num_rows = batch_size * num_processes * 10 fields_needed = self.dist_holder.var_order top_level_aggregator = _CheckShared(fields_needed) self.options.aggregators = [top_level_aggregator] self.options.num_processes = num_processes self.options.num_rows = num_rows self.options.batch_size = batch_size self.worker = gw.Worker(self.options, self.dummy_logger, self.dist_holder) x = self.worker.start() # How many times did any CheckShared aggregator get called? self.assertListEqual(x, [num_rows]) # How many times did the *top-level* CheckShared.map get called? # IF a single aggregator was shared, it would be 200 # If the mothership's aggregator is not shared with the workers, # it should be zero. self.assertEqual(top_level_aggregator.map_call_counter, 0) # How many times did the *top-level* CheckShared.reduce get called? # IF a single aggregator was shared, it would be 199. # If the mothership's aggregator is not shared with the workers, # it should be 19. self.assertEqual(top_level_aggregator.reduce_call_counter, 19)
def test_id_repeatability(self): # THis is in response to a bug that made row-IDs not repeatable # The purposw of this test is to ensure that the generator will # generate the same rows, with the same row-ids, in both # single-processor and multi-processor modes. # # Note that this test is *LONG*. Unfortunately, it seems to take # a very long run to trip the bug, probably becuase it takes a while # for workers and mothership to get 'out of synch' in multiprocessing # mode. self.maxDiff = None self.options.random_seed = 10 self.options.batch_size = 5 num_batches = 225 self.options.num_rows = self.options.batch_size * num_batches # Run 1 self.options.aggregators = [_RowCollector()] self.options.num_processes = 1 self.worker = gw.Worker(self.options, self.dummy_logger, self.dist_holder) result_list = self.worker.start() dict1 = result_list[0] self.assertEqual(len(dict1), self.options.num_rows) # Run 2 self.options.aggregators = [_RowCollector()] self.options.num_processes = 2 self.worker = gw.Worker(self.options, self.dummy_logger, self.dist_holder) result_list = self.worker.start() dict2 = result_list[0] self.assertEqual(len(dict2), self.options.num_rows) self.assertDictEqual(dict1, dict2)
def test_aggregator_death3(self): fields_needed = self.dist_holder.var_order self.options.aggregators = [_DieDuringDone(fields_needed)] for num_processes in [1, 2]: self.options.num_processes = num_processes self.worker = gw.Worker(self.options, self.dummy_logger, self.dist_holder) with self.assertRaises(KeyboardInterrupt): self.worker.start()
def test_generate_aggregate_results(self): self.options.aggregators = [ca.CountsAggregator()] for num_processes in [1, 2]: self.options.num_processes = num_processes self.worker = gw.Worker(self.options, self.dummy_logger, self.dist_holder) result = self.worker.start() correct_result = [self.num_rows] self.assertListEqual(result, correct_result)
def generate_rows(options, queryset_aggregators, logger, dist_holder): # Note: so that the later zip() works, it is important that # be put at the *start* of options.aggregators. Why? options.aggregators # may already contain aggregators, such as a line-raw aggregator. So the # worker will return results for our queryset_aggregators *and* more # besides. So that the results line up with the queryset, we would like # the results to be at the beginning of the result list. This requires # that queryset_aggregators be at the beginning of the options.aggregators # list. options.aggregators = queryset_aggregators + options.aggregators # Spawn a worker and start it worker = gw.Worker(options, logger, dist_holder) aggregator_results = worker.start() return aggregator_results
def test_aggregation_singleproc(self): batch_size = 10 num_processes = 1 aggregator_name = 'counts_aggregator' self.options.num_processes = num_processes num_rows = batch_size * num_processes * 4 self.options.aggregators = [ca.CountsAggregator()] self.options.num_rows = num_rows self.options.batch_size = batch_size self.worker = gw.Worker(self.options, self.dummy_logger, self.dist_holder) x = self.worker.start() self.assertListEqual(x, [num_rows])
def test_aggregation_multiproc(self): batch_size = 10 num_processes = 2 self.options.num_processes = num_processes num_rows = batch_size * num_processes * 2 counts_agg = ca.CountsAggregator() self.options.aggregators = [counts_agg] self.options.num_rows = num_rows self.options.batch_size = batch_size self.worker = gw.Worker(self.options, self.dummy_logger, self.dist_holder) x = self.worker.start() self.assertListEqual(x, [num_rows])
def test_aggregator_death1(self): fields_needed = self.dist_holder.var_order for num_processes in [1, 2]: batch_size = 100 num_rows = batch_size * num_processes self.options.aggregators = [ _DieDuringMap(batch_size, fields_needed) ] self.options.num_processes = num_processes self.options.num_rows = num_rows self.worker = gw.Worker(self.options, self.dummy_logger, self.dist_holder) with self.assertRaises(KeyboardInterrupt): self.worker.start()
def setUp(self): self.seed = int(time.time()) self.seed_msg = "Random seed used for this test: %s" % self.seed self.longMessage = True spar_random.seed(self.seed) class Object(object): pass self.dummy_logger = logging.getLogger('dummy') self.dummy_logger.addHandler(logging.NullHandler()) self.dummy_object = Object() self.num_rows = 100 self.options = gw.DataGeneratorOptions(random_seed=self.seed, num_processes=2, num_rows=self.num_rows, verbose=False, aggregators=[], batch_size=5) # Build the distribution-holder learner_options = Object() pums_files = \ [("mock pums", stringio.StringIO(mock_data_files.mock_pums_data))] pums_dict = \ learn_distributions.learn_pums_dists(learner_options, self.dummy_logger, pums_files) names_files = \ [('male_first_names.txt', stringio.StringIO(mock_data_files.mock_male_first_names)), ('female_first_names.txt', stringio.StringIO(mock_data_files.mock_female_first_names)), ('last_names.txt', stringio.StringIO(mock_data_files.mock_last_names))] names_dict = \ learn_distributions.learn_name_dists(learner_options, self.dummy_logger, names_files) zipcode_files = \ [('mock_zipcodes', stringio.StringIO(mock_data_files.mock_zipcodes))] zipcode_dict = \ learn_distributions.learn_zipcode_dists(learner_options, self.dummy_logger, zipcode_files) text_files = \ [('mock_text', stringio.StringIO(mock_data_files.mock_text_files))] text_engine = \ learn_distributions.train_text_engine(learner_options, self.dummy_logger, text_files) streets_files = \ [('mock street file', stringio.StringIO(mock_data_files.mock_street_names))] address_dict = \ learn_distributions.learn_street_address_dists(learner_options, self.dummy_logger, streets_files) self.dist_holder = \ learn_distributions.make_distribution_holder(learner_options, self.dummy_logger, pums_dict, names_dict, zipcode_dict, address_dict, text_engine) self.worker = gw.Worker(self.options, self.dummy_logger, self.dist_holder)