def main(): root_coding_node = generate_coding_nodes(annotations_xlsx_file_path) all_article_annotated_dict, count_successful_xml_parsing, count_not_successful_xml_parsing = parse_xml_articles( articles_xml_directory=articles_xml_directory, root_coding_node=root_coding_node, limit_percent_data=limit_percent_data) count_sucessful_coding_extraction, count_not_sucessful_coding_extraction = connect_codings_with_articles( annotations_xlsx_file_path=annotations_xlsx_file_path, article_annotated_dict=all_article_annotated_dict, root_coding_node=root_coding_node) log_manager.info_global( "\n" + "\nNumber of xml files succesfully parsed and transformed: " + str(count_successful_xml_parsing) + "\nNumber of xml files not parsed: " + str(count_not_successful_xml_parsing) + "\nAll files' contents were saved regardless of successful xml parsing or not." "The xml parsing is done only for extracting correct meta data." "\n" + "\nNumber of codings succesfully integrated into articles: " + str(count_sucessful_coding_extraction) + "\nNumber of codings not integrated: " + str(count_not_sucessful_coding_extraction) + "\n") return root_coding_node, list(all_article_annotated_dict.values())
def init_reference_tables(config: Type[ConfigRoot]): """Creates the main index table where all articles of a sketch engine corpus are loaded and their pos_ids and urls are persisted. Should be only done once per sub corpus. Other indices will refer to this main index as foreign keys.""" log_manager.info_global("--------------------------------" "\nCreating main reference table for indexing\n") reference_table_initiator.init( db_config={ "host": config.db_host, "dbname": config.db_name, "user": config.db_user, "password": config.db_password, "port": config.db_port }, ske_config={ "ske_rest_url": config.ske_rest_url, "ske_corpus_id": config.ske_corpus_id, "ske_user": config.ske_user, "ske_password": config.ske_password }, table_name_ref_articles=config.table_name_ref_articles, table_name_ref_sentences=config.table_name_ref_sentences, spacy_base_model=config.spacy_base_model, should_do_dummy_run=config.should_do_dummy_run, )
def correct_anfang_ende(self, anfang, ende, segment_for_check): corrected_anfang = anfang corrected_ende = ende for empty_line_number in self.empty_line_number_list: if corrected_anfang >= empty_line_number: corrected_anfang += 1 if corrected_ende >= empty_line_number: corrected_ende += 1 if corrected_anfang < empty_line_number and corrected_ende < empty_line_number: break segment_lines = segment_for_check.splitlines() article_lines = self.article_file_content.splitlines() if ((segment_lines[0] not in article_lines[corrected_anfang - 1] or segment_lines[-1] not in article_lines[corrected_ende - 1]) and ('fieldname="inhalt"' not in segment_lines[0] and 'fieldname="inhalt"' not in segment_lines[-1])): log_manager.info_global( f"Correcting the Anfang and Ende did not work correctly. Can be ignored most likely. article_id: {self.article_id}" ) return corrected_anfang, corrected_ende
def transform_cats( gold_data_transform_rule: Type[TransformRule], gold_data_container: GoldDataContainer) -> GoldDataContainer: # Replace the cats_list overview of gold_data_container cats_list_new = [] for cat_old, cat_new in gold_data_transform_rule.cat_replacements: if cat_new not in cats_list_new: found_old = False for cat_old_assigned in gold_data_container.cats_list: if cat_old_assigned == cat_old: found_old = True cats_list_new.append(cat_new) if not found_old: log_manager.info_global( f"Did not find Occurence of category to be replaced: '{cat_old}'\n" "It could be the case that no text got assigned to this category and it was filtered out.\n" "If so, ignore this warning.") gold_data_container.cats_list = cats_list_new # Replace the assigned cats to each gold_data_item for gold_data_item in gold_data_container.gold_data_item_list: gold_data_item: GoldDataItem cats_dict_old = gold_data_item.cats cats_dict_new = {} for cat_old, cat_new in gold_data_transform_rule.cat_replacements: if cat_old in cats_dict_old: if cats_dict_old[cat_old] == 1: cats_dict_new[cat_new] = 1 elif cats_dict_old[ cat_old] == 0 and cat_new not in cats_dict_new: cats_dict_new[cat_new] = 0 if len(cats_dict_new.keys()) == 0: raise Exception( "Not a single transformation was done. This can't be on purpose?" ) gold_data_item.cats = cats_dict_new return gold_data_container
def run_prodigy(config: Type[ConfigRoot]): log_manager.info_global("--------------------------------" "\nRunning prodigy\n") if config.should_do_dummy_run: config.db_host = "127.0.0.1" config.db_name = "mara_db_dummy" config.db_user = "******" config.db_password = "******" config.db_port = 5432 config.ske_batch_size = 2 prodigy_manager.run( ske_config={ "ske_rest_url": config.ske_rest_url, "ske_corpus_id": config.ske_corpus_id, "ske_user": config.ske_user, "ske_password": config.ske_password, }, db_config={ "host": config.db_host, "dbname": config.db_name, "user": config.db_user, "password": config.db_password, "port": config.db_port }, dataset_name=config.prodigy_dataset_name, index1_table_name=config.index_table_name, index2_table_names=config.index_lmvr_table_names, ske_translation_table_name=config.ske_docid_url_table_name)
def load_gold_data(config: Type[ConfigRoot]) -> GoldDataContainer: log_manager.info_global("--------------------------------" "\nLoading gold data from json file\n") return gold_data_manager.load_from_json( gold_data_json_path=config.gold_data_json_path)
def write_df_to_db(df, index_table_name, db_config): log_manager.debug_global("Creating SqlAlchemy engine ...") engine = sqlalchemy.create_engine( sqlalchemy.engine.url.URL('postgresql+psycopg2', host=db_config['host'], port=db_config['port'], username=db_config['user'], password=db_config['password'], database=db_config['dbname'])) try: log_manager.debug_global( f"Writing DataFrame to DB {index_table_name} ...") df.to_sql(index_table_name, engine, if_exists='append') except ValueError as e: log_manager.info_global( f"Can't write DataFrame to Database: table {index_table_name} already exists" ) finally: log_manager.debug_global("Disposing of SqlAlchemy engine ...") engine.dispose()
def transform_gold_data(config: Type[ConfigRoot], gold_data_container: GoldDataContainer): log_manager.info_global("--------------------------------" "\nTransforming gold data\n") return gold_data_manager.transform_cats( gold_data_transform_rule=config.gold_data_transform_rule, gold_data_container=gold_data_container)
def _populate_index_table(pos_start): log_manager.info_global(f"Start indexing at pos > {pos_start}") trainer.nlp.max_length = 3000000 if should_do_dummy_run: limit = sql.SQL("LIMIT 20") else: limit = sql.SQL("") db_cursor.execute( sql.SQL(""" SELECT * FROM {table_name_ref_articles} WHERE pos_mara002 > {pos_start} ORDER BY pos_mara002 {limit} """).format(table_name_ref_articles=sql.Identifier( table_name_ref_articles), pos_start=sql.Literal(pos_start), limit=limit)) pos_last = ske_manager.get_last_pos(ske_config) pos_percent_step = int((pos_last - pos_start) / 1000) pos_percent_current = pos_percent_step for dict_row in db_cursor.fetchall(): pos = dict_row["pos_mara002"] if pos > pos_percent_current: log_manager.info_global( f"Currently at {round(pos_percent_current / pos_percent_step / 10, 1)}% ; at pos {pos} out of {pos_last}" ) pos_percent_current += pos_percent_step db_connection.commit() text = ske_manager.get_doc_from_pos(ske_config, pos)["text"] cats = trainer.nlp(text).cats sql_col_list = [sql.Identifier("pos_mara002")] sql_val_list = [sql.Literal(pos)] for k, v in cats.items(): sql_col_list.append(sql.Identifier(k)) sql_val_list.append(sql.Literal(round(v, 6))) sql_col_stmt = sql.SQL(", ").join(sql_col_list) sql_val_stmt = sql.SQL(", ").join(sql_val_list) db_cursor.execute( sql.SQL(""" INSERT INTO {index_table_name} ({cols}) VALUES({vals}) """).format(index_table_name=sql.Identifier(index_table_name), cols=sql_col_stmt, vals=sql_val_stmt))
def main(config: Type[ConfigRoot]): config, pipe = handle_cli_args(config) # TODO: It isn't ideal to init the log_manager here because we might want output during _update_config_with_cli_args # But I really wanted to get a custom log_global_path log_manager.initialize(config.log_global_path) log_manager.info_global(f"--------------------------------" f"\nSTART MAIN\n") pipe.run()
def persist_to_json(gold_data_json_path: str, gold_data_container: GoldDataContainer): gold_data_json_dict = export_to_dict( gold_data_container=gold_data_container) with open(gold_data_json_path, "w", encoding="utf8") as f: # makes the json more readable but roughly 2x bigger when it contains only sentences # json.dump(gold_data_json_dict, f, indent=2, ensure_ascii=False) json.dump(gold_data_json_dict, f, ensure_ascii=False) log_manager.info_global(f"Persisted to file: {gold_data_json_path}")
def read_keyword_df(data_path): # for each keyword, store its token count in the entire corpus # we extract keywords and their token count from the filenames in the DOCIDS folder # the column 'csv_types' points to the CSV file that contains the doc ids where this keyword appears keyword_df = pd.DataFrame({'csv_types': os.listdir(f'{data_path}/DOCIDS')}) # set up a regex pattern to extract information out of the file name pattern = re.compile( r'docids_(freqs_(.*)_id_word_0_0_0_mara002)_n_([0-9]+)\.csv') # group 1 is the name of the CSV file containing the tokens, minus the extension '.csv' # group 2 is the keywords ID, consisting of a number and a human-readable representation # group 3 is the token count of this keyword in the entire corpus csv_tokens = [] ids = [] counts = [] # iterate over the keywords for value in keyword_df['csv_types']: match = pattern.fullmatch(value) try: # store the group values csv_tokens.append(match.group(1) + '.csv') ids.append(match.group(2)) counts.append(int(match.group(3))) except AttributeError as e: log_manager.info_global( f"Unexpected file name: {value} \nDid not match the RegEx pattern." ) # store null values csv_tokens.append(None) ids.append(None) counts.append(None) # the column 'csv_tokens' points to the CSV files that contain # all doc ids where each keyword appears, together with its token count within that document keyword_df['csv_tokens'] = csv_tokens # the column 'token_count' contains the keywords' token count in the entire corpus keyword_df['corpus_count'] = counts keyword_df['keyword_id'] = ids keyword_df['category'] = keyword_df['keyword_id'].apply( lambda x: 'SM' if int(x[0:2]) >= 42 else 'SC' if int(x[0:2]) >= 30 else 'Allgemein') # set the keywords' IDs as the DF's row labels keyword_df = keyword_df.set_index('keyword_id') return keyword_df # with the column 'csv_tokens'
def add_coding(self, coding_anfang, coding_ende, coding_tag, coding_segment): def match_coding(coding_list, coding_node): for children_node in coding_node.children: if children_node.coding_value == coding_list[0]: if len(coding_list) > 1: return match_coding(coding_list[1:], children_node) else: children_node.article_annotated_set.add(self) return children_node raise Exception("Could not find matching coding_node!") codings_split = coding_tag.split("\\") matching_coding_node = match_coding(codings_split, self.root_coding_node) if codings_split[-1] != matching_coding_node.coding_value: raise Exception("Did not find correct matching coding_node") coding_dict = { "coding_anfang": coding_anfang, "coding_ende": coding_ende, "coding_tag": coding_tag, "Segment": coding_segment, "coding_node": matching_coding_node, } found_coding_dict = False for existing_coding_dict in self.coding_list: if coding_dict == existing_coding_dict: found_coding_dict = True log_manager.info_global( f"Redundant coding found. Can be ignored most likely. article_id: {self.article_id}, coding_dict: {coding_dict} of " ) if not found_coding_dict: self.coding_list.append(coding_dict)
def persist_gold_data( config: Type[ConfigRoot], gold_data_container: GoldDataContainer, ): log_manager.info_global( "--------------------------------" "\nPersisting transformed data into json structured for training\n") if config.should_do_dummy_run: config.gold_data_json_path = config.gold_data_json_path.replace( ".json", "__dummy.json") gold_data_container.gold_data_item_list = gold_data_container.gold_data_item_list[: 40] gold_data_manager.persist_to_json(config.gold_data_json_path, gold_data_container)
def create_table(db_connection, db_cursor, table_name): try: log_manager.debug_global(f"Dropping table {table_name} ...") db_cursor.execute( sql.SQL(""" DROP TABLE IF EXISTS {table}; """).format(table=sql.Identifier(table_name))) log_manager.debug_global(f"Creating table {table_name} ...") sql_stmt = sql.SQL(""" CREATE TABLE {table} ( {docid} varchar NOT NULL, {pos} varchar NOT NULL, {url} varchar NULL, CONSTRAINT ske_docid_pos_pk PRIMARY KEY ({docid}), CONSTRAINT ske_docid_pos_un_pos UNIQUE ({pos}), CONSTRAINT ske_docid_pos_un_url UNIQUE ({url}) ); """).format(table=sql.Identifier(table_name), docid=sql.Identifier('docid'), pos=sql.Identifier('pos_mara002'), url=sql.Identifier('url_index1')) db_cursor.execute(sql_stmt) db_connection.commit() except Exception as e: log_manager.info_global(f"There was an error: {e}") log_manager.debug_global( f"This was the SQL string: \n{sql_stmt.as_string(db_connection)}") log_manager.debug_global("Rolling back DB operations ...") db_connection.rollback() raise e return
def __init__(self, model_path=None, spacy_base_model=None, should_load_model=None, should_create_model=None, should_persist_model=None, cats=None, gold_data_json_path=None, exclusive_classes=None): if model_path is not None: if model_path[-1] != "/": model_path += "/" # TODO: when the same model is instantiated multiple times, this logger is instantiated multiple times too # TODO: change logger so that when should_persist_model==False, the logger only uses the global logger self.logger = log_manager.create_new_logger(model_path + "log.txt") else: self.logger = None # TODO: Use global log_manager instead log_manager.info_global(f"Instantiating Trainer Object") self.model_path = model_path self.train_data_json_path = gold_data_json_path self.exclusive_classes = exclusive_classes self.should_create_model = should_create_model self.should_load_model = should_load_model self.should_persist_model = should_persist_model self.nlp = None if should_create_model and not should_load_model: self.create_model(cats=cats, spacy_base_model=spacy_base_model) elif not should_create_model and should_load_model: self.load_model() else: raise Exception( "Ambiguity between should_load_model and should_create_model.")
def get_prodigy_data(dataset_name, db_config, ske_config) -> GoldDataContainer: from prodigy.components.db import connect db = connect(db_id='postgresql', db_settings=db_config) prodigy_data = db.get_dataset(dataset_name) if not prodigy_data: log_manager.info_global( f"Dataset {dataset_name} doesn't exist in the prodigy database!" ) return log_manager.info_global(f"Loaded {len(prodigy_data)} entries") prodigy_gold_data_container = transform_to_gold_data(prodigy_data, db_config, ske_config) return prodigy_gold_data_container
def init_trainer(config: Type[ConfigRoot], cats_list: List[str] = None) -> AbstractTrainer: log_manager.info_global("--------------------------------" "\nInitializing model\n") model_path = config.model_path if config.should_do_dummy_run and config.should_persist_model: model_path += "__dummy" assert config.trainer_class is not None return config.trainer_class( spacy_base_model=config.spacy_base_model, model_path=model_path, should_persist_model=config.should_persist_model, should_load_model=config.should_load_model, should_create_model=config.should_create_model, cats=cats_list, gold_data_json_path=config.gold_data_json_path, exclusive_classes=config.exclusive_classes)
def run_lvmr_indexer(config: Type[ConfigRoot]): log_manager.info_global("--------------------------------" "\nRunning LMVR indexer\n") lmvr_indexer.run(data_path=config.csv_folder_path, db_config={ "host": config.db_host, "dbname": config.db_name, "user": config.db_user, "password": config.db_password, "port": config.db_port }, index1_table_name=config.index_table_name, index2_table_names=config.index_lmvr_table_names, ske_config={ "ske_rest_url": config.ske_rest_url, "ske_corpus_id": config.ske_corpus_id, "ske_user": config.ske_user, "ske_password": config.ske_password })
def run_model_indexer(config: Type[ConfigRoot], trainer): log_manager.info_global("--------------------------------" "\nRunning indexer\n") ske_config = { "ske_rest_url": config.ske_rest_url, "ske_corpus_id": config.ske_corpus_id, "ske_user": config.ske_user, "ske_password": config.ske_password, } db_config = { "host": config.db_host, "dbname": config.db_name, "user": config.db_user, "password": config.db_password, "port": config.db_port } if config.indexing_function == model_indexer.index_articles: model_indexer.index_articles( ske_config=ske_config, db_config=db_config, index_table_name=config.index_table_name, trainer=trainer, table_name_ref_articles=config.table_name_ref_articles, should_do_dummy_run=config.should_do_dummy_run, ) elif config.indexing_function == model_indexer.index_sentences: model_indexer.index_sentences( ske_config=ske_config, db_config=db_config, index_table_name=config.index_table_name, trainer=trainer, table_name_ref_articles=config.table_name_ref_articles, should_do_dummy_run=config.should_do_dummy_run, )
def connect_codings_with_articles(annotations_xlsx_file_path, article_annotated_dict, root_coding_node): xl = pd.ExcelFile(annotations_xlsx_file_path) df_codings_on_articles = xl.parse("Codings") count_sucessful = 0 count_not_sucessful = 0 for row in df_codings_on_articles.iterrows(): row_data = row[1] try: article_annotated = article_annotated_dict[ row_data.Dokumentname] corrected_anfang, corrected_ende = article_annotated.correct_anfang_ende( anfang=row_data.Anfang, ende=row_data.Ende, segment_for_check=row_data.Segment) article_annotated.add_coding(coding_anfang=corrected_anfang, coding_ende=corrected_ende, coding_tag=row_data.Code, coding_segment=row_data.Segment) count_sucessful += 1 except Exception as ex: log_manager.info_global( f"{ex.__class__.__name__}: when integrating codings into article: {row_data.Dokumentname}: {ex}" ) count_not_sucessful += 1 return count_sucessful, count_not_sucessful
def load_from_maxqdata(config: Type[ConfigRoot]) -> GoldDataContainer: log_manager.info_global( "--------------------------------" "\nLoading and transforming data from AMC and maxqdata export file\n") if config.should_do_dummy_run: limit_percent_data = 10 else: limit_percent_data = 100 root_coding_node, article_annotated_list = maxqdata_manager.load_from_amc_and_maxqdata( annotations_xlsx_file_path=config.annotations_xlsx_file_path, articles_xml_directory=config.articles_xml_directory, limit_percent_data=limit_percent_data) if config.maxqdata_gold_data_transform_function == maxqdata_manager.transform_to_gold_data_articles: return maxqdata_manager.transform_to_gold_data_articles( root_coding_node=root_coding_node, article_annotated_list=article_annotated_list) elif config.maxqdata_gold_data_transform_function == maxqdata_manager.transform_to_gold_data_sentences: nlp = spacy.load(config.spacy_base_model, disable=["tagger", "parser", "ner"]) if "sentencizer" not in nlp.pipe_names: nlp.add_pipe(nlp.create_pipe("sentencizer")) sentence_split_func = lambda t: nlp(t).sents return maxqdata_manager.transform_to_gold_data_sentences( spacy_base_model=config.spacy_base_model, root_coding_node=root_coding_node, article_annotated_list=article_annotated_list, sentence_split_func=sentence_split_func) else: raise Exception("No maxqdata_gold_data_transform_function defined.")
def load_prodigy_dataset(config: Type[ConfigRoot]): log_manager.info_global("--------------------------------" "\nLoading prodigy data\n") prodigy_data = db_manager.get_prodigy_data( dataset_name=config.prodigy_dataset_name, db_config={ "host": config.db_host, "dbname": config.db_name, "user": config.db_user, "password": config.db_password, "port": config.db_port }, ske_config={ "ske_rest_url": config.ske_rest_url, "ske_corpus_id": config.ske_corpus_id, "ske_user": config.ske_user, "ske_password": config.ske_password }) return prodigy_data
def run_ske_translator(config: Type[ConfigRoot]): log_manager.info_global("--------------------------------" "\nRunning SKE translator\n") ske_translator.run(ske_config={ "ske_rest_url": config.ske_rest_url, "ske_corpus_id": config.ske_corpus_id, "ske_user": config.ske_user, "ske_password": config.ske_password }, db_config={ "host": config.db_host, "dbname": config.db_name, "user": config.db_user, "password": config.db_password, "port": config.db_port }, docid_table_name=config.ske_docid_url_table_name, index1_table_name=config.index_table_name, index2_table_names=config.index_lmvr_table_names, should_drop_create_table=config. should_drop_create_ske_translator_table)
def run_training(config: Type[ConfigRoot], trainer: AbstractTrainer, gold_data_container: GoldDataContainer): log_manager.info_global("--------------------------------" "\nTraining model\n") train_data_container, eval_data_container = gold_data_manager.split_into_train_eval_data( gold_data_container=gold_data_container, data_cutoff=config.train_data_cutoff, ) if config.should_do_dummy_run: train_data_container.gold_data_item_list = train_data_container.gold_data_item_list[: 20] eval_data_container.gold_data_item_list = eval_data_container.gold_data_item_list[: 20] config.iteration_limit = 2 trainer.train(train_data=train_data_container, eval_data=eval_data_container, iteration_limit=config.iteration_limit) return train_data_container, eval_data_container
def run(data_path, db_config, index1_table_name, index2_table_names, ske_config): start = datetime.datetime.now() log_manager.info_global("--------------------------------") log_manager.info_global( f"{start.strftime('[%y-%m-%d %H:%M:%S]')} START INDEXING\n") log_manager.info_global("Creating DB tables ...") create_tables(db_config, index1_table_name, index2_table_names) log_manager.info_global("Creating DataFrames from original CSV files ...") # 1. set up the keywords dataframe log_manager.debug_global("Creating DataFrame for keywords ...") keyword_df = read_keyword_df(data_path) # store the keywords df to the database log_manager.debug_global("Writing keywords DF to DB ...") write_df_to_db( keyword_df.drop(columns=['csv_tokens', 'csv_types'], inplace=False), index2_table_names['keywords'], db_config) # 2. set up the text token counts dataframe log_manager.debug_global("Creating DataFrame for token counts ...") token_df = pd.DataFrame() # in doc_df, we create a column for each keyword # and fill it with that keyword's token count in the given document bar = create_progress_bar('Calculating total of tokens per text', keyword_df.shape[0]) for kw in keyword_df.itertuples(): # kw is a Pandas object representing the row # we find the token counts in the CSV file stored in the column 'csv_tokens' of keyword_df temp_df = pd.read_csv(f'{data_path}/CSV/{kw.csv_tokens}', sep='\t', skiprows=8, names=['docid', 'token', 'token_count'], usecols=['docid', 'token_count']) # we need to group by doc id and sum all the token counts for various shapes of the token temp_df = temp_df.groupby(['docid'], as_index=False).sum() # add a column temp_df['keyword_id'] = kw.Index temp_df = temp_df.set_index(['keyword_id', 'docid'], verify_integrity=True) # 1st index: keyword_id, because this allows for fewer lookups when calculating the scores # we append the rows to token_df token_df = token_df.append(temp_df, verify_integrity=True) bar.next() bar.finish() # Don't write to token_df to DB yet because it has a FK constraint to doc_df. # 3. set up the texts dataframe log_manager.debug_global("Creating DataFrame for texts ...") # we use this file only to get a complete list of doc ids doc_df = pd.read_csv(f'{data_path}/mara002_kvr_all.docids.counts.csv', sep='\t', names=['types_count', 'docid'], usecols=['docid']) doc_df['score_rarity_diversity'] = 0.0 doc_df['already_annotated'] = False doc_df['selected_on'] = None doc_df = doc_df.set_index('docid') # Calculate scores log_manager.debug_global("Calculating scores for texts ...") doc_df = score_rarity_diversity(doc_df, keyword_df, token_df) # Write doc_df to DB log_manager.debug_global("Writing DF for texts to DB ...") write_df_to_db(doc_df, index2_table_names['scores'], db_config) # Now we can write token_df to the DB. log_manager.debug_global("Writing DF for tokens to DB ...") write_df_to_db(token_df, index2_table_names['tokens'], db_config) # all done! end = datetime.datetime.now() log_manager.info_global( f"{end.strftime('[%y-%m-%d %H:%M:%S]')} DONE INDEXING, duration: {end-start}" ) return # TODO: Is this empty return on purpose?
def _init_index_table(): log_manager.info_global( "Checking if index table exists or if it needs to be created") pos_start = None db_cursor.execute( sql.SQL(""" SELECT * FROM information_schema.tables WHERE table_name = {index_table_name} """).format(index_table_name=sql.Literal(index_table_name))) if db_cursor.rowcount == 1: log_manager.info_global( f"Index table exists, fetching highest pos to start / continue from" ) db_cursor.execute( sql.SQL(""" SELECT MAX(pos_mara002) FROM {index_table_name} """).format(index_table_name=sql.Identifier(index_table_name))) pos_start = db_cursor.fetchone()["max"] if pos_start is None: # So that when asking in _populate_index_table for WHERE pos > pos_start, the zero element will be used too pos_start = -1 else: log_manager.info_global( f"Index table does not exist, creating it and start from pos > -1" ) sql_cat_col_list = [ sql.SQL(" {c} DECIMAL,\n").format(c=sql.Identifier(cat)) for cat in trainer.cats ] sql_stmt_define_cat_cols = reduce(lambda a, b: a + b, sql_cat_col_list) sql_stmt_create_table = sql.SQL( "CREATE TABLE {index_table_name} (\n" " pos_mara002 INT,\n" "{cols}" " PRIMARY KEY (pos_mara002)," " CONSTRAINT fkc FOREIGN KEY(pos_mara002) REFERENCES {table_name_ref_articles}(pos_mara002)" ")").format(index_table_name=sql.Identifier(index_table_name), cols=sql_stmt_define_cat_cols, table_name_ref_articles=sql.Identifier( table_name_ref_articles)) log_manager.info_global( f"create table sql statement:\n{sql_stmt_create_table.as_string(db_cursor)}" ) db_cursor.execute(sql_stmt_create_table) # So that when asking in _populate_index_table for WHERE pos > pos_start, the zero element will be used too pos_start = -1 db_connection.commit() return pos_start
def update(examples): log_manager.debug_global("Prodigy: updating ...") nonlocal db_connection nonlocal db_cursor db_connection, db_cursor = db_manager.open_db_connection( db_config, db_connection, db_cursor) assert db_connection and db_connection.closed == 0 # 0 means 'open' assert db_cursor and not db_cursor.closed for example in examples: try: if index1_table_name and 'url' in example['meta']: url = example['meta']['url'] log_manager.debug_global( f"Storing annotation meta info for url={url} in table {index1_table_name} ..." ) db_cursor.execute( sql.SQL("UPDATE {index_table_name} " "SET already_annotated = TRUE " "WHERE {pk} = %(value)s").format( index_table_name=sql.Identifier( index1_table_name), pk=sql.Identifier('url')), {'value': url}) # TODO: this could be made safer to ensure # that index2 won't be updated accidentally with 'already_annotated' # when we are actually only streaming from index1. # # Curently the stream from index1 does not set 'docid' in example['meta'], # but this may not be good to rely on. if index2_table_names and 'docid' in example['meta']: docid = example['meta']['docid'] log_manager.debug_global( f"Storing annotation meta info for docid={docid} in table {index2_table_names['scores']} ..." ) db_cursor.execute( sql.SQL("UPDATE {index_table_name} " "SET already_annotated = TRUE " "WHERE {pk} = %(value)s").format( index_table_name=sql.Identifier( index2_table_names['scores']), pk=sql.Identifier('docid')), {'value': docid}) db_connection.commit() except Exception as ex: log_manager.info_global( f"Error storing an annotation in the database: {ex}") db_connection.rollback()
def run(ske_config, db_config, docid_table_name, index1_table_name, index2_table_names, should_drop_create_table=False): (db_connection, db_cursor) = db_manager.open_db_connection(db_config) if should_drop_create_table: create_table(db_connection, db_cursor, docid_table_name) # Direction 1: look for URLs that are not yet in the translation table # Hannes says that pos -> docid is faster than docid -> pos # because the SKE uses pos as internal indices log_manager.debug_global("Looking for URLs ...") url_records = select_urls_from_index1(db_cursor, docid_table_name, index1_table_name) log_manager.info_global(f"Found {len(url_records)} URLs to be converted. ") if len(url_records) > 0: ske_manager.create_session(ske_config) progressbar = progress.bar.Bar( 'Converting URLs to docid', max=len(url_records), suffix='%(index)d/%(max)d done, ETA: %(eta_td)s h') for record in url_records: url = record['url'] pos = ske_manager.get_pos_from_url(url) docid = ske_manager.get_docid_from_pos( ske_config, pos) # this calls the API endpoing 'fullref' insert_into_table(db_connection, db_cursor, docid_table_name, docid, pos, url) progressbar.next() progressbar.finish() # Direction 2: look for docids that are not yet in the translation table log_manager.debug_global("Looking for docids ...") docid_records = select_docids_from_index2(db_cursor, docid_table_name, index2_table_names) log_manager.debug_global( f"Found {len(docid_records)} docids to be converted.") if len(docid_records) > 0: ske_manager.create_session(ske_config) progressbar = progress.bar.Bar( 'Converting docids to URLs', max=len(docid_records), suffix='%(index)d/%(max)d done, ETA: %(eta_td)s h') for record in docid_records: docid = record['docid'] pos = ske_manager.get_pos_from_docid( ske_config, docid) # this calls the API endpoint 'first' url = ske_manager.get_url_from_pos(ske_config, pos) insert_into_table(db_connection, db_cursor, docid_table_name, docid, pos, url) progressbar.next() progressbar.finish() # All set! ske_manager.close_session() db_manager.close_db_connection(db_connection, db_cursor) return
def log_trainer(self, msg): self.logger.info(msg) log_manager.info_global(msg)