dest="max_context_length", type=int, help="Maximum length of context. (Don't set to inherit from training config)", ) # output folder parser.add_argument( "--output_path", dest="output_path", type=str, default="output", help="Path to the output.", ) parser.add_argument( "--use_cuda", dest="use_cuda", action="store_true", default=False, help="run on gpu" ) parser.add_argument( "--no_logger", dest="no_logger", action="store_true", default=False, help="don't log progress" ) args = parser.parse_args() logger = None if not args.no_logger: logger = utils.get_logger(args.output_path) logger.setLevel(10) models = load_models(args, logger) run(args, logger, *models)
# All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # import argparse import logging import numpy import os import time import torch from elq.index.faiss_indexer import DenseFlatIndexer, DenseIVFFlatIndexer, DenseHNSWFlatIndexer import elq.candidate_ranking.utils as utils logger = utils.get_logger() def main(params): output_path = params["output_path"] logger.info("Loading candidate encoding from path: %s" % params["candidate_encoding"]) candidate_encoding = torch.load(params["candidate_encoding"]) vector_size = candidate_encoding.size(1) index_buffer = params["index_buffer"] if params["faiss_index"] == "hnsw": logger.info("Using HNSW index in FAISS") index = DenseHNSWFlatIndexer(vector_size, index_buffer) elif params["faiss_index"] == "ivfflat": logger.info("Using IVF Flat index in FAISS") index = DenseIVFFlatIndexer(vector_size, 75, 100) else:
biencoder_params["path_to_model"] = args.path_to_model # entities to use biencoder_params["entity_dict_path"] = args.entity_dict_path biencoder_params["degug"] = False biencoder_params["data_parallel"] = True biencoder_params["no_cuda"] = False biencoder_params["max_context_length"] = 32 biencoder_params["encode_batch_size"] = args.batch_size saved_cand_ids = getattr(args, 'saved_cand_ids', None) encoding_save_file_dir = args.encoding_save_file_dir if encoding_save_file_dir is not None and not os.path.exists( encoding_save_file_dir): os.makedirs(encoding_save_file_dir, exist_ok=True) logger = utils.get_logger(biencoder_params.get("model_output_path", None)) biencoder = load_biencoder(biencoder_params) baseline_candidate_encoding = None if getattr(args, 'compare_saved_embeds', None) is not None: baseline_candidate_encoding = torch.load( getattr(args, 'compare_saved_embeds')) candidate_pool = load_candidate_pool( biencoder.tokenizer, biencoder_params, logger, getattr(args, 'saved_cand_ids', None), ) if args.test: candidate_pool = candidate_pool[:10]
def main(params): model_output_path = params["output_path"] if not os.path.exists(model_output_path): os.makedirs(model_output_path) logger = utils.get_logger(params["output_path"]) # Init model reranker = BiEncoderRanker(params) tokenizer = reranker.tokenizer model = reranker.model device = reranker.device n_gpu = reranker.n_gpu if params["gradient_accumulation_steps"] < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(params["gradient_accumulation_steps"])) # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y` params["train_batch_size"] = (params["train_batch_size"] // params["gradient_accumulation_steps"]) train_batch_size = params["train_batch_size"] eval_batch_size = params["eval_batch_size"] grad_acc_steps = params["gradient_accumulation_steps"] # Fix the random seeds seed = params["seed"] random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if reranker.n_gpu > 0: torch.cuda.manual_seed_all(seed) # Load train data train_samples = utils.read_dataset("train", params["data_path"]) logger.info("Read %d train samples." % len(train_samples)) logger.info("Finished reading all train samples") # Load eval data try: valid_samples = utils.read_dataset("valid", params["data_path"]) except FileNotFoundError: valid_samples = utils.read_dataset("dev", params["data_path"]) # MUST BE DIVISBLE BY n_gpus if len(valid_samples) > 1024: valid_subset = 1024 else: valid_subset = len( valid_samples) - len(valid_samples) % torch.cuda.device_count() logger.info("Read %d valid samples, choosing %d subset" % (len(valid_samples), valid_subset)) valid_data, valid_tensor_data, extra_ret_values = process_mention_data( samples=valid_samples[:valid_subset], # use subset of valid data tokenizer=tokenizer, max_context_length=params["max_context_length"], max_cand_length=params["max_cand_length"], context_key=params["context_key"], title_key=params["title_key"], silent=params["silent"], logger=logger, debug=params["debug"], add_mention_bounds=(not args.no_mention_bounds), candidate_token_ids=None, params=params, ) candidate_token_ids = extra_ret_values["candidate_token_ids"] valid_tensor_data = TensorDataset(*valid_tensor_data) valid_sampler = SequentialSampler(valid_tensor_data) valid_dataloader = DataLoader(valid_tensor_data, sampler=valid_sampler, batch_size=eval_batch_size) # load candidate encodings cand_encs = None cand_encs_index = None if params["freeze_cand_enc"]: cand_encs = torch.load(params['cand_enc_path']) logger.info("Loaded saved entity encodings") if params["debug"]: cand_encs = cand_encs[:200] # build FAISS index cand_encs_index = DenseHNSWFlatIndexer(1) cand_encs_index.deserialize_from(params['index_path']) logger.info("Loaded FAISS index on entity encodings") num_neighbors = 10 # evaluate before training results = evaluate( reranker, valid_dataloader, params, cand_encs=cand_encs, device=device, logger=logger, faiss_index=cand_encs_index, ) number_of_samples_per_dataset = {} time_start = time.time() utils.write_to_file(os.path.join(model_output_path, "training_params.txt"), str(params)) logger.info("Starting training") logger.info("device: {} n_gpu: {}, distributed training: {}".format( device, n_gpu, False)) num_train_epochs = params["num_train_epochs"] if params["dont_distribute_train_samples"]: num_samples_per_batch = len(train_samples) train_data, train_tensor_data_tuple, extra_ret_values = process_mention_data( samples=train_samples, tokenizer=tokenizer, max_context_length=params["max_context_length"], max_cand_length=params["max_cand_length"], context_key=params["context_key"], title_key=params["title_key"], silent=params["silent"], logger=logger, debug=params["debug"], add_mention_bounds=(not args.no_mention_bounds), candidate_token_ids=candidate_token_ids, params=params, ) logger.info("Finished preparing training data") else: num_samples_per_batch = len(train_samples) // num_train_epochs trainer_path = params.get("path_to_trainer_state", None) optimizer = get_optimizer(model, params) scheduler = get_scheduler(params, optimizer, num_samples_per_batch, logger) if trainer_path is not None and os.path.exists(trainer_path): training_state = torch.load(trainer_path) optimizer.load_state_dict(training_state["optimizer"]) scheduler.load_state_dict(training_state["scheduler"]) logger.info("Loaded saved training state") model.train() best_epoch_idx = -1 best_score = -1 logger.info("Num samples per batch : %d" % num_samples_per_batch) for epoch_idx in trange(params["last_epoch"] + 1, int(num_train_epochs), desc="Epoch"): tr_loss = 0 results = None if not params["dont_distribute_train_samples"]: start_idx = epoch_idx * num_samples_per_batch end_idx = (epoch_idx + 1) * num_samples_per_batch train_data, train_tensor_data_tuple, extra_ret_values = process_mention_data( samples=train_samples[start_idx:end_idx], tokenizer=tokenizer, max_context_length=params["max_context_length"], max_cand_length=params["max_cand_length"], context_key=params["context_key"], title_key=params["title_key"], silent=params["silent"], logger=logger, debug=params["debug"], add_mention_bounds=(not args.no_mention_bounds), candidate_token_ids=candidate_token_ids, params=params, ) logger.info( "Finished preparing training data for epoch {}: {} samples". format(epoch_idx, len(train_tensor_data_tuple[0]))) batch_train_tensor_data = TensorDataset(*list(train_tensor_data_tuple)) if params["shuffle"]: train_sampler = RandomSampler(batch_train_tensor_data) else: train_sampler = SequentialSampler(batch_train_tensor_data) train_dataloader = DataLoader(batch_train_tensor_data, sampler=train_sampler, batch_size=train_batch_size) if params["silent"]: iter_ = train_dataloader else: iter_ = tqdm(train_dataloader, desc="Batch") for step, batch in enumerate(iter_): batch = tuple(t.to(device) for t in batch) context_input = batch[0] candidate_input = batch[1] label_ids = batch[2] if params["freeze_cand_enc"] else None mention_idxs = batch[-2] mention_idx_mask = batch[-1] if params["debug"] and label_ids is not None: label_ids[label_ids > 199] = 199 cand_encs_input = None label_input = None mention_reps_input = None mention_logits = None mention_bounds = None hard_negs_mask = None if params["adversarial_training"]: assert cand_encs is not None and label_ids is not None # due to params["freeze_cand_enc"] being set ''' GET CLOSEST N CANDIDATES (AND APPROPRIATE LABELS) ''' # (bs, num_spans, embed_size) pos_cand_encs_input = cand_encs[label_ids.to("cpu")] pos_cand_encs_input[~mention_idx_mask] = 0 context_outs = reranker.encode_context( context_input, gold_mention_bounds=mention_idxs, gold_mention_bounds_mask=mention_idx_mask, get_mention_scores=True, ) mention_logits = context_outs['all_mention_logits'] mention_bounds = context_outs['all_mention_bounds'] mention_reps = context_outs['mention_reps'] # mention_reps: (bs, max_num_spans, embed_size) -> masked_mention_reps: (all_pred_mentions_batch, embed_size) masked_mention_reps = mention_reps[ context_outs['mention_masks']] # neg_cand_encs_input_idxs: (all_pred_mentions_batch, num_negatives) _, neg_cand_encs_input_idxs = cand_encs_index.search_knn( masked_mention_reps.detach().cpu().numpy(), num_neighbors) neg_cand_encs_input_idxs = torch.from_numpy( neg_cand_encs_input_idxs) # set "correct" closest entities to -1 # masked_label_ids: (all_pred_mentions_batch) masked_label_ids = label_ids[mention_idx_mask] # neg_cand_encs_input_idxs: (max_spans_in_batch, num_negatives) neg_cand_encs_input_idxs[ neg_cand_encs_input_idxs - masked_label_ids.to("cpu").unsqueeze(-1) == 0] = -1 # reshape back tensor (extract num_spans dimension) # (bs, num_spans, num_negatives) neg_cand_encs_input_idxs_reconstruct = torch.zeros( label_ids.size(0), label_ids.size(1), neg_cand_encs_input_idxs.size(-1), dtype=neg_cand_encs_input_idxs.dtype) neg_cand_encs_input_idxs_reconstruct[ mention_idx_mask] = neg_cand_encs_input_idxs neg_cand_encs_input_idxs = neg_cand_encs_input_idxs_reconstruct # create neg_example_idx (corresponding example (in batch) for each negative) # neg_example_idx: (bs * num_negatives) neg_example_idx = torch.arange( neg_cand_encs_input_idxs.size(0)).unsqueeze(-1) neg_example_idx = neg_example_idx.expand( neg_cand_encs_input_idxs.size(0), neg_cand_encs_input_idxs.size(2)) neg_example_idx = neg_example_idx.flatten() # flatten and filter -1 (i.e. any correct/positive entities) # neg_cand_encs_input_idxs: (bs * num_negatives, num_spans) neg_cand_encs_input_idxs = neg_cand_encs_input_idxs.permute( 0, 2, 1) neg_cand_encs_input_idxs = neg_cand_encs_input_idxs.reshape( -1, neg_cand_encs_input_idxs.size(-1)) # mask invalid negatives (actually the positive example) # (bs * num_negatives) mask = ~((neg_cand_encs_input_idxs == -1).sum(1).bool() ) # rows without any -1 entry # deletes corresponding negative for *all* spans in that example (deletes at most 3 of 10 negatives / example) # neg_cand_encs_input_idxs: (bs * num_negatives - invalid_negs, num_spans) neg_cand_encs_input_idxs = neg_cand_encs_input_idxs[mask] # neg_cand_encs_input_idxs: (bs * num_negatives - invalid_negs) neg_example_idx = neg_example_idx[mask] # (bs * num_negatives - invalid_negs, num_spans, embed_size) neg_cand_encs_input = cand_encs[neg_cand_encs_input_idxs] # (bs * num_negatives - invalid_negs, num_spans, embed_size) neg_mention_idx_mask = mention_idx_mask[neg_example_idx] neg_cand_encs_input[~neg_mention_idx_mask] = 0 # create input tensors (concat [pos examples, neg examples]) # (bs + bs * num_negatives, num_spans, embed_size) mention_reps_input = torch.cat([ mention_reps, mention_reps[neg_example_idx.to(device)], ]) assert mention_reps.size(0) == pos_cand_encs_input.size(0) # (bs + bs * num_negatives, num_spans) label_input = torch.cat([ torch.ones(pos_cand_encs_input.size(0), pos_cand_encs_input.size(1), dtype=label_ids.dtype), torch.zeros(neg_cand_encs_input.size(0), neg_cand_encs_input.size(1), dtype=label_ids.dtype), ]).to(device) # (bs + bs * num_negatives, num_spans, embed_size) cand_encs_input = torch.cat([ pos_cand_encs_input, neg_cand_encs_input, ]).to(device) hard_negs_mask = torch.cat( [mention_idx_mask, neg_mention_idx_mask]) loss, _, _, _ = reranker( context_input, candidate_input, cand_encs=cand_encs_input, text_encs=mention_reps_input, mention_logits=mention_logits, mention_bounds=mention_bounds, label_input=label_input, gold_mention_bounds=mention_idxs, gold_mention_bounds_mask=mention_idx_mask, hard_negs_mask=hard_negs_mask, return_loss=True, ) if grad_acc_steps > 1: loss = loss / grad_acc_steps tr_loss += loss.item() if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0: logger.info("Step {} - epoch {} average loss: {}\n".format( step, epoch_idx, tr_loss / (params["print_interval"] * grad_acc_steps), )) tr_loss = 0 loss.backward() if (step + 1) % grad_acc_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_grad_norm"]) optimizer.step() scheduler.step() optimizer.zero_grad() if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0: logger.info("Evaluation on the development dataset") loss = None # for GPU mem management mention_reps = None mention_reps_input = None label_input = None cand_encs_input = None evaluate( reranker, valid_dataloader, params, cand_encs=cand_encs, device=device, logger=logger, faiss_index=cand_encs_index, get_losses=params["get_losses"], ) model.train() logger.info("\n") logger.info("***** Saving fine - tuned model *****") epoch_output_folder_path = os.path.join(model_output_path, "epoch_{}".format(epoch_idx)) utils.save_model(model, tokenizer, epoch_output_folder_path) torch.save( { "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, os.path.join(epoch_output_folder_path, "training_state.th")) output_eval_file = os.path.join(epoch_output_folder_path, "eval_results.txt") logger.info("Valid data evaluation") results = evaluate( reranker, valid_dataloader, params, cand_encs=cand_encs, device=device, logger=logger, faiss_index=cand_encs_index, get_losses=params["get_losses"], ) logger.info("Train data evaluation") results = evaluate( reranker, train_dataloader, params, cand_encs=cand_encs, device=device, logger=logger, faiss_index=cand_encs_index, get_losses=params["get_losses"], ) ls = [best_score, results["normalized_f1"]] li = [best_epoch_idx, epoch_idx] best_score = ls[np.argmax(ls)] best_epoch_idx = li[np.argmax(ls)] logger.info("\n") execution_time = (time.time() - time_start) / 60 utils.write_to_file( os.path.join(model_output_path, "training_time.txt"), "The training took {} minutes\n".format(execution_time), ) logger.info("The training took {} minutes\n".format(execution_time)) # save the best model in the parent_dir logger.info("Best performance in epoch: {}".format(best_epoch_idx)) params["path_to_model"] = os.path.join(model_output_path, "epoch_{}".format(best_epoch_idx)) utils.save_model(reranker.model, tokenizer, model_output_path) if params["evaluate"]: params["path_to_model"] = model_output_path evaluate(params, cand_encs=cand_encs, logger=logger, faiss_index=cand_encs_index)