Beispiel #1
0
    def prepare(self):
        if self.test:
            self.process_batch_size = 1000
            logging.warning("Batch size restricted to "
                            f"{self.process_batch_size}"
                            " while in test mode")

        es_mode = 'dev' if self.test else 'prod'
        es, es_config = setup_es(es_mode,
                                 self.test,
                                 self.drop_and_recreate,
                                 dataset='nih',
                                 aliases='health_scanner',
                                 increment_version=True)

        # Count articles from the old index
        _old_config = es_config.copy()
        _old_config['index'] = es_config['old_index']
        logging.info(f"Collected article IDs...")
        _ids = get_es_ids(es, _old_config, size=10000)
        logging.info(f"Collected {len(_ids)} IDs")
        done_ids = get_es_ids(es, es_config, size=10000)

        # Generate the job params
        job_params = []
        batches = split_batches(_ids, self.process_batch_size)
        for count, batch in enumerate(batches, 1):
            # Magical '0.3' is the lower end of the deduplication
            # fraction found by inspection
            done = sum(_id in done_ids for _id in batch) / len(batch) > 0.3
            # write batch of ids to s3
            batch_file = ''
            if not done:
                batch_file = put_s3_batch(batch, self.intermediate_bucket,
                                          self.routine_id)
            params = {
                "batch_file": batch_file,
                "config": 'mysqldb.config',
                "bucket": self.intermediate_bucket,
                "done": done,
                'outinfo': es_config['host'],
                'out_port': es_config['port'],
                'out_index': es_config['index'],
                'in_index': es_config['old_index'],
                'out_type': es_config['type'],
                'aws_auth_region': es_config['region'],
                'entity_type': 'paper',
                'test': self.test,
                'routine_id': self.routine_id
            }

            job_params.append(params)
            if self.test and count > 1:
                logging.warning("Breaking after 2 batches "
                                "while in test mode.")
                logging.warning(job_params)
                break
        logging.info("Batch preparation completed, "
                     f"with {len(job_params)} batches")
        return job_params
Beispiel #2
0
    def run(self):
        """Apply health labels using model."""
        # database setup
        database = 'dev' if self.test else 'production'
        logging.warning(f"Using {database} database")
        self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)
        try_until_allowed(Base.metadata.create_all, self.engine)

        # collect and unpickle models from s3
        logging.info("Collecting models from S3")
        s3 = boto3.resource('s3')
        vectoriser_obj = s3.Object(self.bucket, self.vectoriser_key)
        vectoriser = pickle.loads(
            vectoriser_obj.get()['Body']._raw_stream.read())
        classifier_obj = s3.Object(self.bucket, self.classifier_key)
        classifier = pickle.loads(
            classifier_obj.get()['Body']._raw_stream.read())

        # retrieve organisations and categories
        nrows = 1000 if self.test else None
        logging.info("Collecting organisations from database")
        with db_session(self.engine) as session:
            orgs = (session.query(Organization.id).filter(
                Organization.is_health.is_(None)).limit(nrows).all())

        for batch_count, batch in enumerate(
                split_batches(orgs, self.insert_batch_size), 1):
            batch_orgs_with_cats = []
            for (org_id, ) in batch:
                with db_session(self.engine) as session:
                    categories = (session.query(
                        OrganizationCategory.category_name).filter(
                            OrganizationCategory.organization_id ==
                            org_id).all())
                # categories should be a list of str, comma separated: ['cat,cat,cat', 'cat,cat']
                categories = ','.join(cat_name for (cat_name, ) in categories)
                batch_orgs_with_cats.append({
                    'id': org_id,
                    'categories': categories
                })

            logging.debug(
                f"{len(batch_orgs_with_cats)} organisations retrieved from database"
            )

            logging.debug("Predicting health flags")
            batch_orgs_with_flag = predict_health_flag(batch_orgs_with_cats,
                                                       vectoriser, classifier)

            logging.debug(
                f"{len(batch_orgs_with_flag)} organisations to update")
            with db_session(self.engine) as session:
                session.bulk_update_mappings(Organization,
                                             batch_orgs_with_flag)
            logging.info(
                f"{batch_count} batches health labeled and written to db")

        # mark as done
        logging.warning("Task complete")
        self.output().touch()
