示例#1
0
def insert_records(records, table_class_name, db_file_path):
    engine = get_engine(db_file_path)
    _, record_ids = add_records(records,
                                engine,
                                table_class_name,
                                returns_id=True)
    return record_ids
示例#2
0
    def from_db(cls, db_file_path: StringIO):
        misconceptions = []
        sentences = []
        md5_hash = md5()

        engine = get_engine(db_file_path)
        connection = engine.connect()

        misinfos = get_misinfo(engine, connection)

        for misinfo in misinfos:
            md5_hash.update(misinfo['misc'].encode('utf-8'))
            obj = json.loads(misinfo['misc'])
            immutable_obj = {k: _tuplify(v) for k, v in obj.items()}
            misconception = Misconception(
                misinfo['id'], misinfo['text'], misinfo['url'],
                immutable_obj['category'], immutable_obj['pos_variations'],
                immutable_obj['neg_variations'],
                immutable_obj['reliability_score'], misinfo['source'])

            if misconception.pos_variations:
                misconceptions.append(misconception)
            for sentence in misconception.pos_variations:
                sentences.append(sentence)
        misconceptions = tuple(misconceptions)
        sentences = tuple(sentences)
        digest = md5_hash.digest()
        uid = int.from_bytes(digest, byteorder='big')
        connection.close()
        return cls(misconceptions, sentences, uid)
示例#3
0
def get_annotated_data(db, annotated_data):
    input_ids = []
    misinfo_ids = []
    labels = []

    engine = get_engine(db)
    connection = engine.connect()
    annotated = get_outputs(engine, connection, annotated_data)

    for a in annotated:
        input_ids.append(a.input_id)
        misinfo_ids.append(a.misinfo_id)
        labels.append(a.label_id)

    connection.close()

    annotated = pd.DataFrame({
        'input_id': input_ids,
        'gold_label': labels,
        'misinfo_id': misinfo_ids,
    })

    annotated['gold_label'] = annotated.gold_label.astype(int)

    return annotated
示例#4
0
def get_related_article_urls(news_api_client, news_api_config, max_tol, category, db_file_path):
    article_dict_list = list()
    endpoint = news_api_config['endpoint']
    params_config = news_api_config['params']
    num_hits = -1
    article_count = 0
    page_count = 1
    failure_count = 0
    while num_hits == -1 or article_count < num_hits:
        try:
            result = news_api_client.fetch(endpoint, page=page_count, **params_config)
            if result['status'] == 'error' and result['code'] == 'maximumResultsReached':
                break

            num_hits = result['totalResults']
            articles = result['articles']
            article_count += len(articles)
            for article in articles:
                article_dict_list.append({'url': article['url'], 'title': article['title'],
                                          'publishedAt': article.get('publishedAt', '')})
        except Exception:
            failure_count += 1
            if failure_count > max_tol:
                break
            news_api_client = NewsApiClient()
        page_count += 1

    engine = get_engine(db_file_path, echo=False)
    article_urls = update_article_url_db(article_dict_list, category, engine)
    return article_urls
示例#5
0
def insert_records(records, table_class_name, db_file_path, output_file_path):
    engine = get_engine(db_file_path)
    _, record_ids = add_records(records,
                                engine,
                                table_class_name,
                                returns_id=True)
    Path(output_file_path).parent.mkdir(parents=True, exist_ok=True)
    with open(output_file_path, 'w') as fp:
        for record_id in record_ids:
            fp.write('{}\n'.format(record_id))
示例#6
0
def get_preds_db(db, model_name):
    input_id = []
    labels = []
    mid = []
    confidence = []

    if model_name in STANCE_MODELS:
        pos = []
        neg = []

    engine = get_engine(db)
    connection = engine.connect()
    preds = get_outputs(engine, connection, model_name)

    # Load predictions
    for p in preds:
        input_id.append(p.input_id)
        mid.append(p.misinfo_id)
        labels.append(p.label_id)
        confidence.append(p.confidence)

        if model_name in STANCE_MODELS:
            pos.append(p.misc[0])
            neg.append(p.misc[1])

    connection.close()

    # Create data frame
    pred_df = pd.DataFrame({
        'input_id': input_id,
        'model_id': model_name,
        'label_id': labels,
        'misinfo_id': mid,
        'confidence': confidence,
    })

    if model_name in SIMILARITY_MODELS:
        grouped = pred_df.groupby('input_id')
        ranked_confidence = grouped['confidence'].rank(ascending=False)
        pred_df['rank'] = ranked_confidence
    else:
        pred_df['label_id'] = pred_df.label_id.astype(int)
    return pred_df
