def test_pred_sampler(monkeypatch, tmpdir): benchmark = DummyBenchmark() extractor = EmbedText({"tokenizer": { "keepstops": True }}, provide={"collection": benchmark.collection}) search_run = {"301": {"LA010189-0001": 50, "LA010189-0002": 100}} pred_dataset = PredSampler() pred_dataset.prepare(benchmark.qrels, search_run, extractor) def mock_id2vec(*args, **kwargs): return { "query": np.array([1, 2, 3, 4]), "posdoc": np.array([1, 1, 1, 1]) } monkeypatch.setattr(EmbedText, "id2vec", mock_id2vec) dataloader = torch.utils.data.DataLoader(pred_dataset, batch_size=2) for idx, batch in enumerate(dataloader): print(idx, batch) assert len(batch["query"]) == 2 assert len(batch["posdoc"]) == 2 assert batch.get("negdoc") is None assert np.array_equal(batch["query"][0], np.array([1, 2, 3, 4])) assert np.array_equal(batch["query"][1], np.array([1, 2, 3, 4])) assert np.array_equal(batch["posdoc"][0], np.array([1, 1, 1, 1])) assert np.array_equal(batch["posdoc"][1], np.array([1, 1, 1, 1]))
def predict(self): fold = self.config["fold"] self.rank.search() threshold = self.config["threshold"] rank_results = self.rank.evaluate() best_search_run_path = rank_results["path"][fold] best_search_run = Searcher.load_trec_run(best_search_run_path) docids = set(docid for querydocs in best_search_run.values() for docid in querydocs) self.reranker.extractor.preprocess( qids=best_search_run.keys(), docids=docids, topics=self.benchmark.topics[self.benchmark.query_type] ) train_output_path = self.get_results_path() self.reranker.build_model() self.reranker.trainer.load_best_model(self.reranker, train_output_path) test_run = defaultdict(dict) # This is possible because best_search_run is an OrderedDict for qid, docs in best_search_run.items(): if qid in self.benchmark.folds[fold]["predict"]["test"]: for idx, (docid, score) in enumerate(docs.items()): if idx >= threshold: break test_run[qid][docid] = score test_dataset = PredSampler() test_dataset.prepare( test_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level ) test_output_path = train_output_path / "pred" / "test" / "best" test_preds = self.reranker.trainer.predict(self.reranker, test_dataset, test_output_path) preds = {"test": test_preds} return preds
def test_dssm_unigram(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): benchmark = DummyBenchmark() reranker = DSSM( { "nhiddens": "56", "trainer": { "niters": 1, "itersize": 4, "batch": 2 } }, provide={ "index": dummy_index, "benchmark": benchmark }, ) extractor = reranker.extractor metric = "map" extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"], benchmark.topics[benchmark.query_type]) reranker.build_model() train_run = {"301": ["LA010189-0001", "LA010189-0002"]} train_dataset = TrainTripletSampler() train_dataset.prepare(train_run, benchmark.qrels, extractor) dev_dataset = PredSampler() dev_dataset.prepare(train_run, benchmark.qrels, extractor) reranker.trainer.train(reranker, train_dataset, Path(tmpdir) / "train", dev_dataset, Path(tmpdir) / "dev", benchmark.qrels, metric) assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
def test_birch(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): benchmark = DummyBenchmark() reranker = Birch({"trainer": { "niters": 1, "itersize": 2, "batch": 2 }}, provide={ "index": dummy_index, "benchmark": benchmark }) extractor = reranker.extractor metric = "map" extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"], benchmark.topics[benchmark.query_type]) reranker.build_model() reranker.searcher_scores = { "301": { "LA010189-0001": 2, "LA010189-0002": 1 } } train_run = {"301": ["LA010189-0001", "LA010189-0002"]} train_dataset = TrainTripletSampler() train_dataset.prepare(train_run, benchmark.qrels, extractor) dev_dataset = PredSampler() dev_dataset.prepare(train_run, benchmark.qrels, extractor) reranker.trainer.train(reranker, train_dataset, Path(tmpdir) / "train", dev_dataset, Path(tmpdir) / "dev", benchmark.qrels, metric)
def test_bertmaxp_ce(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): benchmark = DummyBenchmark({"collection": {"name": "dummy"}}) reranker = TFBERTMaxP( { "pretrained": "bert-base-uncased", "extractor": { "name": "bertpassage", "usecache": False, "maxseqlen": 32, "numpassages": 2, "passagelen": 15, "stride": 5, "index": { "name": "anserini", "indexstops": False, "stemmer": "porter", "collection": { "name": "dummy" } }, }, "trainer": { "name": "tensorflow", "batch": 4, "niters": 1, "itersize": 2, "lr": 0.001, "validatefreq": 1, "usecache": False, "tpuname": None, "tpuzone": None, "storage": None, "boardname": "default", "loss": "crossentropy", "eager": False, }, }, provide=benchmark, ) reranker.extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"], benchmark.topics[benchmark.query_type]) reranker.build_model() reranker.bm25_scores = {"301": {"LA010189-0001": 2, "LA010189-0002": 1}} train_run = {"301": ["LA010189-0001", "LA010189-0002"]} train_dataset = TrainPairSampler() train_dataset.prepare(train_run, benchmark.qrels, reranker.extractor) dev_dataset = PredSampler() dev_dataset.prepare(train_run, benchmark.qrels, reranker.extractor) reranker.trainer.train(reranker, train_dataset, Path(tmpdir) / "train", dev_dataset, Path(tmpdir) / "dev", benchmark.qrels, "map")
def test_tk(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): def fake_magnitude_embedding(*args, **kwargs): return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0} monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings", fake_magnitude_embedding) benchmark = DummyBenchmark() reranker = TK( { "gradkernels": True, "scoretanh": False, "singlefc": True, "projdim": 32, "ffdim": 100, "numlayers": 2, "numattheads": 4, "alpha": 0.5, "usemask": False, "usemixer": True, "finetune": True, "trainer": { "niters": 1, "itersize": 4, "batch": 2 }, }, provide={ "index": dummy_index, "benchmark": benchmark }, ) extractor = reranker.extractor metric = "map" extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"], benchmark.topics[benchmark.query_type]) reranker.build_model() train_run = {"301": ["LA010189-0001", "LA010189-0002"]} train_dataset = TrainTripletSampler() train_dataset.prepare(train_run, benchmark.qrels, extractor) dev_dataset = PredSampler() dev_dataset.prepare(train_run, benchmark.qrels, extractor) reranker.trainer.train(reranker, train_dataset, Path(tmpdir) / "train", dev_dataset, Path(tmpdir) / "dev", benchmark.qrels, metric) assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
def test_CDSSM(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): def fake_magnitude_embedding(*args, **kwargs): return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0} monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings", fake_magnitude_embedding) benchmark = DummyBenchmark() reranker = CDSSM( { "nkernel": 3, "nfilter": 1, "nhiddens": 30, "windowsize": 3, "dropoutrate": 0, "trainer": { "niters": 1, "itersize": 2, "batch": 1 }, }, provide={ "index": dummy_index, "benchmark": benchmark }, ) extractor = reranker.extractor metric = "map" extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"], benchmark.topics[benchmark.query_type]) reranker.build_model() reranker.searcher_scores = { "301": { "LA010189-0001": 2, "LA010189-0002": 1 } } train_run = {"301": ["LA010189-0001", "LA010189-0002"]} train_dataset = TrainTripletSampler() train_dataset.prepare(train_run, benchmark.qrels, extractor) dev_dataset = PredSampler() dev_dataset.prepare(train_run, benchmark.qrels, extractor) reranker.trainer.train(reranker, train_dataset, Path(tmpdir) / "train", dev_dataset, Path(tmpdir) / "dev", benchmark.qrels, metric)
def test_knrm_tf_ce(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): def fake_magnitude_embedding(*args, **kwargs): vectors = np.zeros((8, 32)) stoi = defaultdict(lambda x: 0) itos = defaultdict(lambda x: "dummy") return vectors, stoi, itos monkeypatch.setattr(capreolus.extractor.common, "load_pretrained_embeddings", fake_magnitude_embedding) monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings", fake_magnitude_embedding) benchmark = DummyBenchmark() reranker = TFKNRM( { "gradkernels": True, "finetune": False, "trainer": { "niters": 1, "itersize": 4, "batch": 2, "loss": "binary_crossentropy" }, }, provide={ "index": dummy_index, "benchmark": benchmark }, ) extractor = reranker.extractor metric = "map" extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"], benchmark.topics[benchmark.query_type]) reranker.build_model() train_run = {"301": ["LA010189-0001", "LA010189-0002"]} train_dataset = TrainPairSampler() train_dataset.prepare(train_run, benchmark.qrels, extractor) dev_dataset = PredSampler() dev_dataset.prepare(train_run, benchmark.qrels, extractor) reranker.trainer.train(reranker, train_dataset, Path(tmpdir) / "train", dev_dataset, Path(tmpdir) / "dev", benchmark.qrels, metric) assert os.path.exists(Path(tmpdir) / "train" / "dev.best.index")
def test_knrm_pytorch(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): def fake_load_embeddings(self): self.embeddings = np.zeros((1, 50)) self.stoi = {"<pad>": 0} self.itos = {v: k for k, v in self.stoi.items()} monkeypatch.setattr(EmbedText, "_load_pretrained_embeddings", fake_load_embeddings) benchmark = DummyBenchmark() reranker = KNRM( { "gradkernels": True, "scoretanh": False, "singlefc": True, "finetune": False, "trainer": { "niters": 1, "itersize": 4, "batch": 2 }, }, provide={ "index": dummy_index, "benchmark": benchmark }, ) extractor = reranker.extractor metric = "map" extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"], benchmark.topics[benchmark.query_type]) reranker.build_model() train_run = {"301": ["LA010189-0001", "LA010189-0002"]} train_dataset = TrainTripletSampler() train_dataset.prepare(train_run, benchmark.qrels, extractor) dev_dataset = PredSampler() dev_dataset.prepare(train_run, benchmark.qrels, extractor) reranker.trainer.train(reranker, train_dataset, Path(tmpdir) / "train", dev_dataset, Path(tmpdir) / "dev", benchmark.qrels, metric) assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
def test_deeptilebars(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): def fake_magnitude_embedding(*args, **kwargs): return Magnitude(None) monkeypatch.setattr(DeepTileExtractor, "_get_pretrained_emb", fake_magnitude_embedding) benchmark = DummyBenchmark() reranker = DeepTileBar( { "name": "DeepTileBar", "passagelen": 30, "numberfilter": 3, "lstmhiddendim": 3, "linearhiddendim1": 32, "linearhiddendim2": 16, "trainer": { "niters": 1, "itersize": 4, "batch": 2 }, }, provide={ "index": dummy_index, "benchmark": benchmark }, ) extractor = reranker.extractor metric = "map" extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"], benchmark.topics[benchmark.query_type]) reranker.build_model() train_run = {"301": ["LA010189-0001", "LA010189-0002"]} train_dataset = TrainTripletSampler() train_dataset.prepare(train_run, benchmark.qrels, extractor) dev_dataset = PredSampler() dev_dataset.prepare(train_run, benchmark.qrels, extractor) reranker.trainer.train(reranker, train_dataset, Path(tmpdir) / "train", dev_dataset, Path(tmpdir) / "dev", benchmark.qrels, metric) assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
def test_HINT(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): def fake_magnitude_embedding(*args, **kwargs): return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0} monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings", fake_magnitude_embedding) benchmark = DummyBenchmark() reranker = HINT( { "spatialGRU": 2, "LSTMdim": 6, "kmax": 10, "trainer": { "niters": 1, "itersize": 2, "batch": 1 } }, provide={ "index": dummy_index, "benchmark": benchmark }, ) extractor = reranker.extractor metric = "map" extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"], benchmark.topics[benchmark.query_type]) reranker.build_model() train_run = {"301": ["LA010189-0001", "LA010189-0002"]} train_dataset = TrainTripletSampler() train_dataset.prepare(train_run, benchmark.qrels, extractor) dev_dataset = PredSampler() dev_dataset.prepare(train_run, benchmark.qrels, extractor) reranker.trainer.train(reranker, train_dataset, Path(tmpdir) / "train", dev_dataset, Path(tmpdir) / "dev", benchmark.qrels, metric) assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
def evaluate(self): fold = self.config["fold"] train_output_path = self.get_results_path() test_output_path = train_output_path / "pred" / "test" / "best" logger.debug("results path: %s", train_output_path) if os.path.exists(test_output_path): test_preds = Searcher.load_trec_run(test_output_path) else: self.rank.search() rank_results = self.rank.evaluate() best_search_run_path = rank_results["path"][fold] best_search_run = Searcher.load_trec_run(best_search_run_path) docids = set(docid for querydocs in best_search_run.values() for docid in querydocs) self.reranker.extractor.preprocess( qids=best_search_run.keys(), docids=docids, topics=self.benchmark.topics[self.benchmark.query_type]) self.reranker.build_model() self.reranker.searcher_scores = best_search_run self.reranker.trainer.load_best_model(self.reranker, train_output_path) test_run = { qid: docs for qid, docs in best_search_run.items() if qid in self.benchmark.folds[fold]["predict"]["test"] } test_dataset = PredSampler() test_dataset.prepare(test_run, self.benchmark.qrels, self.reranker.extractor) test_preds = self.reranker.trainer.predict(self.reranker, test_dataset, test_output_path) metrics = evaluator.eval_runs(test_preds, self.benchmark.qrels, evaluator.DEFAULT_METRICS, self.benchmark.relevance_level) logger.info("rerank: fold=%s test metrics: %s", fold, metrics) print("\ncomputing metrics across all folds") avg = {} found = 0 for fold in self.benchmark.folds: # TODO fix by using multiple Tasks from pathlib import Path pred_path = Path(test_output_path.as_posix().replace( "fold-" + self.config["fold"], "fold-" + fold)) if not os.path.exists(pred_path): print( "\tfold=%s results are missing and will not be included" % fold) continue found += 1 preds = Searcher.load_trec_run(pred_path) metrics = evaluator.eval_runs(preds, self.benchmark.qrels, evaluator.DEFAULT_METRICS, self.benchmark.relevance_level) for metric, val in metrics.items(): avg.setdefault(metric, []).append(val) avg = {k: np.mean(v) for k, v in avg.items()} logger.info( "rerank: average cross-validated metrics when choosing iteration based on '%s':", self.config["optimize"]) for metric, score in sorted(avg.items()): logger.info("%25s: %0.4f", metric, score)
def rerank_run(self, best_search_run, train_output_path, include_train=False, init_path=""): if not isinstance(train_output_path, Path): train_output_path = Path(train_output_path) fold = self.config["fold"] threshold = self.config["threshold"] dev_output_path = train_output_path / "pred" / "dev" logger.debug("results path: %s", train_output_path) docids = set(docid for querydocs in best_search_run.values() for docid in querydocs) self.reranker.extractor.preprocess( qids=best_search_run.keys(), docids=docids, topics=self.benchmark.topics[self.benchmark.query_type]) self.reranker.build_model() self.reranker.searcher_scores = best_search_run train_run = { qid: docs for qid, docs in best_search_run.items() if qid in self.benchmark.folds[fold]["train_qids"] } # For each qid, select the top 100 (defined by config["threshold") docs to be used in validation dev_run = defaultdict(dict) # This is possible because best_search_run is an OrderedDict for qid, docs in best_search_run.items(): if qid in self.benchmark.folds[fold]["predict"]["dev"]: for idx, (docid, score) in enumerate(docs.items()): if idx >= threshold: assert len( dev_run[qid] ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}" break dev_run[qid][docid] = score # Depending on the sampler chosen, the dataset may generate triplets or pairs train_dataset = self.sampler train_dataset.prepare( train_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level, ) dev_dataset = PredSampler() dev_dataset.prepare( dev_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level, ) train_args = [ self.reranker, train_dataset, train_output_path, dev_dataset, dev_output_path, self.benchmark.qrels, self.config["optimize"], self.benchmark.relevance_level ] if self.reranker.trainer.module_name == "tensorflowlog": self.reranker.trainer.train(*train_args, init_path=init_path) else: self.reranker.trainer.train(*train_args) self.reranker.trainer.load_best_model(self.reranker, train_output_path) dev_output_path = train_output_path / "pred" / "dev" / "best" dev_preds = self.reranker.trainer.predict(self.reranker, dev_dataset, dev_output_path) shutil.copy(dev_output_path, dev_output_path.parent / "dev.best") wandb.save(str(dev_output_path.parent / "dev.best")) test_run = defaultdict(dict) # This is possible because best_search_run is an OrderedDict for qid, docs in best_search_run.items(): if qid in self.benchmark.folds[fold]["predict"]["test"]: for idx, (docid, score) in enumerate(docs.items()): if idx >= threshold: assert len( test_run[qid] ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}" break test_run[qid][docid] = score test_dataset = PredSampler() test_dataset.prepare(test_run, self.benchmark.unsampled_qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level) test_output_path = train_output_path / "pred" / "test" / "best" test_preds = self.reranker.trainer.predict(self.reranker, test_dataset, test_output_path) shutil.copy(test_output_path, test_output_path.parent / "test.best") wandb.save(str(test_output_path.parent / "test.best")) preds = {"dev": dev_preds, "test": test_preds} if include_train: train_dataset = PredSampler( train_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level, ) train_output_path = train_output_path / "pred" / "train" / "best" train_preds = self.reranker.trainer.predict( self.reranker, train_dataset, train_output_path) preds["train"] = train_preds return preds
def predict_and_eval(self, init_path=None): fold = self.config["fold"] self.reranker.build_model() if not init_path or init_path == "none": logger.info(f"Loading self best ckpt: {init_path}") logger.info("No init path given, using default parameters") self.reranker.build_model() else: logger.info(f"Load from {init_path}") init_path = Path( init_path) if not init_path.startswith("gs:") else init_path self.reranker.trainer.load_best_model(self.reranker, init_path, do_not_hash=True) dirname = str(init_path).split("/")[-1] if init_path else "noinitpath" savedir = Path( __file__).parent.absolute() / "downloaded_runfiles" / dirname dev_output_path = savedir / fold / "dev" test_output_path = savedir / fold / "test" test_output_path.parent.mkdir(exist_ok=True, parents=True) self.rank.search() threshold = self.config["threshold"] rank_results = self.rank.evaluate() best_search_run_path = rank_results["path"][fold] best_search_run = Searcher.load_trec_run(best_search_run_path) docids = set(docid for querydocs in best_search_run.values() for docid in querydocs) self.reranker.extractor.preprocess( qids=best_search_run.keys(), docids=docids, topics=self.benchmark.topics[self.benchmark.query_type]) # dev run dev_run = defaultdict(dict) for qid, docs in best_search_run.items(): if qid in self.benchmark.folds[fold]["predict"]["dev"]: for idx, (docid, score) in enumerate(docs.items()): if idx >= threshold: assert len( dev_run[qid] ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}" break dev_run[qid][docid] = score dev_dataset = PredSampler() dev_dataset.prepare(dev_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level) # test_run test_run = defaultdict(dict) # This is possible because best_search_run is an OrderedDict for qid, docs in best_search_run.items(): if qid in self.benchmark.folds[fold]["predict"]["test"]: for idx, (docid, score) in enumerate(docs.items()): if idx >= threshold: assert len( test_run[qid] ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}" break test_run[qid][docid] = score unsampled_qrels = self.benchmark.unsampled_qrels if hasattr( self.benchmark, "unsampled_qrels") else self.benchmark.qrels test_dataset = PredSampler() test_dataset.prepare(test_run, unsampled_qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level) logger.info("test prepared") # prediction dev_preds = self.reranker.trainer.predict(self.reranker, dev_dataset, dev_output_path) fold_dev_metrics = evaluator.eval_runs(dev_preds, unsampled_qrels, self.metrics, self.benchmark.relevance_level) logger.info("rerank: fold=%s dev metrics: %s", fold, fold_dev_metrics) test_preds = self.reranker.trainer.predict(self.reranker, test_dataset, test_output_path) fold_test_metrics = evaluator.eval_runs(test_preds, unsampled_qrels, self.metrics, self.benchmark.relevance_level) logger.info("rerank: fold=%s test metrics: %s", fold, fold_test_metrics) wandb.save(str(dev_output_path)) wandb.save(str(test_output_path)) # add cross validate results: n_folds = len(self.benchmark.folds) folds_fn = { f"s{i}": savedir / f"s{i}" / "test" for i in range(1, n_folds + 1) } if not all([fn.exists() for fn in folds_fn.values()]): return {"fold_test_metrics": fold_test_metrics, "cv_metrics": None} all_preds = {} reranker_runs = { fold: { "dev": Searcher.load_trec_run(fn.parent / "dev"), "test": Searcher.load_trec_run(fn) } for fold, fn in folds_fn.items() } for fold, dev_test in reranker_runs.items(): preds = dev_test["test"] qids = self.benchmark.folds[fold]["predict"]["test"] for qid, docscores in preds.items(): if qid not in qids: continue all_preds.setdefault(qid, {}) for docid, score in docscores.items(): all_preds[qid][docid] = score cv_metrics = evaluator.eval_runs(all_preds, unsampled_qrels, self.metrics, self.benchmark.relevance_level) for metric, score in sorted(cv_metrics.items()): logger.info("%25s: %0.4f", metric, score) searcher_runs = {} rank_results = self.rank.evaluate() for fold in self.benchmark.folds: searcher_runs[fold] = { "dev": Searcher.load_trec_run(rank_results["path"][fold]) } searcher_runs[fold]["test"] = searcher_runs[fold]["dev"] interpolated_results = evaluator.interpolated_eval( searcher_runs, reranker_runs, self.benchmark, self.config["optimize"], self.metrics) return { "fold_test_metrics": fold_test_metrics, "cv_metrics": cv_metrics, "interpolated_results": interpolated_results, }
def rerank_run(self, best_search_run, train_output_path, include_train=False): if not isinstance(train_output_path, Path): train_output_path = Path(train_output_path) fold = self.config["fold"] dev_output_path = train_output_path / "pred" / "dev" logger.debug("results path: %s", train_output_path) docids = set(docid for querydocs in best_search_run.values() for docid in querydocs) self.reranker.extractor.preprocess( qids=best_search_run.keys(), docids=docids, topics=self.benchmark.topics[self.benchmark.query_type]) self.reranker.build_model() self.reranker.searcher_scores = best_search_run train_run = { qid: docs for qid, docs in best_search_run.items() if qid in self.benchmark.folds[fold]["train_qids"] } # For each qid, select the top 100 (defined by config["threshold") docs to be used in validation dev_run = defaultdict(dict) # This is possible because best_search_run is an OrderedDict for qid, docs in best_search_run.items(): if qid in self.benchmark.folds[fold]["predict"]["dev"]: for idx, (docid, score) in enumerate(docs.items()): if idx >= self.config["threshold"]: break dev_run[qid][docid] = score # Depending on the sampler chosen, the dataset may generate triplets or pairs train_dataset = self.sampler train_dataset.prepare(train_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level) dev_dataset = PredSampler() dev_dataset.prepare(dev_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level) dev_preds = self.reranker.trainer.train( self.reranker, train_dataset, train_output_path, dev_dataset, dev_output_path, self.benchmark.qrels, self.config["optimize"], self.benchmark.relevance_level, ) self.reranker.trainer.load_best_model(self.reranker, train_output_path) dev_output_path = train_output_path / "pred" / "dev" / "best" if not dev_output_path.exists(): dev_preds = self.reranker.trainer.predict(self.reranker, dev_dataset, dev_output_path) test_run = defaultdict(dict) # This is possible because best_search_run is an OrderedDict for qid, docs in best_search_run.items(): if qid in self.benchmark.folds[fold]["predict"]["test"]: for idx, (docid, score) in enumerate(docs.items()): if idx >= self.config["testthreshold"]: break test_run[qid][docid] = score test_dataset = PredSampler() test_dataset.prepare(test_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level) test_output_path = train_output_path / "pred" / "test" / "best" test_preds = self.reranker.trainer.predict(self.reranker, test_dataset, test_output_path) preds = {"dev": dev_preds, "test": test_preds} if include_train: train_dataset = PredSampler( train_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level) train_output_path = train_output_path / "pred" / "train" / "best" train_preds = self.reranker.trainer.predict( self.reranker, train_dataset, train_output_path) preds["train"] = train_preds return preds