Beispiel #3
0
    def prepare(self):
        '''Chunk up elasticsearch data, and submit batch
        jobs over those chunks.'''
        if self.test:
            self.process_batch_size = 1000
            logging.warning("Batch size restricted to "
                            f"{self.process_batch_size}"
                            " while in test mode")

        # Setup elasticsearch and extract all ids
        es_mode = 'dev' if self.test else 'prod'
        es, es_config = setup_es(es_mode,
                                 self.test,
                                 drop_and_recreate=False,
                                 dataset=self.dataset,
                                 increment_version=False)
        ids = get_es_ids(es, es_config, size=10000)  # All ids in this index
        ids = ids - self._done_ids  # Don't repeat done ids

        # Override the default index if specified
        es_config['index'] = (self.index if self.index is not None else
                              es_config['index'])

        # Generate the job params
        job_params = []
        batches = split_batches(ids, self.process_batch_size)
        for count, batch in enumerate(batches, 1):
            done = False  # Already taken care of with _done_ids
            # write batch of ids to s3
            batch_file = ''
            if not done:
                batch_file = put_s3_batch(batch, self.intermediate_bucket,
                                          self.routine_id)
            params = {
                "batch_file": batch_file,
                "config": self.sql_config_filename,
                "bucket": self.intermediate_bucket,
                "done": done,
                "count": len(ids),
                'outinfo': es_config['host'],
                'out_port': es_config['port'],
                'index': es_config['index'],
                'out_type': es_config['type'],
                'aws_auth_region': es_config['region'],
                'test': self.test,
                'routine_id': self.routine_id,
                'entity_type': self.entity_type,
                **self.kwargs
            }
            job_params.append(params)
            # Test mode
            if self.test and count > 1:
                logging.warning("Breaking after 2 batches "
                                "while in test mode.")
                logging.warning(job_params)
                break
        # Done
        logging.info("Batch preparation completed, "
                     f"with {len(job_params)} batches")
        return job_params
Beispiel #4
0
def _batch_query_articles_by_doi(query, articles, batch_size=10):
    """Manages batches and generates sparql queries for articles and queries them from
    mag via the sparql api using the supplied `doi`.

    Args:
        query (str): sparql query containing a format string placeholder {}
        articles (:obj:`list` of :obj:`dict`): articles to query in MAG.
            Must contatin at least `id` and `doi` in each dict.
        batch_size (int): number of ids to query in a batch. Max size = 50

    Yields:
        (:obj:`list` of :obj:`dict`): batches of data returned from MAG
    """
    if not 1 <= batch_size <= 10:  # max limit for uri length
        raise ValueError("batch_size must be between 1 and 10")

    for articles_batch in split_batches(articles, batch_size):
        clean_dois = [(a['doi'].replace('\n', '').replace('\\',
                                                          '').replace('"', ''))
                      for a in articles_batch]
        concat_dois = ','.join(f'"{a}"^^xsd:string' for a in clean_dois)
        article_filter = f"FILTER (?doi IN ({concat_dois}))"

        for results_batch in sparql_query(
                MAG_ENDPOINT, query.format(article_filter=article_filter)):
            yield articles_batch, results_batch
Beispiel #5
0
def _batched_entity_filter(concat_format, filter_on, ids, batch_size):
    """Creates batches of entity filters for SPARQL queries. Constructing a
    'filter in' statement using the provided ids, splitting whenever the batch size is
    hit.

    A call with the following arguments:
        _batched_entity_filter('<http://ma-graph.org/entity/{}>', 'data', [1, 2], 50)

    Will return a generator which yields the following:
        "FILTER (?data IN (<http://ma-graph.org/entity/1>,<http://ma-graph.org/entity/2>))"

    Args:
        concat_format (str): string format to be applied when concatenating ids
            requires a placeholder for id {}
        filter_on (str): name of the field to use in the filter. The '?' prefix
            is not required
        ids (list): If ids are supplied they are queried as batches, otherwise
            all entities are queried
        batch_size (int): number of ids to query in a batch.

    Yields:
        (str): filter string containing ids up to the chosen batch size
    """
    for batch_of_ids in split_batches(ids, batch_size):
        entities = ','.join(concat_format.format(i) for i in batch_of_ids)

        yield f"FILTER (?{filter_on} IN ({entities}))"
