train_df = pd.read_csv(os.path.join(config.processed_train, file_name + '.csv'), index_col=0) step_grad = config.gradient_accumulation_step_dict[file_name] epoch_total += (len(train_df) // batch_size + 1) // step_grad + 1 print('epoch_total', epoch_total) optimizer = BertAdam(optimizer_grouped_parameters, lr=config.learning_rate, warmup=config.warmup_proportion, t_total=epoch_total * config.epochs) # 下面这行根据实际需求是否注释 model, optimizer = train_valid(model, criterion, optimizer) loss_weights = EpochLossWeight() loss_weights_dict = loss_weights.run() # 进行最后一次模型数据增强与结果测试 # if os.path.exists(config.checkpoint_file): # print('正在加载最后一个模型', config.checkpoint_file) # checkpoint = torch.load(config.checkpoint_file) # model.load_state_dict(checkpoint['model_state']) # optimizer.load_state_dict(checkpoint['optimizer_state']) train_enhance(model, criterion, optimizer, loss_weights_dict) test_model(model, 'last_predict', 'last_predict_submit.zip') # 进行最佳模型的的数据增强与结果测试 if os.path.exists(config.best_checkpoint_file): print('正在加载最佳模型', config.best_checkpoint_file) checkpoint = torch.load(config.best_checkpoint_file) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) train_enhance(model, criterion, optimizer, loss_weights_dict) test_model(model, 'best_predict', 'best_predict_submit.zip')
param_optimizer = list(model.named_parameters()) param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = BertAdam(optimizer_grouped_parameters, lr=lr) if os.path.isfile(model_path): print('=' * 80) print('load training param, ', model_path) print('=' * 80) state = torch.load(model_path) model.load_state_dict(state['best_model_state']) optimizer.load_state_dict(state['best_opt_state']) epoch_list = range(state['best_epoch'] + 1, state['best_epoch'] + 1 + EPOCH) global_step = state['best_step'] state = {} ## FIXME: 临时的解决方案 else: state = None epoch_list = range(EPOCH) global_step = 0 grade = 0 if on_windows: print_every = 50 val_every = [50, 70, 50, 35] else:
class QAModel(object): """ High level model that handles intializing the underlying network architecture, saving, updating examples, and predicting examples. """ def __init__(self, opt, embedding=None, state_dict=None): # Book-keeping. self.opt = opt self.updates = state_dict['updates'] if state_dict else 0 self.eval_embed_transfer = True self.train_loss = AverageMeter() # Building network. self.network = FlowQA(opt, embedding) if state_dict: new_state = set(self.network.state_dict().keys()) for k in list(state_dict['network'].keys()): if k not in new_state: del state_dict['network'][k] self.network.load_state_dict(state_dict['network']) parameters = [p for p in self.network.parameters() if p.requires_grad] self.total_param = sum([p.nelement() for p in parameters]) # Building optimizer. if opt['finetune_bert'] != 0: bert_params = [ p for p in self.network.bert.parameters() if p.requires_grad ] self.bertadam = BertAdam(bert_params, lr=opt['bert_lr'], warmup=opt['bert_warmup'], t_total=opt['bert_t_total']) non_bert_params = [] for p in parameters: for bp in bert_params: if p is bp: break else: non_bert_params.append(p) parameters = non_bert_params if opt['optimizer'] == 'sgd': self.optimizer = optim.SGD(parameters, opt['learning_rate'], momentum=opt['momentum'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adamax': self.optimizer = optim.Adamax(parameters, weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adadelta': self.optimizer = optim.Adadelta(parameters, rho=0.95, weight_decay=opt['weight_decay']) else: raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer']) if state_dict: self.optimizer.load_state_dict(state_dict['optimizer']) if opt['cuda']: for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() if opt['finetune_bert'] != 0 and 'bertadam' in state_dict: self.bertadam.load_state_dict(state_dict['bertadam']) if opt['cuda']: for state in self.bertadam.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() if opt['fix_embeddings']: wvec_size = 0 else: wvec_size = (opt['vocab_size'] - opt['tune_partial']) * opt['embedding_dim'] def update(self, batch): # Train mode self.network.train() torch.set_grad_enabled(True) if self.opt['use_bert']: context_bertidx = batch[19] context_bert_spans = batch[20] question_bertidx = batch[21] question_bert_spans = batch[22] # Transfer to GPU if self.opt['cuda']: inputs = [e.cuda(non_blocking=True) for e in batch[:9]] overall_mask = batch[9].cuda(non_blocking=True) answer_s = batch[10].cuda(non_blocking=True) answer_e = batch[11].cuda(non_blocking=True) answer_c = batch[12].cuda(non_blocking=True) rationale_s = batch[13].cuda(non_blocking=True) rationale_e = batch[14].cuda(non_blocking=True) if self.opt['use_bert']: context_bertidx = [ x.cuda(non_blocking=True) for x in context_bertidx ] else: inputs = [e for e in batch[:9]] overall_mask = batch[9] answer_s = batch[10] answer_e = batch[11] answer_c = batch[12] rationale_s = batch[13] rationale_e = batch[14] # Run forward # output: [batch_size, question_num, context_len], [batch_size, question_num] if self.opt['use_bert']: score_s, score_e, score_c = self.network(*inputs, context_bertidx, context_bert_spans, question_bertidx, question_bert_spans) else: score_s, score_e, score_c = self.network(*inputs) # Compute loss and accuracies if self.opt['use_elmo']: loss = self.opt['elmo_lambda'] * ( self.network.elmo.scalar_mix_0.scalar_parameters[0]**2 + self.network.elmo.scalar_mix_0.scalar_parameters[1]**2 + self.network.elmo.scalar_mix_0.scalar_parameters[2]**2 ) # ELMo L2 regularization else: loss = 0 all_no_span = (answer_c != 3) answer_s.masked_fill_(all_no_span, -100) # ignore_index is -100 in F.cross_entropy answer_e.masked_fill_(all_no_span, -100) rationale_s.masked_fill_( all_no_span, -100) # ignore_index is -100 in F.cross_entropy rationale_e.masked_fill_(all_no_span, -100) for i in range(overall_mask.size(0)): q_num = sum(overall_mask[i] ) # the true question number for this sampled context target_s = answer_s[i, :q_num] # Size: q_num target_e = answer_e[i, :q_num] target_c = answer_c[i, :q_num] target_s_r = rationale_s[i, :q_num] target_e_r = rationale_e[i, :q_num] target_no_span = all_no_span[i, :q_num] # single_loss is averaged across q_num single_loss = (F.cross_entropy(score_c[i, :q_num], target_c) * q_num.item() / 15.0 + F.cross_entropy(score_s[i, :q_num], target_s) * (q_num - sum(target_no_span)).item() / 12.0 + F.cross_entropy(score_e[i, :q_num], target_e) * (q_num - sum(target_no_span)).item() / 12.0) #+ self.opt['rationale_lambda'] * F.cross_entropy(score_s_r[i, :q_num], target_s_r) * (q_num - sum(target_no_span)).item() / 12.0 #+ self.opt['rationale_lambda'] * F.cross_entropy(score_e_r[i, :q_num], target_e_r) * (q_num - sum(target_no_span)).item() / 12.0) loss = loss + (single_loss / overall_mask.size(0)) self.train_loss.update(loss.item(), overall_mask.size(0)) ''' # Clear gradients and run backward self.optimizer.zero_grad() loss.backward() ''' loss = loss / self.opt['aggregate_grad_steps'] loss.backward() ''' # Clip gradients torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.opt['grad_clipping']) # Update parameters self.optimizer.step() self.updates += 1 ''' return loss def take_step(self): # Clip Gradients torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.opt['grad_clipping']) # Update parameters self.optimizer.step() if self.opt['finetune_bert']: self.bertadam.step() self.updates += 1 # Reset any partially fixed parameters (e.g. rare words) self.reset_embeddings() self.eval_embed_transfer = True # Clear gradients and run backward self.optimizer.zero_grad() if self.opt['finetune_bert']: self.bertadam.zero_grad() def predict(self, batch): # Eval mode self.network.eval() torch.set_grad_enabled(False) # Transfer trained embedding to evaluation embedding if self.eval_embed_transfer: self.update_eval_embed() self.eval_embed_transfer = False if self.opt['use_bert']: context_bertidx = batch[19] context_bert_spans = batch[20] question_bertidx = batch[21] question_bert_spans = batch[22] # Transfer to GPU if self.opt['cuda']: inputs = [e.cuda(non_blocking=True) for e in batch[:9]] if self.opt['use_bert']: context_bertidx = [ x.cuda(non_blocking=True) for x in context_bertidx ] else: inputs = [e for e in batch[:9]] # Run forward # output: [batch_size, question_num, context_len], [batch_size, question_num] if self.opt['use_bert']: score_s, score_e, score_c = self.network(*inputs, context_bertidx, context_bert_spans, question_bertidx, question_bert_spans) else: score_s, score_e, score_c = self.network(*inputs) score_s = F.softmax(score_s, dim=2) score_e = F.softmax(score_e, dim=2) # Transfer to CPU/normal tensors for numpy ops score_s = score_s.data.cpu() score_e = score_e.data.cpu() score_c = score_c.data.cpu() # Get argmax text spans text = batch[15] spans = batch[16] overall_mask = batch[9] predictions = [] max_len = self.opt['max_len'] or score_s.size(2) for i in range(overall_mask.size(0)): for j in range(overall_mask.size(1)): if overall_mask[i, j] == 0: # this dialog has ended break ans_type = np.argmax(score_c[i, j]) if ans_type == 0: predictions.append("unknown") elif ans_type == 1: predictions.append("Yes") elif ans_type == 2: predictions.append("No") else: scores = torch.ger(score_s[i, j], score_e[i, j]) scores.triu_().tril_(max_len - 1) scores = scores.numpy() s_idx, e_idx = np.unravel_index(np.argmax(scores), scores.shape) s_offset, e_offset = spans[i][s_idx][0], spans[i][e_idx][1] predictions.append(text[i][s_offset:e_offset]) return predictions # list of (list of strings) # allow the evaluation embedding be larger than training embedding # this is helpful if we have pretrained word embeddings def setup_eval_embed(self, eval_embed, padding_idx=0): # eval_embed should be a supermatrix of training embedding self.network.eval_embed = nn.Embedding(eval_embed.size(0), eval_embed.size(1), padding_idx=padding_idx) self.network.eval_embed.weight.data = eval_embed for p in self.network.eval_embed.parameters(): p.requires_grad = False self.eval_embed_transfer = True if hasattr(self.network, 'CoVe'): self.network.CoVe.setup_eval_embed(eval_embed) def update_eval_embed(self): # update evaluation embedding to trained embedding if self.opt['tune_partial'] > 0: offset = self.opt['tune_partial'] self.network.eval_embed.weight.data[0:offset] \ = self.network.embedding.weight.data[0:offset] else: offset = 10 self.network.eval_embed.weight.data[0:offset] \ = self.network.embedding.weight.data[0:offset] def reset_embeddings(self): # Reset fixed embeddings to original value if self.opt['tune_partial'] > 0: offset = self.opt['tune_partial'] if offset < self.network.embedding.weight.data.size(0): self.network.embedding.weight.data[offset:] \ = self.network.fixed_embedding def get_pretrain(self, state_dict): own_state = self.network.state_dict() for name, param in state_dict.items(): if name not in own_state: continue if isinstance(param, Parameter): param = param.data try: own_state[name].copy_(param) except: print("Skip", name) continue def save(self, filename, epoch): params = { 'state_dict': { 'network': self.network.state_dict(), 'optimizer': self.optimizer.state_dict(), 'updates': self.updates # how many updates }, 'config': self.opt, 'epoch': epoch } if self.opt['finetune_bert']: params['state_dict']['bertadam'] = self.bertadam.state_dict() try: torch.save(params, filename) logger.info('model saved to {}'.format(filename)) except BaseException: logger.warn('[ WARN: Saving failed... continuing anyway. ]') def save_for_predict(self, filename, epoch): network_state = dict([(k, v) for k, v in self.network.state_dict().items() if k[0:4] != 'CoVe']) if 'eval_embed.weight' in network_state: del network_state['eval_embed.weight'] if 'fixed_embedding' in network_state: del network_state['fixed_embedding'] params = { 'state_dict': { 'network': network_state }, 'config': self.opt, } try: torch.save(params, filename) logger.info('model saved to {}'.format(filename)) except BaseException: logger.warn('[ WARN: Saving failed... continuing anyway. ]') def cuda(self): self.network.cuda()
class BertTrainer: def __init__(self, hypers: Hypers, model_name, checkpoint, **extra_model_args): """ initialize the BertOptimizer, with common logic for setting weight_decay_rate, doing gradient accumulation and tracking loss :param hypers: the core hyperparameters for the bert model :param model_name: the fully qualified name of the bert model we will train like pytorch_pretrained_bert.modeling.BertForQuestionAnswering :param checkpoint: if resuming training, this is the checkpoint that contains the optimizer state as checkpoint['optimizer'] """ self.init_time = time.time() self.model = self.get_model(hypers, model_name, checkpoint, **extra_model_args) self.step = 0 self.hypers = hypers self.train_stats = TrainStats(hypers) self.model.train() logger.info('configured model for training') # show parameter names # logger.info(str([n for (n, p) in self.model.named_parameters()])) # Prepare optimizer if hasattr(hypers, 'exclude_pooler') and hypers.exclude_pooler: # module.bert.pooler.dense.weight, module.bert.pooler.dense.bias # see https://github.com/NVIDIA/apex/issues/131 self.param_optimizer = [ (n, p) for (n, p) in self.model.named_parameters() if '.pooler.' not in n ] else: self.param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in self.param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay_rate': 0.01 }, { 'params': [ p for n, p in self.param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay_rate': 0.0 }] self.t_total = hypers.num_train_steps self.global_step = hypers.global_step if hypers.fp16: try: from apex.optimizers import FP16_Optimizer from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=hypers.learning_rate, bias_correction=False, max_grad_norm=1.0) if hypers.loss_scale == 0: self.optimizer = FP16_Optimizer( optimizer, dynamic_loss_scale=True, verbose=(hypers.global_rank == 0)) else: self.optimizer = FP16_Optimizer( optimizer, static_loss_scale=hypers.loss_scale, verbose=(hypers.global_rank == 0)) else: self.optimizer = BertAdam(optimizer_grouped_parameters, lr=hypers.learning_rate, warmup=hypers.warmup_proportion, t_total=self.t_total) logger.info('created optimizer') if checkpoint and type( checkpoint) is dict and 'optimizer' in checkpoint: self.optimizer.load_state_dict(checkpoint['optimizer']) if hypers.fp16: pass else: # if we load this state, we need to set the t_total to what we passed, not what was saved self.optimizer.set_t_total(self.t_total) # show state of optimizer lrs = self.optimizer.get_lr() logger.info('Min and max learn rate: %s', str([min(lrs), max(lrs)])) logger.info('Min and max step in state: %s', str(self.optimizer.get_steps())) instances_per_step = hypers.train_batch_size * hypers.gradient_accumulation_steps * hypers.world_size if 'seen_instances' in checkpoint: self.global_step = int(checkpoint['seen_instances'] / instances_per_step) self.train_stats.previous_instances = checkpoint[ 'seen_instances'] logger.info('got global step from checkpoint = %i', self.global_step) logger.info('Loaded optimizer state:') logger.info(repr(self.optimizer)) def reset(self): """ reset any gradient accumulation :return: """ self.model.zero_grad() self.step = 0 def should_continue(self): """ :return: True if training should continue """ if self.global_step >= self.t_total: logger.info( 'stopping due to train step %i >= target train steps %i', self.global_step, self.t_total) return False if 0 < self.hypers.time_limit <= (time.time() - self.init_time): logger.info('stopping due to time out %i seconds', self.hypers.time_limit) return False return True def save_simple(self, filename): if self.hypers.global_rank != 0: logger.info('skipping save in %i', torch.distributed.get_rank()) return model_to_save = self.model.module if hasattr( self.model, 'module') else self.model # Only save the model itself torch.save(model_to_save.state_dict(), filename) logger.info(f'saved model only to {filename}') def save(self, filename, **extra_checkpoint_info): """ save a checkpoint with the model parameters, the optimizer state and any additional checkpoint info :param filename: :param extra_checkpoint_info: :return: """ # only local_rank 0, in fact only global rank 0 if self.hypers.global_rank != 0: logger.info('skipping save in %i', torch.distributed.get_rank()) return start_time = time.time() checkpoint = extra_checkpoint_info model_to_save = self.model.module if hasattr( self.model, 'module') else self.model # Only save the model itself os.makedirs(os.path.dirname(filename), exist_ok=True) # also save the optimizer state, since we will likely resume from partial pre-training checkpoint['state_dict'] = model_to_save.state_dict() checkpoint['optimizer'] = self.optimizer.state_dict() # include world size in instances_per_step calculation instances_per_step = self.hypers.train_batch_size * \ self.hypers.gradient_accumulation_steps * \ self.hypers.world_size checkpoint['seen_instances'] = self.global_step * instances_per_step checkpoint['num_instances'] = self.t_total * instances_per_step # CONSIDER: also save hypers? torch.save(checkpoint, filename) logger.info( f'saved model to {filename} in {time.time()-start_time} seconds') def get_instance_count(self): instances_per_step = self.hypers.train_batch_size * \ self.hypers.gradient_accumulation_steps * \ self.hypers.world_size return self.global_step * instances_per_step def step_loss(self, loss): """ accumulates the gradient, tracks the loss and applies the gradient to the model :param loss: the loss from evaluating the model """ if self.global_step == 0: logger.info('first step_loss') if self.hypers.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. self.train_stats.note_loss(loss.item()) if self.hypers.gradient_accumulation_steps > 1: loss = loss / self.hypers.gradient_accumulation_steps if self.hypers.fp16: self.optimizer.backward(loss) else: loss.backward() if (self.step + 1) % self.hypers.gradient_accumulation_steps == 0: lr_this_step = self.hypers.learning_rate * warmup_linear( self.global_step / self.t_total, self.hypers.warmup_proportion) for param_group in self.optimizer.param_groups: param_group['lr'] = lr_this_step self.optimizer.step() self.model.zero_grad() self.global_step += 1 self.step += 1 @classmethod def get_files(cls, train_file, completed_files): logger.info('completed files = %s, count = %i', str(completed_files[:min(5, len(completed_files))]), len(completed_files)) # multiple train files if not os.path.isdir(train_file): train_files = [train_file] else: if not train_file.endswith('/'): train_file = train_file + '/' train_files = glob.glob(train_file + '**', recursive=True) train_files = [f for f in train_files if not os.path.isdir(f)] # exclude completed files if not set(train_files) == set(completed_files): train_files = [f for f in train_files if f not in completed_files] else: completed_files = [] # new epoch logger.info('train files = %s, count = %i', str(train_files[:min(5, len(train_files))]), len(train_files)) return train_files, completed_files @classmethod def get_model(cls, hypers, model_name, checkpoint, **extra_model_args): override_state_dict = None if checkpoint: if type(checkpoint) is dict and 'state_dict' in checkpoint: logger.info('loading from multi-part checkpoint') override_state_dict = checkpoint['state_dict'] else: logger.info('loading from saved model parameters') override_state_dict = checkpoint # create the model object by name # https://stackoverflow.com/questions/4821104/python-dynamic-instantiation-from-string-name-of-a-class-in-dynamically-imported import importlib clsdot = model_name.rfind('.') class_ = getattr(importlib.import_module(model_name[0:clsdot]), model_name[clsdot + 1:]) model_args = { 'state_dict': override_state_dict, 'cache_dir': PYTORCH_PRETRAINED_BERT_CACHE } model_args.update(extra_model_args) # logger.info(pprint.pformat(extra_model_args, indent=4)) model = class_.from_pretrained(hypers.bert_model, **model_args) logger.info('built model') # configure model for fp16, multi-gpu and/or distributed training if hypers.fp16: model.half() logger.info('model halved') logger.info('sending model to %s', str(hypers.device)) model.to(hypers.device) logger.info('sent model to %s', str(hypers.device)) if hypers.local_rank != -1: if not hypers.no_apex: try: from apex.parallel import DistributedDataParallel as DDP model = DDP(model) except ImportError: raise ImportError("Please install apex") else: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[hypers.local_rank], output_device=hypers.local_rank) logger.info('using DistributedDataParallel for world size %i', hypers.world_size) elif hypers.n_gpu > 1: model = torch.nn.DataParallel(model) return model @classmethod def get_base_parser(cls): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--bert_model", default=None, type=str, required=True, help= "Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) # Other parameters parser.add_argument( "--num_instances", default=-1, type=int, help="Total number of training instances to train over.") parser.add_argument( "--seen_instances", default=-1, type=int, help= "When resuming training, the number of instances we have already trained over." ) parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10% of training.") parser.add_argument("--no_cuda", default=False, action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--no_apex", default=False, action='store_true', help="Whether not to use apex when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--optimize_on_cpu', default=False, action='store_true', help= "Whether to perform optimization and keep the optimizer averages on CPU" ) parser.add_argument( '--fp16', default=False, action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--loss_scale', type=float, default=0, help= 'Loss scaling, positive power of 2 values can improve fp16 convergence. ' 'Leave at zero to use dynamic loss scaling') return parser
def main(): parser = argparse.ArgumentParser() parser.add_argument("--generation_dataset", default='openi', type=str, help=["mimic-cxr, openi"]) parser.add_argument("--vqa_rad", default="all", type=str, choices=["all", "chest", "head", "abd"]) parser.add_argument("--data_set", default="train", type=str, help="train | valid") parser.add_argument('--img_hidden_sz', type=int, default=2048, help="Whether to use amp for fp16") parser.add_argument( "--bert_model", default="bert-base-uncased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument( "--mlm_task", type=str, default=True, help="The model will train only mlm task!! | True | False") parser.add_argument("--train_batch_size", default=2, type=int, help="Total batch size for training.") parser.add_argument("--num_train_epochs", default=5, type=int, help="Total number of training epochs to perform.") parser.add_argument( '--from_scratch', action='store_true', default=False, help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument("--img_encoding", type=str, default='fully_use_cnn', choices=['random_sample', 'fully_use_cnn']) parser.add_argument( '--len_vis_input', type=int, default=256, help="The length of visual token input" ) #visual token의 fixed length를 100이라 하면, <Unknown> token 100개가 되고, 100개의 word 생성 가능. parser.add_argument('--max_len_b', type=int, default=253, help="Truncate_config: maximum length of segment B.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=10, help="Max tokens of prediction.") parser.add_argument( '--s2s_prob', default=1, type=float, help= "Percentage of examples that are bi-uni-directional LM (seq2seq). This must be turned off!!!!!!! because this is not for seq2seq model!!!" ) parser.add_argument( '--bi_prob', default=0, type=float, help="Percentage of examples that are bidirectional LM.") parser.add_argument('--hidden_size', type=int, default=768) parser.add_argument('--bar', default=False, type=str, help="True or False") parser.add_argument("--config_path", default='./pretrained_model/non_cross/config.json', type=str, help="Bert config file path.") parser.add_argument( "--model_recover_path", default='./pretrained_model/non_cross/pytorch_model.bin', type=str, help="The file of fine-tuned pretraining model.") # model load parser.add_argument( "--output_dir", default='./output_model/base_noncross_mimic_2', type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_file", default="training.log", type=str, help="The output directory where the log will be written.") parser.add_argument('--img_postion', default=True, help="It will produce img_position.") parser.add_argument( "--do_train", action='store_true', default=True, help="Whether to run training. This should ALWAYS be set to True.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) ############################################################################################################ parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument("--global_rank", type=int, default=-1, help="global_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( '--fp16', action='store_true', default=False, help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', default=False, help= "Whether to use 32-bit float precision instead of 32-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', default=False, help="Whether to use amp for fp16") parser.add_argument('--new_segment_ids', default=False, action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument( '--trunc_seg', default='b', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument("--num_workers", default=20, type=int, help="Number of workers for the data loader.") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument( '--image_root', type=str, default='/home/mimic-cxr/dataset/image_preprocessing/re_512_3ch/Train') parser.add_argument('--split', type=str, nargs='+', default=['train', 'valid']) parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist_url', default='file://[PT_OUTPUT_DIR]/nonexistent_file', type=str, help='url used to set up distributed training') parser.add_argument('--sche_mode', default='warmup_linear', type=str, help="warmup_linear | warmup_constant | warmup_cosine") parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--use_num_imgs', default=-1, type=int) parser.add_argument('--max_drop_worst_ratio', default=0, type=float) parser.add_argument('--drop_after', default=6, type=int) parser.add_argument('--tasks', default='report_generation', help='report_generation | vqa') parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") args = parser.parse_args() print('global_rank: {}, local rank: {}'.format(args.global_rank, args.local_rank)) args.max_seq_length = args.max_len_b + args.len_vis_input + 3 # +3 for 2x[SEP] and [CLS] args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir) if args.tasks == 'vqa': wandb.init(config=args, project="VQA") wandb.config["more"] = "custom" args.src_file = '/home/mimic-cxr/dataset/data_RAD' args.file_valid_jpgs = '/home/mimic-cxr/dataset/vqa_rad_original_set.json' else: if args.generation_dataset == 'mimic-cxr': wandb.init(config=args, project="report_generation") wandb.config["more"] = "custom" args.src_file = '/home/mimic-cxr/new_dset/Train_253.jsonl' args.file_valid_jpgs = '/home/mimic-cxr/new_dset/Train_253.jsonl' else: wandb.init(config=args, project="report_generation") wandb.config["more"] = "custom" args.src_file = '/home/mimic-cxr/dataset/open_i/Train_openi.jsonl' args.file_valid_jpgs = '/home/mimic-cxr/dataset/open_i/Valid_openi.jsonl' print(" # PID :", os.getpid()) os.makedirs(args.output_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) logging.basicConfig( filename=os.path.join(args.output_dir, args.log_file), filemode='w', format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") print("device", device) n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) print("device", device) n_gpu = 1 torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.global_rank) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True) if args.do_train: print("args.mask_prob", args.mask_prob) print("args.train_batch_size", args.train_batch_size) bi_uni_pipeline = [ data_loader.Preprocess4Seq2seq( args, args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, args.bar, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mode="s2s", len_vis_input=args.len_vis_input, local_rank=args.local_rank, load_vqa_set=(args.tasks == 'vqa')) ] bi_uni_pipeline.append( data_loader.Preprocess4Seq2seq( args, args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, args.bar, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mode="bi", len_vis_input=args.len_vis_input, local_rank=args.local_rank, load_vqa_set=(args.tasks == 'vqa'))) train_dataset = data_loader.Img2txtDataset( args, args.data_set, args.src_file, args.image_root, args.split, args.train_batch_size, tokenizer, args.max_seq_length, file_valid_jpgs=args.file_valid_jpgs, bi_uni_pipeline=bi_uni_pipeline, use_num_imgs=args.use_num_imgs, s2s_prob=args.s2s_prob, # this must be set to 1. bi_prob=args.bi_prob, tasks=args.tasks) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True) t_total = int( len(train_dataloader) * args.num_train_epochs * 1. / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 relax_projection = 4 if args.relax_projection else 0 task_idx_proj = 3 if args.tasks == 'report_generation' else 0 mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[PAD]"]) # index in BERT vocab: 103, 102, 0 # BERT model will be loaded! from scratch if (args.model_recover_path is None): _state_dict = {} if args.from_scratch else None _state_dict = {} model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, args=args, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, len_vis_input=args.len_vis_input, tasks=args.tasks) print("scratch model's statedict : ") for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size()) global_step = 0 print("The model will train from scratch") else: print("Task :", args.tasks, args.s2s_prob) print("Recoverd model :", args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(model_recover_path) for key in list(model_recover.keys()): model_recover[key.replace('enc.', '').replace( 'mlm.', 'cls.')] = model_recover.pop(key) global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, args=args, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, len_vis_input=args.len_vis_input, tasks=args.tasks) model.load_state_dict(model_recover, strict=False) print("The pretrained model loaded and fine-tuning.") del model_recover torch.cuda.empty_cache() if args.fp16: model.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: model = DataParallelImbalance(model) wandb.watch(model) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, schedule=args.sche_mode, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load( os.path.join(args.output_dir, "optim.{0}.bin".format(recover_step))) if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") model.train() print("Total Parameters:", sum([p.nelement() for p in model.parameters()])) if recover_step: start_epoch = recover_step + 1 print("start_epoch", start_epoch) else: start_epoch = 1 for i_epoch in trange(start_epoch, args.num_train_epochs + 1, desc="Epoch"): if args.local_rank >= 0: train_sampler.set_epoch(i_epoch - 1) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)') nbatches = len(train_dataloader) train_loss = [] avg_loss = 0.0 batch_count = 0 for step, batch in enumerate(iter_bar): batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, task_idx, img, vis_pe, ans_labels, ans_type, organ = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() loss_tuple = model(img, vis_pe, input_ids, segment_ids, input_mask, lm_label_ids, ans_labels, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, drop_worst_ratio=args.max_drop_worst_ratio if i_epoch > args.drop_after else 0, ans_type=ans_type) masked_lm_loss, vqa_loss = loss_tuple batch_count += 1 if args.tasks == 'report_generation': masked_lm_loss = masked_lm_loss.mean() loss = masked_lm_loss else: vqa_loss = vqa_loss.mean() loss = vqa_loss iter_bar.set_description('Iter (loss=%5.3f)' % (loss.item())) train_loss.append(loss.item()) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 wandb.log({"train_loss": np.mean(train_loss)}) logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_config_file = os.path.join(args.output_dir, 'config.json') with open(output_config_file, 'w') as f: f.write(model_to_save.config.to_json_string()) output_model_file = os.path.join(args.output_dir, "model.{0}.bin".format(i_epoch)) output_optim_file = os.path.join(args.output_dir, "optim.{0}.bin".format(i_epoch)) if args.global_rank in ( -1, 0): # save model if the first device or no dist torch.save( copy.deepcopy(model_to_save).cpu().state_dict(), output_model_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.world_size > 1: torch.distributed.barrier()
def train_eval(args, train_data_path, valid_data_path): index = read_pickle(args.index_path) word2index, tag2index = index['word2id'], index['tag2id'] args.num_labels = len(tag2index) args.vocab_size = len(word2index)+1 set_seed(args.seed_num) train_dataloader, train_samples = get_dataloader(train_data_path, args.train_batch_size, True) valid_dataloader, _ = get_dataloader(valid_data_path, args.valid_batch_size, False) if args.model == 'bert': bert_config = BertConfig(args.bert_config_path) model = NERBert(bert_config, args) model.load_state_dict(torch.load(args.bert_model_path), strict=False) # model = NERBert.from_pretrained('bert_chinese', # # cache_dir='/home/dutir/yuetianchi/.pytorch_pretrained_bert', # num_labels=args.num_labels) else: if args.embedding: word_embedding_matrix = read_pickle(args.embedding_data_path) model = NERModel(args, word_embedding_matrix) else: model = NERModel(args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) if args.model == 'bert': param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if 'bert' not in n], 'lr': 5e-5, 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and ('bert' in n)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and ('bert' in n)], 'weight_decay': 0.0} ] warmup_proportion = 0.1 num_train_optimization_steps = int( train_samples / args.train_batch_size / args.gradient_accumulation_steps) * args.epochs optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=warmup_proportion, t_total=num_train_optimization_steps) else: current_learning_rate = args.learning_rate optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=current_learning_rate ) if args.init_checkpoint: # Restore model from checkpoint directory logging.info('Loading checkpoint %s...' % args.init_checkpoint) checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint')) init_step = checkpoint['step'] model.load_state_dict(checkpoint['model_state_dict']) if args.do_train: current_learning_rate = checkpoint['current_learning_rate'] warm_up_steps = checkpoint['warm_up_steps'] optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: logging.info('Ramdomly Initializing %s Model...' % args.model) init_step = 0 global_step = init_step best_score = 0.0 logging.info('Start Training...') logging.info('init_step = %d' % global_step) for epoch_id in range(int(args.epochs)): tr_loss = 0 model.train() for step, train_batch in enumerate(train_dataloader): batch = tuple(t.to(device) for t in train_batch) _, loss = model(batch[0], batch[1]) if n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps tr_loss += loss.item() loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 if (step + 1) % 500 == 0: print(loss.item()) if args.do_valid and global_step % args.valid_step == 1: true_res = [] pred_res = [] len_res = [] model.eval() for valid_step, valid_batch in enumerate(valid_dataloader): valid_batch = tuple(t.to(device) for t in valid_batch) with torch.no_grad(): logit = model(valid_batch[0]) if args.model == 'bert': # 第一个token是‘cls’ len_res.extend(torch.sum(valid_batch[0].gt(0), dim=-1).detach().cpu().numpy()-1) true_res.extend(valid_batch[1].detach().cpu().numpy()[:,1:]) pred_res.extend(logit.detach().cpu().numpy()[:,1:]) else: len_res.extend(torch.sum(valid_batch[0].gt(0),dim=-1).detach().cpu().numpy()) true_res.extend(valid_batch[1].detach().cpu().numpy()) pred_res.extend(logit.detach().cpu().numpy()) acc, score = cal_score(true_res, pred_res, len_res, tag2index) score = f1_score(true_res, pred_res, len_res, tag2index) logging.info('Evaluation:step:{},acc:{},fscore:{}'.format(str(epoch_id), acc, score)) if score>=best_score: best_score = score if args.model == 'bert': model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self output_dir = '{}_{}'.format('bert', str(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) output_model_file = os.path.join(output_dir, WEIGHTS_NAME) torch.save(model_to_save.state_dict(), output_model_file) output_config_file = os.path.join(output_dir, CONFIG_NAME) with open(output_config_file, 'w') as f: f.write(model_to_save.config.to_json_string()) else: save_variable_list = { 'step': global_step, 'current_learning_rate': args.learning_rate, 'warm_up_steps': step } save_model(model, optimizer, save_variable_list, args) model.train()
def main(): parser = argparse.ArgumentParser() # Path parameters parser.add_argument("--data_dir", default=None, type=str, required=True, help="The raw data dir.") parser.add_argument("--vocab_path", default=None, type=str, required=True, help="bert vocab path") parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--model_output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, help="The param init of pretrain or finetune") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") # Data Process Parameters parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") # Model Paramters parser.add_argument("--sop", action='store_true', help="whether use sop task.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") # Train Eval Test Paramters parser.add_argument("--checkpoint_steps", required=True, type=int, help="save model eyery checkpoint_steps") parser.add_argument("--total_steps", required=True, type=int, help="all steps of training model") parser.add_argument("--max_checkpoint", required=True, type=int, help="max saved model in model_output_dir") parser.add_argument( "--examples_size_once", type=int, default=1000, help="read how many examples every time in pretrain or finetune") parser.add_argument("--local_rank", type=int, default=-1, help="process rank in local") parser.add_argument("--local_debug", action='store_true', help="whether debug") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--fine_tune", action='store_true', help="Whether to run fine_tune.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=str, default='dynamic', help= '(float or str, optional, default=None): Optional property override. ' 'If passed as a string,must be a string representing a number, e.g., "128.0", or the string "dynamic".' ) parser.add_argument( '--opt_level', type=str, default='O1', help= ' (str, optional, default="O1"): Pure or mixed precision optimization level. ' 'Accepted values are "O0", "O1", "O2", and "O3", explained in detail above.' ) parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) # Other Patameters parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--rank', type=int, default=0, help="global rank of current process") parser.add_argument("--world_size", default=2, type=int, help="Number of process(显卡)") args = parser.parse_args() cur_env = os.environ args.rank = int(cur_env.get('RANK', -1)) args.world_size = int(cur_env.get('WORLD_SIZE', -1)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) assert args.train_batch_size >= 1, 'batch_size < 1 ' # 更新一次模型参数需要多少个样本 examples_per_update = args.world_size * args.train_batch_size * args.gradient_accumulation_steps args.examples_size_once = args.examples_size_once // examples_per_update * examples_per_update if args.fine_tune: args.examples_size_once = examples_per_update os.makedirs(args.model_output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.model_output_dir, 'unilm_config.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = torch.cuda.device_count() dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=args.rank) logger.info( "world_size:{}, rank:{}, device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}" .format(args.world_size, args.rank, device, n_gpu, bool(args.world_size > 1), args.fp16)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.fine_tune and not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.vocab_path, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings if args.local_rank == 0: dist.barrier() bi_uni_pipeline = [ Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift, fine_tune=args.fine_tune) ] file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') # t_total表示模型参数更新的次数 # t_total = args.train_steps # Prepare model recover_step = _get_max_epoch_model(args.model_output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, local_debug=args.local_debug) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_model_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * args.checkpoint_step) # 预训练时模型的参数初始化,比如使用chinese-bert-base的模型参数进行初始化 elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, local_debug=args.local_debug) total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info("模型参数: {}".format(total_trainable_params)) if args.local_rank == 0: dist.barrier() model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=args.total_steps) if args.amp and args.fp16: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale) from apex.parallel import DistributedDataParallel as DDP model = DDP(model) else: from torch.nn.parallel import DistributedDataParallel as DDP model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) if recover_step: logger.info("** ** * Recover optimizer: %d * ** **", recover_step) optim_recover = torch.load(os.path.join( args.model_output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.fp16 and args.amp: amp_recover = torch.load(os.path.join( args.model_output_dir, "amp.{0}.bin".format(recover_step)), map_location='cpu') logger.info("** ** * Recover amp: %d * ** **", recover_step) amp.load_state_dict(amp_recover) logger.info("** ** * CUDA.empty_cache() * ** **") torch.cuda.empty_cache() if args.rank == 0: writer = SummaryWriter(log_dir=args.log_dir) logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Param Update Num = %d", args.total_steps) model.train() PRE = "rank{},local_rank {},".format(args.rank, args.local_rank) step = 1 start = time.time() train_data_loader = TrainDataLoader( bi_uni_pipline=bi_uni_pipeline, examples_size_once=args.examples_size_once, world_size=args.world_size, train_batch_size=args.train_batch_size, num_workers=args.num_workers, data_dir=args.data_dir, tokenizer=tokenizer, max_len=args.max_seq_length) best_result = -float('inf') for global_step, batch in enumerate(train_data_loader, start=global_step): batch = [t.to(device) if t is not None else None for t in batch] if args.has_sentence_oracle: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch else: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, task_idx, sop_label = batch oracle_pos, oracle_weights, oracle_labels = None, None, None if not args.sop: # 不使用sop训练任务 sop_label = None loss_tuple = model(input_ids, segment_ids, input_mask, masked_lm_labels=lm_label_ids, next_sentence_label=sop_label, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv) masked_lm_loss, next_sentence_loss = loss_tuple # mean() to average on multi-gpu. if n_gpu > 1: masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: masked_lm_loss = masked_lm_loss / args.gradient_accumulation_steps next_sentence_loss = next_sentence_loss / args.gradient_accumulation_steps if not args.sop: loss = masked_lm_loss else: loss = masked_lm_loss + next_sentence_loss if args.fp16 and args.amp: with amp.scale_loss(loss, optimizer) as scale_loss: scale_loss.backward() else: loss.backward() if (global_step + 1) % args.gradient_accumulation_steps == 0: if args.rank == 0: writer.add_scalar('unilm/mlm_loss', masked_lm_loss, global_step) writer.add_scalar('unilm/sop_loss', next_sentence_loss, global_step) lr_this_step = args.learning_rate * warmup_linear( global_step / args.total_steps, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() #global_step += 1 #更新一次模型参数花费的时间,单位:秒 cost_time_per_update = time.time() - start # 更新完所有参数花费的时间,单位:小时 need_time = cost_time_per_update * (args.total_steps - global_step) / 3600.0 cost_time_per_chectpoint = cost_time_per_update * args.checkpoint_steps / 3600.0 start = time.time() if args.local_rank in [-1, 0]: INFO = PRE + '当前/chcklpoint_steps/total:{}/{}/{},loss{}/{},更新一次参数{}秒,checkpoint_steps {}小时,' \ '训练完成{}小时\n'.format(global_step, args.checkpoint_steps, args.total_steps, round(masked_lm_loss.item(), 5), round(next_sentence_loss.item(), 5), round(cost_time_per_update, 4), round(cost_time_per_chectpoint, 3), round(need_time, 3)) print(INFO) # Save a trained model if (global_step + 1) % args.checkpoint_steps == 0: checkpoint_index = (global_step + 1) % args.checkpoint_steps if args.rank >= 0: train_data_loader.train_sampler.set_epoch(checkpoint_index) # if args.eval: # # 如果是pretrain,验证MLM;如果微调,验证评价指标 # result = None #if best_result < result and _get_checkpont_num(args.model_output_num): if args.rank in [0, -1]: logger.info("** ** * Saving model and optimizer * ** **") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.model_output_dir, "model.{0}.bin".format(checkpoint_index)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.model_output_dir, "optim.{0}.bin".format(checkpoint_index)) torch.save(optimizer.state_dict(), output_optim_file) if args.fp16 and args.amp: logger.info("** ** * Saving amp state * ** **") output_amp_file = os.path.join( args.model_output_dir, "amp.{0}.bin".format(checkpoint_index)) torch.save(amp.state_dict(), output_amp_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.rank == 0: writer.close() print('** ** * train finished * ** **')
def train(config, model, train_iter, dev_iter): start_time = time.time() if os.path.exists(config.save_path): model.load_state_dict(torch.load(config.save_path)['model_state_dict']) model.train() param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] # optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) optimizer = BertAdam(optimizer_grouped_parameters, lr=config.learning_rate, warmup=0.05, t_total=len(train_iter) * config.num_epochs) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5) if os.path.exists(config.save_path): optimizer.load_state_dict( torch.load(config.save_path)['optimizer_state_dict']) total_batch = 0 dev_best_loss = float('inf') dev_last_loss = float('inf') no_improve = 0 flag = False model.train() # plot_model(model, to_file= config.save_dic+'.png') for epoch in range(config.num_epochs): print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs)) for i, (trains, labels) in enumerate(train_iter): outputs = model(trains) model.zero_grad() loss = F.cross_entropy(outputs, labels) loss.backward() optimizer.step() if total_batch % 100 == 0: true = labels.data.cpu() predic = torch.max(outputs.data, 1)[1].cpu() train_acc = metrics.accuracy_score(true, predic) train_loss = loss.item() dev_acc, dev_loss = evaluate(config, model, dev_iter) if dev_loss < dev_best_loss: state = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), } dev_best_loss = dev_loss torch.save(state, config.save_dic + str(total_batch) + '.pth') improve = '*' del state else: improve = '' if dev_last_loss > dev_loss: no_improve = 0 elif no_improve % 2 == 0: no_improve += 1 scheduler.step() else: no_improve += 1 dev_last_loss = dev_loss time_dif = get_time_dif(start_time) msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}' print( msg.format(total_batch, train_loss, train_acc, dev_loss, dev_acc, time_dif, improve)) model.train() total_batch += 1 if no_improve > config.require_improvement: print("No optimization for a long time, auto-stopping...") flag = True break if flag: break
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate) if os.path.exists(state_save_path): optimizer.load_state_dict(state['opt_state']) device = torch.device("cuda") tr_total = int(train_dataset.__len__() / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) print_freq = args.print_freq eval_freq = len(train_dataloader) // 4 print('Print freq:', print_freq, "Eval freq:", eval_freq) for epoch in range(epoch_start, int(args.num_train_epochs) + 1): tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 with tqdm(total=len(train_dataloader)) as bar: for step, batch in enumerate(train_dataloader, start=1): model.train()
def train_eval(args): set_seed(args.seed_num) train_x_left, train_x_entity, train_x_right, train_y = read_pickle(args.train_data_path) valid_x_left, valid_x_entity, valid_x_right, valid_y = read_pickle(args.valid_data_path) args.num_labels = train_y.shape[1] train_dataloader = get_dataloader(train_y, args.train_batch_size, True, train_x_left, train_x_entity, train_x_right) valid_dataloader = get_dataloader(valid_y, args.valid_batch_size, False, valid_x_left, valid_x_entity, valid_x_right) if args.model == 'bert': return None # bert_config = BertConfig(args.bert_config_path) # model = NERBert(bert_config, args) # model.load_state_dict(torch.load(args.bert_model_path), strict=False) # model = NERBert.from_pretrained('bert_chinese', # # cache_dir='/home/dutir/yuetianchi/.pytorch_pretrained_bert', # num_labels=args.num_labels) else: if args.embedding: word_embedding_matrix = read_pickle(args.embedding_data_path) args.vocab_size = len(word_embedding_matrix) model = AttentiveLSTM(args, word_embedding_matrix) else: logging.error("args.embedding should be true") return None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) if args.model == 'bert': param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if 'bert' not in n], 'lr': 5e-5, 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and ('bert' in n)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and ('bert' in n)], 'weight_decay': 0.0} ] warmup_proportion = 0.1 num_train_optimization_steps = int( train_samples / args.train_batch_size / args.gradient_accumulation_steps) * args.epochs optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=warmup_proportion, t_total=num_train_optimization_steps) else: current_learning_rate = args.learning_rate optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=current_learning_rate ) if args.init_checkpoint: # Restore model from checkpoint directory logging.info('Loading checkpoint %s...' % args.init_checkpoint) checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint')) init_step = checkpoint['step'] model.load_state_dict(checkpoint['model_state_dict']) if args.do_train: current_learning_rate = checkpoint['current_learning_rate'] warm_up_steps = checkpoint['warm_up_steps'] optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: logging.info('Ramdomly Initializing %s Model...' % args.model) init_step = 0 global_step = init_step best_score = 0.0 logging.info('Start Training...') logging.info('init_step = %d' % global_step) for epoch_id in range(int(args.epochs)): train_loss = 0 for step, train_batch in enumerate(train_dataloader): model.train() batch = tuple(t.to(device) for t in train_batch) train_x = (batch[0], batch[1], batch[2]) train_y = batch[3] loss = model(train_x, train_y) if n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps train_loss += loss.item() loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 if args.do_valid and global_step % args.valid_step == 1: true_res = [] pred_res = [] valid_losses = [] model.eval() for valid_step, valid_batch in enumerate(valid_dataloader): valid_batch = tuple(t.to(device) for t in valid_batch) valid_x = (valid_batch[0], valid_batch[1], valid_batch[2]) valid_y = valid_batch[3] with torch.no_grad(): valid_logit = model(valid_x) valid_loss = F.binary_cross_entropy_with_logits(valid_logit, valid_y) valid_logit = F.sigmoid(valid_logit) if args.model == 'bert': # 第一个token是‘cls’ valid_losses.append(valid_loss.item()) true_res.extend(valid_y.detach().cpu().numpy()) pred_res.extend(valid_logit.detach().cpu().numpy()) else: valid_losses.append(valid_loss.item()) true_res.extend(valid_y.detach().cpu().numpy()) pred_res.extend(valid_logit.detach().cpu().numpy()) metric_res = acc_hook(pred_res, true_res) logging.info('Evaluation:step:{},train_loss:{},valid_loss:{},microf1:{},macrof1:{}'. format(str(global_step), train_loss / args.valid_step, np.average(valid_losses), metric_res['loose_micro_f1'], metric_res['loose_macro_f1'])) if metric_res['loose_micro_f1'] >= best_score: best_score = metric_res['loose_micro_f1'] if args.model == 'bert': model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self output_dir = '{}_{}'.format('bert', str(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) output_model_file = os.path.join(output_dir, WEIGHTS_NAME) torch.save(model_to_save.state_dict(), output_model_file) output_config_file = os.path.join(output_dir, CONFIG_NAME) with open(output_config_file, 'w') as f: f.write(model_to_save.config.to_json_string()) else: save_variable_list = { 'step': global_step, 'current_learning_rate': args.learning_rate, 'warm_up_steps': step } save_model(model, optimizer, save_variable_list, args) train_loss = 0.0
def main(train_file, matched_valid_file, mismatched_valid_file, test_file, bert_path, target_dir, hidden_size=768, dropout=0.5, num_classes=2, epochs=10, batch_size=128, learning_rate=5e-5, patience=5, max_grad_norm=10.0, checkpoint=None): """ Train the ESIM model on the text_similarity dataset. Args: train_file: A path to some preprocessed data that must be used to train the model. valid_files: A dict containing the paths to the preprocessed matched and mismatched datasets that must be used to validate the model. embeddings_file: A path to some preprocessed word embeddings that must be used to initialise the model. target_dir: The path to a directory where the trained model must be saved. hidden_size: The size of the hidden layers in the model. Defaults to 300. dropout: The dropout rate to use in the model. Defaults to 0.5. num_classes: The number of classes in the output of the model. Defaults to 3. epochs: The maximum number of epochs for training. Defaults to 64. batch_size: The size of the batches for training. Defaults to 32. lr: The learning rate for the optimizer. Defaults to 0.0004. patience: The patience to use for early stopping. Defaults to 5. checkpoint: A checkpoint from which to continue training. If None, training starts from scratch. Defaults to None. """ device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu") print(20 * "=", " Preparing for training ", 20 * "=") if not os.path.exists(target_dir): os.makedirs(target_dir) # -------------------- Data loading ------------------- # print("\t* Loading training data...") with open(train_file, "rb") as pkl: train_data = TEXTSIMILARITYDataset(pickle.load(pkl)) train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size) print("\t* Loading validation data...") with open(matched_valid_file, "rb") as pkl: matched_valid_data = TEXTSIMILARITYDataset(pickle.load(pkl)) with open(mismatched_valid_file, "rb") as pkl: mismatched_valid_data = TEXTSIMILARITYDataset(pickle.load(pkl)) with open(test_file, "rb") as pkl: test_raw_data = pickle.load(pkl) test_data = TEXTSIMILARITYDataset(test_raw_data) matched_valid_loader = DataLoader(matched_valid_data, shuffle=False, batch_size=batch_size) mismatched_valid_loader = DataLoader(mismatched_valid_data, shuffle=False, batch_size=batch_size) test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size) # -------------------- Model definition ------------------- # model = BERT(bert_path, hidden_size, num_classes=num_classes, device=device).to(device) # -------------------- Preparation for training ------------------- # criterion = nn.CrossEntropyLoss() param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] optimizer = BertAdam(optimizer_grouped_parameters, lr=learning_rate, warmup=0.05, t_total=len(train_loader) * epochs) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=0) best_score = 0.0 start_epoch = 1 # Data for loss curves plot. epochs_count = [] train_losses = [] matched_valid_losses = [] mismatched_valid_losses = [] test_losses = [] # Continuing training from a checkpoint if one was given as argument. if checkpoint: checkpoint = torch.load(checkpoint) start_epoch = checkpoint["epoch"] + 1 best_score = checkpoint["best_score"] print("\t* Training will continue on existing model from epoch {}..." .format(start_epoch)) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) epochs_count = checkpoint["epochs_count"] train_losses = checkpoint["train_losses"] matched_valid_losses = checkpoint["matched_valid_losses"] mismatched_valid_losses = checkpoint["mismatched_valid_losses"] test_losses = checkpoint["test_losses"] # Compute loss and accuracy before starting (or resuming) training. _, matched_valid_loss, matched_valid_accuracy, precisions, recalls, f1s = validate(model, matched_valid_loader, criterion) print("\t* Validation loss before training on matched valid data: {:.4f}, accuracy: {:.4f}%" .format(matched_valid_loss, (matched_valid_accuracy*100))) _, mismatched_valid_loss, mismatched_valid_accuracy, precisions, recalls, f1s = validate(model, mismatched_valid_loader, criterion) print("\t* Validation loss before training on mismatched valid data: {:.4f}, accuracy: {:.4f}%" .format(mismatched_valid_loss, (mismatched_valid_accuracy*100))) # -------------------- Training epochs ------------------- # print("\n", 20 * "=", "Training ESIM model on device: {}".format(device), 20 * "=") patience_counter = 0 for epoch in range(start_epoch, epochs+1): epochs_count.append(epoch) print("* Training epoch {}:".format(epoch)) epoch_time, epoch_loss, epoch_accuracy = train(model, train_loader, optimizer, criterion, epoch, max_grad_norm) train_losses.append(epoch_loss) print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%\n" .format(epoch_time, epoch_loss, (epoch_accuracy*100))) print("* Validation for epoch {} on matched data:".format(epoch)) epoch_time, epoch_loss, epoch_accuracy, precisions, recalls, f1s = validate(model, matched_valid_loader, criterion) matched_valid_losses.append(epoch_loss) print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%" .format(epoch_time, epoch_loss, (epoch_accuracy*100))) print("* Validation for epoch {} on mismatched data:".format(epoch)) epoch_time, epoch_loss, mis_epoch_accuracy, precisions, recalls, f1s = validate(model, mismatched_valid_loader, criterion) mismatched_valid_losses.append(epoch_loss) print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n" .format(epoch_time, epoch_loss, (mis_epoch_accuracy*100))) print("* Validation for epoch {} on test data:".format(epoch)) print("test data size: ", len(test_data)) epoch_time, epoch_loss, test_epoch_accuracy = test(model, test_loader, test_raw_data, criterion) test_losses.append(epoch_loss) print("-> Test. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n" .format(epoch_time, epoch_loss, (test_epoch_accuracy*100))) # Update the optimizer's learning rate with the scheduler. scheduler.step(epoch_accuracy) # Early stopping on validation accuracy. if epoch_accuracy < best_score: patience_counter += 1 else: best_score = epoch_accuracy patience_counter = 0 # Save the best model. The optimizer is not saved to avoid having # a checkpoint file that is too heavy to be shared. To resume # training from the best model, use the 'esim_*.pth.tar' # checkpoints instead. torch.save({"epoch": epoch, "model": model.state_dict(), "best_score": best_score, "epochs_count": epochs_count, "train_losses": train_losses, "match_valid_losses": matched_valid_losses, "mismatch_valid_losses": mismatched_valid_losses, "test_losses": test_losses}, os.path.join(target_dir, "best.pth.tar")) # Save the model at each epoch. torch.save({"epoch": epoch, "model": model.state_dict(), "best_score": best_score, "optimizer": optimizer.state_dict(), "epochs_count": epochs_count, "train_losses": train_losses, "match_valid_losses": matched_valid_losses, "mismatch_valid_losses": mismatched_valid_losses, "test_losses": test_losses}, os.path.join(target_dir, "esim_{}.pth.tar".format(epoch))) if patience_counter >= patience: print("-> Early stopping: patience limit reached, stopping...") break print("* Validation on test1 data:") epoch_time, epoch_loss, test_epoch_accuracy = test(model, test_loader, test_raw_data, criterion) print("-> Test. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n" .format(epoch_time, epoch_loss, (test_epoch_accuracy*100))) # Plotting of the loss curves for the train and validation sets. plt.figure() plt.plot(epochs_count, train_losses, "-r") plt.plot(epochs_count, matched_valid_losses, "-b") plt.plot(epochs_count, mismatched_valid_losses, "-g") plt.xlabel("epoch") plt.ylabel("loss") plt.legend(["Training loss", "Validation loss (matched set)", "Validation loss (mismatched set)"]) plt.title("Cross entropy loss") plt.savefig("loss.jpg")