def insert_records(records, table_class_name, db_file_path): engine = get_engine(db_file_path) _, record_ids = add_records(records, engine, table_class_name, returns_id=True) return record_ids
def from_db(cls, db_file_path: StringIO): misconceptions = [] sentences = [] md5_hash = md5() engine = get_engine(db_file_path) connection = engine.connect() misinfos = get_misinfo(engine, connection) for misinfo in misinfos: md5_hash.update(misinfo['misc'].encode('utf-8')) obj = json.loads(misinfo['misc']) immutable_obj = {k: _tuplify(v) for k, v in obj.items()} misconception = Misconception( misinfo['id'], misinfo['text'], misinfo['url'], immutable_obj['category'], immutable_obj['pos_variations'], immutable_obj['neg_variations'], immutable_obj['reliability_score'], misinfo['source']) if misconception.pos_variations: misconceptions.append(misconception) for sentence in misconception.pos_variations: sentences.append(sentence) misconceptions = tuple(misconceptions) sentences = tuple(sentences) digest = md5_hash.digest() uid = int.from_bytes(digest, byteorder='big') connection.close() return cls(misconceptions, sentences, uid)
def get_annotated_data(db, annotated_data): input_ids = [] misinfo_ids = [] labels = [] engine = get_engine(db) connection = engine.connect() annotated = get_outputs(engine, connection, annotated_data) for a in annotated: input_ids.append(a.input_id) misinfo_ids.append(a.misinfo_id) labels.append(a.label_id) connection.close() annotated = pd.DataFrame({ 'input_id': input_ids, 'gold_label': labels, 'misinfo_id': misinfo_ids, }) annotated['gold_label'] = annotated.gold_label.astype(int) return annotated
def get_related_article_urls(news_api_client, news_api_config, max_tol, category, db_file_path): article_dict_list = list() endpoint = news_api_config['endpoint'] params_config = news_api_config['params'] num_hits = -1 article_count = 0 page_count = 1 failure_count = 0 while num_hits == -1 or article_count < num_hits: try: result = news_api_client.fetch(endpoint, page=page_count, **params_config) if result['status'] == 'error' and result['code'] == 'maximumResultsReached': break num_hits = result['totalResults'] articles = result['articles'] article_count += len(articles) for article in articles: article_dict_list.append({'url': article['url'], 'title': article['title'], 'publishedAt': article.get('publishedAt', '')}) except Exception: failure_count += 1 if failure_count > max_tol: break news_api_client = NewsApiClient() page_count += 1 engine = get_engine(db_file_path, echo=False) article_urls = update_article_url_db(article_dict_list, category, engine) return article_urls
def insert_records(records, table_class_name, db_file_path, output_file_path): engine = get_engine(db_file_path) _, record_ids = add_records(records, engine, table_class_name, returns_id=True) Path(output_file_path).parent.mkdir(parents=True, exist_ok=True) with open(output_file_path, 'w') as fp: for record_id in record_ids: fp.write('{}\n'.format(record_id))
def get_preds_db(db, model_name): input_id = [] labels = [] mid = [] confidence = [] if model_name in STANCE_MODELS: pos = [] neg = [] engine = get_engine(db) connection = engine.connect() preds = get_outputs(engine, connection, model_name) # Load predictions for p in preds: input_id.append(p.input_id) mid.append(p.misinfo_id) labels.append(p.label_id) confidence.append(p.confidence) if model_name in STANCE_MODELS: pos.append(p.misc[0]) neg.append(p.misc[1]) connection.close() # Create data frame pred_df = pd.DataFrame({ 'input_id': input_id, 'model_id': model_name, 'label_id': labels, 'misinfo_id': mid, 'confidence': confidence, }) if model_name in SIMILARITY_MODELS: grouped = pred_df.groupby('input_id') ranked_confidence = grouped['confidence'].rank(ascending=False) pred_df['rank'] = ranked_confidence else: pred_df['label_id'] = pred_df.label_id.astype(int) return pred_df
def main(): parser = argparse.ArgumentParser() parser.add_argument('--model_name', type=str, required=True) parser.add_argument( '--model_dir', type=Path) #Required for entailment models and pyserini retrieval parser.add_argument('--db_input', type=str, required=True) parser.add_argument('--file', type=Path, required=False) parser.add_argument('--db_output', type=str, required=False) args = parser.parse_args() model_name = args.model_name db_input = args.db_input #Similarity if model_name == 'glove-avg-cosine': model = GloVeCosine() if model_name == 'bert-base-cosine': model = BERTCosine('base') if model_name == 'ct-bert-cosine': model = BERTCosine('ct-bert') if model_name == 'bert-score-base': model = BertScoreDetector('bert-large-uncased') if model_name == 'bert-score-ft': model = BertScoreDetector('models/roberta-ckpt/covid-roberta') if model_name in [ 'bert-score-ct', 'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli', 'comb-sbert-snli-ct', 'comb-sbert-mnli-ct', 'comb-sbert-mednli-ct', 'comb-sbert-fever-ct', 'comb-sbert-scifact-ct', 'comb-sbert-ann-ct' ]: model = BertScoreDetector('digitalepidemiologylab/covid-twitter-bert') if model_name == 'pyserini': model_dir = str(args.model_dir) searcher = SimpleSearcher(model_dir) #Entailment # BoW Logistic if model_name in ['bow-log-snli', 'bow-log-mnli', 'bow-log-mednli']: model = BoWLogistic(args.model_dir) # BoE Logistic if model_name in ['boe-log-snli', 'boe-log-mnli', 'boe-log-mednli']: model = BoELogistic(args.model_dir) if model_name == 'sbert-mnli': model_path = os.path.join('/', args.model_dir, 'mnli-sbert-ckpt-1.pt') model = load_sbert_model('bert-base-cased', model_path) if model_name == 'sbert-snli': model_path = os.path.join('/', args.model_dir, 'snli-sbert-ckpt-2.pt') model = load_sbert_model('bert-base-cased', model_path) if model_name == 'sbert-mednli': model_path = os.path.join('/', args.model_dir, 'mednli-sbert-ckpt-5.pt') model = load_sbert_model('bert-base-cased', model_path) if model_name in ['sbert-snli-ct', 'comb-sbert-snli-ct']: model_path = os.path.join('/', args.model_dir, 'snli-sbert-ct-ckpt-2.pt') if model_name == 'sbert-snli-ct': model = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) elif model_name == 'comb-sbert-snli-ct': model2 = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) if model_name in ['sbert-mnli-ct', 'comb-sbert-mnli-ct']: model_path = os.path.join('/', args.model_dir, 'mnli-sbert-ct-ckpt-2.pt') if model_name == 'sbert-mnli-ct': model = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) elif model_name == 'comb-sbert-mnli-ct': model2 = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) if model_name in ['sbert-mednli-ct', 'comb-sbert-mednli-ct']: model_path = os.path.join('/', args.model_dir, 'mednli-sbert-ct-ckpt-7.pt') if model_name == 'sbert-mednli-ct': model = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) elif model_name == 'comb-sbert-mednli-ct': model2 = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) if model_name in ['sbert-ann-ct', 'comb-sbert-ann-ct']: #model_path = os.path.join('/', args.model_dir, 'ann-sbert-ct-9.pt') model_path = os.path.join('/', args.model_dir, 'ann-ch-sbert-ct-3.pt') if model_name == 'sbert-ann-ct': model = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) elif model_name == 'comb-sbert-ann-ct': model2 = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) if model_name in ['sbert-fever-ct', 'comb-sbert-fever-ct']: model_path = os.path.join('/', args.model_dir, 'fever-sbert-ct-1.pt') if model_name == 'sbert-fever-ct': model = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) elif model_name == 'comb-sbert-fever-ct': model2 = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) if model_name in ['sbert-scifact-ct', 'comb-sbert-scifact-ct']: model_path = os.path.join('/', args.model_dir, 'scifact-sbert-ct-2.pt') if model_name == 'sbert-scifact-ct': model = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) elif model_name == 'comb-sbert-scifact-ct': model2 = load_sbert_model( 'digitalepidemiologylab/covid-twitter-bert', model_path) if model_name in [ 'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli', 'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli' ]: model_path = os.path.join('/', args.model_dir, 'bilstm.pt') field_path = os.path.join('/', args.model_dir, 'bilstm-field.pt') if model_name in ['bilstm-snli', 'bilstm-mnli', 'bilstm-mednli']: model = load_bilstm(model_path, field_path) elif model_name in [ 'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli' ]: model2 = load_bilstm(model_path, field_path) if model_name in [ 'bert-score-base', 'bert-score-ft', 'bert-score-ct', 'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli', 'sbert-snli', 'sbert-mnli', 'sbert-mednli', 'sbert-snli-ct', 'sbert-mnli-ct', 'sbert-mednli-ct', 'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli', 'comb-sbert-snli-ct', 'comb-sbert-mnli-ct', 'comb-sbert-mednli-ct', 'sbert-fever-ct', 'comb-sbert-fever-ct', 'sbert-scifact-ct', 'comb-sbert-scifact-ct' 'sbert-ann-ct', 'comb-sbert-ann-ct' ]: model.eval() if torch.cuda.is_available(): model.cuda() if model_name in [ 'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli', 'comb-sbert-snli-ct', 'comb-sbert-mnli-ct', 'comb-sbert-mednli-ct', 'comb-sbert-fever-ct', 'comb-sbert-scifact-ct', 'comb-sbert-ann-ct' ]: model2.eval() if torch.cuda.is_available(): model2.cuda() if model_name in ['bert-base-cosine', 'ct-bert-cosine']: model.model.eval() if torch.cuda.is_available(): model.model.cuda() #Read misinformation misinfos = MisconceptionDataset.from_db(db_input) mis = [] mid = [] for misinfo in misinfos: mis.append(misinfo.pos_variations[0]) mid.append(misinfo.id) n = len(mid) # Encode misconception if model_name in [ 'bert-base-cosine', 'bert-ft-cosine', 'ct-bert-cosine', 'bert-score-base', 'bert-score-ft', 'bert-score-ct', 'glove-avg-cosine', 'boe-log-snli', 'boe-log-mnli', 'boe-log-mednli', 'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli', 'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli', 'comb-sbert-snli-ct', 'comb-sbert-mnli-ct', 'comb-sbert-mednli-ct', 'comb-sbert-fever-ct', 'comb-sbert-scifact-ct', 'comb-sbert-ann-ct' ]: mis_vect = model._encode(mis) elif model_name in ['bow-log-snli', 'bow-log-mnli', 'bow-log-mednli']: mx = model._encode(mis, 'hyp') #Predict # Tweets engine = get_engine(db_input) connection = engine.connect() inputs = get_inputs(engine, connection) output = [] for input in inputs: print(input['id']) posts = [input['text']] * n post_ids = [input['id']] * n #Similarity Models #BoW Cosine if model_name == 'bow-cosine': corpus = mis + [input['text']] model = BoWCosine(corpus) mis_vect = model._encode(mis) if model_name in [ 'bow-cosine', 'bert-base-cosine', 'bert-ft-cosine', 'ct-bert-cosine', 'bert-score-base', 'bert-score-ft', 'bert-score-ct', 'glove-avg-cosine' ]: post_vect = model._encode([input['text']]) score = model._score(post_vect, mis_vect) df = pd.DataFrame({ 'input_id': post_ids, 'model_id': model_name, 'label_id': 'n/a', 'misinfo_id': mid, 'confidence': score[0] }).to_dict('records') if model_name == 'pyserini': hits = searcher.search(input['text'], k=86) mid2 = [] score = [] # Print the first 10 hits: for i in range(0, len(hits)): mid2.append(int(hits[i].docid)) score.append(hits[i].score) missed = set(mid) - set(mid2) missed = list(missed) for m in missed: mid2.append(m) score.append(0) df = pd.DataFrame({ 'input_id': post_ids, 'model_id': model_name, 'label_id': 'n/a', 'misinfo_id': mid2, 'confidence': score }).to_dict('records') #Entailment Models #BoW Logistic if model_name in ['bow-log-snli', 'bow-log-mnli', 'bow-log-mednli']: px = model._encode(posts, 'prem') preds, probs = model._predict(px, mx) max_probs = probs.max(axis=1) #Boe Logistic if model_name in ['boe-log-snli', 'boe-log-mnli', 'boe-log-mednli']: post_vect = model._encode([input['text']]) * n preds, probs = model._predict(post_vect, mis_vect) max_probs = probs.max(axis=1) #BiLSTM if model_name in [ 'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli', 'bilstm-fnc' ]: post_vect = model._encode([input['text']] * n) with torch.no_grad(): logits = model(post_vect, mis_vect) probs = sm(logits) max = probs.max(dim=-1) max_probs = max[0].tolist() preds = max[1].tolist() #SBERT if model_name in [ 'sbert-snli', 'sbert-mnli', 'sbert-mednli', 'sbert-mednli-ct', 'sbert-mnli-ct', 'sbert-snli-ct', 'sbert-fever-ct', 'sbert-scifact-ct', 'sbert-ann-ct' ]: with torch.no_grad(): logits = model(posts, mis) _, preds = logits.max(dim=-1) probs = sm(logits) max_probs, _ = probs.max(dim=-1) #Stacked # BiLSTM and SBERT if model_name in [ 'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli', 'comb-sbert-snli-ct', 'comb-sbert-mnli-ct', 'comb-sbert-mednli-ct', 'comb-sbert-fever-ct', 'comb-sbert-scifact-ct', 'comb-sbert-ann-ct' ]: ## BertScore-DA post_vect = model._encode([input['text']]) score = model._score(post_vect, mis_vect) bert_score = score[0] ## Relevance Classification rel_preds = [3 if x >= 0.4 else 2 for x in bert_score] df = pd.DataFrame({ 'input_id': post_ids, 'input': posts, 'model_id': model_name, 'rel_label_id': rel_preds, 'misinfo': mis, 'misinfo_id': mid, 'confidence': bert_score }) ## Agree/Disagree Classification relevant = df[df.rel_label_id == 3] if relevant.shape[0] > 0: post = relevant.input.tolist() misinfo = relevant.misinfo.tolist() if model_name in [ 'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli' ]: post = model2._encode(post) misinfo = model2._encode(misinfo) with torch.no_grad(): logits = model2(post, misinfo) probs = sm(logits) relevant['probs'] = probs.tolist() relevant = relevant[['input_id', 'misinfo_id', 'probs']] df = df.merge(relevant, how='left', left_on=['input_id', 'misinfo_id'], right_on=['input_id', 'misinfo_id']) df['label_id'] = df.apply(classify_agree_disagree, axis=1) else: df['label_id'] = df['rel_label_id'] df = df[[ 'input_id', 'model_id', 'label_id', 'misinfo_id', 'confidence' ]].to_dict('records') #Output if model_name in [ 'bow-cosine', 'bert-base-cosine', 'bert-ft-cosine', 'ct-bert-cosine', 'bert-score-base', 'bert-score-ft', 'bert-score-ct', 'glove-avg-cosine' ]: df = pd.DataFrame({ 'input_id': post_ids, 'model_id': model_name, 'label_id': 'n/a', 'misinfo_id': mid, 'confidence': score[0] }).to_dict('records') elif model_name in [ 'bow-log-snli', 'bow-log-mnli', 'bow-log-mednli', 'boe-log-snli', 'boe-log-mnli', 'boe-log-mednli', 'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli' ]: df = pd.DataFrame({ 'input_id': post_ids, 'model_id': model_name, 'label_id': preds, 'misinfo_id': mid, 'confidence': max_probs, 'misc': probs.tolist() }).to_dict('records') elif model_name in [ 'sbert-snli', 'sbert-mnli', 'sbert-mednli', 'sbert-mednli-ct', 'sbert-snli-ct', 'sbert-mnli-ct', 'sbert-fever-ct', 'sbert-scifact-ct', 'sbert-ann-ct' ]: df = pd.DataFrame({ 'input_id': post_ids, 'model_id': model_name, 'label_id': preds.tolist(), 'misinfo_id': mid, 'confidence': max_probs.tolist(), 'misc': probs.tolist() }).to_dict('records') output.append(df) connection.close() #Write to File if args.file: output_df = pd.DataFrame() for o in output: df = pd.DataFrame.from_records(o) output_df = pd.concat([output_df, df], ignore_index=True) output_df.to_csv(args.file) print('Writing predictions to file is complete...') #Write to DB if args.db_output: engine = get_engine(args.db_output) connection = engine.connect() for i in range(len(output)): print('Writing to DB:', i) put_outputs(output[i], engine) connection.close() print('Writing predictions to DB is complete...')
def insert_records(records_dict, table_class_name, db_file_path): engine = get_engine(db_file_path) add_records(records_dict.values(), engine, table_class_name)