Beispiel #6
0
 def run(self):
     data = extract_data(limit=1000 if self.test else None)
     logging.info(f'Got {len(data)} rows')
     database = 'dev' if self.test else 'production'
     for chunk in split_batches(data, 10000):
         logging.info(f'Inserting chunk of size {len(chunk)}')
         insert_data('MYSQLDB', 'mysqldb', database,
                     Base, ApplnFamily, chunk, low_memory=True)
     self.output().touch()
Beispiel #7
0
    def __iter__(self):
        for batch_of_ids in split_batches(self.ids, self.batch_size):
            self.title_articles_lookup.clear()
            for article in (self.session.query(Article).filter(
                    Article.id.in_(batch_of_ids)).all()):
                self.title_articles_lookup[prepare_title(
                    article.title)].append(article.id)

            for title in self.title_articles_lookup:
                yield title
Beispiel #8
0
    def run(self):
        """Collect and process organizations, categories and long descriptions."""

        # database setup
        database = 'dev' if self.test else 'production'
        logging.warning(f"Using {database} database")
        self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)
        try_until_allowed(Base.metadata.create_all, self.engine)
        limit = 2000 if self.test else None
        batch_size = 30 if self.test else 1000

        with db_session(self.engine) as session:
            all_orgs = session.query(
                Organisation.id, Organisation.addresses).limit(limit).all()
            existing_org_location_ids = session.query(
                OrganisationLocation.id).all()
        logging.info(f"{len(all_orgs)} organisations retrieved from database")
        logging.info(
            f"{len(existing_org_location_ids)} organisations have previously been processed"
        )

        # convert to a list of dictionaries with the nested addresses unpacked
        orgs = get_orgs_to_process(all_orgs, existing_org_location_ids)
        logging.info(f"{len(orgs)} new organisations to geocode")

        total_batches = ceil(len(orgs) / batch_size)
        logging.info(f"{total_batches} batches")
        completed_batches = 0
        for batch in split_batches(orgs, batch_size=batch_size):
            # geocode first to add missing country for UK
            batch = map(geocode_uk_with_postcode, batch)
            batch = map(add_country_details, batch)

            # remove data not in OrganisationLocation columns
            org_location_cols = OrganisationLocation.__table__.columns.keys()
            batch = [{k: v
                      for k, v in org.items() if k in org_location_cols}
                     for org in batch]

            insert_data(self.db_config_env, 'mysqldb', database, Base,
                        OrganisationLocation, batch)
            completed_batches += 1
            logging.info(
                f"Completed {completed_batches} of {total_batches} batches")

            if self.test and completed_batches > 1:
                logging.warning("Breaking after 2 batches in test mode")
                break

        # mark as done
        logging.warning("Finished task")
        self.output().touch()
