def test_scores_not_affected_by_padding(self): with open('tests/fixtures/misconceptions.jsonl', 'r') as f: misconceptions = MisconceptionDataset.from_jsonl(f) sentences = ['Lorem ipsum', 'dolor sit amet'] score_a = self.detector.score(sentences[0], misconceptions) score_b = self.detector.score(sentences, misconceptions)[0] assert np.allclose(score_a, score_b)
def test_scores_not_affected_by_padding(self): # Checks that score works on individual sentences as well as lists of sentences, and that # padding does not affect scores. with open('tests/fixtures/misconceptions.jsonl', 'r') as f: misconceptions = MisconceptionDataset.from_jsonl(f) sentences = ['Lorem ipsum', 'dolor sit amet'] score_a = self.detector.score(sentences[0], misconceptions) score_b = self.detector.score(sentences, misconceptions)[0] assert np.allclose(score_a, score_b)
def test_from_jsonl(self): with open('tests/fixtures/misconceptions.jsonl', 'r') as f: misconceptions = MisconceptionDataset.from_jsonl(f) assert len(misconceptions) == 1 expected = Misconception(id=1, canonical_sentence="Don't lick faces", sources=("https://www.google.com", ), category=tuple(), pos_variations=("Don't lick faces", ), neg_variations=tuple(), reliability_score=1, origin="Sameer") assert misconceptions[0] == expected
def test_pipeline(): # TODO: Periodically reload. logger.info('Loading misconceptions') with open('misconceptions.jsonl', 'r') as f: misconceptions = MisconceptionDataset.from_jsonl(f) logger.info('Loading models') retriever = BertScoreDetector('digitalepidemiologylab/covid-twitter-bert') detector = SentenceBertClassifier( model_name='digitalepidemiologylab/covid-twitter-bert', num_classes=3, ) state_dict = torch.load('/home/rlogan/SBERT-MNLI-ckpt-2.pt') logger.info('Restoring detector checkpoint') detector.load_state_dict(state_dict) pipeline = Pipeline(retriever=retriever, detector=detector) pipeline( "North Korea and China conspired together to create the coronavirus.", misconceptions)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, help='input data') parser.add_argument('--output', type=str, help='output data') parser.add_argument('--misconceptions', type=str, help='JSONL file containing misconceptions') parser.add_argument('--score_type', type=str, choices=('precision', 'recall', 'f1'), default='f1') parser.add_argument('--k', type=int, default=5) args = parser.parse_args() detector = BertScoreDetector('covid-roberta/checkpoint-84500', score_type=args.score_type) detector.eval() detector.cuda() with open(args.misconceptions, 'r') as f: misconceptions = MisconceptionDataset.from_jsonl(f) top_scoring_tweets = [MaxHeap(args.k) for _ in range(len(misconceptions))] for sentences in tqdm(generate_sentences(args.input)): scores = detector.score(sentences, misconceptions) # Top-k predictions per misconception top_k = scores.argsort(axis=0)[::-1, :][:args.k] for misconception_idx in range(len(misconceptions)): for sentence_idx in top_k[:, misconception_idx]: # Add (score, sentence) tuple to heap x = (scores[sentence_idx, misconception_idx], sentences[sentence_idx]) top_scoring_tweets[misconception_idx].push(x) with open(args.output, 'w') as f: fieldnames = ['id', 'pos_variation', 'tweet', 'score'] writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for i, pos_variation in enumerate(misconceptions.sentences): id_ = misconceptions[i].id for score, tweet in top_scoring_tweets[i].view(): writer.writerow({ 'id': id_, 'pos_variation': pos_variation, 'tweet': tweet, 'score': score })
print('{0} cells updated.'.format(result.get('updatedCells'))) def read_dataset(self): return NotImplementedError # result = sheet.values().get(spreadsheetId=SAMPLE_SPREADSHEET_ID, # range=SAMPLE_RANGE_NAME).execute() # values = result.get('values', []) # if not values: # print('No data found.') # else: # for row in values: # print(row) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-d', '--dataset_file', type=Path, default=Path('misconceptions.jsonl')) parser.add_argument('-r', '--range_name', type=str, default='Wikipedia!A1') args = parser.parse_args() with open(args.dataset_file, 'r') as f: misconceptions = MisconceptionDataset.from_jsonl(f) obj = MisconceptionDatasetToGSheets() obj.start_service() obj.write_dataset(misconceptions, args.range_name)
def setUp(self): self.detector = MockDetector() with open('tests/fixtures/misconceptions.jsonl', 'r') as f: misconceptions = MisconceptionDataset.from_jsonl(f) self.misconceptions = misconceptions