def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default="../../../dataset/final_data/commongen", type=str, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument( "--bart_model", default="facebook/bart-large", type=str, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default="../../../output/train_kgbart", type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default="../../../log/train_kgbart", type=str, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") # Other parameters parser.add_argument( "--max_seq_length", default=64, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", default=True, type=bool, help="Whether to run training.") parser.add_argument("--do_eval", default=True, type=bool, help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", default=False, type=bool, help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=48, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=2, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0.1, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=6, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', default=True, type=bool, help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', default=True, type=bool, help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=64, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=64, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', default=True, type=bool, help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.70, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=30, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=5, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument('--pretraining_KG', action='store_true', help="Whether to pretraining KG. ") parser.add_argument('--train_pretraining_num', type=int, default=20, help="The number of sample training pretraining KG. ") parser.add_argument('--val_pretraining_num', type=int, default=20, help="The number of sample validing pretraining KG.") parser.add_argument('--max_position_embeddings', type=int, default=64, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', default=False, type=bool, help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', default=False, type=bool, help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', default=False, type=bool, help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument( "--fp16_opt_level", type=str, default="O1", help= "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html", ) parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--keep_last_epochs", default=5, type=int, help="Keep the last few epochs.") args = parser.parse_args() # assert Path(args.model_recover_path).exists( # ), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BartTokenizer.from_pretrained(args.bart_model, do_lower_case=args.do_lower_case) # if args.max_position_embeddings: # tokenizer.max_len = args.max_position_embeddings data_tokenizer = tokenizer # WhitespaceTokenizer() if args.tokenized_input else if args.local_rank == 0: dist.barrier() bi_uni_pipeline = [ seq2seq_loader.Preprocess4Pretrain( list(tokenizer.encoder.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mode="s2s", pretraining_KG=args.pretraining_KG, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] file_oracle = None entity_id = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'CommonGen_KG/commongen_entity2id.txt') entity_embedding_path = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'CommonGen_KG/commongen_ent_embeddings') relation_id = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'CommonGen_KG/commongen_relation2id.txt') relation_embedding_path = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'CommonGen_KG/commongen_rel_embeddings') entity_embedding = np.array(pickle.load(open(entity_embedding_path, "rb"))) entity_embedding = np.array( list(np.zeros((4, 1024))) + list(entity_embedding)) relation_embedding = np.array( pickle.load(open(relation_embedding_path, "rb"))) if args.do_train: print("Loading Train Dataset", args.data_dir) if args.pretraining_KG: file_oracle = os.path.join(args.data_dir, 'CommonGen_KG/commongen_entity2id.txt') fn_src = os.path.join( args.data_dir, args.src_file if args.src_file else 'commongen.train.src_new.txt') fn_tgt = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'commongen.train.tgt.txt') fn_onehop = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'commongen.train.onehop_5.txt') train_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, entity_id, relation_id, fn_onehop, args.train_batch_size, data_tokenizer, args.max_seq_length, pretraining_KG=file_oracle, pretraining_num=args.train_pretraining_num, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) if args.do_eval: print("Loading Dev Dataset", args.data_dir) if args.pretraining_KG: file_oracle = os.path.join(args.data_dir, 'CommonGen_KG/commongen_entity2id.txt') fn_src = os.path.join( args.data_dir, args.src_file if args.src_file else 'commongen.dev.src_new.txt') fn_tgt = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'commongen.dev.tgt.txt') fn_onehop = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'commongen.dev.onehop_5.txt') dev_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, entity_id, relation_id, fn_onehop, args.eval_batch_size, data_tokenizer, args.max_seq_length, pretraining_KG=file_oracle, pretraining_num=args.val_pretraining_num, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: dev_sampler = RandomSampler(dev_dataset, replacement=False) _batch_size = args.eval_batch_size else: dev_sampler = DistributedSampler(dev_dataset) _batch_size = args.eval_batch_size // dist.get_world_size() dev_dataloader = torch.utils.data.DataLoader( dev_dataset, batch_size=_batch_size, sampler=dev_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.pretraining_KG else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None model = KGBartForConditionalGeneration.from_pretrained( args.bart_model, entity_weight=entity_embedding, relation_weight=relation_embedding) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 model = KGBartForConditionalGeneration.from_pretrained( args.bart_model, state_dict=model_recover, entity_weight=entity_embedding, relation_weight=relation_embedding) if args.local_rank == 0: dist.barrier() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) schedule_recover = torch.load(os.path.join( args.output_dir, "sched.{0}.bin".format(recover_step)), map_location='cpu') scheduler.load_state_dict(schedule_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() best_dev_loss = 1000 output_eval_file = os.path.join(args.log_dir, "eval_results.txt") writer = open(output_eval_file, "w") def checkpoint_paths(path, pattern=r"model(\d+)\.pt"): """Retrieves all checkpoints found in `path` directory. Checkpoints are identified by matching filename to the specified pattern. If the pattern contains groups, the result will be sorted by the first group in descending order. """ pt_regexp = re.compile(pattern) files = os.listdir(path) entries = [] for i, f in enumerate(files): m = pt_regexp.fullmatch(f) if m is not None: idx = float(m.group(1)) if len(m.groups()) > 0 else int( f.split(".")[1]) entries.append((idx, m.group(0))) return [ os.path.join(path, x[1]) for x in sorted(entries, reverse=True) ] if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): model.train() if args.local_rank != -1: train_sampler.set_epoch(i_epoch) # iter_bar = tqdm(BackgroundGenerator(train_dataloader), desc='Iter (loss=X.XXX)', # disable=args.local_rank not in (-1, 0)) for step, batch in enumerate( tqdm(train_dataloader, desc="Training", position=0, leave=True)): batch = [ t.to(device) if t is not None else None for t in batch ] if args.pretraining_KG: input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch else: input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch loss_output = model( input_ids, input_entity_ids=input_entity_ids, attention_mask=subword_mask, word_mask=word_mask, word_subword=word_subword, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=labels, label_smoothing=False) masked_lm_loss = loss_output.loss if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() loss = masked_lm_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) # iter_bar.set_description('Iter %d (loss=%5.3f)' % (i_epoch, loss.item())) if step % 1000 == 0: print('Iter %d (Gen_loss=%5.3f)' % (i_epoch, loss.item())) # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): model.eval() cur_dev_loss = [] with torch.no_grad(): for step, batch in enumerate( tqdm(dev_dataloader, desc="Evaluating", position=0, leave=True)): batch = [ t.to(device) if t is not None else None for t in batch ] if args.pretraining_KG: input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch else: input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch loss_output = model( input_ids, input_entity_ids=input_entity_ids, attention_mask=subword_mask, word_mask=word_mask, word_subword=word_subword, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=labels, label_smoothing=False) masked_lm_loss = loss_output.loss if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() # next_sentence_loss = next_sentence_loss.mean() loss = masked_lm_loss cur_dev_loss.append(float(loss.item())) # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) dev_loss = sum(cur_dev_loss) / float(len(cur_dev_loss)) print("the epoch {} DEV loss is {}".format( i_epoch, dev_loss)) if best_dev_loss > dev_loss: best_dev_loss = dev_loss model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self os.makedirs(args.output_dir + "/best_model", exist_ok=True) output_model_file = os.path.join( args.output_dir, "best_model/model.best.bin") # output_optim_file = os.path.join( # args.output_dir, "best_model/optim.best.bin") # output_schedule_file = os.path.join( # args.output_dir, "best_model/sched.best.bin") torch.save(model_to_save.state_dict(), output_model_file) # torch.save(optimizer.state_dict(), output_optim_file) # torch.save(scheduler.state_dict(), output_schedule_file) logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * " ) model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{0}.bin".format(i_epoch)) output_optim_file = os.path.join( args.output_dir, "optim.{0}.bin".format(i_epoch)) output_schedule_file = os.path.join( args.output_dir, "sched.{0}.bin".format(i_epoch)) torch.save(model_to_save.state_dict(), output_model_file) torch.save(optimizer.state_dict(), output_optim_file) torch.save(scheduler.state_dict(), output_schedule_file) writer.write("epoch " + str(i_epoch) + "\n") writer.write("the current eval accuracy is: " + str(dev_loss) + "\n") logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( args.output_dir, pattern=r"model.\d+.bin") for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) checkpoints = checkpoint_paths( args.output_dir, pattern=r"optim.\d+.bin") for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) checkpoints = checkpoint_paths( args.output_dir, pattern=r"sched.\d+.bin") for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache()
class Trainer: def __init__(self, callback: TrainerCallback): self.callback = callback self.callback.trainer = self logging.basicConfig(level=logging.INFO) def parse_args(self): self.parser = argparse.ArgumentParser() self.parser.add_argument('--train', action='store_true') self.parser.add_argument('--dev', action='store_true') self.parser.add_argument('--test', action='store_true') self.parser.add_argument('--debug', action='store_true') self.parser.add_argument("--per_gpu_train_batch_size", default=8, type=int) self.parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int) self.parser.add_argument("--learning_rate", default=5e-5, type=float) self.parser.add_argument("--gradient_accumulation_steps", type=int, default=1) self.parser.add_argument("--weight_decay", default=0.0, type=float) self.parser.add_argument("--adam_epsilon", default=1e-8, type=float) self.parser.add_argument("--max_grad_norm", default=1.0, type=float) self.parser.add_argument("--epochs", default=2, type=int) self.parser.add_argument("--warmup_ratio", default=0.1, type=float) self.parser.add_argument("--logging_steps", type=int, default=500) self.parser.add_argument("--save_steps", type=int, default=10000) self.parser.add_argument("--seed", type=int, default=42) self.parser.add_argument("--num_workers", type=int, default=0) self.parser.add_argument("--local_rank", type=int, default=-1) self.parser.add_argument("--fp16", action="store_true") self.parser.add_argument("--fp16_opt_level", type=str, default="O1") self.parser.add_argument("--no_cuda", action="store_true") self.parser.add_argument("--load_checkpoint", default=None, type=str) self.parser.add_argument("--ignore_progress", action='store_true') self.parser.add_argument("--dataset_ratio", type=float, default=1.0) self.parser.add_argument("--no_save", action="store_true") self.callback.on_argument(self.parser) self.args = self.parser.parse_args() keys = list(self.args.__dict__.keys()) for key in keys: value = getattr(self.args, key) if type(value) == str and os.path.exists(value): setattr(self.args, key, os.path.abspath(value)) if not self.args.train: self.args.epochs = 1 self.train = self.args.train self.dev = self.args.dev self.test = self.args.test self.debug = self.args.debug self.per_gpu_train_batch_size = self.args.per_gpu_train_batch_size self.per_gpu_eval_batch_size = self.args.per_gpu_eval_batch_size self.learning_rate = self.args.learning_rate self.gradient_accumulation_steps = self.args.gradient_accumulation_steps self.weight_decay = self.args.weight_decay self.adam_epsilon = self.args.adam_epsilon self.max_grad_norm = self.args.max_grad_norm self.epochs = self.args.epochs self.warmup_ratio = self.args.warmup_ratio self.logging_steps = self.args.logging_steps self.save_steps = self.args.save_steps self.seed = self.args.seed self.num_workers = self.args.num_workers self.local_rank = self.args.local_rank self.fp16 = self.args.fp16 self.fp16_opt_level = self.args.fp16_opt_level self.no_cuda = self.args.no_cuda self.load_checkpoint = self.args.load_checkpoint self.ignore_progress = self.args.ignore_progress self.dataset_ratio = self.args.dataset_ratio self.no_save = self.args.no_save self.callback.args = self.args def set_env(self): if self.debug: sys.excepthook = IPython.core.ultratb.FormattedTB( mode='Verbose', color_scheme='Linux', call_pdb=1) if self.local_rank == -1 or self.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not self.no_cuda else "cpu") self.n_gpu = 0 if self.no_cuda else torch.cuda.device_count() else: torch.cuda.set_device(self.local_rank) device = torch.device("cuda", self.local_rank) torch.distributed.init_process_group(backend="nccl") self.n_gpu = 1 set_seed(self.seed, self.n_gpu) self.device = device with self.once_barrier(): if not os.path.exists('r'): os.mkdir('r') runs = os.listdir('r') i = max([int(c) for c in runs], default=-1) + 1 os.mkdir(os.path.join('r', str(i))) src_names = [ source for source in os.listdir() if source.endswith('.py') ] for src_name in src_names: shutil.copy(src_name, os.path.join('r', str(i), src_name)) os.mkdir(os.path.join('r', str(i), 'output')) os.mkdir(os.path.join('r', str(i), 'tmp')) runs = os.listdir('r') i = max([int(c) for c in runs]) os.chdir(os.path.join('r', str(i))) with self.once_barrier(): json.dump(sys.argv, open('output/args.json', 'w')) logging.info( "Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, 16-bits training: {}" .format(self.local_rank, device, self.n_gpu, bool(self.local_rank != -1), self.fp16)) self.train_batch_size = self.per_gpu_train_batch_size * max( 1, self.n_gpu) self.eval_batch_size = self.per_gpu_eval_batch_size * max( 1, self.n_gpu) if self.fp16: apex.amp.register_half_function(torch, "einsum") self.stream = torch.cuda.Stream() def set_model(self): self.model = self.callback.load_model() self.model.to(self.device) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": self.weight_decay }, { "params": [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.adam_epsilon) if self.fp16: self.model, self.optimizer = apex.amp.initialize( self.model, self.optimizer, opt_level=self.fp16_opt_level) if self.n_gpu > 1: self.model = torch.nn.DataParallel(self.model) if self.local_rank != -1: self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=True) def once(self): return Once(self.local_rank) def once_barrier(self): return OnceBarrier(self.local_rank) def cache(self): return Cache(self.local_rank) def load_data(self): self.train_step = 1 self.epochs_trained = 0 self.steps_trained_in_current_epoch = 0 train_dataset, dev_dataset, test_dataset = self.callback.load_data() train_fn, dev_fn, test_fn = self.callback.collate_fn() if train_dataset: if self.dataset_ratio < 1: train_dataset = torch.utils.data.Subset( train_dataset, list(range(int(len(train_dataset) * self.dataset_ratio)))) self.train_dataset = train_dataset self.train_sampler = RandomSampler( self.train_dataset ) if self.local_rank == -1 else DistributedSampler( self.train_dataset) self.train_dataloader = Prefetcher( DataLoader(self.train_dataset, sampler=self.train_sampler, batch_size=self.train_batch_size, collate_fn=train_fn, num_workers=self.num_workers), self.stream) self.t_total = len( self.train_dataloader ) // self.gradient_accumulation_steps * self.epochs self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=int(self.t_total * self.warmup_ratio), num_training_steps=self.t_total) if dev_dataset: if self.dataset_ratio < 1: dev_dataset = torch.utils.data.Subset( dev_dataset, list(range(int(len(dev_dataset) * self.dataset_ratio)))) self.dev_dataset = dev_dataset self.dev_sampler = SequentialSampler( self.dev_dataset ) if self.local_rank == -1 else DistributedSampler( self.dev_dataset) self.dev_dataloader = Prefetcher( DataLoader(self.dev_dataset, sampler=self.dev_sampler, batch_size=self.eval_batch_size, collate_fn=dev_fn, num_workers=self.num_workers), self.stream) if test_dataset: if self.dataset_ratio < 1: test_dataset = torch.utils.data.Subset( test_dataset, list(range(int(len(test_dataset) * self.dataset_ratio)))) self.test_dataset = test_dataset self.test_sampler = SequentialSampler( self.test_dataset ) if self.local_rank == -1 else DistributedSampler( self.test_dataset) self.test_dataloader = Prefetcher( DataLoader(self.test_dataset, sampler=self.test_sampler, batch_size=self.eval_batch_size, collate_fn=test_fn, num_workers=self.num_workers), self.stream) def restore_checkpoint(self, path, ignore_progress=False): if self.no_save: return model_to_load = self.model.module if hasattr(self.model, "module") else self.model model_to_load.load_state_dict( torch.load(os.path.join(path, 'pytorch_model.bin'), map_location=self.device)) self.optimizer.load_state_dict( torch.load(os.path.join(path, "optimizer.pt"), map_location=self.device)) self.scheduler.load_state_dict( torch.load(os.path.join(path, "scheduler.pt"), map_location=self.device)) self.callback.on_load(path) if not ignore_progress: self.train_step = int(path.split("-")[-1]) self.epochs_trained = self.train_step // ( len(self.train_dataloader) // self.gradient_accumulation_steps) self.steps_trained_in_current_epoch = self.train_step % ( len(self.train_dataloader) // self.gradient_accumulation_steps) logging.info( " Continuing training from checkpoint, will skip to saved train_step" ) logging.info(" Continuing training from epoch %d", self.epochs_trained) logging.info(" Continuing training from train step %d", self.train_step) logging.info(" Will skip the first %d steps in the first epoch", self.steps_trained_in_current_epoch) def save_checkpoint(self): if self.no_save: return output_dir = os.path.join('output', "checkpoint-{}".format(self.train_step)) if not os.path.exists(output_dir): os.mkdir(output_dir) model_to_save = self.model.module if hasattr(self.model, "module") else self.model torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin')) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) self.callback.on_save(output_dir) def run(self): self.parse_args() self.set_env() with self.once(): self.writer = SummaryWriter() self.set_model() self.load_data() if self.load_checkpoint is not None: self.restore_checkpoint(self.load_checkpoint, self.ignore_progress) best_performance = 0 best_step = -1 for epoch in range(self.epochs): if epoch < self.epochs_trained: continue with self.once(): logging.info('epoch %d', epoch) if self.train: tr_loss, logging_loss = 0.0, 0.0 self.model.zero_grad() self.model.train() self.callback.on_train_epoch_start(epoch) if self.local_rank >= 0: self.train_sampler.set_epoch(epoch) for step, batch in enumerate( tqdm(self.train_dataloader, disable=self.local_rank > 0)): if step < self.steps_trained_in_current_epoch: continue inputs, extra = self.callback.process_train_data(batch) outputs = self.model(**inputs) loss = outputs[0] if self.n_gpu > 1: loss = loss.mean() if self.gradient_accumulation_steps > 1: loss = loss / self.gradient_accumulation_steps if self.local_rank < 0 or ( step + 1) % self.gradient_accumulation_steps == 0: if self.fp16: with apex.amp.scale_loss( loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() else: with self.model.no_sync(): if self.fp16: with apex.amp.scale_loss( loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % self.gradient_accumulation_steps == 0: if self.fp16: torch.nn.utils.clip_grad_norm_( apex.amp.master_params(self.optimizer), self.max_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.scheduler.step() self.model.zero_grad() self.train_step += 1 with self.once(): if self.train_step % self.logging_steps == 0: self.writer.add_scalar( "lr", self.scheduler.get_lr()[0], self.train_step) self.writer.add_scalar( "loss", (tr_loss - logging_loss) / self.logging_steps, self.train_step) logging_loss = tr_loss if self.train_step % self.save_steps == 0: self.save_checkpoint() self.callback.on_train_step(step, self.train_step, inputs, extra, loss.item(), outputs) with self.once(): self.save_checkpoint() self.callback.on_train_epoch_end(epoch) if self.dev: with torch.no_grad(): self.model.eval() self.callback.on_dev_epoch_start(epoch) for step, batch in enumerate( tqdm(self.dev_dataloader, disable=self.local_rank > 0)): inputs, extra = self.callback.process_dev_data(batch) outputs = self.model(**inputs) self.callback.on_dev_step(step, inputs, extra, outputs) performance = self.callback.on_dev_epoch_end(epoch) if performance > best_performance: best_performance = performance best_step = self.train_step if self.test: with torch.no_grad(): if best_step > 0 and self.train: self.restore_checkpoint( os.path.join('output', "checkpoint-{}".format(best_step))) self.model.eval() self.callback.on_test_epoch_start(epoch) for step, batch in enumerate( tqdm(self.test_dataloader, disable=self.local_rank > 0)): inputs, extra = self.callback.process_test_data(batch) outputs = self.model(**inputs) self.callback.on_test_step(step, inputs, extra, outputs) self.callback.on_test_epoch_end(epoch) with self.once(): self.writer.close() json.dump(True, open('output/f.json', 'w')) def distributed_broadcast(self, l): assert type(l) == list or type(l) == dict if self.local_rank < 0: return l else: torch.distributed.barrier() process_number = torch.distributed.get_world_size() json.dump(l, open(f'tmp/{self.local_rank}.json', 'w')) torch.distributed.barrier() objs = list() for i in range(process_number): objs.append(json.load(open(f'tmp/{i}.json'))) if type(objs[0]) == list: ret = list() for i in range(process_number): ret.extend(objs[i]) else: ret = dict() for i in range(process_number): for k, v in objs.items(): assert k not in ret ret[k] = v torch.distributed.barrier() return ret def distributed_merge(self, l): assert type(l) == list or type(l) == dict if self.local_rank < 0: return l else: torch.distributed.barrier() process_number = torch.distributed.get_world_size() json.dump(l, open(f'tmp/{self.local_rank}.json', 'w')) torch.distributed.barrier() if self.local_rank == 0: objs = list() for i in range(process_number): objs.append(json.load(open(f'tmp/{i}.json'))) if type(objs[0]) == list: ret = list() for i in range(process_number): ret.extend(objs[i]) else: ret = dict() for i in range(process_number): for k, v in objs.items(): assert k not in ret ret[k] = v else: ret = None torch.distributed.barrier() return ret def distributed_get(self, v): if self.local_rank < 0: return v else: torch.distributed.barrier() if self.local_rank == 0: json.dump(v, open('tmp/v.json', 'w')) torch.distributed.barrier() v = json.load(open('tmp/v.json')) torch.distributed.barrier() return v
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]: """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) def collate(examples: List[torch.Tensor]): if tokenizer._pad_token is None: return pad_sequence(examples, batch_first=True) return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs model = model.module if hasattr( model, "module") else model # Take care of distributed/parallel training model.resize_token_embeddings(len(tokenizer)) # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if (args.model_name_or_path and os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) non_multi_model = model if args.n_gpu > 1: model = torch.nn.DataParallel(non_multi_model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if args.model_name_or_path and os.path.exists(args.model_name_or_path): try: # set global_step to gobal_step of last saved checkpoint from model path checkpoint_suffix = args.model_name_or_path.split("-")[-1].split( "/")[0] global_step = int(checkpoint_suffix) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: logger.info(" Starting fine-tuning.") tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproducibility best_perplexity = float('inf') for i, epoch in enumerate(train_iterator): epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) if args.local_rank != -1: train_sampler.set_epoch(epoch) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) inputs = inputs.to(args.device) labels = labels.to(args.device) model.train() outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model( inputs, labels=labels) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.do_eval: file_path = Path(args.data_dir, args.eval_data_file) out_file_path = Path(args.data_dir, "output_" + args.eval_data_file) id_to_json_map = {} with open(file_path, encoding="utf-8") as f: lines = [] i = 0 eval_loss = 0.0 nb_eval_steps = 0 for line in tqdm(f, desc="Evaluating"): out_json = {} line = json.loads(line) example_id = line.get("example_id") question_text = line.get("question_text") prompt_text = question_text + " " + args.sep_token + " " encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = encoded_prompt.to(args.device) output_sequences = non_multi_model.generate( input_ids=encoded_prompt, max_length=args.length + len(encoded_prompt[0]), temperature=args.temperature, top_k=args.k, top_p=args.p, repetition_penalty=args.repetition_penalty, do_sample=True, num_return_sequences=args.num_return_sequences, ) if len(output_sequences.shape) > 2: output_sequences.squeeze_() generated_sequences = [] for generated_sequence_idx, generated_sequence in enumerate( output_sequences): # print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1)) # generated_sequence = output_sequences[0] generated_sequence = generated_sequence.tolist() # Decode text text = tokenizer.decode( generated_sequence, clean_up_tokenization_spaces=True) # Remove all text after the stop token if args.stop_token: text = text[:text.find(args.stop_token)] # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing total_sequence = (prompt_text + text[len( tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True) ):]) # print(total_sequence) out_json["journaling_input"], out_json[ "reflection_output"] = total_sequence.split( args.sep_token)[:2] sample_dataset = GenerateTextDataset( tokenizer, total_sequence, args.block_size) def collate(examples: List[torch.Tensor]): if tokenizer._pad_token is None: return pad_sequence(examples, batch_first=True) return pad_sequence( examples, batch_first=True, padding_value=tokenizer.pad_token_id) eval_sampler = SequentialSampler(sample_dataset) eval_dataloader = DataLoader(sample_dataset, sampler=eval_sampler, batch_size=1, collate_fn=collate) model_lm = model if args.n_gpu > 1: model_lm = torch.nn.DataParallel(model_lm) model_lm.eval() for batch in eval_dataloader: inputs, labels = mask_tokens( batch, tokenizer, args) if args.mlm else (batch, batch) inputs = inputs.to(args.device) labels = labels.to(args.device) with torch.no_grad(): outputs = model_lm(inputs, masked_lm_labels=labels ) if args.mlm else model_lm( inputs, labels=labels) lm_loss = outputs[0] example_loss = lm_loss.mean().item() eval_loss += example_loss nb_eval_steps += 1 perplexity = torch.exp( torch.tensor(example_loss)).item() # print(perplexity) out_json["perplexity"] = perplexity example_id += "-" + str(generated_sequence_idx) id_to_json_map[example_id] = json.dumps( out_json, ensure_ascii=False) # result = {"perplexity": perplexity} eval_loss = eval_loss / nb_eval_steps total_perplexity = torch.exp(torch.tensor(eval_loss)) logger.info(f"total_loss:: {eval_loss}") logger.info( f"total_perplexity:: {torch.exp(torch.tensor(eval_loss))}") if total_perplexity < best_perplexity: logger.info( f"Current best epoch::: {i}, with perplexity:: {total_perplexity}" ) best_perplexity = total_perplexity with open(out_file_path, "w+", encoding="utf-8") as out_file: for _, out_json in id_to_json_map.items(): out_file.write(out_json + "\n") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) return global_step, tr_loss / global_step
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--data_dir", default='/data/lq/tianchi/qg/model/unilm/data_file/', type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.") parser.add_argument("--src_file", default='src_file/train_data.json', type=str, help="The input data file name.") parser.add_argument("--dev_file", default='dev_data.json', type=str, help="dev file.") parser.add_argument("--model_type", default='unilm', type=str, required=True, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) parser.add_argument("--model_name_or_path", default='/data/lq/tianchi/qg/model/unilm/torch-model/', type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) parser.add_argument("--output_dir", default='/data/lq/tianchi/qg/model/unilm/output_dir/', type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--log_dir", default='', type=str, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") parser.add_argument("--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name") parser.add_argument("--tokenizer_name", default="", type=str, help="Pretrained tokenizer name or path if not the same as model_name") # Other parameters parser.add_argument("--dev_batch_size", default=20, type=str, help="dev batch size.") parser.add_argument("--max_seq_length", default=512, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument('--max_position_embeddings', type=int, default=512, help="max position embeddings") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=777, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--fp16_opt_level', type=str, default='O1', help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument('--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument('--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument("--mask_prob", default=0.20, type=float, help="Number of prediction is sometimes less than max_pred when sequence is short.") parser.add_argument("--mask_prob_eos", default=0, type=float, help="Number of prediction is sometimes less than max_pred when sequence is short.") parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") args = parser.parse_args() if not(args.model_recover_path and Path(args.model_recover_path).exists()): args.model_recover_path = None args.output_dir = args.output_dir.replace( '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace( '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) if args.log_dir: os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join( args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) args.train_batch_size = int( args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing) tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() if args.do_train: print("Loading Train Dataset", args.data_dir) bi_uni_pipeline = [utils_seq2seq.Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, mask_source_words=False, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, tokenizer=data_tokenizer)] file = os.path.join( args.data_dir, args.src_file if args.src_file else 'train.tgt') train_dataset = utils_seq2seq.Seq2SeqDataset( file, args.train_batch_size, data_tokenizer, args.max_seq_length, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=utils_seq2seq.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int(len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) # Prepare model recover_step = _get_max_epoch_model(args.output_dir) if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() global_step = 0 if (recover_step is None) and (args.model_recover_path is None): model_recover = None else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor( recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load( args.model_recover_path, map_location='cpu') model = model_class.from_pretrained( args.model_name_or_path, state_dict=model_recover, config=config) if args.local_rank == 0: dist.barrier() model.to(device) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any( nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any( nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*t_total), num_training_steps=t_total) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = amp.initialize( model, optimizer, opt_level=args.fp16_opt_level) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[ args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: model = torch.nn.DataParallel(model) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) #logger.info("***** Recover amp: %d *****", recover_step) #amp_recover = torch.load(os.path.join( # args.output_dir, "amp.{0}.bin".format(recover_step)), map_location='cpu') #amp.load_state_dict(amp_recover) logger.info("***** Recover scheduler: %d *****", recover_step) scheduler_recover = torch.load(os.path.join( args.output_dir, "sched.{0}.bin".format(recover_step)), map_location='cpu') scheduler.load_state_dict(scheduler_recover) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() ##################### ### 载入验证数据###### ##################### print("Loading Dev Dataset", args.data_dir) bi_uni_pipeline = [utils_seq2seq.Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, mask_source_words=False, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, tokenizer=data_tokenizer)] file_dev = os.path.join(args.data_dir, args.dev_file if args.dev_file else 'train.tgt') dev_dataset = utils_seq2seq.Seq2SeqDataset(file_dev, args.dev_batch_size, data_tokenizer, args.max_seq_length, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: dev_sampler = RandomSampler(dev_dataset, replacement=False) _batch_size = args.dev_batch_size else: dev_sampler = DistributedSampler(dev_dataset) _batch_size = args.dev_batch_size // dist.get_world_size() dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=int(_batch_size), sampler=dev_sampler, num_workers=args.num_workers, collate_fn=utils_seq2seq.batch_list_to_batch_tensors, pin_memory=False) ##################### ######训练开始####### if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) model.train() if recover_step: start_epoch = recover_step+1 else: start_epoch = 1 for i_epoch in trange(start_epoch, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)): if args.local_rank != -1: train_sampler.set_epoch(i_epoch) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)(lr=X.XXXX)', disable=args.local_rank not in (-1, 0)) for step, batch in enumerate(iter_bar): batch = [ t.to(device) if t is not None else None for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch masked_lm_loss = model(input_ids, segment_ids, input_mask, lm_label_ids, masked_pos=masked_pos, masked_weights=masked_weights) if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() loss = masked_lm_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss={:.2f})(lr={:0.2e})'.format(loss.item(), scheduler.get_lr()[0])) # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm) if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() # Update learning rate schedule optimizer.zero_grad() global_step += 1 if step %100==0: # Save a trained model if (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info( "** ** * Saving fine-tuned model and optimizer :step{}** ** * ".format(step)) model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{}.{}.bin".format(i_epoch,step)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.output_dir, "optim.{}.{}.bin".format(i_epoch,step)) torch.save(optimizer.state_dict(), output_optim_file) output_sched_file = os.path.join( args.output_dir, "sched.{}.{}.bin".format(i_epoch,step)) torch.save(scheduler.state_dict(), output_sched_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() ################################### ################################### # 载入此轮保存的模型哦!!!!!!!! ################################### ################################### #dev_iter_bangbang1 = tqdm(dev_dataloader,disable=args.local_rank not in (-1, 0)) temp = [] with torch.no_grad(): for step, batch in enumerate(dev_dataloader): batch = [ t.to(device) if t is not None else None for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch masked_lm_loss = model(input_ids, segment_ids, input_mask, lm_label_ids, masked_pos=masked_pos, masked_weights=masked_weights) if n_gpu > 1: masked_lm_loss = masked_lm_loss.mean() dev_loss = masked_lm_loss bbb = dev_loss.cpu().detach().numpy() temp.append(bbb) dev_loss_fin = sum(temp) / step print('#### Dev loss:', dev_loss_fin,output_model_file) # Save a trained model if (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{0}.bin".format(i_epoch)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.output_dir, "optim.{0}.bin".format(i_epoch)) torch.save(optimizer.state_dict(), output_optim_file) if args.fp16: output_amp_file = os.path.join( args.output_dir, "amp.{0}.bin".format(i_epoch)) torch.save(amp.state_dict(), output_amp_file) output_sched_file = os.path.join( args.output_dir, "sched.{0}.bin".format(i_epoch)) torch.save(scheduler.state_dict(), output_sched_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() #dev_iter_bangbang = tqdm(dev_dataloader, desc='Iter (loss=X.XXX)', # disable=args.local_rank not in (-1, 0)) temp = [] with torch.no_grad(): for step, batch in enumerate(dev_dataloader): batch = [ t.to(device) if t is not None else None for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch masked_lm_loss = model(input_ids, segment_ids, input_mask, lm_label_ids, masked_pos=masked_pos, masked_weights=masked_weights) if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() dev_loss = masked_lm_loss bbb = dev_loss.cpu().detach().numpy() temp.append(bbb) dev_loss_fin = sum(temp) / step print('#### Dev loss:', dev_loss_fin, output_model_file)
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]: """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) def collate(examples: List[torch.Tensor]): if tokenizer._pad_token is None: return pad_sequence(examples, batch_first=True) return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs model = model.module if hasattr( model, "module") else model # Take care of distributed/parallel training model.resize_token_embeddings(len(tokenizer)) # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if (args.model_name_or_path and os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if args.model_name_or_path and os.path.exists(args.model_name_or_path): try: # set global_step to gobal_step of last saved checkpoint from model path checkpoint_suffix = args.model_name_or_path.split("-")[-1].split( "/")[0] global_step = int(checkpoint_suffix) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: logger.info(" Starting fine-tuning.") tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproducibility for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) if args.local_rank != -1: train_sampler.set_epoch(epoch) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) inputs = inputs.to(args.device) labels = labels.to(args.device) model.train() outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model( inputs, labels=labels) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if ( args.local_rank == -1 and args.evaluate_during_training ): # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: checkpoint_prefix = "checkpoint" # Save model checkpoint output_dir = os.path.join( args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) os.makedirs(output_dir, exist_ok=True) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) _rotate_checkpoints(args, checkpoint_prefix) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step
def train(rank, args, model, model_t, train_dataset_qa, test_dataset_qa, scale_tune): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train model",train_count) printlog(model) per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size gradient_accumulation_steps = args.total_train_batch_size // train_batch_size num_train_epochs = args.num_train_epochs if scale_tune: gradient_accumulation_steps = 1 num_train_epochs = 1 if rank < 0: #single process take all samples sampler = RandomSampler(train_dataset_qa) dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=train_batch_size, num_workers=4) else: #special sampler that divide samples beween processes sampler = torch.utils.data.distributed.DistributedSampler(train_dataset_qa, rank=rank) dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=per_gpu_train_batch_size) steps_total = int(len(dataloader) // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and schedule freeze_list = args.freeze_list.split(',') if args.freeze_list else [] named_params = [] for n, p in model.named_parameters(): if n.lower()!="none" and any(fn in n for fn in freeze_list): if rank in [-1, 0]: logger.warning("rank {} {} param is frozen and excluded from tune".format(rank,n)) continue named_params.append( (n, p) ) # split parameters to scale and the rest named_params_scale = [(n, p) for n, p in named_params if '.scale' in n] named_params_rest = [(n, p) for n, p in named_params if '.scale' not in n] if scale_tune: #keep only scale parameters named_params = named_params_scale named_params_rest = [] groups = [] if named_params_scale: groups.append({'params': [p for n, p in named_params_scale], 'lr': 0.01}) if named_params_rest: groups.append({'params': [p for n, p in named_params_rest], 'lr': args.learning_rate}) optimizer = AdamW( groups, eps=1e-08, lr=args.learning_rate, weight_decay=0) def lr_lambda(current_step): p = float(current_step) / float(steps_total) return 1 - p scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: for n,p in named_params: printlog('param for tune',n) printlog("scale_tune", scale_tune ) printlog("dataset size", len(train_dataset_qa) ) printlog("epoches", num_train_epochs ) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size ) printlog("n_gpu", args.n_gpu ) printlog("world_size", world_size ) printlog("gradient_accumulation_steps", gradient_accumulation_steps ) printlog("total train batch size", train_batch_size * gradient_accumulation_steps ) printlog("steps_total",steps_total ) global_step = 0 model.zero_grad() indicators = collections.defaultdict(list) softplus = torch.nn.Softplus() loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) if args.loss_cfg else dict() for epoch in range(math.ceil(num_train_epochs)): indicators = collections.defaultdict(list) model.train() set_output_hidden_states(rank, model, (model_t is not None)) utils.sync_models(rank, model) if model_t is not None: set_output_hidden_states(rank, model_t, True) model_t.train() if rank > -1: #set epoch to make different samples division betwen process for different epoches sampler.set_epoch(epoch) for step, batch in enumerate(dataloader): epoch_fp = epoch + step/len(dataloader) if epoch_fp > num_train_epochs: break epoch_fp = epoch + step/len(dataloader) losses = [] inputs = get_inputs(batch, args.device) targets = get_targets(batch, args.device) outputs = model(**inputs, **targets, output_hidden_states=(model_t is not None)) losses.append(outputs[0]) outputs = outputs[1:] if model_t is not None: with torch.no_grad(): outputs_t = model_t(**inputs, output_hidden_states=True) hidden_t = outputs_t[2] assert isinstance(hidden_t, (tuple,list)), "hidden states output is not detected right" assert len(hidden_t) == model_t.config.num_hidden_layers+1, "hidden states output is not detected right" if args.kd_weight>0: # Calculate knowladge distilation loss kd_losses = [] for logit_s,logit_t in zip(outputs[0:2],outputs_t[0:2]): T = 1 prob_t = torch.nn.functional.softmax(logit_t.detach() / T, dim=1) logprob_s = torch.nn.functional.log_softmax(logit_s / T, dim=1) kd_losses.append( -(logprob_s * prob_t).mean() * (T * T * prob_t.shape[1]) ) losses.append(args.kd_weight*sum(kd_losses)/len(kd_losses)) hidden_s = outputs[2] assert isinstance(hidden_s, (tuple,list)), "hidden states output is not detected right" assert len(hidden_s) == model.config.num_hidden_layers+1, "hidden states output is not detected right" def align_and_loss_outputs(out_s, out_t): if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [out_t[(i*(n_t-1))//(n_s-1)] for i in range(n_s)] assert len(out_s) == len(out_t), "can not align number of outputs between student and teacher" assert all(s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape)), "output shapes for student and teacher are not the same" return [(s - t.detach()).pow(2).mean() for s,t in zip(out_s, out_t)] sw_losses = align_and_loss_outputs(hidden_s,hidden_t) losses.extend([args.supervision_weight*l for l in sw_losses]) #average over batch losses = [l.mean() for l in losses] l = sum(losses)/len(losses) indicators['loss'].append(l.item()) indicators['ll'].append([lll.item() for lll in losses]) (l/gradient_accumulation_steps).backward() del l if (step + 1) % gradient_accumulation_steps == 0: global_step += 1 utils.sync_grads(rank, named_params, report_no_grad_params=(global_step==1)) torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() if global_step % 50 == 0: # Log metrics wall_time = epoch + step / len(dataloader) lrp = " ".join(['{:.2f}'.format(t) for t in np.log10(scheduler.get_last_lr())]) str_out = "{} ep {:.2f} lrp {}".format(train_count, epoch_fp, lrp) for k,v in indicators.items(): v = np.array(v) if len(v.shape)==1: v = v[:,None] if rank>-1: #sync indicators vt = torch.tensor(v).to(args.device) torch.distributed.all_reduce(vt, op=torch.distributed.ReduceOp.SUM) v = vt.cpu().numpy() / float(world_size) str_out += " {} {}".format(k," ".join(["{:.3f}".format(t) for t in v.mean(0)])) if 'time_last' in locals(): #estimate processing times dt_iter = (time.time() - time_last) / len(indicators['loss']) dt_ep = dt_iter * len(dataloader) str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format(dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) time_last = time.time() indicators = collections.defaultdict(list) if rank in [-1, 0]: logger.info(str_out) if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) model.eval() set_output_hidden_states(rank, model, False) result_s = evaluate(args, model, test_dataset_qa) for k,v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, result_s[k])) if rank>-1: torch.distributed.barrier()
def train(local_rank, config): global_rank = config.node_rank * config.n_gpus + local_rank print(f"local rank: {[local_rank]}, global_rank: {[global_rank]}") # multi-gpu init if torch.cuda.is_available(): if config.world_size > 1: dist.init_process_group( backend='nccl', init_method='env://', world_size=config.world_size, rank=global_rank ) torch.cuda.set_device(local_rank) DEVICE = torch.device("cuda", local_rank) else: DEVICE = torch.device("cuda") else: DEVICE = torch.device("cpu") # build tokenizer tokenizer = T5Tokenizer( vocab_file="../data/tokenizer/google_sp.model", bos_token="<s>", eos_token="</s>", unk_token="<unk>", pad_token="[PAD]", cls_token="[CLS]", sep_token="[SEP]", mask_token="[MASK]", extra_ids=0, additional_special_tokens=(), do_lower_case=True ) # build data source and reporters trn_reporter = StatisticsReporter() dev_reporter = StatisticsReporter() # get data filepaths train_filepaths = [] dev_filepaths = [] for corpus in config.corpora: if corpus == "jp_cc100": from corpus.jp_cc100.config import Config corpus_config = Config() dev_file_idx = 42 corpus_filepaths = sorted(list(filter( lambda x: x.endswith(".txt"), os.listdir(corpus_config.doc_data_dir) ))) for file_idx, filepath in enumerate(corpus_filepaths): if file_idx == dev_file_idx: dev_filepaths.append(f"{corpus_config.doc_data_dir}/{corpus_filepaths[file_idx]}") else: train_filepaths.append(f"{corpus_config.doc_data_dir}/{corpus_filepaths[file_idx]}") if config.small_data: train_filepaths = train_filepaths[:2] # load dev data if global_rank == 0: dev_docs = [] for dev_filepath in dev_filepaths: dev_docs += load_docs_from_filepath(dev_filepath, tokenizer) dev_docs = dev_docs[:10000] mp_print("----- Loading dev data -----", global_rank) dev_data_source = DataSource(config, tokenizer, dev_docs, "dev", randomize=False) mp_print(str(dev_data_source.statistics), global_rank) dev_dataloader = torch.utils.data.DataLoader( dev_data_source, batch_size=config.eval_batch_size, num_workers=0, collate_fn=collate_fn, pin_memory=False ) # build model model_config = PretrainedConfig.from_json_file(config.model_config_filepath) model = GPT2LMHeadModel(model_config) model = model.to(DEVICE) # load model from checkpoint if config.checkpoint_path: mp_print("----- Checkpoint loaded -----", global_rank) mp_print("checkpoint path: {}".format(config.checkpoint_path), global_rank) checkpoint = torch.load(config.checkpoint_path, map_location=DEVICE) mp_print("loading model state dict...", global_rank) model.load_state_dict(checkpoint["model"]) model.tie_weights() # NOTE: don't forget to tie weights after loading weights # use mixed precision if config.use_amp: scaler = amp.GradScaler() # use multi gpus if config.world_size > 1: model = DDP( model, device_ids=[local_rank], find_unused_parameters=True ) # build optimizer optimizer = optim.AdamW( model.parameters(), lr=config.init_lr, weight_decay=config.l2_penalty ) # build lr scheduler lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=config.n_warmup_steps, num_training_steps=config.n_training_steps, ) # init environment or load from checkpoint if config.checkpoint_path: if config.resume_training: mp_print("loading optimizer state dict...", global_rank) optimizer.load_state_dict(checkpoint["optimizer"]) mp_print("recovering lr scheduler...", global_rank) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) mp_print("recovering others...", global_rank) n_step = checkpoint["n_step"] start_n_epoch = checkpoint["n_epoch"] start_train_file_idx = checkpoint["start_train_file_idx"] best_ppl = getattr(checkpoint, "best_ppl", float("inf")) else: n_step = 0 start_n_epoch = 0 start_train_file_idx = 0 best_ppl = float("inf") OUTPUT_FILEID = checkpoint["output_fileid"] del checkpoint else: n_step = 0 start_n_epoch = 0 start_train_file_idx = 0 best_ppl = float("inf") # names OUTPUT_FILEID = "gpt2-ja-{}.seed_{}.{}".format( config.model_size, config.seed, time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) ) if config.filename_note: OUTPUT_FILEID += f".{config.filename_note}" # define logger def mlog(s): if global_rank == 0: if config.enable_log: if not os.path.exists("../log/pretrain"): os.makedirs("../log/pretrain") with open(f"../log/pretrain/{OUTPUT_FILEID}.log", "a+", encoding="utf-8") as log_f: log_f.write(s+"\n") mp_print(s, global_rank) if config.enable_log: if global_rank == 0: tb_writer = SummaryWriter( log_dir=f"../log/pretrain/{OUTPUT_FILEID}", max_queue=5 ) # log hyper parameters start_time = time.time() mlog("----- Hyper-parameters -----") for k, v in sorted(dict(config.__dict__).items()): mlog("{}: {}".format(k, v)) for epoch_idx in range(start_n_epoch, config.n_epochs): for train_file_idx in range(start_train_file_idx, len(train_filepaths), config.n_train_files_per_group): group_train_filepaths = train_filepaths[train_file_idx:train_file_idx+config.n_train_files_per_group] with mp.Pool(processes=config.n_train_files_per_group) as pool: group_train_docs = pool.starmap( load_docs_from_filepath, [(train_filepath, tokenizer) for train_filepath in group_train_filepaths] ) train_docs = [doc for docs in group_train_docs for doc in docs] train_data_source = DataSource(config, tokenizer, train_docs, "train", randomize=True) mp_print(str(train_data_source.statistics), global_rank) # single gpu or cpu if config.world_size == 1 or not torch.cuda.is_available(): train_data_sampler = RandomSampler( train_data_source, replacement=False ) train_dataloader = torch.utils.data.DataLoader( train_data_source, batch_size=config.batch_size, sampler=train_data_sampler, num_workers=0, collate_fn=collate_fn, pin_memory=True ) # multi gpus else: train_data_sampler = DistributedSampler( train_data_source, num_replicas=config.world_size, rank=global_rank ) train_dataloader = torch.utils.data.DataLoader( train_data_source, batch_size=config.batch_size, sampler=train_data_sampler, num_workers=0, collate_fn=collate_fn, pin_memory=False ) if isinstance(train_data_sampler, DistributedSampler): train_data_sampler.set_epoch(epoch_idx) for batch_data in train_dataloader: n_step += 1 # stop if reaches the maximum tranining step if n_step >= config.n_training_steps: break # forward model.train() if config.use_amp: with amp.autocast(): loss, ppl = forward_step(model, tokenizer, batch_data) else: loss, ppl = forward_step(model, tokenizer, batch_data) # update statisitcs trn_reporter.update_data({"ppl": ppl.item(), "loss": loss.item()}) # backward loss /= config.n_accum_steps if config.use_amp: scaler.scale(loss).backward() else: loss.backward() del loss if n_step % config.n_accum_steps == 0: # clip gradient if config.max_grad_norm > 0.0: if config.use_amp: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) # update model parameters if config.use_amp: scaler.step(optimizer) scaler.update() else: optimizer.step() # zero gradients optimizer.zero_grad() # check loss if n_step > 0 and n_step % config.check_loss_after_n_step == 0: lr = list(lr_scheduler.optimizer.param_groups)[0]["lr"] log_s = f"{time.time()-start_time:.2f}s Epoch {epoch_idx}, step {n_step}, lr {lr:.5g} - " log_s += trn_reporter.to_string() mlog(log_s) if config.enable_log and global_rank == 0: for k, v in trn_reporter.items(): tb_writer.add_scalar(f"{k}/train", np.mean(v), n_step) trn_reporter.clear() # evaluation on dev dataset if global_rank == 0 and n_step > 0 and n_step % config.validate_after_n_step == 0: # forward with torch.no_grad(): model.eval() # use only 1 gpu for evaluation in multi-gpu situation if config.world_size > 1: eval_model = model.module else: eval_model = model for eval_batch_idx, eval_batch_data in enumerate(dev_dataloader): if config.use_amp: with amp.autocast(): loss, ppl = forward_step(eval_model, tokenizer, eval_batch_data) else: loss, ppl = forward_step(eval_model, tokenizer, eval_batch_data) dev_reporter.update_data({"ppl": ppl.item(), "loss": loss.item()}) if eval_batch_idx == len(dev_dataloader) - 1: break log_s = f"\n<Dev> - {time.time()-start_time:.3f}s - " log_s += dev_reporter.to_string() mlog(log_s) # Save model if it has better monitor measurement if config.save_model: if not os.path.exists("../data/model/pretrain"): os.makedirs("../data/model/pretrain") model_to_save = model.module if hasattr(model, 'module') else model # save current model checkpoint = { "model": model_to_save.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "n_epoch": epoch_idx, "n_step": n_step, "start_train_file_idx": train_file_idx, "output_fileid": OUTPUT_FILEID, "best_ppl": best_ppl } torch.save( checkpoint, f"../data/model/pretrain/{OUTPUT_FILEID}.checkpoint" ) mlog(f"checkpoint saved to data/model/pretrain/{OUTPUT_FILEID}.checkpoint") # save best model cur_ppl = dev_reporter.get_value("ppl") if cur_ppl < best_ppl: best_ppl = cur_ppl torch.save( checkpoint, f"../data/model/pretrain/{OUTPUT_FILEID}.best.checkpoint" ) mlog(f"best checkpoint saved to data/model/pretrain/{OUTPUT_FILEID}.best.checkpoint") if config.enable_log: for k, v in dev_reporter.items(): tb_writer.add_scalar(f"{k}/dev", np.mean(v), n_step) dev_reporter.clear() # decay learning rate lr_scheduler.step(dev_reporter.get_value("ppl")) # reset starting training file index for every epoch (if might be set to a larger value if resuming from a checkpoint) start_train_file_idx = 0
def main(): args = process_args() if args.loss_type == 'mlm': assert args.neg_num == 0 and args.multiple_neg == 0 elif args.loss_type == 'nsp': assert int(args.bi_prob) == 1 and args.max_pred == 0 and args.neg_num > 0 if args.adaptive_weight == 1: assert args.neg_num > 1 if args.add_boundary == 1: assert args.inc_full_hist if args.world_size > 1: print('global_rank: {}, local rank: {}'.format(args.global_rank, args.local_rank)) # Input format: [CLS] img [SEP] hist [SEP_0] ques [SEP_1] ans [SEP] args.max_seq_length = args.len_vis_input + 2 + args.max_len_hist_ques + 2 + args.max_len_ans + 1 args.mask_image_regions = (args.vis_mask_prob > 0) # whether to mask out image regions args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir) # arguments inspection assert args.enable_butd, 'only support region attn! featmap attn deprecated' if args.enable_butd: if args.visdial_v == '1.0': assert (args.len_vis_input == 36) or (args.len_vis_input == 0) elif args.visdial_v == '0.9': if (args.len_vis_input == 100): args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join(args.image_root, args.region_det_file_prefix) if args.dataset in ( 'cc', 'coco') and args.region_det_file_prefix != '' else '' # output config os.makedirs(args.output_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join( args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) logging.basicConfig( filename=os.path.join(args.output_dir, args.log_file), filemode='w', format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) ch = logging.StreamHandler(sys.stdout) ch.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')) ch.setLevel(logging.INFO) logger.addHandler(ch) if args.local_rank == -1 or args.no_cuda: device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.global_rank) logger.info('Arguments: %s\n' % (' '.join(sys.argv[:]))) logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) args.train_batch_size = int( args.train_batch_size / args.gradient_accumulation_steps) # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # plotting loss, optional if args.enable_visdom: import visdom vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir) vis_window = {'iter': None, 'score': None} tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank)) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer assert args.do_train logger.info('Max seq length: %d, batch size: %d\n' % (args.max_seq_length, args.train_batch_size)) bi_uni_pipeline = [Preprocess4TrainVisdial(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={'len_vis_input': args.len_vis_input, 'max_len_hist_ques': args.max_len_hist_ques, 'max_len_ans': args.max_len_ans}, mask_image_regions=args.mask_image_regions, mode="s2s", vis_mask_prob=args.vis_mask_prob, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, image_features_hdfpath=args.image_features_hdfpath, visdial_v=args.visdial_v, pad_hist=args.pad_hist, finetune=args.finetune, only_mask_ans=args.only_mask_ans, add_boundary=args.add_boundary, only_qa=args.only_qa), Preprocess4TrainVisdial(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={'len_vis_input': args.len_vis_input, 'max_len_hist_ques': args.max_len_hist_ques, 'max_len_ans': args.max_len_ans}, mask_image_regions=args.mask_image_regions, mode="bi", vis_mask_prob=args.vis_mask_prob, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, image_features_hdfpath=args.image_features_hdfpath, visdial_v=args.visdial_v, pad_hist=args.pad_hist, finetune=args.finetune, only_mask_ans=args.only_mask_ans, add_boundary=args.add_boundary, only_qa=args.only_qa)] train_dataset = VisdialDataset( args.src_file, args.train_batch_size, data_tokenizer, use_num_imgs=args.use_num_imgs, bi_uni_pipeline=bi_uni_pipeline, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob, is_train=args.do_train, neg_num=args.neg_num, inc_gt_rel=args.inc_gt_rel, inc_full_hist=args.inc_full_hist, just_for_pretrain=args.just_for_pretrain, sub_sample=args.sub_sample) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) t_total = int(len(train_dataloader) * args.num_train_epochs * 1. / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 relax_projection = 4 if args.relax_projection else 0 task_idx_proj = 3 if args.tasks == 'img2txt' else 0 mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[PAD]"]) # index in BERT vocab: 103, 102, 0 if (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init assert args.scst == False, 'must init from maximum likelihood training' _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, visdial_v=args.visdial_v, loss_type=args.loss_type, neg_num=args.neg_num, adaptive_weight=args.adaptive_weight, add_attn_fuse=args.add_attn_fuse, no_h0=args.no_h0, no_vision=args.no_vision) global_step = 0 else: if args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path) global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, visdial_v=args.visdial_v, loss_type=args.loss_type, neg_num=args.neg_num, adaptive_weight=args.adaptive_weight, add_attn_fuse=args.add_attn_fuse, no_h0=args.no_h0, no_vision=args.no_vision) del model_recover torch.cuda.empty_cache() if args.fp16: model.half() # cnn.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any( nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any( nd in n for nd in no_decay)], 'weight_decay': 0.0} ] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State( optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State( optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, schedule=args.sche_mode, t_total=t_total) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) logger.info(" Loader length = %d", len(train_dataloader)) model.train() start_epoch = 1 logger.info("Begin training from epoch = %d", start_epoch) t0 = time.time() for i_epoch in trange(start_epoch, args.num_train_epochs + 1, desc="Epoch"): if args.multiple_neg and i_epoch > 1: train_dataset = VisdialDataset( args.src_file, args.train_batch_size, data_tokenizer, use_num_imgs=args.use_num_imgs, bi_uni_pipeline=bi_uni_pipeline, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob, is_train=args.do_train, neg_num=args.neg_num, inc_gt_rel=args.inc_gt_rel, inc_full_hist=args.inc_full_hist, just_for_pretrain=args.just_for_pretrain, sub_sample=args.sub_sample) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True) if args.local_rank >= 0: train_sampler.set_epoch(i_epoch - 1) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)') nbatches = len(train_dataloader) losses = [] pretext_loss = [] mlm_losses = [] nsp_losses = [] for step, batch in enumerate(iter_bar): batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, \ task_idx, vis_masked_pos, img, vis_pe = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data loss_tuple = model(conv_feats, vis_pe, input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, vis_masked_pos=vis_masked_pos, mask_image_regions=args.mask_image_regions, drop_worst_ratio=args.max_drop_worst_ratio if i_epoch > args.drop_after else 0) # disable pretext_loss_deprecated for now masked_lm_loss, pretext_loss_deprecated, nsp_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. For dist, this is done through gradient addition. masked_lm_loss = masked_lm_loss.mean() pretext_loss_deprecated = pretext_loss_deprecated.mean() nsp_loss = nsp_loss.mean() loss = masked_lm_loss + pretext_loss_deprecated + nsp_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) losses.append(loss.item()) mlm_losses.append(masked_lm_loss.item()) pretext_loss.append(pretext_loss_deprecated.item()) nsp_losses.append(nsp_loss.item()) if step % max(1, nbatches // 10) == 0: logger.info( "Epoch {}, Iter {}, Loss {:.4f}, MLM {:.4f}, NSP {:.4f}, Elapse time {:.2f}\n".format( i_epoch, step, np.mean(losses), np.mean(mlm_losses), np.mean(nsp_losses), time.time() - t0)) if args.enable_visdom: if vis_window['iter'] is None: vis_window['iter'] = vis.line( X=np.tile(np.arange((i_epoch - 1) * nbatches + step, (i_epoch - 1) * nbatches + step + 1), (1, 1)).T, Y=np.column_stack((np.asarray([np.mean(losses)]),)), opts=dict(title='Training Loss', xlabel='Training Iteration', ylabel='Loss', legend=['total']) ) else: vis.line( X=np.tile(np.arange((i_epoch - 1) * nbatches + step, (i_epoch - 1) * nbatches + step + 1), (1, 1)).T, Y=np.column_stack((np.asarray([np.mean(losses)]),)), opts=dict(title='Training Loss', xlabel='Training Iteration', ylabel='Loss', legend=['total']), win=vis_window['iter'], update='append' ) # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step / t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.%d.%.3f.bin" % (i_epoch, np.mean(losses))) if args.global_rank in (-1, 0): # save model if the first device or no dist torch.save(copy.deepcopy(model_to_save).cpu().state_dict(), output_model_file) logger.info("Save model to %s", output_model_file) logger.info("Finish training epoch %d, avg loss: %.2f and takes %.2f seconds" % ( i_epoch, np.mean(losses), time.time() - t0)) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.world_size > 1: torch.distributed.barrier()
def main(): parser = argparse.ArgumentParser() # General parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default='tmp', type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_file", default="training.log", type=str, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument( "--do_train", action='store_true', help="Whether to run training. This should ALWAYS be set to True.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=64, type=int, help="Total batch size for training.") parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=30, type=int, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument("--global_rank", type=int, default=-1, help="global_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 32-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input") parser.add_argument('--max_len_b', type=int, default=20, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='b', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=3, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=4, type=int, help="Number of workers for the data loader.") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") # Others for VLP parser.add_argument( "--src_file", default=['/mnt/dat/COCO/annotations/dataset_coco.json'], type=str, nargs='+', help="The input data file name.") parser.add_argument('--enable_visdom', action='store_true') parser.add_argument('--visdom_port', type=int, default=8888) # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth') parser.add_argument('--image_root', type=str, default='/mnt/dat/COCO/images') parser.add_argument('--dataset', default='coco', type=str, help='coco | flickr30k | cc') parser.add_argument('--split', type=str, nargs='+', default=['train', 'restval']) parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist_url', default='file://[PT_OUTPUT_DIR]/nonexistent_file', type=str, help='url used to set up distributed training') parser.add_argument( '--file_valid_jpgs', default='/mnt/dat/COCO/annotations/coco_valid_jpgs.json', type=str) parser.add_argument('--sche_mode', default='warmup_linear', type=str, help="warmup_linear | warmup_constant | warmup_cosine") parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--use_num_imgs', default=-1, type=int) parser.add_argument('--vis_mask_prob', default=0, type=float) parser.add_argument('--max_drop_worst_ratio', default=0, type=float) parser.add_argument('--drop_after', default=6, type=int) parser.add_argument( '--s2s_prob', default=1, type=float, help="Percentage of examples that are bi-uni-directional LM (seq2seq)." ) parser.add_argument( '--bi_prob', default=0, type=float, help="Percentage of examples that are bidirectional LM.") parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument( '--region_bbox_file', default= 'coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', type=str) parser.add_argument( '--region_det_file_prefix', default= 'feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval', type=str) parser.add_argument('--tasks', default='img2txt', help='img2txt | vqa2') parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--scst', action='store_true', help='Self-critical sequence training') args = parser.parse_args() print('global_rank: {}, local rank: {}'.format(args.global_rank, args.local_rank)) args.max_seq_length = args.max_len_b + args.len_vis_input + 3 # +3 for 2x[SEP] and [CLS] args.mask_image_regions = (args.vis_mask_prob > 0 ) # whether to mask out image regions args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir) # arguments inspection assert (args.tasks in ('img2txt', 'vqa2')) assert args.enable_butd == True, 'only support region attn! featmap attn deprecated' assert ( not args.scst) or args.dataset == 'coco', 'scst support on coco only!' if args.scst: assert args.dataset == 'coco', 'scst support on coco only!' assert args.max_pred == 0 and args.mask_prob == 0, 'no mask for scst!' rl_crit = RewardCriterion() if args.enable_butd: assert (args.len_vis_input == 100) args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join( args.image_root, args.region_det_file_prefix) if args.dataset in ( 'cc', 'coco') and args.region_det_file_prefix != '' else '' # output config os.makedirs(args.output_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) logging.basicConfig( filename=os.path.join(args.output_dir, args.log_file), filemode='w', format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group( backend='nccl', init_method='tcp://localhost:10001', #args.dist_url, world_size=args.world_size, rank=args.global_rank) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # plotting loss, optional if args.enable_visdom: import visdom vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir) vis_window = {'iter': None, 'score': None} tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank)) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.do_train: bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_image_regions=args.mask_image_regions, mode="s2s", len_vis_input=args.len_vis_input, vis_mask_prob=args.vis_mask_prob, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, local_rank=args.local_rank, load_vqa_ann=(args.tasks == 'vqa2')) ] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_image_regions=args.mask_image_regions, mode="bi", len_vis_input=args.len_vis_input, vis_mask_prob=args.vis_mask_prob, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, local_rank=args.local_rank, load_vqa_ann=(args.tasks == 'vqa2'))) train_dataset = seq2seq_loader.Img2txtDataset( args.src_file, args.image_root, args.split, args.train_batch_size, data_tokenizer, args.max_seq_length, file_valid_jpgs=args.file_valid_jpgs, bi_uni_pipeline=bi_uni_pipeline, use_num_imgs=args.use_num_imgs, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob, enable_butd=args.enable_butd, tasks=args.tasks) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) t_total = int( len(train_dataloader) * args.num_train_epochs * 1. / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 relax_projection = 4 if args.relax_projection else 0 task_idx_proj = 3 if args.tasks == 'img2txt' else 0 mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[PAD]"]) # index in BERT vocab: 103, 102, 0 if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init assert args.scst == False, 'must init from maximum likelihood training' _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks=args.tasks) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load( os.path.join(args.output_dir, "model.{0}.bin".format(recover_step))) # recover_step == number of epochs global_step = math.floor(recover_step * t_total * 1. / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path) global_step = 0 if not args.scst: model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks=args.tasks) else: model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, task_idx=task_idx_proj, mask_word_id=mask_word_id, search_beam_size=1, eos_id=eos_word_ids, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input) del model_recover torch.cuda.empty_cache() # deprecated # from vlp.resnet import resnet # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning if args.fp16: model.half() # cnn.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) # cnn.to(device) if args.local_rank != -1: try: # from apex.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # cnn = DDP(cnn) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # cnn = DataParallelImbalance(cnn) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, schedule=args.sche_mode, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load( os.path.join(args.output_dir, "optim.{0}.bin".format(recover_step))) if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) logger.info(" Loader length = %d", len(train_dataloader)) model.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, args.num_train_epochs + 1, desc="Epoch"): if args.local_rank >= 0: train_sampler.set_epoch(i_epoch - 1) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)') nbatches = len(train_dataloader) train_loss = [] pretext_loss = [] vqa2_loss = [] scst_reward = [] for step, batch in enumerate(iter_bar): batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, ans_labels = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data else: conv_feats, _ = cnn(img.data) # Bx2048x7x7 conv_feats = conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1).permute(0, 2, 1).contiguous() if not args.scst: loss_tuple = model( conv_feats, vis_pe, input_ids, segment_ids, input_mask, lm_label_ids, ans_labels, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, vis_masked_pos=vis_masked_pos, mask_image_regions=args.mask_image_regions, drop_worst_ratio=args.max_drop_worst_ratio if i_epoch > args.drop_after else 0) mean_reward = loss_tuple[0].new(1).fill_(0) else: # scst training model.eval() position_ids = torch.arange( input_ids.size(1), dtype=input_ids.dtype, device=input_ids.device).unsqueeze(0).expand_as( input_ids) input_dummy = input_ids[:, :args.len_vis_input + 2] # +2 for [CLS] and [SEP] greedy_res = input_ids.new( input_ids.size(0), input_ids.size(1) - args.len_vis_input - 2).fill_(0) gen_result = input_ids.new( input_ids.size(0), input_ids.size(1) - args.len_vis_input - 2).fill_(0) with torch.no_grad(): greedy_res_raw, _ = model(conv_feats, vis_pe, input_dummy, segment_ids, position_ids, input_mask, task_idx=task_idx, sample_mode='greedy') for b in range(greedy_res_raw.size(0)): for idx in range(greedy_res_raw.size(1)): if greedy_res_raw[b][idx] not in [ eos_word_ids, pad_word_ids ]: greedy_res[b][idx] = greedy_res_raw[b][idx] else: if greedy_res_raw[b][idx] == eos_word_ids: greedy_res[b][idx] = eos_word_ids break model.train() gen_result_raw, sample_logprobs = model( conv_feats, vis_pe, input_dummy, segment_ids, position_ids, input_mask, task_idx=task_idx, sample_mode='sample') for b in range(gen_result_raw.size(0)): for idx in range(gen_result_raw.size(1)): if gen_result_raw[b][idx] not in [ eos_word_ids, pad_word_ids ]: gen_result[b][idx] = gen_result_raw[b][idx] else: if gen_result_raw[b][idx] == eos_word_ids: gen_result[b][idx] = eos_word_ids break gt_ids = input_ids[:, args.len_vis_input + 2:] reward = get_self_critical_reward(greedy_res, gt_ids, gen_result, gt_ids.size(0)) reward = torch.from_numpy(reward).float().to( gen_result.device) mean_reward = reward.mean() loss = rl_crit(sample_logprobs, gen_result.data, reward) loss_tuple = [ loss, loss.new(1).fill_(0.), loss.new(1).fill_(0.) ] # disable pretext_loss_deprecated for now masked_lm_loss, pretext_loss_deprecated, ans_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. For dist, this is done through gradient addition. masked_lm_loss = masked_lm_loss.mean() pretext_loss_deprecated = pretext_loss_deprecated.mean() ans_loss = ans_loss.mean() loss = masked_lm_loss + pretext_loss_deprecated + ans_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) train_loss.append(loss.item()) pretext_loss.append(pretext_loss_deprecated.item()) vqa2_loss.append(ans_loss.item()) scst_reward.append(mean_reward.item()) if step % 100 == 0: logger.info( "Epoch {}, Iter {}, Loss {:.2f}, Pretext {:.2f}, VQA2 {:.2f}, Mean R {:.3f}\n" .format(i_epoch, step, np.mean(train_loss), np.mean(pretext_loss), np.mean(vqa2_loss), np.mean(scst_reward))) if args.enable_visdom: if vis_window['iter'] is None: vis_window['iter'] = vis.line( X=np.tile( np.arange((i_epoch - 1) * nbatches + step, (i_epoch - 1) * nbatches + step + 1), (1, 1)).T, Y=np.column_stack( (np.asarray([np.mean(train_loss)]), )), opts=dict(title='Training Loss', xlabel='Training Iteration', ylabel='Loss', legend=['total'])) else: vis.line(X=np.tile( np.arange((i_epoch - 1) * nbatches + step, (i_epoch - 1) * nbatches + step + 1), (1, 1)).T, Y=np.column_stack( (np.asarray([np.mean(train_loss)]), )), opts=dict(title='Training Loss', xlabel='Training Iteration', ylabel='Loss', legend=['total']), win=vis_window['iter'], update='append') # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, "model.{0}.bin".format(i_epoch)) output_optim_file = os.path.join(args.output_dir, "optim.{0}.bin".format(i_epoch)) if args.global_rank in ( -1, 0): # save model if the first device or no dist torch.save( copy.deepcopy(model_to_save).cpu().state_dict(), output_model_file) # torch.save(optimizer.state_dict(), output_optim_file) # disable for now, need to sanitize state and ship everthing back to cpu logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.world_size > 1: torch.distributed.barrier()
def main(): args = _get_parser().parse_args() args.device_ids = list(map(int, args.device_ids.split(','))) set_seed(args) sanity_checks(args) init_gpu_params(args) tokenizer = Tokenizer( os.path.join(args.bert_model, "senti_vocab.txt"), os.path.join(args.bert_model, "RoBERTa_Sentiment_kor")) train_dataset = NSMCDataSet(data_split="train", tokenizer=tokenizer, max_seq_length=args.max_seq_length, pad_to_max=args.pad_to_max) train_sampler = RandomSampler( train_dataset) if not args.multi_gpu else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.per_gpu_train_batch_size, collate_fn=train_dataset.collate_fn) model = RobertaForSequenceClassification( classifier_dropout=args.classifier_dropout, bert_model_dir=args.bert_model, pre_trained_model=args.pretrained_bert_model) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) t_total = len(train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs warmup_steps = math.ceil(t_total * args.warmup_proportion) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) model.zero_grad() model.cuda() if args.multi_gpu: model = DistributedDataParallel( model, device_ids=[args.device_ids[args.local_rank]], output_device=args.device_ids[args.local_rank]) if args.is_master: logger.info(json.dumps(vars(args), indent=4)) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.per_gpu_train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.multi_gpu else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_steps = 0 for epoch in range(args.num_train_epochs): if args.multi_gpu: train_sampler.set_epoch(epoch) loss_bce = nn.BCEWithLogitsLoss() iter_loss = 0 model.train() pbar = tqdm(train_dataloader, desc="Iter", disable=not args.is_master) for step, batch in enumerate(pbar): input_ids, attention_mask, labels = batch inputs = { "input_ids": torch.tensor(input_ids, dtype=torch.long).cuda(), "attention_mask": torch.tensor(attention_mask, dtype=torch.long).cuda() } logits = model(**inputs) labels = torch.tensor(labels, dtype=torch.float).cuda() loss = loss_bce(input=logits.view(-1), target=labels.view(-1)) if args.gradient_accumulation_steps > 1: loss /= args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() optimizer.zero_grad() scheduler.step() global_steps += 1 if global_steps % args.save_checkpoints_steps == 0 and args.is_master: model_to_save = model.module if hasattr( model, 'module') else model save_path = os.path.join(args.save_checkpoints_dir, f"step_{global_steps}.ckpt") torch.save(model_to_save.state_dict(), save_path) iter_loss += loss.item() pbar.set_postfix({ "epoch": epoch, "global_steps": global_steps, "learning_rate": f"{scheduler.get_last_lr()[0]:.10f}", "avg_iter_loss": f"{iter_loss / (step + 1) * args.gradient_accumulation_steps:.5f}", "last_loss": f"{loss.item() * args.gradient_accumulation_steps:.5f}" }) pbar.close() if args.is_master: model_to_save = model.module if hasattr(model, 'module') else model save_path = os.path.join(args.save_checkpoints_dir, f"epoch_{epoch+1}.ckpt") torch.save(model_to_save.state_dict(), save_path)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--generation_dataset", default='openi', type=str, help=["mimic-cxr, openi"]) parser.add_argument("--vqa_rad", default="all", type=str, choices=["all", "chest", "head", "abd"]) parser.add_argument("--data_set", default="train", type=str, help="train | valid") parser.add_argument('--img_hidden_sz', type=int, default=2048, help="Whether to use amp for fp16") parser.add_argument( "--bert_model", default="bert-base-uncased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument( "--mlm_task", type=str, default=True, help="The model will train only mlm task!! | True | False") parser.add_argument("--train_batch_size", default=2, type=int, help="Total batch size for training.") parser.add_argument("--num_train_epochs", default=5, type=int, help="Total number of training epochs to perform.") parser.add_argument( '--from_scratch', action='store_true', default=False, help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument("--img_encoding", type=str, default='fully_use_cnn', choices=['random_sample', 'fully_use_cnn']) parser.add_argument( '--len_vis_input', type=int, default=256, help="The length of visual token input" ) #visual token의 fixed length를 100이라 하면, <Unknown> token 100개가 되고, 100개의 word 생성 가능. parser.add_argument('--max_len_b', type=int, default=253, help="Truncate_config: maximum length of segment B.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=10, help="Max tokens of prediction.") parser.add_argument( '--s2s_prob', default=1, type=float, help= "Percentage of examples that are bi-uni-directional LM (seq2seq). This must be turned off!!!!!!! because this is not for seq2seq model!!!" ) parser.add_argument( '--bi_prob', default=0, type=float, help="Percentage of examples that are bidirectional LM.") parser.add_argument('--hidden_size', type=int, default=768) parser.add_argument('--bar', default=False, type=str, help="True or False") parser.add_argument("--config_path", default='./pretrained_model/non_cross/config.json', type=str, help="Bert config file path.") parser.add_argument( "--model_recover_path", default='./pretrained_model/non_cross/pytorch_model.bin', type=str, help="The file of fine-tuned pretraining model.") # model load parser.add_argument( "--output_dir", default='./output_model/base_noncross_mimic_2', type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_file", default="training.log", type=str, help="The output directory where the log will be written.") parser.add_argument('--img_postion', default=True, help="It will produce img_position.") parser.add_argument( "--do_train", action='store_true', default=True, help="Whether to run training. This should ALWAYS be set to True.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) ############################################################################################################ parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument("--global_rank", type=int, default=-1, help="global_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( '--fp16', action='store_true', default=False, help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', default=False, help= "Whether to use 32-bit float precision instead of 32-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', default=False, help="Whether to use amp for fp16") parser.add_argument('--new_segment_ids', default=False, action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument( '--trunc_seg', default='b', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument("--num_workers", default=20, type=int, help="Number of workers for the data loader.") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument( '--image_root', type=str, default='/home/mimic-cxr/dataset/image_preprocessing/re_512_3ch/Train') parser.add_argument('--split', type=str, nargs='+', default=['train', 'valid']) parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist_url', default='file://[PT_OUTPUT_DIR]/nonexistent_file', type=str, help='url used to set up distributed training') parser.add_argument('--sche_mode', default='warmup_linear', type=str, help="warmup_linear | warmup_constant | warmup_cosine") parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--use_num_imgs', default=-1, type=int) parser.add_argument('--max_drop_worst_ratio', default=0, type=float) parser.add_argument('--drop_after', default=6, type=int) parser.add_argument('--tasks', default='report_generation', help='report_generation | vqa') parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") args = parser.parse_args() print('global_rank: {}, local rank: {}'.format(args.global_rank, args.local_rank)) args.max_seq_length = args.max_len_b + args.len_vis_input + 3 # +3 for 2x[SEP] and [CLS] args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir) if args.tasks == 'vqa': wandb.init(config=args, project="VQA") wandb.config["more"] = "custom" args.src_file = '/home/mimic-cxr/dataset/data_RAD' args.file_valid_jpgs = '/home/mimic-cxr/dataset/vqa_rad_original_set.json' else: if args.generation_dataset == 'mimic-cxr': wandb.init(config=args, project="report_generation") wandb.config["more"] = "custom" args.src_file = '/home/mimic-cxr/new_dset/Train_253.jsonl' args.file_valid_jpgs = '/home/mimic-cxr/new_dset/Train_253.jsonl' else: wandb.init(config=args, project="report_generation") wandb.config["more"] = "custom" args.src_file = '/home/mimic-cxr/dataset/open_i/Train_openi.jsonl' args.file_valid_jpgs = '/home/mimic-cxr/dataset/open_i/Valid_openi.jsonl' print(" # PID :", os.getpid()) os.makedirs(args.output_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) logging.basicConfig( filename=os.path.join(args.output_dir, args.log_file), filemode='w', format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") print("device", device) n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) print("device", device) n_gpu = 1 torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.global_rank) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True) if args.do_train: print("args.mask_prob", args.mask_prob) print("args.train_batch_size", args.train_batch_size) bi_uni_pipeline = [ data_loader.Preprocess4Seq2seq( args, args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, args.bar, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mode="s2s", len_vis_input=args.len_vis_input, local_rank=args.local_rank, load_vqa_set=(args.tasks == 'vqa')) ] bi_uni_pipeline.append( data_loader.Preprocess4Seq2seq( args, args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, args.bar, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mode="bi", len_vis_input=args.len_vis_input, local_rank=args.local_rank, load_vqa_set=(args.tasks == 'vqa'))) train_dataset = data_loader.Img2txtDataset( args, args.data_set, args.src_file, args.image_root, args.split, args.train_batch_size, tokenizer, args.max_seq_length, file_valid_jpgs=args.file_valid_jpgs, bi_uni_pipeline=bi_uni_pipeline, use_num_imgs=args.use_num_imgs, s2s_prob=args.s2s_prob, # this must be set to 1. bi_prob=args.bi_prob, tasks=args.tasks) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True) t_total = int( len(train_dataloader) * args.num_train_epochs * 1. / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 relax_projection = 4 if args.relax_projection else 0 task_idx_proj = 3 if args.tasks == 'report_generation' else 0 mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[PAD]"]) # index in BERT vocab: 103, 102, 0 # BERT model will be loaded! from scratch if (args.model_recover_path is None): _state_dict = {} if args.from_scratch else None _state_dict = {} model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, args=args, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, len_vis_input=args.len_vis_input, tasks=args.tasks) print("scratch model's statedict : ") for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size()) global_step = 0 print("The model will train from scratch") else: print("Task :", args.tasks, args.s2s_prob) print("Recoverd model :", args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(model_recover_path) for key in list(model_recover.keys()): model_recover[key.replace('enc.', '').replace( 'mlm.', 'cls.')] = model_recover.pop(key) global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, args=args, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, len_vis_input=args.len_vis_input, tasks=args.tasks) model.load_state_dict(model_recover, strict=False) print("The pretrained model loaded and fine-tuning.") del model_recover torch.cuda.empty_cache() if args.fp16: model.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: model = DataParallelImbalance(model) wandb.watch(model) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, schedule=args.sche_mode, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load( os.path.join(args.output_dir, "optim.{0}.bin".format(recover_step))) if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") model.train() print("Total Parameters:", sum([p.nelement() for p in model.parameters()])) if recover_step: start_epoch = recover_step + 1 print("start_epoch", start_epoch) else: start_epoch = 1 for i_epoch in trange(start_epoch, args.num_train_epochs + 1, desc="Epoch"): if args.local_rank >= 0: train_sampler.set_epoch(i_epoch - 1) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)') nbatches = len(train_dataloader) train_loss = [] avg_loss = 0.0 batch_count = 0 for step, batch in enumerate(iter_bar): batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, task_idx, img, vis_pe, ans_labels, ans_type, organ = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() loss_tuple = model(img, vis_pe, input_ids, segment_ids, input_mask, lm_label_ids, ans_labels, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, drop_worst_ratio=args.max_drop_worst_ratio if i_epoch > args.drop_after else 0, ans_type=ans_type) masked_lm_loss, vqa_loss = loss_tuple batch_count += 1 if args.tasks == 'report_generation': masked_lm_loss = masked_lm_loss.mean() loss = masked_lm_loss else: vqa_loss = vqa_loss.mean() loss = vqa_loss iter_bar.set_description('Iter (loss=%5.3f)' % (loss.item())) train_loss.append(loss.item()) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 wandb.log({"train_loss": np.mean(train_loss)}) logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_config_file = os.path.join(args.output_dir, 'config.json') with open(output_config_file, 'w') as f: f.write(model_to_save.config.to_json_string()) output_model_file = os.path.join(args.output_dir, "model.{0}.bin".format(i_epoch)) output_optim_file = os.path.join(args.output_dir, "optim.{0}.bin".format(i_epoch)) if args.global_rank in ( -1, 0): # save model if the first device or no dist torch.save( copy.deepcopy(model_to_save).cpu().state_dict(), output_model_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.world_size > 1: torch.distributed.barrier()
def train(rank, args, model, model_t, train_dataset_qc, test_dataset_qc, fq_tune_only, model_controller): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train model", train_count) printlog(model) q_dataset = train_dataset_qc.q_dataset per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size if fq_tune_only: gradient_accumulation_steps = 1 num_train_epochs = 1 else: gradient_accumulation_steps = args.total_train_batch_size // train_batch_size num_train_epochs = args.num_train_epochs if rank < 0: #single process take all q_sampler = RandomSampler(q_dataset) q_dataloader = DataLoader(q_dataset, sampler=q_sampler, batch_size=train_batch_size, num_workers=4) else: #special sampler that divide samples between processes q_sampler = torch.utils.data.distributed.DistributedSampler(q_dataset, rank=rank) q_dataloader = DataLoader(q_dataset, sampler=q_sampler, batch_size=per_gpu_train_batch_size) steps_total = int( len(q_dataloader) // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and schedule named_params, groups = utils.make_param_groups( rank, model, args. freeze_list, #list or str with subnames to define frozen parameters args.learning_rate, #learning rate for no FQ parameters 0.01, # learning rate for FQ parameters fq_tune_only, #true if only FQ parameters will be optimized model_controller) optimizer = AdamW(groups, eps=1e-08, lr=args.learning_rate, weight_decay=0) def lr_lambda(current_step): p = float(current_step) / float(steps_total) return 1 - p scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: for n, p in named_params: printlog('param for tune', n) printlog("fq_tune_only", fq_tune_only) printlog("dataset size", len(q_dataset)) printlog("epoches", num_train_epochs) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size) printlog("n_gpu", args.n_gpu) printlog("world_size", world_size) printlog("gradient_accumulation_steps", gradient_accumulation_steps) printlog("total train batch size", train_batch_size * gradient_accumulation_steps) printlog("steps_total", steps_total) global_step = 1 model.zero_grad() indicators = collections.defaultdict(list) softplus = torch.nn.Softplus() loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) hnm_hist = {} for epoch in range(math.ceil(num_train_epochs)): indicators = collections.defaultdict(list) model.train() if model_t: model_t.train() if rank > -1: #set epoch to make different samples division betwen process for different epoches q_sampler.set_epoch(epoch) utils.sync_models(rank, model) for step, q_batch in enumerate(q_dataloader): epoch_fp = epoch + step / len(q_dataloader) if epoch_fp > num_train_epochs: break losses = [] context_ids_pos = q_batch[3] q_inputs = get_inputs(q_batch, args.device) q_outputs = model(**q_inputs, output_hidden_states=(model_t is not None)) q_vec = q_outputs[0] #get positive embeddings c_batch = train_dataset_qc.c_dataset[context_ids_pos.detach().data] c_inputs = get_inputs(c_batch, args.device) c_outputs = model(**c_inputs, output_hidden_states=(model_t is not None)) c_vec_pos = c_outputs[0] if model_t is not None: q_emb_s, q_hidden_s = q_outputs c_emb_s, c_hidden_s = c_outputs with torch.no_grad(): q_emb_t, q_hidden_t = model_t(**q_inputs, output_hidden_states=True) c_emb_t, c_hidden_t = model_t(**c_inputs, output_hidden_states=True) def align_and_loss_outputs(out_s, out_t): if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [ out_t[(i * (n_t - 1)) // (n_s - 1)] for i in range(n_s) ] assert len(out_s) == len( out_t ), "can not align number of outputs between student and teacher" assert all( s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape) ), "output shapes for student and teacher are not the same" return [(s - t.detach()).pow(2).mean() for s, t in zip(out_s, out_t)] l_q = align_and_loss_outputs(q_hidden_s, q_hidden_t) l_c = align_and_loss_outputs(c_hidden_s, c_hidden_t) emb_loss = loss_cfg.get('emb_loss', '') if emb_loss == 'L2': l_q.append((q_emb_s - q_emb_t.detach()).pow(2).mean()) l_c.append((c_emb_s - c_emb_t.detach()).pow(2).mean()) elif emb_loss == 'L1': l_q.append((q_emb_s - q_emb_t.detach()).abs().mean()) l_c.append((c_emb_s - c_emb_t.detach()).abs().mean()) elif emb_loss.lower() not in ['', 'none', '0', 'disable']: raise Exception( 'emb_loss={} is unsupported'.format(emb_loss)) losses.extend([args.supervision_weight * l for l in l_c + l_q]) triplet_num = int(loss_cfg.get('triplet_num', 1)) if fq_tune_only: triplet_num = 0 if triplet_num > 0: #disable grad to select negatives with torch.no_grad(): hnm_scores = [] hnm_idxs = [] #check that current step has no HNM conext vector if global_step not in hnm_hist and args.hnm_num > 0: #generate the new one if world_size > 1 and (args.hnm_num % world_size) != 0: #aligh hnm_num per each replica hnm_plus = world_size - (args.hnm_num % world_size) args.hnm_num += hnm_plus logger.warning( "rank {} args.hnm_num increased by {} from {} to {} to be the same after division by {} replicas." .format(rank, hnm_plus, args.hnm_num - hnm_plus, args.hnm_num, world_size)) # generate random contexts to calc embedding context_ids_all = torch.randint( low=0, high=len(train_dataset_qc.c_dataset), size=[args.hnm_num]) if rank < 0: #single process take all context_ids = context_ids_all else: #broadcast one sigle indicies to all processes context_ids_all = context_ids_all.to(args.device) torch.distributed.broadcast(context_ids_all, 0) context_ids_all = context_ids_all.cpu() #each process take only small part to calc embedding s = ((rank + 0) * args.hnm_num) // world_size e = ((rank + 1) * args.hnm_num) // world_size context_ids = context_ids_all[s:e] batch_size = min(args.hnm_batch_size, context_ids.shape[0]) s, e = 0, batch_size c_outputs = [] while e > s: idx = context_ids.detach()[s:e] c_batch = train_dataset_qc.c_dataset[idx] inputs = get_inputs(c_batch, args.device) outputs = model(**inputs, output_hidden_states=False) c_outputs.append(outputs[0]) s, e = e, min(e + batch_size, context_ids.shape[0]) context_emb = torch.cat(c_outputs, dim=0) if rank < 0: # single process calculated all context_emb_all = context_emb else: context_emb_list = [ torch.zeros_like(context_emb) for _ in range(world_size) ] torch.distributed.all_gather( context_emb_list, context_emb) context_emb_all = torch.cat(context_emb_list, dim=0) hnm_hist[global_step] = (context_ids_all, context_emb_all) #check history size and crop the oldest one if len(hnm_hist) > args.hnm_hist_num: del hnm_hist[min(hnm_hist.keys())] #calc HNM scores for current question batch for hist_step, (c_idx, c_vec) in hnm_hist.items(): w = args.hnm_hist_alpha**(global_step - hist_step) t1 = q_vec[:, None, :] t2 = c_vec[None, :, :] d = (t1 - t2) score = -d.norm(2, dim=-1) score = score * w hnm_scores.append(score) hnm_idxs.append(c_idx) if hnm_scores: #choose the hardest negative if we have scores score = torch.cat(hnm_scores, dim=-1) idx = torch.cat(hnm_idxs, dim=-1) score = score.cpu() pos_mask = (context_ids_pos[:, None] == idx[None, :]).to( dtype=score.dtype, device=score.device) score = (1 - pos_mask) * score + pos_mask * score.min( ) #make positive context with small score to avoid chose it as hard neg hn_idx = score.argmax(dim=1, keepdim=True) context_ids_neg = idx[hn_idx] else: #just random selection in case of no scores for HNM size = (context_ids_pos.shape[0], 1) context_ids_neg = torch.randint( 0, len(train_dataset_qc.c_dataset) - 1, size) shift = (context_ids_neg >= context_ids_pos[:, None]) context_ids_neg = context_ids_neg + shift.to( dtype=context_ids_neg.dtype) d_pos = (q_vec - c_vec_pos).norm(2, dim=-1) # get negative embeddings and calc losses for neg_index in range(context_ids_neg.shape[1]): ids = context_ids_neg[:, neg_index] c_batch = train_dataset_qc.c_dataset[ids.detach()] inputs = get_inputs(c_batch, args.device) outputs = model(**inputs, output_hidden_states=False) c_vec_neg = outputs[0] for triplet_index in range(triplet_num): if triplet_index == 0: d_neg = (q_vec - c_vec_neg).norm(2, dim=-1) if triplet_index == 1: d_neg = (c_vec_pos - c_vec_neg).norm(2, dim=-1) d_diff = d_pos - d_neg indicators['dd' + str(triplet_index)].append( [v.mean().item() for v in (d_pos, d_neg, d_diff)]) l = softplus(d_diff) losses.append(l) del d_neg del d_pos #average over batch losses = [l.mean() for l in losses] l = sum(losses) / len(losses) (l / gradient_accumulation_steps).backward() indicators['loss'].append(l.item()) indicators['ll'].append([lll.item() for lll in losses]) #del losses del l if (step + 1) % gradient_accumulation_steps == 0: utils.sync_grads(rank, named_params, report_no_grad_params=(global_step == 1)) torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if global_step % 10 == 0: # Log metrics wall_time = epoch + step / len(q_dataloader) lrp = [ '{:.2f}'.format(i) for i in np.log10(scheduler.get_last_lr()) ] str_out = "{} ep {:.2f} lrp {}".format( train_count, epoch_fp, " ".join(lrp)) for k, v in indicators.items(): v = np.array(v) if len(v.shape) == 1: v = v[:, None] if rank > -1: #sync indicators vt = torch.tensor(v).to(args.device) torch.distributed.all_reduce( vt, op=torch.distributed.ReduceOp.SUM) v = vt.cpu().numpy() / float(world_size) str_out += " {} {}".format( k, " ".join(["{:.3f}".format(t) for t in v.mean(0)])) if 'score' in locals(): str_out += " SS {}".format(list(score.shape)) if 'time_last' in locals(): dt_iter = (time.time() - time_last) / len( indicators['loss']) dt_ep = dt_iter * len(q_dataloader) str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format( dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) time_last = time.time() indicators = collections.defaultdict(list) if rank in [-1, 0]: logger.info(str_out) if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) result_s = evaluate(args, model.eval(), test_dataset_qc) for k, v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, v)) if rank > -1: torch.distributed.barrier()
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--topic_model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining topic model.") parser.add_argument("--topic_model_dict_path", default=None, type=str, help="The file of fine-tuned pretraining topic model.") parser.add_argument("--tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") parser.add_argument('--topic_mode', default=1, type=float, help="1:idea1 1.1:idea1_wo_theta 2:idea2 ") parser.add_argument('--topic_model', default=False, type=bool, help="if only use topic model") # Other parameters parser.add_argument( "--max_seq_length", default=192, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument( "--train_batch_size", default=32, type=int, help="Total batch size for training.") #batch_size = batch_size/n_gpus parser.add_argument("--eval_batch_size", default=16, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=30, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) print("args.local_rank", args.local_rank) print("args.no_cuda", args.no_cuda) if args.local_rank == -1 or args.no_cuda: #-1 False device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") #device = cuda n_gpu = torch.cuda.device_count() print("n_gpu_1", n_gpu) else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') print("n_gpu_1", n_gpu) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() if args.do_train: bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') fn_src = os.path.join(args.data_dir, args.src_file if args.src_file else 'train.src') fn_tgt = os.path.join(args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt') train_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, args.data_dir, args.topic_model_dict_path, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 ### type_vocab_size=6 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None unilm = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: # here is the entrance logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 unilm = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) #1. 模型初始化,入口定义好 gsm = GSM(train_dataset.vocabsize) gsm_checkpoint = torch.load(args.topic_model_recover_path) gsm.load_state_dict(gsm_checkpoint["net"]) if args.local_rank == 0: dist.barrier() if args.fp16: unilm.half() gsm.half() if args.fp32_embedding: unilm.bert.embeddings.word_embeddings.float() unilm.bert.embeddings.position_embeddings.float() unilm.bert.embeddings.token_type_embeddings.float() unilm.to(device) gsm.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") unilm = DDP(unilm, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: # model = torch.nn.DataParallel(model) unilm = DataParallelImbalance(unilm) gsm = DataParallelImbalance(gsm) # Prepare optimizer total = 0 param_optimizer = list(unilm.named_parameters()) param_optimizer_topic = list(gsm.named_parameters()) for name, parameters in unilm.named_parameters(): if "idea" in name: if "11" in name and "idea2" in name: total += np.prod(parameters.size()) # print(name, ':', parameters.size()) else: total += np.prod(parameters.size()) # print(name, ':', parameters.size()) print("gsm have {} paramerters in total".format( sum(x.numel() for x in gsm.parameters()))) print("Number of parameter: %.6fM" % (total / 1e6)) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] if not args.topic_model: optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01, 'topic': False }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'topic': False }, { 'params': [p for n, p in param_optimizer_topic], 'weight_decay': 0.0, 'lr': 1e-3, 'topic': True }] else: optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer_topic], 'weight_decay': 0.0, 'lr': 1e-3, 'topic': True }] #一部分是有weight的,一部分是没有weight_dacay的 # print("optimizer_grouped_parameters", optimizer_grouped_parameters) if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) unilm.train() gsm.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 print("000000", args.local_rank, start_epoch, int(args.num_train_epochs) + 1) topicloss = [] unilmloss = [] topicloss_lst = [] unilmloss_lst = [] for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): loss_sum = 0.0 ppx_sum = 0.0 word_count = 0.0 doc_count = 0.0 if args.local_rank != -1: train_sampler.set_epoch(i_epoch) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', disable=args.local_rank not in (-1, 0)) for step, batch in enumerate(iter_bar): batch = [ t.to(device) if t is not None else None for t in batch ] if args.has_sentence_oracle: #false input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch else: #这里加了bows input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, bows = batch oracle_pos, oracle_weights, oracle_labels = None, None, None p_x, mus, log_vars, theta, beta, topic_embedding = gsm(bows) if not args.topic_model: loss_tuple = unilm(input_ids, theta, beta, topic_embedding, args.topic_mode, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv) masked_lm_loss, next_sentence_loss = loss_tuple ## topic loss logsoftmax = torch.log(p_x + 1e-10) rec_loss = -1.0 * torch.sum( bows * logsoftmax ) #bows*logsoftmax = [batch_size, |V|], 其中torch.sum 把所有的loss全部加起来了,也可以只用加某一维度。 rec_loss_per = -1.0 * torch.sum(bows * logsoftmax, dim=1) rec_loss_per = rec_loss_per.cpu().detach().numpy() kl_div = -0.5 * torch.sum(1 + log_vars - mus.pow(2) - log_vars.exp()) loss_topic = rec_loss + kl_div if n_gpu > 1: # mean() to average on multi-gpu. loss_topic = loss_topic.mean() if not args.topic_model: masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() if not args.topic_model: loss_unilm = masked_lm_loss + next_sentence_loss # cal perplexity word_count_list = [] loss_sum += loss_topic.item() for bow in bows: word_num = torch.sum(bow).cpu().numpy() word_count_list.append(word_num) word_count += word_num word_count_np = np.array(word_count_list) doc_count += len(bows) ppx_sum += np.sum(np.true_divide(rec_loss_per, word_count_np)) topicloss_lst.append(loss_topic.item() / len(bows)) if not args.topic_model: unilmloss_lst.append(loss_unilm.item()) #topic_loss end if not args.topic_model: loss = loss_unilm + loss_topic else: loss = loss_topic # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: # =1 loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: if not param_group['topic']: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 if not args.topic_model: iter_bar.set_description( 'Iter (loss_unilm=%5.3f),Iter (ppl=%5.3f)' % (loss_unilm.item(), np.sum(np.true_divide(rec_loss_per, word_count_np)))) else: iter_bar.set_description( 'Iter (loss_topic=%5.3f), (ppl=%5.3f)' % (loss_topic.item(), np.sum(np.true_divide(rec_loss_per, word_count_np)))) #Save a trained model ppx_word = np.exp(loss_sum / word_count) ppx_document = np.exp(ppx_sum / doc_count) print("********") print("word_count", word_count) print("ppx_word", ppx_word) print("ppx_document", ppx_document) if (args.local_rank == -1 or torch.distributed.get_rank() == 0): #save unilm model logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") unilm_model_to_save = unilm.module if hasattr( unilm, 'module') else unilm # Only save the model it-self output_unilm_model_file = os.path.join( args.output_dir, "unilm.{0}.bin".format(i_epoch)) torch.save(unilm_model_to_save.state_dict(), output_unilm_model_file) #save topic model logger.info( "** ** * Saving topic model and optimizer ** ** * ") topic_model_to_save = gsm.module if hasattr( gsm, 'module') else gsm # Only save the model it-self output_topic_model_file = os.path.join( args.output_dir, "topic.{0}.ckpt".format(i_epoch)) torch.save(topic_model_to_save.state_dict(), output_topic_model_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() smth_pts = smooth_curve(topicloss_lst) # plt.plot(range(len(topicloss_lst)), topicloss_lst) plt.plot(range(len(smth_pts)), smth_pts) plt.xlabel('epochs') plt.title('Topic Model Train Loss') plt.savefig(args.output_dir + '/topic_loss.png') plt.cla() plt.plot(range(len(unilmloss_lst)), unilmloss_lst) plt.xlabel('epochs') plt.title('Unilm Train Loss') plt.savefig(args.output_dir + '/unilm_loss.png')
def train(args, train_dataset, label_list, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: if args.log_dir: tb_writer = SummaryWriter(args.log_dir) else: tb_writer = SummaryWriter() log_writer = open(os.path.join(args.output_dir, "evaluate_logs.txt"), 'a') args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # model recover recover_step = get_max_steps(args.output_dir) model_recover_path = None if recover_step: model_recover_path = os.path.join(args.output_dir, "checkpoint-{}".format(recover_step)) logger.info(" ** Recover model checkpoint in %s ** ", model_recover_path) model.load_state_dict( torch.load(os.path.join(model_recover_path, WEIGHTS_NAME), map_location='cpu')) model.to(args.device) # check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( model_recover_path, "optimizer.pt")) and os.path.isfile( os.path.join(model_recover_path, "scheduler.pt")): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(model_recover_path, "optimizer.pt"), map_location='cpu')) scheduler.load_state_dict( torch.load(os.path.join(model_recover_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) logger.info(" Logging steps = %d", args.logging_steps) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_recover_path and os.path.exists(model_recover_path): # set global_step to gobal_step of last saved checkpoint from model path global_step = recover_step epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) tr_loss, logging_loss, best_avg = 0.0, 0.0, 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproductibility def logging(): results = evaluate(args, model, tokenizer, label_list, single_gpu=True, splits=args.eval_splits.split(',')) for task, result in results.items(): for key, value in result.items(): tb_writer.add_scalar("eval_{}_{}".format(task, key), value, global_step) log_writer.write("{0}\t{1}\n".format(global_step, json.dumps(results))) log_writer.flush() tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) return results def save_checkpoint(cur_step): output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(cur_step)) logger.info("Saving model checkpoint to %s", output_dir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, "module") else model model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) for epc in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) if not args.use_all_samples_per_epoch: if args.local_rank != -1: train_sampler.set_epoch(epc) if epc > 0: train_dataset.shuffle() for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() if args.filter_k == 0: # no cross-attention layer batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } else: input_ids = [d[0].to(args.device) for d in batch] attention_mask = [d[1].to(args.device) for d in batch] labels = [d[3].to(args.device) for d in batch][0] inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels } if args.alpha > 0: soft_labels = [d[4].to(args.device) for d in batch][0] inputs["soft_labels"] = soft_labels if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert"] else None ) # XLM and DistilBERT don't use segment_ids outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss # Save model checkpoint if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: save_checkpoint(global_step) if args.evaluate_during_training: cur_result = logging() logger.info(cur_result) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.local_rank in [-1, 0] and args.logging_each_epoch: logging_loss = tr_loss save_checkpoint(global_step) if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() log_writer.close() return global_step, tr_loss / (global_step + 1)
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") # Other parameters parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--local_debug", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() if args.do_train: print("Loading Train Dataset", args.data_dir) bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') fn_src = os.path.join(args.data_dir, args.src_file if args.src_file else 'train.src') fn_tgt = os.path.join(args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt') train_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline, corpus_preprocessors=corpus_preprocessors) train_dataset.initial() print(len(train_dataset.ex_list)) print(train_dataset.batch_size) # assert 1==0 if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # c = 0 # for i_epoch in trange(0, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)): # if args.local_rank != -1: # train_sampler.set_epoch(i_epoch) # iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', # disable=args.local_rank not in (-1, 0)) # for step, batch in enumerate(iter_bar): # batch = [ # t.to(device) if t is not None else None for t in batch] # if args.has_sentence_oracle: # input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch # else: # input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch # oracle_pos, oracle_weights, oracle_labels = None, None, None # c += input_ids.shape[0] # # # print(input_ids) # # # # print(input_ids.shape) # # print(segment_ids) # # print(segment_ids.shape) # # print(is_next) # # print(task_idx) # # print(sop_label) # # print(task_idx.shape) # # for i in range(input_mask.shape[0]): # # print(input_mask[i]) # print(c) # print(train_dataset.c) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, local_debug=args.local_debug) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, local_debug=args.local_debug) if args.local_rank == 0: dist.barrier() if args.fp16: model.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) model.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): if args.local_rank != -1: train_sampler.set_epoch(i_epoch) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', disable=args.local_rank not in (-1, 0)) for step, batch in enumerate(iter_bar): batch = [ t.to(device) if t is not None else None for t in batch ] if args.has_sentence_oracle: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch else: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch print(sop_label) print(task_idx) oracle_pos, oracle_weights, oracle_labels = None, None, None # loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, # masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, # masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, # masked_labels_2=oracle_labels, mask_qkv=mask_qkv) loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, sop_label, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv) masked_lm_loss, next_sentence_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() print('mask_lm_loss {}'.format(masked_lm_loss)) print('next_sentence_loss {}'.format(next_sentence_loss)) print('----------------------------------------------') loss = masked_lm_loss + next_sentence_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model if (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{0}.bin".format(i_epoch)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.output_dir, "optim.{0}.bin".format(i_epoch)) torch.save(optimizer.state_dict(), output_optim_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache()
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]: """ Train the model """ tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) def pad_examples(examples, padding_value=tokenizer.pad_token_id): if tokenizer._pad_token is None: return pad_sequence(examples, batch_first=True) return pad_sequence(examples, batch_first=True, padding_value=padding_value) def collate(examples): text_examples = [None] * len(examples) text_labels = [None] * len(examples) text_type_ids = [None] * len(examples) video_examples = [None] * len(examples) video_labels = [None] * len(examples) video_type_ids = [None] * len(examples) joint_examples = [None] * len(examples) joint_labels = [None] * len(examples) joint_type_ids = [None] * len(examples) for i, (te, tl, tti, ve, vl, vti, je, jl, jti) in enumerate(examples): text_examples[i] = te video_examples[i] = ve text_labels[i] = tl video_labels[i] = vl text_type_ids[i] = tti video_type_ids[i] = vti joint_examples[i] = je joint_labels[i] = jl joint_type_ids[i] = jti padded_text_ids = pad_examples(text_examples) text_attention_mask = torch.ones(padded_text_ids.shape, dtype=torch.int64) text_attention_mask[(padded_text_ids == 0)] = 0 padded_video_ids = pad_examples(video_examples) video_attention_mask = torch.ones(padded_video_ids.shape, dtype=torch.int64) video_attention_mask[(padded_video_ids == 0)] = 0 padded_joint_ids = pad_examples(joint_examples) joint_attention_mask = torch.ones(padded_joint_ids.shape, dtype=torch.int64) joint_attention_mask[(padded_joint_ids == 0)] = 0 return padded_text_ids, \ torch.tensor(text_labels, dtype=torch.int64), \ pad_examples(text_type_ids, padding_value=0), \ text_attention_mask, \ padded_video_ids, \ torch.tensor(video_labels, dtype=torch.int64), \ pad_examples(video_type_ids, padding_value=0), \ video_attention_mask, \ padded_joint_ids, \ torch.tensor(joint_labels, dtype=torch.int64), \ pad_examples(joint_type_ids, padding_value=0), \ joint_attention_mask train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if (args.model_name_or_path and os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if args.model_name_or_path and os.path.exists(args.model_name_or_path): try: # set global_step to gobal_step of last saved checkpoint from model path checkpoint_suffix = args.model_name_or_path.split("-")[-1].split( "/")[0] global_step = int(checkpoint_suffix) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: logger.info(" Starting fine-tuning.") tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproducibility for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) if args.local_rank != -1: train_sampler.set_epoch(epoch) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue text_ids = batch[0] text_seq_labels = batch[1] text_token_type_ids = batch[2] text_attention_masks = batch[3] video_ids = batch[4] video_seq_labels = batch[5] video_token_type_ids = batch[6] video_attention_masks = batch[7] joint_ids = batch[8] joint_labels = batch[9] joint_token_type_ids = batch[10] joint_attention_masks = batch[11] text_inputs, text_mask_labels = mask_tokens( text_ids, tokenizer, args) if args.mlm else (text_ids, text_ids) video_inputs, video_mask_labels = mask_tokens( video_ids, tokenizer, args) if args.mlm else (video_ids, video_ids) joint_inputs, joint_mask_labels = mask_tokens( joint_ids, tokenizer, args) if args.mlm else (joint_ids, joint_ids) text_inputs = text_inputs.to(args.device) text_mask_labels = text_mask_labels.to(args.device) text_seq_labels = text_seq_labels.to(args.device) text_token_type_ids = text_token_type_ids.to(args.device) video_token_type_ids = video_token_type_ids.to(args.device) joint_token_type_ids = joint_token_type_ids.to(args.device) text_attention_masks = text_attention_masks.to(args.device) video_attention_masks = video_attention_masks.to(args.device) joint_attention_masks = joint_attention_masks.to(args.device) video_inputs = video_inputs.to(args.device) video_mask_labels = video_mask_labels.to(args.device) video_seq_labels = video_seq_labels.to(args.device) joint_inputs = joint_inputs.to(args.device) joint_mask_labels = joint_mask_labels.to(args.device) joint_labels = joint_labels.to(args.device) model.train() outputs = model( text_input_ids=text_inputs, video_input_ids=video_inputs, joint_input_ids=joint_inputs, text_token_type_ids=text_token_type_ids, video_token_type_ids=video_token_type_ids, joint_token_type_ids=joint_token_type_ids, text_attention_mask=text_attention_masks, video_attention_mask=video_attention_masks, joint_attention_mask=joint_attention_masks, text_masked_lm_labels=text_mask_labels, video_masked_lm_labels=video_mask_labels, joint_masked_lm_labels=joint_mask_labels, text_next_sentence_label=text_seq_labels, video_next_sentence_label=video_seq_labels, joint_vis_lin_label=joint_labels, ) loss = outputs[0] text_loss = outputs[1] video_loss = outputs[2] joint_loss = outputs[3] if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps # text_loss = text_loss / args.gradient_accumulation_steps # video_loss = video_loss / args.gradient_accumulation_steps # joint_loss = joint_loss / args.gradient_accumulation_steps loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) print('loss:', loss.item(), 'text loss:', text_loss.item(), 'video loss:', video_loss.item(), 'joint loss:', joint_loss.item()) # keep BERT embeddings frozen model.bert.embeddings.word_embeddings.weight.grad[ globals.frozen_indices] = 0 optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: print('writing tf logs...') tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) tb_writer.add_scalar("text_loss", text_loss.item(), global_step) tb_writer.add_scalar("video_loss", video_loss.item(), global_step) tb_writer.add_scalar("joint_loss", joint_loss.item(), global_step) logging_loss = tr_loss if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: checkpoint_prefix = "checkpoint" # Save model checkpoint output_dir = os.path.join( args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) os.makedirs(output_dir, exist_ok=True) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) _rotate_checkpoints(args, checkpoint_prefix) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) #Train File parser.add_argument("--src_file", default=None, type=str, help="The input data src file name.") parser.add_argument("--tgt_file", default=None, type=str, help="The input data tgt file name.") parser.add_argument("--check_file", default=None, type=str, help="The input check knowledge data file name") #KS File parser.add_argument("--ks_src_file", default=None, type=str, help="The input ks data src file name.") parser.add_argument("--ks_tgt_file", default=None, type=str, help="The input ks data tgt file name.") parser.add_argument("--predict_input_file", default=None, type=str, help="predict_input_file") parser.add_argument("--predict_output_file", default=None, type=str, help="predict_output_file") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") parser.add_argument("--predict_bleu", default=0.2, type=float, help="The Predicted Bleu for KS Predict ") # Other parameters parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_predict", action='store_true', help="Whether to run ks predict.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--train_avg_bpe_length", default=25, type=int, help="average bpe length for train.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion_step", default=300, type=int, help= "Proportion of training to perform linear learning rate warmup for. ") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=67, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) handler = logging.FileHandler(os.path.join(args.log_dir, "train.log"), encoding='UTF-8') handler.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) console = logging.StreamHandler() console.setLevel(logging.DEBUG) logger.addHandler(handler) logger.addHandler(console) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) #Random Seed #torch.backends.cudnn.enabled = False #torch.backends.cudnn.benchmark = False #torch.backends.cudnn.deterministic = True # if n_gpu > 0: # torch.cuda.manual_seed_all(args.seed) if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() #Data process pipelines bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] C_bi_uni_pipeline = [ seq2seq_loader.C_Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] ks_predict_bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq_predict( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] if args.do_train: print("Loading QKR Train Dataset", args.data_dir) file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') fn_src = os.path.join(args.data_dir, args.src_file if args.src_file else 'train.src') fn_tgt = os.path.join(args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt') fn_check = os.path.join(args.data_dir, args.check_file) train_dataset = seq2seq_loader.C_Seq2SeqDataset( fn_src, fn_tgt, fn_check, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=C_bi_uni_pipeline) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) print("Loading KS Train Dataset", args.data_dir) ks_fn_src = os.path.join(args.data_dir, args.ks_src_file) ks_fn_tgt = os.path.join(args.data_dir, args.ks_tgt_file) ks_train_dataset = seq2seq_loader.Seq2SeqDataset( ks_fn_src, ks_fn_tgt, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: ks_train_sampler = RandomSampler(ks_train_dataset, replacement=False) _batch_size = args.train_batch_size else: ks_train_sampler = DistributedSampler(ks_train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() ks_train_dataloader = torch.utils.data.DataLoader( ks_train_dataset, batch_size=_batch_size, sampler=ks_train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 + ( 1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() #Recover model if args.model_recover_path: logger.info(" ** ** * Recover model: %s ** ** * ", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[S2S_SOS]"]) model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, mask_word_id=mask_word_id, search_beam_size=5, length_penalty=0, eos_id=eos_word_ids, sos_id=sos_word_id, forbid_duplicate_ngrams=True, forbid_ignore_set=None, mode="s2s") if args.local_rank == 0: dist.barrier() if args.fp16: model.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) model.tmp_bert_emb.word_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.word_embeddings.weight.clone()) model.tmp_bert_emb.token_type_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.token_type_embeddings.weight.clone()) model.tmp_bert_emb.position_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.position_embeddings.weight.clone()) model.mul_bert_emb.word_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.word_embeddings.weight.clone()) model.mul_bert_emb.token_type_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.token_type_embeddings.weight.clone()) model.mul_bert_emb.position_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.position_embeddings.weight.clone()) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: from pytorch_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if args.optim_recover_path is not None: logger.info(" ** ** * Recover optimizer from : {} ** ** * ".format( args.optim_recover_path)) optim_recover = torch.load(args.optim_recover_path, map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info( " ** ** * Recover optimizer: dynamic_loss_scale ** ** * ") optimizer.dynamic_loss_scale = True #logger.info(" ** ** * CUDA.empty_cache() ** ** * ") torch.cuda.empty_cache() # ################# TRAIN ############################ # if args.do_train: max_F1 = 0 best_step = 0 logger.info(" ** ** * Running training ** ** * ") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) model.train() start_epoch = 1 for i_epoch in trange(start_epoch, start_epoch + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): if args.local_rank != -1: train_sampler.set_epoch(i_epoch) step = 0 for batch, ks_batch in zip(train_dataloader, ks_train_dataloader): # ################# E step + M step + Mutual Information Loss ############################ # batch = [ t.to(device) if t is not None else None for t in batch ] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, tgt_pos, labels, ks_labels, check_ids = batch oracle_pos, oracle_weights, oracle_labels = None, None, None loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv, tgt_pos=tgt_pos, labels=labels.half(), ks_labels=ks_labels, check_ids=check_ids) masked_lm_loss, next_sentence_loss, KL_loss, Mutual_loss, Golden_loss, predict_kl_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() Mutual_loss = Mutual_loss.mean() Golden_loss = Golden_loss.mean() KL_loss = KL_loss.mean() predict_kl_loss = predict_kl_loss.mean() loss = masked_lm_loss + next_sentence_loss + KL_loss + predict_kl_loss + Mutual_loss + Golden_loss logger.info("In{}step, masked_lm_loss:{}".format( step, masked_lm_loss)) logger.info("In{}step, KL_loss:{}".format(step, KL_loss)) logger.info("In{}step, Mutual_loss:{}".format( step, Mutual_loss)) logger.info("In{}step, Golden_loss:{}".format( step, Golden_loss)) logger.info("In{}step, predict_kl_loss:{}".format( step, predict_kl_loss)) logger.info("******************************************* ") # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * warmup_linear( global_step / t_total, args.warmup_proportion_step / t_total) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # ################# Knowledge Selection Loss ############################ # if random.randint(0, 4) == 0: ks_batch = [ t.to(device) if t is not None else None for t in ks_batch ] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, _, labels, ks_labels = ks_batch oracle_pos, oracle_weights, oracle_labels = None, None, None loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv, labels=labels, ks_labels=ks_labels, train_ks=True) ks_loss, _ = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. ks_loss = ks_loss.mean() loss = ks_loss logger.info("In{}step, ks_loss:{}".format(step, ks_loss)) logger.info("******************************************* ") # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * warmup_linear( global_step / t_total, args.warmup_proportion_step / t_total) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() step += 1 ###################### Eval Every 5000 Step ############################ # if (global_step + 1) % 5000 == 0: next_i = 0 model.eval() # Know Rank Stage logger.info(" ** ** * DEV Know Selection Begin ** ** * ") with open(os.path.join(args.data_dir, args.predict_input_file), "r", encoding="utf-8") as file: src_file = file.readlines() with open(os.path.join(args.data_dir, "train_tgt_pad.empty"), "r", encoding="utf-8") as file: tgt_file = file.readlines() with open(os.path.join(args.data_dir, args.predict_output_file), "w", encoding="utf-8") as out: while next_i < len(src_file): batch_src = src_file[next_i:next_i + args.eval_batch_size] batch_tgt = tgt_file[next_i:next_i + args.eval_batch_size] next_i += args.eval_batch_size ex_list = [] for src, tgt in zip(batch_src, batch_tgt): src_tk = data_tokenizer.tokenize(src.strip()) tgt_tk = data_tokenizer.tokenize(tgt.strip()) ex_list.append((src_tk, tgt_tk)) batch = [] for idx in range(len(ex_list)): instance = ex_list[idx] for proc in ks_predict_bi_uni_pipeline: instance = proc(instance) batch.append(instance) batch_tensor = seq2seq_loader.batch_list_to_batch_tensors( batch) batch = [ t.to(device) if t is not None else None for t in batch_tensor ] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch predict_bleu = args.predict_bleu * torch.ones( [input_ids.shape[0]], device=input_ids.device) oracle_pos, oracle_weights, oracle_labels = None, None, None with torch.no_grad(): logits = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv, labels=predict_bleu, train_ks=True) logits = torch.nn.functional.softmax(logits, dim=1) labels = logits[:, 1].cpu().numpy() for i in range(len(labels)): line = batch_src[i].strip() line += "\t" line += str(labels[i]) out.write(line) out.write("\n") data_path = os.path.join(args.data_dir, "qkr_dev.ks_score.tk") src_path = os.path.join(args.data_dir, "qkr_dev.src.tk") src_out_path = os.path.join(args.data_dir, "rank_qkr_dev.src.tk") tgt_path = os.path.join(args.data_dir, "qkr_dev.tgt") knowledge_selection(data_path, src_path, src_out_path) logger.info(" ** ** * DEV Know Selection End ** ** * ") # Decode Stage logger.info(" ** ** * Dev Decode Begin ** ** * ") with open(src_out_path, encoding="utf-8") as file: dev_src_lines = file.readlines() with open(tgt_path, encoding="utf-8") as file: golden_response_lines = file.readlines() decode_result = decode_batch(model, dev_src_lines) logger.info(" ** ** * Dev Decode End ** ** * ") # Compute dev F1 assert len(decode_result) == len(golden_response_lines) C_F1 = f_one(decode_result, golden_response_lines)[0] logger.info( "** ** * Current F1 is {} ** ** * ".format(C_F1)) if C_F1 < max_F1: logger.info( "** ** * Current F1 is lower than Previous F1. So Stop Training ** ** * " ) logger.info( "** ** * The best model is {} ** ** * ".format( best_step)) break else: max_F1 = C_F1 best_step = step logger.info( "** ** * Current F1 is larger than Previous F1. So Continue Training ** ** * " ) # Save trained model if (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * " ) model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{}_{}.bin".format(i_epoch, global_step)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.output_dir, "optim.bin") torch.save(optimizer.state_dict(), output_optim_file) #logger.info(" ** ** * CUDA.empty_cache() ** ** * ") torch.cuda.empty_cache() # ################# Predict ############################ # if args.do_predict: bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq_predict( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] next_i = 0 model.eval() with open(os.path.join(args.data_dir, args.predict_input_file), "r", encoding="utf-8") as file: src_file = file.readlines() with open("train_tgt_pad.empty", "r", encoding="utf-8") as file: tgt_file = file.readlines() with open(os.path.join(args.data_dir, args.predict_output_file), "w", encoding="utf-8") as out: logger.info("** ** * Continue knowledge ranking ** ** * ") for next_i in tqdm( range(len(src_file) // args.eval_batch_size + 1)): #while next_i < len(src_file): batch_src = src_file[next_i * args.eval_batch_size:(next_i + 1) * args.eval_batch_size] batch_tgt = tgt_file[next_i * args.eval_batch_size:(next_i + 1) * args.eval_batch_size] #next_i += args.eval_batch_size ex_list = [] for src, tgt in zip(batch_src, batch_tgt): src_tk = data_tokenizer.tokenize(src.strip()) tgt_tk = data_tokenizer.tokenize(tgt.strip()) ex_list.append((src_tk, tgt_tk)) batch = [] for idx in range(len(ex_list)): instance = ex_list[idx] for proc in bi_uni_pipeline: instance = proc(instance) batch.append(instance) batch_tensor = seq2seq_loader.batch_list_to_batch_tensors( batch) batch = [ t.to(device) if t is not None else None for t in batch_tensor ] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch predict_bleu = args.predict_bleu * torch.ones( [input_ids.shape[0]], device=input_ids.device) oracle_pos, oracle_weights, oracle_labels = None, None, None with torch.no_grad(): logits = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv, labels=predict_bleu, train_ks=True) logits = torch.nn.functional.softmax(logits, dim=1) labels = logits[:, 1].cpu().numpy() for i in range(len(labels)): line = batch_src[i].strip() line += "\t" line += str(labels[i]) out.write(line) out.write("\n")
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) elif output_mode == "regression": all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) if not is_distributed: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) for epoch_num in trange(int(args.num_train_epochs), desc="Epoch"): model.train() train_sampler.set_epoch(epoch_num) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch # define a new function to compute loss values for both output_modes logits = model(input_ids, segment_ids, input_mask, labels=None) if output_mode == "classification": if args.focal: loss_fct = FocalLoss(class_num=num_labels,