예제 #1
0
파일: run_seq2seq.py 프로젝트: yumoxu/ZRKGC
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    #Train File
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data src file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The input data tgt file name.")
    parser.add_argument("--check_file",
                        default=None,
                        type=str,
                        help="The input check knowledge data file name")

    #KS File
    parser.add_argument("--ks_src_file",
                        default=None,
                        type=str,
                        help="The input ks data src file name.")
    parser.add_argument("--ks_tgt_file",
                        default=None,
                        type=str,
                        help="The input ks data tgt file name.")

    parser.add_argument("--predict_input_file",
                        default=None,
                        type=str,
                        help="predict_input_file")
    parser.add_argument("--predict_output_file",
                        default=None,
                        type=str,
                        help="predict_output_file")

    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument("--predict_bleu",
                        default=0.2,
                        type=float,
                        help="The Predicted Bleu for KS Predict ")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run ks predict.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--train_avg_bpe_length",
                        default=25,
                        type=int,
                        help="average bpe length for train.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion_step",
        default=300,
        type=int,
        help=
        "Proportion of training to perform linear learning rate warmup for. ")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=67,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()

    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)

    handler = logging.FileHandler(os.path.join(args.log_dir, "train.log"),
                                  encoding='UTF-8')
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)

    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)

    logger.addHandler(handler)
    logger.addHandler(console)

    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    #Random Seed

    #torch.backends.cudnn.enabled = False
    #torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.deterministic = True
    # if n_gpu > 0:
    # 	torch.cuda.manual_seed_all(args.seed)

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    #Data process pipelines
    bi_uni_pipeline = [
        seq2seq_loader.Preprocess4Seq2seq(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]
    C_bi_uni_pipeline = [
        seq2seq_loader.C_Preprocess4Seq2seq(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]
    ks_predict_bi_uni_pipeline = [
        seq2seq_loader.Preprocess4Seq2seq_predict(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]

    if args.do_train:
        print("Loading QKR Train Dataset", args.data_dir)
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        fn_check = os.path.join(args.data_dir, args.check_file)

        train_dataset = seq2seq_loader.C_Seq2SeqDataset(
            fn_src,
            fn_tgt,
            fn_check,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=C_bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        print("Loading KS Train Dataset", args.data_dir)
        ks_fn_src = os.path.join(args.data_dir, args.ks_src_file)
        ks_fn_tgt = os.path.join(args.data_dir, args.ks_tgt_file)
        ks_train_dataset = seq2seq_loader.Seq2SeqDataset(
            ks_fn_src,
            ks_fn_tgt,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            ks_train_sampler = RandomSampler(ks_train_dataset,
                                             replacement=False)
            _batch_size = args.train_batch_size
        else:
            ks_train_sampler = DistributedSampler(ks_train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        ks_train_dataloader = torch.utils.data.DataLoader(
            ks_train_dataset,
            batch_size=_batch_size,
            sampler=ks_train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
        t_total = int(
            len(train_dataloader) * args.num_train_epochs /
            args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + (
        1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()

    #Recover model
    if args.model_recover_path:
        logger.info(" ** ** * Recover model: %s ** ** * ",
                    args.model_recover_path)
        model_recover = torch.load(args.model_recover_path, map_location='cpu')
        global_step = 0

    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    model = BertForPreTrainingLossMask.from_pretrained(
        args.bert_model,
        state_dict=model_recover,
        num_labels=cls_num_labels,
        num_rel=0,
        type_vocab_size=type_vocab_size,
        config_path=args.config_path,
        task_idx=3,
        num_sentlvl_labels=num_sentlvl_labels,
        max_position_embeddings=args.max_position_embeddings,
        label_smoothing=args.label_smoothing,
        fp32_embedding=args.fp32_embedding,
        relax_projection=relax_projection,
        new_pos_ids=args.new_pos_ids,
        ffn_type=args.ffn_type,
        hidden_dropout_prob=args.hidden_dropout_prob,
        attention_probs_dropout_prob=args.attention_probs_dropout_prob,
        num_qkv=args.num_qkv,
        seg_emb=args.seg_emb,
        mask_word_id=mask_word_id,
        search_beam_size=5,
        length_penalty=0,
        eos_id=eos_word_ids,
        sos_id=sos_word_id,
        forbid_duplicate_ngrams=True,
        forbid_ignore_set=None,
        mode="s2s")

    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)

    model.tmp_bert_emb.word_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.word_embeddings.weight.clone())
    model.tmp_bert_emb.token_type_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.token_type_embeddings.weight.clone())
    model.tmp_bert_emb.position_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.position_embeddings.weight.clone())
    model.mul_bert_emb.word_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.word_embeddings.weight.clone())
    model.mul_bert_emb.token_type_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.token_type_embeddings.weight.clone())
    model.mul_bert_emb.position_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.position_embeddings.weight.clone())
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from pytorch_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if args.optim_recover_path is not None:
        logger.info(" ** ** * Recover optimizer from : {} ** ** * ".format(
            args.optim_recover_path))
        optim_recover = torch.load(args.optim_recover_path, map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info(
                " ** ** * Recover optimizer: dynamic_loss_scale ** ** * ")
            optimizer.dynamic_loss_scale = True

    #logger.info(" ** ** * CUDA.empty_cache() ** ** * ")
    torch.cuda.empty_cache()

    # ################# TRAIN ############################ #
    if args.do_train:
        max_F1 = 0
        best_step = 0
        logger.info(" ** ** * Running training ** ** * ")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        start_epoch = 1

        for i_epoch in trange(start_epoch,
                              start_epoch + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)

            step = 0
            for batch, ks_batch in zip(train_dataloader, ks_train_dataloader):
                # ################# E step + M step + Mutual Information Loss ############################ #
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]

                input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, tgt_pos, labels, ks_labels, check_ids = batch
                oracle_pos, oracle_weights, oracle_labels = None, None, None

                loss_tuple = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   is_next,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv,
                                   tgt_pos=tgt_pos,
                                   labels=labels.half(),
                                   ks_labels=ks_labels,
                                   check_ids=check_ids)

                masked_lm_loss, next_sentence_loss, KL_loss, Mutual_loss, Golden_loss, predict_kl_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    masked_lm_loss = masked_lm_loss.mean()
                    next_sentence_loss = next_sentence_loss.mean()
                    Mutual_loss = Mutual_loss.mean()
                    Golden_loss = Golden_loss.mean()
                    KL_loss = KL_loss.mean()
                    predict_kl_loss = predict_kl_loss.mean()

                loss = masked_lm_loss + next_sentence_loss + KL_loss + predict_kl_loss + Mutual_loss + Golden_loss
                logger.info("In{}step, masked_lm_loss:{}".format(
                    step, masked_lm_loss))
                logger.info("In{}step, KL_loss:{}".format(step, KL_loss))
                logger.info("In{}step, Mutual_loss:{}".format(
                    step, Mutual_loss))
                logger.info("In{}step, Golden_loss:{}".format(
                    step, Golden_loss))
                logger.info("In{}step, predict_kl_loss:{}".format(
                    step, predict_kl_loss))

                logger.info("******************************************* ")

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total,
                        args.warmup_proportion_step / t_total)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                # ################# Knowledge Selection Loss ############################ #
                if random.randint(0, 4) == 0:
                    ks_batch = [
                        t.to(device) if t is not None else None
                        for t in ks_batch
                    ]

                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, _, labels, ks_labels = ks_batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                    loss_tuple = model(input_ids,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       is_next,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights,
                                       task_idx=task_idx,
                                       masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels,
                                       mask_qkv=mask_qkv,
                                       labels=labels,
                                       ks_labels=ks_labels,
                                       train_ks=True)

                    ks_loss, _ = loss_tuple
                    if n_gpu > 1:  # mean() to average on multi-gpu.
                        ks_loss = ks_loss.mean()
                    loss = ks_loss

                    logger.info("In{}step, ks_loss:{}".format(step, ks_loss))
                    logger.info("******************************************* ")

                    # ensure that accumlated gradients are normalized
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        optimizer.backward(loss)
                        if amp_handle:
                            amp_handle._clear_cache()
                    else:
                        loss.backward()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / t_total,
                            args.warmup_proportion_step / t_total)
                        if args.fp16:
                            # modify learning rate with special warm up BERT uses
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = lr_this_step
                        optimizer.step()
                        optimizer.zero_grad()

                step += 1
                ###################### Eval Every 5000 Step ############################ #
                if (global_step + 1) % 5000 == 0:
                    next_i = 0
                    model.eval()

                    # Know Rank Stage
                    logger.info(" ** ** * DEV Know Selection Begin ** ** * ")
                    with open(os.path.join(args.data_dir,
                                           args.predict_input_file),
                              "r",
                              encoding="utf-8") as file:
                        src_file = file.readlines()
                    with open(os.path.join(args.data_dir,
                                           "train_tgt_pad.empty"),
                              "r",
                              encoding="utf-8") as file:
                        tgt_file = file.readlines()
                    with open(os.path.join(args.data_dir,
                                           args.predict_output_file),
                              "w",
                              encoding="utf-8") as out:
                        while next_i < len(src_file):
                            batch_src = src_file[next_i:next_i +
                                                 args.eval_batch_size]
                            batch_tgt = tgt_file[next_i:next_i +
                                                 args.eval_batch_size]

                            next_i += args.eval_batch_size

                            ex_list = []
                            for src, tgt in zip(batch_src, batch_tgt):
                                src_tk = data_tokenizer.tokenize(src.strip())
                                tgt_tk = data_tokenizer.tokenize(tgt.strip())
                                ex_list.append((src_tk, tgt_tk))

                            batch = []
                            for idx in range(len(ex_list)):
                                instance = ex_list[idx]
                                for proc in ks_predict_bi_uni_pipeline:
                                    instance = proc(instance)
                                    batch.append(instance)

                            batch_tensor = seq2seq_loader.batch_list_to_batch_tensors(
                                batch)
                            batch = [
                                t.to(device) if t is not None else None
                                for t in batch_tensor
                            ]

                            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch

                            predict_bleu = args.predict_bleu * torch.ones(
                                [input_ids.shape[0]], device=input_ids.device)
                            oracle_pos, oracle_weights, oracle_labels = None, None, None
                            with torch.no_grad():
                                logits = model(input_ids,
                                               segment_ids,
                                               input_mask,
                                               lm_label_ids,
                                               is_next,
                                               masked_pos=masked_pos,
                                               masked_weights=masked_weights,
                                               task_idx=task_idx,
                                               masked_pos_2=oracle_pos,
                                               masked_weights_2=oracle_weights,
                                               masked_labels_2=oracle_labels,
                                               mask_qkv=mask_qkv,
                                               labels=predict_bleu,
                                               train_ks=True)

                                logits = torch.nn.functional.softmax(logits,
                                                                     dim=1)
                                labels = logits[:, 1].cpu().numpy()
                                for i in range(len(labels)):
                                    line = batch_src[i].strip()
                                    line += "\t"
                                    line += str(labels[i])
                                    out.write(line)
                                    out.write("\n")

                    data_path = os.path.join(args.data_dir,
                                             "qkr_dev.ks_score.tk")
                    src_path = os.path.join(args.data_dir, "qkr_dev.src.tk")
                    src_out_path = os.path.join(args.data_dir,
                                                "rank_qkr_dev.src.tk")
                    tgt_path = os.path.join(args.data_dir, "qkr_dev.tgt")

                    knowledge_selection(data_path, src_path, src_out_path)
                    logger.info(" ** ** * DEV Know Selection End ** ** * ")

                    # Decode Stage
                    logger.info(" ** ** * Dev Decode Begin ** ** * ")
                    with open(src_out_path, encoding="utf-8") as file:
                        dev_src_lines = file.readlines()
                    with open(tgt_path, encoding="utf-8") as file:
                        golden_response_lines = file.readlines()

                    decode_result = decode_batch(model, dev_src_lines)
                    logger.info(" ** ** * Dev Decode End ** ** * ")

                    # Compute dev F1
                    assert len(decode_result) == len(golden_response_lines)
                    C_F1 = f_one(decode_result, golden_response_lines)[0]
                    logger.info(
                        "** ** * Current F1 is {} ** ** * ".format(C_F1))
                    if C_F1 < max_F1:
                        logger.info(
                            "** ** * Current F1 is lower than Previous F1. So Stop Training ** ** * "
                        )
                        logger.info(
                            "** ** * The best model is {} ** ** * ".format(
                                best_step))
                        break
                    else:
                        max_F1 = C_F1
                        best_step = step
                        logger.info(
                            "** ** * Current F1 is larger than Previous F1. So Continue Training ** ** * "
                        )

                    # Save trained model
                    if (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):
                        logger.info(
                            "** ** * Saving fine-tuned model and optimizer ** ** * "
                        )
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir,
                            "model.{}_{}.bin".format(i_epoch, global_step))
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        output_optim_file = os.path.join(
                            args.output_dir, "optim.bin")
                        torch.save(optimizer.state_dict(), output_optim_file)

                        #logger.info(" ** ** * CUDA.empty_cache() ** ** * ")
                        torch.cuda.empty_cache()

    # ################# Predict ############################ #
    if args.do_predict:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq_predict(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]

        next_i = 0
        model.eval()

        with open(os.path.join(args.data_dir, args.predict_input_file),
                  "r",
                  encoding="utf-8") as file:
            src_file = file.readlines()
        with open("train_tgt_pad.empty", "r", encoding="utf-8") as file:
            tgt_file = file.readlines()
        with open(os.path.join(args.data_dir, args.predict_output_file),
                  "w",
                  encoding="utf-8") as out:
            logger.info("** ** * Continue knowledge ranking ** ** * ")
            for next_i in tqdm(
                    range(len(src_file) // args.eval_batch_size + 1)):
                #while next_i < len(src_file):
                batch_src = src_file[next_i *
                                     args.eval_batch_size:(next_i + 1) *
                                     args.eval_batch_size]
                batch_tgt = tgt_file[next_i *
                                     args.eval_batch_size:(next_i + 1) *
                                     args.eval_batch_size]
                #next_i += args.eval_batch_size

                ex_list = []
                for src, tgt in zip(batch_src, batch_tgt):
                    src_tk = data_tokenizer.tokenize(src.strip())
                    tgt_tk = data_tokenizer.tokenize(tgt.strip())
                    ex_list.append((src_tk, tgt_tk))

                batch = []
                for idx in range(len(ex_list)):
                    instance = ex_list[idx]
                    for proc in bi_uni_pipeline:
                        instance = proc(instance)
                        batch.append(instance)

                batch_tensor = seq2seq_loader.batch_list_to_batch_tensors(
                    batch)
                batch = [
                    t.to(device) if t is not None else None
                    for t in batch_tensor
                ]

                input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch

                predict_bleu = args.predict_bleu * torch.ones(
                    [input_ids.shape[0]], device=input_ids.device)
                oracle_pos, oracle_weights, oracle_labels = None, None, None
                with torch.no_grad():
                    logits = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   is_next,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv,
                                   labels=predict_bleu,
                                   train_ks=True)

                    logits = torch.nn.functional.softmax(logits, dim=1)
                    labels = logits[:, 1].cpu().numpy()
                    for i in range(len(labels)):
                        line = batch_src[i].strip()
                        line += "\t"
                        line += str(labels[i])
                        out.write(line)
                        out.write("\n")
예제 #2
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--local_debug",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()
    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline,
            corpus_preprocessors=corpus_preprocessors)
        train_dataset.initial()
        print(len(train_dataset.ex_list))
        print(train_dataset.batch_size)

        # assert 1==0

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        # c = 0
        # for i_epoch in trange(0, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)):
        #     if args.local_rank != -1:
        #         train_sampler.set_epoch(i_epoch)
        #     iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)',
        #                     disable=args.local_rank not in (-1, 0))
        #     for step, batch in enumerate(iter_bar):
        #         batch = [
        #             t.to(device) if t is not None else None for t in batch]
        #         if args.has_sentence_oracle:
        #             input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
        #         else:
        #             input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch
        #             oracle_pos, oracle_weights, oracle_labels = None, None, None
        #         c += input_ids.shape[0]
        #
        #         # print(input_ids)
        #         #
        #         # print(input_ids.shape)
        #         # print(segment_ids)
        #         # print(segment_ids.shape)
        #         # print(is_next)
        #         # print(task_idx)
        #         # print(sop_label)
        #         # print(task_idx.shape)
        #         # for i in range(input_mask.shape[0]):
        #         #     print(input_mask[i])
        #     print(c)
        #     print(train_dataset.c)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
                else:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch
                    print(sop_label)
                    print(task_idx)
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                # loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next,
                #                    masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx,
                #                    masked_pos_2=oracle_pos, masked_weights_2=oracle_weights,
                #                    masked_labels_2=oracle_labels, mask_qkv=mask_qkv)
                loss_tuple = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   sop_label,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv)
                masked_lm_loss, next_sentence_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                    next_sentence_loss = next_sentence_loss.mean()
                print('mask_lm_loss {}'.format(masked_lm_loss))
                print('next_sentence_loss {}'.format(next_sentence_loss))
                print('----------------------------------------------')
                loss = masked_lm_loss + next_sentence_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.output_dir, "model.{0}.bin".format(i_epoch))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.output_dir, "optim.{0}.bin".format(i_epoch))
                torch.save(optimizer.state_dict(), output_optim_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
예제 #3
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default="../../../dataset/final_data/commongen",
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bart_model",
        default="facebook/bart-large",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default="../../../output/train_kgbart",
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default="../../../log/train_kgbart",
        type=str,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=64,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=True,
                        type=bool,
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=True,
                        type=bool,
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        type=bool,
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=48,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0.1,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=10.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=6,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        default=True,
        type=bool,
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        default=True,
                        type=bool,
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=64,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=64,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        default=True,
        type=bool,
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.70,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=30,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=5,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument('--pretraining_KG',
                        action='store_true',
                        help="Whether to pretraining KG. ")
    parser.add_argument('--train_pretraining_num',
                        type=int,
                        default=20,
                        help="The number of sample training pretraining KG. ")
    parser.add_argument('--val_pretraining_num',
                        type=int,
                        default=20,
                        help="The number of sample validing pretraining KG.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=64,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        default=False,
        type=bool,
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        default=False,
                        type=bool,
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        default=False,
        type=bool,
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--keep_last_epochs",
                        default=5,
                        type=int,
                        help="Keep the last few epochs.")

    args = parser.parse_args()

    # assert Path(args.model_recover_path).exists(
    # ), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BartTokenizer.from_pretrained(args.bart_model,
                                              do_lower_case=args.do_lower_case)
    # if args.max_position_embeddings:
    #     tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = tokenizer  # WhitespaceTokenizer() if args.tokenized_input else
    if args.local_rank == 0:
        dist.barrier()

    bi_uni_pipeline = [
        seq2seq_loader.Preprocess4Pretrain(
            list(tokenizer.encoder.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mode="s2s",
            pretraining_KG=args.pretraining_KG,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]
    file_oracle = None
    entity_id = os.path.join(
        args.data_dir, args.tgt_file
        if args.tgt_file else 'CommonGen_KG/commongen_entity2id.txt')
    entity_embedding_path = os.path.join(
        args.data_dir, args.tgt_file
        if args.tgt_file else 'CommonGen_KG/commongen_ent_embeddings')

    relation_id = os.path.join(
        args.data_dir, args.tgt_file
        if args.tgt_file else 'CommonGen_KG/commongen_relation2id.txt')
    relation_embedding_path = os.path.join(
        args.data_dir, args.tgt_file
        if args.tgt_file else 'CommonGen_KG/commongen_rel_embeddings')

    entity_embedding = np.array(pickle.load(open(entity_embedding_path, "rb")))
    entity_embedding = np.array(
        list(np.zeros((4, 1024))) + list(entity_embedding))
    relation_embedding = np.array(
        pickle.load(open(relation_embedding_path, "rb")))

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        if args.pretraining_KG:
            file_oracle = os.path.join(args.data_dir,
                                       'CommonGen_KG/commongen_entity2id.txt')
        fn_src = os.path.join(
            args.data_dir,
            args.src_file if args.src_file else 'commongen.train.src_new.txt')
        fn_tgt = os.path.join(
            args.data_dir,
            args.tgt_file if args.tgt_file else 'commongen.train.tgt.txt')
        fn_onehop = os.path.join(
            args.data_dir,
            args.tgt_file if args.tgt_file else 'commongen.train.onehop_5.txt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            entity_id,
            relation_id,
            fn_onehop,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            pretraining_KG=file_oracle,
            pretraining_num=args.train_pretraining_num,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)
    if args.do_eval:
        print("Loading Dev Dataset", args.data_dir)
        if args.pretraining_KG:
            file_oracle = os.path.join(args.data_dir,
                                       'CommonGen_KG/commongen_entity2id.txt')
        fn_src = os.path.join(
            args.data_dir,
            args.src_file if args.src_file else 'commongen.dev.src_new.txt')
        fn_tgt = os.path.join(
            args.data_dir,
            args.tgt_file if args.tgt_file else 'commongen.dev.tgt.txt')
        fn_onehop = os.path.join(
            args.data_dir,
            args.tgt_file if args.tgt_file else 'commongen.dev.onehop_5.txt')
        dev_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            entity_id,
            relation_id,
            fn_onehop,
            args.eval_batch_size,
            data_tokenizer,
            args.max_seq_length,
            pretraining_KG=file_oracle,
            pretraining_num=args.val_pretraining_num,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            dev_sampler = RandomSampler(dev_dataset, replacement=False)
            _batch_size = args.eval_batch_size
        else:
            dev_sampler = DistributedSampler(dev_dataset)
            _batch_size = args.eval_batch_size // dist.get_world_size()
        dev_dataloader = torch.utils.data.DataLoader(
            dev_dataset,
            batch_size=_batch_size,
            sampler=dev_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
                      (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.pretraining_KG else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = KGBartForConditionalGeneration.from_pretrained(
            args.bart_model,
            entity_weight=entity_embedding,
            relation_weight=relation_embedding)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        model = KGBartForConditionalGeneration.from_pretrained(
            args.bart_model,
            state_dict=model_recover,
            entity_weight=entity_embedding,
            relation_weight=relation_embedding)
    if args.local_rank == 0:
        dist.barrier()

    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        schedule_recover = torch.load(os.path.join(
            args.output_dir, "sched.{0}.bin".format(recover_step)),
                                      map_location='cpu')
        scheduler.load_state_dict(schedule_recover)

        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    best_dev_loss = 1000

    output_eval_file = os.path.join(args.log_dir, "eval_results.txt")
    writer = open(output_eval_file, "w")

    def checkpoint_paths(path, pattern=r"model(\d+)\.pt"):
        """Retrieves all checkpoints found in `path` directory.

        Checkpoints are identified by matching filename to the specified pattern. If
        the pattern contains groups, the result will be sorted by the first group in
        descending order.
        """
        pt_regexp = re.compile(pattern)
        files = os.listdir(path)

        entries = []
        for i, f in enumerate(files):
            m = pt_regexp.fullmatch(f)
            if m is not None:
                idx = float(m.group(1)) if len(m.groups()) > 0 else int(
                    f.split(".")[1])
                entries.append((idx, m.group(0)))
        return [
            os.path.join(path, x[1]) for x in sorted(entries, reverse=True)
        ]

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            model.train()
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            # iter_bar = tqdm(BackgroundGenerator(train_dataloader), desc='Iter (loss=X.XXX)',
            #                 disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Training",
                         position=0,
                         leave=True)):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.pretraining_KG:
                    input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch
                else:
                    input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch

                loss_output = model(
                    input_ids,
                    input_entity_ids=input_entity_ids,
                    attention_mask=subword_mask,
                    word_mask=word_mask,
                    word_subword=word_subword,
                    decoder_input_ids=decoder_input_ids,
                    decoder_attention_mask=decoder_attention_mask,
                    labels=labels,
                    label_smoothing=False)

                masked_lm_loss = loss_output.loss
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                loss = masked_lm_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                # iter_bar.set_description('Iter %d (loss=%5.3f)' % (i_epoch, loss.item()))

                if step % 1000 == 0:
                    print('Iter %d  (Gen_loss=%5.3f)' % (i_epoch, loss.item()))

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.do_eval and (args.local_rank == -1
                                 or torch.distributed.get_rank() == 0):
                model.eval()
                cur_dev_loss = []
                with torch.no_grad():
                    for step, batch in enumerate(
                            tqdm(dev_dataloader,
                                 desc="Evaluating",
                                 position=0,
                                 leave=True)):
                        batch = [
                            t.to(device) if t is not None else None
                            for t in batch
                        ]
                        if args.pretraining_KG:
                            input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch
                        else:
                            input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch

                        loss_output = model(
                            input_ids,
                            input_entity_ids=input_entity_ids,
                            attention_mask=subword_mask,
                            word_mask=word_mask,
                            word_subword=word_subword,
                            decoder_input_ids=decoder_input_ids,
                            decoder_attention_mask=decoder_attention_mask,
                            labels=labels,
                            label_smoothing=False)

                        masked_lm_loss = loss_output.loss

                        if n_gpu > 1:  # mean() to average on multi-gpu.
                            # loss = loss.mean()
                            masked_lm_loss = masked_lm_loss.mean()
                            # next_sentence_loss = next_sentence_loss.mean()
                        loss = masked_lm_loss
                        cur_dev_loss.append(float(loss.item()))
                        # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                    dev_loss = sum(cur_dev_loss) / float(len(cur_dev_loss))
                    print("the epoch {} DEV loss is {}".format(
                        i_epoch, dev_loss))
                    if best_dev_loss > dev_loss:
                        best_dev_loss = dev_loss
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        os.makedirs(args.output_dir + "/best_model",
                                    exist_ok=True)
                        output_model_file = os.path.join(
                            args.output_dir, "best_model/model.best.bin")
                        # output_optim_file = os.path.join(
                        #     args.output_dir, "best_model/optim.best.bin")
                        # output_schedule_file = os.path.join(
                        #     args.output_dir, "best_model/sched.best.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        # torch.save(optimizer.state_dict(), output_optim_file)
                        # torch.save(scheduler.state_dict(), output_schedule_file)

                    logger.info(
                        "** ** * Saving fine-tuned model and optimizer ** ** * "
                    )
                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self
                    output_model_file = os.path.join(
                        args.output_dir, "model.{0}.bin".format(i_epoch))
                    output_optim_file = os.path.join(
                        args.output_dir, "optim.{0}.bin".format(i_epoch))
                    output_schedule_file = os.path.join(
                        args.output_dir, "sched.{0}.bin".format(i_epoch))
                    torch.save(model_to_save.state_dict(), output_model_file)
                    torch.save(optimizer.state_dict(), output_optim_file)
                    torch.save(scheduler.state_dict(), output_schedule_file)

                    writer.write("epoch " + str(i_epoch) + "\n")
                    writer.write("the current eval accuracy is: " +
                                 str(dev_loss) + "\n")

                    logger.info("***** CUDA.empty_cache() *****")
                    torch.cuda.empty_cache()

                    if args.keep_last_epochs > 0:
                        # remove old epoch checkpoints; checkpoints are sorted in descending order
                        checkpoints = checkpoint_paths(
                            args.output_dir, pattern=r"model.\d+.bin")
                        for old_chk in checkpoints[args.keep_last_epochs:]:
                            if os.path.lexists(old_chk):
                                os.remove(old_chk)

                        checkpoints = checkpoint_paths(
                            args.output_dir, pattern=r"optim.\d+.bin")
                        for old_chk in checkpoints[args.keep_last_epochs:]:
                            if os.path.lexists(old_chk):
                                os.remove(old_chk)

                        checkpoints = checkpoint_paths(
                            args.output_dir, pattern=r"sched.\d+.bin")
                        for old_chk in checkpoints[args.keep_last_epochs:]:
                            if os.path.lexists(old_chk):
                                os.remove(old_chk)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
예제 #4
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--topic_model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining topic model.")
    parser.add_argument("--topic_model_dict_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining topic model.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument('--topic_mode',
                        default=1,
                        type=float,
                        help="1:idea1 1.1:idea1_wo_theta 2:idea2 ")
    parser.add_argument('--topic_model',
                        default=False,
                        type=bool,
                        help="if only use topic model")
    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=192,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        "--train_batch_size",
        default=32,
        type=int,
        help="Total batch size for training.")  #batch_size = batch_size/n_gpus
    parser.add_argument("--eval_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=30,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()
    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)
    print("args.local_rank", args.local_rank)
    print("args.no_cuda", args.no_cuda)
    if args.local_rank == -1 or args.no_cuda:  #-1 False
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")  #device = cuda
        n_gpu = torch.cuda.device_count()
        print("n_gpu_1", n_gpu)
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
        print("n_gpu_1", n_gpu)
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()
    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            args.data_dir,
            args.topic_model_dict_path,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)
    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2    ### type_vocab_size=6
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        unilm = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb)
        global_step = 0

    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:  # here is the entrance
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        unilm = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb)
    #1. 模型初始化,入口定义好
    gsm = GSM(train_dataset.vocabsize)
    gsm_checkpoint = torch.load(args.topic_model_recover_path)
    gsm.load_state_dict(gsm_checkpoint["net"])

    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        unilm.half()
        gsm.half()
        if args.fp32_embedding:
            unilm.bert.embeddings.word_embeddings.float()
            unilm.bert.embeddings.position_embeddings.float()
            unilm.bert.embeddings.token_type_embeddings.float()
    unilm.to(device)
    gsm.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        unilm = DDP(unilm,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        unilm = DataParallelImbalance(unilm)
        gsm = DataParallelImbalance(gsm)
    # Prepare optimizer
    total = 0
    param_optimizer = list(unilm.named_parameters())
    param_optimizer_topic = list(gsm.named_parameters())
    for name, parameters in unilm.named_parameters():
        if "idea" in name:
            if "11" in name and "idea2" in name:
                total += np.prod(parameters.size())
                # print(name, ':', parameters.size())
        else:
            total += np.prod(parameters.size())
            # print(name, ':', parameters.size())

    print("gsm have {} paramerters in total".format(
        sum(x.numel() for x in gsm.parameters())))
    print("Number of parameter: %.6fM" % (total / 1e6))
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    if not args.topic_model:
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01,
            'topic':
            False
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0,
            'topic':
            False
        }, {
            'params': [p for n, p in param_optimizer_topic],
            'weight_decay':
            0.0,
            'lr':
            1e-3,
            'topic':
            True
        }]
    else:
        optimizer_grouped_parameters = [{
            'params': [p for n, p in param_optimizer_topic],
            'weight_decay':
            0.0,
            'lr':
            1e-3,
            'topic':
            True
        }]
    #一部分是有weight的,一部分是没有weight_dacay的
    # print("optimizer_grouped_parameters", optimizer_grouped_parameters)
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)
    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        unilm.train()
        gsm.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        print("000000", args.local_rank, start_epoch,
              int(args.num_train_epochs) + 1)
        topicloss = []
        unilmloss = []
        topicloss_lst = []
        unilmloss_lst = []
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):

            loss_sum = 0.0
            ppx_sum = 0.0
            word_count = 0.0
            doc_count = 0.0
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:  #false
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch
                else:  #这里加了bows
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, bows = batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                p_x, mus, log_vars, theta, beta, topic_embedding = gsm(bows)
                if not args.topic_model:
                    loss_tuple = unilm(input_ids,
                                       theta,
                                       beta,
                                       topic_embedding,
                                       args.topic_mode,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       is_next,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights,
                                       task_idx=task_idx,
                                       masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels,
                                       mask_qkv=mask_qkv)
                    masked_lm_loss, next_sentence_loss = loss_tuple

                ## topic loss
                logsoftmax = torch.log(p_x + 1e-10)
                rec_loss = -1.0 * torch.sum(
                    bows * logsoftmax
                )  #bows*logsoftmax = [batch_size, |V|], 其中torch.sum 把所有的loss全部加起来了,也可以只用加某一维度。
                rec_loss_per = -1.0 * torch.sum(bows * logsoftmax, dim=1)
                rec_loss_per = rec_loss_per.cpu().detach().numpy()

                kl_div = -0.5 * torch.sum(1 + log_vars - mus.pow(2) -
                                          log_vars.exp())
                loss_topic = rec_loss + kl_div
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    loss_topic = loss_topic.mean()
                    if not args.topic_model:
                        masked_lm_loss = masked_lm_loss.mean()
                        next_sentence_loss = next_sentence_loss.mean()
                if not args.topic_model:
                    loss_unilm = masked_lm_loss + next_sentence_loss

                # cal perplexity
                word_count_list = []
                loss_sum += loss_topic.item()
                for bow in bows:
                    word_num = torch.sum(bow).cpu().numpy()
                    word_count_list.append(word_num)
                    word_count += word_num

                word_count_np = np.array(word_count_list)
                doc_count += len(bows)
                ppx_sum += np.sum(np.true_divide(rec_loss_per, word_count_np))
                topicloss_lst.append(loss_topic.item() / len(bows))
                if not args.topic_model:
                    unilmloss_lst.append(loss_unilm.item())
                #topic_loss end
                if not args.topic_model:
                    loss = loss_unilm + loss_topic
                else:
                    loss = loss_topic
                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:  # =1
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            if not param_group['topic']:
                                param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                if not args.topic_model:
                    iter_bar.set_description(
                        'Iter (loss_unilm=%5.3f),Iter (ppl=%5.3f)' %
                        (loss_unilm.item(),
                         np.sum(np.true_divide(rec_loss_per, word_count_np))))
                else:
                    iter_bar.set_description(
                        'Iter (loss_topic=%5.3f), (ppl=%5.3f)' %
                        (loss_topic.item(),
                         np.sum(np.true_divide(rec_loss_per, word_count_np))))
            #Save a trained model
            ppx_word = np.exp(loss_sum / word_count)
            ppx_document = np.exp(ppx_sum / doc_count)
            print("********")
            print("word_count", word_count)
            print("ppx_word", ppx_word)
            print("ppx_document", ppx_document)
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                #save unilm model
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                unilm_model_to_save = unilm.module if hasattr(
                    unilm, 'module') else unilm  # Only save the model it-self
                output_unilm_model_file = os.path.join(
                    args.output_dir, "unilm.{0}.bin".format(i_epoch))
                torch.save(unilm_model_to_save.state_dict(),
                           output_unilm_model_file)
                #save topic model
                logger.info(
                    "** ** * Saving topic model and optimizer ** ** * ")
                topic_model_to_save = gsm.module if hasattr(
                    gsm, 'module') else gsm  # Only save the model it-self
                output_topic_model_file = os.path.join(
                    args.output_dir, "topic.{0}.ckpt".format(i_epoch))
                torch.save(topic_model_to_save.state_dict(),
                           output_topic_model_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
        smth_pts = smooth_curve(topicloss_lst)
        # plt.plot(range(len(topicloss_lst)), topicloss_lst)
        plt.plot(range(len(smth_pts)), smth_pts)
        plt.xlabel('epochs')
        plt.title('Topic Model Train Loss')
        plt.savefig(args.output_dir + '/topic_loss.png')
        plt.cla()
        plt.plot(range(len(unilmloss_lst)), unilmloss_lst)
        plt.xlabel('epochs')
        plt.title('Unilm Train Loss')
        plt.savefig(args.output_dir + '/unilm_loss.png')
예제 #5
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
    parser.add_argument("--model_recover_path", default=None, type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--max_seq_length", default=512, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")

    # decoding parameters
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp', action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split", type=str, default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input', action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed', type=int, default=123,
                        help="random seed for initialization")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids', action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size', type=int, default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size', type=int, default=1,
                        help="Beam size for searching")
    parser.add_argument('--top_k', type=int, default=1,
                        help="Top k for output")
    parser.add_argument('--top_kk', type=int, default=0,
                        help="Top k sample method for output")
    parser.add_argument('--length_penalty', type=float, default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word', type=str, default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode', default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length', type=int, default=128,
                        help="maximum length of target sequence")
    
    # evaluate parameters
    parser.add_argument('--do_predict', action='store_true', help="do_predict")
    parser.add_argument("--do_evaluate", action="store_true", help="caculate the scores if have label file")
    parser.add_argument("--label_file", type=str, default="")
    parser.add_argument("--experiment", type=str, default="full", help="full/title/title-l1/hierachical/title-first/title-first-rouge")

    # ranker parameters
    parser.add_argument("--ranker_recover_path", type=str, help="ranker model for extract sentence")
    parser.add_argument("--ranker_max_len", type=int, default=192, help ="max length of the ranker input")
    parser.add_argument("--ranker_batch_size", type=int, default=128)

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1.")
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    if args.mode == "s2s" or args.mode == "both":
        bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list(
            tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s"))
    if args.mode == "l2r" or args.mode == "both":
        bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list(
            tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="l2r"))
    
    if args.experiment == "segsep":
        bi_uni_pipeline = []
        bi_uni_pipeline.append(Preprocess4SegSepDecoder(list(
            tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s"))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    if args.experiment == "segsep":
        type_vocab_size = 11
    mask_word_id, eos_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]"])
    forbid_ignore_set = None
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))
    print(args.model_recover_path)
    if args.do_predict:
        for model_recover_path in glob.glob(args.model_recover_path.strip()):
            logger.info("***** Recover model: %s *****", model_recover_path)
            model_recover = torch.load(model_recover_path)
            model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size,
                                                        length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, max_position_embeddings=args.max_seq_length)
            del model_recover

            if args.fp16:
                model.half()
            model.to(device)
            if n_gpu > 1:
                model = torch.nn.DataParallel(model)

            torch.cuda.empty_cache()
            model.eval()
            next_i = 0
            max_src_length = args.max_seq_length - 2 - args.max_tgt_length

            if args.experiment in ["full", "title", "title-l1"]:
                input_lines = EvalDataset(args.input_file, args.experiment).proc()
            elif args.experiment == "single":
                input_lines, map_dict = EvalDataset(args.input_file, args.experiment).proc()
            elif args.experiment == "title-first":
                input_lines = EvalDataset(args.input_file, args.experiment, tokenizer, args.max_seq_length, args.max_seq_length).proc()
            elif args.experiment == "segsep":
                input_lines = EvalDataset(args.input_file, args.experiment, tokenizer, args.max_seq_length, args.max_seq_length).proc()
            elif args.experiment == "heirachical":
                logger.info("***** Recover rank model: %s *****", args.ranker_recover_path)
                # extract sentences before load data
                # load rank model
                rank_model_recover = torch.load(args.ranker_recover_path, map_location="cpu")
                global_step = 0
                rank_model = BertForSentenceRanker.from_pretrained(args.bert_model, state_dict=rank_model_recover, num_labels=2)
                
                # set model for multi GPUs or multi nodes
                if args.fp16:
                    rank_model.half()
                
                rank_model.to(device)

                if n_gpu > 1:
                    rank_model = DataParallelImbalance(rank_model)
                
                DatasetFunc = ScoreEvalDataset
                
                # Load title + sentence pair
                print ("Loading Rank Dataset from ", args.input_file)
                data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
                max_pred = 16
                mask_prob = 0.7
                rank_bi_uni_pipeline = [Preprocess4Seq2cls(max_pred, mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.ranker_max_len, new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': 64, 'max_len_b': 16, 'trunc_seg': 'a', 'always_truncate_tail': True}, mask_source_words=False, skipgram_prb=0.0, skipgram_size=1, mask_whole_word=False, mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, eval=True)]
                fn_src = args.input_file
                fn_tgt = None
                eval_dataset = DatasetFunc(
                     fn_src, fn_tgt, args.ranker_batch_size, data_tokenizer, args.ranker_max_len, bi_uni_pipeline=rank_bi_uni_pipeline
                )

                eval_sampler = SequentialSampler(eval_dataset)
                _batch_size = args.ranker_batch_size

                eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=_batch_size, sampler=eval_sampler, num_workers=24, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False)


                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
                logger.info("***** Runinning ranker *****")
                logger.info("   Batch size = %d", _batch_size)
                logger.info("   Num steps = %d", int(len(eval_dataset)/ args.ranker_batch_size))

                rank_model.to(device)
                rank_model.eval()

                iter_bar = tqdm(eval_dataloader, desc = "Iter: ")
                num_rank_labels = 2
                all_labels = []
                for step, batch in enumerate(iter_bar):
                    batch = [t.to(device) if t is not None else None for t in batch]
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch
                    logits = rank_model(input_ids, task_idx=task_idx, mask_qkv=mask_qkv)
                    labels = torch.argmax(logits.view(-1, num_rank_labels), dim=-1)
                    all_labels.append(labels)
                
                all_labels_results = []
                for label in all_labels:
                    all_labels_results.extend(label.detach().cpu().numpy())
                
                # collect results
                logger.info("**** Collect results ******")
                clu2doc_dict, doc2sent_dict, all_titles, all_sents = eval_dataset.get_maps()
                all_docs = []
                for i, doc in enumerate(doc2sent_dict):
                    text = all_titles[i]
                    sent_idx = doc2sent_dict[doc]
                    for idx in sent_idx:
                        if all_labels_results[idx] == 1:
                            text += ". " + all_sents[idx]
                    all_docs.append(text)
                
                input_lines = []
                for clu in tqdm(clu2doc_dict):
                    doc_idx = clu2doc_dict[clu]
                    input_line  = ""
                    for idx in doc_idx:
                        input_line += all_docs[idx]
                    input_lines.append(input_line)

            elif args.experiment == "title-first-rank":
                logger.info("***** Recover rank model: %s *****", args.ranker_recover_path)
                # extract sentences before load data
                # load rank model
                rank_model_recover = torch.load(args.ranker_recover_path, map_location="cpu")
                global_step = 0
                rank_model = BertForSentenceRanker.from_pretrained(args.bert_model, state_dict=rank_model_recover, num_labels=2)
                
                # set model for multi GPUs or multi nodes
                if args.fp16:
                    rank_model.half()
                
                rank_model.to(device)

                if n_gpu > 1:
                    rank_model = DataParallelImbalance(rank_model)
                
                DatasetFunc = EvalRankDataset
                
                # Load title + sentence pair
                print ("Loading Rank Dataset from ", args.input_file)
                data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
                max_pred = 16
                mask_prob = 0.7
                rank_bi_uni_pipeline = [Preprocess4Seq2cls(max_pred, mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': 512, 'max_len_b': 16, 'trunc_seg': 'a', 'always_truncate_tail': True}, mask_source_words=False, skipgram_prb=0.0, skipgram_size=1, mask_whole_word=False, mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, eval=True)]
                fn_src = args.input_file
                fn_tgt = None
                eval_dataset = DatasetFunc(
                     fn_src, fn_tgt, args.ranker_batch_size, data_tokenizer, args.max_seq_length, bi_uni_pipeline=rank_bi_uni_pipeline
                )

                eval_sampler = SequentialSampler(eval_dataset)
                _batch_size = args.ranker_batch_size

                eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=_batch_size, sampler=eval_sampler, num_workers=24, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
                logger.info("***** Runinning ranker *****")
                logger.info("   Batch size = %d", _batch_size)
                logger.info("   Num steps = %d", int(len(eval_dataset)/ args.ranker_batch_size))

                rank_model.to(device)
                rank_model.eval()

                iter_bar = tqdm(eval_dataloader, desc = "Iter: ")
                num_rank_labels = 2
                all_labels = []
                for step, batch in enumerate(iter_bar):
                    batch = [t.to(device) if t is not None else None for t in batch]
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch
                    # print("input_ids", len(input_ids[0]), "segment_ids", len(segment_ids[0]))
                    with torch.no_grad():
                        logits = rank_model(input_ids, task_idx=task_idx, mask_qkv=mask_qkv)
                    labels = logits.view(-1)
                    all_labels.append(labels)

                
                all_labels_results = []
                for label in all_labels:
                    all_labels_results.extend(label.detach().cpu().numpy())
                
                print("test label results")
                print(all_labels_results[0])
                # collect results
                logger.info("**** Collect results ******")
                clu2sent_dict, all_sents, all_titles= eval_dataset.get_maps()
                all_clusters = []
                input_lines = []
                for i, clu_id in enumerate(clu2sent_dict):
                    text = all_titles[clu_id]
                    sent_idx = clu2sent_dict[clu_id]
                    sents_collect = []
                    for idx in sent_idx:
                        sents_collect.append([all_sents[idx], all_labels_results[idx]])
                    sents_collect_sort = sorted(sents_collect, key=lambda x:x[1])

                    sents_collect = [x[0] for x in sents_collect_sort]

                    text_tk = tokenizer.tokenize(text)
                    j = 0
                    while j < len(sents_collect) and len(text_tk) + len(tokenizer.tokenize(sents_collect[j])) <= args.max_seq_length:
                        text += " " + sents_collect[j]
                        j += 1
                    
                    input_lines.append(text)                
            else:
                input_lines = []
            
            data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
            input_lines = [data_tokenizer.tokenize(
                x)[:max_src_length] for x in input_lines]
            input_lines = sorted(list(enumerate(input_lines)),
                                key=lambda x: -len(x[1]))
            output_lines = [""] * len(input_lines)
            score_trace_list = [None] * len(input_lines)
            total_batch = math.ceil(len(input_lines) / args.batch_size)

            with tqdm(total=total_batch) as pbar:
                while next_i < len(input_lines):
                    _chunk = input_lines[next_i:next_i + args.batch_size]
                    buf_id = [x[0] for x in _chunk]
                    buf = [x[1] for x in _chunk]
                    next_i += args.batch_size
                    max_a_len = max([len(x) for x in buf])
                    instances = []
                    for instance in [(x, max_a_len) for x in buf]:
                        for proc in bi_uni_pipeline:
                            instances.append(proc(instance))
                    with torch.no_grad():
                        batch = seq2seq_loader.batch_list_to_batch_tensors(
                            instances)
                        # print("batch")
                        # print(batch)
                        # print(len(batch))
                        batch = [t.to(device) for t in batch if t is not None]
                        input_ids, token_type_ids, position_ids, input_mask, task_idx = batch
                        traces = model(input_ids, token_type_ids,
                                    position_ids, input_mask, task_idx=task_idx)
                        if args.beam_size > 1:
                            traces = {k: v.tolist() for k, v in traces.items()}
                            output_ids = traces['pred_seq']
                        else:
                            output_ids = traces.tolist()
                        for i in range(len(buf)):
                            scores = traces['scores'][i]
                            wids_list = traces['wids'][i]
                            ptrs = traces['ptrs'][i]
                            eos_id = 102
                            top_k = args.top_k
                            # first we need to find the eos frame where all symbols are eos
                            # any frames after the eos frame are invalid
                            last_frame_id = len(scores) - 1
                            for _i, wids in enumerate(wids_list):
                                if all(wid == eos_id for wid in wids):
                                    last_frame_id = _i
                                    break
                            frame_id = -1
                            pos_in_frame = -1
                            seqs = []
                            for fid in range(last_frame_id + 1):
                                for _i, wid in enumerate(wids_list[fid]):
                                    if wid == eos_id or fid == last_frame_id:
                                        s = scores[fid][_i]

                                        frame_id = fid
                                        pos_in_frame = _i

                                        if frame_id != -1 and s < 0:
                                            seq = [wids_list[frame_id][pos_in_frame]]
                                            for _fid in range(frame_id, 0, -1):
                                                pos_in_frame = ptrs[_fid][pos_in_frame]
                                                seq.append(wids_list[_fid - 1][pos_in_frame])
                                            seq.reverse()
                                            seqs.append([seq, s])
                            seqs = sorted(seqs, key= lambda x:x[1], reverse=True)
                            w_idss = [seq[0] for seq in seqs[:top_k]]
                            output_sequences = []
                            for w_ids in w_idss:
                                output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                                output_tokens = []
                                for t in output_buf:
                                    if t in ("[SEP]", "[PAD]"):
                                        break
                                    output_tokens.append(t)
                                output_sequence = ' '.join(detokenize(output_tokens))
                                output_sequences.append(output_sequence)
                            output_lines[buf_id[i]] = output_sequences
                            if args.need_score_traces:
                                score_trace_list[buf_id[i]] = {
                                    'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]}
                    pbar.update(1)
            # collect instances after split
            results = []
            if args.experiment == "single":
                for clu in map_dict:
                    record = []
                    clu_ixs = map_dict[clu]
                    for i in clu_ixs:
                        record.extend(output_lines[i])
                    record_top10 = Counter(record).most_common(10)
                    record_top10 = [x[0] for x in record_top10]
                    results.append(record_top10)

                output_lines = results

            if args.output_file:
                fn_out = args.output_file
            else:
                fn_out = model_recover_path+'.'+args.split
            with open(fn_out, "w", encoding="utf-8") as fout:
                for l in output_lines:
                    fout.write('\t'.join(l))
                    fout.write("\n")

            if args.need_score_traces:
                with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                    pickle.dump(
                        {"version": 0.0, "num_samples": len(input_lines)}, fout_trace)
                    for x in score_trace_list:
                        pickle.dump(x, fout_trace)
        
        # Evaluate !
        if args.do_evaluate:
            labels = []
            if not os.path.exists(args.label_file):
                raise ValueError("Label file not exists")
            print("Loading label file from {}".format(args.label_file))
            with open(args.label_file) as f:
                for line in tqdm(f.readlines()):
                    line = line.strip().split("\t")
                    labels.append(line)
            results = output_lines

            ks = [1, 5, 10]
            results_dict = {}
            for k in ks:
                acc_cul = 0
                r_cul = 0
                f1_cul = 0
                cnt = 0
                for predict, true_label in zip(tqdm(results), tqdm(labels)):
                    predict = predict[:k]
                    true_label = true_label[:k]
                    if len(predict) > 0 and len(true_label) > 0:
                        acc_cul += acc_score(predict, true_label)
                        r_cul += recall_score(predict, true_label)
                        f1_cul += f1_score(acc_score(predict, true_label), recall_score(predict, true_label))
                        cnt += 1
                    
                results_dict["P@{}".format(k)] = acc_cul*1.000 / cnt
                results_dict["R@{}".format(k)] = r_cul*1.000 / cnt
                results_dict["F1@{}".format(k)] = f1_cul*1.000 / cnt
            
            print(results_dict)
예제 #6
0
파일: PPL.py 프로젝트: temp1124/ZRKGC-1
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )

    parser.add_argument("--dev_src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--dev_tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument("--dev_check_file",
                        default=None,
                        type=str,
                        help="The output style response/know data file name.")
    parser.add_argument("--dev_style_file",
                        default=None,
                        type=str,
                        help="The output style response/know data file name.")

    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")

    parser.add_argument("--predict_bleu",
                        default=0.5,
                        type=float,
                        help="The Predicted Bleu for KS Predict ")

    parser.add_argument("--train_vae",
                        action='store_true',
                        help="Whether to train vae.")
    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run ks predict.")

    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion_step",
        default=300,
        type=int,
        help=
        "Proportion of training to perform linear learning rate warmup for. ")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()

    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)

    handler = logging.FileHandler(os.path.join(args.log_dir, "train.log"),
                                  encoding='UTF-8')
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)

    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)

    logger.addHandler(handler)
    logger.addHandler(console)

    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    C_bi_uni_pipeline = [
        seq2seq_loader.C_Preprocess4Seq2seq(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]

    logger.info("Loading Dataset from {}".format(args.data_dir))

    fn_src = os.path.join(args.data_dir, args.dev_src_file)
    fn_tgt = os.path.join(args.data_dir, args.dev_tgt_file)
    dev_reddit_dataset = seq2seq_loader.C_Seq2SeqDataset(
        fn_src,
        fn_tgt,
        args.eval_batch_size,
        data_tokenizer,
        args.max_seq_length,
        file_oracle=None,
        bi_uni_pipeline=C_bi_uni_pipeline)
    if args.local_rank == -1:
        dev_reddit_sampler = RandomSampler(dev_reddit_dataset,
                                           replacement=False)
        _batch_size = args.eval_batch_size
    else:
        dev_reddit_sampler = DistributedSampler(dev_reddit_dataset)
        _batch_size = args.eval_batch_size // dist.get_world_size()
    dev_reddit_dataloader = torch.utils.data.DataLoader(
        dev_reddit_dataset,
        batch_size=_batch_size,
        sampler=dev_reddit_sampler,
        num_workers=args.num_workers,
        collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
        pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
          (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()

    if args.model_recover_path:
        logger.info("***** Recover model: %s *****", args.model_recover_path)
        model_recover = torch.load(args.model_recover_path, map_location='cpu')

    model = BertForPreTrainingLossMask.from_pretrained(
        args.bert_model,
        state_dict=model_recover,
        num_labels=cls_num_labels,
        num_rel=0,
        type_vocab_size=type_vocab_size,
        config_path=args.config_path,
        task_idx=3,
        num_sentlvl_labels=num_sentlvl_labels,
        max_position_embeddings=args.max_position_embeddings,
        label_smoothing=args.label_smoothing,
        fp32_embedding=args.fp32_embedding,
        relax_projection=relax_projection,
        new_pos_ids=args.new_pos_ids,
        ffn_type=args.ffn_type,
        hidden_dropout_prob=args.hidden_dropout_prob,
        attention_probs_dropout_prob=args.attention_probs_dropout_prob,
        num_qkv=args.num_qkv,
        seg_emb=args.seg_emb)

    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if args.optim_recover_path is not None:
        logger.info("***** Recover optimizer from : {} *****".format(
            args.optim_recover_path))
        optim_recover = torch.load(args.optim_recover_path, map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:

        pretrain_step = -1
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)

            logger.info("***** Running QKR evaling *****")
            logger.info("  Batch size = %d", args.eval_batch_size)

            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            dev_iter_bar = tqdm(dev_reddit_dataloader,
                                desc='Iter (loss=X.XXX)',
                                disable=args.local_rank not in (-1, 0))

            total_lm_loss = 0
            for qkr_dev_step, batch in enumerate(dev_iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch
                else:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, tgt_pos, labels, ks_labels, style_ids, style_labels, check_ids = batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None

                with torch.no_grad():

                    loss_tuple = model(input_ids,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       is_next,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights,
                                       task_idx=task_idx,
                                       masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels,
                                       mask_qkv=mask_qkv,
                                       tgt_pos=tgt_pos,
                                       labels=labels,
                                       ks_labels=ks_labels,
                                       train_vae=args.train_vae,
                                       style_ids=style_ids,
                                       style_labels=style_labels,
                                       check_ids=check_ids,
                                       pretrain=None)

                    masked_lm_loss, next_sentence_loss, KL_loss, Mutual_loss, Golden_loss, cosine_similarity_loss, predict_kl_loss = loss_tuple
                    if n_gpu > 1:  # mean() to average on multi-gpu.
                        masked_lm_loss = masked_lm_loss.mean()

                    # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                    total_lm_loss += masked_lm_loss.item()

                    # ensure that accumlated gradients are normalized
                    total_mean_lm_loss = total_lm_loss / (qkr_dev_step + 1)
                    print(total_mean_lm_loss)

            logger.info("** ** * Evaling mean loss ** ** * ")
            logger.info("In{}epoch,dev_lm_loss:{}".format(
                i_epoch, total_mean_lm_loss))
            logger.info("ppl:{}".format(np.exp(total_mean_lm_loss)))
            logger.info("******************************************* ")
            break