Esempio n. 1
0
    def predict(self, reranker, pred_data, pred_fn):
        pred_records = self.get_tf_dev_records(reranker, pred_data)
        pred_dist_dataset = self.strategy.experimental_distribute_dataset(pred_records)

        strategy_scope = self.strategy.scope()

        with strategy_scope:
            wrapped_model = self.get_wrapped_model(reranker.model)

        def test_step(inputs):
            data, labels = inputs
            predictions = wrapped_model.predict_step(data)

            return predictions

        @tf.function
        def distributed_test_step(dataset_inputs):
            return self.strategy.run(test_step, args=(dataset_inputs,))

        predictions = []
        for x in tqdm(pred_dist_dataset, desc="validation"):
            pred_batch = distributed_test_step(x).values if self.strategy.num_replicas_in_sync > 1 else [distributed_test_step(x)]
            for p in pred_batch:
                predictions.extend(p)

        trec_preds = self.get_preds_in_trec_format(predictions, pred_data)
        os.makedirs(os.path.dirname(pred_fn), exist_ok=True)
        Searcher.write_trec_run(trec_preds, pred_fn)

        return trec_preds
Esempio n. 2
0
    def predict(self, reranker, pred_data, pred_fn):
        """Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn`

        Args:
           model (Reranker): a PyTorch Reranker
           pred_data (IterableDataset): data to predict on
           pred_fn (Path): path to write the prediction run file to

        Returns:
           TREC Run 

        """

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # save to pred_fn
        model = reranker.model.to(self.device)
        model.eval()

        preds = {}
        pred_dataloader = torch.utils.data.DataLoader(pred_data, batch_size=self.cfg["batch"], pin_memory=True, num_workers=0)
        with torch.autograd.no_grad():
            for bi, batch in enumerate(pred_dataloader):
                batch = {k: v.to(self.device) if not isinstance(v, list) else v for k, v in batch.items()}
                scores = reranker.test(batch)
                scores = scores.view(-1).cpu().numpy()
                for qid, docid, score in zip(batch["qid"], batch["posdocid"], scores):
                    # Need to use float16 because pytrec_eval's c function call crashes with higher precision floats
                    preds.setdefault(qid, {})[docid] = score.astype(np.float16).item()

        os.makedirs(os.path.dirname(pred_fn), exist_ok=True)
        Searcher.write_trec_run(preds, pred_fn)

        return preds
Esempio n. 3
0
def predict_and_save_to_file(gen, model, outfn, prepare_batch):
    preds = defaultdict(dict)
    with torch.autograd.no_grad():
        for data in tqdm(gen):
            qid_batch, docid_batch = data["qid"], data["posdocid"]
            data = prepare_batch(data)

            if pipeline.cfg["reranker"].startswith("Cedr"):
                scores = model.test(data)
            else:
                query, query_idf, doc = data["query"], data["query_idf"], data[
                    "posdoc"]
                scores = model.test(query,
                                    query_idf,
                                    doc,
                                    qids=qid_batch,
                                    posdoc_ids=docid_batch)
            scores = scores.view(-1).cpu().numpy()
            for qid, docid, score in zip(qid_batch, docid_batch, scores):
                # Need to use float16 because pytrec_eval's c function call crashes with higher precision floats
                preds[qid][docid] = score.astype(np.float16).item()

    # logger.info("predicted scores for %s pairs", sum(1 for qid in preds for docid in preds[qid]))

    # logger.info("writing predictions file: %s", outfn)
    os.makedirs(os.path.dirname(outfn), exist_ok=True)
    Searcher.write_trec_run(preds, outfn)

    return preds
Esempio n. 4
0
def test_write_run(tmpdir):
    """ write a TREC searcher file """
    fn = tmpdir / "searcher"
    run_dict = {"q1": {"d1": 1.1, "d2": 1.0}, "q2": {"d5": 9.0}}

    Searcher.write_trec_run(run_dict, fn)
    run = Searcher.load_trec_run(fn)
    assert sorted(run.items()) == sorted(run_dict.items())
