Beispiel #1
0
    def test_batchExtractMaxLen(self):
        # test maxlen
        extractor = AnswerExtractor(self.mockModel, torch.device("cpu"))

        with mock.patch.object(AbstractReader, "scores2logSpanProb",
                               self.mockModel.scores2logSpanProb):
            answers, scores, passageIds, spanCharOff = extractor.batchExtract(
                self.batch, 1, 1)
            self.assertListEqual(answers, ["Iris"])
            self.assertEqual(len(scores), 1)
            self.assertAlmostEqual(scores[0], -0.9327521295671886)
            self.assertListEqual(passageIds, [self.batch.ids[0]])
            self.assertListEqual(spanCharOff, [(1, 5)])
Beispiel #2
0
 def test_init_gpu(self):
     if torch.cuda.is_available():
         dev = torch.device("cuda:0")
         extractor = AnswerExtractor(self.model, dev)
         self.assertEqual(extractor.device, dev)
         self.assertEqual(next(extractor.model.parameters()).device, dev)
     else:
         self.skipTest("Cuda device is not available.")
Beispiel #3
0
def run_reader_extractive(checkpointDict, reader_output, reranker_output):
    ext_reader_cfg = config["reader"]["extractive"]["config"]
    cache_dir = config["transformers_cache"]

    checkpointDict["config"][
        "cache"] = cache_dir  # overwrite the old loaded cache path
    model = Reader(checkpointDict["config"], initPretrainedWeights=False)
    Checkpoint.loadModel(model, checkpointDict, config["device"])

    if "multi_gpu" in ext_reader_cfg and ext_reader_cfg[
            "multi_gpu"] and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
        logging.info("DataParallel active!")

    extractor = AnswerExtractor(model, config["device"])
    extractor.model.eval()
    tokenizer = AutoTokenizer.from_pretrained(
        checkpointDict["config"]['tokenizer_type'],
        cache_dir=cache_dir,
        use_fast=True)
    database = get_database_path()
    database = PassDatabase(database)
    with ReaderDataset(
            reranker_output, tokenizer, database, ext_reader_cfg["batch_size"],
            checkpointDict["config"]['include_doc_title']) as dataset:
        logging.info(f"Extracting top k answers scores")
        res = {}
        for i, (query, answers, scores, passageIds, charOffsets) in \
                tqdm(enumerate(extractor.extract(dataset,
                                                 ext_reader_cfg["top_k_answers"],
                                                 ext_reader_cfg["max_tokens_for_answer"])),
                     total=len(dataset)):
            res[i] = {
                "raw_question": query,
                "answers": answers,
                "reader_scores": scores,
                "passages": passageIds,
                "char_offsets": charOffsets
            }

        with jsonlines.open(reader_output, "w") as wF:
            for _, record in res.items():
                wF.write(record)
Beispiel #4
0
    def test_extract_max_len(self):
        extractor = AnswerExtractor(self.mockModel, torch.device("cpu"))

        expectedAnswersMaxLen = [["Iris"], ["Some"], ["the"], ["were"],
                                 ["Johnson"]]

        gtSpanCharOffset = [[(1, 5)], [(1, 5)], [(1, 4)], [(1, 5)], [(1, 8)]]

        with mock.patch.object(AbstractReader, "scores2logSpanProb",
                               self.mockModel.scores2logSpanProb):
            # test maxlen
            for i, (query, answers, scores, passageIds,
                    spanCharOff) in enumerate(
                        extractor.extract(self.dataset, 1, 1)):
                self.assertEqual(query, self.gtQueries[i])
                self.assertListEqual(answers, expectedAnswersMaxLen[i])
                self.assertEqual(len(scores), 1)
                self.assertAlmostEqual(scores[0], -0.9327521295671886)
                self.assertListEqual(passageIds, [self.gtPassageIds[i]])
                self.assertListEqual(spanCharOff, gtSpanCharOffset[i])
Beispiel #5
0
    def test_extract(self):
        extractor = AnswerExtractor(self.mockModel, torch.device("cpu"))

        with mock.patch.object(AbstractReader, "scores2logSpanProb",
                               self.mockModel.scores2logSpanProb):
            for i, (query, answers, scores, passageIds,
                    spanCharOff) in enumerate(
                        extractor.extract(self.dataset, 2)):
                self.assertEqual(query, self.gtQueries[i])
                self.assertListEqual(answers, self.expectedAnswers[i])
                self.assertEqual(len(scores), 2)
                self.assertAlmostEqual(scores[0], -0.5)
                self.assertAlmostEqual(scores[1], -0.9327521295671886)
                self.assertListEqual(
                    passageIds, [self.gtPassageIds[i], self.gtPassageIds[i]])
                self.assertListEqual(spanCharOff, self.gtSpanCharOffset[i])

            self.assertFalse(self.dataset.calledActivateMultiprocessing)
            self.assertTrue(self.dataset.storedSkipAnswerMatching)
            self.assertFalse(self.dataset.storedUseGroundTruthPassage)
Beispiel #6
0
 def test_init(self):
     dev = torch.device("cpu")
     extractor = AnswerExtractor(self.model, dev)
     self.assertEqual(extractor.device, dev)
     self.assertEqual(next(extractor.model.parameters()).device, dev)