Beispiel #9
0
    def run(self):
        # database setup
        database = 'dev' if self.test else 'production'
        logging.warning(f"Using {database} database")
        self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)

        # collect mesh terms from S3
        bucket = 'innovation-mapping-general'
        key = 'crunchbase_descriptions/crunchbase_descriptions_mesh.txt'
        mesh_terms = retrieve_mesh_terms(bucket, key)
        mesh_terms = format_mesh_terms(
            mesh_terms)  # [{'id': ['term1', 'term2']}, ...]
        logging.info(f"File contains {len(mesh_terms)} orgs with mesh terms")

        logging.info("Extracting previously processed orgs")
        with db_session(self.engine) as session:
            all_orgs = session.query(Organization.id,
                                     Organization.mesh_terms).all()
        processed_orgs = {
            org_id
            for (org_id, mesh_terms) in all_orgs if mesh_terms is not None
        }
        all_orgs = {org_id for (org_id, _) in all_orgs}
        logging.info(f"{len(all_orgs)} total orgs in database")
        logging.info(f"{len(processed_orgs)} previously processed orgs")

        # reformat for batch insert, removing not found and previously processed terms
        meshed_orgs = [{
            'id': org_id,
            'mesh_terms': '|'.join(terms)
        } for org_id, terms in mesh_terms.items()
                       if org_id in all_orgs and org_id not in processed_orgs]

        logging.info(f"{len(meshed_orgs)} organisations to update in database")

        for count, batch in enumerate(
                split_batches(meshed_orgs, self.insert_batch_size), 1):
            with db_session(self.engine) as session:
                session.bulk_update_mappings(Organization, batch)
            logging.info(
                f"{count} batch{'es' if count > 1 else ''} written to db")
            if self.test and count > 1:
                logging.info("Breaking after 2 batches while in test mode")
                break

        # mark as done
        logging.warning("Task complete")
        self.output().touch()
    def run(self):
        # database setup
        database = 'dev' if self.test else 'production'
        logging.warning(f"Using {database} database")
        self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)

        # collect file
        logging.info(f"Collecting org_parents from crunchbase tar")
        org_parents = get_files_from_tar(['org_parents'])[0]
        logging.info(f"{len(org_parents)} parent ids in crunchbase export")

        # collect previously processed orgs
        logging.info("Extracting previously processed organisations")
        with db_session(self.engine) as session:
            processed_orgs = session.query(Organization.id,
                                           Organization.parent_id).all()
        all_orgs = {org for (org, _) in processed_orgs}
        logging.info(f"{len(all_orgs)} total orgs in database")
        processed_orgs = {
            org
            for (org, parent_id) in processed_orgs if parent_id is not None
        }
        logging.info(f"{len(processed_orgs)} previously processed orgs")

        # reformat into a list of dicts, removing orgs that already have a parent_id
        # or are missing from the database
        org_parents = org_parents[['uuid', 'parent_uuid']]
        org_parents.columns = ['id', 'parent_id']
        org_parents = org_parents[org_parents['id'].isin(all_orgs)]
        org_parents = org_parents[~org_parents['id'].isin(processed_orgs)]
        org_parents = org_parents.to_dict(orient='records')
        logging.info(f"{len(org_parents)} organisations to update in MYSQL")

        # insert parent_ids into db in batches
        for count, batch in enumerate(
                split_batches(org_parents, self.insert_batch_size), 1):
            with db_session(self.engine) as session:
                session.bulk_update_mappings(Organization, batch)
            logging.info(
                f"{count} batch{'es' if count > 1 else ''} written to db")
            if self.test and count > 1:
                logging.info("Breaking after 2 batches while in test mode")
                break

        # mark as done
        logging.warning("Task complete")
        self.output().touch()
Beispiel #11
0
def run():
    test = literal_eval(os.environ["BATCHPAR_test"])
    db_name = os.environ["BATCHPAR_db_name"]
    table = os.environ["BATCHPAR_table"]
    batch_size = int(os.environ["BATCHPAR_batch_size"])
    s3_path = os.environ["BATCHPAR_outinfo"]

    logging.warning(f"Processing {table} file")

    # database setup
    engine = get_mysql_engine("BATCHPAR_config", "mysqldb", db_name)
    try_until_allowed(Base.metadata.create_all, engine)
    table_name = f"crunchbase_{table}"
    table_class = get_class_by_tablename(Base, table_name)

    # collect file
    nrows = 1000 if test else None
    df = get_files_from_tar([table], nrows=nrows)[0]
    logging.warning(f"{len(df)} rows in file")

    # get primary key fields and set of all those already existing in the db
    pk_cols = list(table_class.__table__.primary_key.columns)
    pk_names = [pk.name for pk in pk_cols]
    with db_session(engine) as session:
        existing_rows = set(session.query(*pk_cols).all())

    # process and insert data
    processed_rows = process_non_orgs(df, existing_rows, pk_names)
    for batch in split_batches(processed_rows, batch_size):
        insert_data("BATCHPAR_config",
                    'mysqldb',
                    db_name,
                    Base,
                    table_class,
                    processed_rows,
                    low_memory=True)

    logging.warning(f"Marking task as done to {s3_path}")
    s3 = boto3.resource('s3')
    s3_obj = s3.Object(*parse_s3_path(s3_path))
    s3_obj.put(Body="")

    logging.warning("Batch job complete.")
