Ejemplo n.º 1
0
def evaluate(config, modules):
    metric = "map"
    fold = config["fold"]
    train_output_path = _pipeline_path(config, modules)
    test_output_path = train_output_path / "pred" / "test" / "best"

    searcher = modules["searcher"]
    benchmark = modules["benchmark"]
    reranker = modules["reranker"]

    if os.path.exists(test_output_path):
        test_preds = Searcher.load_trec_run(test_output_path)
    else:
        topics_fn = benchmark.topic_file
        searcher_cache_dir = os.path.join(searcher.get_cache_path(), benchmark.name)
        searcher_run_dir = searcher.query_from_file(topics_fn, searcher_cache_dir)

        best_search_run_path = evaluator.search_best_run(searcher_run_dir, benchmark, metric)["path"][fold]
        best_search_run = searcher.load_trec_run(best_search_run_path)

        docids = set(docid for querydocs in best_search_run.values() for docid in querydocs)
        reranker["extractor"].create(qids=best_search_run.keys(), docids=docids, topics=benchmark.topics[benchmark.query_type])
        reranker.build()

        reranker["trainer"].load_best_model(reranker, train_output_path)

        test_run = {qid: docs for qid, docs in best_search_run.items() if qid in benchmark.folds[fold]["predict"]["test"]}
        test_dataset = PredDataset(qid_docid_to_rank=test_run, extractor=reranker["extractor"], mode="test")

        test_preds = reranker["trainer"].predict(reranker, test_dataset, test_output_path)

    metrics = evaluator.eval_runs(test_preds, benchmark.qrels, ["ndcg_cut_20", "ndcg_cut_10", "map", "P_20", "P_10"])
    print("test metrics for fold=%s:" % fold, metrics)

    print("\ncomputing metrics across all folds")
    avg = {}
    found = 0
    for fold in benchmark.folds:
        pred_path = _pipeline_path(config, modules, fold=fold) / "pred" / "test" / "best"
        if not os.path.exists(pred_path):
            print("\tfold=%s results are missing and will not be included" % fold)
            continue

        found += 1
        preds = Searcher.load_trec_run(pred_path)
        metrics = evaluator.eval_runs(preds, benchmark.qrels, ["ndcg_cut_20", "ndcg_cut_10", "map", "P_20", "P_10"])
        for metric, val in metrics.items():
            avg.setdefault(metric, []).append(val)

    avg = {k: np.mean(v) for k, v in avg.items()}
    print(f"average metrics across {found}/{len(benchmark.folds)} folds:", avg)
Ejemplo n.º 2
0
    def bircheval(self):
        fold = self.config["fold"]
        train_output_path = self.get_results_path()
        searcher_runs, reranker_runs = self.find_birch_crossvalidated_results()

        fold_test_metrics = evaluator.eval_runs(
            reranker_runs[fold]["test"], self.benchmark.qrels, evaluator.DEFAULT_METRICS, self.benchmark.relevance_level
        )
        logger.info("rerank: fold=%s test metrics: %s", fold, fold_test_metrics)
Ejemplo n.º 3
0
    def evaluate(self):
        fold = self.config["fold"]
        train_output_path = self.get_results_path()
        test_output_path = train_output_path / "pred" / "test" / "best"
        logger.debug("results path: %s", train_output_path)

        if os.path.exists(test_output_path):
            test_preds = Searcher.load_trec_run(test_output_path)
        else:
            self.rank.search()
            rank_results = self.rank.evaluate()
            best_search_run_path = rank_results["path"][fold]
            best_search_run = Searcher.load_trec_run(best_search_run_path)

            docids = set(docid for querydocs in best_search_run.values()
                         for docid in querydocs)
            self.reranker.extractor.preprocess(
                qids=best_search_run.keys(),
                docids=docids,
                topics=self.benchmark.topics[self.benchmark.query_type])
            self.reranker.build_model()
            self.reranker.searcher_scores = best_search_run

            self.reranker.trainer.load_best_model(self.reranker,
                                                  train_output_path)

            test_run = {
                qid: docs
                for qid, docs in best_search_run.items()
                if qid in self.benchmark.folds[fold]["predict"]["test"]
            }
            test_dataset = PredSampler()
            test_dataset.prepare(test_run, self.benchmark.qrels,
                                 self.reranker.extractor)

            test_preds = self.reranker.trainer.predict(self.reranker,
                                                       test_dataset,
                                                       test_output_path)

        metrics = evaluator.eval_runs(test_preds, self.benchmark.qrels,
                                      evaluator.DEFAULT_METRICS,
                                      self.benchmark.relevance_level)
        logger.info("rerank: fold=%s test metrics: %s", fold, metrics)

        print("\ncomputing metrics across all folds")
        avg = {}
        found = 0
        for fold in self.benchmark.folds:
            # TODO fix by using multiple Tasks
            from pathlib import Path

            pred_path = Path(test_output_path.as_posix().replace(
                "fold-" + self.config["fold"], "fold-" + fold))
            if not os.path.exists(pred_path):
                print(
                    "\tfold=%s results are missing and will not be included" %
                    fold)
                continue

            found += 1
            preds = Searcher.load_trec_run(pred_path)
            metrics = evaluator.eval_runs(preds, self.benchmark.qrels,
                                          evaluator.DEFAULT_METRICS,
                                          self.benchmark.relevance_level)
            for metric, val in metrics.items():
                avg.setdefault(metric, []).append(val)

        avg = {k: np.mean(v) for k, v in avg.items()}
        logger.info(
            "rerank: average cross-validated metrics when choosing iteration based on '%s':",
            self.config["optimize"])
        for metric, score in sorted(avg.items()):
            logger.info("%25s: %0.4f", metric, score)
