def test_scores_not_affected_by_padding(self):
     with open('tests/fixtures/misconceptions.jsonl', 'r') as f:
         misconceptions = MisconceptionDataset.from_jsonl(f)
     sentences = ['Lorem ipsum', 'dolor sit amet']
     score_a = self.detector.score(sentences[0], misconceptions)
     score_b = self.detector.score(sentences, misconceptions)[0]
     assert np.allclose(score_a, score_b)
示例#2
0
 def test_scores_not_affected_by_padding(self):
     # Checks that score works on individual sentences as well as lists of sentences, and that
     # padding does not affect scores.
     with open('tests/fixtures/misconceptions.jsonl', 'r') as f:
         misconceptions = MisconceptionDataset.from_jsonl(f)
     sentences = ['Lorem ipsum', 'dolor sit amet']
     score_a = self.detector.score(sentences[0], misconceptions)
     score_b = self.detector.score(sentences, misconceptions)[0]
     assert np.allclose(score_a, score_b)
 def test_from_jsonl(self):
     with open('tests/fixtures/misconceptions.jsonl', 'r') as f:
         misconceptions = MisconceptionDataset.from_jsonl(f)
     assert len(misconceptions) == 1
     expected = Misconception(id=1,
                              canonical_sentence="Don't lick faces",
                              sources=("https://www.google.com", ),
                              category=tuple(),
                              pos_variations=("Don't lick faces", ),
                              neg_variations=tuple(),
                              reliability_score=1,
                              origin="Sameer")
     assert misconceptions[0] == expected
示例#4
0
def test_pipeline():
    # TODO: Periodically reload.
    logger.info('Loading misconceptions')
    with open('misconceptions.jsonl', 'r') as f:
        misconceptions = MisconceptionDataset.from_jsonl(f)

    logger.info('Loading models')
    retriever = BertScoreDetector('digitalepidemiologylab/covid-twitter-bert')
    detector = SentenceBertClassifier(
        model_name='digitalepidemiologylab/covid-twitter-bert',
        num_classes=3,
    )
    state_dict = torch.load('/home/rlogan/SBERT-MNLI-ckpt-2.pt')
    logger.info('Restoring detector checkpoint')
    detector.load_state_dict(state_dict)
    pipeline = Pipeline(retriever=retriever, detector=detector)

    pipeline(
        "North Korea and China conspired together to create the coronavirus.",
        misconceptions)
示例#5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, help='input data')
    parser.add_argument('--output', type=str, help='output data')
    parser.add_argument('--misconceptions', type=str, help='JSONL file containing misconceptions')
    parser.add_argument('--score_type', type=str, choices=('precision', 'recall', 'f1'), default='f1')
    parser.add_argument('--k', type=int, default=5)
    args = parser.parse_args()

    detector = BertScoreDetector('covid-roberta/checkpoint-84500', score_type=args.score_type)
    detector.eval()
    detector.cuda()

    with open(args.misconceptions, 'r') as f:
        misconceptions = MisconceptionDataset.from_jsonl(f)

    top_scoring_tweets = [MaxHeap(args.k) for _ in range(len(misconceptions))]

    for sentences in tqdm(generate_sentences(args.input)):
        scores = detector.score(sentences, misconceptions)
        # Top-k predictions per misconception
        top_k = scores.argsort(axis=0)[::-1, :][:args.k]
        for misconception_idx in range(len(misconceptions)):
            for sentence_idx in top_k[:, misconception_idx]:
                # Add (score, sentence) tuple to heap
                x = (scores[sentence_idx, misconception_idx], sentences[sentence_idx])
                top_scoring_tweets[misconception_idx].push(x)

    with open(args.output, 'w') as f:
        fieldnames = ['id', 'pos_variation', 'tweet', 'score']
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for i, pos_variation in enumerate(misconceptions.sentences):
            id_ = misconceptions[i].id
            for score, tweet in top_scoring_tweets[i].view():
                writer.writerow({
                    'id': id_,
                    'pos_variation': pos_variation,
                    'tweet': tweet,
                    'score': score
                })
示例#6
0
        print('{0} cells updated.'.format(result.get('updatedCells')))

    def read_dataset(self):
        return NotImplementedError
        # result = sheet.values().get(spreadsheetId=SAMPLE_SPREADSHEET_ID,
        #                             range=SAMPLE_RANGE_NAME).execute()
        # values = result.get('values', [])

        # if not values:
        #     print('No data found.')
        # else:
        #     for row in values:
        #         print(row)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-d',
                        '--dataset_file',
                        type=Path,
                        default=Path('misconceptions.jsonl'))
    parser.add_argument('-r', '--range_name', type=str, default='Wikipedia!A1')
    args = parser.parse_args()

    with open(args.dataset_file, 'r') as f:
        misconceptions = MisconceptionDataset.from_jsonl(f)

    obj = MisconceptionDatasetToGSheets()
    obj.start_service()
    obj.write_dataset(misconceptions, args.range_name)
示例#7
0
 def setUp(self):
     self.detector = MockDetector()
     with open('tests/fixtures/misconceptions.jsonl', 'r') as f:
         misconceptions = MisconceptionDataset.from_jsonl(f)
     self.misconceptions = misconceptions