def _insert_new_locations(self): """Checks for new city/country combinations and appends them to the geographic data table in mysql. """ limit = self.test_limit if self.test else None with db_session(self.engine) as session: existing_location_ids = { i[0] for i in session.query(Geographic.id).all() } new_locations = [] for city, country, key in (session.query( self.city_col, self.country_col, self.location_key_col).distinct( self.location_key_col).limit(limit)): if key not in existing_location_ids and key is not None: logging.info(f"new location {city}, {country}") new_locations.append( dict(id=key, city=city, country=country)) existing_location_ids.add(key) if new_locations: logging.warning( f"Adding {len(new_locations)} new locations to database") insert_data(self.db_config_env, "mysqldb", self.database, Base, Geographic, new_locations)
def tests_insert_and_exists(self): data = [ { "_id": 10, "_another_id": 2, "some_field": 20 }, { "_id": 10, "_another_id": 2, "some_field": 30 }, # <--- Dupe pk, so should be ignored { "_id": 20, "_another_id": 2, "some_field": 30 } ] objs = insert_data("MYSQLDBCONF", "mysqldb", "production_tests", Base, DummyModel, data) self.assertEqual(len(objs), 2) objs = insert_data("MYSQLDBCONF", "mysqldb", "production_tests", Base, DummyModel, data) self.assertEqual(len(objs), 0)
def run(self): data = extract_data(limit=1000 if self.test else None) logging.info(f'Got {len(data)} rows') database = 'dev' if self.test else 'production' for chunk in split_batches(data, 10000): logging.info(f'Inserting chunk of size {len(chunk)}') insert_data('MYSQLDB', 'mysqldb', database, Base, ApplnFamily, chunk, low_memory=True) self.output().touch()
def run(self): db = 'production' if not self.test else 'dev' keys = self.get_abstract_file_keys(bucket, key_prefix) engine = get_mysql_engine(self.db_config_env, 'mysqldb', db) with db_session(engine) as session: if self.test: existing_projects = set() projects = session.query(Projects.application_id).distinct() for p in projects: existing_projects.update(int(p.application_id)) projects_done = set() projects_mesh = session.query(ProjectMeshTerms.project_id).distinct() for p in projects_mesh: projects_done.update(int(p.project_id)) mesh_term_ids = {int(m.id) for m in session.query(MeshTerms.id).all()} logging.info('Inserting associations') for key_count, key in enumerate(keys): if self.test and (key_count > 2): continue # collect mesh results from s3 file and groups by project id # each project id has set of mesh terms and corresponding term ids df_mesh = retrieve_mesh_terms(bucket, key) project_terms = self.format_mesh_terms(df_mesh) # go through documents for project_count, (project_id, terms) in enumerate(project_terms.items()): rows = [] if self.test and (project_count > 2): continue if (project_id in projects_done) or (project_id not in existing_projects): continue for term, term_id in zip(terms['terms'], terms['ids']): term_id = int(term_id) # add term to mesh term table if not present if term_id not in mesh_term_ids: objs = insert_data( self.db_config_env, 'mysqldb', db, Base, MeshTerms, [{'id': term_id, 'term': term}], low_memory=True) mesh_term_ids.update({term_id}) # prepare row to be added to project-mesh_term link table rows.append({'project_id': project_id, 'mesh_term_id': term_id}) # inesrt rows to link table insert_data(self.db_config_env, 'mysqldb', db, Base, ProjectMeshTerms, rows, low_memory=True) self.output().touch() # populate project-mesh_term link table
def run(self): """Collect and process organizations, categories and long descriptions.""" # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) try_until_allowed(Base.metadata.create_all, self.engine) limit = 2000 if self.test else None batch_size = 30 if self.test else 1000 with db_session(self.engine) as session: all_orgs = session.query( Organisation.id, Organisation.addresses).limit(limit).all() existing_org_location_ids = session.query( OrganisationLocation.id).all() logging.info(f"{len(all_orgs)} organisations retrieved from database") logging.info( f"{len(existing_org_location_ids)} organisations have previously been processed" ) # convert to a list of dictionaries with the nested addresses unpacked orgs = get_orgs_to_process(all_orgs, existing_org_location_ids) logging.info(f"{len(orgs)} new organisations to geocode") total_batches = ceil(len(orgs) / batch_size) logging.info(f"{total_batches} batches") completed_batches = 0 for batch in split_batches(orgs, batch_size=batch_size): # geocode first to add missing country for UK batch = map(geocode_uk_with_postcode, batch) batch = map(add_country_details, batch) # remove data not in OrganisationLocation columns org_location_cols = OrganisationLocation.__table__.columns.keys() batch = [{k: v for k, v in org.items() if k in org_location_cols} for org in batch] insert_data(self.db_config_env, 'mysqldb', database, Base, OrganisationLocation, batch) completed_batches += 1 logging.info( f"Completed {completed_batches} of {total_batches} batches") if self.test and completed_batches > 1: logging.warning("Breaking after 2 batches in test mode") break # mark as done logging.warning("Finished task") self.output().touch()
def run(): batch_file = os.environ['BATCHPAR_batch_file'] bucket = os.environ['BATCHPAR_bucket'] db_name = os.environ['BATCHPAR_db_name'] db_env = "BATCHPAR_config" db_section = "mysqldb" # Setup the database connectors engine = get_mysql_engine(db_env, db_section, db_name) try_until_allowed(Base.metadata.create_all, engine) # Retrieve RCNs to iterate over s3 = boto3.resource('s3') obj = s3.Object(bucket, batch_file) all_rcn = json.loads(obj.get()['Body']._raw_stream.read()) logging.info(f"{len(all_rcn)} project RCNs retrieved from s3") # Retrieve all topics data = defaultdict(list) for i, rcn in enumerate(all_rcn): logging.info(i) project, orgs, reports, pubs = fetch_data(rcn) if project is None: continue _topics = project.pop('topics') _calls = project.pop('proposal_call') # NB: Order below matters due to FK constraints! data['projects'].append(project) data['reports'] += prepare_data(reports, rcn) data['publications'] += prepare_data(pubs, rcn) data['organisations'] += extract_core_orgs(orgs, rcn) data['project_organisations'] += prepare_data(orgs, rcn) for topics, project_topics in split_links(_topics, rcn): data['topics'].append(topics) data['project_topics'].append(project_topics) for calls, project_calls in split_links(_calls, rcn): data['proposal_calls'].append(calls) data['project_proposal_calls'].append(project_calls) # Pipe the data to the db for table_prefix, rows in data.items(): table_name = f'cordis_{table_prefix}' logging.info(table_name) _class = get_class_by_tablename(Base, table_name) insert_data(db_env, db_section, db_name, Base, _class, rows, low_memory=True)
def run(): test = literal_eval(os.environ["BATCHPAR_test"]) db_name = os.environ["BATCHPAR_db_name"] batch_size = int(os.environ["BATCHPAR_batch_size"]) # example parameter s3_path = os.environ["BATCHPAR_outinfo"] start_string = os.environ["BATCHPAR_start_string"], # example parameter offset = int(os.environ["BATCHPAR_offset"]) # reduce records in test mode if test: limit = 50 logging.info(f"Limiting to {limit} rows in test mode") else: limit = batch_size logging.info(f"Processing {offset} - {offset + limit}") # database setup logging.info(f"Using {db_name} database") engine = get_mysql_engine("BATCHPAR_config", "mysqldb", db_name) try_until_allowed(Base.metadata.create_all, engine) with db_session(engine) as session: # consider moving this query and the one from the prepare step into a package batch_records = (session.query(MyTable.id, MyTable.name).filter( MyTable.founded_on > '2007-01-01').offset(offset).limit(limit)) # process and insert data processed_batch = [] for row in batch_records: processed_row = some_func(start_string=start_string, row=row) processed_batch.append(processed_row) logging.info(f"Inserting {len(processed_batch)} rows") insert_data("BATCHPAR_config", 'mysqldb', db_name, Base, MyOtherTable, processed_batch, low_memory=True) logging.info(f"Marking task as done to {s3_path}") s3 = boto3.resource('s3') s3_obj = s3.Object(*parse_s3_path(s3_path)) s3_obj.put(Body="") logging.info("Batch job complete.")
def run(self): # Get all UK geographies, and group by country and base gss_codes = get_gss_codes() country_codes = defaultdict(lambda: defaultdict(list)) for code in gss_codes: country = code[0] base = code[0:3] # Shortened test mode if not self.production and base not in ("S32", "S23"): continue country_codes[country][base].append(code) # Iterate through country and base output = [] for country, base_codes in country_codes.items(): # Try to find children for each base... for base in base_codes.keys(): for base_, codes in base_codes.items(): # ...except for the base of the parent if base == base_: continue output += get_children(base, codes) # Write to database _class = get_class_by_tablename(Base, "onsOpenGeo_geographic_lookup") objs = insert_data(MYSQLDB_ENV, "mysqldb", "production" if self.production else "dev", Base, _class, output) self.output().touch()
def run(): logging.getLogger().setLevel(logging.INFO) # Fetch the input parameters group_urlnames = literal_eval(os.environ["BATCHPAR_group_urlnames"]) group_urlnames = [x.decode("utf8") for x in group_urlnames] s3_path = os.environ["BATCHPAR_outinfo"] db = os.environ["BATCHPAR_db"] # Generate the groups for these members _output = [] for urlname in group_urlnames: _info = get_group_details(urlname, max_results=200) if len(_info) == 0: continue _output.append(_info) logging.info("Processed %s groups", len(_output)) # Flatten the output output = flatten_data(_output, keys=[('category', 'name'), ('category', 'shortname'), ('category', 'id'), 'created', 'country', 'city', 'description', 'id', 'lat', 'lon', 'members', 'name', 'topics', 'urlname']) objs = insert_data("BATCHPAR_config", "mysqldb", db, Base, Group, output[48:49]) # Mark the task as done s3 = boto3.resource('s3') s3_obj = s3.Object(*parse_s3_path(s3_path)) s3_obj.put(Body="") return len(objs)
def run(): logging.getLogger().setLevel(logging.INFO) # Fetch the input parameters member_ids = literal_eval(os.environ["BATCHPAR_member_ids"]) s3_path = os.environ["BATCHPAR_outinfo"] db = os.environ["BATCHPAR_db"] # Generate the groups for these members output = [] for member_id in member_ids: response = get_member_details(member_id, max_results=200) output += get_member_groups(response) logging.info("Got %s groups", len(output)) # Load connection to the db, and create the tables objs = insert_data("BATCHPAR_config", "mysqldb", db, Base, GroupMember, output) logging.info("Inserted %s groups", len(objs)) # Mark the task as done s3 = boto3.resource('s3') s3_obj = s3.Object(*parse_s3_path(s3_path)) s3_obj.put(Body="") return len(objs)
def run(self): config, geogs_list, dataset_id, date_format = process_config(self.config_name, test=not self.production) for igeo, geographies in enumerate(geogs_list): if igeo == 0: continue logging.debug(f"Geography number {igeo}") done = False record_offset = 0 while not done: logging.debug(f"\tOffset of {record_offset}") df, done, record_offset = batch_request(config, dataset_id, geographies, date_format, max_api_calls=10, record_offset=record_offset) data = {self.config_name: df} tables = reformat_nomis_columns(data) for name, table in tables.items(): name = name.split('-sic')[0] # If sic codes are used in the name logging.debug(f"\t\tInserting {len(table)} into nomis_{name}...") _class = get_class_by_tablename(Base, f"nomis_{name}") objs = insert_data(MYSQLDB_ENV, "mysqldb", "production" if self.production else "dev", Base, _class, table, low_memory=True) logging.debug(f"\t\tInserted {len(objs)}") #data = get_nomis_data(self.config_name, test=not self.production) #tables = reformat_nomis_columns({self.config_name:data}) self.output().touch()
def test_object_to_dict(self): parents = [{ "_id": 10, "_another_id": 2, "some_field": 20 }, { "_id": 20, "_another_id": 2, "some_field": 20 }] _parents = insert_data("MYSQLDBCONF", "mysqldb", "production_tests", Base, DummyModel, parents) assert len(parents) == len(_parents) children = [{ "_id": 10, "parent_id": 10 }, { "_id": 10, "parent_id": 20 }, { "_id": 20, "parent_id": 20 }, { "_id": 30, "parent_id": 20 }] _children = insert_data("MYSQLDBCONF", "mysqldb", "production_tests", Base, DummyChild, children) assert len(children) == len(_children) # Re-retrieve parents from the database found_children = set() engine = get_mysql_engine("MYSQLDBCONF", "mysqldb") with db_session(engine) as session: for p in session.query(DummyModel).all(): row = object_to_dict(p) assert type(row) is dict assert len(row['children']) > 0 _found_children = set( (c['_id'], c['parent_id']) for c in row['children']) found_children = found_children.union(_found_children) _row = object_to_dict(p, shallow=True) assert 'children' not in _row del row['children'] assert row == _row assert len(found_children) == len(children) == len(_children)
def run(): test = literal_eval(os.environ["BATCHPAR_test"]) db_name = os.environ["BATCHPAR_db_name"] table = os.environ["BATCHPAR_table"] batch_size = int(os.environ["BATCHPAR_batch_size"]) s3_path = os.environ["BATCHPAR_outinfo"] logging.warning(f"Processing {table} file") # database setup engine = get_mysql_engine("BATCHPAR_config", "mysqldb", db_name) try_until_allowed(Base.metadata.create_all, engine) table_name = f"crunchbase_{table}" table_class = get_class_by_tablename(Base, table_name) # collect file nrows = 1000 if test else None df = get_files_from_tar([table], nrows=nrows)[0] logging.warning(f"{len(df)} rows in file") # get primary key fields and set of all those already existing in the db pk_cols = list(table_class.__table__.primary_key.columns) pk_names = [pk.name for pk in pk_cols] with db_session(engine) as session: existing_rows = set(session.query(*pk_cols).all()) # process and insert data processed_rows = process_non_orgs(df, existing_rows, pk_names) for batch in split_batches(processed_rows, batch_size): insert_data("BATCHPAR_config", 'mysqldb', db_name, Base, table_class, processed_rows, low_memory=True) logging.warning(f"Marking task as done to {s3_path}") s3 = boto3.resource('s3') s3_obj = s3.Object(*parse_s3_path(s3_path)) s3_obj.put(Body="") logging.warning("Batch job complete.")
def run(self): '''Run the data collection''' #engine = get_mysql_engine(MYSQLDB_ENV, "mysqldb", # self.db_config['database']) #Base.metadata.create_all(engine) #Session = sessionmaker(engine) #session = Session() wiki_date = find_latest_wikidump() ngrams = extract_ngrams(wiki_date) if self.test: ngrams = list(ngrams)[0:100] #for n in ngrams: # ngram = WiktionaryNgram(ngram=n) # session.add(ngram) #session.commit() #session.close() insert_data(MYSQLDB_ENV, "mysqldb", self.db_config['database'], Base, WiktionaryNgram, [dict(ngram=n) for n in ngrams]) self.output().touch()
def _insert_new_locations_no_id(self): """Checks for new city/country combinations and appends them to the geographic data table in mysql IF NO location_key_col IS PROVIDED. """ limit = self.test_limit if self.test else None with db_session(self.engine) as session: existing_location_ids = { i[0] for i in session.query(Geographic.id).all() } new_locations = [] all_locations = {(city, country) for city, country in (session.query( self.city_col, self.country_col).limit(limit)) } nulls = [] for city, country in all_locations: if self.country_is_iso2: country = country_iso_code_to_name(country, iso2=True) if city is None or country is None: nulls.append((city, country)) continue key = generate_composite_key(city, country) if key not in existing_location_ids and key is not None: logging.info(f"new location {city}, {country}") new_locations.append( dict(id=key, city=city, country=country)) existing_location_ids.add(key) if len(nulls) > 0: logging.warning(f"{len(nulls)} locations had a null city or " "country, so won't be processed.") logging.warning(nulls) if new_locations: logging.warning( f"Adding {len(new_locations)} new locations to database") insert_data(self.db_config_env, "mysqldb", self.database, Base, Geographic, new_locations)
def test_db_session_query(self): parents = [{ "_id": i, "_another_id": i, "some_field": 20 } for i in range(0, 1000)] _parents = insert_data("MYSQLDBCONF", "mysqldb", "production_tests", Base, DummyModel, parents) # Re-retrieve parents from the database engine = get_mysql_engine("MYSQLDBCONF", "mysqldb") # Test for limit = 3 limit = 3 old_db = mock.MagicMock() old_db.is_active = False n_rows = 0 for db, row in db_session_query(query=DummyModel, engine=engine, chunksize=10, limit=limit): assert type(row) is DummyModel if old_db != db: assert len(old_db.transaction._connections) == 0 assert len(db.transaction._connections) > 0 old_db = db n_rows += 1 assert n_rows == limit # Test for limit = None old_db = mock.MagicMock() old_db.is_active = False n_rows = 0 for db, row in db_session_query(query=DummyModel, engine=engine, chunksize=100, limit=None): assert type(row) is DummyModel if old_db != db: assert len(old_db.transaction._connections) == 0 assert len(db.transaction._connections) > 0 old_db = db n_rows += 1 assert n_rows == len(parents) == 1000
def run(): logging.getLogger().setLevel(logging.INFO) # Fetch the input parameters group_urlname = os.environ["BATCHPAR_group_urlname"] group_id = os.environ["BATCHPAR_group_id"] s3_path = os.environ["BATCHPAR_outinfo"] db = os.environ["BATCHPAR_db"] # Collect members logging.info("Getting %s", group_urlname) output = get_all_members(group_id, group_urlname, max_results=200) logging.info("Got %s members", len(output)) # Load connection to the db, and create the tables objs = insert_data("BATCHPAR_config", "mysqldb", db, Base, GroupMember, output) # Mainly for testing return len(objs)
def run(): PAGE_SIZE = int(os.environ['BATCHPAR_PAGESIZE']) page = int(os.environ['BATCHPAR_page']) db = os.environ["BATCHPAR_db"] s3_path = os.environ["BATCHPAR_outinfo"] data = defaultdict(list) # Get all projects on this page projects = read_xml_from_url(TOP_URL, p=page, s=PAGE_SIZE) for project in projects.getchildren(): # Extract the data for the project into 'row' # Then recursively extract data from nested rows into the parent 'row' _, row = extract_data(project) extract_data_recursive(project, row) # Flatten out any list data directly into 'data' unpack_list_data(row, data) # Append the row data[row.pop('entity')].append(row) # Much of the participant data is repeated so remove overlaps if 'participant' in data: deduplicate_participants(data) # Finally, extract links between entities and the core projects extract_link_table(data) objs = [] for table_name, rows in data.items(): _class = get_class_by_tablename(Base, f"gtr_{table_name}") # Remove any fields that aren't in the ORM cleaned_rows = [{k:v for k, v in row.items() if k in _class.__dict__} for row in rows] objs += insert_data("BATCHPAR_config", "mysqldb", db, Base, _class, cleaned_rows) # Mark the task as done if s3_path != "": s3 = boto3.resource('s3') s3_obj = s3.Object(*parse_s3_path(s3_path)) s3_obj.put(Body="") return len(objs)
def run(): logging.getLogger().setLevel(logging.INFO) # Fetch the input parameters iso2 = os.environ["BATCHPAR_iso2"] name = os.environ["BATCHPAR_name"] category = os.environ["BATCHPAR_cat"] coords = literal_eval(os.environ["BATCHPAR_coords"]) radius = float(os.environ["BATCHPAR_radius"]) s3_path = os.environ["BATCHPAR_outinfo"] db = os.environ["BATCHPAR_db"] # Get the data mcg = MeetupCountryGroups(country_code=iso2, category=category, coords=coords, radius=radius) mcg.get_groups_recursive() output = flatten_data(mcg.groups, country_name=name, country=iso2, timestamp=func.utc_timestamp(), keys=[('category', 'name'), ('category', 'shortname'), ('category', 'id'), 'description', 'created', 'country', 'city', 'id', 'lat', 'lon', 'members', 'name', 'topics', 'urlname']) # Add the data objs = insert_data("BATCHPAR_config", "mysqldb", db, Base, Group, output) # Mark the task as done s3 = boto3.resource('s3') s3_obj = s3.Object(*parse_s3_path(s3_path)) s3_obj.put(Body="") # Mainly for testing return len(objs)
def run(self): # Load the input data (note the input contains the path # to the output) _filename = self.cherry_picked if _filename is None: _body = self.input().open("rb") _filename = _body.read().decode('utf-8') obj = s3.S3Target(f"{self.raw_s3_path_prefix}/" f"{_filename}").open('rb') data = json.load(obj) # Get DB connections and settings database = 'dev' if self.test else 'production' engine = get_mysql_engine(self.db_conf_env, 'mysqldb', database) ArticleTopic.__table__.drop(engine) CorExTopic.__table__.drop(engine) # Insert the topic names data topics = [{ 'id': int(topic_name.split('_')[-1]) + 1, 'terms': terms } for topic_name, terms in data['data']['topic_names'].items()] insert_data(self.db_conf_env, 'mysqldb', database, Base, CorExTopic, topics, low_memory=True) logging.info(f'Inserted {len(topics)} topics') # Insert article topic weight data topic_articles = [] done_ids = set() for row in data['data']['rows']: article_id = row.pop('id') if article_id in done_ids: continue done_ids.add(article_id) topic_articles += [{ 'topic_id': int(topic_name.split('_')[-1]) + 1, 'topic_weight': weight, 'article_id': article_id } for topic_name, weight in row.items()] # Flush if len(topic_articles) > self.insert_batch_size: insert_data(self.db_conf_env, 'mysqldb', database, Base, ArticleTopic, topic_articles, low_memory=True) topic_articles = [] # Final flush if len(topic_articles) > 0: insert_data(self.db_conf_env, 'mysqldb', database, Base, ArticleTopic, topic_articles, low_memory=True) # Touch the output self.output().touch()
def run(self): """Collect and process organizations, categories and long descriptions.""" # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) try_until_allowed(Base.metadata.create_all, self.engine) # collect files nrows = 200 if self.test else None cat_groups, orgs, org_descriptions = get_files_from_tar( ['category_groups', 'organizations', 'organization_descriptions'], nrows=nrows) # process category_groups cat_groups = rename_uuid_columns(cat_groups) insert_data(self.db_config_env, 'mysqldb', database, Base, CategoryGroup, cat_groups.to_dict(orient='records'), low_memory=True) # process organizations and categories with db_session(self.engine) as session: existing_orgs = session.query(Organization.id).all() existing_orgs = {org[0] for org in existing_orgs} logging.info("Summary of organisation data:") logging.info(f"Total number of organisations:\t {len(orgs)}") logging.info( f"Number of organisations already in the database:\t {len(existing_orgs)}" ) logging.info(f"Number of category groups and text descriptions:\t" f"{len(cat_groups)}, {len(org_descriptions)}") processed_orgs, org_cats, missing_cat_groups = process_orgs( orgs, existing_orgs, cat_groups, org_descriptions) # Insert CatGroups insert_data(self.db_config_env, 'mysqldb', database, Base, CategoryGroup, missing_cat_groups) # Insert orgs in batches n_batches = round(len(processed_orgs) / self.insert_batch_size) logging.info( f"Inserting {n_batches} batches of size {self.insert_batch_size}") for i, batch in enumerate( split_batches(processed_orgs, self.insert_batch_size)): if i % 100 == 0: logging.info(f"Inserting batch {i} of {n_batches}") insert_data(self.db_config_env, 'mysqldb', database, Base, Organization, batch, low_memory=True) # link table needs to be inserted via non-bulk method to enforce relationship logging.info("Filtering duplicates...") org_cats, existing_org_cats, failed_org_cats = filter_out_duplicates( self.db_config_env, 'mysqldb', database, Base, OrganizationCategory, org_cats, low_memory=True) logging.info( f"Inserting {len(org_cats)} org categories " f"({len(existing_org_cats)} already existed and {len(failed_org_cats)} failed)" ) #org_cats = [OrganizationCategory(**org_cat) for org_cat in org_cats] with db_session(self.engine) as session: session.add_all(org_cats) # mark as done self.output().touch()
def run(): db_name = os.environ["BATCHPAR_db_name"] s3_path = os.environ["BATCHPAR_outinfo"] start_cursor = int(os.environ["BATCHPAR_start_cursor"]) end_cursor = int(os.environ["BATCHPAR_end_cursor"]) batch_size = end_cursor - start_cursor logging.warning(f"Retrieving {batch_size} articles between {start_cursor - 1}:{end_cursor - 1}") # Setup the database connectors engine = get_mysql_engine("BATCHPAR_config", "mysqldb", db_name) try_until_allowed(Base.metadata.create_all, engine) # load arxiv subject categories to database bucket = 'innovation-mapping-general' cat_file = 'arxiv_classification/arxiv_subject_classifications.csv' load_arxiv_categories("BATCHPAR_config", db_name, bucket, cat_file) # process data articles = [] article_cats = [] resumption_token = request_token() for row in retrieve_arxiv_batch_rows(start_cursor, end_cursor, resumption_token): with db_session(engine) as session: categories = row.pop('categories', []) articles.append(row) for cat in categories: # TODO:this is inefficient and should be queried once to a set. see # iterative proceess. try: session.query(Category).filter(Category.id == cat).one() except NoResultFound: logging.warning(f"missing category: '{cat}' for article {row['id']}. Adding to Category table") session.add(Category(id=cat)) article_cats.append(dict(article_id=row['id'], category_id=cat)) inserted_articles, existing_articles, failed_articles = insert_data( "BATCHPAR_config", "mysqldb", db_name, Base, Article, articles, return_non_inserted=True) logging.warning(f"total article categories: {len(article_cats)}") inserted_article_cats, existing_article_cats, failed_article_cats = insert_data( "BATCHPAR_config", "mysqldb", db_name, Base, ArticleCategory, article_cats, return_non_inserted=True) # sanity checks before the batch is marked as done logging.warning((f'inserted articles: {len(inserted_articles)} ', f'existing articles: {len(existing_articles)} ', f'failed articles: {len(failed_articles)}')) logging.warning((f'inserted article categories: {len(inserted_article_cats)} ', f'existing article categories: {len(existing_article_cats)} ', f'failed article categories: {len(failed_article_cats)}')) if len(inserted_articles) + len(existing_articles) + len(failed_articles) != batch_size: raise ValueError(f'Inserted articles do not match original data.') if len(inserted_article_cats) + len(existing_article_cats) + len(failed_article_cats) != len(article_cats): raise ValueError(f'Inserted article categories do not match original data.') # Mark the task as done s3 = boto3.resource('s3') s3_obj = s3.Object(*parse_s3_path(s3_path)) s3_obj.put(Body="")
def run(self): # s3 setup s3 = boto3.resource('s3') intermediate_file = s3.Object(BUCKET, INTERMEDIATE_FILE) # database setup database = 'dev' if self.test else 'production' logging.info(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) Base.metadata.create_all(self.engine) eu = get_eu_countries() logging.info(f"Retrieved {len(eu)} EU countries") with db_session(self.engine) as session: all_fos_ids = {f.id for f in (session .query(FieldOfStudy.id) .all())} logging.info(f"{len(all_fos_ids):,} fields of study in database") eu_grid_ids = {i.id for i in (session .query(Institute.id) .filter(Institute.country.in_(eu)) .all())} logging.info(f"{len(eu_grid_ids):,} EU institutes in GRID") try: processed_grid_ids = set(json.loads(intermediate_file .get()['Body'] ._raw_stream.read())) logging.info(f"{len(processed_grid_ids)} previously processed institutes") eu_grid_ids = eu_grid_ids - processed_grid_ids logging.info(f"{len(eu_grid_ids):,} institutes to process") except ClientError: logging.info("Unable to load file of processed institutes, starting from scratch") processed_grid_ids = set() if self.test: self.batch_size = 500 batch_limit = 1 else: batch_limit = None testing_finished = False row_count = 0 for institute_count, grid_id in enumerate(eu_grid_ids): paper_ids, author_ids = set(), set() data = {Paper: [], Author: [], PaperAuthor: set(), PaperFieldsOfStudy: set(), PaperLanguage: set()} if not institute_count % 50: logging.info(f"{institute_count:,} of {len(eu_grid_ids):,} institutes processed") if not check_institute_exists(grid_id): logging.debug(f"{grid_id} not found in MAG") continue # these tables have data stored in sets for deduping so the fieldnames will # need to be added when converting to a list of dicts for loading to the db field_names_to_add = {PaperAuthor: ('paper_id', 'author_id'), PaperFieldsOfStudy: ('paper_id', 'field_of_study_id'), PaperLanguage: ('paper_id', 'language')} logging.info(f"Querying MAG for {grid_id}") for row in query_by_grid_id(grid_id, from_date=self.from_date, min_citations=self.min_citations, batch_size=self.batch_size, batch_limit=batch_limit): fos_id = row['fieldOfStudyId'] if fos_id not in all_fos_ids: logging.info(f"Getting missing field of study {fos_id} from MAG") update_field_of_study_ids_sparql(self.engine, fos_ids=[fos_id]) all_fos_ids.add(fos_id) # the returned data is normalised and therefore contains many duplicates paper_id = row['paperId'] if paper_id not in paper_ids: data[Paper].append({'id': paper_id, 'title': row['paperTitle'], 'citation_count': row['paperCitationCount'], 'created_date': row['paperCreatedDate'], 'doi': row.get('paperDoi'), 'book_title': row.get('bookTitle')}) paper_ids.add(paper_id) author_id = row['authorId'] if author_id not in author_ids: data[Author].append({'id': author_id, 'name': row['authorName'], 'grid_id': grid_id}) author_ids.add(author_id) data[PaperAuthor].add((row['paperId'], row['authorId'])) data[PaperFieldsOfStudy].add((row['paperId'], row['fieldOfStudyId'])) try: data[PaperLanguage].add((row['paperId'], row['paperLanguage'])) except KeyError: # language is an optional field pass row_count += 1 if self.test and row_count >= 1000: logging.warning("Breaking after 1000 rows in test mode") testing_finished = True break # write out to SQL for table, rows in data.items(): if table in field_names_to_add: rows = [{k: v for k, v in zip(field_names_to_add[table], row)} for row in rows] logging.debug(f"Writing {len(rows):,} rows to {table.__table__.name}") for batch in split_batches(rows, self.insert_batch_size): insert_data('MYSQLDB', 'mysqldb', database, Base, table, batch) # flag institute as completed on S3 processed_grid_ids.add(grid_id) intermediate_file.put(Body=json.dumps(list(processed_grid_ids))) if testing_finished: break # mark as done logging.info("Task complete") self.output().touch()