Ejemplo n.º 4
0
    def train(self,
              reranker,
              train_dataset,
              train_output_path,
              dev_data,
              dev_output_path,
              qrels,
              metric,
              relevance_level=1,
              init_path=None):
        if self.tpu:
            train_output_path = "{0}/{1}/{2}".format(
                self.config["storage"], "train_output",
                hashlib.md5(
                    str(train_output_path).encode("utf-8")).hexdigest())
        os.makedirs(dev_output_path, exist_ok=True)
        start_epoch = self.config["niters"] if reranker.config.get(
            "modeltype", "") in ["nir", "cedr"] else 0
        train_records = self.get_tf_train_records(reranker, train_dataset)
        dev_records = self.get_tf_dev_records(reranker, dev_data)
        dev_dist_dataset = self.strategy.experimental_distribute_dataset(
            dev_records)

        # Does not very much from https://www.tensorflow.org/tutorials/distribute/custom_training
        strategy_scope = self.strategy.scope()
        with strategy_scope:
            reranker.build_model()
            wrapped_model = self.get_wrapped_model(reranker.model)
            if init_path:
                logger.info(f"Initializing model from checkpoint {init_path}")
                print("number of vars: ",
                      len(wrapped_model.trainable_variables))
                wrapped_model.load_weights(init_path)

            loss_object = self.get_loss(self.config["loss"])
            optimizer_1 = tf.keras.optimizers.Adam(
                learning_rate=self.config["lr"])
            optimizer_2 = tf.keras.optimizers.Adam(
                learning_rate=self.config["bertlr"])

            def compute_loss(labels, predictions):
                per_example_loss = loss_object(labels, predictions)
                return tf.nn.compute_average_loss(
                    per_example_loss, global_batch_size=self.config["batch"])

            def is_bert_parameters(name):
                name = name.lower()
                '''
                if "layer" in name:
                    if not ("9" in name or "10" in name or "11" in name or "12" in name):
                        return False
                '''
                if "/bert/" in name:
                    return True
                if "/electra/" in name:
                    return True
                if "/roberta/" in name:
                    return True
                if "/albert/" in name:
                    return True
                return False

        def train_step(inputs):
            data, labels = inputs

            with tf.GradientTape() as tape:
                train_predictions = wrapped_model(data, training=True)
                loss = compute_loss(labels, train_predictions)

            gradients = tape.gradient(loss, wrapped_model.trainable_variables)

            # TODO: Expose the layer names to lookout for as a ConfigOption?
            # TODO: Crystina mentioned that hugging face models have 'bert' in all the layers (including classifiers). Handle this case
            bert_variables = [
                (gradients[i], variable)
                for i, variable in enumerate(wrapped_model.trainable_variables)
                if is_bert_parameters(variable.name)
                and "classifier" not in variable.name
            ]
            classifier_vars = [
                (gradients[i], variable)
                for i, variable in enumerate(wrapped_model.trainable_variables)
                if "classifier" in variable.name
            ]
            other_vars = [
                (gradients[i], variable)
                for i, variable in enumerate(wrapped_model.trainable_variables)
                if (not is_bert_parameters(variable.name))
                and "classifier" not in variable.name
            ]

            assert len(bert_variables) + len(classifier_vars) + len(
                other_vars) == len(wrapped_model.trainable_variables)
            # TODO: Clean this up for general use
            # Making sure that we did not miss any variables

            if self.config["lr"] > 0:
                optimizer_1.apply_gradients(classifier_vars + other_vars)
            if self.config["bertlr"] > 0:
                optimizer_2.apply_gradients(bert_variables)

            return loss

        def test_step(inputs):
            data, labels = inputs
            predictions = wrapped_model.predict_step(data)

            return predictions

        @tf.function
        def distributed_train_step(dataset_inputs):
            per_replica_losses = self.strategy.run(train_step,
                                                   args=(dataset_inputs, ))

            return self.strategy.reduce(tf.distribute.ReduceOp.SUM,
                                        per_replica_losses,
                                        axis=None)

        @tf.function
        def distributed_test_step(dataset_inputs):
            return self.strategy.run(test_step, args=(dataset_inputs, ))

        best_metric = -np.inf
        epoch = 0
        num_batches = 0
        total_loss = 0
        iter_bar = tqdm(total=self.config["itersize"])

        initial_lr = self.change_lr(epoch,
                                    self.config["bertlr"],
                                    do_warmup=self.config["warmupbert"])
        K.set_value(optimizer_2.lr, K.get_value(initial_lr))
        wandb.log({"bertlr": K.get_value(initial_lr)},
                  step=epoch + start_epoch,
                  commit=False)

        initial_lr = self.change_lr(epoch,
                                    self.config["lr"],
                                    do_warmup=self.config["warmupnonbert"])
        K.set_value(optimizer_1.lr, K.get_value(initial_lr))
        wandb.log({"lr": K.get_value(initial_lr)},
                  step=epoch + start_epoch,
                  commit=False)

        train_records = train_records.shuffle(100000)
        train_dist_dataset = self.strategy.experimental_distribute_dataset(
            train_records)

        # Goes through the dataset ONCE (i.e niters * itersize * batch samples). However, the dataset may already contain multiple instances of the same sample,
        # depending upon what Sampler was used. If you want multiple epochs, achieve it by tweaking the niters and
        # itersize values.
        for x in train_dist_dataset:
            total_loss += distributed_train_step(x)
            train_loss = total_loss / num_batches
            num_batches += 1
            iter_bar.update(1)

            if num_batches % self.config["itersize"] == 0:
                epoch += 1

                # Do warmup and decay
                new_lr = self.change_lr(epoch,
                                        self.config["bertlr"],
                                        do_warmup=self.config["warmupbert"])
                K.set_value(optimizer_2.lr, K.get_value(new_lr))
                wandb.log({f"bertlr": K.get_value(new_lr)},
                          step=epoch + start_epoch,
                          commit=False)

                new_lr = self.change_lr(epoch,
                                        self.config["lr"],
                                        do_warmup=self.config["warmupnonbert"])
                K.set_value(optimizer_1.lr, K.get_value(new_lr))
                wandb.log({f"lr": K.get_value(new_lr)},
                          step=epoch + start_epoch,
                          commit=False)

                iter_bar.close()
                logger.info("train_loss for epoch {} is {}".format(
                    epoch, train_loss))
                wandb.log({f"loss": float(train_loss.numpy())},
                          step=epoch + start_epoch,
                          commit=False)
                total_loss = 0

                if epoch % self.config["validatefreq"] == 0:
                    dev_predictions = []
                    for x in tqdm(dev_dist_dataset, desc="validation"):
                        pred_batch = (distributed_test_step(x).values
                                      if self.strategy.num_replicas_in_sync > 1
                                      else [distributed_test_step(x)])
                        for p in pred_batch:
                            dev_predictions.extend(p)

                    trec_preds = self.get_preds_in_trec_format(
                        dev_predictions, dev_data)
                    metrics = evaluator.eval_runs(
                        trec_preds, dict(qrels),
                        evaluator.DEFAULT_METRICS + ["bpref"], relevance_level)
                    logger.info(
                        "dev metrics: %s", " ".join([
                            f"{metric}={v:0.3f}"
                            for metric, v in sorted(metrics.items())
                        ]))
                    if metrics[metric] > best_metric:
                        logger.info("Writing checkpoint")
                        best_metric = metrics[metric]
                        wrapped_model.save_weights(
                            "{0}/dev.best".format(train_output_path))

                    wandb.log(
                        {
                            f"dev-{k}": v
                            for k, v in metrics.items() if k in [
                                "map", "bpref", "P_20", "ndcg_cut_20",
                                "judged_10", "judged_20", "judged_200"
                            ]
                        },
                        step=epoch + start_epoch,
                        commit=False)

                iter_bar = tqdm(total=self.config["itersize"])

            if num_batches >= self.config["niters"] * self.config["itersize"]:
                break
    def evaluate(self):
        fold = self.config["fold"]
        train_output_path = self.get_results_path()
        logger.debug("results path: %s", train_output_path)

        searcher_runs, reranker_runs = self.find_crossvalidated_results()

        if fold not in reranker_runs:
            logger.error(
                "could not find predictions; run the train command first")
            raise ValueError(
                "could not find predictions; run the train command first")

        fold_dev_metrics = evaluator.eval_runs(reranker_runs[fold]["dev"],
                                               self.benchmark.qrels,
                                               self.metrics,
                                               self.benchmark.relevance_level)
        logger.info("rerank: fold=%s dev metrics: %s", fold, fold_dev_metrics)

        unsampled_qrels = self.benchmark.unsampled_qrels if hasattr(
            self.benchmark, "unsampled_qrels") else self.benchmark.qrels
        fold_test_metrics = evaluator.eval_runs(reranker_runs[fold]["test"],
                                                unsampled_qrels, self.metrics,
                                                self.benchmark.relevance_level)
        logger.info("rerank: fold=%s test metrics: %s", fold,
                    fold_test_metrics)

        if len(reranker_runs) != len(self.benchmark.folds):
            logger.info(
                "rerank: skipping cross-validated metrics because results exist for only %s/%s folds",
                len(reranker_runs), len(self.benchmark.folds))
            logger.info("available runs: ", reranker_runs.keys())
            return {
                "fold_test_metrics": fold_test_metrics,
                "fold_dev_metrics": fold_dev_metrics,
                "cv_metrics": None,
                "interpolated_cv_metrics": None,
            }

        logger.info(
            "rerank: average cross-validated metrics when choosing iteration based on '%s':",
            self.config["optimize"])
        all_preds = {}
        for preds in reranker_runs.values():
            for qid, docscores in preds["test"].items():
                all_preds.setdefault(qid, {})
                for docid, score in docscores.items():
                    all_preds[qid][docid] = score

        cv_metrics = evaluator.eval_runs(all_preds, unsampled_qrels,
                                         self.metrics,
                                         self.benchmark.relevance_level)
        interpolated_results = evaluator.interpolated_eval(
            searcher_runs, reranker_runs, self.benchmark,
            self.config["optimize"], self.metrics)

        for metric, score in sorted(cv_metrics.items()):
            logger.info("%25s: %0.4f", metric, score)

        return {
            "fold_test_metrics": fold_test_metrics,
            "fold_dev_metrics": fold_dev_metrics,
            "cv_metrics": cv_metrics,
            "interpolated_results": interpolated_results,
        }
    def predict_and_eval(self, init_path=None):
        fold = self.config["fold"]
        self.reranker.build_model()
        if not init_path or init_path == "none":
            logger.info(f"Loading self best ckpt: {init_path}")
            logger.info("No init path given, using default parameters")
            self.reranker.build_model()
        else:
            logger.info(f"Load from {init_path}")
            init_path = Path(
                init_path) if not init_path.startswith("gs:") else init_path
            self.reranker.trainer.load_best_model(self.reranker,
                                                  init_path,
                                                  do_not_hash=True)

        dirname = str(init_path).split("/")[-1] if init_path else "noinitpath"
        savedir = Path(
            __file__).parent.absolute() / "downloaded_runfiles" / dirname
        dev_output_path = savedir / fold / "dev"
        test_output_path = savedir / fold / "test"
        test_output_path.parent.mkdir(exist_ok=True, parents=True)

        self.rank.search()
        threshold = self.config["threshold"]
        rank_results = self.rank.evaluate()
        best_search_run_path = rank_results["path"][fold]
        best_search_run = Searcher.load_trec_run(best_search_run_path)

        docids = set(docid for querydocs in best_search_run.values()
                     for docid in querydocs)
        self.reranker.extractor.preprocess(
            qids=best_search_run.keys(),
            docids=docids,
            topics=self.benchmark.topics[self.benchmark.query_type])

        # dev run
        dev_run = defaultdict(dict)
        for qid, docs in best_search_run.items():
            if qid in self.benchmark.folds[fold]["predict"]["dev"]:
                for idx, (docid, score) in enumerate(docs.items()):
                    if idx >= threshold:
                        assert len(
                            dev_run[qid]
                        ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}"
                        break
                    dev_run[qid][docid] = score
        dev_dataset = PredSampler()
        dev_dataset.prepare(dev_run,
                            self.benchmark.qrels,
                            self.reranker.extractor,
                            relevance_level=self.benchmark.relevance_level)

        # test_run
        test_run = defaultdict(dict)
        # This is possible because best_search_run is an OrderedDict
        for qid, docs in best_search_run.items():
            if qid in self.benchmark.folds[fold]["predict"]["test"]:
                for idx, (docid, score) in enumerate(docs.items()):
                    if idx >= threshold:
                        assert len(
                            test_run[qid]
                        ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}"
                        break
                    test_run[qid][docid] = score

        unsampled_qrels = self.benchmark.unsampled_qrels if hasattr(
            self.benchmark, "unsampled_qrels") else self.benchmark.qrels
        test_dataset = PredSampler()
        test_dataset.prepare(test_run,
                             unsampled_qrels,
                             self.reranker.extractor,
                             relevance_level=self.benchmark.relevance_level)
        logger.info("test prepared")

        # prediction
        dev_preds = self.reranker.trainer.predict(self.reranker, dev_dataset,
                                                  dev_output_path)
        fold_dev_metrics = evaluator.eval_runs(dev_preds, unsampled_qrels,
                                               self.metrics,
                                               self.benchmark.relevance_level)
        logger.info("rerank: fold=%s dev metrics: %s", fold, fold_dev_metrics)

        test_preds = self.reranker.trainer.predict(self.reranker, test_dataset,
                                                   test_output_path)
        fold_test_metrics = evaluator.eval_runs(test_preds, unsampled_qrels,
                                                self.metrics,
                                                self.benchmark.relevance_level)
        logger.info("rerank: fold=%s test metrics: %s", fold,
                    fold_test_metrics)
        wandb.save(str(dev_output_path))
        wandb.save(str(test_output_path))

        # add cross validate results:
        n_folds = len(self.benchmark.folds)
        folds_fn = {
            f"s{i}": savedir / f"s{i}" / "test"
            for i in range(1, n_folds + 1)
        }
        if not all([fn.exists() for fn in folds_fn.values()]):
            return {"fold_test_metrics": fold_test_metrics, "cv_metrics": None}

        all_preds = {}
        reranker_runs = {
            fold: {
                "dev": Searcher.load_trec_run(fn.parent / "dev"),
                "test": Searcher.load_trec_run(fn)
            }
            for fold, fn in folds_fn.items()
        }

        for fold, dev_test in reranker_runs.items():
            preds = dev_test["test"]
            qids = self.benchmark.folds[fold]["predict"]["test"]
            for qid, docscores in preds.items():
                if qid not in qids:
                    continue
                all_preds.setdefault(qid, {})
                for docid, score in docscores.items():
                    all_preds[qid][docid] = score

        cv_metrics = evaluator.eval_runs(all_preds, unsampled_qrels,
                                         self.metrics,
                                         self.benchmark.relevance_level)
        for metric, score in sorted(cv_metrics.items()):
            logger.info("%25s: %0.4f", metric, score)

        searcher_runs = {}
        rank_results = self.rank.evaluate()
        for fold in self.benchmark.folds:
            searcher_runs[fold] = {
                "dev": Searcher.load_trec_run(rank_results["path"][fold])
            }
            searcher_runs[fold]["test"] = searcher_runs[fold]["dev"]

        interpolated_results = evaluator.interpolated_eval(
            searcher_runs, reranker_runs, self.benchmark,
            self.config["optimize"], self.metrics)

        return {
            "fold_test_metrics": fold_test_metrics,
            "cv_metrics": cv_metrics,
            "interpolated_results": interpolated_results,
        }
