Exemplo n.º 1
0
 def test_scores_not_affected_by_padding(self):
     with open('tests/fixtures/misconceptions.jsonl', 'r') as f:
         misconceptions = MisconceptionDataset.from_jsonl(f)
     sentences = ['Lorem ipsum', 'dolor sit amet']
     score_a = self.detector.score(sentences[0], misconceptions)
     score_b = self.detector.score(sentences, misconceptions)[0]
     assert np.allclose(score_a, score_b)
Exemplo n.º 2
0
 def test_scores_not_affected_by_padding(self):
     # Checks that score works on individual sentences as well as lists of sentences, and that
     # padding does not affect scores.
     with open('tests/fixtures/misconceptions.jsonl', 'r') as f:
         misconceptions = MisconceptionDataset.from_jsonl(f)
     sentences = ['Lorem ipsum', 'dolor sit amet']
     score_a = self.detector.score(sentences[0], misconceptions)
     score_b = self.detector.score(sentences, misconceptions)[0]
     assert np.allclose(score_a, score_b)
Exemplo n.º 3
0
 def test_from_jsonl(self):
     with open('tests/fixtures/misconceptions.jsonl', 'r') as f:
         misconceptions = MisconceptionDataset.from_jsonl(f)
     assert len(misconceptions) == 1
     expected = Misconception(id=1,
                              canonical_sentence="Don't lick faces",
                              sources=("https://www.google.com", ),
                              category=tuple(),
                              pos_variations=("Don't lick faces", ),
                              neg_variations=tuple(),
                              reliability_score=1,
                              origin="Sameer")
     assert misconceptions[0] == expected
Exemplo n.º 4
0
def test_pipeline():
    # TODO: Periodically reload.
    logger.info('Loading misconceptions')
    with open('misconceptions.jsonl', 'r') as f:
        misconceptions = MisconceptionDataset.from_jsonl(f)

    logger.info('Loading models')
    retriever = BertScoreDetector('digitalepidemiologylab/covid-twitter-bert')
    detector = SentenceBertClassifier(
        model_name='digitalepidemiologylab/covid-twitter-bert',
        num_classes=3,
    )
    state_dict = torch.load('/home/rlogan/SBERT-MNLI-ckpt-2.pt')
    logger.info('Restoring detector checkpoint')
    detector.load_state_dict(state_dict)
    pipeline = Pipeline(retriever=retriever, detector=detector)

    pipeline(
        "North Korea and China conspired together to create the coronavirus.",
        misconceptions)
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, help='input data')
    parser.add_argument('--output', type=str, help='output data')
    parser.add_argument('--misconceptions', type=str, help='JSONL file containing misconceptions')
    parser.add_argument('--score_type', type=str, choices=('precision', 'recall', 'f1'), default='f1')
    parser.add_argument('--k', type=int, default=5)
    args = parser.parse_args()

    detector = BertScoreDetector('covid-roberta/checkpoint-84500', score_type=args.score_type)
    detector.eval()
    detector.cuda()

    with open(args.misconceptions, 'r') as f:
        misconceptions = MisconceptionDataset.from_jsonl(f)

    top_scoring_tweets = [MaxHeap(args.k) for _ in range(len(misconceptions))]

    for sentences in tqdm(generate_sentences(args.input)):
        scores = detector.score(sentences, misconceptions)
        # Top-k predictions per misconception
        top_k = scores.argsort(axis=0)[::-1, :][:args.k]
        for misconception_idx in range(len(misconceptions)):
            for sentence_idx in top_k[:, misconception_idx]:
                # Add (score, sentence) tuple to heap
                x = (scores[sentence_idx, misconception_idx], sentences[sentence_idx])
                top_scoring_tweets[misconception_idx].push(x)

    with open(args.output, 'w') as f:
        fieldnames = ['id', 'pos_variation', 'tweet', 'score']
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for i, pos_variation in enumerate(misconceptions.sentences):
            id_ = misconceptions[i].id
            for score, tweet in top_scoring_tweets[i].view():
                writer.writerow({
                    'id': id_,
                    'pos_variation': pos_variation,
                    'tweet': tweet,
                    'score': score
                })