Beispiel #12
0
    def prepare(self):
        if self.test:
            self.process_batch_size = 100
        # MySQL setup
        database = 'dev' if self.test else 'production'
        engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)

        # Subtract off all done ids
        Base.metadata.create_all(engine)
        with db_session(engine) as session:
            result = session.query(Project.rcn).all()
            done_rcn = {r[0] for r in result}

        # Get all possible ids (or "RCN" in Cordis-speak)
        nrows = 1000 if self.test else None
        all_rcn = set(
            get_framework_ids('fp7', nrows=nrows) +
            get_framework_ids('h2020', nrows=nrows))
        all_rcn = all_rcn - done_rcn

        # Generate the job params
        batches = split_batches(all_rcn, self.process_batch_size)
        params = [{
            "batch_file":
            put_s3_batch(batch, self.intermediate_bucket, self.routine_id),
            "config":
            'mysqldb.config',
            "db_name":
            database,
            "bucket":
            self.intermediate_bucket,
            "outinfo":
            'dummy',
            "done":
            False,
            'test':
            self.test
        } for batch in batches]
        return params
Beispiel #13
0
    def prepare(self):
        '''Prepare the batch job parameters'''
        db = 'dev' if self.test else 'production'
        engine = get_mysql_engine(self.db_config_env, 'mysqldb', db)

        with db_session(engine) as session:
            results = (session.query(
                Projects.id, func.length(Projects.abstractText)).filter(
                    Projects.abstractText is not None).distinct(
                        Projects.abstractText).all())

        # Keep documents with a length larger than the 10th percentile.
        perc = np.percentile([r[1] for r in results], 10)
        all_ids = [r.id for r in results if r[1] >= perc]

        job_params = []
        for count, batch in enumerate(split_batches(all_ids,
                                                    self.process_batch_size),
                                      start=1):
            # write batch of ids to s3
            key = f'text2vec-{self.routine_id}-{self.date}-{count}'
            batch_file = put_s3_batch(batch, self.intermediate_bucket, key)
            done = key in DONE_KEYS
            params = {
                "config": "mysqldb.config",
                "bucket": self.intermediate_bucket,
                "batch_file": batch_file,
                "db_name": db,
                "done": done,
                'outinfo':
                f"s3://{self.intermediate_bucket}/{key}",  # mark as done
                'test': self.test,
            }
            job_params.append(params)
            logging.info(params)
        return job_params
    def prepare(self):
        if self.test:
            self.process_batch_size = 1000
            logging.warning("Batch size restricted to "
                            f"{self.process_batch_size}"
                            " while in test mode")

        # MySQL setup
        self.database = 'dev' if self.test else 'production'
        engine = get_mysql_engine(self.db_config_env, 'mysqldb', self.database)

        # Elasticsearch setup
        es_mode = 'dev' if self.test else 'prod'
        es, es_config = setup_es(es_mode,
                                 self.test,
                                 self.drop_and_recreate,
                                 dataset='crunchbase',
                                 aliases='health_scanner')

        # Get set of existing ids from elasticsearch via scroll
        scanner = scan(es,
                       query={"_source": False},
                       index=es_config['index'],
                       doc_type=es_config['type'])
        existing_ids = {s['_id'] for s in scanner}
        logging.info(f"Collected {len(existing_ids)} existing in "
                     "Elasticsearch")

        # Get set of all organisations from mysql
        all_orgs = list(all_org_ids(engine))
        logging.info(f"{len(all_orgs)} organisations in MySQL")

        # Remove previously processed
        orgs_to_process = list(org for org in all_orgs
                               if org not in existing_ids)
        logging.info(f"{len(orgs_to_process)} to be processed")

        job_params = []
        for count, batch in enumerate(
                split_batches(orgs_to_process, self.process_batch_size), 1):
            logging.info(f"Processing batch {count} with size {len(batch)}")

            # write batch of ids to s3
            batch_file = put_s3_batch(batch, self.intermediate_bucket,
                                      'crunchbase_to_es')
            params = {
                "batch_file": batch_file,
                "config": 'mysqldb.config',
                "db_name": self.database,
                "bucket": self.intermediate_bucket,
                "done": False,
                'outinfo': es_config['host'],
                'out_port': es_config['port'],
                'out_index': es_config['index'],
                'out_type': es_config['type'],
                'aws_auth_region': es_config['region'],
                'entity_type': 'company',
                "test": self.test
            }

            logging.info(params)
            job_params.append(params)
            if self.test and count > 1:
                logging.warning("Breaking after 2 batches while in "
                                "test mode.")
                break

        logging.warning("Batch preparation completed, "
                        f"with {len(job_params)} batches")
        return job_params
    def run(self):
        """Collect and process organizations, categories and long descriptions."""

        # database setup
        database = 'dev' if self.test else 'production'
        logging.warning(f"Using {database} database")
        self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)
        try_until_allowed(Base.metadata.create_all, self.engine)

        # collect files
        nrows = 200 if self.test else None
        cat_groups, orgs, org_descriptions = get_files_from_tar(
            ['category_groups', 'organizations', 'organization_descriptions'],
            nrows=nrows)
        # process category_groups
        cat_groups = rename_uuid_columns(cat_groups)
        insert_data(self.db_config_env,
                    'mysqldb',
                    database,
                    Base,
                    CategoryGroup,
                    cat_groups.to_dict(orient='records'),
                    low_memory=True)

        # process organizations and categories
        with db_session(self.engine) as session:
            existing_orgs = session.query(Organization.id).all()
        existing_orgs = {org[0] for org in existing_orgs}

        logging.info("Summary of organisation data:")
        logging.info(f"Total number of organisations:\t {len(orgs)}")
        logging.info(
            f"Number of organisations already in the database:\t {len(existing_orgs)}"
        )
        logging.info(f"Number of category groups and text descriptions:\t"
                     f"{len(cat_groups)}, {len(org_descriptions)}")

        processed_orgs, org_cats, missing_cat_groups = process_orgs(
            orgs, existing_orgs, cat_groups, org_descriptions)
        # Insert CatGroups
        insert_data(self.db_config_env, 'mysqldb', database, Base,
                    CategoryGroup, missing_cat_groups)
        # Insert orgs in batches
        n_batches = round(len(processed_orgs) / self.insert_batch_size)
        logging.info(
            f"Inserting {n_batches} batches of size {self.insert_batch_size}")
        for i, batch in enumerate(
                split_batches(processed_orgs, self.insert_batch_size)):
            if i % 100 == 0:
                logging.info(f"Inserting batch {i} of {n_batches}")
            insert_data(self.db_config_env,
                        'mysqldb',
                        database,
                        Base,
                        Organization,
                        batch,
                        low_memory=True)

        # link table needs to be inserted via non-bulk method to enforce relationship
        logging.info("Filtering duplicates...")
        org_cats, existing_org_cats, failed_org_cats = filter_out_duplicates(
            self.db_config_env,
            'mysqldb',
            database,
            Base,
            OrganizationCategory,
            org_cats,
            low_memory=True)
        logging.info(
            f"Inserting {len(org_cats)} org categories "
            f"({len(existing_org_cats)} already existed and {len(failed_org_cats)} failed)"
        )
        #org_cats = [OrganizationCategory(**org_cat) for org_cat in org_cats]
        with db_session(self.engine) as session:
            session.add_all(org_cats)

        # mark as done
        self.output().touch()