Ejemplo n.º 7
0
    def train(self,
              reranker,
              train_dataset,
              train_output_path,
              dev_data,
              dev_output_path,
              qrels,
              metric,
              relevance_level=1):
        """Train a model following the trainer's config (specifying batch size, number of iterations, etc).

        Args:
           train_dataset (IterableDataset): training dataset
           train_output_path (Path): directory under which train_dataset runs and training loss will be saved
           dev_data (IterableDataset): dev dataset
           dev_output_path (Path): directory where dev_data runs and metrics will be saved

        """
        # Set up logging
        # TODO why not put this under train_output_path?
        summary_writer = SummaryWriter(RESULTS_BASE_PATH / "runs" /
                                       self.config["boardname"],
                                       comment=train_output_path)

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        model = reranker.model.to(self.device)
        self.optimizer = torch.optim.Adam(filter(
            lambda param: param.requires_grad, model.parameters()),
                                          lr=self.config["lr"])

        if self.config["amp"] in ("both", "train"):
            self.amp_train_autocast = torch.cuda.amp.autocast
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.amp_train_autocast = contextlib.nullcontext
            self.scaler = None

        # REF-TODO how to handle interactions between fastforward and schedule? --> just save its state
        self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer, lambda epoch: self.lr_multiplier(step=epoch * self.
                                                             n_batch_per_iter))

        if self.config["softmaxloss"]:
            self.loss = pair_softmax_loss
        else:
            self.loss = pair_hinge_loss

        dev_best_weight_fn, weights_output_path, info_output_path, loss_fn, metric_fn = self.get_paths_for_early_stopping(
            train_output_path, dev_output_path)

        num_workers = 1 if self.config["multithread"] else 0
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.config["batch"],
            pin_memory=True,
            num_workers=num_workers)

        # if we're fastforwarding, set first iteration and load last saved weights
        initial_iter, metrics = (self.fastforward_training(
            reranker, weights_output_path, loss_fn, metric_fn)
                                 if self.config["fastforward"] else (0, {}))
        dev_best_metric = metrics.get(metric, -np.inf)
        logger.info("starting training from iteration %s/%s", initial_iter + 1,
                    self.config["niters"])
        logger.info(f"Best metric loaded: {metric}={dev_best_metric}")

        train_loss = []
        # are we resuming training? fastforward loss and data if so
        if initial_iter > 0:
            train_loss = self.load_loss_file(loss_fn)

            # are we done training? if not, fastforward through prior batches
            if initial_iter < self.config["niters"]:
                logger.debug("fastforwarding train_dataloader to iteration %s",
                             initial_iter)
                self.exhaust_used_train_data(train_dataloader,
                                             n_batch_to_exhaust=initial_iter *
                                             self.n_batch_per_iter)

        logger.info(self.get_validation_schedule_msg(initial_iter))
        train_start_time = time.time()
        for niter in range(initial_iter, self.config["niters"]):
            niter = niter + 1  # index from 1
            model.train()

            iter_start_time = time.time()
            iter_loss_tensor = self.single_train_iteration(
                reranker, train_dataloader)
            logger.info("A single iteration takes {}".format(time.time() -
                                                             iter_start_time))
            train_loss.append(iter_loss_tensor.item())
            logger.info("iter = %d loss = %f", niter, train_loss[-1])

            # save model weights only when fastforward enabled
            if self.config["fastforward"]:
                weights_fn = weights_output_path / f"{niter}.p"
                reranker.save_weights(weights_fn, self.optimizer)

            # predict performance on dev set
            if niter % self.config["validatefreq"] == 0:
                pred_fn = dev_output_path / f"{niter}.run"
                preds = self.predict(reranker, dev_data, pred_fn)

                # log dev metrics
                metrics = evaluator.eval_runs(preds, qrels,
                                              evaluator.DEFAULT_METRICS,
                                              relevance_level)
                logger.info(
                    "dev metrics: %s", " ".join([
                        f"{metric}={v:0.3f}"
                        for metric, v in sorted(metrics.items())
                    ]))
                summary_writer.add_scalar("ndcg_cut_20",
                                          metrics["ndcg_cut_20"], niter)
                summary_writer.add_scalar("map", metrics["map"], niter)
                summary_writer.add_scalar("P_20", metrics["P_20"], niter)
                # write best dev weights to file
                if metrics[metric] > dev_best_metric:
                    dev_best_metric = metrics[metric]
                    logger.info("new best dev metric: %0.4f", dev_best_metric)
                    reranker.save_weights(dev_best_weight_fn, self.optimizer)
                    self.write_to_metric_file(metric_fn, metrics)

            # write train_loss to file
            # loss_fn.write_text("\n".join(f"{idx} {loss}" for idx, loss in enumerate(train_loss)))
            self.write_to_loss_file(loss_fn, train_loss)

            summary_writer.add_scalar("training_loss", iter_loss_tensor.item(),
                                      niter)
            reranker.add_summary(summary_writer, niter)
            summary_writer.flush()
        logger.info("training loss: %s", train_loss)
        logger.info("Training took {}".format(time.time() - train_start_time))
        summary_writer.close()
