def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) #Train File parser.add_argument("--src_file", default=None, type=str, help="The input data src file name.") parser.add_argument("--tgt_file", default=None, type=str, help="The input data tgt file name.") parser.add_argument("--check_file", default=None, type=str, help="The input check knowledge data file name") #KS File parser.add_argument("--ks_src_file", default=None, type=str, help="The input ks data src file name.") parser.add_argument("--ks_tgt_file", default=None, type=str, help="The input ks data tgt file name.") parser.add_argument("--predict_input_file", default=None, type=str, help="predict_input_file") parser.add_argument("--predict_output_file", default=None, type=str, help="predict_output_file") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") parser.add_argument("--predict_bleu", default=0.2, type=float, help="The Predicted Bleu for KS Predict ") # Other parameters parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_predict", action='store_true', help="Whether to run ks predict.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--train_avg_bpe_length", default=25, type=int, help="average bpe length for train.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion_step", default=300, type=int, help= "Proportion of training to perform linear learning rate warmup for. ") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=67, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) handler = logging.FileHandler(os.path.join(args.log_dir, "train.log"), encoding='UTF-8') handler.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) console = logging.StreamHandler() console.setLevel(logging.DEBUG) logger.addHandler(handler) logger.addHandler(console) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) #Random Seed #torch.backends.cudnn.enabled = False #torch.backends.cudnn.benchmark = False #torch.backends.cudnn.deterministic = True # if n_gpu > 0: # torch.cuda.manual_seed_all(args.seed) if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() #Data process pipelines bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] C_bi_uni_pipeline = [ seq2seq_loader.C_Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] ks_predict_bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq_predict( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] if args.do_train: print("Loading QKR Train Dataset", args.data_dir) file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') fn_src = os.path.join(args.data_dir, args.src_file if args.src_file else 'train.src') fn_tgt = os.path.join(args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt') fn_check = os.path.join(args.data_dir, args.check_file) train_dataset = seq2seq_loader.C_Seq2SeqDataset( fn_src, fn_tgt, fn_check, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=C_bi_uni_pipeline) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) print("Loading KS Train Dataset", args.data_dir) ks_fn_src = os.path.join(args.data_dir, args.ks_src_file) ks_fn_tgt = os.path.join(args.data_dir, args.ks_tgt_file) ks_train_dataset = seq2seq_loader.Seq2SeqDataset( ks_fn_src, ks_fn_tgt, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: ks_train_sampler = RandomSampler(ks_train_dataset, replacement=False) _batch_size = args.train_batch_size else: ks_train_sampler = DistributedSampler(ks_train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() ks_train_dataloader = torch.utils.data.DataLoader( ks_train_dataset, batch_size=_batch_size, sampler=ks_train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 + ( 1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() #Recover model if args.model_recover_path: logger.info(" ** ** * Recover model: %s ** ** * ", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[S2S_SOS]"]) model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, mask_word_id=mask_word_id, search_beam_size=5, length_penalty=0, eos_id=eos_word_ids, sos_id=sos_word_id, forbid_duplicate_ngrams=True, forbid_ignore_set=None, mode="s2s") if args.local_rank == 0: dist.barrier() if args.fp16: model.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) model.tmp_bert_emb.word_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.word_embeddings.weight.clone()) model.tmp_bert_emb.token_type_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.token_type_embeddings.weight.clone()) model.tmp_bert_emb.position_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.position_embeddings.weight.clone()) model.mul_bert_emb.word_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.word_embeddings.weight.clone()) model.mul_bert_emb.token_type_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.token_type_embeddings.weight.clone()) model.mul_bert_emb.position_embeddings.weight = torch.nn.Parameter( model.bert.embeddings.position_embeddings.weight.clone()) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: from pytorch_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if args.optim_recover_path is not None: logger.info(" ** ** * Recover optimizer from : {} ** ** * ".format( args.optim_recover_path)) optim_recover = torch.load(args.optim_recover_path, map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info( " ** ** * Recover optimizer: dynamic_loss_scale ** ** * ") optimizer.dynamic_loss_scale = True #logger.info(" ** ** * CUDA.empty_cache() ** ** * ") torch.cuda.empty_cache() # ################# TRAIN ############################ # if args.do_train: max_F1 = 0 best_step = 0 logger.info(" ** ** * Running training ** ** * ") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) model.train() start_epoch = 1 for i_epoch in trange(start_epoch, start_epoch + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): if args.local_rank != -1: train_sampler.set_epoch(i_epoch) step = 0 for batch, ks_batch in zip(train_dataloader, ks_train_dataloader): # ################# E step + M step + Mutual Information Loss ############################ # batch = [ t.to(device) if t is not None else None for t in batch ] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, tgt_pos, labels, ks_labels, check_ids = batch oracle_pos, oracle_weights, oracle_labels = None, None, None loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv, tgt_pos=tgt_pos, labels=labels.half(), ks_labels=ks_labels, check_ids=check_ids) masked_lm_loss, next_sentence_loss, KL_loss, Mutual_loss, Golden_loss, predict_kl_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() Mutual_loss = Mutual_loss.mean() Golden_loss = Golden_loss.mean() KL_loss = KL_loss.mean() predict_kl_loss = predict_kl_loss.mean() loss = masked_lm_loss + next_sentence_loss + KL_loss + predict_kl_loss + Mutual_loss + Golden_loss logger.info("In{}step, masked_lm_loss:{}".format( step, masked_lm_loss)) logger.info("In{}step, KL_loss:{}".format(step, KL_loss)) logger.info("In{}step, Mutual_loss:{}".format( step, Mutual_loss)) logger.info("In{}step, Golden_loss:{}".format( step, Golden_loss)) logger.info("In{}step, predict_kl_loss:{}".format( step, predict_kl_loss)) logger.info("******************************************* ") # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * warmup_linear( global_step / t_total, args.warmup_proportion_step / t_total) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # ################# Knowledge Selection Loss ############################ # if random.randint(0, 4) == 0: ks_batch = [ t.to(device) if t is not None else None for t in ks_batch ] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, _, labels, ks_labels = ks_batch oracle_pos, oracle_weights, oracle_labels = None, None, None loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv, labels=labels, ks_labels=ks_labels, train_ks=True) ks_loss, _ = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. ks_loss = ks_loss.mean() loss = ks_loss logger.info("In{}step, ks_loss:{}".format(step, ks_loss)) logger.info("******************************************* ") # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * warmup_linear( global_step / t_total, args.warmup_proportion_step / t_total) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() step += 1 ###################### Eval Every 5000 Step ############################ # if (global_step + 1) % 5000 == 0: next_i = 0 model.eval() # Know Rank Stage logger.info(" ** ** * DEV Know Selection Begin ** ** * ") with open(os.path.join(args.data_dir, args.predict_input_file), "r", encoding="utf-8") as file: src_file = file.readlines() with open(os.path.join(args.data_dir, "train_tgt_pad.empty"), "r", encoding="utf-8") as file: tgt_file = file.readlines() with open(os.path.join(args.data_dir, args.predict_output_file), "w", encoding="utf-8") as out: while next_i < len(src_file): batch_src = src_file[next_i:next_i + args.eval_batch_size] batch_tgt = tgt_file[next_i:next_i + args.eval_batch_size] next_i += args.eval_batch_size ex_list = [] for src, tgt in zip(batch_src, batch_tgt): src_tk = data_tokenizer.tokenize(src.strip()) tgt_tk = data_tokenizer.tokenize(tgt.strip()) ex_list.append((src_tk, tgt_tk)) batch = [] for idx in range(len(ex_list)): instance = ex_list[idx] for proc in ks_predict_bi_uni_pipeline: instance = proc(instance) batch.append(instance) batch_tensor = seq2seq_loader.batch_list_to_batch_tensors( batch) batch = [ t.to(device) if t is not None else None for t in batch_tensor ] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch predict_bleu = args.predict_bleu * torch.ones( [input_ids.shape[0]], device=input_ids.device) oracle_pos, oracle_weights, oracle_labels = None, None, None with torch.no_grad(): logits = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv, labels=predict_bleu, train_ks=True) logits = torch.nn.functional.softmax(logits, dim=1) labels = logits[:, 1].cpu().numpy() for i in range(len(labels)): line = batch_src[i].strip() line += "\t" line += str(labels[i]) out.write(line) out.write("\n") data_path = os.path.join(args.data_dir, "qkr_dev.ks_score.tk") src_path = os.path.join(args.data_dir, "qkr_dev.src.tk") src_out_path = os.path.join(args.data_dir, "rank_qkr_dev.src.tk") tgt_path = os.path.join(args.data_dir, "qkr_dev.tgt") knowledge_selection(data_path, src_path, src_out_path) logger.info(" ** ** * DEV Know Selection End ** ** * ") # Decode Stage logger.info(" ** ** * Dev Decode Begin ** ** * ") with open(src_out_path, encoding="utf-8") as file: dev_src_lines = file.readlines() with open(tgt_path, encoding="utf-8") as file: golden_response_lines = file.readlines() decode_result = decode_batch(model, dev_src_lines) logger.info(" ** ** * Dev Decode End ** ** * ") # Compute dev F1 assert len(decode_result) == len(golden_response_lines) C_F1 = f_one(decode_result, golden_response_lines)[0] logger.info( "** ** * Current F1 is {} ** ** * ".format(C_F1)) if C_F1 < max_F1: logger.info( "** ** * Current F1 is lower than Previous F1. So Stop Training ** ** * " ) logger.info( "** ** * The best model is {} ** ** * ".format( best_step)) break else: max_F1 = C_F1 best_step = step logger.info( "** ** * Current F1 is larger than Previous F1. So Continue Training ** ** * " ) # Save trained model if (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * " ) model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{}_{}.bin".format(i_epoch, global_step)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.output_dir, "optim.bin") torch.save(optimizer.state_dict(), output_optim_file) #logger.info(" ** ** * CUDA.empty_cache() ** ** * ") torch.cuda.empty_cache() # ################# Predict ############################ # if args.do_predict: bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq_predict( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] next_i = 0 model.eval() with open(os.path.join(args.data_dir, args.predict_input_file), "r", encoding="utf-8") as file: src_file = file.readlines() with open("train_tgt_pad.empty", "r", encoding="utf-8") as file: tgt_file = file.readlines() with open(os.path.join(args.data_dir, args.predict_output_file), "w", encoding="utf-8") as out: logger.info("** ** * Continue knowledge ranking ** ** * ") for next_i in tqdm( range(len(src_file) // args.eval_batch_size + 1)): #while next_i < len(src_file): batch_src = src_file[next_i * args.eval_batch_size:(next_i + 1) * args.eval_batch_size] batch_tgt = tgt_file[next_i * args.eval_batch_size:(next_i + 1) * args.eval_batch_size] #next_i += args.eval_batch_size ex_list = [] for src, tgt in zip(batch_src, batch_tgt): src_tk = data_tokenizer.tokenize(src.strip()) tgt_tk = data_tokenizer.tokenize(tgt.strip()) ex_list.append((src_tk, tgt_tk)) batch = [] for idx in range(len(ex_list)): instance = ex_list[idx] for proc in bi_uni_pipeline: instance = proc(instance) batch.append(instance) batch_tensor = seq2seq_loader.batch_list_to_batch_tensors( batch) batch = [ t.to(device) if t is not None else None for t in batch_tensor ] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch predict_bleu = args.predict_bleu * torch.ones( [input_ids.shape[0]], device=input_ids.device) oracle_pos, oracle_weights, oracle_labels = None, None, None with torch.no_grad(): logits = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv, labels=predict_bleu, train_ks=True) logits = torch.nn.functional.softmax(logits, dim=1) labels = logits[:, 1].cpu().numpy() for i in range(len(labels)): line = batch_src[i].strip() line += "\t" line += str(labels[i]) out.write(line) out.write("\n")
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") # Other parameters parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--local_debug", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() if args.do_train: print("Loading Train Dataset", args.data_dir) bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') fn_src = os.path.join(args.data_dir, args.src_file if args.src_file else 'train.src') fn_tgt = os.path.join(args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt') train_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline, corpus_preprocessors=corpus_preprocessors) train_dataset.initial() print(len(train_dataset.ex_list)) print(train_dataset.batch_size) # assert 1==0 if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # c = 0 # for i_epoch in trange(0, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)): # if args.local_rank != -1: # train_sampler.set_epoch(i_epoch) # iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', # disable=args.local_rank not in (-1, 0)) # for step, batch in enumerate(iter_bar): # batch = [ # t.to(device) if t is not None else None for t in batch] # if args.has_sentence_oracle: # input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch # else: # input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch # oracle_pos, oracle_weights, oracle_labels = None, None, None # c += input_ids.shape[0] # # # print(input_ids) # # # # print(input_ids.shape) # # print(segment_ids) # # print(segment_ids.shape) # # print(is_next) # # print(task_idx) # # print(sop_label) # # print(task_idx.shape) # # for i in range(input_mask.shape[0]): # # print(input_mask[i]) # print(c) # print(train_dataset.c) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, local_debug=args.local_debug) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, local_debug=args.local_debug) if args.local_rank == 0: dist.barrier() if args.fp16: model.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) model.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): if args.local_rank != -1: train_sampler.set_epoch(i_epoch) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', disable=args.local_rank not in (-1, 0)) for step, batch in enumerate(iter_bar): batch = [ t.to(device) if t is not None else None for t in batch ] if args.has_sentence_oracle: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch else: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch print(sop_label) print(task_idx) oracle_pos, oracle_weights, oracle_labels = None, None, None # loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, # masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, # masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, # masked_labels_2=oracle_labels, mask_qkv=mask_qkv) loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, sop_label, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv) masked_lm_loss, next_sentence_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() print('mask_lm_loss {}'.format(masked_lm_loss)) print('next_sentence_loss {}'.format(next_sentence_loss)) print('----------------------------------------------') loss = masked_lm_loss + next_sentence_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model if (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{0}.bin".format(i_epoch)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.output_dir, "optim.{0}.bin".format(i_epoch)) torch.save(optimizer.state_dict(), output_optim_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache()
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default="../../../dataset/final_data/commongen", type=str, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument( "--bart_model", default="facebook/bart-large", type=str, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default="../../../output/train_kgbart", type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default="../../../log/train_kgbart", type=str, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") # Other parameters parser.add_argument( "--max_seq_length", default=64, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", default=True, type=bool, help="Whether to run training.") parser.add_argument("--do_eval", default=True, type=bool, help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", default=False, type=bool, help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=48, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=2, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0.1, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=6, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', default=True, type=bool, help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', default=True, type=bool, help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=64, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=64, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', default=True, type=bool, help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.70, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=30, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=5, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument('--pretraining_KG', action='store_true', help="Whether to pretraining KG. ") parser.add_argument('--train_pretraining_num', type=int, default=20, help="The number of sample training pretraining KG. ") parser.add_argument('--val_pretraining_num', type=int, default=20, help="The number of sample validing pretraining KG.") parser.add_argument('--max_position_embeddings', type=int, default=64, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', default=False, type=bool, help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', default=False, type=bool, help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', default=False, type=bool, help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument( "--fp16_opt_level", type=str, default="O1", help= "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html", ) parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--keep_last_epochs", default=5, type=int, help="Keep the last few epochs.") args = parser.parse_args() # assert Path(args.model_recover_path).exists( # ), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BartTokenizer.from_pretrained(args.bart_model, do_lower_case=args.do_lower_case) # if args.max_position_embeddings: # tokenizer.max_len = args.max_position_embeddings data_tokenizer = tokenizer # WhitespaceTokenizer() if args.tokenized_input else if args.local_rank == 0: dist.barrier() bi_uni_pipeline = [ seq2seq_loader.Preprocess4Pretrain( list(tokenizer.encoder.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mode="s2s", pretraining_KG=args.pretraining_KG, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] file_oracle = None entity_id = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'CommonGen_KG/commongen_entity2id.txt') entity_embedding_path = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'CommonGen_KG/commongen_ent_embeddings') relation_id = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'CommonGen_KG/commongen_relation2id.txt') relation_embedding_path = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'CommonGen_KG/commongen_rel_embeddings') entity_embedding = np.array(pickle.load(open(entity_embedding_path, "rb"))) entity_embedding = np.array( list(np.zeros((4, 1024))) + list(entity_embedding)) relation_embedding = np.array( pickle.load(open(relation_embedding_path, "rb"))) if args.do_train: print("Loading Train Dataset", args.data_dir) if args.pretraining_KG: file_oracle = os.path.join(args.data_dir, 'CommonGen_KG/commongen_entity2id.txt') fn_src = os.path.join( args.data_dir, args.src_file if args.src_file else 'commongen.train.src_new.txt') fn_tgt = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'commongen.train.tgt.txt') fn_onehop = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'commongen.train.onehop_5.txt') train_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, entity_id, relation_id, fn_onehop, args.train_batch_size, data_tokenizer, args.max_seq_length, pretraining_KG=file_oracle, pretraining_num=args.train_pretraining_num, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) if args.do_eval: print("Loading Dev Dataset", args.data_dir) if args.pretraining_KG: file_oracle = os.path.join(args.data_dir, 'CommonGen_KG/commongen_entity2id.txt') fn_src = os.path.join( args.data_dir, args.src_file if args.src_file else 'commongen.dev.src_new.txt') fn_tgt = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'commongen.dev.tgt.txt') fn_onehop = os.path.join( args.data_dir, args.tgt_file if args.tgt_file else 'commongen.dev.onehop_5.txt') dev_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, entity_id, relation_id, fn_onehop, args.eval_batch_size, data_tokenizer, args.max_seq_length, pretraining_KG=file_oracle, pretraining_num=args.val_pretraining_num, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: dev_sampler = RandomSampler(dev_dataset, replacement=False) _batch_size = args.eval_batch_size else: dev_sampler = DistributedSampler(dev_dataset) _batch_size = args.eval_batch_size // dist.get_world_size() dev_dataloader = torch.utils.data.DataLoader( dev_dataset, batch_size=_batch_size, sampler=dev_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.pretraining_KG else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None model = KGBartForConditionalGeneration.from_pretrained( args.bart_model, entity_weight=entity_embedding, relation_weight=relation_embedding) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 model = KGBartForConditionalGeneration.from_pretrained( args.bart_model, state_dict=model_recover, entity_weight=entity_embedding, relation_weight=relation_embedding) if args.local_rank == 0: dist.barrier() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) schedule_recover = torch.load(os.path.join( args.output_dir, "sched.{0}.bin".format(recover_step)), map_location='cpu') scheduler.load_state_dict(schedule_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() best_dev_loss = 1000 output_eval_file = os.path.join(args.log_dir, "eval_results.txt") writer = open(output_eval_file, "w") def checkpoint_paths(path, pattern=r"model(\d+)\.pt"): """Retrieves all checkpoints found in `path` directory. Checkpoints are identified by matching filename to the specified pattern. If the pattern contains groups, the result will be sorted by the first group in descending order. """ pt_regexp = re.compile(pattern) files = os.listdir(path) entries = [] for i, f in enumerate(files): m = pt_regexp.fullmatch(f) if m is not None: idx = float(m.group(1)) if len(m.groups()) > 0 else int( f.split(".")[1]) entries.append((idx, m.group(0))) return [ os.path.join(path, x[1]) for x in sorted(entries, reverse=True) ] if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): model.train() if args.local_rank != -1: train_sampler.set_epoch(i_epoch) # iter_bar = tqdm(BackgroundGenerator(train_dataloader), desc='Iter (loss=X.XXX)', # disable=args.local_rank not in (-1, 0)) for step, batch in enumerate( tqdm(train_dataloader, desc="Training", position=0, leave=True)): batch = [ t.to(device) if t is not None else None for t in batch ] if args.pretraining_KG: input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch else: input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch loss_output = model( input_ids, input_entity_ids=input_entity_ids, attention_mask=subword_mask, word_mask=word_mask, word_subword=word_subword, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=labels, label_smoothing=False) masked_lm_loss = loss_output.loss if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() loss = masked_lm_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) # iter_bar.set_description('Iter %d (loss=%5.3f)' % (i_epoch, loss.item())) if step % 1000 == 0: print('Iter %d (Gen_loss=%5.3f)' % (i_epoch, loss.item())) # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): model.eval() cur_dev_loss = [] with torch.no_grad(): for step, batch in enumerate( tqdm(dev_dataloader, desc="Evaluating", position=0, leave=True)): batch = [ t.to(device) if t is not None else None for t in batch ] if args.pretraining_KG: input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch else: input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch loss_output = model( input_ids, input_entity_ids=input_entity_ids, attention_mask=subword_mask, word_mask=word_mask, word_subword=word_subword, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=labels, label_smoothing=False) masked_lm_loss = loss_output.loss if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() # next_sentence_loss = next_sentence_loss.mean() loss = masked_lm_loss cur_dev_loss.append(float(loss.item())) # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) dev_loss = sum(cur_dev_loss) / float(len(cur_dev_loss)) print("the epoch {} DEV loss is {}".format( i_epoch, dev_loss)) if best_dev_loss > dev_loss: best_dev_loss = dev_loss model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self os.makedirs(args.output_dir + "/best_model", exist_ok=True) output_model_file = os.path.join( args.output_dir, "best_model/model.best.bin") # output_optim_file = os.path.join( # args.output_dir, "best_model/optim.best.bin") # output_schedule_file = os.path.join( # args.output_dir, "best_model/sched.best.bin") torch.save(model_to_save.state_dict(), output_model_file) # torch.save(optimizer.state_dict(), output_optim_file) # torch.save(scheduler.state_dict(), output_schedule_file) logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * " ) model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{0}.bin".format(i_epoch)) output_optim_file = os.path.join( args.output_dir, "optim.{0}.bin".format(i_epoch)) output_schedule_file = os.path.join( args.output_dir, "sched.{0}.bin".format(i_epoch)) torch.save(model_to_save.state_dict(), output_model_file) torch.save(optimizer.state_dict(), output_optim_file) torch.save(scheduler.state_dict(), output_schedule_file) writer.write("epoch " + str(i_epoch) + "\n") writer.write("the current eval accuracy is: " + str(dev_loss) + "\n") logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( args.output_dir, pattern=r"model.\d+.bin") for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) checkpoints = checkpoint_paths( args.output_dir, pattern=r"optim.\d+.bin") for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) checkpoints = checkpoint_paths( args.output_dir, pattern=r"sched.\d+.bin") for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache()
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--topic_model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining topic model.") parser.add_argument("--topic_model_dict_path", default=None, type=str, help="The file of fine-tuned pretraining topic model.") parser.add_argument("--tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") parser.add_argument('--topic_mode', default=1, type=float, help="1:idea1 1.1:idea1_wo_theta 2:idea2 ") parser.add_argument('--topic_model', default=False, type=bool, help="if only use topic model") # Other parameters parser.add_argument( "--max_seq_length", default=192, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument( "--train_batch_size", default=32, type=int, help="Total batch size for training.") #batch_size = batch_size/n_gpus parser.add_argument("--eval_batch_size", default=16, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=30, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) print("args.local_rank", args.local_rank) print("args.no_cuda", args.no_cuda) if args.local_rank == -1 or args.no_cuda: #-1 False device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") #device = cuda n_gpu = torch.cuda.device_count() print("n_gpu_1", n_gpu) else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') print("n_gpu_1", n_gpu) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() if args.do_train: bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') fn_src = os.path.join(args.data_dir, args.src_file if args.src_file else 'train.src') fn_tgt = os.path.join(args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt') train_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, args.data_dir, args.topic_model_dict_path, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 ### type_vocab_size=6 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None unilm = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: # here is the entrance logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 unilm = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) #1. 模型初始化,入口定义好 gsm = GSM(train_dataset.vocabsize) gsm_checkpoint = torch.load(args.topic_model_recover_path) gsm.load_state_dict(gsm_checkpoint["net"]) if args.local_rank == 0: dist.barrier() if args.fp16: unilm.half() gsm.half() if args.fp32_embedding: unilm.bert.embeddings.word_embeddings.float() unilm.bert.embeddings.position_embeddings.float() unilm.bert.embeddings.token_type_embeddings.float() unilm.to(device) gsm.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") unilm = DDP(unilm, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: # model = torch.nn.DataParallel(model) unilm = DataParallelImbalance(unilm) gsm = DataParallelImbalance(gsm) # Prepare optimizer total = 0 param_optimizer = list(unilm.named_parameters()) param_optimizer_topic = list(gsm.named_parameters()) for name, parameters in unilm.named_parameters(): if "idea" in name: if "11" in name and "idea2" in name: total += np.prod(parameters.size()) # print(name, ':', parameters.size()) else: total += np.prod(parameters.size()) # print(name, ':', parameters.size()) print("gsm have {} paramerters in total".format( sum(x.numel() for x in gsm.parameters()))) print("Number of parameter: %.6fM" % (total / 1e6)) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] if not args.topic_model: optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01, 'topic': False }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'topic': False }, { 'params': [p for n, p in param_optimizer_topic], 'weight_decay': 0.0, 'lr': 1e-3, 'topic': True }] else: optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer_topic], 'weight_decay': 0.0, 'lr': 1e-3, 'topic': True }] #一部分是有weight的,一部分是没有weight_dacay的 # print("optimizer_grouped_parameters", optimizer_grouped_parameters) if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) unilm.train() gsm.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 print("000000", args.local_rank, start_epoch, int(args.num_train_epochs) + 1) topicloss = [] unilmloss = [] topicloss_lst = [] unilmloss_lst = [] for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): loss_sum = 0.0 ppx_sum = 0.0 word_count = 0.0 doc_count = 0.0 if args.local_rank != -1: train_sampler.set_epoch(i_epoch) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', disable=args.local_rank not in (-1, 0)) for step, batch in enumerate(iter_bar): batch = [ t.to(device) if t is not None else None for t in batch ] if args.has_sentence_oracle: #false input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch else: #这里加了bows input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, bows = batch oracle_pos, oracle_weights, oracle_labels = None, None, None p_x, mus, log_vars, theta, beta, topic_embedding = gsm(bows) if not args.topic_model: loss_tuple = unilm(input_ids, theta, beta, topic_embedding, args.topic_mode, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv) masked_lm_loss, next_sentence_loss = loss_tuple ## topic loss logsoftmax = torch.log(p_x + 1e-10) rec_loss = -1.0 * torch.sum( bows * logsoftmax ) #bows*logsoftmax = [batch_size, |V|], 其中torch.sum 把所有的loss全部加起来了,也可以只用加某一维度。 rec_loss_per = -1.0 * torch.sum(bows * logsoftmax, dim=1) rec_loss_per = rec_loss_per.cpu().detach().numpy() kl_div = -0.5 * torch.sum(1 + log_vars - mus.pow(2) - log_vars.exp()) loss_topic = rec_loss + kl_div if n_gpu > 1: # mean() to average on multi-gpu. loss_topic = loss_topic.mean() if not args.topic_model: masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() if not args.topic_model: loss_unilm = masked_lm_loss + next_sentence_loss # cal perplexity word_count_list = [] loss_sum += loss_topic.item() for bow in bows: word_num = torch.sum(bow).cpu().numpy() word_count_list.append(word_num) word_count += word_num word_count_np = np.array(word_count_list) doc_count += len(bows) ppx_sum += np.sum(np.true_divide(rec_loss_per, word_count_np)) topicloss_lst.append(loss_topic.item() / len(bows)) if not args.topic_model: unilmloss_lst.append(loss_unilm.item()) #topic_loss end if not args.topic_model: loss = loss_unilm + loss_topic else: loss = loss_topic # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: # =1 loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: if not param_group['topic']: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 if not args.topic_model: iter_bar.set_description( 'Iter (loss_unilm=%5.3f),Iter (ppl=%5.3f)' % (loss_unilm.item(), np.sum(np.true_divide(rec_loss_per, word_count_np)))) else: iter_bar.set_description( 'Iter (loss_topic=%5.3f), (ppl=%5.3f)' % (loss_topic.item(), np.sum(np.true_divide(rec_loss_per, word_count_np)))) #Save a trained model ppx_word = np.exp(loss_sum / word_count) ppx_document = np.exp(ppx_sum / doc_count) print("********") print("word_count", word_count) print("ppx_word", ppx_word) print("ppx_document", ppx_document) if (args.local_rank == -1 or torch.distributed.get_rank() == 0): #save unilm model logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") unilm_model_to_save = unilm.module if hasattr( unilm, 'module') else unilm # Only save the model it-self output_unilm_model_file = os.path.join( args.output_dir, "unilm.{0}.bin".format(i_epoch)) torch.save(unilm_model_to_save.state_dict(), output_unilm_model_file) #save topic model logger.info( "** ** * Saving topic model and optimizer ** ** * ") topic_model_to_save = gsm.module if hasattr( gsm, 'module') else gsm # Only save the model it-self output_topic_model_file = os.path.join( args.output_dir, "topic.{0}.ckpt".format(i_epoch)) torch.save(topic_model_to_save.state_dict(), output_topic_model_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() smth_pts = smooth_curve(topicloss_lst) # plt.plot(range(len(topicloss_lst)), topicloss_lst) plt.plot(range(len(smth_pts)), smth_pts) plt.xlabel('epochs') plt.title('Topic Model Train Loss') plt.savefig(args.output_dir + '/topic_loss.png') plt.cla() plt.plot(range(len(unilmloss_lst)), unilmloss_lst) plt.xlabel('epochs') plt.title('Unilm Train Loss') plt.savefig(args.output_dir + '/unilm_loss.png')
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument("--max_seq_length", default=512, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") # decoding parameters parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument("--input_file", type=str, help="Input file") parser.add_argument("--output_file", type=str, help="output file") parser.add_argument("--split", type=str, default="", help="Data split (train/val/test).") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--top_k', type=int, default=1, help="Top k for output") parser.add_argument('--top_kk', type=int, default=0, help="Top k sample method for output") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--need_score_traces', action='store_true') parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--mode', default="s2s", choices=["s2s", "l2r", "both"]) parser.add_argument('--max_tgt_length', type=int, default=128, help="maximum length of target sequence") # evaluate parameters parser.add_argument('--do_predict', action='store_true', help="do_predict") parser.add_argument("--do_evaluate", action="store_true", help="caculate the scores if have label file") parser.add_argument("--label_file", type=str, default="") parser.add_argument("--experiment", type=str, default="full", help="full/title/title-l1/hierachical/title-first/title-first-rouge") # ranker parameters parser.add_argument("--ranker_recover_path", type=str, help="ranker model for extract sentence") parser.add_argument("--ranker_max_len", type=int, default=192, help ="max length of the ranker input") parser.add_argument("--ranker_batch_size", type=int, default=128) args = parser.parse_args() if args.need_score_traces and args.beam_size <= 1: raise ValueError( "Score trace is only available for beam search with beam size > 1.") if args.max_tgt_length >= args.max_seq_length - 2: raise ValueError("Maximum tgt length exceeds max seq length - 2.") device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case) tokenizer.max_len = args.max_seq_length pair_num_relation = 0 bi_uni_pipeline = [] if args.mode == "s2s" or args.mode == "both": bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list( tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s")) if args.mode == "l2r" or args.mode == "both": bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list( tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="l2r")) if args.experiment == "segsep": bi_uni_pipeline = [] bi_uni_pipeline.append(Preprocess4SegSepDecoder(list( tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s")) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 if args.experiment == "segsep": type_vocab_size = 11 mask_word_id, eos_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]"]) forbid_ignore_set = None if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list)) print(args.model_recover_path) if args.do_predict: for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, max_position_embeddings=args.max_seq_length) del model_recover if args.fp16: model.half() model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) torch.cuda.empty_cache() model.eval() next_i = 0 max_src_length = args.max_seq_length - 2 - args.max_tgt_length if args.experiment in ["full", "title", "title-l1"]: input_lines = EvalDataset(args.input_file, args.experiment).proc() elif args.experiment == "single": input_lines, map_dict = EvalDataset(args.input_file, args.experiment).proc() elif args.experiment == "title-first": input_lines = EvalDataset(args.input_file, args.experiment, tokenizer, args.max_seq_length, args.max_seq_length).proc() elif args.experiment == "segsep": input_lines = EvalDataset(args.input_file, args.experiment, tokenizer, args.max_seq_length, args.max_seq_length).proc() elif args.experiment == "heirachical": logger.info("***** Recover rank model: %s *****", args.ranker_recover_path) # extract sentences before load data # load rank model rank_model_recover = torch.load(args.ranker_recover_path, map_location="cpu") global_step = 0 rank_model = BertForSentenceRanker.from_pretrained(args.bert_model, state_dict=rank_model_recover, num_labels=2) # set model for multi GPUs or multi nodes if args.fp16: rank_model.half() rank_model.to(device) if n_gpu > 1: rank_model = DataParallelImbalance(rank_model) DatasetFunc = ScoreEvalDataset # Load title + sentence pair print ("Loading Rank Dataset from ", args.input_file) data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer max_pred = 16 mask_prob = 0.7 rank_bi_uni_pipeline = [Preprocess4Seq2cls(max_pred, mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.ranker_max_len, new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': 64, 'max_len_b': 16, 'trunc_seg': 'a', 'always_truncate_tail': True}, mask_source_words=False, skipgram_prb=0.0, skipgram_size=1, mask_whole_word=False, mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, eval=True)] fn_src = args.input_file fn_tgt = None eval_dataset = DatasetFunc( fn_src, fn_tgt, args.ranker_batch_size, data_tokenizer, args.ranker_max_len, bi_uni_pipeline=rank_bi_uni_pipeline ) eval_sampler = SequentialSampler(eval_dataset) _batch_size = args.ranker_batch_size eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=_batch_size, sampler=eval_sampler, num_workers=24, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() logger.info("***** Runinning ranker *****") logger.info(" Batch size = %d", _batch_size) logger.info(" Num steps = %d", int(len(eval_dataset)/ args.ranker_batch_size)) rank_model.to(device) rank_model.eval() iter_bar = tqdm(eval_dataloader, desc = "Iter: ") num_rank_labels = 2 all_labels = [] for step, batch in enumerate(iter_bar): batch = [t.to(device) if t is not None else None for t in batch] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch logits = rank_model(input_ids, task_idx=task_idx, mask_qkv=mask_qkv) labels = torch.argmax(logits.view(-1, num_rank_labels), dim=-1) all_labels.append(labels) all_labels_results = [] for label in all_labels: all_labels_results.extend(label.detach().cpu().numpy()) # collect results logger.info("**** Collect results ******") clu2doc_dict, doc2sent_dict, all_titles, all_sents = eval_dataset.get_maps() all_docs = [] for i, doc in enumerate(doc2sent_dict): text = all_titles[i] sent_idx = doc2sent_dict[doc] for idx in sent_idx: if all_labels_results[idx] == 1: text += ". " + all_sents[idx] all_docs.append(text) input_lines = [] for clu in tqdm(clu2doc_dict): doc_idx = clu2doc_dict[clu] input_line = "" for idx in doc_idx: input_line += all_docs[idx] input_lines.append(input_line) elif args.experiment == "title-first-rank": logger.info("***** Recover rank model: %s *****", args.ranker_recover_path) # extract sentences before load data # load rank model rank_model_recover = torch.load(args.ranker_recover_path, map_location="cpu") global_step = 0 rank_model = BertForSentenceRanker.from_pretrained(args.bert_model, state_dict=rank_model_recover, num_labels=2) # set model for multi GPUs or multi nodes if args.fp16: rank_model.half() rank_model.to(device) if n_gpu > 1: rank_model = DataParallelImbalance(rank_model) DatasetFunc = EvalRankDataset # Load title + sentence pair print ("Loading Rank Dataset from ", args.input_file) data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer max_pred = 16 mask_prob = 0.7 rank_bi_uni_pipeline = [Preprocess4Seq2cls(max_pred, mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': 512, 'max_len_b': 16, 'trunc_seg': 'a', 'always_truncate_tail': True}, mask_source_words=False, skipgram_prb=0.0, skipgram_size=1, mask_whole_word=False, mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, eval=True)] fn_src = args.input_file fn_tgt = None eval_dataset = DatasetFunc( fn_src, fn_tgt, args.ranker_batch_size, data_tokenizer, args.max_seq_length, bi_uni_pipeline=rank_bi_uni_pipeline ) eval_sampler = SequentialSampler(eval_dataset) _batch_size = args.ranker_batch_size eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=_batch_size, sampler=eval_sampler, num_workers=24, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() logger.info("***** Runinning ranker *****") logger.info(" Batch size = %d", _batch_size) logger.info(" Num steps = %d", int(len(eval_dataset)/ args.ranker_batch_size)) rank_model.to(device) rank_model.eval() iter_bar = tqdm(eval_dataloader, desc = "Iter: ") num_rank_labels = 2 all_labels = [] for step, batch in enumerate(iter_bar): batch = [t.to(device) if t is not None else None for t in batch] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch # print("input_ids", len(input_ids[0]), "segment_ids", len(segment_ids[0])) with torch.no_grad(): logits = rank_model(input_ids, task_idx=task_idx, mask_qkv=mask_qkv) labels = logits.view(-1) all_labels.append(labels) all_labels_results = [] for label in all_labels: all_labels_results.extend(label.detach().cpu().numpy()) print("test label results") print(all_labels_results[0]) # collect results logger.info("**** Collect results ******") clu2sent_dict, all_sents, all_titles= eval_dataset.get_maps() all_clusters = [] input_lines = [] for i, clu_id in enumerate(clu2sent_dict): text = all_titles[clu_id] sent_idx = clu2sent_dict[clu_id] sents_collect = [] for idx in sent_idx: sents_collect.append([all_sents[idx], all_labels_results[idx]]) sents_collect_sort = sorted(sents_collect, key=lambda x:x[1]) sents_collect = [x[0] for x in sents_collect_sort] text_tk = tokenizer.tokenize(text) j = 0 while j < len(sents_collect) and len(text_tk) + len(tokenizer.tokenize(sents_collect[j])) <= args.max_seq_length: text += " " + sents_collect[j] j += 1 input_lines.append(text) else: input_lines = [] data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer input_lines = [data_tokenizer.tokenize( x)[:max_src_length] for x in input_lines] input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1])) output_lines = [""] * len(input_lines) score_trace_list = [None] * len(input_lines) total_batch = math.ceil(len(input_lines) / args.batch_size) with tqdm(total=total_batch) as pbar: while next_i < len(input_lines): _chunk = input_lines[next_i:next_i + args.batch_size] buf_id = [x[0] for x in _chunk] buf = [x[1] for x in _chunk] next_i += args.batch_size max_a_len = max([len(x) for x in buf]) instances = [] for instance in [(x, max_a_len) for x in buf]: for proc in bi_uni_pipeline: instances.append(proc(instance)) with torch.no_grad(): batch = seq2seq_loader.batch_list_to_batch_tensors( instances) # print("batch") # print(batch) # print(len(batch)) batch = [t.to(device) for t in batch if t is not None] input_ids, token_type_ids, position_ids, input_mask, task_idx = batch traces = model(input_ids, token_type_ids, position_ids, input_mask, task_idx=task_idx) if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces.tolist() for i in range(len(buf)): scores = traces['scores'][i] wids_list = traces['wids'][i] ptrs = traces['ptrs'][i] eos_id = 102 top_k = args.top_k # first we need to find the eos frame where all symbols are eos # any frames after the eos frame are invalid last_frame_id = len(scores) - 1 for _i, wids in enumerate(wids_list): if all(wid == eos_id for wid in wids): last_frame_id = _i break frame_id = -1 pos_in_frame = -1 seqs = [] for fid in range(last_frame_id + 1): for _i, wid in enumerate(wids_list[fid]): if wid == eos_id or fid == last_frame_id: s = scores[fid][_i] frame_id = fid pos_in_frame = _i if frame_id != -1 and s < 0: seq = [wids_list[frame_id][pos_in_frame]] for _fid in range(frame_id, 0, -1): pos_in_frame = ptrs[_fid][pos_in_frame] seq.append(wids_list[_fid - 1][pos_in_frame]) seq.reverse() seqs.append([seq, s]) seqs = sorted(seqs, key= lambda x:x[1], reverse=True) w_idss = [seq[0] for seq in seqs[:top_k]] output_sequences = [] for w_ids in w_idss: output_buf = tokenizer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) output_sequence = ' '.join(detokenize(output_tokens)) output_sequences.append(output_sequence) output_lines[buf_id[i]] = output_sequences if args.need_score_traces: score_trace_list[buf_id[i]] = { 'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]} pbar.update(1) # collect instances after split results = [] if args.experiment == "single": for clu in map_dict: record = [] clu_ixs = map_dict[clu] for i in clu_ixs: record.extend(output_lines[i]) record_top10 = Counter(record).most_common(10) record_top10 = [x[0] for x in record_top10] results.append(record_top10) output_lines = results if args.output_file: fn_out = args.output_file else: fn_out = model_recover_path+'.'+args.split with open(fn_out, "w", encoding="utf-8") as fout: for l in output_lines: fout.write('\t'.join(l)) fout.write("\n") if args.need_score_traces: with open(fn_out + ".trace.pickle", "wb") as fout_trace: pickle.dump( {"version": 0.0, "num_samples": len(input_lines)}, fout_trace) for x in score_trace_list: pickle.dump(x, fout_trace) # Evaluate ! if args.do_evaluate: labels = [] if not os.path.exists(args.label_file): raise ValueError("Label file not exists") print("Loading label file from {}".format(args.label_file)) with open(args.label_file) as f: for line in tqdm(f.readlines()): line = line.strip().split("\t") labels.append(line) results = output_lines ks = [1, 5, 10] results_dict = {} for k in ks: acc_cul = 0 r_cul = 0 f1_cul = 0 cnt = 0 for predict, true_label in zip(tqdm(results), tqdm(labels)): predict = predict[:k] true_label = true_label[:k] if len(predict) > 0 and len(true_label) > 0: acc_cul += acc_score(predict, true_label) r_cul += recall_score(predict, true_label) f1_cul += f1_score(acc_score(predict, true_label), recall_score(predict, true_label)) cnt += 1 results_dict["P@{}".format(k)] = acc_cul*1.000 / cnt results_dict["R@{}".format(k)] = r_cul*1.000 / cnt results_dict["F1@{}".format(k)] = f1_cul*1.000 / cnt print(results_dict)
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--dev_src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--dev_tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument("--dev_check_file", default=None, type=str, help="The output style response/know data file name.") parser.add_argument("--dev_style_file", default=None, type=str, help="The output style response/know data file name.") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") parser.add_argument("--predict_bleu", default=0.5, type=float, help="The Predicted Bleu for KS Predict ") parser.add_argument("--train_vae", action='store_true', help="Whether to train vae.") # Other parameters parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_predict", action='store_true', help="Whether to run ks predict.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion_step", default=300, type=int, help= "Proportion of training to perform linear learning rate warmup for. ") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) handler = logging.FileHandler(os.path.join(args.log_dir, "train.log"), encoding='UTF-8') handler.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) console = logging.StreamHandler() console.setLevel(logging.DEBUG) logger.addHandler(handler) logger.addHandler(console) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() C_bi_uni_pipeline = [ seq2seq_loader.C_Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] logger.info("Loading Dataset from {}".format(args.data_dir)) fn_src = os.path.join(args.data_dir, args.dev_src_file) fn_tgt = os.path.join(args.data_dir, args.dev_tgt_file) dev_reddit_dataset = seq2seq_loader.C_Seq2SeqDataset( fn_src, fn_tgt, args.eval_batch_size, data_tokenizer, args.max_seq_length, file_oracle=None, bi_uni_pipeline=C_bi_uni_pipeline) if args.local_rank == -1: dev_reddit_sampler = RandomSampler(dev_reddit_dataset, replacement=False) _batch_size = args.eval_batch_size else: dev_reddit_sampler = DistributedSampler(dev_reddit_dataset) _batch_size = args.eval_batch_size // dist.get_world_size() dev_reddit_dataloader = torch.utils.data.DataLoader( dev_reddit_dataset, batch_size=_batch_size, sampler=dev_reddit_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) if args.local_rank == 0: dist.barrier() if args.fp16: model.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if args.optim_recover_path is not None: logger.info("***** Recover optimizer from : {} *****".format( args.optim_recover_path)) optim_recover = torch.load(args.optim_recover_path, map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: pretrain_step = -1 logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) model.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): if args.local_rank != -1: train_sampler.set_epoch(i_epoch) logger.info("***** Running QKR evaling *****") logger.info(" Batch size = %d", args.eval_batch_size) if args.local_rank != -1: train_sampler.set_epoch(i_epoch) dev_iter_bar = tqdm(dev_reddit_dataloader, desc='Iter (loss=X.XXX)', disable=args.local_rank not in (-1, 0)) total_lm_loss = 0 for qkr_dev_step, batch in enumerate(dev_iter_bar): batch = [ t.to(device) if t is not None else None for t in batch ] if args.has_sentence_oracle: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch else: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, tgt_pos, labels, ks_labels, style_ids, style_labels, check_ids = batch oracle_pos, oracle_weights, oracle_labels = None, None, None with torch.no_grad(): loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv, tgt_pos=tgt_pos, labels=labels, ks_labels=ks_labels, train_vae=args.train_vae, style_ids=style_ids, style_labels=style_labels, check_ids=check_ids, pretrain=None) masked_lm_loss, next_sentence_loss, KL_loss, Mutual_loss, Golden_loss, cosine_similarity_loss, predict_kl_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. masked_lm_loss = masked_lm_loss.mean() # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) total_lm_loss += masked_lm_loss.item() # ensure that accumlated gradients are normalized total_mean_lm_loss = total_lm_loss / (qkr_dev_step + 1) print(total_mean_lm_loss) logger.info("** ** * Evaling mean loss ** ** * ") logger.info("In{}epoch,dev_lm_loss:{}".format( i_epoch, total_mean_lm_loss)) logger.info("ppl:{}".format(np.exp(total_mean_lm_loss))) logger.info("******************************************* ") break