def ann_data_gen(args): last_checkpoint = args.last_checkpoint_dir ann_no, ann_path, ndcg_json = get_latest_ann_data(args.output_dir) output_num = ann_no + 1 logger.info("starting output number %d", output_num) preloaded_data = None if is_first_worker(): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if not os.path.exists(args.cache_dir): os.makedirs(args.cache_dir) preloaded_data = load_data(args) while args.end_output_num == -1 or output_num <= args.end_output_num: next_checkpoint, latest_step_num = get_latest_checkpoint(args) if args.only_keep_latest_embedding_file: latest_step_num = 0 if next_checkpoint == last_checkpoint: time.sleep(60) else: logger.info("start generate ann data number %d", output_num) logger.info("next checkpoint at " + next_checkpoint) generate_new_ann(args, output_num, next_checkpoint, preloaded_data, latest_step_num) logger.info("finished generating ann data number %d", output_num) output_num += 1 last_checkpoint = next_checkpoint if args.local_rank != -1: dist.barrier()
def compute_mrr(D, I, qids, ref_dict): knn_pkl = {"D": D, "I": I} all_knn_list = all_gather(knn_pkl) mrr = 0.0 if is_first_worker(): D_merged = concat_key(all_knn_list, "D", axis=1) I_merged = concat_key(all_knn_list, "I", axis=1) print(D_merged.shape, I_merged.shape) # we pad with negative pids and distance -128 - if they make it to the top we have a problem idx = np.argsort(D_merged, axis=1)[:, ::-1][:, :10] sorted_I = np.take_along_axis(I_merged, idx, axis=1) candidate_dict = {} for i, qid in enumerate(qids): seen_pids = set() if qid not in candidate_dict: candidate_dict[qid] = [0] * 1000 j = 0 for pid in sorted_I[i]: if pid >= 0 and pid not in seen_pids: candidate_dict[qid][j] = pid j += 1 seen_pids.add(pid) allowed, message = quality_checks_qids(ref_dict, candidate_dict) if message != '': print(message) mrr_metrics = compute_metrics(ref_dict, candidate_dict) mrr = mrr_metrics["MRR @10"] print(mrr) return mrr
def save_checkpoint(args, model, tokenizer): # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() if args.do_train and is_first_worker(): # Create output directory if needed if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` model_to_save = (model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model torch.save(args, os.path.join(args.output_dir, "training_args.bin")) dist.barrier()
def evaluation(args, model, tokenizer): # Evaluation results = {} if args.do_eval: model_dir = args.model_name_or_path if args.model_name_or_path else args.output_dir checkpoints = [model_dir] for checkpoint in checkpoints: global_step = checkpoint.split( "-")[-1] if len(checkpoints) > 1 else "" prefix = checkpoint.split( "/")[-1] if checkpoint.find("checkpoint") != -1 else "" model.eval() reranking_mrr, full_ranking_mrr = passage_dist_eval( args, model, tokenizer) if is_first_worker(): print("Reranking/Full ranking mrr: {0}/{1}".format( str(reranking_mrr), str(full_ranking_mrr))) dist.barrier() return results
def load_model(args): # Prepare GLUE task args.task_name = args.task_name.lower() args.output_mode = "classification" label_list = ["0", "1"] num_labels = len(label_list) # store args if args.local_rank != -1: args.world_size = torch.distributed.get_world_size() args.rank = dist.get_rank() # Load pretrained model and tokenizer if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab args.model_type = args.model_type.lower() configObj = MSMarcoConfigDict[args.model_type] tokenizer = configObj.tokenizer_class.from_pretrained( "bert-base-uncased", do_lower_case=True, cache_dir=args.cache_dir if args.cache_dir else None, ) if is_first_worker(): # Create output directory if needed if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) model = configObj.model_class(args) if args.local_rank == 0: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab model.to(args.device) return tokenizer, model
def generate_new_ann(args, output_num, checkpoint_path, training_query_positive_id, dev_query_positive_id, latest_step_num): config, tokenizer, model = load_model(args, checkpoint_path) logger.info("***** inference of dev query *****") dev_query_collection_path = os.path.join(args.data_dir, "dev-query") dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=True), "dev_query_" + str(latest_step_num) + "_", emb, is_query_inference=True) logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False) logger.info("***** Done passage inference *****") if args.inference: return logger.info("***** inference of train query *****") train_query_collection_path = os.path.join(args.data_dir, "train-query") train_query_cache = EmbeddingCache(train_query_collection_path) with train_query_cache as emb: query_embedding, query_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True) if is_first_worker(): dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info("***** Done ANN Index *****") # measure ANN mrr # I: [number of queries, topk] _, dev_I = cpu_index.search(dev_query_embedding, 100) dev_ndcg, num_queries_dev = EvalDevQuery(args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I) # Construct new traing set ================================== chunk_factor = args.ann_chunk_factor effective_idx = output_num % chunk_factor if chunk_factor <= 0: chunk_factor = 1 num_queries = len(query_embedding) queries_per_chunk = num_queries // chunk_factor q_start_idx = queries_per_chunk * effective_idx q_end_idx = num_queries if (effective_idx == (chunk_factor - 1)) else ( q_start_idx + queries_per_chunk) query_embedding = query_embedding[q_start_idx:q_end_idx] query_embedding2id = query_embedding2id[q_start_idx:q_end_idx] logger.info("Chunked {} query from {}".format(len(query_embedding), num_queries)) # I: [number of queries, topk] _, I = cpu_index.search(query_embedding, top_k) effective_q_id = set(query_embedding2id.flatten()) query_negative_passage = GenerateNegativePassaageID( args, query_embedding2id, passage_embedding2id, training_query_positive_id, I, effective_q_id) logger.info("***** Construct ANN Triplet *****") train_data_output_path = os.path.join( args.output_dir, "ann_training_data_" + str(output_num)) with open(train_data_output_path, 'w') as f: query_range = list(range(I.shape[0])) random.shuffle(query_range) for query_idx in query_range: query_id = query_embedding2id[query_idx] if query_id not in effective_q_id or query_id not in training_query_positive_id: continue pos_pid = training_query_positive_id[query_id] f.write("{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]))) ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f) return dev_ndcg, num_queries_dev
def train(args, model, tokenizer, f, train_fn): """ Train the model """ tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ (torch.distributed.get_world_size() if args.local_rank != -1 else 1) if args.max_steps > 0: t_total = args.max_steps else: t_total = args.expected_train_size // real_batch_size * args.num_train_epochs # layerwise optimization for lamb optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3", "embeddingHead" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) optimizer_grouped_parameters.append({ "params": [p for p in model.parameters() if p not in layer_optim_params] }) if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) if args.scheduler.lower() == "linear": scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) elif args.scheduler.lower() == "cosine": scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8) else: raise Exception( "Scheduler {0} not recognized! Can only be linear or cosine". format(args.scheduler)) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join( args.model_name_or_path, "scheduler.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model path try: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (args.expected_train_size // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( args.expected_train_size // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except: logger.info(" Start training from a pretrained model") tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility for m_epoch in train_iterator: f.seek(0) sds = StreamingDataset(f, train_fn) epoch_iterator = DataLoader(sds, batch_size=args.per_gpu_train_batch_size, num_workers=1) for step, batch in tqdm(enumerate(epoch_iterator), desc="Iteration", disable=args.local_rank not in [-1, 0]): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device).long() for t in batch) if (step + 1) % args.gradient_accumulation_steps == 0: outputs = model(*batch) else: with model.no_sync(): outputs = model(*batch) # model outputs are always tuple in transformers (see doc) loss = outputs[0] if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if (step + 1) % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) dist.barrier() if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} if args.evaluate_during_training and global_step % ( args.logging_steps_per_eval * args.logging_steps) == 0: model.eval() reranking_mrr, full_ranking_mrr = passage_dist_eval( args, model, tokenizer) if is_first_worker(): print("Reranking/Full ranking mrr: {0}/{1}".format( str(reranking_mrr), str(full_ranking_mrr))) mrr_dict = { "reranking": float(reranking_mrr), "full_raking": float(full_ranking_mrr) } tb_writer.add_scalars("mrr", mrr_dict, global_step) print(args.output_dir) loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss if is_first_worker(): for key, value in logs.items(): print(key, type(value)) tb_writer.add_scalar(key, value, global_step) tb_writer.add_scalar("epoch", m_epoch, global_step) print(json.dumps({**logs, **{"step": global_step}})) dist.barrier() if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step, tr_loss / global_step
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num): model = load_model(args, checkpoint_path) pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") logger.info("***** inference of train query *****") train_query_collection_path = os.path.join(args.data_dir, "train-query") train_query_cache = EmbeddingCache(train_query_collection_path) with train_query_cache as emb: query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "query_" + str(latest_step_num)+"_", emb, is_query_inference = True) logger.info("***** inference of dev query *****") dev_query_collection_path = os.path.join(args.data_dir, "test-query") dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True) dev_query_collection_path_trivia = os.path.join(args.data_dir, "trivia-test-query") dev_query_cache_trivia = EmbeddingCache(dev_query_collection_path_trivia) with dev_query_cache_trivia as emb: dev_query_embedding_trivia, dev_query_embedding2id_trivia = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True) logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=False), "passage_"+ str(latest_step_num)+"_", emb, is_query_inference = False, load_cache = False) logger.info("***** Done passage inference *****") if is_first_worker(): passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info("***** Done ANN Index *****") # measure ANN mrr _, dev_I = cpu_index.search(dev_query_embedding, 100) #I: [number of queries, topk] top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id) # measure ANN mrr _, dev_I = cpu_index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk] top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I, dev_query_embedding2id_trivia, passage_embedding2id) logger.info("Start searching for query embedding with length %d", len(query_embedding)) _, I = cpu_index.search(query_embedding, top_k) #I: [number of queries, topk] logger.info("***** GenerateNegativePassaageID *****") effective_q_id = set(query_embedding2id.flatten()) logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0]) query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id) logger.info("Done generating negative passages, output length %d", len(query_negative_passage)) logger.info("***** Construct ANN Triplet *****") train_data_output_path = os.path.join(args.output_dir, "ann_training_data_" + str(output_num)) with open(train_data_output_path, 'w') as f: query_range = list(range(I.shape[0])) random.shuffle(query_range) for query_idx in query_range: query_id = query_embedding2id[query_idx] # if not query_id in train_pos_id: # continue pos_pid = train_pos_id[query_id] f.write("{}\t{}\t{}\n".format(query_id, pos_pid, ','.join(str(neg_pid) for neg_pid in query_negative_passage[query_id]))) ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump({'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19], 'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path}, f)
def train(args, model, tokenizer, query_cache, passage_cache): """ Train the model """ logger.info("Training/evaluation parameters %s", args) tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max( 1, args.n_gpu) #nll loss for query real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * ( torch.distributed.get_world_size() if args.local_rank != -1 else 1) optimizer = get_optimizer( args, model, weight_decay=args.weight_decay, ) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Max steps = %d", args.max_steps) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) tr_loss = 0.0 model.zero_grad() model.train() set_seed(args) # Added here for reproductibility last_ann_no = -1 train_dataloader = None train_dataloader_iter = None dev_ndcg = 0 step = 0 iter_count = 0 scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) global_step = 0 if args.model_name_or_path != "bert-base-uncased": saved_state = load_states_from_checkpoint(args.model_name_or_path) global_step = _load_saved_state(model, optimizer, scheduler, saved_state) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from global step %d", global_step) nq_dev_nll_loss, nq_correct_ratio = evaluate_dev( args, model, passage_cache) dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev( args, model, passage_cache, "-trivia") if is_first_worker(): tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step) tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step) tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step) tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step) while global_step < args.max_steps: if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0: if args.num_epoch == 0: # check if new ann training data is availabe ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) if ann_path is not None and ann_no != last_ann_no: logger.info("Training on new add data at %s", ann_path) with open(ann_path, 'r') as f: ann_training_data = f.readlines() logger.info("Training data line count: %d", len(ann_training_data)) ann_training_data = [ l for l in ann_training_data if len(l.split('\t')[2].split(',')) > 1 ] logger.info("Filtered training data line count: %d", len(ann_training_data)) ann_checkpoint_path = ndcg_json['checkpoint'] ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] logger.info("Total ann queries: %d", len(ann_training_data)) if args.triplet: train_dataset = StreamingDataset( ann_training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size) else: train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) train_dataloader_iter = iter(train_dataloader) # re-warmup if not args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(ann_training_data)) if args.local_rank != -1: dist.barrier() if is_first_worker(): # add ndcg at checkpoint step used instead of current step tb_writer.add_scalar("retrieval_accuracy/top20_nq", ndcg_json['top20'], ann_checkpoint_no) tb_writer.add_scalar("retrieval_accuracy/top100_nq", ndcg_json['top100'], ann_checkpoint_no) if 'top20_trivia' in ndcg_json: tb_writer.add_scalar( "retrieval_accuracy/top20_trivia", ndcg_json['top20_trivia'], ann_checkpoint_no) tb_writer.add_scalar( "retrieval_accuracy/top100_trivia", ndcg_json['top100_trivia'], ann_checkpoint_no) if last_ann_no != -1: tb_writer.add_scalar("epoch", last_ann_no, global_step - 1) tb_writer.add_scalar("epoch", ann_no, global_step) last_ann_no = ann_no elif step == 0: train_data_path = os.path.join(args.data_dir, "train-data") with open(train_data_path, 'r') as f: training_data = f.readlines() if args.triplet: train_dataset = StreamingDataset( training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size) else: train_dataset = StreamingDataset( training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) all_batch = [b for b in train_dataloader] logger.info("Total batch count: %d", len(all_batch)) train_dataloader_iter = iter(train_dataloader) try: batch = next(train_dataloader_iter) except StopIteration: logger.info("Finished iterating current dataset, begin reiterate") if args.num_epoch != 0: iter_count += 1 if is_first_worker(): tb_writer.add_scalar("epoch", iter_count - 1, global_step - 1) tb_writer.add_scalar("epoch", iter_count, global_step) nq_dev_nll_loss, nq_correct_ratio = evaluate_dev( args, model, passage_cache) dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev( args, model, passage_cache, "-trivia") if is_first_worker(): tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step) tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step) tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step) tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step) train_dataloader_iter = iter(train_dataloader) batch = next(train_dataloader_iter) dist.barrier() if args.num_epoch != 0 and iter_count > args.num_epoch: break step += 1 if args.triplet: loss = triplet_fwd_pass(args, model, batch) else: loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch) if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if step % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() if step % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} loss_scalar = tr_loss / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar tr_loss = 0 if is_first_worker(): for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logger.info(json.dumps({**logs, **{"step": global_step}})) if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: _save_checkpoint(args, model, optimizer, scheduler, global_step) if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step
def train(args, model, tokenizer, query_cache, passage_cache): """ Train the model """ logger.info("Training/evaluation parameters %s", args) tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ (torch.distributed.get_world_size() if args.local_rank != -1 else 1) optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) optimizer_grouped_parameters.append({ "params": [p for p in model.parameters() if p not in layer_optim_params] }) if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) # Check if saved optimizer or scheduler states exist if os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train logger.info("***** Running training *****") logger.info(" Max steps = %d", args.max_steps) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) global_step = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model # path if "-" in args.model_name_or_path: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) else: global_step = 0 logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from global step %d", global_step) tr_loss = 0.0 model.zero_grad() model.train() set_seed(args) # Added here for reproductibility last_ann_no = -1 train_dataloader = None train_dataloader_iter = None dev_ndcg = 0 step = 0 if args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) while global_step < args.max_steps: if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0: # check if new ann training data is availabe ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) if ann_path is not None and ann_no != last_ann_no: logger.info("Training on new add data at %s", ann_path) with open(ann_path, 'r') as f: ann_training_data = f.readlines() dev_ndcg = ndcg_json['ndcg'] ann_checkpoint_path = ndcg_json['checkpoint'] ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] logger.info("Total ann queries: %d", len(ann_training_data)) if args.triplet: train_dataset = StreamingDataset( ann_training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) else: train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size) train_dataloader_iter = iter(train_dataloader) # re-warmup if not args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(ann_training_data)) if args.local_rank != -1: dist.barrier() if is_first_worker(): # add ndcg at checkpoint step used instead of current step tb_writer.add_scalar("dev_ndcg", dev_ndcg, ann_checkpoint_no) if last_ann_no != -1: tb_writer.add_scalar("epoch", last_ann_no, global_step - 1) tb_writer.add_scalar("epoch", ann_no, global_step) last_ann_no = ann_no try: batch = next(train_dataloader_iter) except StopIteration: logger.info("Finished iterating current dataset, begin reiterate") train_dataloader_iter = iter(train_dataloader) batch = next(train_dataloader_iter) batch = tuple(t.to(args.device) for t in batch) step += 1 if args.triplet: inputs = { "query_ids": batch[0].long(), "attention_mask_q": batch[1].long(), "input_ids_a": batch[3].long(), "attention_mask_a": batch[4].long(), "input_ids_b": batch[6].long(), "attention_mask_b": batch[7].long() } else: inputs = { "input_ids_a": batch[0].long(), "attention_mask_a": batch[1].long(), "input_ids_b": batch[3].long(), "attention_mask_b": batch[4].long(), "labels": batch[6] } # sync gradients only at gradient accumulation step if step % args.gradient_accumulation_steps == 0: outputs = model(**inputs) else: with model.no_sync(): outputs = model(**inputs) # model outputs are always tuple in transformers (see doc) loss = outputs[0] if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if step % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() if step % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} loss_scalar = tr_loss / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar tr_loss = 0 if is_first_worker(): for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logger.info(json.dumps({**logs, **{"step": global_step}})) if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step