Ejemplo n.º 8
0
    def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1):
        if self.tpu:
            # WARNING: not sure if pathlib is compatible with gs://
            train_output_path = Path(
                "{0}/{1}/{2}".format(
                    self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest()
                )
            )

        dev_best_weight_fn, weights_output_path, info_output_path, loss_fn, metric_fn = self.get_paths_for_early_stopping(
            train_output_path, dev_output_path
        )

        train_records = self.get_tf_train_records(reranker, train_dataset)
        dev_records = self.get_tf_dev_records(reranker, dev_data)
        dev_dist_dataset = self.strategy.experimental_distribute_dataset(dev_records)

        # Does not very much from https://www.tensorflow.org/tutorials/distribute/custom_training
        strategy_scope = self.strategy.scope()
        with strategy_scope:
            reranker.build_model()
            wrapped_model = self.get_wrapped_model(reranker.model)
            loss_object = self.get_loss(self.config["loss"])
            optimizer_1 = tf.keras.optimizers.Adam(learning_rate=self.config["lr"])
            optimizer_2 = tf.keras.optimizers.Adam(learning_rate=self.config["bertlr"])

            # "You should remove the use of the LossScaleOptimizer when TPUs are used."
            if self.amp and not self.tpu:
                optimizer_2 = mixed_precision.LossScaleOptimizer(optimizer_2, loss_scale="dynamic")

            def compute_loss(labels, predictions):
                per_example_loss = loss_object(labels, predictions)
                return tf.nn.compute_average_loss(per_example_loss, global_batch_size=self.config["batch"])

        def is_bert_variable(name):
            if "bert" in name:
                return True
            if "electra" in name:
                return True
            return False

        def train_step(inputs):
            data, labels = inputs

            with tf.GradientTape() as tape:
                train_predictions = wrapped_model(data, training=True)
                loss = compute_loss(labels, train_predictions)
                if self.amp and not self.tpu:
                    loss = optimizer_2.get_scaled_loss(loss)

            gradients = tape.gradient(loss, wrapped_model.trainable_variables)
            if self.amp and not self.tpu:
                optimizer_2.get_unscaled_gradients(gradients)

            bert_variables = [
                (gradients[i], variable)
                for i, variable in enumerate(wrapped_model.trainable_variables)
                if is_bert_variable(variable.name) and "classifier" not in variable.name
            ]
            classifier_vars = [
                (gradients[i], variable)
                for i, variable in enumerate(wrapped_model.trainable_variables)
                if "classifier" in variable.name
            ]
            other_vars = [
                (gradients[i], variable)
                for i, variable in enumerate(wrapped_model.trainable_variables)
                if not is_bert_variable(variable.name) and "classifier" not in variable.name
            ]

            assert len(bert_variables) + len(classifier_vars) + len(other_vars) == len(wrapped_model.trainable_variables)
            # TODO: Clean this up for general use
            # Making sure that we did not miss any variables
            optimizer_1.apply_gradients(classifier_vars)
            optimizer_2.apply_gradients(bert_variables)
            if other_vars:
                optimizer_1.apply_gradients(other_vars)

            return loss

        def test_step(inputs):
            data, labels = inputs
            predictions = wrapped_model.predict_step(data)

            return predictions

        @tf.function
        def distributed_train_step(dataset_inputs):
            per_replica_losses = self.strategy.run(train_step, args=(dataset_inputs,))

            return self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

        @tf.function
        def distributed_test_step(dataset_inputs):
            return self.strategy.run(test_step, args=(dataset_inputs,))

        train_records = train_records.shuffle(100000)
        train_dist_dataset = self.strategy.experimental_distribute_dataset(train_records)

        initial_iter, metrics = (
            self.fastforward_training(wrapped_model, weights_output_path, loss_fn, metric_fn)
            if self.config["fastforward"]
            else (0, {})
        )
        dev_best_metric = metrics.get(metric, -np.inf)
        logger.info("starting training from iteration %s/%s", initial_iter + 1, self.config["niters"])
        logger.info(f"Best metric loaded: {metric}={dev_best_metric}")

        cur_step = initial_iter * self.n_batch_per_iter
        initial_lr = self.change_lr(step=cur_step, lr=self.config["bertlr"])
        K.set_value(optimizer_2.lr, K.get_value(initial_lr))
        train_loss = self.load_loss_file(loss_fn) if initial_iter > 0 else []
        if 0 < initial_iter < self.config["niters"]:
            self.exhaust_used_train_data(train_dist_dataset, n_batch_to_exhaust=initial_iter * self.n_batch_per_iter)

        niter = initial_iter
        total_loss = 0
        trec_preds = {}
        iter_bar = tqdm(desc="Training iteration", total=self.n_batch_per_iter)
        # Goes through the dataset ONCE (i.e niters * itersize).
        # However, the dataset may already contain multiple instances of the same sample,
        # depending upon what Sampler was used.
        # If you want multiple epochs, achieve it by tweaking the niters and itersize values.
        for x in train_dist_dataset:
            total_loss += distributed_train_step(x)
            cur_step += 1
            iter_bar.update(1)

            # Do warmup and decay
            new_lr = self.change_lr(step=cur_step, lr=self.config["bertlr"])
            K.set_value(optimizer_2.lr, K.get_value(new_lr))

            if cur_step % self.n_batch_per_iter == 0:
                niter += 1

                iter_bar.close()
                iter_bar = tqdm(total=self.n_batch_per_iter)
                train_loss.append(total_loss / self.n_batch_per_iter)
                logger.info("iter={} loss = {}".format(niter, train_loss[-1]))
                self.write_to_loss_file(loss_fn, train_loss)
                total_loss = 0

                if self.config["fastforward"]:
                    wrapped_model.save_weights(f"{weights_output_path}/{niter}")

                if niter % self.config["validatefreq"] == 0:
                    dev_predictions = []
                    for x in tqdm(dev_dist_dataset, desc="validation"):
                        pred_batch = (
                            distributed_test_step(x).values
                            if self.strategy.num_replicas_in_sync > 1
                            else [distributed_test_step(x)]
                        )
                        for p in pred_batch:
                            dev_predictions.extend(p)

                    trec_preds = self.get_preds_in_trec_format(dev_predictions, dev_data)
                    metrics = evaluator.eval_runs(trec_preds, dict(qrels), evaluator.DEFAULT_METRICS, relevance_level)
                    logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())]))
                    if metrics[metric] > dev_best_metric:
                        dev_best_metric = metrics[metric]
                        logger.info("new best dev metric: %0.4f", dev_best_metric)

                        self.write_to_metric_file(metric_fn, metrics)
                        wrapped_model.save_weights(dev_best_weight_fn)
                        Searcher.write_trec_run(trec_preds, outfn=(dev_output_path / "best").as_posix())

            if cur_step >= self.config["niters"] * self.n_batch_per_iter:
                break

        return trec_preds