Exemplo n.º 6
0
        print('{0} cells updated.'.format(result.get('updatedCells')))

    def read_dataset(self):
        return NotImplementedError
        # result = sheet.values().get(spreadsheetId=SAMPLE_SPREADSHEET_ID,
        #                             range=SAMPLE_RANGE_NAME).execute()
        # values = result.get('values', [])

        # if not values:
        #     print('No data found.')
        # else:
        #     for row in values:
        #         print(row)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-d',
                        '--dataset_file',
                        type=Path,
                        default=Path('misconceptions.jsonl'))
    parser.add_argument('-r', '--range_name', type=str, default='Wikipedia!A1')
    args = parser.parse_args()

    with open(args.dataset_file, 'r') as f:
        misconceptions = MisconceptionDataset.from_jsonl(f)

    obj = MisconceptionDatasetToGSheets()
    obj.start_service()
    obj.write_dataset(misconceptions, args.range_name)
Exemplo n.º 7
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument(
        '--model_dir',
        type=Path)  #Required for entailment models and pyserini retrieval
    parser.add_argument('--db_input', type=str, required=True)
    parser.add_argument('--file', type=Path, required=False)
    parser.add_argument('--db_output', type=str, required=False)

    args = parser.parse_args()
    model_name = args.model_name
    db_input = args.db_input

    #Similarity
    if model_name == 'glove-avg-cosine':
        model = GloVeCosine()

    if model_name == 'bert-base-cosine':
        model = BERTCosine('base')

    if model_name == 'ct-bert-cosine':
        model = BERTCosine('ct-bert')

    if model_name == 'bert-score-base':
        model = BertScoreDetector('bert-large-uncased')

    if model_name == 'bert-score-ft':
        model = BertScoreDetector('models/roberta-ckpt/covid-roberta')

    if model_name in [
            'bert-score-ct', 'comb-bilstm-snli', 'comb-bilstm-mnli',
            'comb-bilstm-mednli', 'comb-sbert-snli-ct', 'comb-sbert-mnli-ct',
            'comb-sbert-mednli-ct', 'comb-sbert-fever-ct',
            'comb-sbert-scifact-ct', 'comb-sbert-ann-ct'
    ]:
        model = BertScoreDetector('digitalepidemiologylab/covid-twitter-bert')

    if model_name == 'pyserini':
        model_dir = str(args.model_dir)
        searcher = SimpleSearcher(model_dir)

    #Entailment
    # BoW Logistic
    if model_name in ['bow-log-snli', 'bow-log-mnli', 'bow-log-mednli']:
        model = BoWLogistic(args.model_dir)

    # BoE Logistic
    if model_name in ['boe-log-snli', 'boe-log-mnli', 'boe-log-mednli']:
        model = BoELogistic(args.model_dir)

    if model_name == 'sbert-mnli':
        model_path = os.path.join('/', args.model_dir, 'mnli-sbert-ckpt-1.pt')
        model = load_sbert_model('bert-base-cased', model_path)

    if model_name == 'sbert-snli':
        model_path = os.path.join('/', args.model_dir, 'snli-sbert-ckpt-2.pt')
        model = load_sbert_model('bert-base-cased', model_path)

    if model_name == 'sbert-mednli':
        model_path = os.path.join('/', args.model_dir,
                                  'mednli-sbert-ckpt-5.pt')
        model = load_sbert_model('bert-base-cased', model_path)

    if model_name in ['sbert-snli-ct', 'comb-sbert-snli-ct']:
        model_path = os.path.join('/', args.model_dir,
                                  'snli-sbert-ct-ckpt-2.pt')
        if model_name == 'sbert-snli-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-snli-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in ['sbert-mnli-ct', 'comb-sbert-mnli-ct']:
        model_path = os.path.join('/', args.model_dir,
                                  'mnli-sbert-ct-ckpt-2.pt')
        if model_name == 'sbert-mnli-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-mnli-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in ['sbert-mednli-ct', 'comb-sbert-mednli-ct']:
        model_path = os.path.join('/', args.model_dir,
                                  'mednli-sbert-ct-ckpt-7.pt')
        if model_name == 'sbert-mednli-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-mednli-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in ['sbert-ann-ct', 'comb-sbert-ann-ct']:
        #model_path = os.path.join('/', args.model_dir, 'ann-sbert-ct-9.pt')
        model_path = os.path.join('/', args.model_dir, 'ann-ch-sbert-ct-3.pt')
        if model_name == 'sbert-ann-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-ann-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in ['sbert-fever-ct', 'comb-sbert-fever-ct']:
        model_path = os.path.join('/', args.model_dir, 'fever-sbert-ct-1.pt')
        if model_name == 'sbert-fever-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-fever-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in ['sbert-scifact-ct', 'comb-sbert-scifact-ct']:
        model_path = os.path.join('/', args.model_dir, 'scifact-sbert-ct-2.pt')
        if model_name == 'sbert-scifact-ct':
            model = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)
        elif model_name == 'comb-sbert-scifact-ct':
            model2 = load_sbert_model(
                'digitalepidemiologylab/covid-twitter-bert', model_path)

    if model_name in [
            'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli', 'comb-bilstm-snli',
            'comb-bilstm-mnli', 'comb-bilstm-mednli'
    ]:
        model_path = os.path.join('/', args.model_dir, 'bilstm.pt')
        field_path = os.path.join('/', args.model_dir, 'bilstm-field.pt')
        if model_name in ['bilstm-snli', 'bilstm-mnli', 'bilstm-mednli']:
            model = load_bilstm(model_path, field_path)
        elif model_name in [
                'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli'
        ]:
            model2 = load_bilstm(model_path, field_path)

    if model_name in [
            'bert-score-base', 'bert-score-ft', 'bert-score-ct', 'bilstm-snli',
            'bilstm-mnli', 'bilstm-mednli', 'sbert-snli', 'sbert-mnli',
            'sbert-mednli', 'sbert-snli-ct', 'sbert-mnli-ct',
            'sbert-mednli-ct', 'comb-bilstm-snli', 'comb-bilstm-mnli',
            'comb-bilstm-mednli', 'comb-sbert-snli-ct', 'comb-sbert-mnli-ct',
            'comb-sbert-mednli-ct', 'sbert-fever-ct', 'comb-sbert-fever-ct',
            'sbert-scifact-ct', 'comb-sbert-scifact-ct'
            'sbert-ann-ct', 'comb-sbert-ann-ct'
    ]:
        model.eval()
        if torch.cuda.is_available():
            model.cuda()

        if model_name in [
                'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli',
                'comb-sbert-snli-ct', 'comb-sbert-mnli-ct',
                'comb-sbert-mednli-ct', 'comb-sbert-fever-ct',
                'comb-sbert-scifact-ct', 'comb-sbert-ann-ct'
        ]:
            model2.eval()
            if torch.cuda.is_available():
                model2.cuda()

    if model_name in ['bert-base-cosine', 'ct-bert-cosine']:
        model.model.eval()
        if torch.cuda.is_available():
            model.model.cuda()

    #Read misinformation
    misinfos = MisconceptionDataset.from_db(db_input)
    mis = []
    mid = []
    for misinfo in misinfos:
        mis.append(misinfo.pos_variations[0])
        mid.append(misinfo.id)
    n = len(mid)
    # Encode misconception
    if model_name in [
            'bert-base-cosine', 'bert-ft-cosine', 'ct-bert-cosine',
            'bert-score-base', 'bert-score-ft', 'bert-score-ct',
            'glove-avg-cosine', 'boe-log-snli', 'boe-log-mnli',
            'boe-log-mednli', 'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli',
            'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli',
            'comb-sbert-snli-ct', 'comb-sbert-mnli-ct', 'comb-sbert-mednli-ct',
            'comb-sbert-fever-ct', 'comb-sbert-scifact-ct', 'comb-sbert-ann-ct'
    ]:
        mis_vect = model._encode(mis)

    elif model_name in ['bow-log-snli', 'bow-log-mnli', 'bow-log-mednli']:
        mx = model._encode(mis, 'hyp')

    #Predict
    # Tweets
    engine = get_engine(db_input)
    connection = engine.connect()
    inputs = get_inputs(engine, connection)

    output = []

    for input in inputs:
        print(input['id'])
        posts = [input['text']] * n
        post_ids = [input['id']] * n

        #Similarity Models
        #BoW Cosine
        if model_name == 'bow-cosine':
            corpus = mis + [input['text']]
            model = BoWCosine(corpus)
            mis_vect = model._encode(mis)

        if model_name in [
                'bow-cosine', 'bert-base-cosine', 'bert-ft-cosine',
                'ct-bert-cosine', 'bert-score-base', 'bert-score-ft',
                'bert-score-ct', 'glove-avg-cosine'
        ]:
            post_vect = model._encode([input['text']])
            score = model._score(post_vect, mis_vect)
            df = pd.DataFrame({
                'input_id': post_ids,
                'model_id': model_name,
                'label_id': 'n/a',
                'misinfo_id': mid,
                'confidence': score[0]
            }).to_dict('records')

        if model_name == 'pyserini':
            hits = searcher.search(input['text'], k=86)

            mid2 = []
            score = []
            # Print the first 10 hits:
            for i in range(0, len(hits)):
                mid2.append(int(hits[i].docid))
                score.append(hits[i].score)

            missed = set(mid) - set(mid2)
            missed = list(missed)

            for m in missed:
                mid2.append(m)
                score.append(0)

            df = pd.DataFrame({
                'input_id': post_ids,
                'model_id': model_name,
                'label_id': 'n/a',
                'misinfo_id': mid2,
                'confidence': score
            }).to_dict('records')

        #Entailment Models
        #BoW Logistic
        if model_name in ['bow-log-snli', 'bow-log-mnli', 'bow-log-mednli']:
            px = model._encode(posts, 'prem')
            preds, probs = model._predict(px, mx)
            max_probs = probs.max(axis=1)

        #Boe Logistic
        if model_name in ['boe-log-snli', 'boe-log-mnli', 'boe-log-mednli']:
            post_vect = model._encode([input['text']]) * n

            preds, probs = model._predict(post_vect, mis_vect)
            max_probs = probs.max(axis=1)

        #BiLSTM
        if model_name in [
                'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli', 'bilstm-fnc'
        ]:
            post_vect = model._encode([input['text']] * n)
            with torch.no_grad():
                logits = model(post_vect, mis_vect)
                probs = sm(logits)
                max = probs.max(dim=-1)
                max_probs = max[0].tolist()
                preds = max[1].tolist()

        #SBERT
        if model_name in [
                'sbert-snli', 'sbert-mnli', 'sbert-mednli', 'sbert-mednli-ct',
                'sbert-mnli-ct', 'sbert-snli-ct', 'sbert-fever-ct',
                'sbert-scifact-ct', 'sbert-ann-ct'
        ]:
            with torch.no_grad():
                logits = model(posts, mis)
                _, preds = logits.max(dim=-1)
                probs = sm(logits)
                max_probs, _ = probs.max(dim=-1)

        #Stacked # BiLSTM and SBERT
        if model_name in [
                'comb-bilstm-snli', 'comb-bilstm-mnli', 'comb-bilstm-mednli',
                'comb-sbert-snli-ct', 'comb-sbert-mnli-ct',
                'comb-sbert-mednli-ct', 'comb-sbert-fever-ct',
                'comb-sbert-scifact-ct', 'comb-sbert-ann-ct'
        ]:
            ## BertScore-DA
            post_vect = model._encode([input['text']])
            score = model._score(post_vect, mis_vect)
            bert_score = score[0]

            ## Relevance Classification
            rel_preds = [3 if x >= 0.4 else 2 for x in bert_score]
            df = pd.DataFrame({
                'input_id': post_ids,
                'input': posts,
                'model_id': model_name,
                'rel_label_id': rel_preds,
                'misinfo': mis,
                'misinfo_id': mid,
                'confidence': bert_score
            })

            ## Agree/Disagree Classification
            relevant = df[df.rel_label_id == 3]

            if relevant.shape[0] > 0:
                post = relevant.input.tolist()
                misinfo = relevant.misinfo.tolist()

                if model_name in [
                        'comb-bilstm-snli', 'comb-bilstm-mnli',
                        'comb-bilstm-mednli'
                ]:
                    post = model2._encode(post)
                    misinfo = model2._encode(misinfo)

                with torch.no_grad():
                    logits = model2(post, misinfo)
                    probs = sm(logits)
                    relevant['probs'] = probs.tolist()

                    relevant = relevant[['input_id', 'misinfo_id', 'probs']]
                    df = df.merge(relevant,
                                  how='left',
                                  left_on=['input_id', 'misinfo_id'],
                                  right_on=['input_id', 'misinfo_id'])

                    df['label_id'] = df.apply(classify_agree_disagree, axis=1)

            else:
                df['label_id'] = df['rel_label_id']

            df = df[[
                'input_id', 'model_id', 'label_id', 'misinfo_id', 'confidence'
            ]].to_dict('records')

        #Output
        if model_name in [
                'bow-cosine', 'bert-base-cosine', 'bert-ft-cosine',
                'ct-bert-cosine', 'bert-score-base', 'bert-score-ft',
                'bert-score-ct', 'glove-avg-cosine'
        ]:
            df = pd.DataFrame({
                'input_id': post_ids,
                'model_id': model_name,
                'label_id': 'n/a',
                'misinfo_id': mid,
                'confidence': score[0]
            }).to_dict('records')

        elif model_name in [
                'bow-log-snli', 'bow-log-mnli', 'bow-log-mednli',
                'boe-log-snli', 'boe-log-mnli', 'boe-log-mednli',
                'bilstm-snli', 'bilstm-mnli', 'bilstm-mednli'
        ]:

            df = pd.DataFrame({
                'input_id': post_ids,
                'model_id': model_name,
                'label_id': preds,
                'misinfo_id': mid,
                'confidence': max_probs,
                'misc': probs.tolist()
            }).to_dict('records')

        elif model_name in [
                'sbert-snli', 'sbert-mnli', 'sbert-mednli', 'sbert-mednli-ct',
                'sbert-snli-ct', 'sbert-mnli-ct', 'sbert-fever-ct',
                'sbert-scifact-ct', 'sbert-ann-ct'
        ]:
            df = pd.DataFrame({
                'input_id': post_ids,
                'model_id': model_name,
                'label_id': preds.tolist(),
                'misinfo_id': mid,
                'confidence': max_probs.tolist(),
                'misc': probs.tolist()
            }).to_dict('records')

        output.append(df)

    connection.close()

    #Write to File
    if args.file:
        output_df = pd.DataFrame()

        for o in output:
            df = pd.DataFrame.from_records(o)
            output_df = pd.concat([output_df, df], ignore_index=True)

        output_df.to_csv(args.file)

        print('Writing predictions to file is complete...')

    #Write to DB
    if args.db_output:
        engine = get_engine(args.db_output)
        connection = engine.connect()

        for i in range(len(output)):
            print('Writing to DB:', i)
            put_outputs(output[i], engine)

        connection.close()
        print('Writing predictions to DB is complete...')
Exemplo n.º 8
0
 def setUp(self):
     self.detector = MockDetector()
     with open('tests/fixtures/misconceptions.jsonl', 'r') as f:
         misconceptions = MisconceptionDataset.from_jsonl(f)
     self.misconceptions = misconceptions
Exemplo n.º 9
0
 def test_hashable(self):
     misconceptions = MisconceptionDataset(tuple(), tuple(), uid=123)
     assert hash(misconceptions) == 123