Beispiel #16
0
    def run(self):
        # s3 setup
        s3 = boto3.resource('s3')
        intermediate_file = s3.Object(BUCKET, INTERMEDIATE_FILE)

        # database setup
        database = 'dev' if self.test else 'production'
        logging.info(f"Using {database} database")
        self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)
        Base.metadata.create_all(self.engine)

        eu = get_eu_countries()
        logging.info(f"Retrieved {len(eu)} EU countries")

        with db_session(self.engine) as session:
            all_fos_ids = {f.id for f in (session
                                          .query(FieldOfStudy.id)
                                          .all())}
            logging.info(f"{len(all_fos_ids):,} fields of study in database")

            eu_grid_ids = {i.id for i in (session
                                          .query(Institute.id)
                                          .filter(Institute.country.in_(eu))
                                          .all())}
            logging.info(f"{len(eu_grid_ids):,} EU institutes in GRID")

        try:
            processed_grid_ids = set(json.loads(intermediate_file
                                                .get()['Body']
                                                ._raw_stream.read()))

            logging.info(f"{len(processed_grid_ids)} previously processed institutes")
            eu_grid_ids = eu_grid_ids - processed_grid_ids
            logging.info(f"{len(eu_grid_ids):,} institutes to process")
        except ClientError:
            logging.info("Unable to load file of processed institutes, starting from scratch")
            processed_grid_ids = set()

        if self.test:
            self.batch_size = 500
            batch_limit = 1
        else:
            batch_limit = None
        testing_finished = False

        row_count = 0
        for institute_count, grid_id in enumerate(eu_grid_ids):
            paper_ids, author_ids = set(), set()
            data = {Paper: [],
                    Author: [],
                    PaperAuthor: set(),
                    PaperFieldsOfStudy: set(),
                    PaperLanguage: set()}

            if not institute_count % 50:
                logging.info(f"{institute_count:,} of {len(eu_grid_ids):,} institutes processed")

            if not check_institute_exists(grid_id):
                logging.debug(f"{grid_id} not found in MAG")
                continue

            # these tables have data stored in sets for deduping so the fieldnames will
            # need to be added when converting to a list of dicts for loading to the db
            field_names_to_add = {PaperAuthor: ('paper_id', 'author_id'),
                                  PaperFieldsOfStudy: ('paper_id', 'field_of_study_id'),
                                  PaperLanguage: ('paper_id', 'language')}

            logging.info(f"Querying MAG for {grid_id}")
            for row in query_by_grid_id(grid_id,
                                        from_date=self.from_date,
                                        min_citations=self.min_citations,
                                        batch_size=self.batch_size,
                                        batch_limit=batch_limit):

                fos_id = row['fieldOfStudyId']
                if fos_id not in all_fos_ids:
                    logging.info(f"Getting missing field of study {fos_id} from MAG")
                    update_field_of_study_ids_sparql(self.engine, fos_ids=[fos_id])
                    all_fos_ids.add(fos_id)

                # the returned data is normalised and therefore contains many duplicates
                paper_id = row['paperId']
                if paper_id not in paper_ids:
                    data[Paper].append({'id': paper_id,
                                        'title': row['paperTitle'],
                                        'citation_count': row['paperCitationCount'],
                                        'created_date': row['paperCreatedDate'],
                                        'doi': row.get('paperDoi'),
                                        'book_title': row.get('bookTitle')})
                    paper_ids.add(paper_id)

                author_id = row['authorId']
                if author_id not in author_ids:
                    data[Author].append({'id': author_id,
                                         'name': row['authorName'],
                                         'grid_id': grid_id})
                    author_ids.add(author_id)

                data[PaperAuthor].add((row['paperId'], row['authorId']))

                data[PaperFieldsOfStudy].add((row['paperId'], row['fieldOfStudyId']))

                try:
                    data[PaperLanguage].add((row['paperId'], row['paperLanguage']))
                except KeyError:
                    # language is an optional field
                    pass

                row_count += 1
                if self.test and row_count >= 1000:
                    logging.warning("Breaking after 1000 rows in test mode")
                    testing_finished = True
                    break

            # write out to SQL
            for table, rows in data.items():
                if table in field_names_to_add:
                    rows = [{k: v for k, v in zip(field_names_to_add[table], row)}
                            for row in rows]
                logging.debug(f"Writing {len(rows):,} rows to {table.__table__.name}")

                for batch in split_batches(rows, self.insert_batch_size):
                    insert_data('MYSQLDB', 'mysqldb', database, Base, table, batch)

            # flag institute as completed on S3
            processed_grid_ids.add(grid_id)
            intermediate_file.put(Body=json.dumps(list(processed_grid_ids)))

            if testing_finished:
                break


        # mark as done
        logging.info("Task complete")
        self.output().touch()
