def train(args, model, tokenizer, query_cache, passage_cache): """ Train the model """ #if args.local_rank in [-1, 0]: tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * ( torch.distributed.get_world_size() if args.local_rank != -1 else 1) # layerwise optimization for lamb optimizer_grouped_parameters = [] for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) if len(optimizer_grouped_parameters) == 0: no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) # Check if saved optimizer or scheduler states exist if os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") #logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Max steps = %d", args.max_steps) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) global_step = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model path if "-" in args.model_name_or_path: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) else: global_step = 0 logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from global step %d", global_step) tr_loss = 0.0 model.zero_grad() model.train() set_seed(args) # Added here for reproductibility last_ann_no = -1 train_dataloader = None train_dataloader_iter = None dev_ndcg = 0 step = 0 if args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) while global_step < args.max_steps: if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0: # check if new ann training data is availabe ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) if ann_path is not None and ann_no != last_ann_no: logger.info("Training on new add data at %s", ann_path) with open(ann_path, 'r') as f: ann_training_data = f.readlines() dev_ndcg = ndcg_json['ndcg'] ann_checkpoint_path = ndcg_json['checkpoint'] ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] logger.info("Total ann queries: %d", len(ann_training_data)) if args.triplet: train_dataset = StreamingDataset( ann_training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) else: train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size) train_dataloader_iter = iter(train_dataloader) # re-warmup if not args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(ann_training_data)) if args.local_rank != -1: dist.barrier() if is_first_worker(): # add ndcg at checkpoint step used instead of current step tb_writer.add_scalar("dev_ndcg", dev_ndcg, ann_checkpoint_no) if last_ann_no != -1: tb_writer.add_scalar("epoch", last_ann_no, global_step - 1) tb_writer.add_scalar("epoch", ann_no, global_step) last_ann_no = ann_no try: batch = next(train_dataloader_iter) except StopIteration: logger.info("Finished iterating current dataset, begin reiterate") train_dataloader_iter = iter(train_dataloader) batch = next(train_dataloader_iter) batch = tuple(t.to(args.device) for t in batch) step += 1 if args.triplet: inputs = { "query_ids": batch[0].long(), "attention_mask_q": batch[1].long(), "input_ids_a": batch[3].long(), "attention_mask_a": batch[4].long(), "input_ids_b": batch[6].long(), "attention_mask_b": batch[7].long() } else: inputs = { "input_ids_a": batch[0].long(), "attention_mask_a": batch[1].long(), "input_ids_b": batch[3].long(), "attention_mask_b": batch[4].long(), "labels": batch[6] } # sync gradients only at gradient accumulation step if step % args.gradient_accumulation_steps == 0: outputs = model(**inputs) else: with model.no_sync(): outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if step % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() if step % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} loss_scalar = tr_loss / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar tr_loss = 0 if is_first_worker(): for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logger.info(json.dumps({**logs, **{"step": global_step}})) if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step
for epoch in range(epochs): print(f"Epoch {epoch + 1}") start_time = time.time() train_loss = train(args, epoch, writer, model, train_dataset) valid_loss, em, f1 = valid(model, valid_dataset, writer, epoch) end_time = time.time() epoch_mins, epoch_secs = epoch_time(start_time, end_time) metrics['train_losses'].append(train_loss) metrics['valid_losses'].append(valid_loss) metrics['ems'].append(em) metrics['f1s'].append(f1) if valid_loss < valid_loss_prev: state = {'epoch': epoch, 'model_state_dict': model.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict()} fname = os.path.join(ckpt_dir, 'best_weights.pt'.format(epoch)) torch.save(state, fname) else: lives -= 1 if lives == 0: break valid_loss_prev = valid_loss pickle.dump(metrics, open(os.path.join(ckpt_dir, 'metrics.p'), 'wb')) print(f"Epoch train loss : {train_loss}| Time: {epoch_mins}m {epoch_secs}s") print(f"Epoch valid loss: {valid_loss}") print(f"Epoch EM: {em}") print(f"Epoch F1: {f1}") print("====================================================================================")
def train(args, model, tokenizer, train_dataloader): """ Train the model """ #if args.local_rank in [-1, 0]: tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1) if args.max_steps > 0: t_total = args.max_steps #args.num_train_epochs = args.max_steps // (args.expected_train_size // args.gradient_accumulation_steps) + 1 else: t_total = args.expected_train_size // real_batch_size * args.num_train_epochs # layerwise optimization for lamb optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in ["roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3", "embeddingHead"]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) optimizer_grouped_parameters.append({"params": [p for p in model.parameters() if p not in layer_optim_params]}) if len(optimizer_grouped_parameters)==0: no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] if args.optimizer.lower()=="lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower()=="adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception("optimizer {0} not recognized! Can only be lamb or adamW".format(args.optimizer)) if args.scheduler.lower()=="linear": scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total ) elif args.scheduler.lower()=="cosine": scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8) else: raise Exception("Scheduler {0} not recognized! Can only be linear or cosine".format(args.scheduler)) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt") ) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states # if is_first_worker(): # op_state = torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")) # print([len(x['params']) for x in op_state['param_groups']]) # real_op_state = optimizer.state_dict() # print([len(x['params']) for x in real_op_state['param_groups']]) optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") #logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model path try: global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (args.expected_train_size // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % (args.expected_train_size // args.gradient_accumulation_steps) logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except: logger.info(" Start training from a pretrained model") tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility for m_epoch in train_iterator: #epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in tqdm(enumerate(train_dataloader), desc="Iteration", disable=args.local_rank not in [-1, 0]): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device).long() for t in batch) if (step + 1) % args.gradient_accumulation_steps == 0: outputs = model(*batch) else: with model.no_sync(): outputs = model(*batch) loss = outputs[0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if (step + 1) % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if is_first_worker() and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) dist.barrier() if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} if args.evaluate_during_training and global_step % (args.logging_steps_per_eval*args.logging_steps)==0: model.eval() reranking_mrr, full_ranking_mrr = passage_dist_eval(args, model, tokenizer) if is_first_worker(): print("Reranking/Full ranking mrr: {0}/{1}".format(str(reranking_mrr), str(full_ranking_mrr))) mrr_dict = {"reranking": float(reranking_mrr), "full_raking": float(full_ranking_mrr)} tb_writer.add_scalars("mrr", mrr_dict, global_step) print(args.output_dir) loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss if is_first_worker(): for key, value in logs.items(): print(key, type(value)) tb_writer.add_scalar(key, value, global_step) tb_writer.add_scalar("epoch", m_epoch, global_step) print(json.dumps({**logs, **{"step": global_step}})) dist.barrier() if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step, tr_loss / global_step
def train(args, train_dataset, model_d, model_g, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_d_grouped_parameters = [ { "params": [ p for n, p in model_d.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model_d.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] # optimizer_d = AdamW(optimizer_d_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer_d = Lamb(optimizer_d_grouped_parameters, lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-6) scheduler_d = get_linear_schedule_with_warmup( optimizer_d, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) optimizer_g_grouped_parameters = [ { "params": [ p for n, p in model_g.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model_g.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] # optimizer_g = AdamW(optimizer_g_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer_g = Lamb(optimizer_g_grouped_parameters, lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-6) scheduler_g = get_linear_schedule_with_warmup( optimizer_g, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer_d.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler_d.pt")): # Load in optimizer and scheduler states optimizer_d.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer_d.pt"))) scheduler_d.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler_d.pt"))) if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer_g.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler_g.pt")): # Load in optimizer and scheduler states optimizer_g.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer_g.pt"))) scheduler_g.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler_g.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model_d, optimizer_d = amp.initialize(model_d, optimizer_d, opt_level=args.fp16_opt_level) model_g, optimizer_g = amp.initialize(model_g, optimizer_g, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model_d = torch.nn.DataParallel(model_d) model_g = torch.nn.DataParallel(model_g) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model_d = torch.nn.parallel.DistributedDataParallel( model_d, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) model_g = torch.nn.parallel.DistributedDataParallel( model_g, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model path global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) model_to_resize_d = model_d.module if hasattr( model_d, "module") else model_d # Take care of distributed/parallel training # model_to_resize_d.resize_token_embeddings(len(tokenizer)) model_to_resize_g = model_g.module if hasattr( model_g, "module") else model_g # Take care of distributed/parallel training # model_to_resize_g.resize_token_embeddings(len(tokenizer)) # model_to_resize_d.bert.embeddings = model_to_resize_g.bert.embeddings tr_loss, logging_loss = 0.0, 0.0 tr_loss_d, logging_loss_d = 0.0, 0.0 tr_loss_g, logging_loss_g = 0.0, 0.0 model_d.zero_grad() model_g.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model_d.train() model_g.train() # batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids # outputs = model(**inputs) # loss = outputs[0] # model outputs are always tuple in transformers (see doc) masked_input_ids, mask_labels = mask_tokens( inputs['input_ids'], tokenizer, args) outputs_g = model_g( input_ids=masked_input_ids.to(args.device), masked_lm_labels=mask_labels.to(args.device), attention_mask=inputs['attention_mask'].to(args.device), token_type_ids=inputs['token_type_ids'].to(args.device)) masked_lm_loss, prediction_scores_g = outputs_g[0], outputs_g[1] prediction_g = prediction_scores_g.max(dim=-1)[1].cpu() acc_g = (prediction_g[mask_labels >= 0] == mask_labels[ mask_labels >= 0]).float().mean().item() prediction_probs_g = F.softmax(prediction_scores_g, dim=-1).cpu() bsz, seq_len, vocab_size = prediction_probs_g.size() prediction_samples_g = torch.multinomial(prediction_probs_g.view( -1, vocab_size), num_samples=1) prediction_samples_g = prediction_samples_g.view(bsz, seq_len) input_ids_replace = inputs['input_ids'].clone() input_ids_replace[mask_labels >= 0] = prediction_samples_g[ mask_labels >= 0] labels_d = input_ids_replace.eq(inputs['input_ids']).long() special_tokens_mask = [ tokenizer.get_special_tokens_mask( val, already_has_special_tokens=True) for val in inputs['input_ids'].tolist() ] labels_d.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=-100) padding_mask = inputs['input_ids'].eq(tokenizer.pad_token_id) labels_d.masked_fill_(padding_mask, value=-100) labels_d_ones = labels_d[labels_d >= 0].float().mean().item() acc_replace = 1 - ((labels_d == 0).sum().float() / (mask_labels >= 0).sum().float()).item() outputs_d = model_d( input_ids=input_ids_replace.to(args.device), attention_mask=inputs['attention_mask'].to(args.device), token_type_ids=inputs['token_type_ids'].to(args.device), labels=labels_d.to(args.device)) loss_d, prediction_scores_d = outputs_d[0], outputs_d[1] prediction_d = prediction_scores_d.max(dim=-1)[1].cpu() acc_d = (prediction_d[labels_d >= 0] == labels_d[labels_d >= 0] ).float().mean().item() acc_d_0 = (prediction_d[labels_d == 0] == labels_d[labels_d == 0] ).float().mean().item() acc_d_1 = (prediction_d[labels_d == 1] == labels_d[labels_d == 1] ).float().mean().item() if args.n_gpu > 1: loss_d = loss_d.mean( ) # mean() to average on multi-gpu parallel training masked_lm_loss = masked_lm_loss.mean() if args.gradient_accumulation_steps > 1: loss_d = loss_d / args.gradient_accumulation_steps masked_lm_loss = masked_lm_loss / args.gradient_accumulation_steps lambd = 50 loss = loss_d * lambd + masked_lm_loss if args.fp16: loss_d = loss_d * lambd with amp.scale_loss(loss_d, optimizer_d) as scaled_loss_d: scaled_loss_d.backward() with amp.scale_loss(masked_lm_loss, optimizer_g) as scaled_loss_g: scaled_loss_g.backward() else: loss.backward() tr_loss += loss.item() tr_loss_d += loss_d.item() tr_loss_g += masked_lm_loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer_d), args.max_grad_norm) torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer_g), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model_d.parameters(), args.max_grad_norm) torch.nn.utils.clip_grad_norm_(model_g.parameters(), args.max_grad_norm) optimizer_d.step() scheduler_d.step() # Update learning rate schedule model_d.zero_grad() optimizer_g.step() scheduler_g.step() # Update learning rate schedule model_g.zero_grad() if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} # if ( # args.local_rank == -1 and args.evaluate_during_training # ): # Only evaluate when single GPU otherwise metrics may not average well # results = evaluate(args, model, tokenizer) # for key, value in results.items(): # eval_key = "eval_{}".format(key) # logs[eval_key] = value loss_scalar = (tr_loss - logging_loss) / args.logging_steps loss_scalar_d = (tr_loss_d - logging_loss_d) / args.logging_steps loss_scalar_g = (tr_loss_g - logging_loss_g) / args.logging_steps learning_rate_scalar_d = scheduler_d.get_lr()[0] learning_rate_scalar_g = scheduler_g.get_lr()[0] logs["learning_rate_d"] = learning_rate_scalar_d logs["learning_rate_g"] = learning_rate_scalar_g logs["loss"] = loss_scalar logs["loss_d"] = loss_scalar_d logs["loss_g"] = loss_scalar_g logs["acc_repalce"] = acc_replace logs["acc_d"] = acc_d logs["acc_d_0"] = acc_d_0 logs["acc_d_1"] = acc_d_1 logs["acc_g"] = acc_g logs["labels_d_ones"] = labels_d_ones logs["masked_ratio"] = (mask_labels >= 0).float().sum( ).item() / (labels_d >= 0).sum().float().item() logging_loss = tr_loss logging_loss_d = tr_loss_d logging_loss_g = tr_loss_g for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) print(json.dumps({**logs, **{"step": global_step}})) # print(args.save_steps) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) output_dir_d = os.path.join( output_dir, "checkpoint-d-{}".format(global_step)) output_dir_g = os.path.join( output_dir, "checkpoint-g-{}".format(global_step)) if not os.path.exists(output_dir_d): os.makedirs(output_dir_d) if not os.path.exists(output_dir_g): os.makedirs(output_dir_g) model_to_save_d = ( model_d.module if hasattr(model_d, "module") else model_d) # Take care of distributed/parallel training model_to_save_g = ( model_g.module if hasattr(model_g, "module") else model_g) # Take care of distributed/parallel training model_to_save_d.save_pretrained(output_dir_d) model_to_save_g.save_pretrained(output_dir_g) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer_d.state_dict(), os.path.join(output_dir_d, "optimizer_d.pt")) torch.save(scheduler_d.state_dict(), os.path.join(output_dir_d, "scheduler_d.pt")) torch.save(optimizer_g.state_dict(), os.path.join(output_dir_d, "optimizer_g.pt")) torch.save(scheduler_g.state_dict(), os.path.join(output_dir_d, "scheduler_g.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) global_step += 1 if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step