def evaluate(config, modules): metric = "map" fold = config["fold"] train_output_path = _pipeline_path(config, modules) test_output_path = train_output_path / "pred" / "test" / "best" searcher = modules["searcher"] benchmark = modules["benchmark"] reranker = modules["reranker"] if os.path.exists(test_output_path): test_preds = Searcher.load_trec_run(test_output_path) else: topics_fn = benchmark.topic_file searcher_cache_dir = os.path.join(searcher.get_cache_path(), benchmark.name) searcher_run_dir = searcher.query_from_file(topics_fn, searcher_cache_dir) best_search_run_path = evaluator.search_best_run(searcher_run_dir, benchmark, metric)["path"][fold] best_search_run = searcher.load_trec_run(best_search_run_path) docids = set(docid for querydocs in best_search_run.values() for docid in querydocs) reranker["extractor"].create(qids=best_search_run.keys(), docids=docids, topics=benchmark.topics[benchmark.query_type]) reranker.build() reranker["trainer"].load_best_model(reranker, train_output_path) test_run = {qid: docs for qid, docs in best_search_run.items() if qid in benchmark.folds[fold]["predict"]["test"]} test_dataset = PredDataset(qid_docid_to_rank=test_run, extractor=reranker["extractor"], mode="test") test_preds = reranker["trainer"].predict(reranker, test_dataset, test_output_path) metrics = evaluator.eval_runs(test_preds, benchmark.qrels, ["ndcg_cut_20", "ndcg_cut_10", "map", "P_20", "P_10"]) print("test metrics for fold=%s:" % fold, metrics) print("\ncomputing metrics across all folds") avg = {} found = 0 for fold in benchmark.folds: pred_path = _pipeline_path(config, modules, fold=fold) / "pred" / "test" / "best" if not os.path.exists(pred_path): print("\tfold=%s results are missing and will not be included" % fold) continue found += 1 preds = Searcher.load_trec_run(pred_path) metrics = evaluator.eval_runs(preds, benchmark.qrels, ["ndcg_cut_20", "ndcg_cut_10", "map", "P_20", "P_10"]) for metric, val in metrics.items(): avg.setdefault(metric, []).append(val) avg = {k: np.mean(v) for k, v in avg.items()} print(f"average metrics across {found}/{len(benchmark.folds)} folds:", avg)
def bircheval(self): fold = self.config["fold"] train_output_path = self.get_results_path() searcher_runs, reranker_runs = self.find_birch_crossvalidated_results() fold_test_metrics = evaluator.eval_runs( reranker_runs[fold]["test"], self.benchmark.qrels, evaluator.DEFAULT_METRICS, self.benchmark.relevance_level ) logger.info("rerank: fold=%s test metrics: %s", fold, fold_test_metrics)
def evaluate(self): fold = self.config["fold"] train_output_path = self.get_results_path() test_output_path = train_output_path / "pred" / "test" / "best" logger.debug("results path: %s", train_output_path) if os.path.exists(test_output_path): test_preds = Searcher.load_trec_run(test_output_path) else: self.rank.search() rank_results = self.rank.evaluate() best_search_run_path = rank_results["path"][fold] best_search_run = Searcher.load_trec_run(best_search_run_path) docids = set(docid for querydocs in best_search_run.values() for docid in querydocs) self.reranker.extractor.preprocess( qids=best_search_run.keys(), docids=docids, topics=self.benchmark.topics[self.benchmark.query_type]) self.reranker.build_model() self.reranker.searcher_scores = best_search_run self.reranker.trainer.load_best_model(self.reranker, train_output_path) test_run = { qid: docs for qid, docs in best_search_run.items() if qid in self.benchmark.folds[fold]["predict"]["test"] } test_dataset = PredSampler() test_dataset.prepare(test_run, self.benchmark.qrels, self.reranker.extractor) test_preds = self.reranker.trainer.predict(self.reranker, test_dataset, test_output_path) metrics = evaluator.eval_runs(test_preds, self.benchmark.qrels, evaluator.DEFAULT_METRICS, self.benchmark.relevance_level) logger.info("rerank: fold=%s test metrics: %s", fold, metrics) print("\ncomputing metrics across all folds") avg = {} found = 0 for fold in self.benchmark.folds: # TODO fix by using multiple Tasks from pathlib import Path pred_path = Path(test_output_path.as_posix().replace( "fold-" + self.config["fold"], "fold-" + fold)) if not os.path.exists(pred_path): print( "\tfold=%s results are missing and will not be included" % fold) continue found += 1 preds = Searcher.load_trec_run(pred_path) metrics = evaluator.eval_runs(preds, self.benchmark.qrels, evaluator.DEFAULT_METRICS, self.benchmark.relevance_level) for metric, val in metrics.items(): avg.setdefault(metric, []).append(val) avg = {k: np.mean(v) for k, v in avg.items()} logger.info( "rerank: average cross-validated metrics when choosing iteration based on '%s':", self.config["optimize"]) for metric, score in sorted(avg.items()): logger.info("%25s: %0.4f", metric, score)
def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1, init_path=None): if self.tpu: train_output_path = "{0}/{1}/{2}".format( self.config["storage"], "train_output", hashlib.md5( str(train_output_path).encode("utf-8")).hexdigest()) os.makedirs(dev_output_path, exist_ok=True) start_epoch = self.config["niters"] if reranker.config.get( "modeltype", "") in ["nir", "cedr"] else 0 train_records = self.get_tf_train_records(reranker, train_dataset) dev_records = self.get_tf_dev_records(reranker, dev_data) dev_dist_dataset = self.strategy.experimental_distribute_dataset( dev_records) # Does not very much from https://www.tensorflow.org/tutorials/distribute/custom_training strategy_scope = self.strategy.scope() with strategy_scope: reranker.build_model() wrapped_model = self.get_wrapped_model(reranker.model) if init_path: logger.info(f"Initializing model from checkpoint {init_path}") print("number of vars: ", len(wrapped_model.trainable_variables)) wrapped_model.load_weights(init_path) loss_object = self.get_loss(self.config["loss"]) optimizer_1 = tf.keras.optimizers.Adam( learning_rate=self.config["lr"]) optimizer_2 = tf.keras.optimizers.Adam( learning_rate=self.config["bertlr"]) def compute_loss(labels, predictions): per_example_loss = loss_object(labels, predictions) return tf.nn.compute_average_loss( per_example_loss, global_batch_size=self.config["batch"]) def is_bert_parameters(name): name = name.lower() ''' if "layer" in name: if not ("9" in name or "10" in name or "11" in name or "12" in name): return False ''' if "/bert/" in name: return True if "/electra/" in name: return True if "/roberta/" in name: return True if "/albert/" in name: return True return False def train_step(inputs): data, labels = inputs with tf.GradientTape() as tape: train_predictions = wrapped_model(data, training=True) loss = compute_loss(labels, train_predictions) gradients = tape.gradient(loss, wrapped_model.trainable_variables) # TODO: Expose the layer names to lookout for as a ConfigOption? # TODO: Crystina mentioned that hugging face models have 'bert' in all the layers (including classifiers). Handle this case bert_variables = [ (gradients[i], variable) for i, variable in enumerate(wrapped_model.trainable_variables) if is_bert_parameters(variable.name) and "classifier" not in variable.name ] classifier_vars = [ (gradients[i], variable) for i, variable in enumerate(wrapped_model.trainable_variables) if "classifier" in variable.name ] other_vars = [ (gradients[i], variable) for i, variable in enumerate(wrapped_model.trainable_variables) if (not is_bert_parameters(variable.name)) and "classifier" not in variable.name ] assert len(bert_variables) + len(classifier_vars) + len( other_vars) == len(wrapped_model.trainable_variables) # TODO: Clean this up for general use # Making sure that we did not miss any variables if self.config["lr"] > 0: optimizer_1.apply_gradients(classifier_vars + other_vars) if self.config["bertlr"] > 0: optimizer_2.apply_gradients(bert_variables) return loss def test_step(inputs): data, labels = inputs predictions = wrapped_model.predict_step(data) return predictions @tf.function def distributed_train_step(dataset_inputs): per_replica_losses = self.strategy.run(train_step, args=(dataset_inputs, )) return self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) @tf.function def distributed_test_step(dataset_inputs): return self.strategy.run(test_step, args=(dataset_inputs, )) best_metric = -np.inf epoch = 0 num_batches = 0 total_loss = 0 iter_bar = tqdm(total=self.config["itersize"]) initial_lr = self.change_lr(epoch, self.config["bertlr"], do_warmup=self.config["warmupbert"]) K.set_value(optimizer_2.lr, K.get_value(initial_lr)) wandb.log({"bertlr": K.get_value(initial_lr)}, step=epoch + start_epoch, commit=False) initial_lr = self.change_lr(epoch, self.config["lr"], do_warmup=self.config["warmupnonbert"]) K.set_value(optimizer_1.lr, K.get_value(initial_lr)) wandb.log({"lr": K.get_value(initial_lr)}, step=epoch + start_epoch, commit=False) train_records = train_records.shuffle(100000) train_dist_dataset = self.strategy.experimental_distribute_dataset( train_records) # Goes through the dataset ONCE (i.e niters * itersize * batch samples). However, the dataset may already contain multiple instances of the same sample, # depending upon what Sampler was used. If you want multiple epochs, achieve it by tweaking the niters and # itersize values. for x in train_dist_dataset: total_loss += distributed_train_step(x) train_loss = total_loss / num_batches num_batches += 1 iter_bar.update(1) if num_batches % self.config["itersize"] == 0: epoch += 1 # Do warmup and decay new_lr = self.change_lr(epoch, self.config["bertlr"], do_warmup=self.config["warmupbert"]) K.set_value(optimizer_2.lr, K.get_value(new_lr)) wandb.log({f"bertlr": K.get_value(new_lr)}, step=epoch + start_epoch, commit=False) new_lr = self.change_lr(epoch, self.config["lr"], do_warmup=self.config["warmupnonbert"]) K.set_value(optimizer_1.lr, K.get_value(new_lr)) wandb.log({f"lr": K.get_value(new_lr)}, step=epoch + start_epoch, commit=False) iter_bar.close() logger.info("train_loss for epoch {} is {}".format( epoch, train_loss)) wandb.log({f"loss": float(train_loss.numpy())}, step=epoch + start_epoch, commit=False) total_loss = 0 if epoch % self.config["validatefreq"] == 0: dev_predictions = [] for x in tqdm(dev_dist_dataset, desc="validation"): pred_batch = (distributed_test_step(x).values if self.strategy.num_replicas_in_sync > 1 else [distributed_test_step(x)]) for p in pred_batch: dev_predictions.extend(p) trec_preds = self.get_preds_in_trec_format( dev_predictions, dev_data) metrics = evaluator.eval_runs( trec_preds, dict(qrels), evaluator.DEFAULT_METRICS + ["bpref"], relevance_level) logger.info( "dev metrics: %s", " ".join([ f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items()) ])) if metrics[metric] > best_metric: logger.info("Writing checkpoint") best_metric = metrics[metric] wrapped_model.save_weights( "{0}/dev.best".format(train_output_path)) wandb.log( { f"dev-{k}": v for k, v in metrics.items() if k in [ "map", "bpref", "P_20", "ndcg_cut_20", "judged_10", "judged_20", "judged_200" ] }, step=epoch + start_epoch, commit=False) iter_bar = tqdm(total=self.config["itersize"]) if num_batches >= self.config["niters"] * self.config["itersize"]: break
def evaluate(self): fold = self.config["fold"] train_output_path = self.get_results_path() logger.debug("results path: %s", train_output_path) searcher_runs, reranker_runs = self.find_crossvalidated_results() if fold not in reranker_runs: logger.error( "could not find predictions; run the train command first") raise ValueError( "could not find predictions; run the train command first") fold_dev_metrics = evaluator.eval_runs(reranker_runs[fold]["dev"], self.benchmark.qrels, self.metrics, self.benchmark.relevance_level) logger.info("rerank: fold=%s dev metrics: %s", fold, fold_dev_metrics) unsampled_qrels = self.benchmark.unsampled_qrels if hasattr( self.benchmark, "unsampled_qrels") else self.benchmark.qrels fold_test_metrics = evaluator.eval_runs(reranker_runs[fold]["test"], unsampled_qrels, self.metrics, self.benchmark.relevance_level) logger.info("rerank: fold=%s test metrics: %s", fold, fold_test_metrics) if len(reranker_runs) != len(self.benchmark.folds): logger.info( "rerank: skipping cross-validated metrics because results exist for only %s/%s folds", len(reranker_runs), len(self.benchmark.folds)) logger.info("available runs: ", reranker_runs.keys()) return { "fold_test_metrics": fold_test_metrics, "fold_dev_metrics": fold_dev_metrics, "cv_metrics": None, "interpolated_cv_metrics": None, } logger.info( "rerank: average cross-validated metrics when choosing iteration based on '%s':", self.config["optimize"]) all_preds = {} for preds in reranker_runs.values(): for qid, docscores in preds["test"].items(): all_preds.setdefault(qid, {}) for docid, score in docscores.items(): all_preds[qid][docid] = score cv_metrics = evaluator.eval_runs(all_preds, unsampled_qrels, self.metrics, self.benchmark.relevance_level) interpolated_results = evaluator.interpolated_eval( searcher_runs, reranker_runs, self.benchmark, self.config["optimize"], self.metrics) for metric, score in sorted(cv_metrics.items()): logger.info("%25s: %0.4f", metric, score) return { "fold_test_metrics": fold_test_metrics, "fold_dev_metrics": fold_dev_metrics, "cv_metrics": cv_metrics, "interpolated_results": interpolated_results, }
def predict_and_eval(self, init_path=None): fold = self.config["fold"] self.reranker.build_model() if not init_path or init_path == "none": logger.info(f"Loading self best ckpt: {init_path}") logger.info("No init path given, using default parameters") self.reranker.build_model() else: logger.info(f"Load from {init_path}") init_path = Path( init_path) if not init_path.startswith("gs:") else init_path self.reranker.trainer.load_best_model(self.reranker, init_path, do_not_hash=True) dirname = str(init_path).split("/")[-1] if init_path else "noinitpath" savedir = Path( __file__).parent.absolute() / "downloaded_runfiles" / dirname dev_output_path = savedir / fold / "dev" test_output_path = savedir / fold / "test" test_output_path.parent.mkdir(exist_ok=True, parents=True) self.rank.search() threshold = self.config["threshold"] rank_results = self.rank.evaluate() best_search_run_path = rank_results["path"][fold] best_search_run = Searcher.load_trec_run(best_search_run_path) docids = set(docid for querydocs in best_search_run.values() for docid in querydocs) self.reranker.extractor.preprocess( qids=best_search_run.keys(), docids=docids, topics=self.benchmark.topics[self.benchmark.query_type]) # dev run dev_run = defaultdict(dict) for qid, docs in best_search_run.items(): if qid in self.benchmark.folds[fold]["predict"]["dev"]: for idx, (docid, score) in enumerate(docs.items()): if idx >= threshold: assert len( dev_run[qid] ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}" break dev_run[qid][docid] = score dev_dataset = PredSampler() dev_dataset.prepare(dev_run, self.benchmark.qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level) # test_run test_run = defaultdict(dict) # This is possible because best_search_run is an OrderedDict for qid, docs in best_search_run.items(): if qid in self.benchmark.folds[fold]["predict"]["test"]: for idx, (docid, score) in enumerate(docs.items()): if idx >= threshold: assert len( test_run[qid] ) == threshold, f"Expect {threshold} on each qid, got {len(dev_run[qid])} for query {qid}" break test_run[qid][docid] = score unsampled_qrels = self.benchmark.unsampled_qrels if hasattr( self.benchmark, "unsampled_qrels") else self.benchmark.qrels test_dataset = PredSampler() test_dataset.prepare(test_run, unsampled_qrels, self.reranker.extractor, relevance_level=self.benchmark.relevance_level) logger.info("test prepared") # prediction dev_preds = self.reranker.trainer.predict(self.reranker, dev_dataset, dev_output_path) fold_dev_metrics = evaluator.eval_runs(dev_preds, unsampled_qrels, self.metrics, self.benchmark.relevance_level) logger.info("rerank: fold=%s dev metrics: %s", fold, fold_dev_metrics) test_preds = self.reranker.trainer.predict(self.reranker, test_dataset, test_output_path) fold_test_metrics = evaluator.eval_runs(test_preds, unsampled_qrels, self.metrics, self.benchmark.relevance_level) logger.info("rerank: fold=%s test metrics: %s", fold, fold_test_metrics) wandb.save(str(dev_output_path)) wandb.save(str(test_output_path)) # add cross validate results: n_folds = len(self.benchmark.folds) folds_fn = { f"s{i}": savedir / f"s{i}" / "test" for i in range(1, n_folds + 1) } if not all([fn.exists() for fn in folds_fn.values()]): return {"fold_test_metrics": fold_test_metrics, "cv_metrics": None} all_preds = {} reranker_runs = { fold: { "dev": Searcher.load_trec_run(fn.parent / "dev"), "test": Searcher.load_trec_run(fn) } for fold, fn in folds_fn.items() } for fold, dev_test in reranker_runs.items(): preds = dev_test["test"] qids = self.benchmark.folds[fold]["predict"]["test"] for qid, docscores in preds.items(): if qid not in qids: continue all_preds.setdefault(qid, {}) for docid, score in docscores.items(): all_preds[qid][docid] = score cv_metrics = evaluator.eval_runs(all_preds, unsampled_qrels, self.metrics, self.benchmark.relevance_level) for metric, score in sorted(cv_metrics.items()): logger.info("%25s: %0.4f", metric, score) searcher_runs = {} rank_results = self.rank.evaluate() for fold in self.benchmark.folds: searcher_runs[fold] = { "dev": Searcher.load_trec_run(rank_results["path"][fold]) } searcher_runs[fold]["test"] = searcher_runs[fold]["dev"] interpolated_results = evaluator.interpolated_eval( searcher_runs, reranker_runs, self.benchmark, self.config["optimize"], self.metrics) return { "fold_test_metrics": fold_test_metrics, "cv_metrics": cv_metrics, "interpolated_results": interpolated_results, }
def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1): """Train a model following the trainer's config (specifying batch size, number of iterations, etc). Args: train_dataset (IterableDataset): training dataset train_output_path (Path): directory under which train_dataset runs and training loss will be saved dev_data (IterableDataset): dev dataset dev_output_path (Path): directory where dev_data runs and metrics will be saved """ # Set up logging # TODO why not put this under train_output_path? summary_writer = SummaryWriter(RESULTS_BASE_PATH / "runs" / self.config["boardname"], comment=train_output_path) self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") model = reranker.model.to(self.device) self.optimizer = torch.optim.Adam(filter( lambda param: param.requires_grad, model.parameters()), lr=self.config["lr"]) if self.config["amp"] in ("both", "train"): self.amp_train_autocast = torch.cuda.amp.autocast self.scaler = torch.cuda.amp.GradScaler() else: self.amp_train_autocast = contextlib.nullcontext self.scaler = None # REF-TODO how to handle interactions between fastforward and schedule? --> just save its state self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lambda epoch: self.lr_multiplier(step=epoch * self. n_batch_per_iter)) if self.config["softmaxloss"]: self.loss = pair_softmax_loss else: self.loss = pair_hinge_loss dev_best_weight_fn, weights_output_path, info_output_path, loss_fn, metric_fn = self.get_paths_for_early_stopping( train_output_path, dev_output_path) num_workers = 1 if self.config["multithread"] else 0 train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=self.config["batch"], pin_memory=True, num_workers=num_workers) # if we're fastforwarding, set first iteration and load last saved weights initial_iter, metrics = (self.fastforward_training( reranker, weights_output_path, loss_fn, metric_fn) if self.config["fastforward"] else (0, {})) dev_best_metric = metrics.get(metric, -np.inf) logger.info("starting training from iteration %s/%s", initial_iter + 1, self.config["niters"]) logger.info(f"Best metric loaded: {metric}={dev_best_metric}") train_loss = [] # are we resuming training? fastforward loss and data if so if initial_iter > 0: train_loss = self.load_loss_file(loss_fn) # are we done training? if not, fastforward through prior batches if initial_iter < self.config["niters"]: logger.debug("fastforwarding train_dataloader to iteration %s", initial_iter) self.exhaust_used_train_data(train_dataloader, n_batch_to_exhaust=initial_iter * self.n_batch_per_iter) logger.info(self.get_validation_schedule_msg(initial_iter)) train_start_time = time.time() for niter in range(initial_iter, self.config["niters"]): niter = niter + 1 # index from 1 model.train() iter_start_time = time.time() iter_loss_tensor = self.single_train_iteration( reranker, train_dataloader) logger.info("A single iteration takes {}".format(time.time() - iter_start_time)) train_loss.append(iter_loss_tensor.item()) logger.info("iter = %d loss = %f", niter, train_loss[-1]) # save model weights only when fastforward enabled if self.config["fastforward"]: weights_fn = weights_output_path / f"{niter}.p" reranker.save_weights(weights_fn, self.optimizer) # predict performance on dev set if niter % self.config["validatefreq"] == 0: pred_fn = dev_output_path / f"{niter}.run" preds = self.predict(reranker, dev_data, pred_fn) # log dev metrics metrics = evaluator.eval_runs(preds, qrels, evaluator.DEFAULT_METRICS, relevance_level) logger.info( "dev metrics: %s", " ".join([ f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items()) ])) summary_writer.add_scalar("ndcg_cut_20", metrics["ndcg_cut_20"], niter) summary_writer.add_scalar("map", metrics["map"], niter) summary_writer.add_scalar("P_20", metrics["P_20"], niter) # write best dev weights to file if metrics[metric] > dev_best_metric: dev_best_metric = metrics[metric] logger.info("new best dev metric: %0.4f", dev_best_metric) reranker.save_weights(dev_best_weight_fn, self.optimizer) self.write_to_metric_file(metric_fn, metrics) # write train_loss to file # loss_fn.write_text("\n".join(f"{idx} {loss}" for idx, loss in enumerate(train_loss))) self.write_to_loss_file(loss_fn, train_loss) summary_writer.add_scalar("training_loss", iter_loss_tensor.item(), niter) reranker.add_summary(summary_writer, niter) summary_writer.flush() logger.info("training loss: %s", train_loss) logger.info("Training took {}".format(time.time() - train_start_time)) summary_writer.close()
def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1): if self.tpu: # WARNING: not sure if pathlib is compatible with gs:// train_output_path = Path( "{0}/{1}/{2}".format( self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest() ) ) dev_best_weight_fn, weights_output_path, info_output_path, loss_fn, metric_fn = self.get_paths_for_early_stopping( train_output_path, dev_output_path ) train_records = self.get_tf_train_records(reranker, train_dataset) dev_records = self.get_tf_dev_records(reranker, dev_data) dev_dist_dataset = self.strategy.experimental_distribute_dataset(dev_records) # Does not very much from https://www.tensorflow.org/tutorials/distribute/custom_training strategy_scope = self.strategy.scope() with strategy_scope: reranker.build_model() wrapped_model = self.get_wrapped_model(reranker.model) loss_object = self.get_loss(self.config["loss"]) optimizer_1 = tf.keras.optimizers.Adam(learning_rate=self.config["lr"]) optimizer_2 = tf.keras.optimizers.Adam(learning_rate=self.config["bertlr"]) # "You should remove the use of the LossScaleOptimizer when TPUs are used." if self.amp and not self.tpu: optimizer_2 = mixed_precision.LossScaleOptimizer(optimizer_2, loss_scale="dynamic") def compute_loss(labels, predictions): per_example_loss = loss_object(labels, predictions) return tf.nn.compute_average_loss(per_example_loss, global_batch_size=self.config["batch"]) def is_bert_variable(name): if "bert" in name: return True if "electra" in name: return True return False def train_step(inputs): data, labels = inputs with tf.GradientTape() as tape: train_predictions = wrapped_model(data, training=True) loss = compute_loss(labels, train_predictions) if self.amp and not self.tpu: loss = optimizer_2.get_scaled_loss(loss) gradients = tape.gradient(loss, wrapped_model.trainable_variables) if self.amp and not self.tpu: optimizer_2.get_unscaled_gradients(gradients) bert_variables = [ (gradients[i], variable) for i, variable in enumerate(wrapped_model.trainable_variables) if is_bert_variable(variable.name) and "classifier" not in variable.name ] classifier_vars = [ (gradients[i], variable) for i, variable in enumerate(wrapped_model.trainable_variables) if "classifier" in variable.name ] other_vars = [ (gradients[i], variable) for i, variable in enumerate(wrapped_model.trainable_variables) if not is_bert_variable(variable.name) and "classifier" not in variable.name ] assert len(bert_variables) + len(classifier_vars) + len(other_vars) == len(wrapped_model.trainable_variables) # TODO: Clean this up for general use # Making sure that we did not miss any variables optimizer_1.apply_gradients(classifier_vars) optimizer_2.apply_gradients(bert_variables) if other_vars: optimizer_1.apply_gradients(other_vars) return loss def test_step(inputs): data, labels = inputs predictions = wrapped_model.predict_step(data) return predictions @tf.function def distributed_train_step(dataset_inputs): per_replica_losses = self.strategy.run(train_step, args=(dataset_inputs,)) return self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) @tf.function def distributed_test_step(dataset_inputs): return self.strategy.run(test_step, args=(dataset_inputs,)) train_records = train_records.shuffle(100000) train_dist_dataset = self.strategy.experimental_distribute_dataset(train_records) initial_iter, metrics = ( self.fastforward_training(wrapped_model, weights_output_path, loss_fn, metric_fn) if self.config["fastforward"] else (0, {}) ) dev_best_metric = metrics.get(metric, -np.inf) logger.info("starting training from iteration %s/%s", initial_iter + 1, self.config["niters"]) logger.info(f"Best metric loaded: {metric}={dev_best_metric}") cur_step = initial_iter * self.n_batch_per_iter initial_lr = self.change_lr(step=cur_step, lr=self.config["bertlr"]) K.set_value(optimizer_2.lr, K.get_value(initial_lr)) train_loss = self.load_loss_file(loss_fn) if initial_iter > 0 else [] if 0 < initial_iter < self.config["niters"]: self.exhaust_used_train_data(train_dist_dataset, n_batch_to_exhaust=initial_iter * self.n_batch_per_iter) niter = initial_iter total_loss = 0 trec_preds = {} iter_bar = tqdm(desc="Training iteration", total=self.n_batch_per_iter) # Goes through the dataset ONCE (i.e niters * itersize). # However, the dataset may already contain multiple instances of the same sample, # depending upon what Sampler was used. # If you want multiple epochs, achieve it by tweaking the niters and itersize values. for x in train_dist_dataset: total_loss += distributed_train_step(x) cur_step += 1 iter_bar.update(1) # Do warmup and decay new_lr = self.change_lr(step=cur_step, lr=self.config["bertlr"]) K.set_value(optimizer_2.lr, K.get_value(new_lr)) if cur_step % self.n_batch_per_iter == 0: niter += 1 iter_bar.close() iter_bar = tqdm(total=self.n_batch_per_iter) train_loss.append(total_loss / self.n_batch_per_iter) logger.info("iter={} loss = {}".format(niter, train_loss[-1])) self.write_to_loss_file(loss_fn, train_loss) total_loss = 0 if self.config["fastforward"]: wrapped_model.save_weights(f"{weights_output_path}/{niter}") if niter % self.config["validatefreq"] == 0: dev_predictions = [] for x in tqdm(dev_dist_dataset, desc="validation"): pred_batch = ( distributed_test_step(x).values if self.strategy.num_replicas_in_sync > 1 else [distributed_test_step(x)] ) for p in pred_batch: dev_predictions.extend(p) trec_preds = self.get_preds_in_trec_format(dev_predictions, dev_data) metrics = evaluator.eval_runs(trec_preds, dict(qrels), evaluator.DEFAULT_METRICS, relevance_level) logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())])) if metrics[metric] > dev_best_metric: dev_best_metric = metrics[metric] logger.info("new best dev metric: %0.4f", dev_best_metric) self.write_to_metric_file(metric_fn, metrics) wrapped_model.save_weights(dev_best_weight_fn) Searcher.write_trec_run(trec_preds, outfn=(dev_output_path / "best").as_posix()) if cur_step >= self.config["niters"] * self.n_batch_per_iter: break return trec_preds
def evaluate(self): fold = self.config["fold"] train_output_path = self.get_results_path() logger.debug("results path: %s", train_output_path) metrics = self.config["metrics"] if list(self.config["metrics"]) != [ "default" ] else evaluator.DEFAULT_METRICS searcher_runs, reranker_runs = self.find_crossvalidated_results() if fold not in reranker_runs: logger.error( "could not find predictions; run the train command first") raise ValueError( "could not find predictions; run the train command first") fold_dev_metrics = evaluator.eval_runs(reranker_runs[fold]["dev"], self.benchmark.qrels, metrics, self.benchmark.relevance_level) pretty_fold_dev_metrics = " ".join([ f"{metric}={v:0.3f}" for metric, v in sorted(fold_dev_metrics.items()) ]) logger.info("rerank: fold=%s dev metrics: %s", fold, pretty_fold_dev_metrics) fold_test_metrics = evaluator.eval_runs(reranker_runs[fold]["test"], self.benchmark.qrels, metrics, self.benchmark.relevance_level) pretty_fold_test_metrics = " ".join([ f"{metric}={v:0.3f}" for metric, v in sorted(fold_test_metrics.items()) ]) logger.info("rerank: fold=%s test metrics: %s", fold, pretty_fold_test_metrics) if len(reranker_runs) != len(self.benchmark.folds): logger.info( "rerank: skipping cross-validated metrics because results exist for only %s/%s folds", len(reranker_runs), len(self.benchmark.folds), ) return { "fold_test_metrics": fold_test_metrics, "fold_dev_metrics": fold_dev_metrics, "cv_metrics": None, "interpolated_cv_metrics": None, } logger.info( "rerank: average cross-validated metrics when choosing iteration based on '%s':", self.config["optimize"]) all_preds = {} for preds in reranker_runs.values(): for qid, docscores in preds["test"].items(): all_preds.setdefault(qid, {}) for docid, score in docscores.items(): all_preds[qid][docid] = score cv_metrics = evaluator.eval_runs(all_preds, self.benchmark.qrels, metrics, self.benchmark.relevance_level) interpolated_results = evaluator.interpolated_eval( searcher_runs, reranker_runs, self.benchmark, self.config["optimize"], metrics) for metric, score in sorted(cv_metrics.items()): logger.info("%25s: %0.4f", metric, score) logger.info("interpolated with alphas = %s", sorted(interpolated_results["alphas"].values())) for metric, score in sorted(interpolated_results["score"].items()): logger.info("%25s: %0.4f", metric + " [interp]", score) return { "fold_test_metrics": fold_test_metrics, "fold_dev_metrics": fold_dev_metrics, "cv_metrics": cv_metrics, "interpolated_results": interpolated_results, }
def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1): """Train a model following the trainer's config (specifying batch size, number of iterations, etc). Args: train_dataset (IterableDataset): training dataset train_output_path (Path): directory under which train_dataset runs and training loss will be saved dev_data (IterableDataset): dev dataset dev_output_path (Path): directory where dev_data runs and metrics will be saved """ # Set up logging # TODO why not put this under train_output_path? summary_writer = SummaryWriter(RESULTS_BASE_PATH / "runs" / self.config["boardname"], comment=train_output_path) self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") model = reranker.model.to(self.device) self.optimizer = torch.optim.Adam(filter( lambda param: param.requires_grad, model.parameters()), lr=self.config["lr"]) if self.config["softmaxloss"]: self.loss = pair_softmax_loss else: self.loss = pair_hinge_loss dev_best_weight_fn, weights_output_path, info_output_path, loss_fn = self.get_paths_for_early_stopping( train_output_path, dev_output_path) initial_iter = self.fastforward_training( reranker, weights_output_path, loss_fn) if self.config["fastforward"] else 0 logger.info("starting training from iteration %s/%s", initial_iter, self.config["niters"]) num_workers = 1 if self.config["multithread"] else 0 train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=self.config["batch"], pin_memory=True, num_workers=num_workers) # dataiter = iter(train_dataloader) # sample_input = dataiter.next() # summary_writer.add_graph( # reranker.model, # [ # sample_input["query"].to(self.device), # sample_input["posdoc"].to(self.device), # sample_input["negdoc"].to(self.device), # ], # ) train_loss = [] # are we resuming training? if initial_iter > 0: train_loss = self.load_loss_file(loss_fn) # are we done training? if initial_iter < self.config["niters"]: logger.debug("fastforwarding train_dataloader to iteration %s", initial_iter) batches_per_epoch = self.config["itersize"] // self.config[ "batch"] for niter in range(initial_iter): for bi, batch in enumerate(train_dataloader): if (bi + 1) % batches_per_epoch == 0: break dev_best_metric = -np.inf validation_frequency = self.config["validatefreq"] train_start_time = time.time() for niter in range(initial_iter, self.config["niters"]): model.train() iter_start_time = time.time() iter_loss_tensor = self.single_train_iteration( reranker, train_dataloader) logger.info("A single iteration takes {}".format(time.time() - iter_start_time)) train_loss.append(iter_loss_tensor.item()) logger.info("iter = %d loss = %f", niter, train_loss[-1]) # write model weights to file weights_fn = weights_output_path / f"{niter}.p" reranker.save_weights(weights_fn, self.optimizer) # predict performance on dev set if niter % validation_frequency == 0: pred_fn = dev_output_path / f"{niter}.run" preds = self.predict(reranker, dev_data, pred_fn) # log dev metrics metrics = evaluator.eval_runs(preds, qrels, evaluator.DEFAULT_METRICS, relevance_level) logger.info( "dev metrics: %s", " ".join([ f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items()) ])) summary_writer.add_scalar("ndcg_cut_20", metrics["ndcg_cut_20"], niter) summary_writer.add_scalar("map", metrics["map"], niter) summary_writer.add_scalar("P_20", metrics["P_20"], niter) # write best dev weights to file if metrics[metric] > dev_best_metric: dev_best_metric = metrics[metric] logger.info("new best dev metric: %0.4f", dev_best_metric) reranker.save_weights(dev_best_weight_fn, self.optimizer) # write train_loss to file loss_fn.write_text("\n".join( f"{idx} {loss}" for idx, loss in enumerate(train_loss))) summary_writer.add_scalar("training_loss", iter_loss_tensor.item(), niter) reranker.add_summary(summary_writer, niter) summary_writer.flush() logger.info("training loss: %s", train_loss) logger.info("Training took {}".format(time.time() - train_start_time)) summary_writer.close()
def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric): """Train a model following the trainer's config (specifying batch size, number of iterations, etc). Args: train_dataset (IterableDataset): training dataset train_output_path (Path): directory under which train_dataset runs and training loss will be saved dev_data (IterableDataset): dev dataset dev_output_path (Path): directory where dev_data runs and metrics will be saved """ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = reranker.model.to(self.device) self.optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=self.cfg["lr"]) if self.cfg["softmaxloss"]: self.loss = pair_softmax_loss else: self.loss = pair_hinge_loss os.makedirs(dev_output_path, exist_ok=True) dev_best_weight_fn = train_output_path / "dev.best" weights_output_path = train_output_path / "weights" info_output_path = train_output_path / "info" os.makedirs(weights_output_path, exist_ok=True) os.makedirs(info_output_path, exist_ok=True) loss_fn = info_output_path / "loss.txt" metrics_fn = dev_output_path / "metrics.json" metrics_history = {} initial_iter = self.fastforward_training(reranker, weights_output_path, loss_fn) logger.info("starting training from iteration %s/%s", initial_iter, self.cfg["niters"]) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=self.cfg["batch"], pin_memory=True, num_workers=0 ) train_loss = [] # are we resuming training? if initial_iter > 0: train_loss = self.load_loss_file(loss_fn) # are we done training? if initial_iter < self.cfg["niters"]: logger.debug("fastforwarding train_dataloader to iteration %s", initial_iter) batches_per_epoch = self.cfg["itersize"] // self.cfg["batch"] for niter in range(initial_iter): for bi, batch in enumerate(train_dataloader): if (bi + 1) % batches_per_epoch == 0: break dev_best_metric = -np.inf for niter in range(initial_iter, self.cfg["niters"]): model.train() iter_loss_tensor = self.single_train_iteration(reranker, train_dataloader) train_loss.append(iter_loss_tensor.item()) logger.info("iter = %d loss = %f", niter, train_loss[-1]) # write model weights to file weights_fn = weights_output_path / f"{niter}.p" reranker.save_weights(weights_fn, self.optimizer) # predict performance on dev set pred_fn = dev_output_path / f"{niter}.run" preds = self.predict(reranker, dev_data, pred_fn) # log dev metrics metrics = evaluator.eval_runs(preds, qrels, ["ndcg_cut_20", "map", "P_20"]) logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())])) # write best dev weights to file if metrics[metric] > dev_best_metric: reranker.save_weights(dev_best_weight_fn, self.optimizer) for m in metrics: metrics_history.setdefault(m, []).append(metrics[m]) # write train_loss to file loss_fn.write_text("\n".join(f"{idx} {loss}" for idx, loss in enumerate(train_loss))) json.dump(metrics_history, open(metrics_fn, "w", encoding="utf-8")) plot_metrics(metrics_history, str(dev_output_path) + ".pdf", interactive=self.cfg["interactive"]) plot_loss(train_loss, str(loss_fn).replace(".txt", ".pdf"), interactive=self.cfg["interactive"])