Beispiel #17
0
def test_split_batches_when_data_is_smaller_than_batch_size(generate_test_data):
    yielded_batches = []
    for batch in split_batches(generate_test_data(200), batch_size=1000):
        yielded_batches.append(batch)

    assert len(yielded_batches) == 1
Beispiel #18
0
def test_split_batches_yields_multiple_batches_with_remainder(generate_test_data):
    yielded_batches = []
    for batch in split_batches(generate_test_data(2400), batch_size=1000):
        yielded_batches.append(batch)

    assert len(yielded_batches) == 3
Beispiel #19
0
def test_split_batches_with_set(generate_test_set_data):
    yielded_batches = []
    for batch in split_batches(generate_test_set_data(2400), batch_size=1000):
        yielded_batches.append(batch)

    assert len(yielded_batches) == 3
Beispiel #20
0
    def prepare(self):
        if self.test:
            self.process_batch_size = 1000
            logging.warning("Batch size restricted to "
                            f"{self.process_batch_size}"
                            " while in test mode")

        # MySQL setup
        database = 'dev' if self.test else 'production'
        engine = get_mysql_engine(self.db_config_env, self.db_section,
                                  database)

        # Elasticsearch setup
        es_mode = 'dev' if self.test else 'prod'
        es, es_config = setup_es(es_mode,
                                 self.test,
                                 self.drop_and_recreate,
                                 dataset=self.dataset,
                                 aliases=self.aliases)

        # Get set of existing ids from elasticsearch via scroll
        existing_ids = get_es_ids(es, es_config)
        logging.info(f"Collected {len(existing_ids)} existing in "
                     "Elasticsearch")

        # Get set of all organisations from mysql
        with db_session(engine) as session:
            result = session.query(self.id_field).all()
            all_ids = {r[0] for r in result}
        logging.info(f"{len(all_ids)} organisations in MySQL")

        # Remove previously processed
        ids_to_process = (org for org in all_ids if org not in existing_ids)

        job_params = []
        for count, batch in enumerate(
                split_batches(ids_to_process, self.process_batch_size), 1):
            # write batch of ids to s3
            batch_file = put_s3_batch(batch, self.intermediate_bucket,
                                      self.routine_id)
            params = {
                "batch_file": batch_file,
                "config": 'mysqldb.config',
                "db_name": database,
                "bucket": self.intermediate_bucket,
                "done": False,
                'outinfo': es_config['host'],
                'out_port': es_config['port'],
                'out_index': es_config['index'],
                'out_type': es_config['type'],
                'aws_auth_region': es_config['region'],
                'entity_type': self.entity_type,
                'test': self.test,
                'routine_id': self.routine_id
            }
            params.update(self.kwargs)

            logging.info(params)
            job_params.append(params)
            if self.test and count > 1:
                logging.warning("Breaking after 2 batches while in "
                                "test mode.")
                logging.warning(job_params)
                break

        logging.warning("Batch preparation completed, "
                        f"with {len(job_params)} batches")
        return job_params