def train(args, train_dataset, model, tokenizer): """ Train the model """ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate_fn) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] args.warmup_steps = int(t_total * args.warmup_proportion) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt")): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path ) and "checkpoint" in args.model_name_or_path: # set global_step to gobal_step of last saved checkpoint from model path global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) tr_loss, logging_loss = 0.0, 0.0 if args.do_adv: fgm = FGM(model, emb_name=args.adv_name, epsilon=args.adv_epsilon) model.zero_grad() seed_everything( args.seed ) # Added here for reproductibility (even between python 2 and 3) for _ in range(int(args.num_train_epochs)): pbar = ProgressBar(n_total=len(train_dataloader), desc='Training') for step, batch in enumerate(train_dataloader): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } if args.model_type != "distilbert": # XLM and RoBERTa don"t use segment_ids inputs["token_type_ids"] = (batch[2] if args.model_type in ["bert", "xlnet"] else None) outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in pytorch-transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if args.do_adv: fgm.attack() loss_adv = model(**inputs)[0] if args.n_gpu > 1: loss_adv = loss_adv.mean() loss_adv.backward() fgm.restore() pbar(step, {'loss': loss.item()}) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scheduler.step() # Update learning rate schedule optimizer.step() model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics print(" ") if args.local_rank == -1: # Only evaluate when single GPU otherwise metrics may not average well evaluate(args, model, tokenizer) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) # Take care of distributed/parallel training model_to_save = (model.module if hasattr(model, "module") else model) model_to_save.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) tokenizer.save_vocabulary(output_dir) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) print(" ") if 'cuda' in str(args.device): torch.cuda.empty_cache() return global_step, tr_loss / global_step
def train(args, train_features, model, tokenizer, use_crf): """ Train the model """ # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] bert_param_optimizer = list(model.bert.named_parameters()) if args.model_encdec == 'bert2crf': crf_param_optimizer = list(model.crf.named_parameters()) linear_param_optimizer = list(model.classifier.named_parameters()) optimizer_grouped_parameters = [{ 'params': [ p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.learning_rate }, { 'params': [ p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.learning_rate }, { 'params': [ p for n, p in crf_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.crf_learning_rate }, { 'params': [ p for n, p in crf_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.crf_learning_rate }, { 'params': [ p for n, p in linear_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.crf_learning_rate }, { 'params': [ p for n, p in linear_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.crf_learning_rate }] elif args.model_encdec == 'bert2gru': gru_param_optimizer = list(model.decoder.named_parameters()) linear_param_optimizer = list(model.clsdense.named_parameters()) optimizer_grouped_parameters = [{ 'params': [ p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.learning_rate }, { 'params': [ p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.learning_rate }, { 'params': [ p for n, p in gru_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.crf_learning_rate }, { 'params': [ p for n, p in gru_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.crf_learning_rate }, { 'params': [ p for n, p in linear_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.crf_learning_rate }, { 'params': [ p for n, p in linear_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.crf_learning_rate }] elif args.model_encdec == 'bert2soft': linear_param_optimizer = list(model.classifier.named_parameters()) optimizer_grouped_parameters = [{ 'params': [ p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.learning_rate }, { 'params': [ p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.learning_rate }, { 'params': [ p for n, p in linear_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.crf_learning_rate }, { 'params': [ p for n, p in linear_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.crf_learning_rate }] elif args.model_encdec == 'multi2point': # gru_param_optimizer = list(model.decoder.named_parameters()) linear_param_optimizer = list(model.pointer.named_parameters()) optimizer_grouped_parameters = [{ 'params': [ p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.learning_rate }, { 'params': [ p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.learning_rate }, { 'params': [ p for n, p in linear_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.point_learning_rate }, { 'params': [ p for n, p in linear_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.point_learning_rate }] t_total = len(train_features) // args.batch_size * args.num_train_epochs args.warmup_steps = int(t_total * args.warmup_proportion) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt")): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_features)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.batch_size * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) # logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 tr_loss, logging_loss = 0.0, 0.0 pre_result = {} model.zero_grad() seed_everything( args.seed ) # Added here for reproductibility (even between python 2 and 3) total_step = 0 best_spanf = -1 test_results = {} for ep in range(int(args.num_train_epochs)): pbar = ProgressBar(n_total=len(train_features) // args.batch_size, desc='Training') if ep == int(args.num_train_epochs) - 1: eval_features = load_and_cache_examples(args, args.data_type, tokenizer, data_type='dev') train_features.extend(eval_features) step = 0 for batch in batch_generator(features=train_features, batch_size=args.batch_size, use_crf=use_crf, answer_seq_len=args.answer_seq_len): batch_input_ids, batch_input_mask, batch_segment_ids, batch_label_ids, batch_multi_span_label, batch_context_mask, batch_start_position, batch_end_position, batch_raw_labels, _, batch_example = batch model.train() if args.model_encdec == 'bert2crf' or args.model_encdec == 'bert2gru' or args.model_encdec == 'bert2soft': batch_inputs = tuple(t.to(args.device) for t in batch[0:6]) inputs = { "input_ids": batch_inputs[0], "attention_mask": batch_inputs[1], "token_type_ids": batch_inputs[2], "context_mask": batch_inputs[5], "labels": batch_inputs[3], "testing": False } elif args.model_encdec == 'multi2point': batch_inputs = tuple(t.to(args.device) for t in batch[0:5]) inputs = { "input_ids": batch_inputs[0], "attention_mask": batch_inputs[1], "token_type_ids": batch_inputs[2], "span_label": batch_inputs[4], "testing": False } outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in pytorch-transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training loss.backward() if step % 15 == 0: pbar(step, {'epoch': ep, 'loss': loss.item()}) step += 1 tr_loss += loss.item() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scheduler.step() # Update learning rate schedule optimizer.step() model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics print("start evalue") if args.local_rank == -1: # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args=args, model=model, tokenizer=tokenizer, prefix="dev", use_crf=use_crf) span_f = results['span_f'] if span_f > best_spanf: output_dir = os.path.join(args.output_dir, "checkpoint-bestf") if os.path.exists(output_dir): shutil.rmtree(output_dir) print('remove file', args.output_dir) print('\n\n eval results:', results) test_results = evaluate(args=args, model=model, tokenizer=tokenizer, prefix="test", use_crf=use_crf) print('\n\n test results', test_results) print('\n epoch = :', ep) best_spanf = span_f os.makedirs(output_dir) # print('dir = ', output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) torch.save( args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) tokenizer.save_vocabulary(output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info( "Saving optimizer and scheduler states to %s", output_dir) np.random.seed() np.random.shuffle(train_features) logger.info("\n") # if 'cuda' in str(args.device): torch.cuda.empty_cache() return global_step, tr_loss / global_step, test_results