def test_constant_scheduler(self): scheduler = get_constant_schedule(self.optimizer) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [10.] * self.num_steps self.assertEqual(len(lrs[0]), 1) self.assertListEqual([l[0] for l in lrs], expected_learning_rates) scheduler = get_constant_schedule(self.optimizer) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def configure_optimizers(self): if self.hparams.optimize == 'basic': optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) scheduler = get_constant_schedule(optimizer) elif self.hparams.optimize == 'bert': # Copied from: https://huggingface.co/transformers/training.html no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': self.hparams.weight_decay }, { 'params': [ p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0. }] optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.lr) self.num_warmup_steps = int(self.num_train_steps * self.hparams.warmup_proportion) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=self.num_warmup_steps, num_training_steps=self.num_train_steps) else: raise ValueError return [optimizer], [scheduler]
def configure_optimizers(self): "Prepare optimizer and schedule (linear warmup and decay)" model = self.model no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": self.hparams.weight_decay, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] # Original optimizer from Transformers. It works but needs warmup. # optimizer = transformers.AdamW(optimizer_grouped_parameters, # lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) # The RAdam optimizer works approximately as well as Ranger. #optimizer = RAdam(optimizer_grouped_parameters, # lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) # The Ranger optimizer is the combination of RAdam and Lookahead. It # works well for this task. The best conditions seem to be learning # rate 1e-4 w/ RAdam or Ranger, gradient accumulation of 2 batches. optimizer = ranger.Ranger(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) # The constant scheduler does nothing. Replace with another # scheduler if required. scheduler = transformers.get_constant_schedule(optimizer) scheduler = { 'scheduler': scheduler, 'interval': 'step', 'frequency': 1 } return [optimizer], [scheduler]
def init_optimizer(self, model, lr): args = self.args no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }] # TODO calculate t_total optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=args.adam_epsilon) # scheduler = WarmupLinearSchedule( # optimizer, warmup_steps=args.warmup_steps, t_total=t_total) scheduler = get_constant_schedule(optimizer) return optimizer_grouped_parameters, optimizer, scheduler
def get_scheduler(optimizer, scheduler: str, warmup_steps: int, num_total: int): assert scheduler in [ "constantlr", "warmuplinear", "warmupconstant", "warmupcosine", "warmupcosinewithhardrestarts" ], ('scheduler should be one of ["constantlr","warmupconstant","warmupcosine","warmupcosinewithhardrestarts"]' ) if scheduler == 'constantlr': return transformers.get_constant_schedule(optimizer) elif scheduler == 'warmupconstant': return transformers.get_constant_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps) elif scheduler == 'warmuplinear': return transformers.get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_total) elif scheduler == 'warmupcosine': return transformers.get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_total) elif scheduler == 'warmupcosinewithhardrestarts': return transformers.get_cosine_with_hard_restarts_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_total)
def _get_scheduler(self, optimizer, scheduler: str, warmup_steps: int, t_total: int): """ Returns the correct learning rate scheduler """ scheduler = scheduler.lower() if scheduler == 'constantlr': return transformers.get_constant_schedule(optimizer) elif scheduler == 'warmupconstant': return transformers.get_constant_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps) elif scheduler == 'warmuplinear': return transformers.get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) elif scheduler == 'warmupcosine': return transformers.get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) elif scheduler == 'warmupcosinewithhardrestarts': return transformers.get_cosine_with_hard_restarts_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) else: raise ValueError("Unknown scheduler {}".format(scheduler))
def configure_optimizers(self): optimizers = [ LookaheadRMSprop( params=[ { "params": self.gate.g_hat.parameters(), "lr": self.hparams.learning_rate, }, { "params": self.gate.placeholder.parameters() if isinstance(self.gate.placeholder, torch.nn.ParameterList) else [self.gate.placeholder], "lr": self.hparams.learning_rate_placeholder, }, ], centered=True, ), LookaheadRMSprop( params=[self.alpha] if isinstance(self.alpha, torch.Tensor) else self.alpha.parameters(), lr=self.hparams.learning_rate_alpha, ), ] schedulers = [ { "scheduler": get_constant_schedule_with_warmup(optimizers[0], 12 * 100), "interval": "step", }, get_constant_schedule(optimizers[1]), ] return optimizers, schedulers
def get_lr_scheduler(optimizer, scheduler_type, lr_warmup=None, num_steps=None): if scheduler_type == "linear": scheduler = get_linear_schedule_with_warmup( optimizer, int(lr_warmup * num_steps), num_steps) elif scheduler_type == "constant": scheduler = get_constant_schedule(optimizer) else: raise ValueError("Unknown scheduler_type:", scheduler_type) # Initialize step as Poptorch does not call optimizer.step() explicitly optimizer._step_count = 1 return scheduler
def train_model(train_dataloader, val_dataloader, model, EPOCHS, BATCH_SIZE, lr, ACCUMULATION_STEPS): ## Optimization num_train_optimization_steps = int(EPOCHS * len(train_dataloader) / BATCH_SIZE / ACCUMULATION_STEPS) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(np in n for np in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(np in n for np in no_decay)], 'weight_decay': 0.01} ] optimizer = AdamW(optimizer_grouped_parameters, lr=lr, correct_bias=False) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=num_train_optimization_steps) scheduler0 = get_constant_schedule(optimizer) frozen = True # Training for epoch in (range(EPOCHS+1)): print("\n--------Start training on Epoch %d/%d" %(epoch, EPOCHS)) avg_loss = 0 avg_accuracy = 0 model.train() for i, (input_ids, attention_mask, label_batch) in (enumerate(train_dataloader)): input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) label_batch = label_batch.to(device) optimizer.zero_grad() y_preds = model(input_ids, attention_mask, None) loss = torch.nn.functional.binary_cross_entropy(y_preds.to(device), label_batch.float().to(device)) loss = loss.mean() loss.backward() optimizer.step() lossf = loss.item() avg_loss += loss.item() / len(train_dataloader) print("Loss training:", avg_loss) roc = eval(val_dataloader, model, device) return model
def get_lr_scheduler(optimizer, scheduler_type, warmup_steps=None, num_steps=None, last_epoch=-1): if scheduler_type == "linear": scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, num_steps) elif scheduler_type == "constant": scheduler = get_constant_schedule(optimizer) elif scheduler_type == "cosine": scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, num_steps, last_epoch=last_epoch) else: raise ValueError("Unknown scheduler_type:", scheduler_type) return scheduler
def configure_optimizers(self): optimizers = [ LookaheadRMSprop( params=list(self.gate.parameters()) + [self.placeholder], lr=self.hparams.learning_rate, centered=True, ), LookaheadRMSprop( params=[self.alpha], lr=self.hparams.learning_rate_alpha, ), ] schedulers = [ { "scheduler": get_constant_schedule_with_warmup(optimizers[0], 200), "interval": "step", }, get_constant_schedule(optimizers[1]), ] return optimizers, schedulers
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", default=None, type=str, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--teacher_model", default=None, type=str, help="The teacher model dir.") parser.add_argument("--student_model", default=None, type=str, help="The student model dir.") parser.add_argument("--task_name", default="SST-2", type=str, help="The name of the task to train.") parser.add_argument( "--output_dir", default=None, type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=32, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) # added arguments parser.add_argument('--aug_train', action='store_true') parser.add_argument('--eval_step', type=float, default=0.1) parser.add_argument('--pred_distill', action='store_true') parser.add_argument('--data_url', type=str, default="") parser.add_argument('--temperature', type=float, default=1.) args = parser.parse_args() logger.info('The args: {}'.format(args)) # intermediate distillation default parameters default_params = { "cola": { "num_train_epochs": 50, "max_seq_length": 64 }, "mnli": { "num_train_epochs": 5, "max_seq_length": 128 }, "mrpc": { "num_train_epochs": 20, "max_seq_length": 128 }, "sst-2": { "num_train_epochs": 10, "max_seq_length": 64 }, "sts-b": { "num_train_epochs": 20, "max_seq_length": 128 }, "qqp": { "num_train_epochs": 5, "max_seq_length": 128 }, "qnli": { "num_train_epochs": 10, "max_seq_length": 128 }, "rte": { "num_train_epochs": 20, "max_seq_length": 128 } } acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"] corr_tasks = ["sts-b"] mcc_tasks = ["cola"] # Prepare devices device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger.info("device: {} n_gpu: {}".format(device, n_gpu)) # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare task settings if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError( "Output directory ({}) already exists and is not empty.".format( args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) task_name = args.task_name.lower() if task_name in default_params: args.max_seq_len = default_params[task_name]["max_seq_length"] if not args.pred_distill and not args.do_eval: if task_name in default_params: args.num_train_epoch = default_params[task_name][ "num_train_epochs"] if task_name not in processors: raise ValueError("Task not found: %s" % task_name) processor = processors[task_name]() output_mode = output_modes[task_name] label_list = processor.get_labels() num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.student_model, do_lower_case=args.do_lower_case) student_config = BertConfig.from_pretrained(args.student_model, num_labels=num_labels, finetuning_task=args.task_name) if not args.do_eval: if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps train_data, _ = get_tensor_data(args, task_name, tokenizer, False, args.aug_train) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) num_train_optimization_steps = int( len(train_dataloader) / args.gradient_accumulation_steps) * args.num_train_epochs eval_data, eval_labels = get_tensor_data(args, task_name, tokenizer, True, False) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) if not args.do_eval: teacher_config = BertConfig.from_pretrained( args.teacher_model, num_labels=num_labels, finetuning_task=args.task_name) teacher_model = TinyBertForSequenceClassification.from_pretrained( args.teacher_model, config=teacher_config) teacher_model.to(device) student_model = TinyBertForSequenceClassification.from_pretrained( args.student_model, config=student_config) student_model.to(device) if args.do_eval: logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_data)) logger.info(" Batch size = %d", args.eval_batch_size) student_model.eval() result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) else: logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_data)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) if n_gpu > 1: student_model = torch.nn.DataParallel(student_model) teacher_model = torch.nn.DataParallel(teacher_model) # Prepare optimizer param_optimizer = list(student_model.named_parameters()) size = 0 for n, p in student_model.named_parameters(): logger.info('n: {}'.format(n)) size += p.nelement() logger.info('Total parameters: {}'.format(size)) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(num_train_optimization_steps * args.warmup_proportion), num_training_steps=num_train_optimization_steps) if not args.pred_distill: scheduler = get_constant_schedule(optimizer) # Prepare loss functions loss_mse = MSELoss() def soft_cross_entropy(predicts, targets): student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1) targets_prob = torch.nn.functional.softmax(targets, dim=-1) return (-targets_prob * student_likelihood).mean() # Train and evaluate global_step = 0 best_dev_acc = 0.0 output_eval_file = os.path.join(args.output_dir, "eval_results.txt") for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. student_model.train() nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", ascii=True)): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch if input_ids.size()[0] != args.train_batch_size: continue att_loss = 0. rep_loss = 0. cls_loss = 0. student_logits, student_atts, student_reps = student_model( input_ids, segment_ids, input_mask, is_student=True) with torch.no_grad(): teacher_logits, teacher_atts, teacher_reps = teacher_model( input_ids, segment_ids, input_mask) if not args.pred_distill: teacher_layer_num = len(teacher_atts) student_layer_num = len(student_atts) # print("teacher_layer_num:",teacher_layer_num) # print("student_layer_num:",student_layer_num) # print("teacher_reps num:",len(teacher_reps)) assert teacher_layer_num % student_layer_num == 0 layers_per_block = int(teacher_layer_num / student_layer_num) new_teacher_atts = [ teacher_atts[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num) ] for student_att, teacher_att in zip( student_atts, new_teacher_atts): student_att = torch.where( student_att <= -1e2, torch.zeros_like(student_att).to(device), student_att) teacher_att = torch.where( teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att) tmp_loss = loss_mse(student_att, teacher_att) att_loss += tmp_loss new_teacher_reps = [ teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1) ] new_student_reps = student_reps for student_rep, teacher_rep in zip( new_student_reps, new_teacher_reps): tmp_loss = loss_mse(student_rep, teacher_rep) rep_loss += tmp_loss loss = rep_loss + att_loss tr_att_loss += att_loss.item() tr_rep_loss += rep_loss.item() else: if output_mode == "classification": cls_loss = soft_cross_entropy( student_logits / args.temperature, teacher_logits / args.temperature) elif output_mode == "regression": loss_mse = MSELoss() cls_loss = loss_mse(student_logits.view(-1), label_ids.view(-1)) loss = cls_loss tr_cls_loss += cls_loss.item() if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tr_loss += loss.item() nb_tr_examples += label_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 if (global_step + 1) % int( args.eval_step * num_train_optimization_steps) == 0: logger.info("***** Running evaluation *****") logger.info(" Epoch = {} iter {} step".format( epoch_, global_step)) logger.info(" Num examples = %d", len(eval_data)) logger.info(" Batch size = %d", args.eval_batch_size) student_model.eval() loss = tr_loss / (step + 1) cls_loss = tr_cls_loss / (step + 1) att_loss = tr_att_loss / (step + 1) rep_loss = tr_rep_loss / (step + 1) result = {} if args.pred_distill: result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) result['global_step'] = global_step result['cls_loss'] = cls_loss result['att_loss'] = att_loss result['rep_loss'] = rep_loss result['loss'] = loss result_to_file(result, output_eval_file) if not args.pred_distill: save_model = True else: save_model = False if task_name in acc_tasks and result[ 'acc'] > best_dev_acc: best_dev_acc = result['acc'] save_model = True if task_name in corr_tasks and result[ 'corr'] > best_dev_acc: best_dev_acc = result['corr'] save_model = True if task_name in mcc_tasks and result[ 'mcc'] > best_dev_acc: best_dev_acc = result['mcc'] save_model = True if save_model: logger.info("***** Save model *****") model_to_save = student_model.module if hasattr( student_model, 'module') else student_model model_name = "pytorch_model.bin" # if not args.pred_distill: # model_name = "step_{}_{}".format(global_step, "pytorch_model.bin") output_model_file = os.path.join( args.output_dir, model_name) output_config_file = os.path.join( args.output_dir, "config.json") torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) # Test mnli-mm if args.pred_distill and task_name == "mnli": task_name = "mnli-mm" if not os.path.exists(args.output_dir + '-MM'): os.makedirs(args.output_dir + '-MM') eval_data, eval_labels = get_tensor_data( args, task_name, tokenizer, True, False) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader( eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) logger.info("***** Running mm evaluation *****") logger.info(" Num examples = %d", len(eval_data)) logger.info(" Batch size = %d", args.eval_batch_size) result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) result['global_step'] = global_step tmp_output_eval_file = os.path.join( args.output_dir + '-MM', "eval_results.txt") result_to_file(result, tmp_output_eval_file) task_name = 'mnli' student_model.train()
def train(args): print(args) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.cuda: torch.cuda.manual_seed(args.seed) config_path = os.path.join(args.save_dir, 'config.json') model_path = os.path.join(args.save_dir, 'model.pt') log_path = os.path.join(args.save_dir, 'log.csv') export_config(args, config_path) check_path(model_path) with open(log_path, 'w') as fout: fout.write('step,train_acc,dev_acc\n') ################################################################################################### # Load data # ################################################################################################### if 'lm' in args.ent_emb: print('Using contextualized embeddings for concepts') use_contextualized, cp_emb = True, None else: use_contextualized = False cp_emb = [np.load(path) for path in args.ent_emb_paths] cp_emb = torch.tensor(np.concatenate(cp_emb, 1)) concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1) rel_emb = np.load(args.rel_emb_path) rel_emb = np.concatenate((rel_emb, -rel_emb), 0) rel_emb = cal_2hop_rel_emb(rel_emb) rel_emb = torch.tensor(rel_emb) relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1) # print('| num_concepts: {} | num_relations: {} |'.format(concept_num, relation_num)) device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu") dataset = LMRelationNetDataLoader(args.train_statements, args.train_rel_paths, args.dev_statements, args.dev_rel_paths, args.test_statements, args.test_rel_paths, batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, device=device, model_name=args.encoder, max_tuple_num=args.max_tuple_num, max_seq_length=args.max_seq_len, is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, use_contextualized=use_contextualized, train_adj_path=args.train_adj, dev_adj_path=args.dev_adj, test_adj_path=args.test_adj, train_node_features_path=args.train_node_features, dev_node_features_path=args.dev_node_features, test_node_features_path=args.test_node_features, node_feature_type=args.node_feature_type, format=args.format) ################################################################################################### # Build model # ################################################################################################### lstm_config = get_lstm_config_from_args(args) model = LMRelationNet(model_name=args.encoder, concept_num=concept_num, concept_dim=relation_dim, relation_num=relation_num, relation_dim=relation_dim, concept_in_dim=(dataset.get_node_feature_dim() if use_contextualized else concept_dim), hidden_size=args.mlp_dim, num_hidden_layers=args.mlp_layer_num, num_attention_heads=args.att_head_num, fc_size=args.fc_dim, num_fc_layers=args.fc_layer_num, dropout=args.dropoutm, pretrained_concept_emb=cp_emb, pretrained_relation_emb=rel_emb, freeze_ent_emb=args.freeze_ent_emb, init_range=args.init_range, ablation=args.ablation, use_contextualized=use_contextualized, emb_scale=args.emb_scale, encoder_config=lstm_config) try: model.to(device) except RuntimeError as e: print(e) print('best dev acc: 0.0 (at epoch 0)') print('final test acc: 0.0') print() return no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] grouped_parameters = [ {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr}, {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr}, {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr}, {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr}, ] optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters) if args.lr_schedule == 'fixed': scheduler = get_constant_schedule(optimizer) elif args.lr_schedule == 'warmup_constant': scheduler = get_constant_schedule_with_warmup(optimizer, warmup_steps=args.warmup_steps) elif args.lr_schedule == 'warmup_linear': max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size)) scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps) print('parameters:') for name, param in model.decoder.named_parameters(): if param.requires_grad: print('\t{:45}\ttrainable\t{}'.format(name, param.size())) else: print('\t{:45}\tfixed\t{}'.format(name, param.size())) num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad) print('\ttotal:', num_params) if args.loss == 'margin_rank': loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean') elif args.loss == 'cross_entropy': loss_func = nn.CrossEntropyLoss(reduction='mean') ################################################################################################### # Training # ################################################################################################### print() print('-' * 71) global_step, best_dev_epoch = 0, 0 best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0 start_time = time.time() model.train() freeze_net(model.encoder) try: rel_grad = [] linear_grad = [] for epoch_id in range(args.n_epochs): if epoch_id == args.unfreeze_epoch: print('encoder unfreezed') unfreeze_net(model.encoder) if epoch_id == args.refreeze_epoch: print('encoder refreezed') freeze_net(model.encoder) model.train() for qids, labels, *input_data in dataset.train(): optimizer.zero_grad() bs = labels.size(0) for a in range(0, bs, args.mini_batch_size): b = min(a + args.mini_batch_size, bs) logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer) if args.loss == 'margin_rank': num_choice = logits.size(1) flat_logits = logits.view(-1) correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1) # of length batch_size*num_choice correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1) # of length batch_size*(num_choice-1) wrong_logits = flat_logits[correct_mask == 0] # of length batch_size*(num_choice-1) y = wrong_logits.new_ones((wrong_logits.size(0),)) loss = loss_func(correct_logits, wrong_logits, y) # margin ranking loss elif args.loss == 'cross_entropy': loss = loss_func(logits, labels[a:b]) loss = loss * (b - a) / bs loss.backward() total_loss += loss.item() if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) rel_grad.append(model.decoder.rel_emb.weight.grad.abs().mean().item()) linear_grad.append(model.decoder.mlp.layers[8].weight.grad.abs().mean().item()) scheduler.step() optimizer.step() if (global_step + 1) % args.log_interval == 0: total_loss /= args.log_interval ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval print('| step {:5} | lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch)) # print('| rel_grad: {:1.2e} | linear_grad: {:1.2e} |'.format(sum(rel_grad) / len(rel_grad), sum(linear_grad) / len(linear_grad))) total_loss = 0 rel_grad = [] linear_grad = [] start_time = time.time() global_step += 1 model.eval() dev_acc = evaluate_accuracy(dataset.dev(), model) test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0 print('-' * 71) print('| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, dev_acc, test_acc)) print('-' * 71) with open(log_path, 'a') as fout: fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc)) if dev_acc >= best_dev_acc: best_dev_acc = dev_acc final_test_acc = test_acc best_dev_epoch = epoch_id torch.save([model, args], model_path) print(f'model saved to {model_path}') model.train() start_time = time.time() if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop: break except (KeyboardInterrupt, RuntimeError) as e: print(e) print() print('training ends in {} steps'.format(global_step)) print('best dev acc: {:.4f} (at epoch {})'.format(best_dev_acc, best_dev_epoch)) print('final test acc: {:.4f}'.format(final_test_acc)) print()
def train(args, train_dataset, model, model_config, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: # setup tb writer logger.info("Saving tensorboard logs to %s", args.tb_output_dir) tb_writer = SummaryWriter(log_dir=args.tb_output_dir, flush_secs=30) # Write config files to tensorboard tb_writer.add_text('encoder_config', str(model_config)) # create train log file if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) output_train_file = os.path.join(args.output_dir, "train_results.txt") args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) #scheduler = get_linear_schedule_with_warmup( # optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total #) scheduler = get_constant_schedule(optimizer) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt")): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model path global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility # Log once before training starts # Only evaluate when single GPU and not torch.distributed otherwise metrics may not average well if args.local_rank in [-1, 0]: if args.evaluate_during_training: results = evaluate(args, model, tokenizer, dev_set=True) for key, value in results.items(): tb_writer.add_scalar('eval_{}'.format(key), value, global_step) # log to wandb # wandb.log({f'eval_{key}': value}, step=0) # write to tensorboard tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) # Enter training loop for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: global_step += 1 # log gradients before clipping them if args.local_rank in [ -1, 0 ] and args.train_logging_steps > 0 and global_step % args.train_logging_steps == 0: for name, param in model.named_parameters(): # tb_writer.add_histogram(name, param, global_step) if param.grad is not None: grads = param.grad.view(-1) grads_norm = torch.norm(grads, p=2, dim=0) tb_writer.add_scalar(name + '_grad_norm', grads_norm, global_step) # else: # For XLM transformer.lang_embeddings.weight grads are disabled # print('Gradients are disabled for:', name) if args.fp16: total_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: total_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm) # log stuff if args.local_rank in [ -1, 0 ] and args.train_logging_steps > 0 and global_step % args.train_logging_steps == 0: # log weights and gradients after clipping for name, param in model.named_parameters(): # Compute l2 norm of the gradients if param.grad is not None: # tb_writer.add_histogram(name, param, global_step) # tb_writer.add_histogram( # name + '_grad', param.grad, global_step) grads = param.grad.view(-1) grads_norm = torch.norm(grads, p=2, dim=0) tb_writer.add_scalar(name + '_clipped_grad_norm', grads_norm, global_step) tb_writer.add_scalar('total_grad_norm', total_norm, global_step) # log learning rate tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) # log training loss loss = (tr_loss - logging_loss) / args.train_logging_steps tb_writer.add_scalar('train_loss', loss, global_step) logging_loss = tr_loss # log to wandb wandb.log( { 'loss': loss, 'total_grad_norm': total_norm, 'lr': scheduler.get_lr()[0] }, step=global_step) # write to logfile with open(output_train_file, "a") as writer: writer.write(f"{global_step}: train_loss = {loss}\n") optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} if ( args.local_rank == -1 and args.evaluate_during_training ): # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value # log to wandb wandb.log({'eval_key': value}, step=0) loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss # log to wandb wandb.log( { 'loss': loss, 'total_grad_norm': total_norm, 'lr': scheduler.get_lr()[0] }, step=global_step) for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) print(json.dumps({**logs, **{"step": global_step}})) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step
def train(args, train_dataset, model, tokenizer, orgin_dict): """ Train the model """ record_result = [] if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) # scheduler = get_linear_schedule_with_warmup( # optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total # ) scheduler = get_constant_schedule(optimizer) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt")): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) return_flag = False print('starting pruning') pruning_model(model, args.sparsity) rate_weight_equal_zero = see_weight_rate(model) print('zero_rate = ', rate_weight_equal_zero) print('starting rewinding') model_dict = model.state_dict() model_dict.update(orgin_dict) model.load_state_dict(model_dict) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to global_step of last saved checkpoint from model path try: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) except ValueError: global_step = 0 epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} if ( args.local_rank == -1 and args.evaluate_during_training ): # Only evaluate when single GPU otherwise metrics may not average well rate_weight_equal_zero = see_weight_rate(model) print('zero_rate = ', rate_weight_equal_zero) results = evaluate(args, model, tokenizer) # return_flag = True record_result.append(results) for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) print(json.dumps({**logs, **{"step": global_step}})) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(model, os.path.join(output_dir, "model.pt")) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if return_flag: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if return_flag: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() torch.save(record_result, os.path.join(args.output_dir, "record_result")) return global_step, tr_loss / global_step
def train(args): logging.info(f'{socket.gethostname()}: {os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "unknown"}') logging.info('python ' + ' '.join(sys.argv)) logging.info(args) model_path = os.path.join(args.save_dir, args.save_file_name) check_path(model_path) ################################################################################################### # Load data # ################################################################################################### cp_emb = [np.load(path) for path in args.ent_emb_paths] cp_emb = torch.tensor(np.concatenate(cp_emb, 1)) concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1) rel_emb = np.load(args.rel_emb_path) rel_emb = np.concatenate((rel_emb, -rel_emb), 0) rel_emb = torch.tensor(rel_emb) relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1) logging.info('| num_concepts: {} | num_relations: {} |'.format(concept_num, relation_num)) device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") lm_data_loader = LMDataLoader(args.train_jsonl, args.dev_jsonl, args.test_jsonl, batch_size=args.mini_batch_size, eval_batch_size=args.eval_batch_size, device=device, model_name=args.encoder, max_seq_length=args.max_seq_len, is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, subset_qids_path=args.subset_train_qids, format=args.format) logging.info(f'| # train questions: {lm_data_loader.train_size()} | # dev questions: {lm_data_loader.dev_size()} | # test questions: {lm_data_loader.test_size()} |') ################################################################################################### # Build model # ################################################################################################### graph_data_loader = GraphDataLoader(args.train_adj_pk, args.train_gen_pt, args.dev_adj_pk, args.dev_gen_pt, args.test_adj_pk, args.test_gen_pt, args.mini_batch_size, args.eval_batch_size, args.num_choice, args.ablation) train_avg_node_num, train_avg_edge_num = graph_data_loader.get_pyg_loader(lm_data_loader.get_train_indexes(), stats_only=True) dev_lm_data_loader = lm_data_loader.dev() dev_graph_loader, dev_avg_node_num, dev_avg_edge_num = graph_data_loader.dev_graph_data() assert len(dev_graph_loader) == len(dev_lm_data_loader) if args.inhouse: test_index = lm_data_loader.get_test_indexes() test_graph_loader, test_avg_node_num, test_avg_edge_num = graph_data_loader.get_pyg_loader(test_index) else: test_index = None test_graph_loader, test_avg_node_num, test_avg_edge_num = graph_data_loader.test_graph_data() test_lm_data_loader = lm_data_loader.test(test_index) assert len(test_graph_loader) == len(test_lm_data_loader) logging.info(f'| train | avg node num: {train_avg_node_num:.2f} | avg edge num: {train_avg_edge_num:.2f} |') logging.info(f'| dev | avg node num: {dev_avg_node_num:.2f} | avg edge num: {dev_avg_edge_num:.2f} |') logging.info(f'| test | avg node num: {test_avg_node_num:.2f} | avg edge num: {test_avg_edge_num:.2f} |') model = LMGraphNet(model_name=args.encoder, encoder_pooler=args.encoder_pooler, concept_num=concept_num, concept_dim=relation_dim, relation_num=relation_num, relation_dim=relation_dim, concept_in_dim=concept_dim, hidden_size=args.mlp_dim, num_attention_heads=args.att_head_num, fc_size=args.fc_dim, num_fc_layers=args.fc_layer_num, dropout=args.dropoutm, edge_weight_dropout=args.edge_weight_dropout, pretrained_concept_emb=cp_emb, pretrained_relation_emb=rel_emb, freeze_ent_emb=args.freeze_ent_emb, num_layers=args.num_gnn_layers, ablation=args.ablation, emb_scale=args.emb_scale, aristo_path=args.aristo_path) model.to(device) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] grouped_parameters = [ {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr}, {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr}, {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr}, {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr}, ] optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters) if args.lr_schedule == 'fixed': scheduler = get_constant_schedule(optimizer) elif args.lr_schedule == 'warmup_constant': scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps) elif args.lr_schedule == 'warmup_linear': max_steps = int(args.n_epochs * (lm_data_loader.train_size() / args.batch_size)) if args.warmup_ratio is not None: warmup_steps = int(args.warmup_ratio * max_steps) else: warmup_steps = args.warmup_steps scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps) logging.info('parameters:') for name, param in model.decoder.named_parameters(): if param.requires_grad: logging.info('\t{:45}\ttrainable\t{}'.format(name, param.size())) else: logging.info('\t{:45}\tfixed\t{}'.format(name, param.size())) num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad) logging.info(f'\ttotal: {num_params}') loss_func = nn.CrossEntropyLoss(reduction='mean') ################################################################################################### # Training # ################################################################################################### logging.info('') logging.info('-' * 71) global_step, eval_id, best_dev_id, best_dev_step = 0, 0, 0, 0 best_dev_acc, final_test_acc, best_test_acc, total_loss = 0.0, 0.0, 0.0, 0.0 best_test_acc = 0.0 exit_training = False train_start_time = time.time() start_time = train_start_time model.train() freeze_net(model.encoder) try: binary_score_lst = [] for epoch_id in range(args.n_epochs): if exit_training: break if epoch_id == args.unfreeze_epoch: logging.info('encoder unfreezed') unfreeze_net(model.encoder) if epoch_id == args.refreeze_epoch: logging.info('encoder refreezed') freeze_net(model.encoder) model.train() i = 0 optimizer.zero_grad() train_index = lm_data_loader.get_train_indexes() train_graph_loader, train_avg_node_num, train_avg_edge_num = graph_data_loader.get_pyg_loader(train_index) train_lm_data_loader = lm_data_loader.train(train_index) assert len(train_graph_loader) == len(train_lm_data_loader) for graph, (qids, labels, *lm_input_data) in zip(train_graph_loader, train_lm_data_loader): graph = graph.to(device) edge_index = graph.edge_index row, col = edge_index node_batch = graph.batch num_of_nodes = graph.num_of_nodes num_of_edges = graph.num_of_edges rel_ids_embs = graph.edge_attr c_ids = graph.x c_types = graph.node_type logits, unnormalized_wts, normalized_wts = model(*lm_input_data, edge_index=edge_index, c_ids=c_ids, c_types=c_types, node_batch=node_batch, rel_ids_embs=rel_ids_embs, num_of_nodes=num_of_nodes, num_of_edges=num_of_edges) loss = loss_func(logits, labels) # scale: loss per question if 'no_edge_weight' not in args.ablation and 'GAT' not in args.ablation: # add options for other kinds of sparsity log_wts = torch.log(normalized_wts + 0.0000001) entropy = - normalized_wts * log_wts # entropy: [num_of_edges in the batched graph, 1] entropy = scatter_mean(entropy, node_batch[row], dim=0, dim_size=args.mini_batch_size * args.num_choice) loss += args.alpha * torch.mean(entropy) # scale: entropy per graph (each question has num_choice graphs) loss = loss * args.mini_batch_size / args.batch_size # will be accumulated for (args.batch_size / args.mini_batch_size) times loss.backward() total_loss += loss.item() if 'no_edge_weight' not in args.ablation and 'GAT' not in args.ablation: binary_score_lst += entropy.squeeze().tolist() else: binary_score_lst.append(0) i = i + args.mini_batch_size if i % args.batch_size == 0: if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() # bp: scale: loss per question scheduler.step() optimizer.zero_grad() global_step += 1 if global_step % args.log_interval == 0: total_loss /= args.log_interval ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval logging.info('| step {:5} | lr: {:9.7f} | loss {:7.20f} | entropy score {:7.4f} | ms/batch {:7.2f} |' .format(global_step, scheduler.get_lr()[0], total_loss, np.mean(binary_score_lst), ms_per_batch)) total_loss = 0 binary_score_lst = [] start_time = time.time() if args.eval_interval > 0: if global_step % args.eval_interval == 0: eval_id += 1 model.eval() dev_acc = evaluate_accuracy(dev_graph_loader, dev_lm_data_loader, model, device) test_acc = evaluate_accuracy(test_graph_loader, test_lm_data_loader, model, device) # test_acc = 0.2 best_test_acc = max(best_test_acc, test_acc) logging.info('-' * 71) logging.info('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(global_step, dev_acc, test_acc)) logging.info('-' * 71) if dev_acc >= best_dev_acc: best_dev_acc = dev_acc final_test_acc = test_acc best_dev_id = eval_id best_dev_step = global_step if args.save_model: torch.save(model.state_dict(), model_path) copyfile(model_path, f'{model_path}_{global_step}_{dev_acc*100:.2f}_{test_acc*100:.2f}.pt') # tmp logging.info(f'model saved to {model_path}') else: logging.info(f'hit patience {eval_id - best_dev_id}/{args.patience}') model.train() if epoch_id > args.unfreeze_epoch and eval_id - best_dev_id >= args.patience: exit_training = True break if args.eval_interval == 0: eval_id += 1 model.eval() dev_acc = evaluate_accuracy(dev_graph_loader, dev_lm_data_loader, model, device) test_acc = evaluate_accuracy(test_graph_loader, test_lm_data_loader, model, device) best_test_acc = max(best_test_acc, test_acc) logging.info('-' * 71) logging.info('| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, dev_acc, test_acc)) logging.info('-' * 71) if dev_acc >= best_dev_acc: best_dev_acc = dev_acc final_test_acc = test_acc best_dev_id = eval_id best_dev_step = global_step if args.save_model: torch.save(model.state_dict(), model_path) logging.info(f'model saved to {model_path}') else: logging.info(f'hit patience {eval_id - best_dev_id}/{args.patience}') model.train() if epoch_id > args.unfreeze_epoch and eval_id - best_dev_id >= args.patience: exit_training = True break start_time = time.time() except KeyboardInterrupt: logging.info('-' * 89) logging.info('Exiting from training early') train_end_time = time.time() logging.info('') logging.info(f'training ends in {global_step} steps, {train_end_time - train_start_time:.0f} s') logging.info('best dev acc: {:.4f} (at step {})'.format(best_dev_acc, best_dev_step)) logging.info('final test acc: {:.4f}'.format(final_test_acc)) if args.use_last_epoch: logging.info(f'last dev acc: {dev_acc:.4f}') logging.info(f'last test acc: {test_acc:.4f}') return dev_acc, test_acc, best_test_acc else: return best_dev_acc, final_test_acc, best_test_acc
def train(self, train_path: str, valid_path: str, types_path: str, input_reader_cls: BaseInputReader): args = self.args train_label, valid_label = 'train', 'valid' self._logger.info("Datasets: %s, %s" % (train_path, valid_path)) self._logger.info("Model type: %s" % args.model_type) # create log csv files self._init_train_logging(train_label) self._init_eval_logging(valid_label) # read datasets input_reader = input_reader_cls(types_path, args.bio_path, self._tokenizer, self._logger) input_reader.read({train_label: train_path, valid_label: valid_path}) self._log_datasets(input_reader) train_dataset = input_reader.get_dataset(train_label) train_sample_count = train_dataset.document_count updates_epoch = train_sample_count // args.train_batch_size updates_total = updates_epoch * args.epochs steps_before_rel = int(updates_total * self.args.before_rel) validation_dataset = input_reader.get_dataset(valid_label) self._logger.info("Updates per epoch: %s" % updates_epoch) self._logger.info("Updates total: %s" % updates_total) self._logger.info("Updates before relation: %s" % steps_before_rel) # create model model_class = models.get_model(self.args.model_type) # load model if args.model_type == 'table_filling': model = model_class.from_pretrained( self.args.model_path, cache_dir=self.args.cache_path, tokenizer=self._tokenizer, # table_filling model parameters relation_labels=input_reader.relation_label_count, entity_labels=input_reader.entity_label_count, att_hidden=self.args.att_hidden, prop_drop=self.args.prop_drop, entity_label_embedding=self.args.entity_label_embedding, freeze_transformer=self.args.freeze_transformer, device=self._device) # if self._device.type != 'cpu': # torch.distributed.init_process_group(backend='nccl', world_size=3, init_method='...') # model = torch.nn.parallel.DistributedDataParallel(model) model.to(self._device) # model.to(f'cuda:{model.device_ids[0]}') # create optimizer optimizer_params = self._get_optimizer_params(model) optimizer = AdamW(optimizer_params, lr=args.lr, weight_decay=args.weight_decay, correct_bias=False) # other_optimizer_params = self._get_optimizer_params([]) # create scheduler if args.scheduler == 'constant': scheduler = transformers.get_constant_schedule(optimizer) elif args.scheduler == 'constant_warmup': scheduler = transformers.get_constant_schedule_with_warmup( optimizer, num_warmup_steps=args.lr_warmup * updates_total) elif args.scheduler == 'linear_warmup': scheduler = transformers.get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.lr_warmup * updates_total, num_training_steps=updates_total) elif args.scheduler == 'cosine_warmup': scheduler = transformers.get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=args.lr_warmup * updates_total, num_training_steps=updates_total) elif args.scheduler == 'cosine_warmup_restart': scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup( optimizer, num_warmup_steps=args.lr_warmup * updates_total, num_training_steps=updates_total, num_cycles=args.num_cycles) # create loss function rel_criterion = torch.nn.CrossEntropyLoss(reduction='none') entity_criterion = torch.nn.CrossEntropyLoss(reduction='none') if args.model_type == 'table_filling': compute_loss = TableLoss(rel_criterion, entity_criterion, model, optimizer, scheduler, args.max_grad_norm) # eval validation set if args.init_eval: self._eval(model, compute_loss, validation_dataset, input_reader, 0, updates_epoch) # train for epoch in range(args.epochs): # train epoch self._train_epoch(model, compute_loss, optimizer, train_dataset, updates_epoch, epoch, input_reader.context_size, input_reader.entity_label_count, input_reader.relation_label_count, input_reader._start_entity_label, steps_before_rel) # eval validation sets if not args.final_eval or (epoch == args.epochs - 1): ner_acc, rel_acc, rel_ner_acc = self._eval( model, compute_loss, validation_dataset, input_reader, epoch, updates_epoch) if args.save_best: extra = dict(epoch=epoch, updates_epoch=updates_epoch, epoch_iteration=0) self._save_best(model=model, optimizer=optimizer if self.args.save_optimizer else None, accuracy=ner_acc[2], iteration=epoch * updates_epoch, label='ner_micro_f1', extra=extra) # save final model extra = dict(epoch=args.epochs, updates_epoch=updates_epoch, epoch_iteration=0) global_iteration = args.epochs * updates_epoch self._save_model( self._save_path, model, global_iteration, optimizer=optimizer if self.args.save_optimizer else None, extra=extra, include_iteration=False, name='final_model') self._logger.info("Logged in: %s" % self._log_path) self._logger.info("Saved in: %s" % self._save_path)
def train(args): util.ensure_dir(args["save_dir"]) model_file = args["save_dir"] + "/" + "phonlp.pt" tokenizer = AutoTokenizer.from_pretrained(args["pretrained_lm"], use_fast=False) config_phobert = AutoConfig.from_pretrained(args["pretrained_lm"], output_hidden_states=True) print("Loading data with batch size {}...".format(args["batch_size"])) train_doc_dep = Document( CoNLL.conll2dict(input_file=args["train_file_dep"])) vocab = BuildVocab(args, args["train_file_pos"], train_doc_dep, args["train_file_ner"]).vocab train_batch_pos = DataLoaderPOS( args["train_file_pos"], args["batch_size"], args, vocab=vocab, evaluation=False, tokenizer=tokenizer, max_seq_length=args["max_sequence_length"], ) train_batch_dep = DataLoaderDep( train_doc_dep, args["batch_size"], args, vocab=vocab, evaluation=False, tokenizer=tokenizer, max_seq_length=args["max_sequence_length"], ) train_batch_ner = DataLoaderNER( args["train_file_ner"], args["batch_size"], args, vocab=vocab, evaluation=False, tokenizer=tokenizer, max_seq_length=args["max_sequence_length"], ) dev_doc_dep = Document(CoNLL.conll2dict(input_file=args["eval_file_dep"])) dev_batch_pos = DataLoaderPOS( args["eval_file_pos"], args["batch_size"], args, vocab=vocab, sort_during_eval=True, evaluation=True, tokenizer=tokenizer, max_seq_length=args["max_sequence_length"], ) dev_batch_dep = DataLoaderDep( dev_doc_dep, args["batch_size"], args, vocab=vocab, sort_during_eval=True, evaluation=True, tokenizer=tokenizer, max_seq_length=args["max_sequence_length"], ) dev_batch_ner = DataLoaderNER( args["eval_file_ner"], args["batch_size"], args, vocab=vocab, evaluation=True, tokenizer=tokenizer, max_seq_length=args["max_sequence_length"], ) # pred and gold path system_pred_file = args["output_file_dep"] gold_file = args["eval_file_dep"] # ##POS dev_gold_tags = dev_batch_ner.tags # skip training if the language does not have training or dev data if len(train_batch_pos) == 0 or len(dev_batch_pos) == 0: print("Skip training because no data available...") sys.exit(0) print("Training jointmodel...") trainer = JointTrainer(args, vocab, None, config_phobert, args["cuda"]) # ### tsfm = trainer.model.phobert for child in tsfm.children(): for param in child.parameters(): if not param.requires_grad: print("whoopsies") param.requires_grad = True global_step = 0 las_score_history = 0 uas_score_history = 0 upos_score_history = 0 f1_score_history = 0 #### # start training train_loss = 0 train_loss_pos = 0 train_loss_dep = 0 train_loss_ner = 0 # Creating optimizer and lr schedulers param_optimizer = list(trainer.model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01 }, { "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0 }, ] num_train_optimization_steps = int( args["num_epoch"] * len(train_batch_pos) / args["accumulation_steps"]) optimizer = AdamW( optimizer_grouped_parameters, lr=args["lr"], correct_bias=False ) # To reproduce BertAdam specific behavior set correct_bias=False scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=5, num_training_steps=num_train_optimization_steps) get_constant_schedule(optimizer) for epoch in range(args["num_epoch"]): #### optimizer.zero_grad() print(" EPOCH : ", epoch) step = 0 lambda_pos = args["lambda_pos"] lambda_ner = args["lambda_ner"] lambda_dep = args["lambda_dep"] epoch_size = max( [len(train_batch_pos), len(train_batch_dep), len(train_batch_ner)]) for i in tqdm(range(epoch_size)): step += 1 global_step += 1 batch_pos = train_batch_pos[i] batch_dep = train_batch_dep[i] batch_ner = train_batch_ner[i] ### loss, loss_pos, loss_ner = trainer.update( batch_dep, batch_pos, batch_ner, lambda_pos=lambda_pos, lambda_dep=lambda_dep, lambda_ner=lambda_ner) # update step train_loss += loss train_loss_pos += loss_pos # train_loss_dep += loss_dep train_loss_ner += loss_ner ### if i % args["accumulation_steps"] == 0: optimizer.step() optimizer.zero_grad() scheduler.step() if epoch_size == len(train_batch_pos): if step % len(train_batch_dep) == 0: train_batch_dep.reshuffle() if step % len(train_batch_ner) == 0: train_batch_ner.reshuffle() elif epoch_size == len(train_batch_ner): if step % len(train_batch_dep) == 0: train_batch_dep.reshuffle() if step % len(train_batch_pos) == 0: train_batch_pos.reshuffle() elif epoch_size == len(train_batch_dep): if step % len(train_batch_pos) == 0: train_batch_dep.reshuffle() if step % len(train_batch_ner) == 0: train_batch_ner.reshuffle() if step % args["eval_interval"] == 0: print("Evaluating on dev set...") dev_preds_dep = [] dev_preds_upos = [] dev_preds_ner = [] for batch in dev_batch_dep: preds_dep = trainer.predict_dep(batch) dev_preds_dep += preds_dep ### dev_preds_dep = util.unsort(dev_preds_dep, dev_batch_dep.data_orig_idx_dep) dev_batch_dep.doc_dep.set( [HEAD, DEPREL], [y for x in dev_preds_dep for y in x]) CoNLL.dict2conll(dev_batch_dep.doc_dep.to_dict(), system_pred_file) _, _, las_dev, uas_dev = score_dep.score( system_pred_file, gold_file) for batch in dev_batch_pos: preds_pos = trainer.predict_pos(batch) dev_preds_upos += preds_pos dev_preds_upos = util.unsort(dev_preds_upos, dev_batch_pos.data_orig_idx_pos) accuracy_pos_dev = score_pos.score_acc(dev_preds_upos, dev_batch_pos.upos) for batch in dev_batch_ner: preds_ner = trainer.predict_ner(batch) dev_preds_ner += preds_ner p, r, f1 = score_ner.score_by_entity(dev_preds_ner, dev_gold_tags) for i in range(len(dev_batch_ner)): assert len(dev_preds_ner[i]) == len(dev_gold_tags[i]) print( "step {}: dev_las_score = {:.4f}, dev_uas_score = {:.4f}, dev_pos = {:.4f}, dev_ner_p = {:.4f}, dev_ner_r = {:.4f}, dev_ner_f1 = {:.4f}" .format(global_step, las_dev, uas_dev, accuracy_pos_dev, p, r, f1)) # save best model if las_dev + accuracy_pos_dev + f1 >= (las_score_history + upos_score_history + f1_score_history): las_score_history = las_dev upos_score_history = accuracy_pos_dev uas_score_history = uas_dev f1_score_history = f1 trainer.save(model_file) print("new best model saved.") print("") print("Evaluating on dev set...") dev_preds_dep = [] dev_preds_upos = [] dev_preds_ner = [] for batch in dev_batch_dep: preds_dep = trainer.predict_dep(batch) dev_preds_dep += preds_dep dev_preds_dep = util.unsort(dev_preds_dep, dev_batch_dep.data_orig_idx_dep) dev_batch_dep.doc_dep.set([HEAD, DEPREL], [y for x in dev_preds_dep for y in x]) CoNLL.dict2conll(dev_batch_dep.doc_dep.to_dict(), system_pred_file) _, _, las_dev, uas_dev = score_dep.score(system_pred_file, gold_file) for batch in dev_batch_pos: preds_pos = trainer.predict_pos(batch) dev_preds_upos += preds_pos dev_preds_upos = util.unsort(dev_preds_upos, dev_batch_pos.data_orig_idx_pos) accuracy_pos_dev = score_pos.score_acc(dev_preds_upos, dev_batch_pos.upos) for batch in dev_batch_ner: preds_ner = trainer.predict_ner(batch) dev_preds_ner += preds_ner p, r, f1 = score_ner.score_by_entity(dev_preds_ner, dev_gold_tags) for i in range(len(dev_batch_ner)): assert len(dev_preds_ner[i]) == len(dev_gold_tags[i]) train_loss = train_loss / len(train_batch_pos) # avg loss per batch train_loss_dep = train_loss_dep / len(train_batch_pos) train_loss_pos = train_loss_pos / len(train_batch_pos) train_loss_ner = train_loss_ner / len(train_batch_pos) print( "step {}: train_loss = {:.6f}, train_loss_dep = {:.6f}, train_loss_pos = {:.6f}, train_loss_ner = {:.6f}, dev_las_score = {:.4f}, dev_uas_score = {:.4f}, dev_pos = {:.4f}, dev_ner_p = {:.4f}, dev_ner_r = {:.4f}, dev_ner_f1 = {:.4f} " .format( global_step, train_loss, train_loss_dep, train_loss_pos, train_loss_ner, las_dev, uas_dev, accuracy_pos_dev, p, r, f1, )) # save best model if las_dev + accuracy_pos_dev + f1 >= ( las_score_history + upos_score_history + f1_score_history): las_score_history = las_dev upos_score_history = accuracy_pos_dev uas_score_history = uas_dev f1_score_history = f1 trainer.save(model_file) print("new best model saved.") train_loss = 0 train_loss_pos = 0 train_loss_dep = 0 train_loss_ner = 0 print("") train_batch_dep.reshuffle() train_batch_pos.reshuffle() train_batch_ner.reshuffle() print("Training ended with {} epochs.".format(epoch)) best_las, uas, upos, f1 = ( las_score_history * 100, uas_score_history * 100, upos_score_history * 100, f1_score_history * 100, ) print("Best dev las = {:.2f}, uas = {:.2f}, upos = {:.2f}, f1 = {:.2f}". format(best_las, uas, upos, f1))
def train(args): print(args) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.cuda: torch.cuda.manual_seed(args.seed) model_path = os.path.join(args.save_dir, 'model.pt') check_path(model_path) logger = setup_logger(__name__, args.save_dir + "log.txt") logger.info(args) ################################################################################################### # Load data # ################################################################################################### device = torch.device( "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu") dataset = LMDataLoader(args.train_statements, args.dev_statements, args.test_statements, batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, device=device, model_name=args.encoder, max_seq_length=args.max_seq_len, is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, subsample=args.subsample, format=args.format) ################################################################################################### # Build model # ################################################################################################### lstm_config = get_lstm_config_from_args(args) model = LMForMultipleChoice(args.encoder, from_checkpoint=args.from_checkpoint, encoder_config=lstm_config) try: model.to(device) except RuntimeError as e: logger.info(e) logger.info('best dev acc: 0.0 (at epoch 0)') logger.info('final test acc: 0.0') print() return no_decay = ['bias', 'LayerNorm.weight'] grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'lr': args.encoder_lr, 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'lr': args.encoder_lr, 'weight_decay': 0.0 }] optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters) if args.lr_schedule == 'fixed': scheduler = get_constant_schedule(optimizer) elif args.lr_schedule == 'warmup_constant': scheduler = get_constant_schedule_with_warmup( optimizer, warmup_steps=args.warmup_steps) elif args.lr_schedule == 'warmup_linear': max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size)) scheduler = get_linear_schedule_with_warmup( optimizer, warmup_steps=args.warmup_steps, t_total=max_steps) if args.loss == 'margin_rank': loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean') elif args.loss == 'cross_entropy': loss_func = nn.CrossEntropyLoss(reduction='mean') ################################################################################################### # Training # ################################################################################################### print() print('***** running training *****') logger.info( f'| batch_size: {args.batch_size} | num_epochs: {args.n_epochs} | num_train: {dataset.train_size()} |' f' num_dev: {dataset.dev_size()} | num_test: {dataset.test_size()}') global_step = 0 best_dev_acc = 0 best_dev_epoch = 0 final_test_acc = 0 try: for epoch in range(int(args.n_epochs)): model.train() tqdm_bar = tqdm(dataset.train(), desc="Training") for qids, labels, *input_data in tqdm_bar: optimizer.zero_grad() batch_loss = 0 bs = labels.size(0) for a in range(0, bs, args.mini_batch_size): b = min(a + args.mini_batch_size, bs) logits = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer) if args.loss == 'margin_rank': num_choice = logits.size(1) flat_logits = logits.view(-1) correct_mask = F.one_hot( labels, num_classes=num_choice).view( -1) # of length batch_size*num_choice correct_logits = flat_logits[ correct_mask == 1].contiguous().view(-1, 1).expand( -1, num_choice - 1).contiguous().view( -1) # of length batch_size*(num_choice-1) wrong_logits = flat_logits[ correct_mask == 0] # of length batch_size*(num_choice-1) y = wrong_logits.new_ones((wrong_logits.size(0), )) loss = loss_func(correct_logits, wrong_logits, y) # margin ranking loss elif args.loss == 'cross_entropy': loss = loss_func(logits, labels[a:b]) loss = loss * (b - a) / bs loss.backward() batch_loss += loss.item() if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() tqdm_bar.desc = "loss: {:.2e} lr: {:.2e}".format( batch_loss, scheduler.get_lr()[0]) global_step += 1 model.eval() dev_acc = evaluate_accuracy(dataset.dev(), model) test_acc = evaluate_accuracy( dataset.test(), model) if dataset.test_size() > 0 else 0.0 if dev_acc > best_dev_acc: final_test_acc = test_acc best_dev_acc = dev_acc best_dev_epoch = epoch torch.save([model, args], model_path) logger.info( '| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format( epoch, dev_acc, test_acc)) if epoch - best_dev_epoch >= args.max_epochs_before_stop: break except (KeyboardInterrupt, RuntimeError) as e: print(e) print('***** training ends *****') print() logger.info('training ends in {} steps'.format(global_step)) logger.info('best dev acc: {:.4f} (at epoch {})'.format( best_dev_acc, best_dev_epoch)) logger.info('final test acc: {:.4f}'.format(final_test_acc)) print()
def train(args): print(args) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.cuda: torch.cuda.manual_seed(args.seed) config_path = os.path.join(args.save_dir, 'config.json') model_path = os.path.join(args.save_dir, 'model.pt') log_path = os.path.join(args.save_dir, 'log.csv') export_config(args, config_path) check_path(model_path) with open(log_path, 'w') as fout: fout.write('step,dev_acc,test_acc\n') ################################################################################################### # Load data # ################################################################################################### cp_emb = [np.load(path) for path in args.ent_emb_paths] cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float) concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1) print('| num_concepts: {} |'.format(concept_num)) # try: if True: if torch.cuda.device_count() >= 2 and args.cuda: device0 = torch.device("cuda:0") device1 = torch.device("cuda:1") elif torch.cuda.device_count() == 1 and args.cuda: device0 = torch.device("cuda:0") device1 = torch.device("cuda:0") else: device0 = torch.device("cpu") device1 = torch.device("cpu") dataset = LM_QAGNN_DataLoader(args, args.train_statements, args.train_adj, args.dev_statements, args.dev_adj, args.test_statements, args.test_adj, batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, device=(device0, device1), model_name=args.encoder, max_node_num=args.max_node_num, max_seq_length=args.max_seq_len, is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, subsample=args.subsample, use_cache=args.use_cache) ################################################################################################### # Build model # ################################################################################################### model = LM_QAGNN(args, args.encoder, k=args.k, n_ntype=4, n_etype=args.num_relation, n_concept=concept_num, concept_dim=args.gnn_dim, concept_in_dim=concept_dim, n_attention_head=args.att_head_num, fc_dim=args.fc_dim, n_fc_layer=args.fc_layer_num, p_emb=args.dropouti, p_gnn=args.dropoutg, p_fc=args.dropoutf, pretrained_concept_emb=cp_emb, freeze_ent_emb=args.freeze_ent_emb, init_range=args.init_range, encoder_config={}) model.encoder.to(device0) model.decoder.to(device1) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] grouped_parameters = [ {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr}, {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr}, {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr}, {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr}, ] optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters) if args.lr_schedule == 'fixed': try: scheduler = ConstantLRSchedule(optimizer) except: scheduler = get_constant_schedule(optimizer) elif args.lr_schedule == 'warmup_constant': try: scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps) except: scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps) elif args.lr_schedule == 'warmup_linear': max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size)) try: scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps) except: scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=max_steps) print('parameters:') for name, param in model.decoder.named_parameters(): if param.requires_grad: print('\t{:45}\ttrainable\t{}\tdevice:{}'.format(name, param.size(), param.device)) else: print('\t{:45}\tfixed\t{}\tdevice:{}'.format(name, param.size(), param.device)) num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad) print('\ttotal:', num_params) if args.loss == 'margin_rank': loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean') elif args.loss == 'cross_entropy': loss_func = nn.CrossEntropyLoss(reduction='mean') ################################################################################################### # Training # ################################################################################################### print() print('-' * 71) global_step, best_dev_epoch = 0, 0 best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0 start_time = time.time() model.train() freeze_net(model.encoder) if True: # try: for epoch_id in range(args.n_epochs): if epoch_id == args.unfreeze_epoch: unfreeze_net(model.encoder) if epoch_id == args.refreeze_epoch: freeze_net(model.encoder) model.train() for qids, labels, *input_data in dataset.train(): optimizer.zero_grad() bs = labels.size(0) for a in range(0, bs, args.mini_batch_size): b = min(a + args.mini_batch_size, bs) logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer) if args.loss == 'margin_rank': num_choice = logits.size(1) flat_logits = logits.view(-1) correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1) # of length batch_size*num_choice correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1) # of length batch_size*(num_choice-1) wrong_logits = flat_logits[correct_mask == 0] y = wrong_logits.new_ones((wrong_logits.size(0),)) loss = loss_func(correct_logits, wrong_logits, y) # margin ranking loss elif args.loss == 'cross_entropy': loss = loss_func(logits, labels[a:b]) loss = loss * (b - a) / bs loss.backward() total_loss += loss.item() if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scheduler.step() optimizer.step() if (global_step + 1) % args.log_interval == 0: total_loss /= args.log_interval ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval print('| step {:5} | lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch)) total_loss = 0 start_time = time.time() global_step += 1 model.eval() dev_acc = evaluate_accuracy(dataset.dev(), model) save_test_preds = args.save_model if not save_test_preds: test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0 else: eval_set = dataset.test() total_acc = [] count = 0 preds_path = os.path.join(args.save_dir, 'test_e{}_preds.csv'.format(epoch_id)) with open(preds_path, 'w') as f_preds: with torch.no_grad(): for qids, labels, *input_data in tqdm(eval_set): count += 1 logits, _, concept_ids, node_type_ids, edge_index, edge_type = model(*input_data, detail=True) predictions = logits.argmax(1) #[bsize, ] preds_ranked = (-logits).argsort(1) #[bsize, n_choices] for i, (qid, label, pred, _preds_ranked, cids, ntype, edges, etype) in enumerate(zip(qids, labels, predictions, preds_ranked, concept_ids, node_type_ids, edge_index, edge_type)): acc = int(pred.item()==label.item()) print ('{},{}'.format(qid, chr(ord('A') + pred.item())), file=f_preds) f_preds.flush() total_acc.append(acc) test_acc = float(sum(total_acc))/len(total_acc) print('-' * 71) print('| epoch {:3} | step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, global_step, dev_acc, test_acc)) print('-' * 71) with open(log_path, 'a') as fout: fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc)) if dev_acc >= best_dev_acc: best_dev_acc = dev_acc final_test_acc = test_acc best_dev_epoch = epoch_id if args.save_model: torch.save([model, args], model_path +".{}".format(epoch_id)) with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f: for p in model.named_parameters(): print (p, file=f) print(f'model saved to {model_path}') else: if args.save_model: torch.save([model, args], model_path +".{}".format(epoch_id)) with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f: for p in model.named_parameters(): print (p, file=f) print(f'model saved to {model_path}') model.train() start_time = time.time() if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop: break
def train(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.cuda: torch.cuda.manual_seed(args.seed) print('configuration:') print('\n'.join('\t{:15} {}'.format(k + ':', str(v)) for k, v in sorted(dict(vars(args)).items()))) print() config_path = os.path.join(args.save_dir, 'config.json') model_path = os.path.join(args.save_dir, 'model.pt') log_path = os.path.join(args.save_dir, 'log.csv') export_config(args, config_path) check_path(model_path) with open(log_path, 'w') as fout: fout.write('step,train_acc,dev_acc\n') dic = {'transe': 0, 'numberbatch': 1} cp_emb, rel_emb = [np.load(args.ent_emb_paths[dic[source]]) for source in args.ent_emb], np.load(args.rel_emb_path) cp_emb = np.concatenate(cp_emb, axis=1) cp_emb = torch.tensor(cp_emb) rel_emb = np.concatenate((rel_emb, -rel_emb), 0) rel_emb = torch.tensor(rel_emb) concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1) print('num_concepts: {}, concept_dim: {}'.format(concept_num, concept_dim)) relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1) print('num_relations: {}, relation_dim: {}'.format(relation_num, relation_dim)) try: device0 = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu") device1 = torch.device("cuda:1" if torch.cuda.is_available() and args.cuda else "cpu") dataset = KagNetDataLoader(args.train_statements, args.train_paths, args.train_graphs, args.dev_statements, args.dev_paths, args.dev_graphs, args.test_statements, args.test_paths, args.test_graphs, batch_size=args.mini_batch_size, eval_batch_size=args.eval_batch_size, device=(device0, device1), model_name=args.encoder, max_seq_length=args.max_seq_len, max_path_len=args.max_path_len, is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, use_cache=args.use_cache, format=args.format) print('dataset done') ################################################################################################### # Build model # ################################################################################################### lstm_config = get_lstm_config_from_args(args) model = LMKagNet(model_name=args.encoder, concept_dim=concept_dim, relation_dim=relation_dim, concept_num=concept_num, relation_num=relation_num, qas_encoded_dim=args.qas_encoded_dim, pretrained_concept_emb=cp_emb, pretrained_relation_emb=rel_emb, lstm_dim=args.lstm_dim, lstm_layer_num=args.lstm_layer_num, graph_hidden_dim=args.graph_hidden_dim, graph_output_dim=args.graph_output_dim, dropout=args.dropout, bidirect=args.bidirect, num_random_paths=args.num_random_paths, path_attention=args.path_attention, qa_attention=args.qa_attention, encoder_config=lstm_config) print('model done') if args.freeze_ent_emb: freeze_net(model.decoder.concept_emb) print('freezed') model.encoder.to(device0) print('encoder done') model.decoder.to(device1) print('decoder done') except RuntimeError as e: print(e) print('best dev acc: 0.0 (at epoch 0)') print('final test acc: 0.0') print() return no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] grouped_parameters = [ {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr}, {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr}, {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr}, {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr}, ] optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters) if args.lr_schedule == 'fixed': scheduler = get_constant_schedule(optimizer) elif args.lr_schedule == 'warmup_constant': scheduler = get_constant_schedule_with_warmup(optimizer, warmup_steps=args.warmup_steps) elif args.lr_schedule == 'warmup_linear': max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size)) scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps) print('parameters:') for name, param in model.decoder.named_parameters(): if param.requires_grad: print('\t{:45}\ttrainable\t{}'.format(name, param.size())) else: print('\t{:45}\tfixed\t{}'.format(name, param.size())) num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad) print('\ttotal:', num_params) if args.loss == 'margin_rank': loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean') elif args.loss == 'cross_entropy': loss_func = nn.CrossEntropyLoss(reduction='mean') print() print('-' * 71) global_step, last_best_step = 0, 0 best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0 start_time = time.time() model.train() freeze_net(model.encoder) try: for epoch_id in range(args.n_epochs): if epoch_id == args.unfreeze_epoch: unfreeze_net(model.encoder) if epoch_id == args.refreeze_epoch: freeze_net(model.encoder) for qids, labels, *input_data in dataset.train(): optimizer.zero_grad() bs = labels.size(0) for a in range(0, bs, args.mini_batch_size): print(00) b = min(a + args.mini_batch_size, bs) # print(11) # # print([x.device if isinstance(x, (torch.tensor,)) else None for x in input_data]) # print(type(input_data[0]), type(input_data[0][0]), input_data[0][0].size()) # print(type(input_data[1]), type(input_data[1][0]), input_data[1][0].size()) # print(type(input_data[2]), type(input_data[2][0]), input_data[2][0].size()) # print(type(input_data[3]), type(input_data[3][0]), input_data[3][0].size()) # print(type(input_data[4]), type(input_data[4][0])) # print(type(input_data[5]), type(input_data[5][0])) # print(type(input_data[6]), type(input_data[6][0])) # print(type(input_data[7]), type(input_data[7][0])) # print(type(input_data[8]), type(input_data[8][0])) # print(type(input_data[9])) # print(type(input_data[10])) logits, _ = model(*[x for x in input_data], layer_id=args.encoder_layer) if args.loss == 'margin_rank': num_choice = logits.size(1) flat_logits = logits.view(-1) correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1) # of length batch_size*num_choice correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1) # of length batch_size*(num_choice-1) wrong_logits = flat_logits[correct_mask == 0] # of length batch_size*(num_choice-1) y = wrong_logits.new_ones((wrong_logits.size(0),)) loss = loss_func(correct_logits, wrong_logits, y) # margin ranking loss elif args.loss == 'cross_entropy': loss = loss_func(logits, labels[a:b]) loss = loss * (b - a) / bs loss.backward() total_loss += loss.item() if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scheduler.step() optimizer.step() if (global_step + 1) % args.log_interval == 0: total_loss /= args.log_interval ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval print('| step {:5} | lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch)) total_loss = 0 start_time = time.time() if (global_step + 1) % args.eval_interval == 0: model.eval() dev_acc = evaluate_accuracy(dataset.dev(), model) test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0 print('-' * 71) print('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(global_step, dev_acc, test_acc)) print('-' * 71) with open(log_path, 'a') as fout: fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc)) if dev_acc >= best_dev_acc: best_dev_acc = dev_acc final_test_acc = test_acc last_best_step = global_step torch.save([model, args], model_path) print(f'model saved to {model_path}') model.train() start_time = time.time() global_step += 1 # if global_step >= args.max_steps or global_step - last_best_step >= args.max_steps_before_stop: # end_flag = True # break except (KeyboardInterrupt, RuntimeError) as e: print(e) print() print('training ends in {} steps'.format(global_step)) print('best dev acc: {:.4f} (at step)'.format(best_dev_acc, last_best_step)) print('final test acc: {:.4f}'.format(final_test_acc))
def train(self, eval_set="val", train_set_name="train"): model_short_name = self.args.base_model_path_or_name.split('/')[-1] load_path = os.path.join(self.args.data_path, f"{model_short_name}-ft-{train_set_name}-data.pt") tokenizer = AutoTokenizer.from_pretrained(self.args.base_model_path_or_name) if os.path.isfile(load_path) and not self.args.override: encoded_train_examples = torch.load(load_path) else: data_path = os.path.join(self.args.data_path, f"{train_set_name}.json") examples = read_jsonl(data_path) if self.args.eda_aug: self.logger.info('start eda augmentation') categories = [] for ex in examples: cates = [] for each in ex["categories"].split(","): cates.append(each.split("-")[-1]) categories.extend(cates) class_dist = Counter(categories) self.logger.info(f'dist of categories before augmentation: {json.dumps(class_dist, indent=2)}') examples, class_dist = self.aug_batch_with_eda(examples, class_dist, aug_target=500) self.logger.info(f'dist of categories after eda augmentation: {json.dumps(class_dist, indent=2)}') self.report_data_stats(examples) encoded_train_examples = self.encode_data(tokenizer, examples) torch.save(encoded_train_examples, load_path) categories = [] for each in encoded_train_examples["categories"]: categories.extend(each.split(",")) self.logger.info(f"the dist of categories (training): {json.dumps(Counter(categories), indent=2)}") self.logger.info( f"the dist of priority (training): {json.dumps(Counter(encoded_train_examples['priority']), indent=2)}") self.logger.info( f"the dist of tweets by events (training): {json.dumps(Counter(encoded_train_examples['events']), indent=2)}") eval_dataset = None if eval_set is not None and os.path.isfile(os.path.join(self.args.data_path, f"{eval_set}.json")): data_path = os.path.join(self.args.data_path, f"{eval_set}.json") examples = read_jsonl(data_path) encoded_eval_examples = self.encode_data(tokenizer, examples) eval_dataset = MyDataset(encoded_eval_examples) self.logger.info( f"the dist of tweets by events (eval): {json.dumps(Counter(encoded_eval_examples['events']), indent=2)}") model = MTLModelForSequenceClassification(self.args.base_model_path_or_name, len(self.cate_classes)) train_dataset = MyDataset(encoded_train_examples) model = self.get_model_by_device(model) train_loader = DataLoader(train_dataset, batch_size=self.args.train_batch_size_per_device * self.device_count, num_workers=1, shuffle=True) total_steps = len(train_loader) * self.args.train_epochs / self.args.accumulation_steps no_decay = ["bias", "LayerNorm.weight"] params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)] params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)] optim_groups = [ {"params": params_decay, "weight_decay": self.args.weight_decay}, {"params": params_nodecay, "weight_decay": 0.0}, ] optimizer = AdamW(optim_groups, lr=self.args.training_lr, eps=1e-8) # optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=self.training_args.pre_train_training_lr, eps=1e-8) if self.args.lr_scheduler == "linear": scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_ratio * total_steps, num_training_steps=total_steps) elif self.args.lr_scheduler == "linearconstant": scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=total_steps) else: scheduler = get_constant_schedule(optimizer) multi_label_loss_fn = nn.BCEWithLogitsLoss() regression_loss_fn = nn.MSELoss() model.train() global_step = 0 eval_loss = 0 for i in range(self.args.train_epochs): self.logger.info(f"Epoch {i + 1}:") wrap_dataset_loader = tqdm(train_loader) model.zero_grad() total_epoch_loss = 0 for j, batch in enumerate(wrap_dataset_loader): batch.pop("categories") batch.pop("priority") batch.pop("raw_text") batch.pop("events") categories_indices = batch.pop("categories_indices").to(self.device) priority_score = batch.pop("priority_score").to(self.device) inputs = {k: batch[k].to(self.device) for k in batch} classification_logits, regression_logits = model(inputs) classification_loss = multi_label_loss_fn(classification_logits, categories_indices.float()) regression_loss = regression_loss_fn(regression_logits.view(-1).sigmoid(), priority_score.float()) loss = self.args.alpha * classification_loss + (1 - self.args.alpha) * regression_loss total_epoch_loss += loss.item() eval_loss += loss.item() loss.backward() if (j + 1) % self.args.accumulation_steps == 0: # Clip the norm of the gradients to 1.0. torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() model.zero_grad() global_step += 1 wrap_dataset_loader.update(1) wrap_dataset_loader.set_description( f"MTL-Training - epoch {i + 1}/{self.args.train_epochs} iter {j}/{len(wrap_dataset_loader)}: train loss {loss.item():.8f}. lr {scheduler.get_last_lr()[0]:e}") if self.args.eval_steps > 0 and global_step % self.args.eval_steps == 0: self.logger.info( f"\naverage training loss at global_step={global_step}: {eval_loss / self.args.eval_steps}") eval_loss = 0 if eval_dataset is not None: self.logger.info( f"evaluation during training on {eval_set} set ({model_short_name}_epoch{i + 1}): ") self.inference(model, eval_dataset) model.train() self.logger.info(f"Average training loss for epoch {i + 1}: {total_epoch_loss / len(train_loader)}") # evaluate at the end of epoch if eval_steps is smaller than or equal to 0 if self.args.eval_steps <= 0: self.logger.info(f"evaluation during training on {eval_set} set ({model_short_name}_epoch{i + 1}): ") self.inference(model, eval_dataset) model.train() # save up at end of each epoch! # model.save_pretrained(os.path.join(self.args.output_path, "mtl_train", model_short_name, f"epoch_{i + 1}")) # tokenizer.save_pretrained(os.path.join(self.args.output_path, "mtl_train", model_short_name, f"epoch_{i + 1}")) # save up at end of training! save_model_path = os.path.join(self.args.output_path, "mtl_train", model_short_name if not self.args.eda_aug else model_short_name + "-eda", "final_model") if isinstance(model, DataParallel): model.module.save_pretrained(save_model_path) else: model.save_pretrained(save_model_path) tokenizer.save_pretrained(save_model_path) # eval at the final model saved ck return_dict = {} if eval_dataset is not None: self.logger.info(f"evaluation on test set with mtl-trained model: {save_model_path}") return_dict = {f"mtl_train(eval_set={eval_set})": self.inference(model, eval_dataset)} return return_dict
def train(args): cudnn.enabled = True cudnn.benchmark = True cudnn.deterministic = True print("torch_version:{}".format(torch.__version__)) print("CUDA_version:{}".format(torch.version.cuda)) print("cudnn_version:{}".format(cudnn.version())) init_seed(123456) data_path = args.base_data_path+args.dataset+'/' tokenizer, vocab2id, id2vocab = bert_tokenizer() detokenizer = bert_detokenizer() print('Vocabulary size', len(vocab2id)) if os.path.exists(data_path + 'train_DukeNet.pkl'): query = torch.load(data_path + 'query_DukeNet.pkl') train_samples = torch.load(data_path + 'train_DukeNet.pkl') passage = torch.load(data_path + 'passage_DukeNet.pkl') print("The number of train_samples:", len(train_samples)) else: samples, query, passage = load_default(args.dataset, args.datasetdata_path + args.dataset + '.answer', data_path + args.dataset + '.passage', data_path + args.dataset + '.pool', data_path + args.dataset + '.qrel', data_path + args.dataset + '.query', tokenizer) if args.dataset == "wizard_of_wikipedia": train_samples, dev_samples, test_seen_samples, test_unseen_samples = split_data(args.dataset, data_path + args.dataset + '.split', samples) print("The number of test_seen_samples:", len(test_seen_samples)) print("The number of test_unseen_samples:", len(test_unseen_samples)) torch.save(test_seen_samples, data_path + 'test_seen_DukeNet.pkl') torch.save(test_unseen_samples, data_path + 'test_unseen_DukeNet.pkl') elif args.dataset == "holl_e": train_samples, dev_samples, test_samples, = split_data(args.dataset, data_path + args.dataset + '.split', samples) print("The number of test_samples:", len(test_samples)) torch.save(test_samples, data_path + 'test_DukeNet.pkl') print("The number of train_samples:", len(train_samples)) print("The number of dev_samples:", len(dev_samples)) torch.save(query, data_path + 'query_DukeNet.pkl') torch.save(passage, data_path + 'passage_DukeNet.pkl') torch.save(train_samples, data_path + 'train_DukeNet.pkl') torch.save(dev_samples, data_path + 'dev_DukeNet.pkl') model = DukeNet(vocab2id, id2vocab, args) saved_model_path = os.path.join(args.base_output_path + args.name + "/", 'model/') if args.resume is True: print("Reading checkpoints...") with open(saved_model_path + "checkpoints.json", 'r', encoding='utf-8') as r: checkpoints = json.load(r) last_epoch = checkpoints["time"][-1] fuse_dict = torch.load(os.path.join(saved_model_path, '.'.join([str(last_epoch), 'pkl']))) model.load_state_dict(fuse_dict["model"]) print('Loading success, last_epoch is {}'.format(last_epoch)) else: init_params(model, "enc") freeze_params(model, "enc") last_epoch = -1 if not os.path.exists(saved_model_path): os.makedirs(saved_model_path) with open(saved_model_path + "checkpoints.json", 'w', encoding='utf-8') as w: checkpoints = {"time": []} json.dump(checkpoints, w) # construct an optimizer object model_optimizer = optim.Adam(model.parameters(), args.lr) # model.parameters() Returns an iterator over module parameters.This is typically passed to an optimizer. model_scheduler = get_constant_schedule(model_optimizer) if args.resume is True: model_scheduler.load_state_dict(fuse_dict["scheduler"]) print('Loading scheduler, last_scheduler is', fuse_dict["scheduler"]) trainer = CumulativeTrainer(args.name, model, tokenizer, detokenizer, args.local_rank, accumulation_steps=args.accumulation_steps) model_optimizer.zero_grad() # Clears the gradients of all optimized torch.Tensor s. for i in range(last_epoch+1, args.epoches): if i==5: unfreeze_params(model, "enc") args.train_batch_size = 2 args.accumulation_steps = 16 train_dataset = Dataset(args.mode, train_samples, query, passage, vocab2id, args.max_knowledge_pool_when_train, args.max_knowledge_pool_when_inference, args.context_len, args.knowledge_sentence_len, args.max_dec_length) trainer.train_epoch('train', train_dataset, collate_fn, args.train_batch_size, i, model_optimizer, model_scheduler) del train_dataset trainer.serialize(i, model_scheduler, saved_model_path=saved_model_path)
def main(): parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() data_files = dict(train=data_args.train_file, validation=data_args.validation_file) datasets = load_dataset('csv', data_files=data_files, cache_dir=model_args.cache_dir, delimiter='\t', column_names=['img_id', 'graph', '_', 'text']) # _ is unimportant model, tokenizer = load_model_and_tokenizer(model_args) prefix = "translate graph to text" column_names = datasets["train"].column_names padding = "max_length" if data_args.pad_to_max_length else False def preprocess_function(examples): inputs = [ex for ex in examples['graph']] targets = [ex for ex in examples['text']] inputs = [prefix + inp for inp in inputs] model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) # Setup the tokenizer for targets with tokenizer.as_target_tokenizer(): labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore # padding in the loss. if padding == "max_length" and data_args.ignore_pad_token_for_loss: labels["input_ids"] = [ [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] ] model_inputs["labels"] = labels["input_ids"] return model_inputs if training_args.do_train: train_dataset = datasets["train"] #train_dataset = train_dataset.select(range(1)) if "train" not in datasets: raise ValueError("--do_train requires a train dataset") train_dataset = train_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_eval: if "validation" not in datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = datasets["validation"] #eval_dataset = eval_dataset.select(range(1)) eval_dataset = eval_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id if data_args.pad_to_max_length: data_collator = default_data_collator else: data_collator = DataCollatorForSeq2Seq( tokenizer, model=model, label_pad_token_id=label_pad_token_id, pad_to_multiple_of=8 if training_args.fp16 else None, ) metric = load_metric('sacrebleu') def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] labels = [label.strip() for label in labels] labels = [[label] for label in labels] return preds, labels def compute_metrics(eval_preds): preds, labels = eval_preds if isinstance(preds, tuple): preds = preds[0] decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) if data_args.ignore_pad_token_for_loss: # Replace -100 in the labels as we can't decode them. labels = np.where(labels != -100, labels, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) result = metric.compute(predictions=decoded_preds, references=decoded_labels) result = {"bleu": result["score"]} prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] result["gen_len"] = np.mean(prediction_lens) result = {k: round(v, 4) for k, v in result.items()} return result # this is the recommended t5 finetuning setup from # https://huggingface.co/transformers/main_classes/optimizer_schedules.html#adafactor-pytorch optimizer = Adafactor( model.parameters(), lr=3e-4, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, relative_step=False, scale_parameter=False, warmup_init=False) """ optimizer = AdamW(lr=3e-5) """ lr_scheduler = transformers.get_constant_schedule(optimizer) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics if training_args.predict_with_generate else None, optimizers=(optimizer, lr_scheduler) ) if training_args.do_train: train_result = trainer.train() trainer.save_model() metrics = train_result.metrics max_train_samples = len(train_dataset) metrics["train_samples"] = len(train_dataset) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() if training_args.do_eval: metrics = trainer.evaluate( max_length=data_args.max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval" ) max_val_samples = len(eval_dataset) metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
def start(self, model, data, evaluation): """ :param model: :type model: BertWrapperModel :param data: :type data: MultiData :type evaluation: Evaluation """ self.prepare_training(model, data) if hasattr(model.bert.config, 'adapter_attention'): self.logger.info( "Adapter attention detected. Freezing all weights except the adapter attention" ) for param in model.bert.bert.parameters(): param.requires_grad = False model.bert.bert.enable_adapters(unfreeze_adapters=False, unfreeze_attention=True) elif hasattr(model.bert.config, 'adapters'): self.logger.info( "Adapters detected. Freezing all weights except the adapters") for param in model.bert.bert.parameters(): param.requires_grad = False model.bert.bert.enable_adapters(unfreeze_adapters=True, unfreeze_attention=False) if self.config.get('freeze_bert', False): self.logger.warn('FREEZING BERT') for name, param in model.bert.bert.named_parameters(): self.logger.warn('freeze {}'.format(name)) param.requires_grad = False if self.config.get('freeze_head', False): self.logger.info("Freezing the weights of the classification head") for name, param in model.bert.lin_layer.named_parameters(): param.requires_grad = False # for param in model.bert.parameters(): # if param.requires_grad: # print(param) # Prepare BERT optimizer param_optimizer = list(model.bert.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': self.config.get('weight_decay', 0.0) }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] t_total = self.get_n_batches() * self.n_epochs num_warmup_steps = int(self.get_n_batches() * self.config.get('warmup_proportion', 0.1)) optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.get('learning_rate', 5e-5), eps=self.config.get('adam_epsilon', 1e-8)) # correct_bias=False) if num_warmup_steps > 0: scheduler = transformers.get_constant_schedule_with_warmup( optimizer, num_warmup_steps) else: # scheduler = WarmupConstantSchedule(optimizer=optimizer, warmup_steps=num_warmup_steps) scheduler = transformers.get_constant_schedule(optimizer) self.state.load(model.bert, optimizer, weights='last') start_epoch = self.state.recorded_epochs + 1 end_epoch = self.n_epochs + 1 if self.state.recorded_epochs > 0: self.logger.info( 'Loaded the weights of last epoch {} with valid score={}'. format(self.state.recorded_epochs, self.state.scores[-1])) # if start_epoch < end_epoch and not self.config.get('skip_restore_validation', False): # self.logger.info('Now calculating validation score (to verify the restoring success)') # valid_score = list(evaluation.start(model, data, valid_only=True)[0].values())[0] # self.logger.info('Score={:.4f}'.format(valid_score)) self.logger.info('Running from epoch {} to epoch {}'.format( start_epoch, end_epoch - 1)) global_step = self.get_n_batches() * self.state.recorded_epochs for epoch in range(start_epoch, end_epoch): self.logger.info('Epoch {}/{}'.format(epoch, self.n_epochs)) self.logger.debug('Preparing epoch') self.prepare_next_epoch(model, data, epoch) bar = self.create_progress_bar('loss') train_losses = [] # used to calculate the epoch train loss recent_train_losses = [] # used to calculate the display loss self.logger.debug('Training') self.logger.debug('{} minibatches with size {}'.format( self.get_n_batches(), self.batchsize)) for _ in bar(range(int(self.get_n_batches()))): # self.global_step += self.batchsize train_examples = self.get_next_batch(model, data) batch_loss = 0 batch_steps = int( np.ceil( len(train_examples) / (self.batchsize / self.gradient_accumulation_steps))) for i in range(batch_steps): step_size = self.batchsize // self.gradient_accumulation_steps step_examples = train_examples[i * step_size:(i + 1) * step_size] step_loss = self.get_loss(model, step_examples, self.adapter_task) step_loss = step_loss / self.gradient_accumulation_steps step_loss.backward() batch_loss += step_loss.item() if self.config.get('max_grad_norm', 0) > 0: torch.nn.utils.clip_grad_norm_( model.bert.parameters(), self.config.get('max_grad_norm')) # Gradient clipping is not in AdamW anymore (so you can use amp without issue) optimizer.step() scheduler.step(epoch) optimizer.zero_grad() global_step += 1 self.tensorboard.add_scalars('scores', {'train_loss': batch_loss}, global_step=global_step) recent_train_losses = ([batch_loss] + recent_train_losses)[:20] train_losses.append(recent_train_losses[0]) bar.dynamic_messages['loss'] = np.mean(recent_train_losses) self.logger.info('train loss={:.6f}'.format(np.mean(train_losses))) if self.config.get('evaluate_dev', True): self.logger.info('Now calculating validation score') valid_score, valid_score_other_measures = evaluation.start( model, data, valid_only=True) valid_score = list(valid_score.values( ))[0] # get only the dev split (there wont be any other split) valid_score_other_measures = list( valid_score_other_measures.values())[0] self.tensorboard.add_scalar('valid_score', valid_score, global_step=global_step) self.tensorboard.add_scalars( 'scores', dict([('valid_' + k, v) for k, v, in valid_score_other_measures.items()]), global_step=global_step) for key, value in valid_score_other_measures.items(): self.tensorboard.add_scalar(key, value, global_step=global_step) self.state.record(model.bert, optimizer if self.chkpt_optimizer else None, valid_score, self.backup_checkpoint_every) else: self.logger.info('Not validating dev. Setting score to epoch') self.state.record(model.bert, optimizer if self.chkpt_optimizer else None, epoch, self.backup_checkpoint_every) return self.state.best_epoch, self.state.best_score
def train(args, train_dataset, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriterP(args.output_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) warmup_steps = args.warmup_samples // args.train_batch_size if args.lr_decay: scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps=warmup_steps, t_total=t_total) else: scheduler = get_constant_schedule(optimizer, warmup_steps=warmup_steps) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) try: with open(os.path.join(args.model_name_or_path, 'step.txt'), 'r') as c: global_step = int(c.readline()) except OSError as e: global_step = 0 tr_loss, logging_loss = 0.0, 0.0 moving_loss = MovingLoss(10000 // args.logging_steps) model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed( args) # Added here for reproducibility (even between python 2 and 3) try: for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): inputs, labels = mask_tokens( batch, tokenizer, args) if args.mlm else (batch, batch) inputs = inputs.to(args.device) labels = labels.to(args.device) model.train() outputs = model( inputs, masked_lm_labels=labels) if args.mlm else model( inputs, labels=labels) loss = outputs[ 0] # model outputs are always tuple in pytorch-transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() moving_loss.add(loss.item()) if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 # Log metrics if args.local_rank == -1 and args.evaluate_during_training and global_step % args.eval_steps == 0: # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer, f"checkpoint-{global_step}") for key, value in results.items(): tb_writer.add_scalar('eval_{}'.format(key), value, global_step) if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss epoch_iterator.set_postfix( MovingLoss=f'{moving_loss.loss:.2f}', Perplexity= f'{torch.exp(torch.tensor(moving_loss.loss)):.2f}') if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint save_state(args, model, tokenizer, global_step) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break print_sample(model, tokenizer, args.device, args) if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break except (KeyboardInterrupt, SystemExit): save_state(args, model, tokenizer, global_step) raise if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step
def train(args, training_features, model, tokenizer): """ Train the model """ wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=args, name=args.run_name) wandb.watch(model) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) else: amp = None # model recover recover_step = utils.get_max_epoch_model(args.output_dir) # if recover_step: # model_recover_checkpoint = os.path.join(args.output_dir, "model.{}.bin".format(recover_step)) # logger.info(" ** Recover model checkpoint in %s ** ", model_recover_checkpoint) # model_state_dict = torch.load(model_recover_checkpoint, map_location='cpu') # optimizer_recover_checkpoint = os.path.join(args.output_dir, "optim.{}.bin".format(recover_step)) # checkpoint_state_dict = torch.load(optimizer_recover_checkpoint, map_location='cpu') # checkpoint_state_dict['model'] = model_state_dict # else: checkpoint_state_dict = None model.to(args.device) model, optimizer = prepare_for_training(args, model, checkpoint_state_dict, amp=amp) if args.n_gpu == 0 or args.no_cuda: per_node_train_batch_size = args.per_gpu_train_batch_size * args.gradient_accumulation_steps else: per_node_train_batch_size = args.per_gpu_train_batch_size * args.n_gpu * args.gradient_accumulation_steps train_batch_size = per_node_train_batch_size * ( torch.distributed.get_world_size() if args.local_rank != -1 else 1) global_step = recover_step if recover_step else 0 if args.num_training_steps == -1: args.num_training_steps = int(args.num_training_epochs * len(training_features) / train_batch_size) if args.warmup_portion: args.num_warmup_steps = args.warmup_portion * args.num_training_steps if args.scheduler == "linear": scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.num_training_steps, last_epoch=-1) elif args.scheduler == "constant": scheduler = get_constant_schedule(optimizer, last_epoch=-1) elif args.scheduler == "1cycle": scheduler = OneCycleLR(optimizer, max_lr=args.learning_rate, total_steps=args.num_training_steps, pct_start=args.warmup_portion, anneal_strategy=args.anneal_strategy, final_div_factor=1e4, last_epoch=-1) else: assert False if checkpoint_state_dict: scheduler.load_state_dict(checkpoint_state_dict["lr_scheduler"]) train_dataset = utils.Seq2seqDatasetForBert( features=training_features, max_source_len=args.max_source_seq_length, max_target_len=args.max_target_seq_length, vocab_size=tokenizer.vocab_size, cls_id=tokenizer.cls_token_id, sep_id=tokenizer.sep_token_id, pad_id=tokenizer.pad_token_id, mask_id=tokenizer.mask_token_id, random_prob=args.random_prob, keep_prob=args.keep_prob, offset=train_batch_size * global_step, num_training_instances=train_batch_size * args.num_training_steps, ) logger.info("Check dataset:") for i in range(5): source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens = train_dataset.__getitem__( i) logger.info("Instance-%d" % i) logger.info("Source tokens = %s" % " ".join(tokenizer.convert_ids_to_tokens(source_ids))) logger.info("Target tokens = %s" % " ".join(tokenizer.convert_ids_to_tokens(target_ids))) logger.info("Mode = %s" % str(model)) # Train! logger.info(" ***** Running training ***** *") logger.info(" Num examples = %d", len(training_features)) logger.info(" Num Epochs = %.2f", len(train_dataset) / len(training_features)) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Batch size per node = %d", per_node_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", train_batch_size) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", args.num_training_steps) if args.num_training_steps <= global_step: logger.info( "Training is done. Please use a new dir or clean this dir!") else: # The training features are shuffled train_sampler = SequentialSampler(train_dataset) \ if args.local_rank == -1 else DistributedSampler(train_dataset, shuffle=False) train_dataloader = DataLoader( train_dataset, sampler=train_sampler, batch_size=per_node_train_batch_size // args.gradient_accumulation_steps, collate_fn=utils.batch_list_to_batch_tensors) train_iterator = tqdm.tqdm(train_dataloader, initial=global_step, desc="Iter (loss=X.XXX, lr=X.XXXXXXX)", disable=args.local_rank not in [-1, 0]) model.train() model.zero_grad() tr_loss, logging_loss = 0.0, 0.0 for step, batch in enumerate(train_iterator): batch = tuple(t.to(args.device) for t in batch) inputs = { 'source_ids': batch[0], 'target_ids': batch[1], 'pseudo_ids': batch[2], 'num_source_tokens': batch[3], 'num_target_tokens': batch[4] } loss = model(**inputs) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel (not distributed) training train_iterator.set_description( 'Iter (loss=%5.3f) lr=%9.7f' % (loss.item(), scheduler.get_last_lr()[0])) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() logging_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: wandb.log( { 'lr': scheduler.get_last_lr()[0], 'loss': logging_loss / args.logging_steps }, step=global_step) logger.info(" Step [%d ~ %d]: %.2f", global_step - args.logging_steps, global_step, logging_loss) logging_loss = 0.0 if args.local_rank in [-1, 0] and args.save_steps > 0 and \ (global_step % args.save_steps == 0 or global_step == args.num_training_steps): save_path = os.path.join(args.output_dir, "ckpt-%d" % global_step) os.makedirs(save_path, exist_ok=True) model_to_save = model.module if hasattr( model, "module") else model model_to_save.save_pretrained(save_path) # optim_to_save = { # "optimizer": optimizer.state_dict(), # "lr_scheduler": scheduler.state_dict(), # } # if args.fp16: # optim_to_save["amp"] = amp.state_dict() # torch.save( # optim_to_save, os.path.join(args.output_dir, 'optim.{}.bin'.format(global_step))) logger.info("Saving model checkpoint %d into %s", global_step, save_path) wandb.save(f'{save_path}/*')
def train(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.cuda: torch.cuda.manual_seed(args.seed) print('configuration:') print('\n'.join('\t{:15} {}'.format(k + ':', str(v)) for k, v in sorted(dict(vars(args)).items()))) print() config_path = os.path.join(args.save_dir, 'config.json') model_path = os.path.join(args.save_dir, 'model.pt') log_path = os.path.join(args.save_dir, 'log.csv') if args.save: export_config(args, config_path) check_path(model_path) with open(log_path, 'w') as fout: fout.write('step,train_acc,dev_acc\n') ################################################################################################### # Load data # ################################################################################################### cp_emb = [np.load(path) for path in args.ent_emb_paths] cp_emb = torch.tensor(np.concatenate(cp_emb, 1)) concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1) print('num_concepts: {}, concept_dim: {}'.format(concept_num, concept_dim)) device = torch.device( "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu") dataset = GconAttnDataLoader( train_statement_path=args.train_statements, train_concept_jsonl=args.train_concepts, dev_statement_path=args.dev_statements, dev_concept_jsonl=args.dev_concepts, test_statement_path=args.test_statements, test_concept_jsonl=args.test_concepts, concept2id_path=args.cpnet_vocab_path, batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, device=device, model_name=args.encoder, max_cpt_num=max_cpt_num[args.dataset], max_seq_length=args.max_seq_len, is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, subsample=args.subsample, format=args.format) print('len(train_set): {} len(dev_set): {} len(test_set): {}'.format( dataset.train_size(), dataset.dev_size(), dataset.test_size())) print() ################################################################################################### # Build model # ################################################################################################### lstm_config = get_lstm_config_from_args(args) model = LMGconAttn(model_name=args.encoder, concept_num=concept_num, concept_dim=args.cpt_out_dim, concept_in_dim=concept_dim, freeze_ent_emb=args.freeze_ent_emb, pretrained_concept_emb=cp_emb, hidden_dim=args.decoder_hidden_dim, dropout=args.dropoutm, encoder_config=lstm_config) if args.freeze_ent_emb: freeze_net(model.decoder.concept_emb) try: model.to(device) except RuntimeError as e: print(e) print('best dev acc: 0.0 (at epoch 0)') print('final test acc: 0.0') print() return no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] grouped_parameters = [ { 'params': [ p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr }, { 'params': [ p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.encoder_lr }, { 'params': [ p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr }, { 'params': [ p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.decoder_lr }, ] optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters) if args.lr_schedule == 'fixed': scheduler = get_constant_schedule(optimizer) elif args.lr_schedule == 'warmup_constant': scheduler = get_constant_schedule_with_warmup( optimizer, warmup_steps=args.warmup_steps) elif args.lr_schedule == 'warmup_linear': max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size)) scheduler = get_linear_schedule_with_warmup( optimizer, warmup_steps=args.warmup_steps, t_total=max_steps) print('parameters:') for name, param in model.decoder.named_parameters(): if param.requires_grad: print('\t{:45}\ttrainable\t{}'.format(name, param.size())) else: print('\t{:45}\tfixed\t{}'.format(name, param.size())) num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad) print('\ttotal:', num_params) if args.loss == 'margin_rank': loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean') elif args.loss == 'cross_entropy': loss_func = nn.CrossEntropyLoss(reduction='mean') ################################################################################################### # Training # ################################################################################################### print('-' * 71) global_step, best_dev_epoch = 0, 0 best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0 start_time = time.time() model.train() freeze_net(model.encoder) try: for epoch_id in range(args.n_epochs): if epoch_id == args.unfreeze_epoch: unfreeze_net(model.encoder) if epoch_id == args.refreeze_epoch: freeze_net(model.encoder) model.train() for qids, labels, *input_data in dataset.train(): optimizer.zero_grad() bs = labels.size(0) for a in range(0, bs, args.mini_batch_size): b = min(a + args.mini_batch_size, bs) logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer) if args.loss == 'margin_rank': num_choice = logits.size(1) flat_logits = logits.view(-1) correct_mask = F.one_hot( labels, num_classes=num_choice).view( -1) # of length batch_size*num_choice correct_logits = flat_logits[ correct_mask == 1].contiguous().view(-1, 1).expand( -1, num_choice - 1).contiguous().view( -1) # of length batch_size*(num_choice-1) wrong_logits = flat_logits[ correct_mask == 0] # of length batch_size*(num_choice-1) y = wrong_logits.new_ones((wrong_logits.size(0), )) loss = loss_func(correct_logits, wrong_logits, y) # margin ranking loss elif args.loss == 'cross_entropy': loss = loss_func(logits, labels[a:b]) loss = loss * (b - a) / bs loss.backward() total_loss += loss.item() if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scheduler.step() optimizer.step() if (global_step + 1) % args.log_interval == 0: total_loss /= args.log_interval ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval print( '| step {:5} | lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |' .format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch)) total_loss = 0 start_time = time.time() global_step += 1 model.eval() dev_acc = evaluate_accuracy(dataset.dev(), model) test_acc = evaluate_accuracy( dataset.test(), model) if args.test_statements else 0.0 print('-' * 71) print('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format( global_step, dev_acc, test_acc)) print('-' * 71) if args.save: with open(log_path, 'a') as fout: fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc)) if dev_acc >= best_dev_acc: best_dev_acc = dev_acc final_test_acc = test_acc best_dev_epoch = epoch_id if args.save: torch.save([model, args], model_path) print(f'model saved to {model_path}') model.train() start_time = time.time() if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop: break except (KeyboardInterrupt, RuntimeError) as e: print(e) print() print('training ends in {} steps'.format(global_step)) print('best dev acc: {:.4f} (at epoch {})'.format(best_dev_acc, best_dev_epoch)) print('final test acc: {:.4f}'.format(final_test_acc)) print()
def train(args, train_dataset, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) # Allow for different learning rate for final layers final_layers = [ 'span_outputs.weight', 'span_outputs.bias', 'type_output.weight', 'type_output.bias' ] if args.final_layers_lr == -1.0: args.final_layers_lr = args.learning_rate if args.final_layers_wd == -1.0: args.final_layers_wd = args.weight_decay final_layer_params = [(n, p) for n, p in model.named_parameters() if n in final_layers] non_final_layer_params = [(n, p) for n, p in model.named_parameters() if n not in final_layers] no_decay = ['bias', 'LayerNorm.weight'] final_layer_decaying_params = [ p for n, p in final_layer_params if not any(nd in n for nd in no_decay) ] final_layer_nondecaying_params = [ p for n, p in final_layer_params if any(nd in n for nd in no_decay) ] non_final_layer_decaying_params = [ p for n, p in non_final_layer_params if not any(nd in n for nd in no_decay) ] non_final_layer_nondecaying_params = [ p for n, p in non_final_layer_params if any(nd in n for nd in no_decay) ] optimizer_grouped_parameters = [ { 'params': final_layer_decaying_params, 'lr': args.final_layers_lr, 'weight_decay': args.final_layers_wd }, { 'params': final_layer_nondecaying_params, 'lr': args.final_layers_lr, 'weight_decay': 0.0 }, { 'params': non_final_layer_decaying_params, 'lr': args.learning_rate, 'weight_decay': args.weight_decay }, { 'params': non_final_layer_nondecaying_params, 'lr': args.learning_rate, 'weight_decay': 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) # Allow choice between lr schedules if args.constant_lr and args.warmup_steps == 0: scheduler = get_constant_schedule(optimizer) elif args.constant_lr and args.warmup_steps > 0: scheduler = get_constant_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps) else: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 1 tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed( args) # Added here for reproductibility (even between python 2 and 3) for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): model.train() inputs = { 'input_ids': batch['input_ids'].to(args.device), 'attention_mask': batch['attention_mask'].to(args.device), 'token_type_ids': batch['token_type_ids'].to(args.device), 'start_positions': batch['start_positions'].to(args.device), 'end_positions': batch['end_positions'].to(args.device), 'instance_types': batch['instance_types'].to(args.device) } outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel (not distributed) training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer, dataset_type='dev', prefix=str(global_step)) for key, value in results.items(): tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('lr_final_layers', scheduler.get_lr()[1], global_step) tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, 'checkpoint-{}'.format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = model.module if hasattr( model, 'module' ) else model # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, 'training_args.bin')) logger.info("Saving model checkpoint to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step