def prepare(self): '''Prepare the batch job parameters''' job_params = [] total = total_articles() for count, batch_start in enumerate(range(1, total, BATCH_SIZE), 1): batch_done_key = '_'.join(['arxiv_collection_from', str(batch_start)]) done = batch_done_key in DONE_KEYS end_cursor = batch_start + BATCH_SIZE if end_cursor > total: end_cursor = total + 1 params = {"config": "mysqldb.config", "start_cursor": batch_start, "end_cursor": end_cursor, "db_name": "production" if not self.test else "dev", "outinfo": "s3://nesta-production-intermediate/%s" % batch_done_key, "done": done} logging.warning(f"Batch {count}: {params}") job_params.append(params) return job_params
def test_total_articles_doesnt_override_delay(self, mocked_request, mock_response): mocked_request.return_value = ET.fromstring(mock_response) total_articles() if 'delay' in mocked_request.call_args[1]: assert mocked_request.call_args[1]['delay'] >= 3
def test_total_articles_returns_correct_amount(self, mocked_request, mock_response): mocked_request.return_value = ET.fromstring(mock_response) assert total_articles() == 1463679