Ejemplo n.º 9
0
    def evaluate(self):
        fold = self.config["fold"]
        train_output_path = self.get_results_path()
        logger.debug("results path: %s", train_output_path)
        metrics = self.config["metrics"] if list(self.config["metrics"]) != [
            "default"
        ] else evaluator.DEFAULT_METRICS

        searcher_runs, reranker_runs = self.find_crossvalidated_results()

        if fold not in reranker_runs:
            logger.error(
                "could not find predictions; run the train command first")
            raise ValueError(
                "could not find predictions; run the train command first")

        fold_dev_metrics = evaluator.eval_runs(reranker_runs[fold]["dev"],
                                               self.benchmark.qrels, metrics,
                                               self.benchmark.relevance_level)
        pretty_fold_dev_metrics = " ".join([
            f"{metric}={v:0.3f}"
            for metric, v in sorted(fold_dev_metrics.items())
        ])
        logger.info("rerank: fold=%s dev metrics: %s", fold,
                    pretty_fold_dev_metrics)

        fold_test_metrics = evaluator.eval_runs(reranker_runs[fold]["test"],
                                                self.benchmark.qrels, metrics,
                                                self.benchmark.relevance_level)
        pretty_fold_test_metrics = " ".join([
            f"{metric}={v:0.3f}"
            for metric, v in sorted(fold_test_metrics.items())
        ])
        logger.info("rerank: fold=%s test metrics: %s", fold,
                    pretty_fold_test_metrics)

        if len(reranker_runs) != len(self.benchmark.folds):
            logger.info(
                "rerank: skipping cross-validated metrics because results exist for only %s/%s folds",
                len(reranker_runs),
                len(self.benchmark.folds),
            )
            return {
                "fold_test_metrics": fold_test_metrics,
                "fold_dev_metrics": fold_dev_metrics,
                "cv_metrics": None,
                "interpolated_cv_metrics": None,
            }

        logger.info(
            "rerank: average cross-validated metrics when choosing iteration based on '%s':",
            self.config["optimize"])
        all_preds = {}
        for preds in reranker_runs.values():
            for qid, docscores in preds["test"].items():
                all_preds.setdefault(qid, {})
                for docid, score in docscores.items():
                    all_preds[qid][docid] = score

        cv_metrics = evaluator.eval_runs(all_preds, self.benchmark.qrels,
                                         metrics,
                                         self.benchmark.relevance_level)
        interpolated_results = evaluator.interpolated_eval(
            searcher_runs, reranker_runs, self.benchmark,
            self.config["optimize"], metrics)

        for metric, score in sorted(cv_metrics.items()):
            logger.info("%25s: %0.4f", metric, score)

        logger.info("interpolated with alphas = %s",
                    sorted(interpolated_results["alphas"].values()))
        for metric, score in sorted(interpolated_results["score"].items()):
            logger.info("%25s: %0.4f", metric + " [interp]", score)

        return {
            "fold_test_metrics": fold_test_metrics,
            "fold_dev_metrics": fold_dev_metrics,
            "cv_metrics": cv_metrics,
            "interpolated_results": interpolated_results,
        }
