예제 #1
0
    def _preprocess_qa_data(tokenizer):
        # 标注数据
        webqa_data = json.load(open(os.path.join(FLAGS.data_dir,
                                                 'WebQA.json')))
        sogou_data = json.load(
            open(os.path.join(FLAGS.data_dir, 'SogouQA.json')))
        train_data = webqa_data + sogou_data

        bi_uni_pipeline = [
            utils_seq2seq.Preprocess4Seq2seq(
                FLAGS.max_pred,
                FLAGS.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                FLAGS.max_seq_length,
                mask_source_words=False,
                skipgram_prb=FLAGS.skipgram_prb,
                skipgram_size=FLAGS.skipgram_size,
                mask_whole_word=FLAGS.mask_whole_word,
                tokenizer=tokenizer)
        ]

        train_dataset = utils_seq2seq.Seq2SeqDataset(
            file_data=train_data,
            batch_size=FLAGS.batch_size,
            tokenizer=tokenizer,
            max_len=FLAGS.max_seq_length,
            bi_uni_pipeline=bi_uni_pipeline)

        return train_dataset
예제 #2
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--data_dir", default='/data/lq/tianchi/qg/model/unilm/data_file/', type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--src_file", default='src_file/train_data.json', type=str,
                        help="The input data file name.")
    parser.add_argument("--dev_file", default='dev_data.json', type=str, help="dev file.")
    parser.add_argument("--model_type", default='unilm', type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name_or_path", default='/data/lq/tianchi/qg/model/unilm/torch-model/', type=str, required=True,
                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
    parser.add_argument("--output_dir", default='/data/lq/tianchi/qg/model/unilm/output_dir/', type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--log_dir", default='', type=str,
                        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path", default=None, type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path", default=None, type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument("--config_name", default="", type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", default="", type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")

    # Other parameters
    parser.add_argument("--dev_batch_size", default=20, type=str, help="dev batch size.")
    parser.add_argument("--max_seq_length", default=512, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument('--max_position_embeddings', type=int, default=512,
                        help="max position embeddings")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size", default=32, type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size", default=64, type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing", default=0, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.01, type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob", default=0.1, type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed', type=int, default=777,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--tokenized_input', action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a', type=int, default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b', type=int, default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument('--trunc_seg', default='',
                        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument('--always_truncate_tail', action='store_true',
                        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument("--mask_prob", default=0.20, type=float,
                        help="Number of prediction is sometimes less than max_pred when sequence is short.")
    parser.add_argument("--mask_prob_eos", default=0, type=float,
                        help="Number of prediction is sometimes less than max_pred when sequence is short.")
    parser.add_argument('--max_pred', type=int, default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers", default=0, type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words', action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb', type=float, default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size', type=int, default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word', action='store_true',
                        help="Whether masking a whole word.")

    args = parser.parse_args()


    if not(args.model_recover_path and Path(args.model_recover_path).exists()):
        args.model_recover_path = None

    args.output_dir = args.output_dir.replace(
        '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace(
        '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    if args.log_dir:
        os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__, open(os.path.join(
        args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(
        args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
    data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        bi_uni_pipeline = [utils_seq2seq.Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, mask_source_words=False, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, tokenizer=data_tokenizer)]

        file = os.path.join(
            args.data_dir, args.src_file if args.src_file else 'train.tgt')
        train_dataset = utils_seq2seq.Seq2SeqDataset(
            file, args.train_batch_size, data_tokenizer, args.max_seq_length, bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=_batch_size, sampler=train_sampler,
                                                       num_workers=args.num_workers, collate_fn=utils_seq2seq.batch_list_to_batch_tensors, pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(len(train_dataloader) * args.num_train_epochs /
                  args.gradient_accumulation_steps)

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    global_step = 0
    if (recover_step is None) and (args.model_recover_path is None):
        model_recover = None
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(
                recover_step * t_total / args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(
                args.model_recover_path, map_location='cpu')
    model = model_class.from_pretrained(
        args.model_name_or_path, state_dict=model_recover, config=config)
    if args.local_rank == 0:
        dist.barrier()

    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(
            nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*t_total), num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(
            model, optimizer, opt_level=args.fp16_opt_level)

    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model, device_ids=[
                    args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)

        #logger.info("***** Recover amp: %d *****", recover_step)
        #amp_recover = torch.load(os.path.join(
        #    args.output_dir, "amp.{0}.bin".format(recover_step)), map_location='cpu')
        #amp.load_state_dict(amp_recover)

        logger.info("***** Recover scheduler: %d *****", recover_step)
        scheduler_recover = torch.load(os.path.join(
            args.output_dir, "sched.{0}.bin".format(recover_step)), map_location='cpu')
        scheduler.load_state_dict(scheduler_recover)

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    #####################
    ### 载入验证数据######
    #####################
    print("Loading Dev Dataset", args.data_dir)
    bi_uni_pipeline = [utils_seq2seq.Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()),
                                                        tokenizer.convert_tokens_to_ids, args.max_seq_length,
                                                        mask_source_words=False, skipgram_prb=args.skipgram_prb,
                                                        skipgram_size=args.skipgram_size,
                                                        mask_whole_word=args.mask_whole_word, tokenizer=data_tokenizer)]
    file_dev = os.path.join(args.data_dir, args.dev_file if args.dev_file else 'train.tgt')
    dev_dataset = utils_seq2seq.Seq2SeqDataset(file_dev, args.dev_batch_size, data_tokenizer, args.max_seq_length,
                                               bi_uni_pipeline=bi_uni_pipeline)
    if args.local_rank == -1:
        dev_sampler = RandomSampler(dev_dataset, replacement=False)
        _batch_size = args.dev_batch_size
    else:
        dev_sampler = DistributedSampler(dev_dataset)
        _batch_size = args.dev_batch_size // dist.get_world_size()

    dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=int(_batch_size), sampler=dev_sampler,
                                                 num_workers=args.num_workers,
                                                 collate_fn=utils_seq2seq.batch_list_to_batch_tensors,
                                                 pin_memory=False)

    #####################
    ######训练开始#######
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step+1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)):

            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)(lr=X.XXXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                masked_lm_loss = model(input_ids, segment_ids, input_mask, lm_label_ids,
                                       masked_pos=masked_pos, masked_weights=masked_weights)
                if n_gpu > 1:    # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                loss = masked_lm_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss={:.2f})(lr={:0.2e})'.format(loss.item(), scheduler.get_lr()[0]))

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.max_grad_norm)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    optimizer.zero_grad()
                    global_step += 1

                if step %100==0:
                    # Save a trained model
                    if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                        logger.info(
                            "** ** * Saving fine-tuned model and optimizer  :step{}** ** * ".format(step))
                        model_to_save = model.module if hasattr(
                            model, 'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir, "model.{}.{}.bin".format(i_epoch,step))
                        torch.save(model_to_save.state_dict(), output_model_file)
                        output_optim_file = os.path.join(
                            args.output_dir, "optim.{}.{}.bin".format(i_epoch,step))
                        torch.save(optimizer.state_dict(), output_optim_file)
                        output_sched_file = os.path.join(
                            args.output_dir, "sched.{}.{}.bin".format(i_epoch,step))
                        torch.save(scheduler.state_dict(), output_sched_file)

                        logger.info("***** CUDA.empty_cache() *****")
                        torch.cuda.empty_cache()

                    ###################################
                    ###################################
                    # 载入此轮保存的模型哦!!!!!!!!
                    ###################################
                    ###################################

                    #dev_iter_bangbang1 = tqdm(dev_dataloader,disable=args.local_rank not in (-1, 0))
                    temp = []
                    with torch.no_grad():
                        for step, batch in enumerate(dev_dataloader):
                            batch = [
                                t.to(device) if t is not None else None for t in batch]
                            input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                            masked_lm_loss = model(input_ids, segment_ids, input_mask, lm_label_ids,
                                                   masked_pos=masked_pos, masked_weights=masked_weights)
                            if n_gpu > 1:
                                masked_lm_loss = masked_lm_loss.mean()
                            dev_loss = masked_lm_loss
                            bbb = dev_loss.cpu().detach().numpy()
                            temp.append(bbb)
                        dev_loss_fin = sum(temp) / step
                        print('####   Dev loss:', dev_loss_fin,output_model_file)

            # Save a trained model
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.output_dir, "model.{0}.bin".format(i_epoch))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.output_dir, "optim.{0}.bin".format(i_epoch))
                torch.save(optimizer.state_dict(), output_optim_file)
                if args.fp16:
                    output_amp_file = os.path.join(
                        args.output_dir, "amp.{0}.bin".format(i_epoch))
                    torch.save(amp.state_dict(), output_amp_file)
                output_sched_file = os.path.join(
                    args.output_dir, "sched.{0}.bin".format(i_epoch))
                torch.save(scheduler.state_dict(), output_sched_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()


            #dev_iter_bangbang = tqdm(dev_dataloader, desc='Iter (loss=X.XXX)',
            #                     disable=args.local_rank not in (-1, 0))
            temp = []
            with torch.no_grad():
                for step, batch in enumerate(dev_dataloader):
                    batch = [
                        t.to(device) if t is not None else None for t in batch]
                    input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                    masked_lm_loss = model(input_ids, segment_ids, input_mask, lm_label_ids,
                                           masked_pos=masked_pos, masked_weights=masked_weights)
                    if n_gpu > 1:  # mean() to average on multi-gpu.
                        # loss = loss.mean()
                        masked_lm_loss = masked_lm_loss.mean()
                    dev_loss = masked_lm_loss
                    bbb = dev_loss.cpu().detach().numpy()
                    temp.append(bbb)
                dev_loss_fin = sum(temp) / step
                print('####   Dev loss:', dev_loss_fin, output_model_file)
예제 #3
0
def main():
    my_parser = argparse.ArgumentParser()

    # Required parameters
    my_parser.add_argument("--model_type",
                           default=None,
                           type=str,
                           required=True,
                           help="Model type selected in the list: " +
                           ", ".join(MODEL_CLASSES.keys()))
    my_parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    my_parser.add_argument("--model_recover_path",
                           default=None,
                           type=str,
                           help="The file of fine-tuned pretraining model.")
    my_parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    my_parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    my_parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    # decoding parameters
    my_parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    my_parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    my_parser.add_argument("--input_file", type=str, help="Input file")
    my_parser.add_argument('--subset',
                           type=int,
                           default=0,
                           help="Decode a subset of the input dataset.")
    my_parser.add_argument("--output_file", type=str, help="output file")
    my_parser.add_argument("--split",
                           type=str,
                           default="",
                           help="Data split (train/val/test).")
    my_parser.add_argument('--tokenized_input',
                           action='store_true',
                           help="Whether the input is tokenized.")
    my_parser.add_argument('--seed',
                           type=int,
                           default=123,
                           help="random seed for initialization")
    my_parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    my_parser.add_argument('--batch_size',
                           type=int,
                           default=4,
                           help="Batch size for decoding.")
    my_parser.add_argument('--beam_size',
                           type=int,
                           default=1,
                           help="Beam size for searching")
    my_parser.add_argument('--length_penalty',
                           type=float,
                           default=0,
                           help="Length penalty for beam search")
    my_parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    my_parser.add_argument(
        '--forbid_ignore_word',
        type=str,
        default=None,
        help="Forbid the word during forbid_duplicate_ngrams")
    my_parser.add_argument("--min_len", default=None, type=int)
    my_parser.add_argument('--need_score_traces', action='store_true')
    my_parser.add_argument('--ngram_size', type=int, default=3)
    my_parser.add_argument('--max_tgt_length',
                           type=int,
                           default=69,
                           help="maximum length of target sequence")

    args = my_parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        max_position_embeddings=args.max_seq_length)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case)
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        utils_seq2seq.Preprocess4Seq2seqDecode(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            tokenizer=data_tokenizer))

    test_dataset = utils_seq2seq.Seq2SeqDataset(
        args.input_file,
        args.batch_size,
        data_tokenizer,
        args.max_seq_length,
        bi_uni_pipeline=bi_uni_pipeline)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        collate_fn=utils_seq2seq.batch_list_to_batch_tensors,
        pin_memory=False)
    # Prepare model
    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])
    forbid_ignore_set = None
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        my_logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path, map_location=device)
        model = model_class.from_pretrained(
            args.model_name_or_path,
            state_dict=model_recover,
            config=config,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len)
        del model_recover

        model.to(device)

        if args.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model = amp.initialize(model, opt_level=args.fp16_opt_level)

        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()

        output_lines = []
        score_trace_list = []

        iter_bar = tqdm(test_dataloader)
        for step, batch in enumerate(iter_bar):
            with torch.no_grad():
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                input_ids, segment_ids, answer_tag, position_ids, input_mask = batch
                traces = model(input_ids, segment_ids, answer_tag,
                               position_ids, input_mask)
                if args.beam_size > 1:
                    traces = {k: v.tolist() for k, v in traces.items()}
                    output_ids = traces['pred_seq']
                else:
                    output_ids = traces.tolist()
                for i in range(len(input_ids)):
                    w_ids = output_ids[i]
                    output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                    output_tokens = []
                    for t in output_buf:
                        if t in ("[SEP]", "[PAD]"):
                            break
                        output_tokens.append(t)
                    output_sequence = ''.join(detokenize(output_tokens))
                    output_lines.append(output_sequence)
                    if args.need_score_traces:
                        score_trace_list.append({
                            'scores': traces['scores'][i],
                            'wids': traces['wids'][i],
                            'ptrs': traces['ptrs'][i]
                        })

        if args.output_file:
            fn_out = args.output_file
        else:
            fn_out = model_recover_path + '.' + args.split
        with open(fn_out, "w", encoding="utf-8") as fout:
            for l in output_lines:
                fout.write(l)
                fout.write("\n")

        if args.need_score_traces:
            with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                pickle.dump({
                    "version": 0.0,
                    "num_samples": len(output_lines)
                }, fout_trace)
                for x in score_trace_list:
                    pickle.dump(x, fout_trace)