def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference=True): inference_batch_size = args.per_gpu_eval_batch_size # * max(1, args.n_gpu) inference_dataset = StreamingDataset(f, fn) inference_dataloader = DataLoader( inference_dataset, batch_size=inference_batch_size) if args.local_rank != -1: dist.barrier() # directory created _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader( args, model, inference_dataloader, is_query_inference=is_query_inference, prefix=prefix) logger.info("merging embeddings") # preserve to memory full_embedding = barrier_array_merge( args, _embedding, prefix=prefix + "_emb_p_", load_cache=False, only_load_in_master=True) full_embedding2id = barrier_array_merge( args, _embedding2id, prefix=prefix + "_embid_p_", load_cache=False, only_load_in_master=True) return full_embedding, full_embedding2id
def embedding_inference(args, path, model, fn, bz, num_workers=2, is_query=True): f = open(path, encoding="utf-8") model = model.module if hasattr(model, "module") else model sds = StreamingDataset(f, fn) loader = DataLoader(sds, batch_size=bz, num_workers=0) emb_list, id_list = [], [] model.eval() for i, batch in tqdm(enumerate(loader), desc="Eval", disable=args.local_rank not in [-1, 0]): batch = tuple(t.to(args.device) for t in batch) with torch.no_grad(): inputs = {"input_ids": batch[0].long( ), "attention_mask": batch[1].long()} idx = batch[3].long() if is_query: embs = model.query_emb(**inputs) else: embs = model.body_emb(**inputs) if len(embs.shape) == 3: B, C, E = embs.shape # [b1c1, b1c2, b1c3, b1c4, b2c1 ....] embs = embs.view(B*C, -1) idx = idx.repeat_interleave(C) assert embs.shape[0] == idx.shape[0] emb_list.append(embs.detach().cpu().numpy()) id_list.append(idx.detach().cpu().numpy()) f.close() emb_arr = np.concatenate(emb_list, axis=0) id_arr = np.concatenate(id_list, axis=0) return emb_arr, id_arr
def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference = True, load_cache=False): inference_batch_size = args.per_gpu_eval_batch_size #* max(1, args.n_gpu) #inference_dataloader = StreamingDataLoader(f, fn, batch_size=inference_batch_size, num_workers=1) inference_dataset = StreamingDataset(f, fn) inference_dataloader = DataLoader(inference_dataset, batch_size=inference_batch_size) if args.local_rank != -1: dist.barrier() # directory created if (args.emb_file_multi_split_num > 0) and ("passage" in prefix): # extra handling the memory problem by specifying the size of file _, _ = InferenceEmbeddingFromStreamDataLoader(args, model, inference_dataloader, is_query_inference = is_query_inference, prefix = prefix) # dist.barrier() full_embedding = None full_embedding2id = None # TODO: loading ids for first_worker() else: if load_cache: _embedding = None _embedding2id = None else: _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader(args, model, inference_dataloader, is_query_inference = is_query_inference, prefix = prefix) not_loading = args.split_ann_search and ("passage" in prefix) # preserve to memory full_embedding = barrier_array_merge(args, _embedding, prefix = prefix + "_emb_p_", load_cache = load_cache, only_load_in_master = True,not_loading=not_loading) _embedding=None del _embedding full_embedding2id = barrier_array_merge(args, _embedding2id, prefix = prefix + "_embid_p_", load_cache = load_cache, only_load_in_master = True,not_loading=not_loading) logger.info( f"finish saving embbedding of {prefix}, not loading into MEM: {not_loading}" ) _embedding2id=None del _embedding2id return full_embedding, full_embedding2id
def main(): args = get_arguments() set_env(args) config, tokenizer, model, configObj = load_stuff(args.train_model_type, args) # Training if args.do_train: logger.info("Training/evaluation parameters %s", args) def train_fn(line, i): return configObj.process_fn(line, i, tokenizer, args) train_path = os.path.join(args.data_dir, args.train_file) with open(train_path, encoding="utf-8-sig") as f: train_batch_size = args.per_gpu_train_batch_size * \ max(1, args.n_gpu) sds = StreamingDataset(f, train_fn) train_dataloader = DataLoader(sds, batch_size=train_batch_size, num_workers=1) global_step, tr_loss = train(args, model, tokenizer, train_dataloader) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) save_checkpoint(args, model, tokenizer) results = evaluation(args, model, tokenizer) return results
def evaluate_dev(args, model, passage_cache, source=""): dev_query_collection_path = os.path.join(args.data_dir, "dev-query{}".format(source)) dev_query_cache = EmbeddingCache(dev_query_collection_path) logger.info('NLL validation ...') model.eval() log_result_step = 100 batches = 0 total_loss = 0.0 total_correct_predictions = 0 with dev_query_cache: dev_data_path = os.path.join(args.data_dir, "dev-data{}".format(source)) with open(dev_data_path, 'r') as f: dev_data = f.readlines() dev_dataset = StreamingDataset( dev_data, GetTrainingDataProcessingFn(args, dev_query_cache, passage_cache, shuffle=False)) dev_dataloader = DataLoader(dev_dataset, batch_size=args.train_batch_size * 2) for i, batch in enumerate(dev_dataloader): loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch) loss.backward() # get CUDA oom without this model.zero_grad() total_loss += loss.item() total_correct_predictions += correct_cnt batches += 1 if (i + 1) % log_result_step == 0: logger.info('Eval step: %d , loss=%f ', i, loss.item()) total_loss = total_loss / batches total_samples = batches * args.train_batch_size * torch.distributed.get_world_size( ) correct_ratio = float(total_correct_predictions / total_samples) logger.info( 'NLL Validation: loss = %f. correct prediction ratio %d/%d ~ %f', total_loss, total_correct_predictions, total_samples, correct_ratio) model.train() return total_loss, correct_ratio
def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference=True): inference_batch_size = args.per_gpu_eval_batch_size # * max(1, args.n_gpu) inference_dataset = StreamingDataset(f, fn) inference_dataloader = DataLoader(inference_dataset, batch_size=inference_batch_size) if args.local_rank != -1: dist.barrier() # directory created _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader( args, model, inference_dataloader, is_query_inference=is_query_inference, prefix=prefix) logger.info("merging embeddings") not_loading = args.split_ann_search and ("passage" in prefix) # preserve to memory full_embedding = barrier_array_merge(args, _embedding, prefix=prefix + "_emb_p_", load_cache=False, only_load_in_master=True, not_loading=not_loading) _embedding = None del _embedding logger.info( f"finish saving embbedding of {prefix}, not loading into MEM: {not_loading}" ) full_embedding2id = barrier_array_merge(args, _embedding2id, prefix=prefix + "_embid_p_", load_cache=False, only_load_in_master=True) return full_embedding, full_embedding2id
def train(args, model, tokenizer, query_cache, passage_cache): """ Train the model """ logger.info("Training/evaluation parameters %s", args) tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ (torch.distributed.get_world_size() if args.local_rank != -1 else 1) optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) optimizer_grouped_parameters.append({ "params": [p for p in model.parameters() if p not in layer_optim_params] }) if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) # Check if saved optimizer or scheduler states exist if os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train logger.info("***** Running training *****") logger.info(" Max steps = %d", args.max_steps) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) global_step = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model # path if "-" in args.model_name_or_path: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) else: global_step = 0 logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from global step %d", global_step) tr_loss = 0.0 model.zero_grad() model.train() set_seed(args) # Added here for reproductibility last_ann_no = -1 train_dataloader = None train_dataloader_iter = None dev_ndcg = 0 step = 0 save_no = 0 if args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) while global_step < args.max_steps: if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0 and global_step % args.save_steps < args.save_steps / 20: # check if new ann training data is availabe ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) if ann_path is not None and ann_no != last_ann_no: logger.info("Training on new add data at %s", ann_path) with open(ann_path, 'r') as f: ann_training_data = f.readlines() dev_ndcg = ndcg_json['ndcg'] ann_checkpoint_path = ndcg_json['checkpoint'] ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] logger.info("Total ann queries: %d", len(ann_training_data)) if args.triplet: train_dataset = StreamingDataset( ann_training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) else: train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size) train_dataloader_iter = iter(train_dataloader) # re-warmup if not args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(ann_training_data)) if args.local_rank != -1: dist.barrier() if is_first_worker(): # add ndcg at checkpoint step used instead of current step tb_writer.add_scalar("dev_ndcg", dev_ndcg, ann_checkpoint_no) if last_ann_no != -1: tb_writer.add_scalar("epoch", last_ann_no, global_step - 1) tb_writer.add_scalar("epoch", ann_no, global_step) last_ann_no = ann_no try: batch = next(train_dataloader_iter) except StopIteration: logger.info("Finished iterating current dataset, begin reiterate") train_dataloader_iter = iter(train_dataloader) batch = next(train_dataloader_iter) batch = tuple(t.to(args.device) for t in batch) step += 1 if args.triplet: inputs = { "query_ids": batch[0].long(), "attention_mask_q": batch[1].long(), "input_ids_a": batch[3].long(), "attention_mask_a": batch[4].long(), "input_ids_b": batch[6].long(), "attention_mask_b": batch[7].long() } else: inputs = { "input_ids_a": batch[0].long(), "attention_mask_a": batch[1].long(), "input_ids_b": batch[3].long(), "attention_mask_b": batch[4].long(), "labels": batch[6] } # sync gradients only at gradient accumulation step if step % args.gradient_accumulation_steps == 0: outputs = model(**inputs) else: with model.no_sync(): outputs = model(**inputs) # model outputs are always tuple in transformers (see doc) loss = outputs[0] if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if step % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() if step % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} loss_scalar = tr_loss / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar tr_loss = 0 if is_first_worker(): for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logger.info(json.dumps({**logs, **{"step": global_step}})) if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) save_no += 1 if save_no > 1: ann_no, ann_path, ndcg_json = get_latest_ann_data( args.ann_dir) while (ann_no == last_ann_no): print("Waiting for new ann_data. Sleeping for 1hr!!") time.sleep(3600) ann_no, ann_path, ndcg_json = get_latest_ann_data( args.ann_dir) dist.barrier() if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step
def train(args, model, tokenizer, query_cache, passage_cache): """ Train the model """ logger.info("Training/evaluation parameters %s", args) tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max( 1, args.n_gpu) #nll loss for query real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * ( torch.distributed.get_world_size() if args.local_rank != -1 else 1) optimizer = get_optimizer( args, model, weight_decay=args.weight_decay, ) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Max steps = %d", args.max_steps) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) tr_loss = 0.0 model.zero_grad() model.train() set_seed(args) # Added here for reproductibility last_ann_no = -1 train_dataloader = None train_dataloader_iter = None dev_ndcg = 0 step = 0 iter_count = 0 scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) global_step = 0 if args.model_name_or_path != "bert-base-uncased": saved_state = load_states_from_checkpoint(args.model_name_or_path) global_step = _load_saved_state(model, optimizer, scheduler, saved_state) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from global step %d", global_step) #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache) #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia") #if is_first_worker(): # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step) # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step) print(args.num_epoch) #step = global_step print(step, args.max_steps, global_step) global_step = 0 while global_step < args.max_steps: if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0: if args.num_epoch == 0: #print('yes') # check if new ann training data is availabe ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) #print(ann_path) #print(ann_no) #print(ndcg_json) if ann_path is not None and ann_no != last_ann_no: logger.info("Training on new add data at %s", ann_path) time.sleep(180) with open(ann_path, 'r') as f: #print(ann_path) ann_training_data = f.readlines() logger.info("Training data line count: %d", len(ann_training_data)) ann_training_data = [ l for l in ann_training_data if len(l.split('\t')[2].split(',')) > 1 ] logger.info("Filtered training data line count: %d", len(ann_training_data)) #ann_checkpoint_path = ndcg_json['checkpoint'] #ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] logger.info("Total ann queries: %d", len(ann_training_data)) if args.triplet: train_dataset = StreamingDataset( ann_training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size) else: train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) train_dataloader_iter = iter(train_dataloader) # re-warmup if not args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(ann_training_data)) if args.local_rank != -1: dist.barrier() if is_first_worker(): # add ndcg at checkpoint step used instead of current step #tb_writer.add_scalar("retrieval_accuracy/top20_nq", ndcg_json['top20'], ann_checkpoint_no) #tb_writer.add_scalar("retrieval_accuracy/top100_nq", ndcg_json['top100'], ann_checkpoint_no) #if 'top20_trivia' in ndcg_json: # tb_writer.add_scalar("retrieval_accuracy/top20_trivia", ndcg_json['top20_trivia'], ann_checkpoint_no) # tb_writer.add_scalar("retrieval_accuracy/top100_trivia", ndcg_json['top100_trivia'], ann_checkpoint_no) if last_ann_no != -1: tb_writer.add_scalar("epoch", last_ann_no, global_step - 1) tb_writer.add_scalar("epoch", ann_no, global_step) last_ann_no = ann_no elif step == 0: train_data_path = os.path.join(args.data_dir, "train-data") with open(train_data_path, 'r') as f: training_data = f.readlines() if args.triplet: train_dataset = StreamingDataset( training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size) else: train_dataset = StreamingDataset( training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) all_batch = [b for b in train_dataloader] logger.info("Total batch count: %d", len(all_batch)) train_dataloader_iter = iter(train_dataloader) try: batch = next(train_dataloader_iter) except StopIteration: logger.info("Finished iterating current dataset, begin reiterate") if args.num_epoch != 0: iter_count += 1 if is_first_worker(): tb_writer.add_scalar("epoch", iter_count - 1, global_step - 1) tb_writer.add_scalar("epoch", iter_count, global_step) #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache) #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia") #if is_first_worker(): # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step) # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step) ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) if ann_path is not None: with open(ann_path, 'r') as f: print(ann_path) ann_training_data = f.readlines() logger.info("Training data line count: %d", len(ann_training_data)) ann_training_data = [ l for l in ann_training_data if len(l.split('\t')[2].split(',')) > 1 ] logger.info("Filtered training data line count: %d", len(ann_training_data)) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) train_dataloader_iter = iter(train_dataloader) batch = next(train_dataloader_iter) dist.barrier() if args.num_epoch != 0 and iter_count > args.num_epoch: break step += 1 if args.triplet: loss = triplet_fwd_pass(args, model, batch) else: loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch) if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if step % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() if step % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} loss_scalar = tr_loss / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar tr_loss = 0 if is_first_worker(): for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logger.info(json.dumps({**logs, **{"step": global_step}})) if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: _save_checkpoint(args, model, optimizer, scheduler, global_step) if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step
def train(args, model, tokenizer, f, train_fn): """ Train the model """ tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ (torch.distributed.get_world_size() if args.local_rank != -1 else 1) if args.max_steps > 0: t_total = args.max_steps else: t_total = args.expected_train_size // real_batch_size * args.num_train_epochs print('????t_total', t_total) # layerwise optimization for lamb optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3", "embeddingHead" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) # if getattr_recursive(model, "roberta.encoder.layer") is not None: # for layer in model.roberta.encoder.layer: # optimizer_grouped_parameters.append({"params": layer.parameters()}) # for p in layer.parameters(): # layer_optim_params.add(p) optimizer_grouped_parameters.append({ "params": [p for p in model.parameters() if p not in layer_optim_params] }) if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) if args.scheduler.lower() == "linear": scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) elif args.scheduler.lower() == "cosine": scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8) else: raise Exception( "Scheduler {0} not recognized! Can only be linear or cosine". format(args.scheduler)) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join( args.model_name_or_path, "scheduler.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model path try: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (args.expected_train_size // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( args.expected_train_size // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except: logger.info(" Start training from a pretrained model") tr_loss, logging_loss = 0.0, 0.0 tr_acc, logging_acc = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility #print('???',args.local_rank) #assert 1==0, "?????" for m_epoch in train_iterator: f.seek(0) sds = StreamingDataset(f, train_fn) epoch_iterator = DataLoader(sds, batch_size=args.per_gpu_train_batch_size, num_workers=1) for step, batch in tqdm(enumerate(epoch_iterator), desc="Iteration", disable=args.local_rank not in [-1, 0]): #assert 1==0, "?????" # Skip past any already trained steps if resuming training #assert 1==0, steps_trained_in_current_epoch if not args.reset_iter: if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device).long() for t in batch) # print('???',*batch) # assert 1==0, "!!!!!" if (step + 1) % args.gradient_accumulation_steps == 0: outputs = model(*batch) else: with model.no_sync(): # print('???',*batch) # assert 1==0 outputs = model(*batch) # model outputs are always tuple in transformers (see doc) loss = outputs[0] acc = outputs[1] #print('???',acc) if is_first_worker(): print(*batch) assert 1 == 0 if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training acc = acc.float().mean() #print('???',acc) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps acc = acc / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if (step + 1) % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() tr_acc += acc.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) if 'fairseq' not in args.train_model_type: model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) else: torch.save(model.state_dict(), os.path.join(output_dir, 'model.pt')) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) dist.barrier() if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} if args.evaluate_during_training and global_step % ( args.logging_steps_per_eval * args.logging_steps) == 0: model.eval() reranking_mrr, full_ranking_mrr = passage_dist_eval( args, model, tokenizer) if is_first_worker(): print("Reranking/Full ranking mrr: {0}/{1}".format( str(reranking_mrr), str(full_ranking_mrr))) mrr_dict = { "reranking": float(reranking_mrr), "full_raking": float(full_ranking_mrr) } tb_writer.add_scalars("mrr", mrr_dict, global_step) print(args.output_dir) loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss acc_scalar = (tr_acc - logging_acc) / args.logging_steps logs["acc"] = acc_scalar logging_acc = tr_acc if is_first_worker(): for key, value in logs.items(): print(key, type(value)) tb_writer.add_scalar(key, value, global_step) tb_writer.add_scalar("epoch", m_epoch, global_step) print(json.dumps({**logs, **{"step": global_step}})) dist.barrier() if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step, tr_loss / global_step
def train(args, model, tokenizer, f, train_fn): """ Train the model """ tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ (torch.distributed.get_world_size() if args.local_rank != -1 else 1) if args.max_steps > 0: t_total = args.max_steps else: t_total = args.expected_train_size // real_batch_size * args.num_train_epochs print('????t_total', t_total) # layerwise optimization for lamb optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3", "embeddingHead" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) # if getattr_recursive(model, "roberta.encoder.layer") is not None: # for layer in model.roberta.encoder.layer: # optimizer_grouped_parameters.append({"params": layer.parameters()}) # for p in layer.parameters(): # layer_optim_params.add(p) optimizer_grouped_parameters.append({ "params": [p for p in model.parameters() if p not in layer_optim_params] }) if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) if args.scheduler.lower() == "linear": scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) elif args.scheduler.lower() == "cosine": scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8) else: raise Exception( "Scheduler {0} not recognized! Can only be linear or cosine". format(args.scheduler)) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join( args.model_name_or_path, "scheduler.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model path try: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (args.expected_train_size // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( args.expected_train_size // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except: logger.info(" Start training from a pretrained model") tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility #print('???',args.local_rank) #assert 1==0, "?????" for m_epoch in train_iterator: f.seek(0) sds = StreamingDataset(f, train_fn) epoch_iterator = DataLoader(sds, batch_size=args.per_gpu_train_batch_size, num_workers=1) count = 0 avg_cls_norm = 0 loss_avg = 0 for step, batch in tqdm(enumerate(epoch_iterator), desc="Iteration", disable=args.local_rank not in [-1, 0]): #assert 1==0, "?????" # Skip past any already trained steps if resuming training #assert 1==0, steps_trained_in_current_epoch # if not args.reset_iter: # if steps_trained_in_current_epoch > 0: # steps_trained_in_current_epoch -= 1 # continue model.train() batch = tuple(t.to(args.device).long() for t in batch) # print('???',*batch) # assert 1==0, "!!!!!" with torch.no_grad(): outputs = model(*batch) cls_norm = outputs[1] loss = outputs[0] count += 1 avg_cls_norm += float(cls_norm.cpu().data) loss_avg += float(loss.cpu().data) print( "SEED-Encoder norm: ", cls_norm, ) #print("loss: ",loss) #assert 1==0 #print("optimus norm: ",cls_norm) if count == 1024: # print('avg_cls_norm: ',float(avg_cls_norm)/count) print('avg_cls_sim: ', float(avg_cls_norm) / count) print('avg_loss: ', float(loss_avg) / count) return if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step, tr_loss / global_step
def train(args, model, tokenizer, query_cache, passage_cache): """ Train the model """ logger.info("Training/evaluation parameters %s", args) tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ (torch.distributed.get_world_size() if args.local_rank != -1 else 1) optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) optimizer_grouped_parameters.append({ "params": [p for p in model.parameters() if p not in layer_optim_params] }) if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) def optimizer_to(optim, device): for param in optim.state.values(): # Not sure there are any global tensors in the state dict if isinstance(param, torch.Tensor): param.data = param.data.to(device) if param._grad is not None: param._grad.data = param._grad.data.to(device) elif isinstance(param, dict): for subparam in param.values(): if isinstance(subparam, torch.Tensor): subparam.data = subparam.data.to(device) if subparam._grad is not None: subparam._grad.data = subparam._grad.data.to( device) torch.cuda.empty_cache() # Check if saved optimizer or scheduler states exist if os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"), map_location='cpu')) optimizer_to(optimizer, args.device) model.to(args.device) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train logger.info("***** Running training *****") logger.info(" Max steps = %d", args.max_steps) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) global_step = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model # path if "-" in args.model_name_or_path: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) else: global_step = 0 logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from global step %d", global_step) is_hypersphere_training = (args.hyper_align_weight > 0 or args.hyper_unif_weight > 0) if is_hypersphere_training: logger.info( f"training with hypersphere property regularization, align weight {args.hyper_align_weight}, unif weight {args.hyper_unif_weight}" ) if not args.dual_training: args.dual_loss_weight = 0.0 tr_loss_dict = {} model.zero_grad() model.train() set_seed(args) # Added here for reproductibility last_ann_no = -1 train_dataloader = None train_dataloader_iter = None # dev_ndcg = 0 step = 0 if args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) if os.path.isfile(os.path.join( args.model_name_or_path, "scheduler.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states scheduler.load_state_dict( torch.load( os.path.join(args.model_name_or_path, "scheduler.pt"))) while global_step < args.max_steps: if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0: # check if new ann training data is availabe ann_no, ann_path, ndcg_json = get_latest_ann_data( args.ann_dir, is_grouped=(args.grouping_ann_data > 0)) if ann_path is not None and ann_no != last_ann_no: logger.info("Training on new add data at %s", ann_path) time.sleep(30) # wait until transmission finished with open(ann_path, 'r') as f: ann_training_data = f.readlines() # marcodev_ndcg = ndcg_json['marcodev_ndcg'] logging.info(f"loading:\n{ndcg_json}") ann_checkpoint_path = ndcg_json['checkpoint'] ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] logger.info( "Total ann queries: %d", len(ann_training_data) if args.grouping_ann_data < 0 else len(ann_training_data) * args.grouping_ann_data) if args.grouping_ann_data > 0: if args.polling_loaded_data_batch_from_group: train_dataset = StreamingDataset( ann_training_data, GetGroupedTrainingDataProcessingFn_polling( args, query_cache, passage_cache)) else: train_dataset = StreamingDataset( ann_training_data, GetGroupedTrainingDataProcessingFn_origin( args, query_cache, passage_cache)) else: if not args.dual_training: if args.triplet: train_dataset = StreamingDataset( ann_training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) else: train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn( args, query_cache, passage_cache)) else: # return quadruplet train_dataset = StreamingDataset( ann_training_data, GetQuadrapuletTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size) train_dataloader_iter = iter(train_dataloader) # re-warmup if not args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(ann_training_data) if args.grouping_ann_data < 0 else len(ann_training_data) * args.grouping_ann_data) if args.local_rank != -1: dist.barrier() if is_first_worker(): # add ndcg at checkpoint step used instead of current step for key in ndcg_json: if "marcodev" in key: tb_writer.add_scalar(key, ndcg_json[key], ann_checkpoint_no) if 'trec2019_ndcg' in ndcg_json: tb_writer.add_scalar("trec2019_ndcg", ndcg_json['trec2019_ndcg'], ann_checkpoint_no) if last_ann_no != -1: tb_writer.add_scalar("epoch", last_ann_no, global_step - 1) tb_writer.add_scalar("epoch", ann_no, global_step) last_ann_no = ann_no try: batch = next(train_dataloader_iter) except StopIteration: logger.info("Finished iterating current dataset, begin reiterate") train_dataloader_iter = iter(train_dataloader) batch = next(train_dataloader_iter) # original way if args.grouping_ann_data <= 0: batch = tuple(t.to(args.device) for t in batch) if args.triplet: inputs = { "query_ids": batch[0].long(), "attention_mask_q": batch[1].long(), "input_ids_a": batch[3].long(), "attention_mask_a": batch[4].long(), "input_ids_b": batch[6].long(), "attention_mask_b": batch[7].long() } if args.dual_training: inputs["neg_query_ids"] = batch[9].long() inputs["attention_mask_neg_query"] = batch[10].long() inputs["prime_loss_weight"] = args.prime_loss_weight inputs["dual_loss_weight"] = args.dual_loss_weight else: inputs = { "input_ids_a": batch[0].long(), "attention_mask_a": batch[1].long(), "input_ids_b": batch[3].long(), "attention_mask_b": batch[4].long(), "labels": batch[6] } else: # the default collate_fn will convert item["q_pos"] into batch format ... I guess inputs = { "query_ids": batch["q_pos"][0].to(args.device).long(), "attention_mask_q": batch["q_pos"][1].to(args.device).long(), "input_ids_a": batch["d_pos"][0].to(args.device).long(), "attention_mask_a": batch["d_pos"][1].to(args.device).long(), "input_ids_b": batch["d_neg"][0].to(args.device).long(), "attention_mask_b": batch["d_neg"][1].to(args.device).long(), } if args.dual_training: inputs["neg_query_ids"] = batch["q_neg"][0].to( args.device).long() inputs["attention_mask_neg_query"] = batch["q_neg"][1].to( args.device).long() inputs["prime_loss_weight"] = args.prime_loss_weight inputs["dual_loss_weight"] = args.dual_loss_weight inputs["temperature"] = args.temperature inputs["loss_objective"] = args.loss_objective_function if is_hypersphere_training: inputs["alignment_weight"] = args.hyper_align_weight inputs["uniformity_weight"] = args.hyper_unif_weight step += 1 if args.local_rank != -1: # sync gradients only at gradient accumulation step if step % args.gradient_accumulation_steps == 0: outputs = model(**inputs) else: with model.no_sync(): outputs = model(**inputs) else: outputs = model(**inputs) # model outputs are always tuple in transformers (see doc) loss = outputs[0] loss_item_dict = outputs[1] if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training for k in loss_item_dict: loss_item_dict[k] = loss_item_dict[k].mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps for k in loss_item_dict: loss_item_dict[ k] = loss_item_dict[k] / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if args.local_rank != -1: if step % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() else: loss.backward() def incremental_tr_loss(tr_loss_dict, loss_item_dict, total_loss): for k in loss_item_dict: if k not in tr_loss_dict: tr_loss_dict[k] = loss_item_dict[k].item() else: tr_loss_dict[k] += loss_item_dict[k].item() if "loss_total" not in tr_loss_dict: tr_loss_dict["loss_total"] = total_loss.item() else: tr_loss_dict["loss_total"] += total_loss.item() return tr_loss_dict tr_loss_dict = incremental_tr_loss(tr_loss_dict, loss_item_dict, total_loss=loss) if step % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar for k in tr_loss_dict: logs[k] = tr_loss_dict[k] / args.logging_steps tr_loss_dict = {} if is_first_worker(): for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logger.info(json.dumps({**logs, **{"step": global_step}})) if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step