Esempio n. 1
0
def test_pred_sampler(monkeypatch, tmpdir):
    benchmark = DummyBenchmark()
    extractor = EmbedText({"tokenizer": {
        "keepstops": True
    }},
                          provide={"collection": benchmark.collection})
    search_run = {"301": {"LA010189-0001": 50, "LA010189-0002": 100}}
    pred_dataset = PredSampler()
    pred_dataset.prepare(benchmark.qrels, search_run, extractor)

    def mock_id2vec(*args, **kwargs):
        return {
            "query": np.array([1, 2, 3, 4]),
            "posdoc": np.array([1, 1, 1, 1])
        }

    monkeypatch.setattr(EmbedText, "id2vec", mock_id2vec)
    dataloader = torch.utils.data.DataLoader(pred_dataset, batch_size=2)
    for idx, batch in enumerate(dataloader):
        print(idx, batch)
        assert len(batch["query"]) == 2
        assert len(batch["posdoc"]) == 2
        assert batch.get("negdoc") is None
        assert np.array_equal(batch["query"][0], np.array([1, 2, 3, 4]))
        assert np.array_equal(batch["query"][1], np.array([1, 2, 3, 4]))
        assert np.array_equal(batch["posdoc"][0], np.array([1, 1, 1, 1]))
        assert np.array_equal(batch["posdoc"][1], np.array([1, 1, 1, 1]))
Esempio n. 2
0
    def predict(self):
        fold = self.config["fold"]
        self.rank.search()
        threshold = self.config["threshold"]
        rank_results = self.rank.evaluate()
        best_search_run_path = rank_results["path"][fold]
        best_search_run = Searcher.load_trec_run(best_search_run_path)

        docids = set(docid for querydocs in best_search_run.values() for docid in querydocs)
        self.reranker.extractor.preprocess(
            qids=best_search_run.keys(), docids=docids, topics=self.benchmark.topics[self.benchmark.query_type]
        )
        train_output_path = self.get_results_path()
        self.reranker.build_model()
        self.reranker.trainer.load_best_model(self.reranker, train_output_path)

        test_run = defaultdict(dict)
        # This is possible because best_search_run is an OrderedDict
        for qid, docs in best_search_run.items():
            if qid in self.benchmark.folds[fold]["predict"]["test"]:
                for idx, (docid, score) in enumerate(docs.items()):
                    if idx >= threshold:
                        break
                    test_run[qid][docid] = score

        test_dataset = PredSampler()
        test_dataset.prepare(
            test_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level
        )
        test_output_path = train_output_path / "pred" / "test" / "best"
        test_preds = self.reranker.trainer.predict(self.reranker, test_dataset, test_output_path)

        preds = {"test": test_preds}

        return preds