Esempio n. 5
0
def test_search_run_metrics(tmpdir):
    qrels_dict = {"q1": {"d1": 1, "d2": 0, "d3": 2}, "q2": {"d5": 0, "d6": 1}}
    run_dict = {
        "q1": {
            "d1": 1.1,
            "d2": 1.0
        },
        "q2": {
            "d5": 9.0,
            "d6": 8.0
        },
        "q3": {
            "d7": 1.0,
            "d8": 2.0
        }
    }
    valid_metrics = {"P", "map", "map_cut", "ndcg_cut", "Rprec", "recip_rank"}

    fn = tmpdir / "searcher"
    Searcher.write_trec_run(run_dict, fn)

    # calculate results with q1 and q2
    searcher = Searcher(None, None, None, None)
    qids = set(qrels_dict.keys())
    evaluator = pytrec_eval.RelevanceEvaluator(qrels_dict, valid_metrics)
    partial_metrics = searcher.search_run_metrics(fn, evaluator, qids)

    # cache file exists?
    assert os.path.exists(fn + ".metrics")

    # add q3 and re-run to update cache
    qrels_dict["q3"] = {"d7": 0, "d8": 2}
    qids = set(qrels_dict.keys())
    evaluator = pytrec_eval.RelevanceEvaluator(qrels_dict, valid_metrics)
    metrics = searcher.search_run_metrics(fn, evaluator, qids)
    assert "q3" in metrics
    assert "q2" in metrics

    # remove original file to ensure results loaded from cache,
    # then make sure metrics haven't changed (and include the new q3)
    os.remove(fn)
    cached_metrics = searcher.search_run_metrics(fn, evaluator, qids)
    assert metrics == cached_metrics
