示例#1
0
    def test_retrieve_arxiv_batch_rows_stops_at_end_cursor(self, mocked_batch):
        rows = ['row1', 'row2', 'row3', 'row4', 'row5', 'row6']
        mocked_batch.side_effect = [(rows[0:2], 'mytoken|2'),
                                    (rows[2:4], 'mytoken|4'),
                                    (rows[4:6], 'mytoken|6')]

        result = list(retrieve_arxiv_batch_rows(0, 4, 'mytoken|1'))
        assert result == ['row1', 'row2', 'row3', 'row4']
示例#2
0
    def test_retrieve_arxiv_batch_rows_returns_all_rows_till_empty_token(self, mocked_batch):
        rows = ['row1', 'row2', 'row3', 'row4', 'row5', 'row6']
        mocked_batch.side_effect = [(rows[0:2], 'mytoken|2'),
                                    (rows[2:4], 'mytoken|4'),
                                    (rows[4:6], None)]

        result = list(retrieve_arxiv_batch_rows(0, 9999, 'mytoken'))
        assert result == ['row1', 'row2', 'row3', 'row4', 'row5', 'row6']
示例#3
0
    def test_retrieve_arxiv_batch_rows_calls_arxiv_batch_correctly(self, mocked_batch):
        mocked_batch.side_effect = [('data', 'mytoken|2'),
                                    ('data', 'mytoken|4'),
                                    ('data', None)]

        list(retrieve_arxiv_batch_rows(0, 9999, 'mytoken'))
        assert mocked_batch.mock_calls == [mock.call('mytoken|0'),
                                           mock.call('mytoken|2'),
                                           mock.call('mytoken|4')]
示例#4
0
文件: run.py 项目: yitzikc/nesta
def run():
    db_name = os.environ["BATCHPAR_db_name"]
    s3_path = os.environ["BATCHPAR_outinfo"]
    start_cursor = int(os.environ["BATCHPAR_start_cursor"])
    end_cursor = int(os.environ["BATCHPAR_end_cursor"])
    batch_size = end_cursor - start_cursor
    logging.warning(f"Retrieving {batch_size} articles between {start_cursor - 1}:{end_cursor - 1}")

    # Setup the database connectors
    engine = get_mysql_engine("BATCHPAR_config", "mysqldb", db_name)
    try_until_allowed(Base.metadata.create_all, engine)

    # load arxiv subject categories to database
    bucket = 'innovation-mapping-general'
    cat_file = 'arxiv_classification/arxiv_subject_classifications.csv'
    load_arxiv_categories("BATCHPAR_config", db_name, bucket, cat_file)

    # process data
    articles = []
    article_cats = []
    resumption_token = request_token()
    for row in retrieve_arxiv_batch_rows(start_cursor, end_cursor, resumption_token):
        with db_session(engine) as session:
            categories = row.pop('categories', [])
            articles.append(row)
            for cat in categories:
                # TODO:this is inefficient and should be queried once to a set. see
                # iterative proceess.
                try:
                    session.query(Category).filter(Category.id == cat).one()
                except NoResultFound:
                    logging.warning(f"missing category: '{cat}' for article {row['id']}.  Adding to Category table")
                    session.add(Category(id=cat))
                article_cats.append(dict(article_id=row['id'], category_id=cat))

    inserted_articles, existing_articles, failed_articles = insert_data(
                                                "BATCHPAR_config", "mysqldb", db_name,
                                                Base, Article, articles,
                                                return_non_inserted=True)
    logging.warning(f"total article categories: {len(article_cats)}")
    inserted_article_cats, existing_article_cats, failed_article_cats = insert_data(
                                                "BATCHPAR_config", "mysqldb", db_name,
                                                Base, ArticleCategory, article_cats,
                                                return_non_inserted=True)

    # sanity checks before the batch is marked as done
    logging.warning((f'inserted articles: {len(inserted_articles)} ',
                     f'existing articles: {len(existing_articles)} ',
                     f'failed articles: {len(failed_articles)}'))
    logging.warning((f'inserted article categories: {len(inserted_article_cats)} ',
                     f'existing article categories: {len(existing_article_cats)} ',
                     f'failed article categories: {len(failed_article_cats)}'))
    if len(inserted_articles) + len(existing_articles) + len(failed_articles) != batch_size:
        raise ValueError(f'Inserted articles do not match original data.')
    if len(inserted_article_cats) + len(existing_article_cats) + len(failed_article_cats) != len(article_cats):
        raise ValueError(f'Inserted article categories do not match original data.')

    # Mark the task as done
    s3 = boto3.resource('s3')
    s3_obj = s3.Object(*parse_s3_path(s3_path))
    s3_obj.put(Body="")