Ejemplo n.º 10
0
    def train(self,
              reranker,
              train_dataset,
              train_output_path,
              dev_data,
              dev_output_path,
              qrels,
              metric,
              relevance_level=1):
        """Train a model following the trainer's config (specifying batch size, number of iterations, etc).

        Args:
           train_dataset (IterableDataset): training dataset
           train_output_path (Path): directory under which train_dataset runs and training loss will be saved
           dev_data (IterableDataset): dev dataset
           dev_output_path (Path): directory where dev_data runs and metrics will be saved

        """
        # Set up logging
        # TODO why not put this under train_output_path?
        summary_writer = SummaryWriter(RESULTS_BASE_PATH / "runs" /
                                       self.config["boardname"],
                                       comment=train_output_path)

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        model = reranker.model.to(self.device)
        self.optimizer = torch.optim.Adam(filter(
            lambda param: param.requires_grad, model.parameters()),
                                          lr=self.config["lr"])

        if self.config["softmaxloss"]:
            self.loss = pair_softmax_loss
        else:
            self.loss = pair_hinge_loss

        dev_best_weight_fn, weights_output_path, info_output_path, loss_fn = self.get_paths_for_early_stopping(
            train_output_path, dev_output_path)

        initial_iter = self.fastforward_training(
            reranker, weights_output_path,
            loss_fn) if self.config["fastforward"] else 0
        logger.info("starting training from iteration %s/%s", initial_iter,
                    self.config["niters"])

        num_workers = 1 if self.config["multithread"] else 0
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.config["batch"],
            pin_memory=True,
            num_workers=num_workers)
        # dataiter = iter(train_dataloader)
        # sample_input = dataiter.next()
        # summary_writer.add_graph(
        #     reranker.model,
        #     [
        #         sample_input["query"].to(self.device),
        #         sample_input["posdoc"].to(self.device),
        #         sample_input["negdoc"].to(self.device),
        #     ],
        # )

        train_loss = []
        # are we resuming training?
        if initial_iter > 0:
            train_loss = self.load_loss_file(loss_fn)

            # are we done training?
            if initial_iter < self.config["niters"]:
                logger.debug("fastforwarding train_dataloader to iteration %s",
                             initial_iter)
                batches_per_epoch = self.config["itersize"] // self.config[
                    "batch"]
                for niter in range(initial_iter):
                    for bi, batch in enumerate(train_dataloader):
                        if (bi + 1) % batches_per_epoch == 0:
                            break

        dev_best_metric = -np.inf
        validation_frequency = self.config["validatefreq"]
        train_start_time = time.time()
        for niter in range(initial_iter, self.config["niters"]):
            model.train()

            iter_start_time = time.time()
            iter_loss_tensor = self.single_train_iteration(
                reranker, train_dataloader)
            logger.info("A single iteration takes {}".format(time.time() -
                                                             iter_start_time))
            train_loss.append(iter_loss_tensor.item())
            logger.info("iter = %d loss = %f", niter, train_loss[-1])

            # write model weights to file
            weights_fn = weights_output_path / f"{niter}.p"
            reranker.save_weights(weights_fn, self.optimizer)
            # predict performance on dev set

            if niter % validation_frequency == 0:
                pred_fn = dev_output_path / f"{niter}.run"
                preds = self.predict(reranker, dev_data, pred_fn)

                # log dev metrics
                metrics = evaluator.eval_runs(preds, qrels,
                                              evaluator.DEFAULT_METRICS,
                                              relevance_level)
                logger.info(
                    "dev metrics: %s", " ".join([
                        f"{metric}={v:0.3f}"
                        for metric, v in sorted(metrics.items())
                    ]))
                summary_writer.add_scalar("ndcg_cut_20",
                                          metrics["ndcg_cut_20"], niter)
                summary_writer.add_scalar("map", metrics["map"], niter)
                summary_writer.add_scalar("P_20", metrics["P_20"], niter)
                # write best dev weights to file
                if metrics[metric] > dev_best_metric:
                    dev_best_metric = metrics[metric]
                    logger.info("new best dev metric: %0.4f", dev_best_metric)
                    reranker.save_weights(dev_best_weight_fn, self.optimizer)

            # write train_loss to file
            loss_fn.write_text("\n".join(
                f"{idx} {loss}" for idx, loss in enumerate(train_loss)))

            summary_writer.add_scalar("training_loss", iter_loss_tensor.item(),
                                      niter)
            reranker.add_summary(summary_writer, niter)
            summary_writer.flush()
        logger.info("training loss: %s", train_loss)
        logger.info("Training took {}".format(time.time() - train_start_time))
        summary_writer.close()