Esempio n. 6
0
    def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1):
        if self.tpu:
            # WARNING: not sure if pathlib is compatible with gs://
            train_output_path = Path(
                "{0}/{1}/{2}".format(
                    self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest()
                )
            )

        dev_best_weight_fn, weights_output_path, info_output_path, loss_fn, metric_fn = self.get_paths_for_early_stopping(
            train_output_path, dev_output_path
        )

        train_records = self.get_tf_train_records(reranker, train_dataset)
        dev_records = self.get_tf_dev_records(reranker, dev_data)
        dev_dist_dataset = self.strategy.experimental_distribute_dataset(dev_records)

        # Does not very much from https://www.tensorflow.org/tutorials/distribute/custom_training
        strategy_scope = self.strategy.scope()
        with strategy_scope:
            reranker.build_model()
            wrapped_model = self.get_wrapped_model(reranker.model)
            loss_object = self.get_loss(self.config["loss"])
            optimizer_1 = tf.keras.optimizers.Adam(learning_rate=self.config["lr"])
            optimizer_2 = tf.keras.optimizers.Adam(learning_rate=self.config["bertlr"])

            # "You should remove the use of the LossScaleOptimizer when TPUs are used."
            if self.amp and not self.tpu:
                optimizer_2 = mixed_precision.LossScaleOptimizer(optimizer_2, loss_scale="dynamic")

            def compute_loss(labels, predictions):
                per_example_loss = loss_object(labels, predictions)
                return tf.nn.compute_average_loss(per_example_loss, global_batch_size=self.config["batch"])

        def is_bert_variable(name):
            if "bert" in name:
                return True
            if "electra" in name:
                return True
            return False

        def train_step(inputs):
            data, labels = inputs

            with tf.GradientTape() as tape:
                train_predictions = wrapped_model(data, training=True)
                loss = compute_loss(labels, train_predictions)
                if self.amp and not self.tpu:
                    loss = optimizer_2.get_scaled_loss(loss)

            gradients = tape.gradient(loss, wrapped_model.trainable_variables)
            if self.amp and not self.tpu:
                optimizer_2.get_unscaled_gradients(gradients)

            bert_variables = [
                (gradients[i], variable)
                for i, variable in enumerate(wrapped_model.trainable_variables)
                if is_bert_variable(variable.name) and "classifier" not in variable.name
            ]
            classifier_vars = [
                (gradients[i], variable)
                for i, variable in enumerate(wrapped_model.trainable_variables)
                if "classifier" in variable.name
            ]
            other_vars = [
                (gradients[i], variable)
                for i, variable in enumerate(wrapped_model.trainable_variables)
                if not is_bert_variable(variable.name) and "classifier" not in variable.name
            ]

            assert len(bert_variables) + len(classifier_vars) + len(other_vars) == len(wrapped_model.trainable_variables)
            # TODO: Clean this up for general use
            # Making sure that we did not miss any variables
            optimizer_1.apply_gradients(classifier_vars)
            optimizer_2.apply_gradients(bert_variables)
            if other_vars:
                optimizer_1.apply_gradients(other_vars)

            return loss

        def test_step(inputs):
            data, labels = inputs
            predictions = wrapped_model.predict_step(data)

            return predictions

        @tf.function
        def distributed_train_step(dataset_inputs):
            per_replica_losses = self.strategy.run(train_step, args=(dataset_inputs,))

            return self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

        @tf.function
        def distributed_test_step(dataset_inputs):
            return self.strategy.run(test_step, args=(dataset_inputs,))

        train_records = train_records.shuffle(100000)
        train_dist_dataset = self.strategy.experimental_distribute_dataset(train_records)

        initial_iter, metrics = (
            self.fastforward_training(wrapped_model, weights_output_path, loss_fn, metric_fn)
            if self.config["fastforward"]
            else (0, {})
        )
        dev_best_metric = metrics.get(metric, -np.inf)
        logger.info("starting training from iteration %s/%s", initial_iter + 1, self.config["niters"])
        logger.info(f"Best metric loaded: {metric}={dev_best_metric}")

        cur_step = initial_iter * self.n_batch_per_iter
        initial_lr = self.change_lr(step=cur_step, lr=self.config["bertlr"])
        K.set_value(optimizer_2.lr, K.get_value(initial_lr))
        train_loss = self.load_loss_file(loss_fn) if initial_iter > 0 else []
        if 0 < initial_iter < self.config["niters"]:
            self.exhaust_used_train_data(train_dist_dataset, n_batch_to_exhaust=initial_iter * self.n_batch_per_iter)

        niter = initial_iter
        total_loss = 0
        trec_preds = {}
        iter_bar = tqdm(desc="Training iteration", total=self.n_batch_per_iter)
        # Goes through the dataset ONCE (i.e niters * itersize).
        # However, the dataset may already contain multiple instances of the same sample,
        # depending upon what Sampler was used.
        # If you want multiple epochs, achieve it by tweaking the niters and itersize values.
        for x in train_dist_dataset:
            total_loss += distributed_train_step(x)
            cur_step += 1
            iter_bar.update(1)

            # Do warmup and decay
            new_lr = self.change_lr(step=cur_step, lr=self.config["bertlr"])
            K.set_value(optimizer_2.lr, K.get_value(new_lr))

            if cur_step % self.n_batch_per_iter == 0:
                niter += 1

                iter_bar.close()
                iter_bar = tqdm(total=self.n_batch_per_iter)
                train_loss.append(total_loss / self.n_batch_per_iter)
                logger.info("iter={} loss = {}".format(niter, train_loss[-1]))
                self.write_to_loss_file(loss_fn, train_loss)
                total_loss = 0

                if self.config["fastforward"]:
                    wrapped_model.save_weights(f"{weights_output_path}/{niter}")

                if niter % self.config["validatefreq"] == 0:
                    dev_predictions = []
                    for x in tqdm(dev_dist_dataset, desc="validation"):
                        pred_batch = (
                            distributed_test_step(x).values
                            if self.strategy.num_replicas_in_sync > 1
                            else [distributed_test_step(x)]
                        )
                        for p in pred_batch:
                            dev_predictions.extend(p)

                    trec_preds = self.get_preds_in_trec_format(dev_predictions, dev_data)
                    metrics = evaluator.eval_runs(trec_preds, dict(qrels), evaluator.DEFAULT_METRICS, relevance_level)
                    logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())]))
                    if metrics[metric] > dev_best_metric:
                        dev_best_metric = metrics[metric]
                        logger.info("new best dev metric: %0.4f", dev_best_metric)

                        self.write_to_metric_file(metric_fn, metrics)
                        wrapped_model.save_weights(dev_best_weight_fn)
                        Searcher.write_trec_run(trec_preds, outfn=(dev_output_path / "best").as_posix())

            if cur_step >= self.config["niters"] * self.n_batch_per_iter:
                break

        return trec_preds