Esempio n. 3
0
def test_dssm_unigram(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    benchmark = DummyBenchmark()
    reranker = DSSM(
        {
            "nhiddens": "56",
            "trainer": {
                "niters": 1,
                "itersize": 4,
                "batch": 2
            }
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
Esempio n. 4
0
def test_birch(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    benchmark = DummyBenchmark()
    reranker = Birch({"trainer": {
        "niters": 1,
        "itersize": 2,
        "batch": 2
    }},
                     provide={
                         "index": dummy_index,
                         "benchmark": benchmark
                     })
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()
    reranker.searcher_scores = {
        "301": {
            "LA010189-0001": 2,
            "LA010189-0002": 1
        }
    }
    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)
Esempio n. 5
0
def test_bertmaxp_ce(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    benchmark = DummyBenchmark({"collection": {"name": "dummy"}})
    reranker = TFBERTMaxP(
        {
            "pretrained": "bert-base-uncased",
            "extractor": {
                "name": "bertpassage",
                "usecache": False,
                "maxseqlen": 32,
                "numpassages": 2,
                "passagelen": 15,
                "stride": 5,
                "index": {
                    "name": "anserini",
                    "indexstops": False,
                    "stemmer": "porter",
                    "collection": {
                        "name": "dummy"
                    }
                },
            },
            "trainer": {
                "name": "tensorflow",
                "batch": 4,
                "niters": 1,
                "itersize": 2,
                "lr": 0.001,
                "validatefreq": 1,
                "usecache": False,
                "tpuname": None,
                "tpuzone": None,
                "storage": None,
                "boardname": "default",
                "loss": "crossentropy",
                "eager": False,
            },
        },
        provide=benchmark,
    )

    reranker.extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                                  benchmark.topics[benchmark.query_type])
    reranker.build_model()
    reranker.bm25_scores = {"301": {"LA010189-0001": 2, "LA010189-0002": 1}}
    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainPairSampler()
    train_dataset.prepare(train_run, benchmark.qrels, reranker.extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, reranker.extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, "map")
Esempio n. 6
0
def test_tk(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0}

    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)

    benchmark = DummyBenchmark()
    reranker = TK(
        {
            "gradkernels": True,
            "scoretanh": False,
            "singlefc": True,
            "projdim": 32,
            "ffdim": 100,
            "numlayers": 2,
            "numattheads": 4,
            "alpha": 0.5,
            "usemask": False,
            "usemixer": True,
            "finetune": True,
            "trainer": {
                "niters": 1,
                "itersize": 4,
                "batch": 2
            },
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
Esempio n. 7
0
def test_CDSSM(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0}

    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)

    benchmark = DummyBenchmark()
    reranker = CDSSM(
        {
            "nkernel": 3,
            "nfilter": 1,
            "nhiddens": 30,
            "windowsize": 3,
            "dropoutrate": 0,
            "trainer": {
                "niters": 1,
                "itersize": 2,
                "batch": 1
            },
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()
    reranker.searcher_scores = {
        "301": {
            "LA010189-0001": 2,
            "LA010189-0002": 1
        }
    }
    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)
Esempio n. 8
0
def test_knrm_tf_ce(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        vectors = np.zeros((8, 32))
        stoi = defaultdict(lambda x: 0)
        itos = defaultdict(lambda x: "dummy")

        return vectors, stoi, itos

    monkeypatch.setattr(capreolus.extractor.common,
                        "load_pretrained_embeddings", fake_magnitude_embedding)
    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)
    benchmark = DummyBenchmark()
    reranker = TFKNRM(
        {
            "gradkernels": True,
            "finetune": False,
            "trainer": {
                "niters": 1,
                "itersize": 4,
                "batch": 2,
                "loss": "binary_crossentropy"
            },
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainPairSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best.index")
Esempio n. 9
0
def test_knrm_pytorch(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_load_embeddings(self):
        self.embeddings = np.zeros((1, 50))
        self.stoi = {"<pad>": 0}
        self.itos = {v: k for k, v in self.stoi.items()}

    monkeypatch.setattr(EmbedText, "_load_pretrained_embeddings",
                        fake_load_embeddings)

    benchmark = DummyBenchmark()
    reranker = KNRM(
        {
            "gradkernels": True,
            "scoretanh": False,
            "singlefc": True,
            "finetune": False,
            "trainer": {
                "niters": 1,
                "itersize": 4,
                "batch": 2
            },
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)

    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
Esempio n. 10
0
def test_deeptilebars(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return Magnitude(None)

    monkeypatch.setattr(DeepTileExtractor, "_get_pretrained_emb",
                        fake_magnitude_embedding)
    benchmark = DummyBenchmark()
    reranker = DeepTileBar(
        {
            "name": "DeepTileBar",
            "passagelen": 30,
            "numberfilter": 3,
            "lstmhiddendim": 3,
            "linearhiddendim1": 32,
            "linearhiddendim2": 16,
            "trainer": {
                "niters": 1,
                "itersize": 4,
                "batch": 2
            },
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
Esempio n. 11
0
def test_HINT(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0}

    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)

    benchmark = DummyBenchmark()
    reranker = HINT(
        {
            "spatialGRU": 2,
            "LSTMdim": 6,
            "kmax": 10,
            "trainer": {
                "niters": 1,
                "itersize": 2,
                "batch": 1
            }
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
Esempio n. 12
0
    def evaluate(self):
        fold = self.config["fold"]
        train_output_path = self.get_results_path()
        test_output_path = train_output_path / "pred" / "test" / "best"
        logger.debug("results path: %s", train_output_path)

        if os.path.exists(test_output_path):
            test_preds = Searcher.load_trec_run(test_output_path)
        else:
            self.rank.search()
            rank_results = self.rank.evaluate()
            best_search_run_path = rank_results["path"][fold]
            best_search_run = Searcher.load_trec_run(best_search_run_path)

            docids = set(docid for querydocs in best_search_run.values()
                         for docid in querydocs)
            self.reranker.extractor.preprocess(
                qids=best_search_run.keys(),
                docids=docids,
                topics=self.benchmark.topics[self.benchmark.query_type])
            self.reranker.build_model()
            self.reranker.searcher_scores = best_search_run

            self.reranker.trainer.load_best_model(self.reranker,
                                                  train_output_path)

            test_run = {
                qid: docs
                for qid, docs in best_search_run.items()
                if qid in self.benchmark.folds[fold]["predict"]["test"]
            }
            test_dataset = PredSampler()
            test_dataset.prepare(test_run, self.benchmark.qrels,
                                 self.reranker.extractor)

            test_preds = self.reranker.trainer.predict(self.reranker,
                                                       test_dataset,
                                                       test_output_path)

        metrics = evaluator.eval_runs(test_preds, self.benchmark.qrels,
                                      evaluator.DEFAULT_METRICS,
                                      self.benchmark.relevance_level)
        logger.info("rerank: fold=%s test metrics: %s", fold, metrics)

        print("\ncomputing metrics across all folds")
        avg = {}
        found = 0
        for fold in self.benchmark.folds:
            # TODO fix by using multiple Tasks
            from pathlib import Path

            pred_path = Path(test_output_path.as_posix().replace(
                "fold-" + self.config["fold"], "fold-" + fold))
            if not os.path.exists(pred_path):
                print(
                    "\tfold=%s results are missing and will not be included" %
                    fold)
                continue

            found += 1
            preds = Searcher.load_trec_run(pred_path)
            metrics = evaluator.eval_runs(preds, self.benchmark.qrels,
                                          evaluator.DEFAULT_METRICS,
                                          self.benchmark.relevance_level)
            for metric, val in metrics.items():
                avg.setdefault(metric, []).append(val)

        avg = {k: np.mean(v) for k, v in avg.items()}
        logger.info(
            "rerank: average cross-validated metrics when choosing iteration based on '%s':",
            self.config["optimize"])
        for metric, score in sorted(avg.items()):
            logger.info("%25s: %0.4f", metric, score)
    def rerank_run(self,
                   best_search_run,
                   train_output_path,
                   include_train=False,
                   init_path=""):
        if not isinstance(train_output_path, Path):
            train_output_path = Path(train_output_path)

        fold = self.config["fold"]
        threshold = self.config["threshold"]
        dev_output_path = train_output_path / "pred" / "dev"
        logger.debug("results path: %s", train_output_path)

        docids = set(docid for querydocs in best_search_run.values()
                     for docid in querydocs)
        self.reranker.extractor.preprocess(
            qids=best_search_run.keys(),
            docids=docids,
            topics=self.benchmark.topics[self.benchmark.query_type])

        self.reranker.build_model()
        self.reranker.searcher_scores = best_search_run

        train_run = {
            qid: docs
            for qid, docs in best_search_run.items()
            if qid in self.benchmark.folds[fold]["train_qids"]
        }

        # For each qid, select the top 100 (defined by config["threshold") docs to be used in validation
        dev_run = defaultdict(dict)
        # This is possible because best_search_run is an OrderedDict
        for qid, docs in best_search_run.items():
            if qid in self.benchmark.folds[fold]["predict"]["dev"]:
                for idx, (docid, score) in enumerate(docs.items()):
                    if idx >= threshold:
                        assert len(
                            dev_run[qid]
                        ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}"
                        break
                    dev_run[qid][docid] = score

        # Depending on the sampler chosen, the dataset may generate triplets or pairs
        train_dataset = self.sampler
        train_dataset.prepare(
            train_run,
            self.benchmark.qrels,
            self.reranker.extractor,
            relevance_level=self.benchmark.relevance_level,
        )
        dev_dataset = PredSampler()
        dev_dataset.prepare(
            dev_run,
            self.benchmark.qrels,
            self.reranker.extractor,
            relevance_level=self.benchmark.relevance_level,
        )

        train_args = [
            self.reranker, train_dataset, train_output_path, dev_dataset,
            dev_output_path, self.benchmark.qrels, self.config["optimize"],
            self.benchmark.relevance_level
        ]
        if self.reranker.trainer.module_name == "tensorflowlog":
            self.reranker.trainer.train(*train_args, init_path=init_path)
        else:
            self.reranker.trainer.train(*train_args)

        self.reranker.trainer.load_best_model(self.reranker, train_output_path)
        dev_output_path = train_output_path / "pred" / "dev" / "best"
        dev_preds = self.reranker.trainer.predict(self.reranker, dev_dataset,
                                                  dev_output_path)
        shutil.copy(dev_output_path, dev_output_path.parent / "dev.best")
        wandb.save(str(dev_output_path.parent / "dev.best"))

        test_run = defaultdict(dict)
        # This is possible because best_search_run is an OrderedDict
        for qid, docs in best_search_run.items():
            if qid in self.benchmark.folds[fold]["predict"]["test"]:
                for idx, (docid, score) in enumerate(docs.items()):
                    if idx >= threshold:
                        assert len(
                            test_run[qid]
                        ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}"
                        break
                    test_run[qid][docid] = score

        test_dataset = PredSampler()
        test_dataset.prepare(test_run,
                             self.benchmark.unsampled_qrels,
                             self.reranker.extractor,
                             relevance_level=self.benchmark.relevance_level)
        test_output_path = train_output_path / "pred" / "test" / "best"
        test_preds = self.reranker.trainer.predict(self.reranker, test_dataset,
                                                   test_output_path)
        shutil.copy(test_output_path, test_output_path.parent / "test.best")
        wandb.save(str(test_output_path.parent / "test.best"))

        preds = {"dev": dev_preds, "test": test_preds}

        if include_train:
            train_dataset = PredSampler(
                train_run,
                self.benchmark.qrels,
                self.reranker.extractor,
                relevance_level=self.benchmark.relevance_level,
            )

            train_output_path = train_output_path / "pred" / "train" / "best"
            train_preds = self.reranker.trainer.predict(
                self.reranker, train_dataset, train_output_path)
            preds["train"] = train_preds

        return preds
    def predict_and_eval(self, init_path=None):
        fold = self.config["fold"]
        self.reranker.build_model()
        if not init_path or init_path == "none":
            logger.info(f"Loading self best ckpt: {init_path}")
            logger.info("No init path given, using default parameters")
            self.reranker.build_model()
        else:
            logger.info(f"Load from {init_path}")
            init_path = Path(
                init_path) if not init_path.startswith("gs:") else init_path
            self.reranker.trainer.load_best_model(self.reranker,
                                                  init_path,
                                                  do_not_hash=True)

        dirname = str(init_path).split("/")[-1] if init_path else "noinitpath"
        savedir = Path(
            __file__).parent.absolute() / "downloaded_runfiles" / dirname
        dev_output_path = savedir / fold / "dev"
        test_output_path = savedir / fold / "test"
        test_output_path.parent.mkdir(exist_ok=True, parents=True)

        self.rank.search()
        threshold = self.config["threshold"]
        rank_results = self.rank.evaluate()
        best_search_run_path = rank_results["path"][fold]
        best_search_run = Searcher.load_trec_run(best_search_run_path)

        docids = set(docid for querydocs in best_search_run.values()
                     for docid in querydocs)
        self.reranker.extractor.preprocess(
            qids=best_search_run.keys(),
            docids=docids,
            topics=self.benchmark.topics[self.benchmark.query_type])

        # dev run
        dev_run = defaultdict(dict)
        for qid, docs in best_search_run.items():
            if qid in self.benchmark.folds[fold]["predict"]["dev"]:
                for idx, (docid, score) in enumerate(docs.items()):
                    if idx >= threshold:
                        assert len(
                            dev_run[qid]
                        ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}"
                        break
                    dev_run[qid][docid] = score
        dev_dataset = PredSampler()
        dev_dataset.prepare(dev_run,
                            self.benchmark.qrels,
                            self.reranker.extractor,
                            relevance_level=self.benchmark.relevance_level)

        # test_run
        test_run = defaultdict(dict)
        # This is possible because best_search_run is an OrderedDict
        for qid, docs in best_search_run.items():
            if qid in self.benchmark.folds[fold]["predict"]["test"]:
                for idx, (docid, score) in enumerate(docs.items()):
                    if idx >= threshold:
                        assert len(
                            test_run[qid]
                        ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}"
                        break
                    test_run[qid][docid] = score

        unsampled_qrels = self.benchmark.unsampled_qrels if hasattr(
            self.benchmark, "unsampled_qrels") else self.benchmark.qrels
        test_dataset = PredSampler()
        test_dataset.prepare(test_run,
                             unsampled_qrels,
                             self.reranker.extractor,
                             relevance_level=self.benchmark.relevance_level)
        logger.info("test prepared")

        # prediction
        dev_preds = self.reranker.trainer.predict(self.reranker, dev_dataset,
                                                  dev_output_path)
        fold_dev_metrics = evaluator.eval_runs(dev_preds, unsampled_qrels,
                                               self.metrics,
                                               self.benchmark.relevance_level)
        logger.info("rerank: fold=%s dev metrics: %s", fold, fold_dev_metrics)

        test_preds = self.reranker.trainer.predict(self.reranker, test_dataset,
                                                   test_output_path)
        fold_test_metrics = evaluator.eval_runs(test_preds, unsampled_qrels,
                                                self.metrics,
                                                self.benchmark.relevance_level)
        logger.info("rerank: fold=%s test metrics: %s", fold,
                    fold_test_metrics)
        wandb.save(str(dev_output_path))
        wandb.save(str(test_output_path))

        # add cross validate results:
        n_folds = len(self.benchmark.folds)
        folds_fn = {
            f"s{i}": savedir / f"s{i}" / "test"
            for i in range(1, n_folds + 1)
        }
        if not all([fn.exists() for fn in folds_fn.values()]):
            return {"fold_test_metrics": fold_test_metrics, "cv_metrics": None}

        all_preds = {}
        reranker_runs = {
            fold: {
                "dev": Searcher.load_trec_run(fn.parent / "dev"),
                "test": Searcher.load_trec_run(fn)
            }
            for fold, fn in folds_fn.items()
        }

        for fold, dev_test in reranker_runs.items():
            preds = dev_test["test"]
            qids = self.benchmark.folds[fold]["predict"]["test"]
            for qid, docscores in preds.items():
                if qid not in qids:
                    continue
                all_preds.setdefault(qid, {})
                for docid, score in docscores.items():
                    all_preds[qid][docid] = score

        cv_metrics = evaluator.eval_runs(all_preds, unsampled_qrels,
                                         self.metrics,
                                         self.benchmark.relevance_level)
        for metric, score in sorted(cv_metrics.items()):
            logger.info("%25s: %0.4f", metric, score)

        searcher_runs = {}
        rank_results = self.rank.evaluate()
        for fold in self.benchmark.folds:
            searcher_runs[fold] = {
                "dev": Searcher.load_trec_run(rank_results["path"][fold])
            }
            searcher_runs[fold]["test"] = searcher_runs[fold]["dev"]

        interpolated_results = evaluator.interpolated_eval(
            searcher_runs, reranker_runs, self.benchmark,
            self.config["optimize"], self.metrics)

        return {
            "fold_test_metrics": fold_test_metrics,
            "cv_metrics": cv_metrics,
            "interpolated_results": interpolated_results,
        }
Esempio n. 15
0
    def rerank_run(self,
                   best_search_run,
                   train_output_path,
                   include_train=False):
        if not isinstance(train_output_path, Path):
            train_output_path = Path(train_output_path)

        fold = self.config["fold"]
        dev_output_path = train_output_path / "pred" / "dev"
        logger.debug("results path: %s", train_output_path)

        docids = set(docid for querydocs in best_search_run.values()
                     for docid in querydocs)
        self.reranker.extractor.preprocess(
            qids=best_search_run.keys(),
            docids=docids,
            topics=self.benchmark.topics[self.benchmark.query_type])
        self.reranker.build_model()
        self.reranker.searcher_scores = best_search_run

        train_run = {
            qid: docs
            for qid, docs in best_search_run.items()
            if qid in self.benchmark.folds[fold]["train_qids"]
        }
        # For each qid, select the top 100 (defined by config["threshold") docs to be used in validation
        dev_run = defaultdict(dict)
        # This is possible because best_search_run is an OrderedDict
        for qid, docs in best_search_run.items():
            if qid in self.benchmark.folds[fold]["predict"]["dev"]:
                for idx, (docid, score) in enumerate(docs.items()):
                    if idx >= self.config["threshold"]:
                        break
                    dev_run[qid][docid] = score

        # Depending on the sampler chosen, the dataset may generate triplets or pairs
        train_dataset = self.sampler
        train_dataset.prepare(train_run,
                              self.benchmark.qrels,
                              self.reranker.extractor,
                              relevance_level=self.benchmark.relevance_level)
        dev_dataset = PredSampler()
        dev_dataset.prepare(dev_run,
                            self.benchmark.qrels,
                            self.reranker.extractor,
                            relevance_level=self.benchmark.relevance_level)

        dev_preds = self.reranker.trainer.train(
            self.reranker,
            train_dataset,
            train_output_path,
            dev_dataset,
            dev_output_path,
            self.benchmark.qrels,
            self.config["optimize"],
            self.benchmark.relevance_level,
        )

        self.reranker.trainer.load_best_model(self.reranker, train_output_path)
        dev_output_path = train_output_path / "pred" / "dev" / "best"
        if not dev_output_path.exists():
            dev_preds = self.reranker.trainer.predict(self.reranker,
                                                      dev_dataset,
                                                      dev_output_path)

        test_run = defaultdict(dict)
        # This is possible because best_search_run is an OrderedDict
        for qid, docs in best_search_run.items():
            if qid in self.benchmark.folds[fold]["predict"]["test"]:
                for idx, (docid, score) in enumerate(docs.items()):
                    if idx >= self.config["testthreshold"]:
                        break
                    test_run[qid][docid] = score

        test_dataset = PredSampler()
        test_dataset.prepare(test_run,
                             self.benchmark.qrels,
                             self.reranker.extractor,
                             relevance_level=self.benchmark.relevance_level)
        test_output_path = train_output_path / "pred" / "test" / "best"
        test_preds = self.reranker.trainer.predict(self.reranker, test_dataset,
                                                   test_output_path)

        preds = {"dev": dev_preds, "test": test_preds}

        if include_train:
            train_dataset = PredSampler(
                train_run,
                self.benchmark.qrels,
                self.reranker.extractor,
                relevance_level=self.benchmark.relevance_level)

            train_output_path = train_output_path / "pred" / "train" / "best"
            train_preds = self.reranker.trainer.predict(
                self.reranker, train_dataset, train_output_path)
            preds["train"] = train_preds

        return preds