def run(): db_config = { "host": ConfigRoot.db_host, "dbname": ConfigRoot.db_name, "user": ConfigRoot.db_user, "password": ConfigRoot.db_password, "port": ConfigRoot.db_port } db_connection, db_cursor = db_manager.open_db_connection(db_config) table_name_ref_articles_old = ConfigIndexBase.table_name_ref_articles table_name_ref_articles_new = table_name_ref_articles_old + "_random_subset" ConfigIndexBase.table_name_ref_articles = table_name_ref_articles_new create_main_reference_table( db_cursor, db_connection, table_name_ref_articles_old, table_name_ref_articles_new) # TODO: comment this out if not needed for i in range(100): if i == 0: trainer1 = None trainer2 = None trainer1, trainer2 = train( trainer1, trainer2) # TODO: comment this out if not needed drop_index(db_cursor, db_connection) # TODO: comment this out if not needed create_index(trainer1, trainer2) # TODO: comment this out if not needed compare_index(db_cursor, i) # TODO: comment this out if not needed db_manager.close_db_connection(db_connection, db_cursor)
def stream_from_db_with_predictions(ske_config, db_config, index_table_name): log_manager.debug_global("Streaming from DB with predictions ...") db_connection = None db_cursor = None (db_connection, db_cursor) = db_manager.open_db_connection(db_config, db_connection, db_cursor) try: while True: db_cursor.execute( sql.SQL( 'SELECT *, ("AF: Social Companions" + "AF: Soziale Medien") AS AF_SC_SM ' 'FROM {index_table_name} ' 'WHERE already_annotated = FALSE ' 'AND already_selected = FALSE ' # left-over from the old system "AND ((selected_on IS NULL) OR (selected_on < (NOW() - INTERVAL '2 days'))) " 'ORDER BY AF_SC_SM ASC ' 'LIMIT 1').format( index_table_name=sql.Identifier(index_table_name))) result = db_cursor.fetchone() url = result["url"] _select_text(db_connection, db_cursor, index_table_name, 'url', url) options = _preselect_options(result) ske_doc = ske_manager.get_doc_from_url(ske_config, url) yield { "text": ske_doc["text"], "options": options['cats_as_options'], "accept": options['options_accepted'], "meta": { "url": url, "scores": options['scores_text'] } } except Exception as ex: print(ex) finally: db_manager.close_db_connection(db_connection, db_cursor)
def create_tables(db_config, index1_table_name, index2_table_names): (db_connection, db_cursor) = db_manager.open_db_connection(db_config) try: log_manager.debug_global("Dropping tables ...") db_cursor.execute( sql.SQL(""" DROP TABLE IF EXISTS {table_keywords}, {table_scores}, {table_tokens} CASCADE; DROP INDEX IF EXISTS {score_idx} CASCADE; """).format( table_keywords=sql.Identifier(index2_table_names['keywords']), table_scores=sql.Identifier(index2_table_names['scores']), table_tokens=sql.Identifier(index2_table_names['tokens']), score_idx=sql.Identifier( 'index_2__mara002__lmvr_scores_score_rarity_diversity_idx') )) # table 1: keywords log_manager.debug_global( f"Creating table {index2_table_names['keywords']} ...") db_cursor.execute( sql.SQL(""" CREATE TABLE {table} ( {pk} varchar NOT NULL, corpus_count int4 NOT NULL, category varchar NOT NULL, CONSTRAINT index_2__mara002__lmvr_keywords_pk PRIMARY KEY ({pk}) ); """).format(table=sql.Identifier(index2_table_names['keywords']), pk=sql.Identifier('keyword_id'))) # table 2: texts + scores log_manager.debug_global( f"Creating table {index2_table_names['scores']} ...") db_cursor.execute( sql.SQL(""" CREATE TABLE {table} ( {pk} varchar NOT NULL, {score1} numeric NOT NULL, already_annotated bool NULL, selected_on timestamptz NULL, CONSTRAINT index_2__mara002__lmvr_scores_pk PRIMARY KEY ({pk}) ); CREATE INDEX index_2__mara002__lmvr_scores_score_rarity_diversity_idx ON {table} USING btree ({score1} DESC); """).format(table=sql.Identifier(index2_table_names['scores']), pk=sql.Identifier('docid'), score1=sql.Identifier('score_rarity_diversity'))) # table 3: keywords in texts log_manager.debug_global( f"Creating table {index2_table_names['tokens']} ...") db_cursor.execute( sql.SQL(""" CREATE TABLE {table} ( {fk_texts} varchar NOT NULL, {fk_kw} varchar NOT NULL, token_count int4 NOT NULL DEFAULT 0, CONSTRAINT index_2__mara002__lmvr_tokens_pk PRIMARY KEY ({fk_texts}, {fk_kw}), CONSTRAINT index_2__mara002__lmvr_tokens_fk FOREIGN KEY ({fk_texts}) REFERENCES {table_texts}({fk_texts}) ON UPDATE CASCADE ON DELETE CASCADE, CONSTRAINT index_2__mara002__lmvr_tokens_fk_keyword FOREIGN KEY ({fk_kw}) REFERENCES {table_kw}({fk_kw}) ON UPDATE CASCADE ON DELETE CASCADE ); """).format( table=sql.Identifier(index2_table_names['tokens']), table_texts=sql.Identifier(index2_table_names['scores']), fk_texts=sql.Identifier('docid'), table_kw=sql.Identifier(index2_table_names['keywords']), fk_kw=sql.Identifier('keyword_id'))) db_connection.commit() except Exception as e: db_connection.rollback() raise e finally: db_manager.close_db_connection(db_connection, db_cursor) return # TODO: Is this empty return on purpose?
def run(): # get the VR info eval_data_container = main.load_gold_data(ConfigLoadG1) eval_data_container_VR = main.transform_gold_data(ConfigTransformG1VR, eval_data_container) df_VR = pd.DataFrame( data=[{ "article_id": gdi.article_id, "VR=ja": gdi.cats['Verantwortungsreferenz'] == 1, } for gdi in eval_data_container_VR.gold_data_item_list]) # get the AF info eval_data_container = main.load_gold_data(ConfigLoadG1) eval_data_container_AF = main.transform_gold_data( ConfigTransformG1AF_Part1, eval_data_container) #eval_data_container_AF = main.transform_gold_data(ConfigTransformG1AF_Part2, eval_data_container_AF) df_AF = pd.DataFrame(data=[{ "article_id": gdi.article_id, "AF=SM": gdi.cats['AF: Soziale Medien'] == 1, "AF=SC": gdi.cats['AF: Social Companions'] == 1, } for gdi in eval_data_container_AF.gold_data_item_list]) # for each text, read from the DB how many LM it contains db_connection, db_cursor = db_manager.open_db_connection( db_config={ "host": credentials.db_host, "dbname": credentials.db_name, "user": credentials.db_user, "password": credentials.db_password, "port": credentials.db_port }) db_cursor.execute( sql.SQL(""" select t.docid as id, count(distinct t.keyword_id) as dist, sum(t.token_count) as total from {table_name} as t where t.docid = any( %(docid_list)s ) group by t.docid order by t.docid asc """).format( table_name=sql.Identifier('index_2__mara002__lmvr_tokens')), { 'docid_list': [ gdi.article_id for gdi in eval_data_container.gold_data_item_list ], }) results = db_cursor.fetchall() df_LM = pd.DataFrame(data=[{ "article_id": r['id'], "LMs total": r['total'], "LMs distinct": r['dist'], } for r in results]) # close db connection db_manager.close_db_connection(db_connection, db_cursor) # merge the 3 dataframes df = df_LM.merge(df_AF, how='outer', on='article_id') df = df.merge(df_VR, how='outer', on='article_id') # the LM table in the db doesn't contain all texts, so we have NaN values. Replace those with 0. df['LMs total'] = df['LMs total'].fillna(0) df['LMs distinct'] = df['LMs distinct'].fillna(0) # define shortcuts to filter the dataframe maskAF = (df['AF=SC'] == True) | (df['AF=SM'] == True) maskVR = (df['VR=ja'] == True) main.log_manager.info_global( "--------------------------------\n" "Calculations complete. \n" "You can now access the DataFrame as `df`. \n" "There are 2 masks provided as `maskAF` (SC or SM) and `maskVR` (trivial). \n" ) # usage example: # df[maskAF & maskVR] # df[~maskVR] embed()
def stream_from_db_with_lmvr_keywords(ske_config, db_config, index1_table_name, index2_table_names, ske_translation_table_name): log_manager.debug_global("Streaming from database (index2) ...") # open db connection db_connection = None db_cursor = None (db_connection, db_cursor) = db_manager.open_db_connection(db_config, db_connection, db_cursor) # Don't know where to close the DB connection! while True: db_cursor.execute( sql.SQL(""" SELECT * FROM {idx2_table} AS idx2 INNER JOIN {ske_table} AS ske ON ske.{ske_fk_idx2} = idx2.{idx2_fk_ske} INNER JOIN {idx1_table} AS idx1 ON idx1.{idx1_pk} = ske.{ske_fk_idx1} WHERE idx1.already_annotated = FALSE AND idx2.already_annotated = FALSE AND idx1.already_selected = FALSE AND ((idx1.selected_on IS NULL) OR (idx1.selected_on < (NOW() - INTERVAL '2 days'))) AND ((idx2.selected_on IS NULL) OR (idx2.selected_on < (NOW() - INTERVAL '2 days'))) ORDER BY idx2.score_rarity_diversity DESC LIMIT 1 """).format(idx2_table=sql.Identifier( index2_table_names['scores']), idx2_fk_ske=sql.Identifier('docid'), ske_table=sql.Identifier(ske_translation_table_name), ske_fk_idx2=sql.Identifier('docid'), ske_fk_idx1=sql.Identifier('url_index1'), idx1_table=sql.Identifier(index1_table_name), idx1_pk=sql.Identifier('url'))) result = db_cursor.fetchone() # log_manager.debug_global(f"Result={result}") url = result['url'] docid = result['docid'] # log_manager.debug_global(f"Selected text with url={url}, docid={docid}") # Store the information that this URL is getting selected now _select_text(db_connection, db_cursor, index1_table_name, 'url', url) _select_text(db_connection, db_cursor, index2_table_names['scores'], 'docid', docid) # Calculate the preselection options based on model predictions # (Will be empty if there are no predictions for this URL) options = _preselect_options(result) # Get this text's LMVR token counts db_cursor.execute( sql.SQL(""" SELECT keyword_id, token_count FROM {tokens_table} WHERE docid = %(docid)s AND token_count > 0 """).format(tokens_table=sql.Identifier( index2_table_names['tokens']), ), {'docid': docid}) lmvr_count = { row['keyword_id']: int(row['token_count']) for row in db_cursor.fetchall() } lmvr_count_text = json.dumps(lmvr_count, ensure_ascii=False, sort_keys=True) # retrieving the text ske_doc = ske_manager.get_doc_from_url(ske_config, url) log_manager.debug_global(" Feeding this text into prodigy ...") yield { "text": ske_doc["text"], "options": options['cats_as_options'], "accept": options['options_accepted'], "meta": { "docid": result['docid'], "url": url, "category scores": options['scores_text'], "LMVR count": lmvr_count_text, "LMVR score": result['score_rarity_diversity'] } }
def index_articles(ske_config, db_config, index_table_name, trainer, table_name_ref_articles, should_do_dummy_run): if index_table_name is None or trainer is None or table_name_ref_articles is None: raise Exception("Necessary objects have not been initialized.") db_connection, db_cursor = db_manager.open_db_connection(db_config) def _init_index_table(): log_manager.info_global( "Checking if index table exists or if it needs to be created") pos_start = None db_cursor.execute( sql.SQL(""" SELECT * FROM information_schema.tables WHERE table_name = {index_table_name} """).format(index_table_name=sql.Literal(index_table_name))) if db_cursor.rowcount == 1: log_manager.info_global( f"Index table exists, fetching highest pos to start / continue from" ) db_cursor.execute( sql.SQL(""" SELECT MAX(pos_mara002) FROM {index_table_name} """).format(index_table_name=sql.Identifier(index_table_name))) pos_start = db_cursor.fetchone()["max"] if pos_start is None: # So that when asking in _populate_index_table for WHERE pos > pos_start, the zero element will be used too pos_start = -1 else: log_manager.info_global( f"Index table does not exist, creating it and start from pos > -1" ) sql_cat_col_list = [ sql.SQL(" {c} DECIMAL,\n").format(c=sql.Identifier(cat)) for cat in trainer.cats ] sql_stmt_define_cat_cols = reduce(lambda a, b: a + b, sql_cat_col_list) sql_stmt_create_table = sql.SQL( "CREATE TABLE {index_table_name} (\n" " pos_mara002 INT,\n" "{cols}" " PRIMARY KEY (pos_mara002)," " CONSTRAINT fkc FOREIGN KEY(pos_mara002) REFERENCES {table_name_ref_articles}(pos_mara002)" ")").format(index_table_name=sql.Identifier(index_table_name), cols=sql_stmt_define_cat_cols, table_name_ref_articles=sql.Identifier( table_name_ref_articles)) log_manager.info_global( f"create table sql statement:\n{sql_stmt_create_table.as_string(db_cursor)}" ) db_cursor.execute(sql_stmt_create_table) # So that when asking in _populate_index_table for WHERE pos > pos_start, the zero element will be used too pos_start = -1 db_connection.commit() return pos_start def _populate_index_table(pos_start): log_manager.info_global(f"Start indexing at pos > {pos_start}") trainer.nlp.max_length = 3000000 if should_do_dummy_run: limit = sql.SQL("LIMIT 20") else: limit = sql.SQL("") db_cursor.execute( sql.SQL(""" SELECT * FROM {table_name_ref_articles} WHERE pos_mara002 > {pos_start} ORDER BY pos_mara002 {limit} """).format(table_name_ref_articles=sql.Identifier( table_name_ref_articles), pos_start=sql.Literal(pos_start), limit=limit)) pos_last = ske_manager.get_last_pos(ske_config) pos_percent_step = int((pos_last - pos_start) / 1000) pos_percent_current = pos_percent_step for dict_row in db_cursor.fetchall(): pos = dict_row["pos_mara002"] if pos > pos_percent_current: log_manager.info_global( f"Currently at {round(pos_percent_current / pos_percent_step / 10, 1)}% ; at pos {pos} out of {pos_last}" ) pos_percent_current += pos_percent_step db_connection.commit() text = ske_manager.get_doc_from_pos(ske_config, pos)["text"] cats = trainer.nlp(text).cats sql_col_list = [sql.Identifier("pos_mara002")] sql_val_list = [sql.Literal(pos)] for k, v in cats.items(): sql_col_list.append(sql.Identifier(k)) sql_val_list.append(sql.Literal(round(v, 6))) sql_col_stmt = sql.SQL(", ").join(sql_col_list) sql_val_stmt = sql.SQL(", ").join(sql_val_list) db_cursor.execute( sql.SQL(""" INSERT INTO {index_table_name} ({cols}) VALUES({vals}) """).format(index_table_name=sql.Identifier(index_table_name), cols=sql_col_stmt, vals=sql_val_stmt)) def _main(): try: pos_start = _init_index_table() _populate_index_table(pos_start) db_connection.commit() ske_manager.close_session() except Exception as e: db_connection.rollback() ske_manager.close_session() raise e finally: db_manager.close_db_connection(db_connection, db_cursor) _main()
def update(examples): log_manager.debug_global("Prodigy: updating ...") nonlocal db_connection nonlocal db_cursor db_connection, db_cursor = db_manager.open_db_connection( db_config, db_connection, db_cursor) assert db_connection and db_connection.closed == 0 # 0 means 'open' assert db_cursor and not db_cursor.closed for example in examples: try: if index1_table_name and 'url' in example['meta']: url = example['meta']['url'] log_manager.debug_global( f"Storing annotation meta info for url={url} in table {index1_table_name} ..." ) db_cursor.execute( sql.SQL("UPDATE {index_table_name} " "SET already_annotated = TRUE " "WHERE {pk} = %(value)s").format( index_table_name=sql.Identifier( index1_table_name), pk=sql.Identifier('url')), {'value': url}) # TODO: this could be made safer to ensure # that index2 won't be updated accidentally with 'already_annotated' # when we are actually only streaming from index1. # # Curently the stream from index1 does not set 'docid' in example['meta'], # but this may not be good to rely on. if index2_table_names and 'docid' in example['meta']: docid = example['meta']['docid'] log_manager.debug_global( f"Storing annotation meta info for docid={docid} in table {index2_table_names['scores']} ..." ) db_cursor.execute( sql.SQL("UPDATE {index_table_name} " "SET already_annotated = TRUE " "WHERE {pk} = %(value)s").format( index_table_name=sql.Identifier( index2_table_names['scores']), pk=sql.Identifier('docid')), {'value': docid}) db_connection.commit() except Exception as ex: log_manager.info_global( f"Error storing an annotation in the database: {ex}") db_connection.rollback()
def run(ske_config, db_config, docid_table_name, index1_table_name, index2_table_names, should_drop_create_table=False): (db_connection, db_cursor) = db_manager.open_db_connection(db_config) if should_drop_create_table: create_table(db_connection, db_cursor, docid_table_name) # Direction 1: look for URLs that are not yet in the translation table # Hannes says that pos -> docid is faster than docid -> pos # because the SKE uses pos as internal indices log_manager.debug_global("Looking for URLs ...") url_records = select_urls_from_index1(db_cursor, docid_table_name, index1_table_name) log_manager.info_global(f"Found {len(url_records)} URLs to be converted. ") if len(url_records) > 0: ske_manager.create_session(ske_config) progressbar = progress.bar.Bar( 'Converting URLs to docid', max=len(url_records), suffix='%(index)d/%(max)d done, ETA: %(eta_td)s h') for record in url_records: url = record['url'] pos = ske_manager.get_pos_from_url(url) docid = ske_manager.get_docid_from_pos( ske_config, pos) # this calls the API endpoing 'fullref' insert_into_table(db_connection, db_cursor, docid_table_name, docid, pos, url) progressbar.next() progressbar.finish() # Direction 2: look for docids that are not yet in the translation table log_manager.debug_global("Looking for docids ...") docid_records = select_docids_from_index2(db_cursor, docid_table_name, index2_table_names) log_manager.debug_global( f"Found {len(docid_records)} docids to be converted.") if len(docid_records) > 0: ske_manager.create_session(ske_config) progressbar = progress.bar.Bar( 'Converting docids to URLs', max=len(docid_records), suffix='%(index)d/%(max)d done, ETA: %(eta_td)s h') for record in docid_records: docid = record['docid'] pos = ske_manager.get_pos_from_docid( ske_config, docid) # this calls the API endpoint 'first' url = ske_manager.get_url_from_pos(ske_config, pos) insert_into_table(db_connection, db_cursor, docid_table_name, docid, pos, url) progressbar.next() progressbar.finish() # All set! ske_manager.close_session() db_manager.close_db_connection(db_connection, db_cursor) return