Esempio n. 7
0
def interpolate(_config):
    from capreolus.searcher import Searcher
    import pytrec_eval

    pipeline.initialize(_config)
    logger.info("initialized pipeline with results path: %s",
                pipeline.reranker_path)

    benchmark = pipeline.benchmark
    benchmark.build()  # TODO move this to pipeline.initialize?

    test_metrics = {}
    for foldname, fold in sorted(benchmark.folds.items()):
        if not (len(fold["predict"]) == 2 and "dev" in fold["predict"]
                and "test" in fold["predict"]):
            raise RuntimeError(
                "this evaluation command is only supported for benchmarks with 'dev' and 'test' folds"
            )

        logger.debug("evaluating fold: %s", foldname)
        predict_path = os.path.join(pipeline.reranker_path, foldname,
                                    "predict")

        dev_qids = set(fold["predict"]["dev"])
        dev_qrels = {
            qid: labels
            for qid, labels in pipeline.collection.qrels.items()
            if qid in dev_qids
        }
        dev_eval = pytrec_eval.RelevanceEvaluator(dev_qrels,
                                                  {"ndcg_cut", "P", "map"})

        test_qids = set(fold["predict"]["test"])
        test_qrels = {
            qid: labels
            for qid, labels in pipeline.collection.qrels.items()
            if qid in test_qids
        }
        searcher_dev = {
            qid: docscores
            for qid, docscores in benchmark.reranking_runs[foldname].items()
            if qid in dev_qids
        }
        searcher_test = {
            qid: docscores
            for qid, docscores in benchmark.reranking_runs[foldname].items()
            if qid in test_qids
        }

        best_metric, best_iter, dev_run = -np.inf, None, None
        target_metric = "ndcg_cut_20"
        # target_metric = "map"
        devpath = os.path.join(predict_path, "dev")
        for iterfn in os.listdir(devpath):
            dev_run = Searcher.load_trec_run(os.path.join(devpath, iterfn))
            test_run = Searcher.load_trec_run(
                os.path.join(predict_path, "test", iterfn))
            alpha, interpolated_test_run, interpolated_dev_run = Searcher.crossvalidated_interpolation(
                dev={
                    "reranker": dev_run,
                    "searcher": searcher_dev,
                    "qrels": dev_qrels
                },
                test={
                    "reranker": test_run,
                    "searcher": searcher_test,
                    "qrels": test_qrels
                },
                metric=target_metric,
            )

            this_metric = np.mean([
                q[target_metric]
                for q in dev_eval.evaluate(interpolated_dev_run).values()
            ])
            if this_metric > best_metric:
                best_metric = this_metric
                best_iter = iterfn
                use_run = interpolated_test_run
                print(foldname, iterfn, best_metric, alpha)
        logger.debug("best dev %s was on iteration #%s", target_metric,
                     best_iter)

        # test_run = Searcher.load_trec_run(os.path.join(predict_path, "test", best_iter))
        test_run = use_run
        test_eval = pytrec_eval.RelevanceEvaluator(test_qrels,
                                                   {"ndcg_cut", "P", "map"})
        for qid, metrics in test_eval.evaluate(test_run).items():
            assert qid in test_qids
            for metric, value in metrics.items():
                test_metrics.setdefault(metric, {})
                assert qid not in test_metrics[metric], "fold testqid overlap"
                test_metrics[metric][qid] = value

        # output files for Anserini interpolation script
        Searcher.write_trec_run(
            Searcher.load_trec_run(os.path.join(predict_path, "dev",
                                                best_iter)),
            f"runs.rerankerIES.{foldname}.dev")
        Searcher.write_trec_run(
            Searcher.load_trec_run(
                os.path.join(predict_path, "test", best_iter)),
            f"runs.rerankerIES.{foldname}.test")

    logger.info(f"optimized for {target_metric}")
    logger.info(f"results on {len(test_metrics[metric])} aggregated test qids")
    for metric in ["ndcg_cut_20", "map", "P_5", "P_20"]:
        interpolated_avg = np.mean([*test_metrics[metric].values()])
        logger.info(f"[interpolated] avg {metric}: {interpolated_avg:0.3f}")