示例#7
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument(
        '--model_dir',
        type=Path)  #Required for entailment models and pyserini retrieval
    parser.add_argument('--db_input', type=str, required=True)
    parser.add_argument('--file', type=Path, required=False)
    parser.add_argument('--db_output', type=str, required=False)

    args = parser.parse_args()
    model_name = args.model_name
    db_input = args.db_input

    #Similarity
    if model_name == 'glove-avg-cosine':
        model = GloVeCosine()

    if model_name == 'bert-base-cosine':
        model = BERTCosine('base')

    if model_name == 'ct-bert-cosine':
        model = BERTCosine('ct-bert')

    if model_name == 'bert-score-base':
        model = BertScoreDetector('bert-large-uncased')

    if model_name == 'bert-score-ft':
        model = BertScoreDetector('models/roberta-ckpt/covid-roberta')

    if model_name in [
            'bert-score-ct', 'comb-bilstm-snli', 'comb-bilstm-mnli',
            'comb-bilstm-mednli', 'comb-sbert-snli-ct', 'comb-sbert-mnli-ct',
            'comb-sbert-mednli-ct', 'comb-sbert-fever-ct',
            'comb-sbert-scifact-ct', 'comb-sbert-ann-ct'
    ]:
        model = BertScoreDetector('digitalepidemiologylab/covid-twitter-bert')

    if model_name == 'pyserini':
        model_dir = str(args.model_dir)
        searcher = SimpleSearcher(model_dir)

    #Entailment
    # BoW Logistic
    if model_name in ['bow-log-snli', 'bow-log-mnli', 'bow-log-mednli']:
        model = BoWLogistic(args.model_dir)

    # BoE Logistic
    if model_name in ['boe-log-snli', 'boe-log-mnli', 'boe-log-mednli']:
        model = BoELogistic(args.model_dir)

    if model_name == 'sbert-mnli':
        model_path = os.path.join('/', args.model_dir, 'mnli-sbert-ckpt-1.pt')
        model = load_sbert_model('bert-base-cased', model_path)

    if model_name == 'sbert-snli':
        model_path = os.path.join('/', args.model_dir, 'snli-sbert-ckpt-2.pt')
        model = load_sbert_model('bert-base-cased', model_path)

    if model_name == 'sbert-mednli':
        model_path = os.path.join('/', args.model_dir,
                                  'mednli-sbert-ckpt-5.pt')
        model = load_sbert_model('bert-base-cased', model_path)

    if model_name in ['sbert-snli-ct', 'comb-sbert-snli-ct']:
        model_path = os.path.join('/', args.model_dir,
                                  'snli-sbert-ct-ckpt-2.pt')
        if model_name == 'sbert-snli-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-snli-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in ['sbert-mnli-ct', 'comb-sbert-mnli-ct']:
        model_path = os.path.join('/', args.model_dir,
                                  'mnli-sbert-ct-ckpt-2.pt')
        if model_name == 'sbert-mnli-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-mnli-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in ['sbert-mednli-ct', 'comb-sbert-mednli-ct']:
        model_path = os.path.join('/', args.model_dir,
                                  'mednli-sbert-ct-ckpt-7.pt')
        if model_name == 'sbert-mednli-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-mednli-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in ['sbert-ann-ct', 'comb-sbert-ann-ct']:
        #model_path = os.path.join('/', args.model_dir, 'ann-sbert-ct-9.pt')
        model_path = os.path.join('/', args.model_dir, 'ann-ch-sbert-ct-3.pt')
        if model_name == 'sbert-ann-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-ann-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in ['sbert-fever-ct', 'comb-sbert-fever-ct']:
        model_path = os.path.join('/', args.model_dir, 'fever-sbert-ct-1.pt')
        if model_name == 'sbert-fever-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-fever-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in ['sbert-scifact-ct', 'comb-sbert-scifact-ct']:
        model_path = os.path.join('/', args.model_dir, 'scifact-sbert-ct-2.pt')
        if model_name == 'sbert-scifact-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-scifact-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in [
            'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli', 'comb-bilstm-snli',
            'comb-bilstm-mnli', 'comb-bilstm-mednli'
    ]:
        model_path = os.path.join('/', args.model_dir, 'bilstm.pt')
        field_path = os.path.join('/', args.model_dir, 'bilstm-field.pt')
        if model_name in ['bilstm-snli', 'bilstm-mnli', 'bilstm-mednli']:
            model = load_bilstm(model_path, field_path)
        elif model_name in [
                'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli'
        ]:
            model2 = load_bilstm(model_path, field_path)

    if model_name in [
            'bert-score-base', 'bert-score-ft', 'bert-score-ct', 'bilstm-snli',
            'bilstm-mnli', 'bilstm-mednli', 'sbert-snli', 'sbert-mnli',
            'sbert-mednli', 'sbert-snli-ct', 'sbert-mnli-ct',
            'sbert-mednli-ct', 'comb-bilstm-snli', 'comb-bilstm-mnli',
            'comb-bilstm-mednli', 'comb-sbert-snli-ct', 'comb-sbert-mnli-ct',
            'comb-sbert-mednli-ct', 'sbert-fever-ct', 'comb-sbert-fever-ct',
            'sbert-scifact-ct', 'comb-sbert-scifact-ct'
            'sbert-ann-ct', 'comb-sbert-ann-ct'
    ]:
        model.eval()
        if torch.cuda.is_available():
            model.cuda()

        if model_name in [
                'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli',
                'comb-sbert-snli-ct', 'comb-sbert-mnli-ct',
                'comb-sbert-mednli-ct', 'comb-sbert-fever-ct',
                'comb-sbert-scifact-ct', 'comb-sbert-ann-ct'
        ]:
            model2.eval()
            if torch.cuda.is_available():
                model2.cuda()

    if model_name in ['bert-base-cosine', 'ct-bert-cosine']:
        model.model.eval()
        if torch.cuda.is_available():
            model.model.cuda()

    #Read misinformation
    misinfos = MisconceptionDataset.from_db(db_input)
    mis = []
    mid = []
    for misinfo in misinfos:
        mis.append(misinfo.pos_variations[0])
        mid.append(misinfo.id)
    n = len(mid)
    # Encode misconception
    if model_name in [
            'bert-base-cosine', 'bert-ft-cosine', 'ct-bert-cosine',
            'bert-score-base', 'bert-score-ft', 'bert-score-ct',
            'glove-avg-cosine', 'boe-log-snli', 'boe-log-mnli',
            'boe-log-mednli', 'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli',
            'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli',
            'comb-sbert-snli-ct', 'comb-sbert-mnli-ct', 'comb-sbert-mednli-ct',
            'comb-sbert-fever-ct', 'comb-sbert-scifact-ct', 'comb-sbert-ann-ct'
    ]:
        mis_vect = model._encode(mis)

    elif model_name in ['bow-log-snli', 'bow-log-mnli', 'bow-log-mednli']:
        mx = model._encode(mis, 'hyp')

    #Predict
    # Tweets
    engine = get_engine(db_input)
    connection = engine.connect()
    inputs = get_inputs(engine, connection)

    output = []

    for input in inputs:
        print(input['id'])
        posts = [input['text']] * n
        post_ids = [input['id']] * n

        #Similarity Models
        #BoW Cosine
        if model_name == 'bow-cosine':
            corpus = mis + [input['text']]
            model = BoWCosine(corpus)
            mis_vect = model._encode(mis)

        if model_name in [
                'bow-cosine', 'bert-base-cosine', 'bert-ft-cosine',
                'ct-bert-cosine', 'bert-score-base', 'bert-score-ft',
                'bert-score-ct', 'glove-avg-cosine'
        ]:
            post_vect = model._encode([input['text']])
            score = model._score(post_vect, mis_vect)
            df = pd.DataFrame({
                'input_id': post_ids,
                'model_id': model_name,
                'label_id': 'n/a',
                'misinfo_id': mid,
                'confidence': score[0]
            }).to_dict('records')

        if model_name == 'pyserini':
            hits = searcher.search(input['text'], k=86)

            mid2 = []
            score = []
            # Print the first 10 hits:
            for i in range(0, len(hits)):
                mid2.append(int(hits[i].docid))
                score.append(hits[i].score)

            missed = set(mid) - set(mid2)
            missed = list(missed)

            for m in missed:
                mid2.append(m)
                score.append(0)

            df = pd.DataFrame({
                'input_id': post_ids,
                'model_id': model_name,
                'label_id': 'n/a',
                'misinfo_id': mid2,
                'confidence': score
            }).to_dict('records')

        #Entailment Models
        #BoW Logistic
        if model_name in ['bow-log-snli', 'bow-log-mnli', 'bow-log-mednli']:
            px = model._encode(posts, 'prem')
            preds, probs = model._predict(px, mx)
            max_probs = probs.max(axis=1)

        #Boe Logistic
        if model_name in ['boe-log-snli', 'boe-log-mnli', 'boe-log-mednli']:
            post_vect = model._encode([input['text']]) * n

            preds, probs = model._predict(post_vect, mis_vect)
            max_probs = probs.max(axis=1)

        #BiLSTM
        if model_name in [
                'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli', 'bilstm-fnc'
        ]:
            post_vect = model._encode([input['text']] * n)
            with torch.no_grad():
                logits = model(post_vect, mis_vect)
                probs = sm(logits)
                max = probs.max(dim=-1)
                max_probs = max[0].tolist()
                preds = max[1].tolist()

        #SBERT
        if model_name in [
                'sbert-snli', 'sbert-mnli', 'sbert-mednli', 'sbert-mednli-ct',
                'sbert-mnli-ct', 'sbert-snli-ct', 'sbert-fever-ct',
                'sbert-scifact-ct', 'sbert-ann-ct'
        ]:
            with torch.no_grad():
                logits = model(posts, mis)
                _, preds = logits.max(dim=-1)
                probs = sm(logits)
                max_probs, _ = probs.max(dim=-1)

        #Stacked # BiLSTM and SBERT
        if model_name in [
                'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli',
                'comb-sbert-snli-ct', 'comb-sbert-mnli-ct',
                'comb-sbert-mednli-ct', 'comb-sbert-fever-ct',
                'comb-sbert-scifact-ct', 'comb-sbert-ann-ct'
        ]:
            ## BertScore-DA
            post_vect = model._encode([input['text']])
            score = model._score(post_vect, mis_vect)
            bert_score = score[0]

            ## Relevance Classification
            rel_preds = [3 if x >= 0.4 else 2 for x in bert_score]
            df = pd.DataFrame({
                'input_id': post_ids,
                'input': posts,
                'model_id': model_name,
                'rel_label_id': rel_preds,
                'misinfo': mis,
                'misinfo_id': mid,
                'confidence': bert_score
            })

            ## Agree/Disagree Classification
            relevant = df[df.rel_label_id == 3]

            if relevant.shape[0] > 0:
                post = relevant.input.tolist()
                misinfo = relevant.misinfo.tolist()

                if model_name in [
                        'comb-bilstm-snli', 'comb-bilstm-mnli',
                        'comb-bilstm-mednli'
                ]:
                    post = model2._encode(post)
                    misinfo = model2._encode(misinfo)

                with torch.no_grad():
                    logits = model2(post, misinfo)
                    probs = sm(logits)
                    relevant['probs'] = probs.tolist()

                    relevant = relevant[['input_id', 'misinfo_id', 'probs']]
                    df = df.merge(relevant,
                                  how='left',
                                  left_on=['input_id', 'misinfo_id'],
                                  right_on=['input_id', 'misinfo_id'])

                    df['label_id'] = df.apply(classify_agree_disagree, axis=1)

            else:
                df['label_id'] = df['rel_label_id']

            df = df[[
                'input_id', 'model_id', 'label_id', 'misinfo_id', 'confidence'
            ]].to_dict('records')

        #Output
        if model_name in [
                'bow-cosine', 'bert-base-cosine', 'bert-ft-cosine',
                'ct-bert-cosine', 'bert-score-base', 'bert-score-ft',
                'bert-score-ct', 'glove-avg-cosine'
        ]:
            df = pd.DataFrame({
                'input_id': post_ids,
                'model_id': model_name,
                'label_id': 'n/a',
                'misinfo_id': mid,
                'confidence': score[0]
            }).to_dict('records')

        elif model_name in [
                'bow-log-snli', 'bow-log-mnli', 'bow-log-mednli',
                'boe-log-snli', 'boe-log-mnli', 'boe-log-mednli',
                'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli'
        ]:

            df = pd.DataFrame({
                'input_id': post_ids,
                'model_id': model_name,
                'label_id': preds,
                'misinfo_id': mid,
                'confidence': max_probs,
                'misc': probs.tolist()
            }).to_dict('records')

        elif model_name in [
                'sbert-snli', 'sbert-mnli', 'sbert-mednli', 'sbert-mednli-ct',
                'sbert-snli-ct', 'sbert-mnli-ct', 'sbert-fever-ct',
                'sbert-scifact-ct', 'sbert-ann-ct'
        ]:
            df = pd.DataFrame({
                'input_id': post_ids,
                'model_id': model_name,
                'label_id': preds.tolist(),
                'misinfo_id': mid,
                'confidence': max_probs.tolist(),
                'misc': probs.tolist()
            }).to_dict('records')

        output.append(df)

    connection.close()

    #Write to File
    if args.file:
        output_df = pd.DataFrame()

        for o in output:
            df = pd.DataFrame.from_records(o)
            output_df = pd.concat([output_df, df], ignore_index=True)

        output_df.to_csv(args.file)

        print('Writing predictions to file is complete...')

    #Write to DB
    if args.db_output:
        engine = get_engine(args.db_output)
        connection = engine.connect()

        for i in range(len(output)):
            print('Writing to DB:', i)
            put_outputs(output[i], engine)

        connection.close()
        print('Writing predictions to DB is complete...')
示例#8
0
def insert_records(records_dict, table_class_name, db_file_path):
    engine = get_engine(db_file_path)
    add_records(records_dict.values(), engine, table_class_name)