def test_retrieve_arxiv_batch_rows_stops_at_end_cursor(self, mocked_batch): rows = ['row1', 'row2', 'row3', 'row4', 'row5', 'row6'] mocked_batch.side_effect = [(rows[0:2], 'mytoken|2'), (rows[2:4], 'mytoken|4'), (rows[4:6], 'mytoken|6')] result = list(retrieve_arxiv_batch_rows(0, 4, 'mytoken|1')) assert result == ['row1', 'row2', 'row3', 'row4']
def test_retrieve_arxiv_batch_rows_returns_all_rows_till_empty_token(self, mocked_batch): rows = ['row1', 'row2', 'row3', 'row4', 'row5', 'row6'] mocked_batch.side_effect = [(rows[0:2], 'mytoken|2'), (rows[2:4], 'mytoken|4'), (rows[4:6], None)] result = list(retrieve_arxiv_batch_rows(0, 9999, 'mytoken')) assert result == ['row1', 'row2', 'row3', 'row4', 'row5', 'row6']
def test_retrieve_arxiv_batch_rows_calls_arxiv_batch_correctly(self, mocked_batch): mocked_batch.side_effect = [('data', 'mytoken|2'), ('data', 'mytoken|4'), ('data', None)] list(retrieve_arxiv_batch_rows(0, 9999, 'mytoken')) assert mocked_batch.mock_calls == [mock.call('mytoken|0'), mock.call('mytoken|2'), mock.call('mytoken|4')]
def run(): db_name = os.environ["BATCHPAR_db_name"] s3_path = os.environ["BATCHPAR_outinfo"] start_cursor = int(os.environ["BATCHPAR_start_cursor"]) end_cursor = int(os.environ["BATCHPAR_end_cursor"]) batch_size = end_cursor - start_cursor logging.warning(f"Retrieving {batch_size} articles between {start_cursor - 1}:{end_cursor - 1}") # Setup the database connectors engine = get_mysql_engine("BATCHPAR_config", "mysqldb", db_name) try_until_allowed(Base.metadata.create_all, engine) # load arxiv subject categories to database bucket = 'innovation-mapping-general' cat_file = 'arxiv_classification/arxiv_subject_classifications.csv' load_arxiv_categories("BATCHPAR_config", db_name, bucket, cat_file) # process data articles = [] article_cats = [] resumption_token = request_token() for row in retrieve_arxiv_batch_rows(start_cursor, end_cursor, resumption_token): with db_session(engine) as session: categories = row.pop('categories', []) articles.append(row) for cat in categories: # TODO:this is inefficient and should be queried once to a set. see # iterative proceess. try: session.query(Category).filter(Category.id == cat).one() except NoResultFound: logging.warning(f"missing category: '{cat}' for article {row['id']}. Adding to Category table") session.add(Category(id=cat)) article_cats.append(dict(article_id=row['id'], category_id=cat)) inserted_articles, existing_articles, failed_articles = insert_data( "BATCHPAR_config", "mysqldb", db_name, Base, Article, articles, return_non_inserted=True) logging.warning(f"total article categories: {len(article_cats)}") inserted_article_cats, existing_article_cats, failed_article_cats = insert_data( "BATCHPAR_config", "mysqldb", db_name, Base, ArticleCategory, article_cats, return_non_inserted=True) # sanity checks before the batch is marked as done logging.warning((f'inserted articles: {len(inserted_articles)} ', f'existing articles: {len(existing_articles)} ', f'failed articles: {len(failed_articles)}')) logging.warning((f'inserted article categories: {len(inserted_article_cats)} ', f'existing article categories: {len(existing_article_cats)} ', f'failed article categories: {len(failed_article_cats)}')) if len(inserted_articles) + len(existing_articles) + len(failed_articles) != batch_size: raise ValueError(f'Inserted articles do not match original data.') if len(inserted_article_cats) + len(existing_article_cats) + len(failed_article_cats) != len(article_cats): raise ValueError(f'Inserted article categories do not match original data.') # Mark the task as done s3 = boto3.resource('s3') s3_obj = s3.Object(*parse_s3_path(s3_path)) s3_obj.put(Body="")