Esempio n. 8
0
def evaluate(_config):
    from capreolus.searcher import Searcher
    import pytrec_eval

    pipeline.initialize(_config)
    logger.debug("initialized pipeline with results path: %s",
                 pipeline.reranker_path)

    benchmark = pipeline.benchmark
    benchmark.build()  # TODO move this to pipeline.initialize?

    test_metrics = {}
    searcher_test_metrics = {}
    interpolated_test_metrics = {}
    for foldname, fold in sorted(benchmark.folds.items()):
        if not (len(fold["predict"]) == 2 and "dev" in fold["predict"]
                and "test" in fold["predict"]):
            raise RuntimeError(
                "this evaluation command is only supported for benchmarks with 'dev' and 'test' folds"
            )

        logger.debug("evaluating fold: %s", foldname)
        predict_path = os.path.join(pipeline.reranker_path, foldname,
                                    "predict")

        dev_qids = set(fold["predict"]["dev"])
        dev_qrels = {
            qid: labels
            for qid, labels in pipeline.collection.qrels.items()
            if qid in dev_qids
        }
        dev_eval = pytrec_eval.RelevanceEvaluator(dev_qrels,
                                                  {"ndcg_cut", "P", "map"})

        best_metric, best_iter, dev_run = -np.inf, None, None
        target_metric = "ndcg_cut_20"
        # target_metric = "map"
        devpath = os.path.join(predict_path, "dev")
        for iterfn in os.listdir(devpath):
            run = Searcher.load_trec_run(os.path.join(devpath, iterfn))
            this_metric = np.mean(
                [q[target_metric] for q in dev_eval.evaluate(run).values()])
            if this_metric > best_metric:
                best_metric = this_metric
                best_iter = iterfn
                dev_run = run
        logger.debug("best dev %s=%0.3f was on iteration #%s", target_metric,
                     best_metric, best_iter)

        test_run = Searcher.load_trec_run(
            os.path.join(predict_path, "test", best_iter))
        test_qids = set(fold["predict"]["test"])
        test_qrels = {
            qid: labels
            for qid, labels in pipeline.collection.qrels.items()
            if qid in test_qids
        }
        test_eval = pytrec_eval.RelevanceEvaluator(test_qrels,
                                                   {"ndcg_cut", "P", "map"})
        for qid, metrics in test_eval.evaluate(test_run).items():
            assert qid in test_qids
            for metric, value in metrics.items():
                test_metrics.setdefault(metric, {})
                assert qid not in test_metrics[metric], "fold testqid overlap"
                test_metrics[metric][qid] = value

        # compute metrics for the run being reranked
        for qid, metrics in test_eval.evaluate(
                benchmark.reranking_runs[foldname]).items():
            assert qid in test_qids
            for metric, value in metrics.items():
                searcher_test_metrics.setdefault(metric, {})
                assert qid not in searcher_test_metrics[
                    metric], "fold testqid overlap"
                searcher_test_metrics[metric][qid] = value

        # choose an alpha for interpolation using the dev_qids,
        # then create a run by interpolating the searcher and reranker scores
        searcher_dev = {
            qid: docscores
            for qid, docscores in benchmark.reranking_runs[foldname].items()
            if qid in dev_qids
        }
        searcher_test = {
            qid: docscores
            for qid, docscores in benchmark.reranking_runs[foldname].items()
            if qid in test_qids
        }
        alpha, interpolated_test_run, _ = Searcher.crossvalidated_interpolation(
            dev={
                "reranker": dev_run,
                "searcher": searcher_dev,
                "qrels": dev_qrels
            },
            test={
                "reranker": test_run,
                "searcher": searcher_test,
                "qrels": test_qrels
            },
            metric=target_metric,
        )

        # output files for Anserini interpolation script
        Searcher.write_trec_run(dev_run, f"runs.reranker.{foldname}.dev")
        Searcher.write_trec_run(test_run, f"runs.reranker.{foldname}.test")
        Searcher.write_trec_run(searcher_dev, f"runs.searcher.{foldname}.dev")
        Searcher.write_trec_run(searcher_test,
                                f"runs.searcher.{foldname}.test")

        logger.debug(f"interpolation alpha={alpha}")
        for qid, metrics in test_eval.evaluate(interpolated_test_run).items():
            assert qid in test_qids
            for metric, value in metrics.items():
                interpolated_test_metrics.setdefault(metric, {})
                assert qid not in interpolated_test_metrics[
                    metric], "fold testqid overlap"
                interpolated_test_metrics[metric][qid] = value

    logger.info(f"optimized for {target_metric}")
    logger.info(f"results on {len(test_metrics[metric])} aggregated test qids")
    for metric in ["map", "P_20", "ndcg_cut_20"]:
        assert len(test_metrics[metric]) == len(searcher_test_metrics[metric])
        assert len(test_metrics[metric]) == len(
            interpolated_test_metrics[metric])

        searcher_avg = np.mean([*searcher_test_metrics[metric].values()])
        logger.info(f"[searcher] avg {metric}: {searcher_avg:0.3f}")

        sigtest_qids = sorted(test_metrics[metric].keys())
        sigtest = ttest_rel(
            [searcher_test_metrics[metric][qid] for qid in sigtest_qids],
            [test_metrics[metric][qid] for qid in sigtest_qids])

        avg = np.mean([*test_metrics[metric].values()])
        logger.info(
            f"[reranker] avg {metric}: {avg:0.3f}\tp={sigtest.pvalue:0.3f} (vs. searcher)"
        )

        interpolated_avg = np.mean(
            [*interpolated_test_metrics[metric].values()])
        logger.info(f"[interpolated] avg {metric}: {interpolated_avg:0.3f}")

    with open(os.path.join(predict_path, "results.json"), "wt") as outf:
        json.dump(
            (test_metrics, searcher_test_metrics, interpolated_test_metrics),
            outf)