class AMSLServiceDeprecated(AMSLTask): """ Defunkt task via #14415 as of 2018-12-12. Will be remove soon. Retrieve AMSL API response. Outbound: holdingsfiles, contentfiles, metadata_usage. 2018-12-12: discovery API EOL, XXX: adjust, refs #14415. Example output (discovery): [ { "shardLabel": "SLUB-dbod", "sourceID": "64", "megaCollection": "Perinorm – Datenbank für Normen und technische Regeln", "productISIL": null, "externalLinkToContentFile": null, "contentFileLabel": null, "contentFileURI": null, "linkToContentFile": null, "ISIL": "DE-105", "evaluateHoldingsFileForLibrary": "no", "holdingsFileLabel": null, "holdingsFileURI": null, "linkToHoldingsFile": null }, { "shardLabel": "SLUB-dbod", "sourceID": "64", "megaCollection": "Perinorm – Datenbank für Normen und technische Regeln", "productISIL": null, "externalLinkToContentFile": null, "contentFileLabel": null, "contentFileURI": null, "linkToContentFile": null, "ISIL": "DE-14", "evaluateHoldingsFileForLibrary": "no", "holdingsFileLabel": null, "holdingsFileURI": null, "linkToHoldingsFile": null }, ... """ date = luigi.DateParameter(default=datetime.date.today()) name = luigi.Parameter( default='outboundservices:discovery', description= 'discovery, holdingsfiles, contentfiles, metadata_usage, freeContent') def run(self): parts = self.name.split(':') if not len(parts) == 2: raise RuntimeError( 'realm:name expected, e.g. outboundservices:discovery') realm, name = parts link = '%s/%s/list?do=%s' % (self.config.get( 'amsl', 'base').rstrip('/'), realm, name) output = shellout("""curl --fail "{link}" | pigz -c > {output} """, link=link) # Check for valid JSON before, simplifies debugging. with gzip.open(output, 'rb') as handle: try: _ = json.load(handle) except ValueError as err: self.logger.warning("AMSL API did not return valid JSON") raise luigi.LocalTarget(output).move(self.output().path) def output(self): return luigi.LocalTarget(path=self.path(digest=True, ext='json.gz'), format=Gzip)
class CrunchbaseSql2EsTask(autobatch.AutoBatchTask): '''Download tar file of csvs and load them into the MySQL server. Args: date (datetime): Datetime used to label the outputs _routine_id (str): String used to label the AWS task db_config_env (str): The output database envariable process_batch_size (int): Number of rows to process in a batch insert_batch_size (int): Number of rows to insert into the db in a batch intermediate_bucket (str): S3 bucket where the list of ids for each batch are written ''' date = luigi.DateParameter() _routine_id = luigi.Parameter() db_config_env = luigi.Parameter() process_batch_size = luigi.IntParameter(default=10000) insert_batch_size = luigi.IntParameter() intermediate_bucket = luigi.Parameter() drop_and_recreate = luigi.BoolParameter(default=False) def requires(self): yield DescriptionMeshTask(date=self.date, _routine_id=self._routine_id, test=self.test, insert_batch_size=self.insert_batch_size, db_config_path=self.db_config_path, db_config_env=self.db_config_env) def output(self): '''Points to the output database engine''' self.db_config_path = os.environ[self.db_config_env] db_config = get_config(self.db_config_path, "mysqldb") db_config["database"] = 'dev' if self.test else 'production' db_config[ "table"] = "Crunchbase to Elasticsearch <dummy>" # Note, not a real table update_id = "CrunchbaseToElasticsearch_{}".format(self.date) return MySqlTarget(update_id=update_id, **db_config) def prepare(self): if self.test: self.process_batch_size = 1000 logging.warning("Batch size restricted to " f"{self.process_batch_size}" " while in test mode") # MySQL setup self.database = 'dev' if self.test else 'production' engine = get_mysql_engine(self.db_config_env, 'mysqldb', self.database) # Elasticsearch setup es_mode = 'dev' if self.test else 'prod' es, es_config = setup_es(es_mode, self.test, self.drop_and_recreate, dataset='crunchbase', aliases='health_scanner') # Get set of existing ids from elasticsearch via scroll scanner = scan(es, query={"_source": False}, index=es_config['index'], doc_type=es_config['type']) existing_ids = {s['_id'] for s in scanner} logging.info(f"Collected {len(existing_ids)} existing in " "Elasticsearch") # Get set of all organisations from mysql all_orgs = list(all_org_ids(engine)) logging.info(f"{len(all_orgs)} organisations in MySQL") # Remove previously processed orgs_to_process = list(org for org in all_orgs if org not in existing_ids) logging.info(f"{len(orgs_to_process)} to be processed") job_params = [] for count, batch in enumerate( split_batches(orgs_to_process, self.process_batch_size), 1): logging.info(f"Processing batch {count} with size {len(batch)}") # write batch of ids to s3 batch_file = put_s3_batch(batch, self.intermediate_bucket, 'crunchbase_to_es') params = { "batch_file": batch_file, "config": 'mysqldb.config', "db_name": self.database, "bucket": self.intermediate_bucket, "done": False, 'outinfo': es_config['host'], 'out_port': es_config['port'], 'out_index': es_config['index'], 'out_type': es_config['type'], 'aws_auth_region': es_config['region'], 'entity_type': 'company', "test": self.test } logging.info(params) job_params.append(params) if self.test and count > 1: logging.warning("Breaking after 2 batches while in " "test mode.") break logging.warning("Batch preparation completed, " f"with {len(job_params)} batches") return job_params def combine(self, job_params): '''Touch the checkpoint''' self.output().touch()
class CommonDateTask(luigi.Task): d = luigi.DateParameter() def output(self): return MockTarget(self.d.strftime('/n2000y01a05n/%Y_%m-_-%daww/21mm01dara21/ooo'))
class Blah(RunOnceTask): date = luigi.DateParameter() blah_arg = luigi.IntParameter()
def testDateWithInterval(self): p = luigi.DateParameter(config_path=dict(section="foo", name="bar"), interval=3, start=datetime.date(2001, 2, 1)) self.assertEqual(datetime.date(2001, 2, 1), _value(p))
class DailyProcessFromCybersourceTask(PullFromCybersourceTaskMixin, luigi.Task): """ A task that reads a local file generated from a daily Cybersource pull, and writes to a TSV file. The output file should be readable by Hive, and be in a common format across other payment accounts. """ run_date = luigi.DateParameter( default=datetime.date.today(), description='Date to fetch Cybersource report. Default is today.', ) output_root = luigi.Parameter( description='URL of location to write output.', ) def requires(self): args = { 'run_date': self.run_date, 'output_root': self.output_root, 'overwrite': self.overwrite, 'merchant_id': self.merchant_id, } return DailyPullFromCybersourceTask(**args) def run(self): # Read from input and reformat for output. self.remove_output_on_overwrite() with self.input().open('r') as input_file: # Skip the first line, which provides information about the source # of the file. The second line should define the column headings. _download_header = input_file.readline() reader = csv.DictReader(input_file, delimiter=',') with self.output().open('w') as output_file: for row in reader: # Output most of the fields from the original source. # The values not included are: # batch_id: CyberSource batch in which the transaction was sent. # payment_processor: code for organization that processes the payment. result = [ # Date row['batch_date'], # Name of system. 'cybersource', # CyberSource merchant ID used for the transaction. row['merchant_id'], # Merchant-generated order reference or tracking number. # For shoppingcart or otto, this should equal order_id, # though sometimes it is basket_id. row['merchant_ref_number'], # ISO currency code used for the transaction. row['currency'], row['amount'], # Transaction fee '\\N', TRANSACTION_TYPE_MAP[row['transaction_type']], # We currently only process credit card transactions with Cybersource 'credit_card', # Type of credit card used row['payment_method'].lower().replace(' ', '_'), # Identifier for the transaction. row['request_id'], ] output_file.write('\t'.join(result)) output_file.write('\n') def output(self): """ Output is set up so it can be read in as a Hive table with partitions. The form is {output_root}/payments/dt={CCYY-mm-dd}/cybersource_{merchant}.tsv """ date_string = self.run_date.strftime('%Y-%m-%d') # pylint: disable=no-member partition_path_spec = HivePartition('dt', date_string).path_spec filename = "cybersource_{}.tsv".format(self.merchant_id) url_with_filename = url_path_join(self.output_root, "payments", partition_path_spec, filename) return get_target_from_url(url_with_filename)
class GridTask(luigi.Task): """Join arxiv articles with GRID data for institute addresses and geocoding. Args: date (datetime): Datetime used to label the outputs _routine_id (str): String used to label the AWS task db_config_env (str): environmental variable pointing to the db config file db_config_path (str): The output database configuration mag_config_path (str): Microsoft Academic Graph Api key configuration path insert_batch_size (int): number of records to insert into the database at once (not used in this task but passed down to others) articles_from_date (str): new and updated articles from this date will be retrieved. Must be in YYYY-MM-DD format (not used in this task but passed down to others) """ date = luigi.DateParameter() _routine_id = luigi.Parameter() test = luigi.BoolParameter(default=True) db_config_env = luigi.Parameter() db_config_path = luigi.Parameter() mag_config_path = luigi.Parameter() insert_batch_size = luigi.IntParameter(default=500) articles_from_date = luigi.Parameter() def output(self): '''Points to the output database engine''' db_config = misctools.get_config(self.db_config_path, "mysqldb") db_config["database"] = 'dev' if self.test else 'production' db_config["table"] = "arXlive <dummy>" # Note, not a real table update_id = "ArxivGrid_{}".format(self.date) return MySqlTarget(update_id=update_id, **db_config) def requires(self): yield MagSparqlTask(date=self.date, _routine_id=self._routine_id, db_config_path=self.db_config_path, db_config_env=self.db_config_env, mag_config_path=self.mag_config_path, test=self.test, articles_from_date=self.articles_from_date, insert_batch_size=self.insert_batch_size) def run(self): # database setup database = 'dev' if self.test else 'production' logging.info(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) Base.metadata.create_all(self.engine) article_institute_batcher = BatchWriter(self.insert_batch_size, add_article_institutes, self.engine) match_attempted_batcher = BatchWriter(self.insert_batch_size, update_existing_articles, self.engine) fuzzer = ComboFuzzer([fuzz.token_sort_ratio, fuzz.partial_ratio], store_history=True) # extract lookup of GRID institute names to ids - seems to be OK to hold in memory institute_name_id_lookup = grid_name_lookup(self.engine) with db_session(self.engine) as session: # used to check GRID ids from MAG are valid (they are not all...) all_grid_ids = {i.id for i in session.query(Institute.id).all()} logging.info(f"{len(all_grid_ids)} institutes in GRID") article_query = (session.query( Article.id, Article.mag_authors).filter( Article.institute_match_attempted.is_(False) & ~Article.institutes.any() & Article.mag_authors.isnot(None))) total = article_query.count() logging.info( f"Total articles with authors and no institutes links: {total}" ) logging.debug("Starting the matching process") articles = article_query.all() for count, article in enumerate(articles, start=1): article_institute_links = [] for author in article.mag_authors: # prevent duplicates when a mixture of institute aliases are used in the same article existing_article_institute_ids = { link['institute_id'] for link in article_institute_links } # extract and validate grid_id try: extracted_grid_id = author['affiliation_grid_id'] except KeyError: pass else: # check grid id is valid if (extracted_grid_id in all_grid_ids and extracted_grid_id not in existing_article_institute_ids): links = create_article_institute_links( article_id=article.id, institute_ids=[extracted_grid_id], score=1) article_institute_links.extend(links) logging.debug(f"Used grid_id: {extracted_grid_id}") continue # extract author affiliation try: affiliation = author['author_affiliation'] except KeyError: # no grid id or affiliation for this author logging.debug(f"No affiliation found in: {author}") continue # look for an exact match on affiliation name try: institute_ids = institute_name_id_lookup[affiliation] except KeyError: pass else: institute_ids = set( institute_ids) - existing_article_institute_ids links = create_article_institute_links( article_id=article.id, institute_ids=institute_ids, score=1) article_institute_links.extend(links) logging.debug(f"Found an exact match for: {affiliation}") continue # fuzzy matching try: match, score = fuzzer.fuzzy_match_one( affiliation, institute_name_id_lookup.keys()) except KeyError: # failed fuzzy match logging.debug(f"Failed fuzzy match: {affiliation}") else: institute_ids = institute_name_id_lookup[match] institute_ids = set( institute_ids) - existing_article_institute_ids links = create_article_institute_links( article_id=article.id, institute_ids=institute_ids, score=score) article_institute_links.extend(links) logging.debug( f"Found a fuzzy match: {affiliation} {score} {match}" ) # add links for this article to the batch queue article_institute_batcher.extend(article_institute_links) # mark that matching has been attempted for this article match_attempted_batcher.append( dict(id=article.id, institute_match_attempted=True)) if not count % 100: logging.info( f"{count} processed articles from {total} : {(count / total) * 100:.1f}%" ) if self.test and count == 50: logging.warning("Exiting after 50 articles in test mode") logging.debug(article_institute_batcher) break # pick up any left over in the batches if article_institute_batcher: article_institute_batcher.write() if match_attempted_batcher: match_attempted_batcher.write() logging.info("All articles processed") logging.info( f"Total successful fuzzy matches for institute names: {len(fuzzer.successful_fuzzy_matches)}" ) logging.info( f"Total failed fuzzy matches for institute names{len(fuzzer.failed_fuzzy_matches): }" ) # mark as done logging.info("Task complete") self.output().touch()
class MeshJoinTask(luigi.Task): '''Joins MeSH labels stored in S3 to NIH projects in MySQL. Args: date (str): Date used to label the outputs _routine_id (str): String used to label the AWS task db_config_env (str): Environment variable for path to MySQL database configuration. ''' date = luigi.DateParameter() _routine_id = luigi.Parameter() db_config_env = luigi.Parameter() test = luigi.BoolParameter() @staticmethod def format_mesh_terms(df): """ Removes unrequired columns and pivots the mesh terms data into a dictionary. Args: df (dataframe): mesh terms as returned from retrieve_mesh_terms Returns: (dict): document_id: list of mesh terms """ logging.info("Formatting mesh terms") # remove PRC rows df = df.drop(df[df.term == 'PRC'].index, axis=0) # remove invalid error rows df = df.drop(df[df.doc_id.astype(str).str.contains('ERROR.*ERROR', na=False)].index, axis=0) df['term_id'] = df['term_id'].apply(lambda x: int(x[1:])) # pivot and remove unrequired columns doc_terms = { doc_id: {'terms': grouped.term.values, 'ids': grouped.term_id.values} for doc_id, grouped in df.groupby("doc_id")} return doc_terms @staticmethod def get_abstract_file_keys(bucket, key_prefix): s3 = boto3.resource('s3') s3bucket = s3.Bucket(bucket) return {o.key for o in s3bucket.objects.filter(Prefix=key_prefix)} def output(self): db_config = get_config(os.environ[self.db_config_env], "mysqldb") db_config['database'] = 'dev' if self.test else 'production' db_config['table'] = "MeshTerms <dummy>" update_id = "NihJoinMeshTerms_{}".format(self.date) return MySqlTarget(update_id=update_id, **db_config) def run(self): db = 'production' if not self.test else 'dev' keys = self.get_abstract_file_keys(bucket, key_prefix) engine = get_mysql_engine(self.db_config_env, 'mysqldb', db) with db_session(engine) as session: if self.test: existing_projects = set() projects = session.query(Projects.application_id).distinct() for p in projects: existing_projects.update(int(p.application_id)) projects_done = set() projects_mesh = session.query(ProjectMeshTerms.project_id).distinct() for p in projects_mesh: projects_done.update(int(p.project_id)) mesh_term_ids = {int(m.id) for m in session.query(MeshTerms.id).all()} logging.info('Inserting associations') for key_count, key in enumerate(keys): if self.test and (key_count > 2): continue # collect mesh results from s3 file and groups by project id # each project id has set of mesh terms and corresponding term ids df_mesh = retrieve_mesh_terms(bucket, key) project_terms = self.format_mesh_terms(df_mesh) # go through documents for project_count, (project_id, terms) in enumerate(project_terms.items()): rows = [] if self.test and (project_count > 2): continue if (project_id in projects_done) or (project_id not in existing_projects): continue for term, term_id in zip(terms['terms'], terms['ids']): term_id = int(term_id) # add term to mesh term table if not present if term_id not in mesh_term_ids: objs = insert_data( self.db_config_env, 'mysqldb', db, Base, MeshTerms, [{'id': term_id, 'term': term}], low_memory=True) mesh_term_ids.update({term_id}) # prepare row to be added to project-mesh_term link table rows.append({'project_id': project_id, 'mesh_term_id': term_id}) # inesrt rows to link table insert_data(self.db_config_env, 'mysqldb', db, Base, ProjectMeshTerms, rows, low_memory=True) self.output().touch() # populate project-mesh_term link table
class EndSong(luigi.Task): date = luigi.DateParameter()
class InputText(luigi.ExternalTask): date = luigi.DateParameter() def output(self): print("Task {} is running.".format(self.__class__.__name__)) return luigi.LocalTarget(self.date.strftime('/tmp/text/%Y-%m-%d.txt'))
class Sql2EsTask(autobatch.AutoBatchTask): '''Launches batch tasks to pipe data from MySQL to Elasticsearch. Args: date (datetime): Datetime used to label the outputs. routine_id (str): String used to label the AWS task. intermediate_bucket (str): Name of the S3 bucket where to store the batch ids. db_config_env (str): The output database envariable. process_batch_size (int): Number of rows to process in a batch. drop_and_recreate (bool): If in test mode, drop and recreate the ES index? dataset (str): Name of the elasticsearch dataset. id_field (SqlAlchemy selectable attribute): The ID field attribute. entity_type (str): Name of the entity type to label this task with. kwargs (dict): Any other job parameters to pass to the batchable. ''' date = luigi.DateParameter() routine_id = luigi.Parameter() intermediate_bucket = luigi.Parameter() db_config_env = luigi.Parameter() db_section = luigi.Parameter(default="mysqldb") process_batch_size = luigi.IntParameter(default=10000) drop_and_recreate = luigi.BoolParameter(default=False) aliases = luigi.Parameter(default=None) dataset = luigi.Parameter() id_field = luigi.Parameter() entity_type = luigi.Parameter() kwargs = luigi.DictParameter(default={}) def output(self): '''Points to the output database engine''' self.db_config_path = os.environ[self.db_config_env] db_config = get_config(self.db_config_path, "mysqldb") db_config["database"] = 'dev' if self.test else 'production' db_config["table"] = f"{self.routine_id} <dummy>" # Not a real table update_id = f"{self.routine_id}_{self.date}" return MySqlTarget(update_id=update_id, **db_config) def prepare(self): if self.test: self.process_batch_size = 1000 logging.warning("Batch size restricted to " f"{self.process_batch_size}" " while in test mode") # MySQL setup database = 'dev' if self.test else 'production' engine = get_mysql_engine(self.db_config_env, self.db_section, database) # Elasticsearch setup es_mode = 'dev' if self.test else 'prod' es, es_config = setup_es(es_mode, self.test, self.drop_and_recreate, dataset=self.dataset, aliases=self.aliases) # Get set of existing ids from elasticsearch via scroll existing_ids = get_es_ids(es, es_config) logging.info(f"Collected {len(existing_ids)} existing in " "Elasticsearch") # Get set of all organisations from mysql with db_session(engine) as session: result = session.query(self.id_field).all() all_ids = {r[0] for r in result} logging.info(f"{len(all_ids)} organisations in MySQL") # Remove previously processed ids_to_process = (org for org in all_ids if org not in existing_ids) job_params = [] for count, batch in enumerate( split_batches(ids_to_process, self.process_batch_size), 1): # write batch of ids to s3 batch_file = put_s3_batch(batch, self.intermediate_bucket, self.routine_id) params = { "batch_file": batch_file, "config": 'mysqldb.config', "db_name": database, "bucket": self.intermediate_bucket, "done": False, 'outinfo': es_config['host'], 'out_port': es_config['port'], 'out_index': es_config['index'], 'out_type': es_config['type'], 'aws_auth_region': es_config['region'], 'entity_type': self.entity_type, 'test': self.test, 'routine_id': self.routine_id } params.update(self.kwargs) logging.info(params) job_params.append(params) if self.test and count > 1: logging.warning("Breaking after 2 batches while in " "test mode.") logging.warning(job_params) break logging.warning("Batch preparation completed, " f"with {len(job_params)} batches") return job_params def combine(self, job_params): '''Touch the checkpoint''' self.output().touch()
def setUp(self): self.task = self.task_class( # pylint: disable=not-callable date=luigi.DateParameter().parse(self.DATE), overwrite_from_date=datetime.date(2014, 4, 1), )
class CreateCIKLookup(luigi.Task): date = luigi.DateParameter() def requires(self): return [PrepareEnv(dirs=""), GetCrawler(self.date)] def slices(self, s, *args): position = 0 for length in args: yield s[position:position + length] position += length def run(self): company_dict = {} main_json = {'root': []} forms = {} forminfo = {} crawl_file = open(output_folder + "/crawler_{}.idx".format(self.date), "r") for line in crawl_file.readlines()[9:]: company, form, cik, dateoffile, url = self.slices( line, 62, 12, 12, 12, 86) if form.strip() not in forminfo: forminfo[form.strip()] = 1 else: forminfo[form.strip()] = forminfo[form.strip()] + 1 f = { 'form': form.strip(), 'date': dateoffile.strip(), 'url': url.strip() } if cik.strip() not in forms: forms[cik.strip()] = [f] else: forms[cik.strip()].append(f) main_json.get('root').append(forms) with open(output_folder + "/form_counts_{}.json".format(self.date), 'w') as fm: ujson.dump(forminfo, fm, indent=4) with open(lookups_folder + "/company_{}.json".format(self.date), 'w') as lkp: ujson.dump(company_dict, lkp, indent=4) with open(lookups_folder + "/cik_url_{}.json".format(self.date), 'w') as urllkp: ujson.dump(main_json, urllkp, indent=4) def output(self): return [ luigi.LocalTarget(path=lookups_folder + "/company_{}.json".format(self.date)), luigi.LocalTarget(path=output_folder + "/form_counts_{}.json".format(self.date)), luigi.LocalTarget(path=lookups_folder + "/cik_url_{}.json".format(self.date)) ]
class AMSLFilterConfig(AMSLTask): """ Turn AMSL API to a span(1) filter configuration. Cases (Spring 2017): 61084 Case 3 (ISIL, SID, Collection, Holding File) 805 Case 2 (ISIL, SID, Collection, Product ISIL) 380 Case 1 (ISIL, SID, Collection) 38 Case 4 (ISIL, SID, Collection, External Content File) 31 Case 6 (ISIL, SID, Collection, External Content File, Holding File) 1 Case 7 (ISIL, SID, Collection, Internal Content File, Holding File) 1 Case 5 (ISIL, SID, Collection, Internal Content File) Process: AMSL Discovery API | v AMSLFilterConfig | v $ span-tag -c config.json < input.is > output.is Notes: This task turns an AMSL discovery API response into a filterconfig[1], which span(1) can understand. AMSL API might not specify everything we need to know, so this task shall be the only place, where workarounds happen. While span-tag is fast, it is not fast enough to iterate over a disjuction of 60K items for each of the 100M documents fast enough, which - if we could - would simplify the implementation of this task. The main speed improvement comes from using lists of collection names instead of having each collection processed separately - which is how it works conceptually: Each collection could use a separate KBART file (or use none at all). We ignore collection names, if (external) content files are used. These content files are usually there, because we cannot infer the collection name from the data alone. Performance data point: 22 ISIL each with between 1 and 26 alternatives for attachment, each alternative consisting of around three filters. Around 30 holding or content files each with between 10 and 50000 entries referenced about 200 times in total: around 20k records/s. Case table (Feb 2019), X/-/o, yes, no, maybe. SID COLL ISIL LTHF LTCF ELTCF PI TCID ------------------------------------- X X X - - - - o X X X X - - - o X X X X X - - o X X X X - X - o X X X X - - X o X X X - X - - o X X X - - X - o X X X - - - X o For a transition period, we extend the collection list from AMSL with canoncical names via DOI prefix, see #13587. The mapping is fixed and need to be updated manually. ---- [1] https://git.io/vQohE """ date = luigi.DateParameter(default=datetime.date.today()) def extend_collections(self, colls): """ Given a list of collection names, extend the list by adding the canonicals names as well, refs #13587. """ if not hasattr(self, '_name_to_canonical'): with open(self.assets("amsl/13587.json")) as handle: self._name_to_canonical = json.load(handle) result = set() for c in colls: result.add(c) if c in self._name_to_canonical: result.add(self._name_to_canonical[c]) self.logger.debug( "extended collection list from %d to %d items (%d mappings)" % (len(colls), len(result), len(self._name_to_canonical))) return list(result) def requires(self): return { 'amsl': AMSLService(date=self.date), 'wiso': AMSLWisoPackages(date=self.date), } def run(self): with self.input().get('amsl').open() as handle: doc = json.loads(handle.read()) # Case: ISIL, SID, collection. isilsidcollections = collections.defaultdict( lambda: collections.defaultdict(set)) # Case: ISIL, SID, collection, link. isilsidlinkcollections = collections.defaultdict( lambda: collections.defaultdict(lambda: collections.defaultdict(set ))) # Ready-made filters per ISIL. Some filters can be added on-the-fly # because there aren't many occurences. isilfilters = collections.defaultdict(list) for item in doc: isil, sid, mega_collection, technicalCollectionID = operator.itemgetter( 'ISIL', 'sourceID', 'megaCollection', 'technicalCollectionID')(item) if sid == '48': # Handled elsewhere. continue # refs #10495, a subject filter for a few hard-coded ISIL. if sid == '34' and isil in ('DE-L152', 'DE-1156', 'DE-1972', 'DE-Kn38'): isilfilters[isil].append({ "and": [ { "source": ["34"], }, { "subject": [ "Music", "Music education", ] }, ] }) continue # refs #10495, maybe use a TSV with custom column name to use a subject list? if sid == '34' and isil == 'DE-15-FID': isilfilters[isil].append({ "and": [ { "source": ["34"], }, { "subject": [ "Film studies", "Information science", "Mass communication", ] }, ] }) continue # SID COLL ISIL LTHF LTCF ELTCF PI TCID # ------------------------------------- # X X X - - - - o if dictcheck(item, contains=['sourceID', 'megaCollection', 'ISIL'], absent=[ 'linkToHoldingsFile', 'linkToContentFile', 'externalLinkToContentFile', 'productISIL' ], ignore=['technicalCollectionID']): isilsidcollections[isil][sid].add(mega_collection) if technicalCollectionID: isilsidcollections[isil][sid].add(technicalCollectionID) # SID COLL ISIL LTHF LTCF ELTCF PI TCID # ------------------------------------- # X X X - - - X o elif dictcheck(item, contains=[ 'sourceID', 'megaCollection', 'ISIL', 'productISIL' ], absent=[ 'linkToHoldingsFile', 'linkToContentFile', 'externalLinkToContentFile' ], ignore=['technicalCollectionID']): isilsidcollections[isil][sid].add(mega_collection) self.logger.debug("productISIL given, but ignored: %s, %s, %s", isil, sid, item['productISIL']) if technicalCollectionID: isilsidcollections[isil][sid].add(technicalCollectionID) # SID COLL ISIL LTHF LTCF ELTCF PI TCID # ------------------------------------- # X X X X - - X o elif dictcheck( item, contains=[ 'sourceID', 'megaCollection', 'ISIL', 'linkToHoldingsFile', 'productISIL' ], absent=['linkToContentFile', 'externalLinkToContentFile'], ignore=['technicalCollectionID']): self.logger.debug( "productISIL is set, but we do not have a filter for it yet: %s, %s, %s", isil, sid, mega_collection) if item.get('evaluateHoldingsFileForLibrary') == "yes": isilsidlinkcollections[isil][sid][ item['linkToHoldingsFile']].add(mega_collection) if technicalCollectionID: isilsidlinkcollections[isil][sid][item[ 'linkToHoldingsFile']].add(technicalCollectionID) else: self.logger.warning( "evaluateHoldingsFileForLibrary=no plus link: skipping %s", item) # SID COLL ISIL LTHF LTCF ELTCF PI TCID # ------------------------------------- # X X X X - - - o elif dictcheck(item, contains=[ 'sourceID', 'megaCollection', 'ISIL', 'linkToHoldingsFile' ], absent=[ 'linkToContentFile', 'externalLinkToContentFile', 'productISIL' ], ignore=['technicalCollectionID']): if item.get('evaluateHoldingsFileForLibrary') == "yes": isilsidlinkcollections[isil][sid][ item['linkToHoldingsFile']].add(mega_collection) if technicalCollectionID: isilsidlinkcollections[isil][sid][item[ 'linkToHoldingsFile']].add(technicalCollectionID) else: self.logger.warning( "evaluateHoldingsFileForLibrary=no plus link: skipping %s", item) # SID COLL ISIL LTHF LTCF ELTCF PI TCID # ------------------------------------- # X X X - - X - o elif dictcheck(item, contains=[ 'sourceID', 'megaCollection', 'ISIL', 'externalLinkToContentFile' ], absent=[ 'linkToHoldingsFile', 'linkToContentFile', 'productISIL' ], ignore=['technicalCollectionID']): isilfilters[isil].append({ "and": [ { "source": [sid] }, { "holdings": { "urls": [item["externalLinkToContentFile"]] } }, ] }) # SID COLL ISIL LTHF LTCF ELTCF PI TCID # ------------------------------------- # X X X - X - - o elif dictcheck(item, contains=[ 'sourceID', 'megaCollection', 'ISIL', 'linkToContentFile' ], absent=[ 'linkToHoldingsFile', 'externalLinkToContentFile', 'productISIL' ], ignore=['technicalCollectionID']): isilfilters[isil].append({ "and": [ { "source": [sid] }, { "holdings": { "urls": [item["linkToContentFile"]] } }, ] }) # SID COLL ISIL LTHF LTCF ELTCF PI TCID # ------------------------------------- # X X X X - X - o elif dictcheck(item, contains=[ 'sourceID', 'megaCollection', 'ISIL', 'linkToHoldingsFile', 'externalLinkToContentFile' ], absent=['linkToContentFile', 'productISIL'], ignore=['technicalCollectionID']): if item.get('evaluateHoldingsFileForLibrary') == "yes": isilfilters[isil].append({ "and": [ { "source": [sid] }, { "holdings": { "urls": [item["externalLinkToContentFile"]] } }, { "holdings": { "urls": [item["linkToHoldingsFile"]] } }, ] }) else: self.logger.warning( "evaluateHoldingsFileForLibrary=no plus link: skipping %s", item) # SID COLL ISIL LTHF LTCF ELTCF PI TCID # ------------------------------------- # X X X X X - - o elif dictcheck(item, contains=[ 'sourceID', 'megaCollection', 'ISIL', 'linkToHoldingsFile', 'linkToContentFile' ], absent=['externalLinkToContentFile', 'productISIL'], ignore=['technicalCollectionID']): if item.get('evaluateHoldingsFileForLibrary') == "yes": isilfilters[isil].append({ "and": [ { "source": [sid] }, { "holdings": { "urls": [item["linkToContentFile"]] } }, { "holdings": { "urls": [item["linkToHoldingsFile"]] } }, ] }) else: self.logger.warning( "evaluateHoldingsFileForLibrary=no plus link: skipping %s", item) # SID COLL ISIL LTHF LTCF ELTCF PI TCID # ------------------------------------- # ? ? ? ? ? ? ? o else: raise RuntimeError( "unhandled combination of sid, collection and other parameters: %s", item) # A second pass. for isil, blob in list(isilsidcollections.items()): for sid, colls in list(blob.items()): isilfilters[isil].append({ "and": [ { "source": [sid] }, { "collection": sorted(self.extend_collections(colls)) }, ] }) # A second pass. for isil, blob in list(isilsidlinkcollections.items()): for sid, spec in list(blob.items()): for link, colls in list(spec.items()): isilfilters[isil].append({ "and": [ { "source": [sid] }, { "collection": sorted(self.extend_collections(colls)) }, { "holdings": { "urls": [link] } }, ] }) # Final assembly. filterconfig = collections.defaultdict(dict) for isil, filters in list(isilfilters.items()): if len(filters) == 0: continue if len(filters) == 1: filterconfig[isil] = filters[0] continue filterconfig[isil] = {"or": filters} # Include WISO. with self.input().get('wiso').open() as handle: wisoconf = json.load(handle) for isil, tree in list(wisoconf.items()): for filter in tree.get('or', []): filterconfig[isil]['or'].append(filter) # XXX: Adjust a few items for DE-14, cf. 2018-06-11, namely, add links # to external holding files, which are not included into the AMSL # discovery API response, refs #13378. fix_url = 'https://dbod.de/SLUB-EZB-KBART.zip' for term in filterconfig["DE-14"]["or"]: for t in term["and"]: if (not 'holdings' in t) and (not 'urls' in t.get( 'holdings', [])): continue if fix_url in t['holdings']['urls']: continue t['holdings']['urls'].append( "https://dbod.de/SLUB-EZB-KBART.zip") with self.output().open('w') as output: json.dump(filterconfig, output, cls=SetEncoder) def output(self): return luigi.LocalTarget(path=self.path(ext='json'))
class IntervalPullFromCybersourceTask(PullFromCybersourceTaskMixin, WarehouseMixin, luigi.WrapperTask): """Determines a set of dates to pull, and requires them.""" interval_end = luigi.DateParameter( default=datetime.datetime.utcnow().date(), significant=False, description='Default is today, UTC.', ) # Overwrite parameter definition to make it optional. output_root = luigi.Parameter( default=None, description='URL of location to write output.', ) def __init__(self, *args, **kwargs): super(IntervalPullFromCybersourceTask, self).__init__(*args, **kwargs) # Provide default for output_root at this level. if self.output_root is None: self.output_root = self.warehouse_path path = url_path_join(self.warehouse_path, 'payments') file_pattern = '*cybersource_{}.tsv'.format(self.merchant_id) path_targets = PathSetTask([path], include=[file_pattern], include_zero_length=True).output() paths = list( set([os.path.dirname(target.path) for target in path_targets])) dates = [path.rsplit('/', 2)[-1] for path in paths] latest_date = sorted(dates)[-1] latest_completion_date = datetime.datetime.strptime( latest_date, "dt=%Y-%m-%d").date() run_date = latest_completion_date + datetime.timedelta(days=1) # Limit intervals to merchant account close date(if any). if self.merchant_close_date: run_date = min(run_date, self.merchant_close_date) self.interval_end = min(self.interval_end, self.merchant_close_date) self.selection_interval = date_interval.Custom(self.interval_start, run_date) self.run_interval = date_interval.Custom(run_date, self.interval_end) def requires(self): """Internal method to actually calculate required tasks once.""" yield PathSelectionByDateIntervalTask( source=[url_path_join(self.warehouse_path, 'payments')], interval=self.selection_interval, pattern=[ '.*dt=(?P<date>\\d{{4}}-\\d{{2}}-\\d{{2}})/cybersource_{}\\.tsv' .format(self.merchant_id) ], expand_interval=datetime.timedelta(0), date_pattern='%Y-%m-%d', ) for run_date in self.run_interval: yield DailyProcessFromCybersourceTask( merchant_id=self.merchant_id, output_root=self.output_root, run_date=run_date, overwrite=self.overwrite, ) def output(self): return [task.output() for task in self.requires()]
class GenerateW3ACTTitleExport(luigi.Task): task_namespace = 'discovery' date = luigi.DateParameter(default=datetime.date.today()) record_count = 0 blocked_record_count = 0 missing_record_count = 0 embargoed_record_count = 0 target_count = 0 collection_count = 0 collection_published_count = 0 subject_count = 0 subject_published_count = 0 def requires(self): return [TargetList(self.date), CollectionList(self.date), SubjectList(self.date)] def output(self): logger.warning('in output') return state_file(self.date,'access-data', 'title-level-metadata-w3act.xml') def run(self): # Get the data: targets = json.load(self.input()[0].open()) self.target_count = len(targets) collections = json.load(self.input()[1].open()) self.collection_count = len(collections) subjects = json.load(self.input()[2].open()) self.subject_count = len(subjects) # Index collections by ID: collections_by_id = {} for col in collections: collections_by_id[int(col['id'])] = col if col['publish']: self.collection_published_count += 1 # Index subjects by ID: subjects_by_id = {} for sub in subjects: subjects_by_id[int(sub['id'])] = sub if sub['publish']: self.subject_published_count += 1 # Convert to records: records = [] for target in targets: # Skip blocked items: if target['crawl_frequency'] == 'NEVERCRAWL': logger.warning("The Target '%s' is blocked (NEVERCRAWL)." % target['title']) self.blocked_record_count += 1 continue # Skip items that have no crawl permission? # hasOpenAccessLicense == False, and inScopeForLegalDeposit == False ? # Skip items with no URLs: if len(target.get('urls',[])) == 0: logger.warning("Skipping %s" % target.get('title', 'NO TITLE')) continue # Get the url, use the first: url = target['urls'][0] # Extract the domain: parsed_url = tldextract.extract(url) publisher = parsed_url.registered_domain # Lookup in CDX: wayback_date_str = CdxIndex().get_first_capture_date(url) # Get date in '20130401120000' form. if wayback_date_str is None: logger.warning("The URL '%s' is not yet available, inScopeForLegalDeposit = %s" % (url, target['isNPLD'])) self.missing_record_count += 1 continue wayback_date = datetime.datetime.strptime(wayback_date_str, '%Y%m%d%H%M%S') first_date = wayback_date.isoformat() # Honour embargo ago = datetime.datetime.now() - wayback_date if ago.days <= 7: self.embargoed_record_count += 1 continue #### Otherwise, build the record: record_id = "%s/%s" % (wayback_date_str, base64.b64encode(hashlib.md5(url.encode('utf-8')).digest())) title = target['title'] # set the rights and wayback_url depending on licence if target.get('isOA', False): rights = '***Free access' wayback_url = 'https://www.webarchive.org.uk/wayback/archive/' + wayback_date_str + '/' + url else: rights = '***Available only in our Reading Rooms' wayback_url = 'https://bl.ldls.org.uk/welcome.html?' + wayback_date_str + '/' + url rec = { 'id': record_id, 'date': first_date, 'url': url, 'title': title, 'rights': rights, 'publisher': publisher, 'wayback_url': wayback_url } # Add any collection: if len(target['subject_ids']) > 0: sub0 = subjects_by_id.get(int(target['subject_ids'][0]), {}) rec['subject'] = sub0.get('name', None) # And append record to the set: records.append(rec) self.record_count += 1 # declare output XML namespaces OAINS = 'http://www.openarchives.org/OAI/2.0/' OAIDCNS = 'http://www.openarchives.org/OAI/2.0/oai_dc/' DCNS = 'http://purl.org/dc/elements/1.1/' XLINKNS = 'http://www.w3.org/1999/xlink' OAIDC_B = "{%s}" % OAIDCNS DC_B = "{%s}" % DCNS XLINK_B = "{%s}" % XLINKNS # create OAI-PMH XML via lxml oaiPmh = etree.Element('OAI-PMH', nsmap={None:OAINS, 'oai_dc':OAIDCNS, 'dc':DCNS, 'xlink':XLINKNS}) listRecords = etree.SubElement(oaiPmh, 'ListRecords') for rec in records: record = etree.SubElement(listRecords, 'record') # header header = etree.SubElement(record, 'header') identifier = etree.SubElement(header, 'identifier') identifier.text = rec['id'] # metadata metadata = etree.SubElement(record, 'metadata') dc = etree.SubElement(metadata, OAIDC_B+'dc') source = etree.SubElement(dc, DC_B+'source' ) source.text = rec['url'] publisher = etree.SubElement(dc, DC_B+'publisher' ) publisher.text = rec['publisher'] title = etree.SubElement(dc, DC_B+'title' ) title.text = rec['title'] date = etree.SubElement(dc, DC_B+'date' ) date.text = rec['date'] rights = etree.SubElement(dc, DC_B+'rights' ) rights.text = rec['rights'] href = etree.SubElement(dc, XLINK_B+'href' ) href.text = rec['wayback_url'] if 'subject' in rec: subject = etree.SubElement(dc, DC_B+'subject') subject.text = rec['subject'] # output OAI-PMH XML with self.output().open('w') as f: f.write(etree.tostring(oaiPmh, xml_declaration=True, encoding='UTF-8', pretty_print=True)) def get_metrics(self, registry): # type: (CollectorRegistry) -> None g = Gauge('ukwa_record_count', 'Total number of UKWA records.', labelnames=['kind', 'status'], registry=registry) g.labels(kind='targets', status='_any_').set(self.target_count) g.labels(kind='collections', status='_any_').set(self.collection_count) g.labels(kind='collections', status='published').set(self.collection_published_count) g.labels(kind='subjects', status='_any_').set(self.subject_count) g.labels(kind='title_level', status='complete').set(self.record_count) g.labels(kind='title_level', status='blocked').set(self.blocked_record_count) g.labels(kind='title_level', status='missing').set(self.missing_record_count) g.labels(kind='title_level', status='embargoed').set(self.embargoed_record_count)
class BuildInternalReportingUserActivityCombinedView(VerticaCopyTaskMixin, WarehouseMixin, luigi.Task): """luigi task to build the combined view on top of the history and production tables for user activity.""" date = luigi.DateParameter() n_reduce_tasks = luigi.Parameter() history_schema = luigi.Parameter(default='history') def requires(self): return {'insert_source': LoadInternalReportingUserActivityToWarehouse( n_reduce_tasks=self.n_reduce_tasks, date=self.date, warehouse_path=self.warehouse_path, overwrite=self.overwrite, schema=self.schema, credentials=self.credentials), 'credentials': ExternalURL(self.credentials)} @property def view(self): """The "table name" is the name of the view we build over the table we insert here and the history table.""" return "f_user_activity_combined" def update_id(self): """All that matters is whether we've built the view before, and the parameter information doesn't matter.""" return "user_activity_view_built" def run(self): """Construct the view on top of the historical and new user activity tables.""" connection = self.output().connect() try: cursor = connection.cursor() # We mark this task as complete first, since the view creation does an implicit commit. self.output().touch(connection) # Creating the view commits the transaction as well. build_view_query = """CREATE VIEW {schema}.{view} AS SELECT * FROM ( SELECT * FROM {schema}.f_user_activity UNION SELECT * FROM {history}.f_user_activity ) AS u""".format(schema=self.schema, view=self.view, history=self.history_schema) log.debug(build_view_query) cursor.execute(build_view_query) log.debug("Committed transaction.") except Exception as exc: log.debug("Rolled back the transaction; exception raised: %s", str(exc)) connection.rollback() raise finally: connection.close() def output(self): """ Returns a Vertica target noting that the update occurred. """ return CredentialFileVerticaTarget( credentials_target=self.input()['credentials'], table=self.view, schema=self.schema, update_id=self.update_id() ) def complete(self): """ OverwriteOutputMixin redefines the complete method so that tasks are re-run, which is great for the Vertica loading tasks where we would delete and then re-start, but detrimental here, as the existence of the view does not depend on the data inside the table, only on the table's existence. We override this method again to revert to the standard luigi complete() method, because we can't meaningfully re-run this task given that CREATE VIEW IF NOT EXISTS and DROP VIEW IF EXISTS are not supported in Vertica. """ return self.output().exists()
class GenerateIndexAnnotations(luigi.Task): """ Gets the annotations needed for full-text indexing. """ task_namespace = 'discovery' date = luigi.DateParameter(default=datetime.date.today()) def output(self): return state_file(self.date,'access-data', 'indexer-annotations.json') def requires(self): return [TargetList(), CollectionList(), SubjectList()] def add_annotations(self, annotations, collection, targets_by_id, prefix=""): # assemble full collection name: collection_name = "%s%s" % (prefix, collection['name']) # deal with all targets: for tid in collection.get('target_ids',[]): if tid not in targets_by_id: logger.error("Target %i not found in targets list!" % tid) continue target = targets_by_id[tid] scope = target['scope'] if scope is None or scope == '': logger.error("Scope not set for %s - %s!" % (tid, target['urls']) ) continue for url in target.get('urls',[]): ann = annotations['collections'][scope].get(url, {'collection': collection_name, 'collections': [], 'subject': []}) if collection_name not in ann['collections']: ann['collections'].append(collection_name) # And subjects: for sid in target['subject_ids']: subject_name = self.subjects_by_id[sid]['name'] if subject_name not in ann['subject']: ann['subject'].append(subject_name) # and patch back in: annotations['collections'][scope][url] = ann # And add date ranges: annotations['collectionDateRanges'][collection_name] = {} if collection['start_date']: annotations['collectionDateRanges'][collection_name]['start'] = collection['start_date'] else: annotations['collectionDateRanges'][collection_name]['start'] = None if collection['end_date']: annotations['collectionDateRanges'][collection_name]['end'] = collection['end_date'] else: annotations['collectionDateRanges'][collection_name]['end'] = None # And process child collections: for child_collection in collection['children']: self.add_annotations(annotations, child_collection, targets_by_id, prefix="%s|" % collection_name) def run(self): targets = json.load(self.input()[0].open()) collections = json.load(self.input()[1].open()) subjects = json.load(self.input()[2].open()) # build look-up table for Target IDs targets_by_id = {} target_count = 0 for target in targets: tid = target['id'] targets_by_id[tid] = target target_count += 1 logger.info("Found %i targets..." % target_count) # build look-up table for subjects self.subjects_by_id = {} for top_level_subject in subjects: self.subjects_by_id[top_level_subject['id']] = top_level_subject for child_subject in top_level_subject['children']: self.subjects_by_id[child_subject['id']] = child_subject # Assemble the annotations, keyed on scope + url: annotations = { "collections": { "subdomains": { }, "resource": { }, "root": { }, "plus1": { } }, "collectionDateRanges": { } } for collection in collections: self.add_annotations(annotations, collection, targets_by_id) with self.output().open('w') as f: f.write('{}'.format(json.dumps(annotations, indent=4)))
class DailyPullFromCybersourceTask(PullFromCybersourceTaskMixin, luigi.Task): """ A task that reads out of a remote Cybersource account and writes to a file. A complication is that this needs to be performed with more than one account (or merchant_id), with potentially different credentials. If possible, create the same credentials (username, password) for each account. Pulls are made for only a single day. This is what Cybersource supports for these reports, and it allows runs to performed incrementally on a daily tempo. """ # Date to fetch Cybersource report. run_date = luigi.DateParameter( default=datetime.date.today(), description='Default is today.', ) # This is the table that we had been using for gathering and # storing historical Cybersource data. It adds one additional # column over the 'PaymentBatchDetailReport' format. REPORT_NAME = 'PaymentSubmissionDetailReport' REPORT_FORMAT = 'csv' def requires(self): pass def run(self): self.remove_output_on_overwrite() auth = (self.username, self.password) response = requests.get(self.query_url, auth=auth) if response.status_code != requests.codes.ok: # pylint: disable=no-member msg = "Encountered status {} on request to Cybersource for {}".format( response.status_code, self.run_date) raise Exception(msg) with self.output().open('w') as output_file: output_file.write(response.content) def output(self): """Output is in the form {output_root}/cybersource/{CCYY-mm}/cybersource_{merchant}_{CCYYmmdd}.csv""" month_year_string = self.run_date.strftime('%Y-%m') # pylint: disable=no-member date_string = self.run_date.strftime('%Y%m%d') # pylint: disable=no-member filename = "cybersource_{merchant_id}_{date_string}.{report_format}".format( merchant_id=self.merchant_id, date_string=date_string, report_format=self.REPORT_FORMAT, ) url_with_filename = url_path_join(self.output_root, "cybersource", month_year_string, filename) return get_target_from_url(url_with_filename) @property def query_url(self): """Generate the url to download a report from a Cybersource account.""" slashified_date = self.run_date.strftime('%Y/%m/%d') # pylint: disable=no-member url = 'https://{host}/DownloadReport/{date}/{merchant_id}/{report_name}.{report_format}'.format( host=self.host, date=slashified_date, merchant_id=self.merchant_id, report_name=self.REPORT_NAME, report_format=self.REPORT_FORMAT) return url
class SnowflakeLoadTask(SnowflakeLoadDownstreamMixin, luigi.Task): """ A task for copying data into a Snowflake database table. """ date = luigi.DateParameter() output_target = None required_tasks = None def requires(self): if self.required_tasks is None: self.required_tasks = { 'credentials': ExternalURL(url=self.credentials), 'insert_source_task': self.insert_source_task, } return self.required_tasks @property def insert_source_task(self): """ Defines the task that provides source of data. """ raise NotImplementedError @property def table(self): """ Provides the name of the database table. """ raise NotImplementedError @property def columns(self): """ Provides definition of columns. If only writing to existing tables, then columns() need only provide a list of names. If also needing to create the table, then columns() should define a list of (name, definition) tuples. For example, ('first_name', 'VARCHAR(255)'). """ raise NotImplementedError @property def file_format_name(self): raise NotImplementedError @property def pattern(self): """ Files matching this pattern will be used in the COPY operation. """ return ".*" @property def table_description(self): """ Description of table containing various facts, such as import time and excluded fields. """ return '' @property def qualified_stage_name(self): """ Fully qualified stage name. """ return "{database}.{schema}.{table}_stage".format( database=self.sf_database, schema=self.schema, table=self.table, ) @property def qualified_table_name(self): """ Fully qualified table name. """ return qualified_table_name( database=self.sf_database, schema=self.schema, table=self.table, ) @property def qualified_scratch_table_name(self): """ Fully qualified scratch table name. """ return "{database}.{scratch_schema}.{table}_{run_id}".format( database=self.sf_database, scratch_schema=self.scratch_schema, table=self.table, run_id=self.run_id, ) def create_scratch_table(self, connection): coldefs = ','.join( '{name} {definition}'.format(name=name, definition=definition) for name, definition in self.columns) query = "CREATE TABLE {scratch_table} ({coldefs}) COMMENT='{comment}'".format( scratch_table=self.qualified_scratch_table_name, coldefs=coldefs, comment=self.table_description.replace("'", "\\'")) _execute_query(connection, query) def create_format(self, connection): """ Invoke Snowflake's CREATE FILE FORMAT statement to create the named file format which configures the loading. The resulting file format name should be: {self.sf_database}.{self.schema}.{self.file_format_name} """ raise NotImplementedError def create_stage(self, connection): stage_url = canonicalize_s3_url( self.input()['insert_source_task'].path) query = """ CREATE OR REPLACE STAGE {stage_name} URL = '{stage_url}' CREDENTIALS = (AWS_KEY_ID='{aws_key_id}' AWS_SECRET_KEY='{aws_secret_key}') FILE_FORMAT = {database}.{schema}.{file_format_name}; """.format( stage_name=self.qualified_stage_name, database=self.sf_database, schema=self.schema, stage_url=stage_url, aws_key_id=self.output().aws_key_id, aws_secret_key=self.output().aws_secret_key, file_format_name=self.file_format_name, ) _execute_query(connection, query) def init_copy(self, connection): self.attempted_removal = True if self.overwrite: # Delete all markers related to this table self.output().clear_marker_table(connection) def copy(self, connection): query = """ COPY INTO {scratch_table} FROM @{stage_name} PATTERN='{pattern}' """.format( scratch_table=self.qualified_scratch_table_name, stage_name=self.qualified_stage_name, pattern=self.pattern, ) log.debug(query) _execute_query(connection, query) def swap(self, connection): query = """ ALTER TABLE {scratch_table} SWAP WITH {table} """.format( scratch_table=self.qualified_scratch_table_name, table=self.qualified_table_name, ) log.debug(query) try: _execute_query(connection, query) except ProgrammingError as err: if "does not exist" in str(err): # Since the table did not exist in the target schema, simply move it instead of swapping. query = """ ALTER TABLE {scratch_table} RENAME TO {table} """.format( scratch_table=self.qualified_scratch_table_name, table=self.qualified_table_name, ) _execute_query(connection, query) else: raise def drop_scratch(self, connection): query = """ DROP TABLE IF EXISTS {scratch_table} """.format(scratch_table=self.qualified_scratch_table_name) log.debug(query) _execute_query(connection, query) def run(self): connection = self.output().connect() try: cursor = connection.cursor() self.create_scratch_table(connection) self.create_format(connection) self.create_stage(connection) cursor.execute("BEGIN") self.init_copy(connection) self.copy(connection) self.swap(connection) self.drop_scratch(connection) self.output().touch(connection) connection.commit() except Exception as exc: log.exception("Rolled back the transaction; exception raised: %s", str(exc)) connection.rollback() raise finally: connection.close() def output(self): if self.output_target is None: self.output_target = SnowflakeTarget( credentials_target=self.input()['credentials'], database=self.sf_database, schema=self.schema, scratch_schema=self.scratch_schema, run_id=self.run_id, table=self.table, role=self.role, warehouse=self.warehouse, update_id=self.update_id(), ) return self.output_target def update_id(self): return '{task_name}(date={key})'.format(task_name=self.task_family, key=self.date.isoformat())
class BigQueryLoadTask(BigQueryLoadDownstreamMixin, luigi.Task): # Regardless whether loading only a partition or an entire table, # we still need a date to use to mark the table. date = luigi.DateParameter() output_target = None required_tasks = None def requires(self): if self.required_tasks is None: self.required_tasks = { 'credentials': ExternalURL(url=self.credentials), 'source': self.insert_source_task, } return self.required_tasks @property def insert_source_task(self): raise NotImplementedError @property def table(self): raise NotImplementedError @property def schema(self): raise NotImplementedError @property def table_description(self): return '' @property def table_friendly_name(self): return '' @property def partitioning_type(self): """Set to 'DAY' in order to partition by day. Default is to not partition at all.""" return None @property def field_delimiter(self): return "\t" @property def null_marker(self): return '\N' @property def quote_character(self): return '' def create_dataset(self, client): dataset = client.dataset(self.dataset_id) if not dataset.exists(): dataset.create() def create_table(self, client): dataset = client.dataset(self.dataset_id) table = dataset.table(self.table, self.schema) if not table.exists(): if self.partitioning_type: table.partitioning_type = self.partitioning_type if self.table_description: table.description = self.table_description if self.table_friendly_name: table.friendly_name = self.table_friendly_name table.create() def init_copy(self, client): self.attempted_removal = True if self.overwrite: dataset = client.dataset(self.dataset_id) table = dataset.table(self.table) if self.partitioning_type: # Delete only the specific partition, and clear the marker only for the partition. # Note that there is no partition.exists() functionality that is useful. Instead, # it returns table.exists(). Likewise, partition.delete() is a no-op if the actual # partition doesn't exist. So if the table exists, we just try to delete the partition. if table.exists(): partition = self._get_table_partition(dataset, table) partition.delete() self.output().clear_marker_table_entry() else: # Delete the entire table and all markers related to the table. if table.exists(): table.delete() self.output().clear_marker_table() def _get_destination_from_source(self, source_path): parsed_url = urlparse.urlparse(source_path) destination_path = url_path_join('gs://{}'.format(parsed_url.netloc), parsed_url.path) return destination_path def _get_table_partition(self, dataset, table): date_string = self.date.isoformat() stripped_date = date_string.replace('-', '') partition_name = '{}${}'.format(table.name, stripped_date) return dataset.table(partition_name) def _copy_data_to_gs(self, source_path, destination_path): if self.is_file(source_path): return_code = subprocess.call( ['gsutil', 'cp', source_path, destination_path]) else: log.debug(" ".join( ['gsutil', '-m', 'rsync', source_path, destination_path])) return_code = subprocess.call( ['gsutil', '-m', 'rsync', source_path, destination_path]) if return_code != 0: raise RuntimeError( 'Error while syncing {source} to {destination}'.format( source=source_path, destination=destination_path, )) def _get_load_url_from_destination(self, destination_path): if self.is_file(destination_path): return destination_path else: return url_path_join(destination_path, '*') def _run_load_table_job(self, client, job_id, table, load_uri): job = client.load_table_from_storage(job_id, table, load_uri) job.field_delimiter = self.field_delimiter job.quote_character = self.quote_character job.null_marker = self.null_marker if self.max_bad_records > 0: job.max_bad_records = self.max_bad_records log.debug("Starting BigQuery Load job.") job.begin() wait_for_job(job, check_error_result=False) try: log.debug( " Load job started: %s ended: %s input_files: %s output_rows: %s output_bytes: %s", job.started, job.ended, job.input_files, job.output_rows, job.output_bytes) except KeyError as keyerr: log.debug(" Load job started: %s ended: %s No load stats.", job.started, job.ended) if job.error_result: for error in job.errors: log.debug(" Load error: %s", error) raise RuntimeError(job.errors) else: log.debug(" No errors encountered!") def run(self): self.check_bigquery_availability() client = self.output().client self.create_dataset(client) self.init_copy(client) self.create_table(client) dataset = client.dataset(self.dataset_id) table = dataset.table(self.table, self.schema) source_path = self.input()['source'].path destination_path = self._get_destination_from_source(source_path) self._copy_data_to_gs(source_path, destination_path) load_uri = self._get_load_url_from_destination(destination_path) if self.partitioning_type: partition = self._get_table_partition(dataset, table) partition.partitioning_type = self.partitioning_type job_id = 'load_{table}_{date_string}_{timestamp}'.format( table=self.table, date_string=self.date.isoformat(), timestamp=int(time.time())) self._run_load_table_job(client, job_id, partition, load_uri) else: job_id = 'load_{table}_{timestamp}'.format(table=self.table, timestamp=int( time.time())) self._run_load_table_job(client, job_id, table, load_uri) self.output().touch() def output(self): if self.output_target is None: self.output_target = BigQueryTarget( credentials_target=self.input()['credentials'], dataset_id=self.dataset_id, table=self.table, update_id=self.update_id(), ) return self.output_target def update_id(self): return '{task_name}(date={key})'.format(task_name=self.task_family, key=self.date.isoformat()) def is_file(self, path): if path.endswith('.tsv') or path.endswith('.csv') or path.endswith( '.gz'): return True else: return False def check_bigquery_availability(self): """Call to ensure fast failure if this machine doesn't have the Bigquery libraries available.""" if not bigquery_available: raise ImportError('Bigquery library not available')
class CourseActivityWeeklyTask(CourseActivityTask): """ Number of users performing each category of activity each ISO week. Note that this was the original activity metric, so it is stored in the original table that is simply named "course_activity" even though it should probably be named "course_activity_weekly". Also the schema does not match the other activity tables for the same reason. All references to weeks in here refer to ISO weeks. Note that ISO weeks may belong to different ISO years than the Gregorian calendar year. If, for example, you wanted to analyze all data in the past week, you could run the job on Monday and pass in 1 to the "weeks" parameter. This will not analyze data for the week that contains the current day (since it is not complete). It will only compute data for the previous week. TODO: update table name and schema to be consistent with other tables. Parameters: end_date (date): A day within the upper bound week. The week that contains this date will *not* be included in the analysis, however, all of the data up to the first day of this week will be included. This is consistent with all of our existing closed-open intervals. weeks (int): The number of weeks to include in the analysis, counting back from the week that contains the end_date. """ end_date = luigi.DateParameter(default=datetime.datetime.utcnow().date()) weeks = luigi.IntParameter(default=24) @property def interval(self): """Given the parameters, compute the first and last date of the interval.""" if self.weeks == 0: raise ValueError( 'Number of weeks to process must be greater than 0') starting_week = self.get_iso_week_containing_date( self.end_date - datetime.timedelta(weeks=self.weeks)) ending_week = self.get_iso_week_containing_date(self.end_date) # include all complete weeks up to but not including the week containing the end_date return luigi.date_interval.Custom(starting_week.monday(), ending_week.monday()) def get_iso_week_containing_date(self, date): iso_year, iso_weekofyear, iso_weekday = date.isocalendar() return Week(iso_year, iso_weekofyear) @property def table(self): return 'course_activity' @property def activity_query(self): # Note that hive timestamp format is "yyyy-mm-dd HH:MM:SS.ffff" so we have to snap all of our dates to midnight return """ SELECT act.course_id as course_id, CONCAT(cal.iso_week_start, ' 00:00:00') as interval_start, CONCAT(cal.iso_week_end, ' 00:00:00') as interval_end, act.category as label, COUNT(DISTINCT username) as count FROM user_activity_daily act JOIN calendar cal ON act.date = cal.date WHERE "{interval_start}" <= cal.date AND cal.date < "{interval_end}" GROUP BY act.course_id, cal.iso_week_start, cal.iso_week_end, act.category; """ @property def columns(self): return [ ('course_id', 'VARCHAR(255) NOT NULL'), ('interval_start', 'DATETIME NOT NULL'), ('interval_end', 'DATETIME NOT NULL'), ('label', 'VARCHAR(255) NOT NULL'), ('count', 'INT(11) NOT NULL'), ] @property def indexes(self): return [('course_id', 'label'), ('interval_end', )]
def testDate(self): p = luigi.DateParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(datetime.date(2001, 2, 3), _value(p))
class CourseActivityMonthlyTask(CourseActivityTask): """ Number of users performing each category of activity each calendar month. Note that the month containing the end_date is not included in the analysis. If, for example, you wanted to analyze all data in the past month, you could run the job on the first day of the following month pass in 1 to the "months" parameter. This will not analyze data for the month that contains the current day (since it is not complete). It will only compute data for the previous month. Parameters: end_date (date): A date within the month that will be the upper bound of the closed-open interval. months (int): The number of months to include in the analysis, counting back from the month that contains the end_date. """ end_date = luigi.DateParameter(default=datetime.datetime.utcnow().date()) months = luigi.IntParameter(default=6) @property def interval(self): """Given the parameters, compute the first and last date of the interval.""" from dateutil.relativedelta import relativedelta # We don't actually care about the particular day of the month in this computation since we are fixing both the # start and end dates to the first day of the month, so we can perform simple arithmetic with the numeric month # and only have to worry about adjusting the year. Note that bankers perform this arithmetic differently so it # is spelled out here explicitly even though their are third party libraries that contain this computation. if self.months == 0: raise ValueError( 'Number of months to process must be greater than 0') ending_date = self.end_date.replace(day=1) starting_date = ending_date - relativedelta(months=self.months) return luigi.date_interval.Custom(starting_date, ending_date) @property def table(self): return 'course_activity_monthly' @property def activity_query(self): return """ SELECT act.course_id as course_id, cal.year, cal.month, act.category as label, COUNT(DISTINCT username) as count FROM user_activity_daily act JOIN calendar cal ON act.date = cal.date WHERE "{interval_start}" <= cal.date AND cal.date < "{interval_end}" GROUP BY act.course_id, cal.year, cal.month, act.category; """ @property def columns(self): return [ ('course_id', 'VARCHAR(255) NOT NULL'), ('year', 'INT(11) NOT NULL'), ('month', 'INT(11) NOT NULL'), ('label', 'VARCHAR(255) NOT NULL'), ('count', 'INT(11) NOT NULL'), ] @property def indexes(self): return [('course_id', 'label'), ('year', 'month')]
class AnalysisTask(luigi.Task): """Extract and analyse arXiv data to produce data and charts for the arXlive front end to consume. Proposed charts: 1. distribution of dl/non dl papers by country (horizontal bar) 2. distribution of dl/non dl papers by city (horizontal bar) 3. % ML papers by year (line) 4. share of ML activity in arxiv subjects, pre/post 2012 (horizontal point / slope) 5. rca, pre/post 2012 by country (horizontal point / slope) 6. rca over time, citation > mean & top 50 countries (horizontal violin) [NOT DONE] Proposed table data: 1. top countries by rca (moving window of last 12 months?) [NOT DONE] Args: date (datetime): Datetime used to label the outputs _routine_id (str): String used to label the AWS task db_config_env (str): environmental variable pointing to the db config file db_config_path (str): The output database configuration mag_config_path (str): Microsoft Academic Graph Api key configuration path insert_batch_size (int): number of records to insert into the database at once (not used in this task but passed down to others) articles_from_date (str): new and updated articles from this date will be retrieved. Must be in YYYY-MM-DD format (not used in this task but passed down to others) """ date = luigi.DateParameter() _routine_id = luigi.Parameter() test = luigi.BoolParameter(default=True) db_config_env = luigi.Parameter() db_config_path = luigi.Parameter() mag_config_path = luigi.Parameter() insert_batch_size = luigi.IntParameter(default=500) articles_from_date = luigi.Parameter() s3_path_prefix = luigi.Parameter(default="s3://nesta-arxlive") raw_data_path = luigi.Parameter(default="raw-inputs") grid_task_kwargs = DictParameterPlus() cherry_picked = luigi.Parameter() def output(self): '''Points to the output database engine''' db_config = get_config(self.db_config_path, "mysqldb") db_config["database"] = 'dev' if self.test else 'production' db_config["table"] = "arXlive <dummy>" # NB: not a real table update_id = "ArxivAnalysis_{}_{}".format(self.date, self.test) return mysqldb.MySqlTarget(update_id=update_id, **db_config) def requires(self): s3_path_prefix = (f"{self.s3_path_prefix}/" f"automl/{self.date}") data_path = (f"{self.s3_path_prefix}/" f"{self.raw_data_path}/{self.date}") yield WriteTopicTask(raw_s3_path_prefix=self.s3_path_prefix, s3_path_prefix=s3_path_prefix, data_path=data_path, date=self.date, cherry_picked=self.cherry_picked, test=self.test, grid_task_kwargs=self.grid_task_kwargs) def run(self): # Threshold for testing year_threshold = 2008 if self.test else YEAR_THRESHOLD test_label = 'test' if self.test else '' # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) # All queries except last prepare temporary tables # and the final query produces the dataframe # which collects data, such that there is one row per # article / institute / institute country for query, is_last in sql_queries(): if not is_last: self.engine.execute(query) df = pd.read_sql(query, self.engine) logging.info(f"Dataset contains {len(df)} articles") # Manual hack to factor Hong Kong outside of China for city in [ "Hong Kong", "Tsuen Wan", "Tuen Mun", "Tai Po", "Sai Kung" ]: df.loc[df.institute_city == f"{city}, CN", "institute_country"] = "Hong Kong" # Manual hack to factor out transnational corps countries = set(df.institute_country) df['is_multinational'] = df['institute_name'].apply( lambda x: dc.is_multinational(x, countries)) df.loc[df.is_multinational, 'institute_city'] = df.loc[ df.is_multinational, 'institute_name'].apply(lambda x: ''.join(x.split("(")[:-1])) df.loc[df.is_multinational, 'institute_country'] = "Transnationals" # collect topics, determine which represents # deep_learning and apply flag terms = [ "deep", "deep_learning", "reinforcement", "neural_networks", "neural_network" ] min_weight = 0.1 if self.test else 0.3 dl_topic_ids = dc.get_article_ids_by_terms(self.engine, terms=terms, min_weight=min_weight) df['is_dl'] = df.article_id.apply(lambda i: i in dl_topic_ids) logging.info( f"Flagged {df.is_dl.sum()} deep learning articles in dataset") df['date'] = df.apply( lambda row: row.article_updated or row.article_created, axis=1) df['year'] = df.date.apply(lambda date: date.year) df = dc.add_before_date_flag(df, date_column='date', before_year=year_threshold) # first plot - dl/non dl distribution by country (top n) pivot_by_country = (pd.pivot_table( df.groupby(['institute_country', 'is_dl']).size().reset_index(drop=False), index='institute_country', columns='is_dl', values=0).apply(lambda x: 100 * (x / x.sum())).rename( columns={ True: 'DL', False: 'non DL' })) fig, ax = plt.subplots() (pivot_by_country.sort_values( 'DL', ascending=False)[:N_TOP].sort_values('DL').plot.barh( ax=ax, color=[COLOR_B, COLOR_A], width=0.6)) ax.set_xlabel( 'Percentage of DL papers in arXiv CompSci\ncategories, by country') ax.set_ylabel('') handles, labels = ax.get_legend_handles_labels() _ = ax.legend(labels=[labels[1], labels[0]], handles=[handles[1], handles[0]], title='') dc.plot_to_s3(STATIC_FILES_BUCKET, f'static/figure_1{test_label}.png', plt) # second plot - dl/non dl distribution by city (top n) pivot_by_city = (pd.pivot_table( df.groupby(['institute_city', 'is_dl']).size().reset_index(drop=False), index='institute_city', columns='is_dl', values=0).apply(lambda x: 100 * (x / x.sum())).rename( columns={ True: 'DL', False: 'non DL' })) fig, ax = plt.subplots() (pivot_by_city.sort_values( 'DL', ascending=False)[:N_TOP].sort_values('DL').plot.barh( ax=ax, color=[COLOR_B, COLOR_A], width=0.8)) ax.set_xlabel( 'Percentage of DL papers in arXiv CompSci\ncategories, by city or multinational' ) ax.set_ylabel('') handles, labels = ax.get_legend_handles_labels() ax.legend(labels=[labels[1], labels[0]], handles=[handles[1], handles[0]], title='') dc.plot_to_s3(STATIC_FILES_BUCKET, f'static/figure_2{test_label}.png', plt) # third plot - percentage of dl papers by year deduped = df.drop_duplicates('article_id') start_year = 2000 papers_by_year = pd.crosstab(deduped['year'], deduped['is_dl']).loc[start_year:] papers_by_year = (100 * papers_by_year.apply(lambda x: x / x.sum(), axis=1)) papers_by_year = papers_by_year.drop(False, axis=1) # drop non-dl column fig, ax = plt.subplots(figsize=(20, 8)) papers_by_year.plot(ax=ax, legend=None, color=COLOR_A, linewidth=10) plt.xlabel('\nYear of paper publication') plt.ylabel('Percentage of DL papers\nin arXiv CompSci categories\n') plt.xticks( np.arange(min(papers_by_year.index), max(papers_by_year.index) + 1, 1)) ax.set_xticklabels( ['' if i % 2 else y for i, y in enumerate(papers_by_year.index)]) dc.plot_to_s3(STATIC_FILES_BUCKET, f'static/figure_3{test_label}.png', plt) # fourth plot - share of DL activity by arxiv # subject pre/post threshold df_all_cats = pd.read_sql("SELECT * FROM arxiv_categories", self.engine) condition = (df_all_cats.id.str.startswith('cs.') | (df_all_cats.id.str == 'stat.ML')) all_categories = list(df_all_cats.loc[condition].description) _before = f'Before {year_threshold}' _after = f'After {year_threshold}' cat_period_container = [] for cat in all_categories: subset = df.loc[[cat in x for x in df['arxiv_category_descs']], :] subset_ct = pd.crosstab(subset[f'before_{year_threshold}'], subset.is_dl, normalize=0) # This is true for some categories in dev mode # due to a smaller dataset if list(subset_ct.index) != [False, True]: continue subset_ct.index = [_after, _before] # this try /except may not be required when # running on the full dataset try: cat_period_container.append( pd.Series(subset_ct[True], name=cat)) except KeyError: pass cat_thres_df = (pd.concat(cat_period_container, axis=1).T.sort_values(_after, ascending=False)) other = cat_thres_df[N_TOP:].mean().rename('Other') cat_thres_df = cat_thres_df[:N_TOP].append(other) fig, ax = plt.subplots() (100 * cat_thres_df[_before]).plot(markeredgecolor=COLOR_B, marker='o', color=COLOR_B, ax=ax, markerfacecolor=COLOR_B, linewidth=7.5) (100 * cat_thres_df[_after]).plot(markeredgecolor=COLOR_A, marker='o', color=COLOR_A, ax=ax, markerfacecolor=COLOR_A, linewidth=7.5) ax.vlines(np.arange(len(cat_thres_df)), ymin=len(cat_thres_df) * [0], ymax=100 * cat_thres_df[_after], linestyle=':', linewidth=2) ax.set_xticks(np.arange(len(cat_thres_df))) ax.set_xticklabels(cat_thres_df.index, rotation=40, ha='right') ax.set_ylabel('Percentage of DL papers,\n' 'by arXiv CompSci category') ax.legend() dc.plot_to_s3(STATIC_FILES_BUCKET, f'static/figure_4{test_label}.png', plt) # fifth chart - changes in specialisation before / after threshold (top n countries) dl_counts = df.groupby('institute_country')['is_dl'].count() # remove the bottom 10% of countries here top_countries = list( dl_counts.loc[dl_counts > dl_counts.quantile(0.25)].index) top_countries = df.institute_country.apply( lambda x: x in top_countries) # Only highly citated papers avg_citation_counts = df[['year', 'citation_count' ]].groupby('year').quantile(0.5) avg_citation_counts['citation_count'] = avg_citation_counts[ 'citation_count'].apply(lambda x: x if x > 0 else 1) highly_cited = map(lambda x: dc.highly_cited(x, avg_citation_counts), [row for _, row in df.iterrows()]) highly_cited = np.array(list(highly_cited)) if self.test: highly_cited = np.array([True] * len(df)) # Min year threshold min_year = (df.year >= MIN_RCA_YEAR if not self.test else df.year >= 2000) # Apply filters before calculating RCA top_df = df.loc[top_countries & highly_cited & min_year] logging.info(f'Got {len(top_df)} rows for RCA calculation.\n' 'Breakdown (ctry, cite, yr) = ' f'{sum(top_countries)}, ' f'{sum(highly_cited)}, {sum(min_year)}') before_year = top_df[f'before_{year_threshold}'] logging.info("Before is DL = " f"{sum(top_df.loc[before_year].is_dl)}") logging.info("After is DL = " f"{sum(top_df.loc[~before_year].is_dl)}") # Calculate revealed comparative advantage pre_threshold_rca = dc.calculate_rca_by_country( top_df[top_df[f'before_{year_threshold}']], country_column='institute_country', commodity_column='is_dl') post_threshold_rca = dc.calculate_rca_by_country( top_df[~top_df[f'before_{year_threshold}']], country_column='institute_country', commodity_column='is_dl') rca_combined = (pd.merge( pre_threshold_rca, post_threshold_rca, left_index=True, right_index=True, suffixes=('_before', '_after')).rename(columns={ 'is_dl_before': _before, 'is_dl_after': _after }).sort_values(_after, ascending=False)) top_dl_countries = list( top_df[['institute_country', 'is_dl']].groupby('institute_country').sum().sort_values( 'is_dl', ascending=False)[:N_TOP].index) condition = rca_combined.index.isin(top_dl_countries) rca_combined_top = rca_combined[condition] fig, ax = plt.subplots() rca_combined_top[_before].plot(markeredgecolor=COLOR_B, marker='o', markersize=20, color='white', ax=ax, markerfacecolor=COLOR_B, linewidth=0) rca_combined_top[_after].plot(markeredgecolor=COLOR_A, marker='o', markersize=20, color='white', ax=ax, markerfacecolor=COLOR_A, linewidth=0) col = [ COLOR_A if x > y else '#d18270' for x, y in zip( rca_combined_top[_after], rca_combined_top[_before]) ] ax.vlines(np.arange(len(rca_combined_top)), ymin=rca_combined_top[_before], ymax=rca_combined_top[_after], linestyle=':', color=col, linewidth=4) ax.hlines(y=1, xmin=-0.5, xmax=len(rca_combined_top) - 0.5, color='darkgrey', linestyle='--', linewidth=4) ax.set_xticks(np.arange(len(rca_combined_top))) ax.set_xlim(-1, len(rca_combined_top)) ax.set_xticklabels(rca_combined_top.index, rotation=40, ha='right') ax.legend() ax.set_ylabel('Specialisation in Deep Learning\n' 'relative to other arXiv CompSci categories') ax.set_xlabel('') dc.plot_to_s3(STATIC_FILES_BUCKET, f'static/figure_5{test_label}.png', plt) # mark as done logging.warning("Task complete") self.output().touch()
class Bar(RunOnceTask): date = luigi.DateParameter()
class HealthLabelTask(luigi.Task): """Apply health labels to the organisation data in MYSQL. Args: date (datetime): Datetime used to label the outputs _routine_id (str): String used to label the AWS task test (bool): True if in test mode insert_batch_size (int): Number of rows to insert into the db in a batch db_config_env (str): The output database envariable bucket (str): S3 bucket where the models are stored vectoriser_key (str): S3 key for the vectoriser model classifier_key (str): S3 key for the classifier model """ date = luigi.DateParameter() _routine_id = luigi.Parameter() test = luigi.BoolParameter() insert_batch_size = luigi.IntParameter(default=500) db_config_env = luigi.Parameter() bucket = luigi.Parameter() vectoriser_key = luigi.Parameter() classifier_key = luigi.Parameter() def requires(self): yield OrgGeocodeTask(date=self.date, _routine_id=self._routine_id, test=self.test, db_config_env="MYSQLDB", city_col=Organization.city, country_col=Organization.country, location_key_col=Organization.location_id, insert_batch_size=self.insert_batch_size, env_files=[find_filepath_from_pathstub("nesta/nesta/"), find_filepath_from_pathstub("config/mysqldb.config")], job_def="py36_amzn1_image", job_name=f"CrunchBaseOrgGeocodeTask-{self._routine_id}", job_queue="HighPriority", region_name="eu-west-2", poll_time=10, memory=4096, max_live_jobs=2) def output(self): """Points to the output database engine""" self.db_config_path = os.environ[self.db_config_env] db_config = get_config(self.db_config_path, "mysqldb") db_config["database"] = 'dev' if self.test else 'production' db_config["table"] = "Crunchbase health labels <dummy>" # Note, not a real table update_id = "CrunchbaseHealthLabel_{}".format(self.date) return MySqlTarget(update_id=update_id, **db_config) def run(self): """Apply health labels using model.""" # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) try_until_allowed(Base.metadata.create_all, self.engine) # collect and unpickle models from s3 logging.info("Collecting models from S3") s3 = boto3.resource('s3') vectoriser_obj = s3.Object(self.bucket, self.vectoriser_key) vectoriser = pickle.loads(vectoriser_obj.get()['Body']._raw_stream.read()) classifier_obj = s3.Object(self.bucket, self.classifier_key) classifier = pickle.loads(classifier_obj.get()['Body']._raw_stream.read()) # retrieve organisations and categories nrows = 1000 if self.test else None logging.info("Collecting organisations from database") with db_session(self.engine) as session: orgs = (session .query(Organization.id) .filter(Organization.is_health.is_(None)) .limit(nrows) .all()) for batch_count, batch in enumerate(split_batches(orgs, self.insert_batch_size), 1): batch_orgs_with_cats = [] for (org_id, ) in batch: with db_session(self.engine) as session: categories = (session .query(OrganizationCategory.category_name) .filter(OrganizationCategory.organization_id == org_id) .all()) # categories should be a list of str, comma separated: ['cat,cat,cat', 'cat,cat'] categories = ','.join(cat_name for (cat_name, ) in categories) batch_orgs_with_cats.append({'id': org_id, 'categories': categories}) logging.debug(f"{len(batch_orgs_with_cats)} organisations retrieved from database") logging.debug("Predicting health flags") batch_orgs_with_flag = predict_health_flag(batch_orgs_with_cats, vectoriser, classifier) logging.debug(f"{len(batch_orgs_with_flag)} organisations to update") with db_session(self.engine) as session: session.bulk_update_mappings(Organization, batch_orgs_with_flag) logging.info(f"{batch_count} batches health labeled and written to db") # mark as done logging.warning("Task complete") self.output().touch()
class InputText(luigi.ExternalTask): date = luigi.DateParameter() def output(self): return luigi.hdfs.HdfsTarget( self.date.strftime('/tmp/text/%Y-%m-%d.txt'))
class SomeDailyTask(luigi.Task): d = luigi.DateParameter() def output(self): return MockTarget(self.d.strftime('/data/2014/p/v/z/%Y_/_%m-_-%doctor/20/ZOOO'))
class AMSLWisoPackages(AMSLTask): """ Collect WISO packages. XXX(miku): Throw this away. """ date = luigi.DateParameter(default=datetime.date.today()) encoding = luigi.Parameter(default='latin-1', description='wiso journal id csv file encoding', significant=False) def requires(self): return AMSLService(date=self.date) def hardcoded_list_of_wiso_journal_identifiers(self): """ Refs. #10707, Att. #4444. """ ids = set() filename = self.assets( 'wiso/645896059854847ce4ccd1416e11ba372e45bfd6.csv') with io.open(filename, encoding=self.encoding) as handle: for i, line in enumerate(handle): if i == 0: continue line = line.strip() if not line: continue fields = line.split(';') if len(fields) < 12: raise ValueError( 'expected ; separated KBART-ish file with journal identifier at column 11' ) id = fields[11].strip() if not id: continue ids.add(id) return sorted(ids) def resolve_ubl_profile(self, colls): """ Given a list of collection names, replace the complete list with hardcoded WISO journal identifiers from #10707/4444. """ if 'wiso UB Leipzig Profil' in colls: return self.hardcoded_list_of_wiso_journal_identifiers() return colls def run(self): with self.input().open() as handle: doc = json.loads(handle.read()) isilpkg = collections.defaultdict(lambda: collections.defaultdict(set)) for item in doc: isil, sid = item.get('ISIL'), item.get('sourceID') mega_collection = item.get('megaCollection') lthf = item.get('linkToHoldingsFile') if sid != "48": continue isilpkg[isil][lthf].add(mega_collection) filterconfig = collections.defaultdict(dict) fzs_package_name = 'Genios (Fachzeitschriften)' for isil, blob in list(isilpkg.items()): include_fzs = False for _, colls in list(blob.items()): if fzs_package_name in colls: include_fzs = True filters = [] if include_fzs and isil != 'DE-15-FID': packages = set( itertools.chain(*[c for _, c in list(blob.items())])) packages = self.resolve_ubl_profile(packages) filters.append({ 'and': [{ 'source': ['48'] }, { 'package': [fzs_package_name] }, { 'package': [ name for name in packages if name != fzs_package_name ] }] }) for lthf, colls in list(blob.items()): if lthf is None or lthf == 'null': continue if isil == 'DE-15-FID': colls = self.resolve_ubl_profile(colls) filter = { 'and': [ { 'source': ['48'] }, { 'holdings': { 'urls': [lthf] } }, { 'package': [c for c in colls if c != fzs_package_name] }, ] } else: filter = { 'and': [{ 'source': ['48'] }, { 'holdings': { 'urls': [lthf] } }, { 'package': [c for c in colls if c != fzs_package_name] }, { 'not': { 'package': [fzs_package_name] } }] } filters.append(filter) filterconfig[isil] = {'or': filters} with self.output().open('w') as output: json.dump(filterconfig, output, cls=SetEncoder) def output(self): return luigi.LocalTarget(path=self.path(ext='json'))