def prepare(self): if self.test: self.process_batch_size = 1000 logging.warning("Batch size restricted to " f"{self.process_batch_size}" " while in test mode") es_mode = 'dev' if self.test else 'prod' es, es_config = setup_es(es_mode, self.test, self.drop_and_recreate, dataset='nih', aliases='health_scanner', increment_version=True) # Count articles from the old index _old_config = es_config.copy() _old_config['index'] = es_config['old_index'] logging.info(f"Collected article IDs...") _ids = get_es_ids(es, _old_config, size=10000) logging.info(f"Collected {len(_ids)} IDs") done_ids = get_es_ids(es, es_config, size=10000) # Generate the job params job_params = [] batches = split_batches(_ids, self.process_batch_size) for count, batch in enumerate(batches, 1): # Magical '0.3' is the lower end of the deduplication # fraction found by inspection done = sum(_id in done_ids for _id in batch) / len(batch) > 0.3 # write batch of ids to s3 batch_file = '' if not done: batch_file = put_s3_batch(batch, self.intermediate_bucket, self.routine_id) params = { "batch_file": batch_file, "config": 'mysqldb.config', "bucket": self.intermediate_bucket, "done": done, 'outinfo': es_config['host'], 'out_port': es_config['port'], 'out_index': es_config['index'], 'in_index': es_config['old_index'], 'out_type': es_config['type'], 'aws_auth_region': es_config['region'], 'entity_type': 'paper', 'test': self.test, 'routine_id': self.routine_id } job_params.append(params) if self.test and count > 1: logging.warning("Breaking after 2 batches " "while in test mode.") logging.warning(job_params) break logging.info("Batch preparation completed, " f"with {len(job_params)} batches") return job_params
def run(self): """Apply health labels using model.""" # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) try_until_allowed(Base.metadata.create_all, self.engine) # collect and unpickle models from s3 logging.info("Collecting models from S3") s3 = boto3.resource('s3') vectoriser_obj = s3.Object(self.bucket, self.vectoriser_key) vectoriser = pickle.loads( vectoriser_obj.get()['Body']._raw_stream.read()) classifier_obj = s3.Object(self.bucket, self.classifier_key) classifier = pickle.loads( classifier_obj.get()['Body']._raw_stream.read()) # retrieve organisations and categories nrows = 1000 if self.test else None logging.info("Collecting organisations from database") with db_session(self.engine) as session: orgs = (session.query(Organization.id).filter( Organization.is_health.is_(None)).limit(nrows).all()) for batch_count, batch in enumerate( split_batches(orgs, self.insert_batch_size), 1): batch_orgs_with_cats = [] for (org_id, ) in batch: with db_session(self.engine) as session: categories = (session.query( OrganizationCategory.category_name).filter( OrganizationCategory.organization_id == org_id).all()) # categories should be a list of str, comma separated: ['cat,cat,cat', 'cat,cat'] categories = ','.join(cat_name for (cat_name, ) in categories) batch_orgs_with_cats.append({ 'id': org_id, 'categories': categories }) logging.debug( f"{len(batch_orgs_with_cats)} organisations retrieved from database" ) logging.debug("Predicting health flags") batch_orgs_with_flag = predict_health_flag(batch_orgs_with_cats, vectoriser, classifier) logging.debug( f"{len(batch_orgs_with_flag)} organisations to update") with db_session(self.engine) as session: session.bulk_update_mappings(Organization, batch_orgs_with_flag) logging.info( f"{batch_count} batches health labeled and written to db") # mark as done logging.warning("Task complete") self.output().touch()
def prepare(self): '''Chunk up elasticsearch data, and submit batch jobs over those chunks.''' if self.test: self.process_batch_size = 1000 logging.warning("Batch size restricted to " f"{self.process_batch_size}" " while in test mode") # Setup elasticsearch and extract all ids es_mode = 'dev' if self.test else 'prod' es, es_config = setup_es(es_mode, self.test, drop_and_recreate=False, dataset=self.dataset, increment_version=False) ids = get_es_ids(es, es_config, size=10000) # All ids in this index ids = ids - self._done_ids # Don't repeat done ids # Override the default index if specified es_config['index'] = (self.index if self.index is not None else es_config['index']) # Generate the job params job_params = [] batches = split_batches(ids, self.process_batch_size) for count, batch in enumerate(batches, 1): done = False # Already taken care of with _done_ids # write batch of ids to s3 batch_file = '' if not done: batch_file = put_s3_batch(batch, self.intermediate_bucket, self.routine_id) params = { "batch_file": batch_file, "config": self.sql_config_filename, "bucket": self.intermediate_bucket, "done": done, "count": len(ids), 'outinfo': es_config['host'], 'out_port': es_config['port'], 'index': es_config['index'], 'out_type': es_config['type'], 'aws_auth_region': es_config['region'], 'test': self.test, 'routine_id': self.routine_id, 'entity_type': self.entity_type, **self.kwargs } job_params.append(params) # Test mode if self.test and count > 1: logging.warning("Breaking after 2 batches " "while in test mode.") logging.warning(job_params) break # Done logging.info("Batch preparation completed, " f"with {len(job_params)} batches") return job_params
def _batch_query_articles_by_doi(query, articles, batch_size=10): """Manages batches and generates sparql queries for articles and queries them from mag via the sparql api using the supplied `doi`. Args: query (str): sparql query containing a format string placeholder {} articles (:obj:`list` of :obj:`dict`): articles to query in MAG. Must contatin at least `id` and `doi` in each dict. batch_size (int): number of ids to query in a batch. Max size = 50 Yields: (:obj:`list` of :obj:`dict`): batches of data returned from MAG """ if not 1 <= batch_size <= 10: # max limit for uri length raise ValueError("batch_size must be between 1 and 10") for articles_batch in split_batches(articles, batch_size): clean_dois = [(a['doi'].replace('\n', '').replace('\\', '').replace('"', '')) for a in articles_batch] concat_dois = ','.join(f'"{a}"^^xsd:string' for a in clean_dois) article_filter = f"FILTER (?doi IN ({concat_dois}))" for results_batch in sparql_query( MAG_ENDPOINT, query.format(article_filter=article_filter)): yield articles_batch, results_batch
def _batched_entity_filter(concat_format, filter_on, ids, batch_size): """Creates batches of entity filters for SPARQL queries. Constructing a 'filter in' statement using the provided ids, splitting whenever the batch size is hit. A call with the following arguments: _batched_entity_filter('<http://ma-graph.org/entity/{}>', 'data', [1, 2], 50) Will return a generator which yields the following: "FILTER (?data IN (<http://ma-graph.org/entity/1>,<http://ma-graph.org/entity/2>))" Args: concat_format (str): string format to be applied when concatenating ids requires a placeholder for id {} filter_on (str): name of the field to use in the filter. The '?' prefix is not required ids (list): If ids are supplied they are queried as batches, otherwise all entities are queried batch_size (int): number of ids to query in a batch. Yields: (str): filter string containing ids up to the chosen batch size """ for batch_of_ids in split_batches(ids, batch_size): entities = ','.join(concat_format.format(i) for i in batch_of_ids) yield f"FILTER (?{filter_on} IN ({entities}))"
def run(self): data = extract_data(limit=1000 if self.test else None) logging.info(f'Got {len(data)} rows') database = 'dev' if self.test else 'production' for chunk in split_batches(data, 10000): logging.info(f'Inserting chunk of size {len(chunk)}') insert_data('MYSQLDB', 'mysqldb', database, Base, ApplnFamily, chunk, low_memory=True) self.output().touch()
def __iter__(self): for batch_of_ids in split_batches(self.ids, self.batch_size): self.title_articles_lookup.clear() for article in (self.session.query(Article).filter( Article.id.in_(batch_of_ids)).all()): self.title_articles_lookup[prepare_title( article.title)].append(article.id) for title in self.title_articles_lookup: yield title
def run(self): """Collect and process organizations, categories and long descriptions.""" # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) try_until_allowed(Base.metadata.create_all, self.engine) limit = 2000 if self.test else None batch_size = 30 if self.test else 1000 with db_session(self.engine) as session: all_orgs = session.query( Organisation.id, Organisation.addresses).limit(limit).all() existing_org_location_ids = session.query( OrganisationLocation.id).all() logging.info(f"{len(all_orgs)} organisations retrieved from database") logging.info( f"{len(existing_org_location_ids)} organisations have previously been processed" ) # convert to a list of dictionaries with the nested addresses unpacked orgs = get_orgs_to_process(all_orgs, existing_org_location_ids) logging.info(f"{len(orgs)} new organisations to geocode") total_batches = ceil(len(orgs) / batch_size) logging.info(f"{total_batches} batches") completed_batches = 0 for batch in split_batches(orgs, batch_size=batch_size): # geocode first to add missing country for UK batch = map(geocode_uk_with_postcode, batch) batch = map(add_country_details, batch) # remove data not in OrganisationLocation columns org_location_cols = OrganisationLocation.__table__.columns.keys() batch = [{k: v for k, v in org.items() if k in org_location_cols} for org in batch] insert_data(self.db_config_env, 'mysqldb', database, Base, OrganisationLocation, batch) completed_batches += 1 logging.info( f"Completed {completed_batches} of {total_batches} batches") if self.test and completed_batches > 1: logging.warning("Breaking after 2 batches in test mode") break # mark as done logging.warning("Finished task") self.output().touch()
def run(self): # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) # collect mesh terms from S3 bucket = 'innovation-mapping-general' key = 'crunchbase_descriptions/crunchbase_descriptions_mesh.txt' mesh_terms = retrieve_mesh_terms(bucket, key) mesh_terms = format_mesh_terms( mesh_terms) # [{'id': ['term1', 'term2']}, ...] logging.info(f"File contains {len(mesh_terms)} orgs with mesh terms") logging.info("Extracting previously processed orgs") with db_session(self.engine) as session: all_orgs = session.query(Organization.id, Organization.mesh_terms).all() processed_orgs = { org_id for (org_id, mesh_terms) in all_orgs if mesh_terms is not None } all_orgs = {org_id for (org_id, _) in all_orgs} logging.info(f"{len(all_orgs)} total orgs in database") logging.info(f"{len(processed_orgs)} previously processed orgs") # reformat for batch insert, removing not found and previously processed terms meshed_orgs = [{ 'id': org_id, 'mesh_terms': '|'.join(terms) } for org_id, terms in mesh_terms.items() if org_id in all_orgs and org_id not in processed_orgs] logging.info(f"{len(meshed_orgs)} organisations to update in database") for count, batch in enumerate( split_batches(meshed_orgs, self.insert_batch_size), 1): with db_session(self.engine) as session: session.bulk_update_mappings(Organization, batch) logging.info( f"{count} batch{'es' if count > 1 else ''} written to db") if self.test and count > 1: logging.info("Breaking after 2 batches while in test mode") break # mark as done logging.warning("Task complete") self.output().touch()
def run(self): # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) # collect file logging.info(f"Collecting org_parents from crunchbase tar") org_parents = get_files_from_tar(['org_parents'])[0] logging.info(f"{len(org_parents)} parent ids in crunchbase export") # collect previously processed orgs logging.info("Extracting previously processed organisations") with db_session(self.engine) as session: processed_orgs = session.query(Organization.id, Organization.parent_id).all() all_orgs = {org for (org, _) in processed_orgs} logging.info(f"{len(all_orgs)} total orgs in database") processed_orgs = { org for (org, parent_id) in processed_orgs if parent_id is not None } logging.info(f"{len(processed_orgs)} previously processed orgs") # reformat into a list of dicts, removing orgs that already have a parent_id # or are missing from the database org_parents = org_parents[['uuid', 'parent_uuid']] org_parents.columns = ['id', 'parent_id'] org_parents = org_parents[org_parents['id'].isin(all_orgs)] org_parents = org_parents[~org_parents['id'].isin(processed_orgs)] org_parents = org_parents.to_dict(orient='records') logging.info(f"{len(org_parents)} organisations to update in MYSQL") # insert parent_ids into db in batches for count, batch in enumerate( split_batches(org_parents, self.insert_batch_size), 1): with db_session(self.engine) as session: session.bulk_update_mappings(Organization, batch) logging.info( f"{count} batch{'es' if count > 1 else ''} written to db") if self.test and count > 1: logging.info("Breaking after 2 batches while in test mode") break # mark as done logging.warning("Task complete") self.output().touch()
def run(): test = literal_eval(os.environ["BATCHPAR_test"]) db_name = os.environ["BATCHPAR_db_name"] table = os.environ["BATCHPAR_table"] batch_size = int(os.environ["BATCHPAR_batch_size"]) s3_path = os.environ["BATCHPAR_outinfo"] logging.warning(f"Processing {table} file") # database setup engine = get_mysql_engine("BATCHPAR_config", "mysqldb", db_name) try_until_allowed(Base.metadata.create_all, engine) table_name = f"crunchbase_{table}" table_class = get_class_by_tablename(Base, table_name) # collect file nrows = 1000 if test else None df = get_files_from_tar([table], nrows=nrows)[0] logging.warning(f"{len(df)} rows in file") # get primary key fields and set of all those already existing in the db pk_cols = list(table_class.__table__.primary_key.columns) pk_names = [pk.name for pk in pk_cols] with db_session(engine) as session: existing_rows = set(session.query(*pk_cols).all()) # process and insert data processed_rows = process_non_orgs(df, existing_rows, pk_names) for batch in split_batches(processed_rows, batch_size): insert_data("BATCHPAR_config", 'mysqldb', db_name, Base, table_class, processed_rows, low_memory=True) logging.warning(f"Marking task as done to {s3_path}") s3 = boto3.resource('s3') s3_obj = s3.Object(*parse_s3_path(s3_path)) s3_obj.put(Body="") logging.warning("Batch job complete.")
def prepare(self): if self.test: self.process_batch_size = 100 # MySQL setup database = 'dev' if self.test else 'production' engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) # Subtract off all done ids Base.metadata.create_all(engine) with db_session(engine) as session: result = session.query(Project.rcn).all() done_rcn = {r[0] for r in result} # Get all possible ids (or "RCN" in Cordis-speak) nrows = 1000 if self.test else None all_rcn = set( get_framework_ids('fp7', nrows=nrows) + get_framework_ids('h2020', nrows=nrows)) all_rcn = all_rcn - done_rcn # Generate the job params batches = split_batches(all_rcn, self.process_batch_size) params = [{ "batch_file": put_s3_batch(batch, self.intermediate_bucket, self.routine_id), "config": 'mysqldb.config', "db_name": database, "bucket": self.intermediate_bucket, "outinfo": 'dummy', "done": False, 'test': self.test } for batch in batches] return params
def prepare(self): '''Prepare the batch job parameters''' db = 'dev' if self.test else 'production' engine = get_mysql_engine(self.db_config_env, 'mysqldb', db) with db_session(engine) as session: results = (session.query( Projects.id, func.length(Projects.abstractText)).filter( Projects.abstractText is not None).distinct( Projects.abstractText).all()) # Keep documents with a length larger than the 10th percentile. perc = np.percentile([r[1] for r in results], 10) all_ids = [r.id for r in results if r[1] >= perc] job_params = [] for count, batch in enumerate(split_batches(all_ids, self.process_batch_size), start=1): # write batch of ids to s3 key = f'text2vec-{self.routine_id}-{self.date}-{count}' batch_file = put_s3_batch(batch, self.intermediate_bucket, key) done = key in DONE_KEYS params = { "config": "mysqldb.config", "bucket": self.intermediate_bucket, "batch_file": batch_file, "db_name": db, "done": done, 'outinfo': f"s3://{self.intermediate_bucket}/{key}", # mark as done 'test': self.test, } job_params.append(params) logging.info(params) return job_params
def prepare(self): if self.test: self.process_batch_size = 1000 logging.warning("Batch size restricted to " f"{self.process_batch_size}" " while in test mode") # MySQL setup self.database = 'dev' if self.test else 'production' engine = get_mysql_engine(self.db_config_env, 'mysqldb', self.database) # Elasticsearch setup es_mode = 'dev' if self.test else 'prod' es, es_config = setup_es(es_mode, self.test, self.drop_and_recreate, dataset='crunchbase', aliases='health_scanner') # Get set of existing ids from elasticsearch via scroll scanner = scan(es, query={"_source": False}, index=es_config['index'], doc_type=es_config['type']) existing_ids = {s['_id'] for s in scanner} logging.info(f"Collected {len(existing_ids)} existing in " "Elasticsearch") # Get set of all organisations from mysql all_orgs = list(all_org_ids(engine)) logging.info(f"{len(all_orgs)} organisations in MySQL") # Remove previously processed orgs_to_process = list(org for org in all_orgs if org not in existing_ids) logging.info(f"{len(orgs_to_process)} to be processed") job_params = [] for count, batch in enumerate( split_batches(orgs_to_process, self.process_batch_size), 1): logging.info(f"Processing batch {count} with size {len(batch)}") # write batch of ids to s3 batch_file = put_s3_batch(batch, self.intermediate_bucket, 'crunchbase_to_es') params = { "batch_file": batch_file, "config": 'mysqldb.config', "db_name": self.database, "bucket": self.intermediate_bucket, "done": False, 'outinfo': es_config['host'], 'out_port': es_config['port'], 'out_index': es_config['index'], 'out_type': es_config['type'], 'aws_auth_region': es_config['region'], 'entity_type': 'company', "test": self.test } logging.info(params) job_params.append(params) if self.test and count > 1: logging.warning("Breaking after 2 batches while in " "test mode.") break logging.warning("Batch preparation completed, " f"with {len(job_params)} batches") return job_params
def run(self): """Collect and process organizations, categories and long descriptions.""" # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) try_until_allowed(Base.metadata.create_all, self.engine) # collect files nrows = 200 if self.test else None cat_groups, orgs, org_descriptions = get_files_from_tar( ['category_groups', 'organizations', 'organization_descriptions'], nrows=nrows) # process category_groups cat_groups = rename_uuid_columns(cat_groups) insert_data(self.db_config_env, 'mysqldb', database, Base, CategoryGroup, cat_groups.to_dict(orient='records'), low_memory=True) # process organizations and categories with db_session(self.engine) as session: existing_orgs = session.query(Organization.id).all() existing_orgs = {org[0] for org in existing_orgs} logging.info("Summary of organisation data:") logging.info(f"Total number of organisations:\t {len(orgs)}") logging.info( f"Number of organisations already in the database:\t {len(existing_orgs)}" ) logging.info(f"Number of category groups and text descriptions:\t" f"{len(cat_groups)}, {len(org_descriptions)}") processed_orgs, org_cats, missing_cat_groups = process_orgs( orgs, existing_orgs, cat_groups, org_descriptions) # Insert CatGroups insert_data(self.db_config_env, 'mysqldb', database, Base, CategoryGroup, missing_cat_groups) # Insert orgs in batches n_batches = round(len(processed_orgs) / self.insert_batch_size) logging.info( f"Inserting {n_batches} batches of size {self.insert_batch_size}") for i, batch in enumerate( split_batches(processed_orgs, self.insert_batch_size)): if i % 100 == 0: logging.info(f"Inserting batch {i} of {n_batches}") insert_data(self.db_config_env, 'mysqldb', database, Base, Organization, batch, low_memory=True) # link table needs to be inserted via non-bulk method to enforce relationship logging.info("Filtering duplicates...") org_cats, existing_org_cats, failed_org_cats = filter_out_duplicates( self.db_config_env, 'mysqldb', database, Base, OrganizationCategory, org_cats, low_memory=True) logging.info( f"Inserting {len(org_cats)} org categories " f"({len(existing_org_cats)} already existed and {len(failed_org_cats)} failed)" ) #org_cats = [OrganizationCategory(**org_cat) for org_cat in org_cats] with db_session(self.engine) as session: session.add_all(org_cats) # mark as done self.output().touch()
def run(self): # s3 setup s3 = boto3.resource('s3') intermediate_file = s3.Object(BUCKET, INTERMEDIATE_FILE) # database setup database = 'dev' if self.test else 'production' logging.info(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) Base.metadata.create_all(self.engine) eu = get_eu_countries() logging.info(f"Retrieved {len(eu)} EU countries") with db_session(self.engine) as session: all_fos_ids = {f.id for f in (session .query(FieldOfStudy.id) .all())} logging.info(f"{len(all_fos_ids):,} fields of study in database") eu_grid_ids = {i.id for i in (session .query(Institute.id) .filter(Institute.country.in_(eu)) .all())} logging.info(f"{len(eu_grid_ids):,} EU institutes in GRID") try: processed_grid_ids = set(json.loads(intermediate_file .get()['Body'] ._raw_stream.read())) logging.info(f"{len(processed_grid_ids)} previously processed institutes") eu_grid_ids = eu_grid_ids - processed_grid_ids logging.info(f"{len(eu_grid_ids):,} institutes to process") except ClientError: logging.info("Unable to load file of processed institutes, starting from scratch") processed_grid_ids = set() if self.test: self.batch_size = 500 batch_limit = 1 else: batch_limit = None testing_finished = False row_count = 0 for institute_count, grid_id in enumerate(eu_grid_ids): paper_ids, author_ids = set(), set() data = {Paper: [], Author: [], PaperAuthor: set(), PaperFieldsOfStudy: set(), PaperLanguage: set()} if not institute_count % 50: logging.info(f"{institute_count:,} of {len(eu_grid_ids):,} institutes processed") if not check_institute_exists(grid_id): logging.debug(f"{grid_id} not found in MAG") continue # these tables have data stored in sets for deduping so the fieldnames will # need to be added when converting to a list of dicts for loading to the db field_names_to_add = {PaperAuthor: ('paper_id', 'author_id'), PaperFieldsOfStudy: ('paper_id', 'field_of_study_id'), PaperLanguage: ('paper_id', 'language')} logging.info(f"Querying MAG for {grid_id}") for row in query_by_grid_id(grid_id, from_date=self.from_date, min_citations=self.min_citations, batch_size=self.batch_size, batch_limit=batch_limit): fos_id = row['fieldOfStudyId'] if fos_id not in all_fos_ids: logging.info(f"Getting missing field of study {fos_id} from MAG") update_field_of_study_ids_sparql(self.engine, fos_ids=[fos_id]) all_fos_ids.add(fos_id) # the returned data is normalised and therefore contains many duplicates paper_id = row['paperId'] if paper_id not in paper_ids: data[Paper].append({'id': paper_id, 'title': row['paperTitle'], 'citation_count': row['paperCitationCount'], 'created_date': row['paperCreatedDate'], 'doi': row.get('paperDoi'), 'book_title': row.get('bookTitle')}) paper_ids.add(paper_id) author_id = row['authorId'] if author_id not in author_ids: data[Author].append({'id': author_id, 'name': row['authorName'], 'grid_id': grid_id}) author_ids.add(author_id) data[PaperAuthor].add((row['paperId'], row['authorId'])) data[PaperFieldsOfStudy].add((row['paperId'], row['fieldOfStudyId'])) try: data[PaperLanguage].add((row['paperId'], row['paperLanguage'])) except KeyError: # language is an optional field pass row_count += 1 if self.test and row_count >= 1000: logging.warning("Breaking after 1000 rows in test mode") testing_finished = True break # write out to SQL for table, rows in data.items(): if table in field_names_to_add: rows = [{k: v for k, v in zip(field_names_to_add[table], row)} for row in rows] logging.debug(f"Writing {len(rows):,} rows to {table.__table__.name}") for batch in split_batches(rows, self.insert_batch_size): insert_data('MYSQLDB', 'mysqldb', database, Base, table, batch) # flag institute as completed on S3 processed_grid_ids.add(grid_id) intermediate_file.put(Body=json.dumps(list(processed_grid_ids))) if testing_finished: break # mark as done logging.info("Task complete") self.output().touch()
def test_split_batches_when_data_is_smaller_than_batch_size(generate_test_data): yielded_batches = [] for batch in split_batches(generate_test_data(200), batch_size=1000): yielded_batches.append(batch) assert len(yielded_batches) == 1
def test_split_batches_yields_multiple_batches_with_remainder(generate_test_data): yielded_batches = [] for batch in split_batches(generate_test_data(2400), batch_size=1000): yielded_batches.append(batch) assert len(yielded_batches) == 3
def test_split_batches_with_set(generate_test_set_data): yielded_batches = [] for batch in split_batches(generate_test_set_data(2400), batch_size=1000): yielded_batches.append(batch) assert len(yielded_batches) == 3
def prepare(self): if self.test: self.process_batch_size = 1000 logging.warning("Batch size restricted to " f"{self.process_batch_size}" " while in test mode") # MySQL setup database = 'dev' if self.test else 'production' engine = get_mysql_engine(self.db_config_env, self.db_section, database) # Elasticsearch setup es_mode = 'dev' if self.test else 'prod' es, es_config = setup_es(es_mode, self.test, self.drop_and_recreate, dataset=self.dataset, aliases=self.aliases) # Get set of existing ids from elasticsearch via scroll existing_ids = get_es_ids(es, es_config) logging.info(f"Collected {len(existing_ids)} existing in " "Elasticsearch") # Get set of all organisations from mysql with db_session(engine) as session: result = session.query(self.id_field).all() all_ids = {r[0] for r in result} logging.info(f"{len(all_ids)} organisations in MySQL") # Remove previously processed ids_to_process = (org for org in all_ids if org not in existing_ids) job_params = [] for count, batch in enumerate( split_batches(ids_to_process, self.process_batch_size), 1): # write batch of ids to s3 batch_file = put_s3_batch(batch, self.intermediate_bucket, self.routine_id) params = { "batch_file": batch_file, "config": 'mysqldb.config', "db_name": database, "bucket": self.intermediate_bucket, "done": False, 'outinfo': es_config['host'], 'out_port': es_config['port'], 'out_index': es_config['index'], 'out_type': es_config['type'], 'aws_auth_region': es_config['region'], 'entity_type': self.entity_type, 'test': self.test, 'routine_id': self.routine_id } params.update(self.kwargs) logging.info(params) job_params.append(params) if self.test and count > 1: logging.warning("Breaking after 2 batches while in " "test mode.") logging.warning(job_params) break logging.warning("Batch preparation completed, " f"with {len(job_params)} batches") return job_params