Ejemplo n.º 11
0
    def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric):
        """Train a model following the trainer's config (specifying batch size, number of iterations, etc).

        Args:
           train_dataset (IterableDataset): training dataset
           train_output_path (Path): directory under which train_dataset runs and training loss will be saved
           dev_data (IterableDataset): dev dataset
           dev_output_path (Path): directory where dev_data runs and metrics will be saved

        """

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model = reranker.model.to(self.device)
        self.optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=self.cfg["lr"])

        if self.cfg["softmaxloss"]:
            self.loss = pair_softmax_loss
        else:
            self.loss = pair_hinge_loss

        os.makedirs(dev_output_path, exist_ok=True)
        dev_best_weight_fn = train_output_path / "dev.best"
        weights_output_path = train_output_path / "weights"
        info_output_path = train_output_path / "info"
        os.makedirs(weights_output_path, exist_ok=True)
        os.makedirs(info_output_path, exist_ok=True)

        loss_fn = info_output_path / "loss.txt"
        metrics_fn = dev_output_path / "metrics.json"
        metrics_history = {}
        initial_iter = self.fastforward_training(reranker, weights_output_path, loss_fn)
        logger.info("starting training from iteration %s/%s", initial_iter, self.cfg["niters"])

        train_dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=self.cfg["batch"], pin_memory=True, num_workers=0
        )

        train_loss = []
        # are we resuming training?
        if initial_iter > 0:
            train_loss = self.load_loss_file(loss_fn)

            # are we done training?
            if initial_iter < self.cfg["niters"]:
                logger.debug("fastforwarding train_dataloader to iteration %s", initial_iter)
                batches_per_epoch = self.cfg["itersize"] // self.cfg["batch"]
                for niter in range(initial_iter):
                    for bi, batch in enumerate(train_dataloader):
                        if (bi + 1) % batches_per_epoch == 0:
                            break

        dev_best_metric = -np.inf
        for niter in range(initial_iter, self.cfg["niters"]):
            model.train()

            iter_loss_tensor = self.single_train_iteration(reranker, train_dataloader)

            train_loss.append(iter_loss_tensor.item())
            logger.info("iter = %d loss = %f", niter, train_loss[-1])

            # write model weights to file
            weights_fn = weights_output_path / f"{niter}.p"
            reranker.save_weights(weights_fn, self.optimizer)

            # predict performance on dev set
            pred_fn = dev_output_path / f"{niter}.run"
            preds = self.predict(reranker, dev_data, pred_fn)

            # log dev metrics
            metrics = evaluator.eval_runs(preds, qrels, ["ndcg_cut_20", "map", "P_20"])
            logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())]))

            # write best dev weights to file
            if metrics[metric] > dev_best_metric:
                reranker.save_weights(dev_best_weight_fn, self.optimizer)
            for m in metrics:
                metrics_history.setdefault(m, []).append(metrics[m])

            # write train_loss to file
            loss_fn.write_text("\n".join(f"{idx} {loss}" for idx, loss in enumerate(train_loss)))

        json.dump(metrics_history, open(metrics_fn, "w", encoding="utf-8"))
        plot_metrics(metrics_history, str(dev_output_path) + ".pdf", interactive=self.cfg["interactive"])
        plot_loss(train_loss, str(loss_fn).replace(".txt", ".pdf"), interactive=self.cfg["interactive"])