Beispiel #1
0
def decode_batch(model,big_batch_data):

    torch.cuda.empty_cache()
    model.eval()

    max_src_length = max_seq_length - 2 - max_tgt_length
    input_lines = [x.strip() for x in big_batch_data]


    all_input_lines = [data_tokenizer.tokenize(
        x)[:max_src_length] for x in input_lines]

    total_length = len(input_lines)
    total_iter = total_length//batch_size
    if total_iter * batch_size < total_length:
        total_iter += 1

    all_output_lines = []

    for cur_iter in range(total_iter):
        input_lines = all_input_lines[cur_iter*batch_size:(cur_iter+1)*batch_size]
        input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1]))
        output_lines = [""] * len(input_lines)
        _chunk = input_lines
        buf_id = [x[0] for x in _chunk]
        buf = [x[1] for x in _chunk]

        max_a_len = max([len(x) for x in buf])
        instances = []
        for instance in [(x, max_a_len) for x in buf]:
            for proc in bi_uni_pipeline:
                instances.append(proc(instance))
        with torch.no_grad():
            batch = seq2seq_loader.batch_list_to_batch_tensors(
                instances)
            batch = [
                t.to(device) if t is not None else None for t in batch]
            input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
            traces = model(input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask = input_mask, task_idx=task_idx, mask_qkv=mask_qkv, decode=True)
            if beam_size > 1:
                traces = {k: v.tolist() for k, v in traces.items()}
                output_ids = traces['pred_seq']
            else:
                output_ids = traces.tolist()
            for i in range(len(buf)):
                w_ids = output_ids[i]
                output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                output_tokens = []
                for t in output_buf:
                    if t in ("[SEP]", "[PAD]"):
                        break
                    output_tokens.append(t)
                output_sequence = ' '.join(detokenize(output_tokens))
                output_lines[buf_id[i]] = output_sequence

        all_output_lines.extend(output_lines)

    assert len(all_output_lines) == len(all_input_lines)
    return all_output_lines
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    #Train File
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data src file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The input data tgt file name.")
    parser.add_argument("--check_file",
                        default=None,
                        type=str,
                        help="The input check knowledge data file name")

    #KS File
    parser.add_argument("--ks_src_file",
                        default=None,
                        type=str,
                        help="The input ks data src file name.")
    parser.add_argument("--ks_tgt_file",
                        default=None,
                        type=str,
                        help="The input ks data tgt file name.")

    parser.add_argument("--predict_input_file",
                        default=None,
                        type=str,
                        help="predict_input_file")
    parser.add_argument("--predict_output_file",
                        default=None,
                        type=str,
                        help="predict_output_file")

    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument("--predict_bleu",
                        default=0.2,
                        type=float,
                        help="The Predicted Bleu for KS Predict ")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run ks predict.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--train_avg_bpe_length",
                        default=25,
                        type=int,
                        help="average bpe length for train.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion_step",
        default=300,
        type=int,
        help=
        "Proportion of training to perform linear learning rate warmup for. ")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=67,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()

    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)

    handler = logging.FileHandler(os.path.join(args.log_dir, "train.log"),
                                  encoding='UTF-8')
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)

    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)

    logger.addHandler(handler)
    logger.addHandler(console)

    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    #Random Seed

    #torch.backends.cudnn.enabled = False
    #torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.deterministic = True
    # if n_gpu > 0:
    # 	torch.cuda.manual_seed_all(args.seed)

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    #Data process pipelines
    bi_uni_pipeline = [
        seq2seq_loader.Preprocess4Seq2seq(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]
    C_bi_uni_pipeline = [
        seq2seq_loader.C_Preprocess4Seq2seq(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]
    ks_predict_bi_uni_pipeline = [
        seq2seq_loader.Preprocess4Seq2seq_predict(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]

    if args.do_train:
        print("Loading QKR Train Dataset", args.data_dir)
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        fn_check = os.path.join(args.data_dir, args.check_file)

        train_dataset = seq2seq_loader.C_Seq2SeqDataset(
            fn_src,
            fn_tgt,
            fn_check,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=C_bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        print("Loading KS Train Dataset", args.data_dir)
        ks_fn_src = os.path.join(args.data_dir, args.ks_src_file)
        ks_fn_tgt = os.path.join(args.data_dir, args.ks_tgt_file)
        ks_train_dataset = seq2seq_loader.Seq2SeqDataset(
            ks_fn_src,
            ks_fn_tgt,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            ks_train_sampler = RandomSampler(ks_train_dataset,
                                             replacement=False)
            _batch_size = args.train_batch_size
        else:
            ks_train_sampler = DistributedSampler(ks_train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        ks_train_dataloader = torch.utils.data.DataLoader(
            ks_train_dataset,
            batch_size=_batch_size,
            sampler=ks_train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
        t_total = int(
            len(train_dataloader) * args.num_train_epochs /
            args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + (
        1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()

    #Recover model
    if args.model_recover_path:
        logger.info(" ** ** * Recover model: %s ** ** * ",
                    args.model_recover_path)
        model_recover = torch.load(args.model_recover_path, map_location='cpu')
        global_step = 0

    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    model = BertForPreTrainingLossMask.from_pretrained(
        args.bert_model,
        state_dict=model_recover,
        num_labels=cls_num_labels,
        num_rel=0,
        type_vocab_size=type_vocab_size,
        config_path=args.config_path,
        task_idx=3,
        num_sentlvl_labels=num_sentlvl_labels,
        max_position_embeddings=args.max_position_embeddings,
        label_smoothing=args.label_smoothing,
        fp32_embedding=args.fp32_embedding,
        relax_projection=relax_projection,
        new_pos_ids=args.new_pos_ids,
        ffn_type=args.ffn_type,
        hidden_dropout_prob=args.hidden_dropout_prob,
        attention_probs_dropout_prob=args.attention_probs_dropout_prob,
        num_qkv=args.num_qkv,
        seg_emb=args.seg_emb,
        mask_word_id=mask_word_id,
        search_beam_size=5,
        length_penalty=0,
        eos_id=eos_word_ids,
        sos_id=sos_word_id,
        forbid_duplicate_ngrams=True,
        forbid_ignore_set=None,
        mode="s2s")

    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)

    model.tmp_bert_emb.word_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.word_embeddings.weight.clone())
    model.tmp_bert_emb.token_type_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.token_type_embeddings.weight.clone())
    model.tmp_bert_emb.position_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.position_embeddings.weight.clone())
    model.mul_bert_emb.word_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.word_embeddings.weight.clone())
    model.mul_bert_emb.token_type_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.token_type_embeddings.weight.clone())
    model.mul_bert_emb.position_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.position_embeddings.weight.clone())
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from pytorch_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if args.optim_recover_path is not None:
        logger.info(" ** ** * Recover optimizer from : {} ** ** * ".format(
            args.optim_recover_path))
        optim_recover = torch.load(args.optim_recover_path, map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info(
                " ** ** * Recover optimizer: dynamic_loss_scale ** ** * ")
            optimizer.dynamic_loss_scale = True

    #logger.info(" ** ** * CUDA.empty_cache() ** ** * ")
    torch.cuda.empty_cache()

    # ################# TRAIN ############################ #
    if args.do_train:
        max_F1 = 0
        best_step = 0
        logger.info(" ** ** * Running training ** ** * ")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        start_epoch = 1

        for i_epoch in trange(start_epoch,
                              start_epoch + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)

            step = 0
            for batch, ks_batch in zip(train_dataloader, ks_train_dataloader):
                # ################# E step + M step + Mutual Information Loss ############################ #
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]

                input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, tgt_pos, labels, ks_labels, check_ids = batch
                oracle_pos, oracle_weights, oracle_labels = None, None, None

                loss_tuple = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   is_next,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv,
                                   tgt_pos=tgt_pos,
                                   labels=labels.half(),
                                   ks_labels=ks_labels,
                                   check_ids=check_ids)

                masked_lm_loss, next_sentence_loss, KL_loss, Mutual_loss, Golden_loss, predict_kl_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    masked_lm_loss = masked_lm_loss.mean()
                    next_sentence_loss = next_sentence_loss.mean()
                    Mutual_loss = Mutual_loss.mean()
                    Golden_loss = Golden_loss.mean()
                    KL_loss = KL_loss.mean()
                    predict_kl_loss = predict_kl_loss.mean()

                loss = masked_lm_loss + next_sentence_loss + KL_loss + predict_kl_loss + Mutual_loss + Golden_loss
                logger.info("In{}step, masked_lm_loss:{}".format(
                    step, masked_lm_loss))
                logger.info("In{}step, KL_loss:{}".format(step, KL_loss))
                logger.info("In{}step, Mutual_loss:{}".format(
                    step, Mutual_loss))
                logger.info("In{}step, Golden_loss:{}".format(
                    step, Golden_loss))
                logger.info("In{}step, predict_kl_loss:{}".format(
                    step, predict_kl_loss))

                logger.info("******************************************* ")

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total,
                        args.warmup_proportion_step / t_total)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                # ################# Knowledge Selection Loss ############################ #
                if random.randint(0, 4) == 0:
                    ks_batch = [
                        t.to(device) if t is not None else None
                        for t in ks_batch
                    ]

                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, _, labels, ks_labels = ks_batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                    loss_tuple = model(input_ids,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       is_next,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights,
                                       task_idx=task_idx,
                                       masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels,
                                       mask_qkv=mask_qkv,
                                       labels=labels,
                                       ks_labels=ks_labels,
                                       train_ks=True)

                    ks_loss, _ = loss_tuple
                    if n_gpu > 1:  # mean() to average on multi-gpu.
                        ks_loss = ks_loss.mean()
                    loss = ks_loss

                    logger.info("In{}step, ks_loss:{}".format(step, ks_loss))
                    logger.info("******************************************* ")

                    # ensure that accumlated gradients are normalized
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        optimizer.backward(loss)
                        if amp_handle:
                            amp_handle._clear_cache()
                    else:
                        loss.backward()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / t_total,
                            args.warmup_proportion_step / t_total)
                        if args.fp16:
                            # modify learning rate with special warm up BERT uses
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = lr_this_step
                        optimizer.step()
                        optimizer.zero_grad()

                step += 1
                ###################### Eval Every 5000 Step ############################ #
                if (global_step + 1) % 5000 == 0:
                    next_i = 0
                    model.eval()

                    # Know Rank Stage
                    logger.info(" ** ** * DEV Know Selection Begin ** ** * ")
                    with open(os.path.join(args.data_dir,
                                           args.predict_input_file),
                              "r",
                              encoding="utf-8") as file:
                        src_file = file.readlines()
                    with open(os.path.join(args.data_dir,
                                           "train_tgt_pad.empty"),
                              "r",
                              encoding="utf-8") as file:
                        tgt_file = file.readlines()
                    with open(os.path.join(args.data_dir,
                                           args.predict_output_file),
                              "w",
                              encoding="utf-8") as out:
                        while next_i < len(src_file):
                            batch_src = src_file[next_i:next_i +
                                                 args.eval_batch_size]
                            batch_tgt = tgt_file[next_i:next_i +
                                                 args.eval_batch_size]

                            next_i += args.eval_batch_size

                            ex_list = []
                            for src, tgt in zip(batch_src, batch_tgt):
                                src_tk = data_tokenizer.tokenize(src.strip())
                                tgt_tk = data_tokenizer.tokenize(tgt.strip())
                                ex_list.append((src_tk, tgt_tk))

                            batch = []
                            for idx in range(len(ex_list)):
                                instance = ex_list[idx]
                                for proc in ks_predict_bi_uni_pipeline:
                                    instance = proc(instance)
                                    batch.append(instance)

                            batch_tensor = seq2seq_loader.batch_list_to_batch_tensors(
                                batch)
                            batch = [
                                t.to(device) if t is not None else None
                                for t in batch_tensor
                            ]

                            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch

                            predict_bleu = args.predict_bleu * torch.ones(
                                [input_ids.shape[0]], device=input_ids.device)
                            oracle_pos, oracle_weights, oracle_labels = None, None, None
                            with torch.no_grad():
                                logits = model(input_ids,
                                               segment_ids,
                                               input_mask,
                                               lm_label_ids,
                                               is_next,
                                               masked_pos=masked_pos,
                                               masked_weights=masked_weights,
                                               task_idx=task_idx,
                                               masked_pos_2=oracle_pos,
                                               masked_weights_2=oracle_weights,
                                               masked_labels_2=oracle_labels,
                                               mask_qkv=mask_qkv,
                                               labels=predict_bleu,
                                               train_ks=True)

                                logits = torch.nn.functional.softmax(logits,
                                                                     dim=1)
                                labels = logits[:, 1].cpu().numpy()
                                for i in range(len(labels)):
                                    line = batch_src[i].strip()
                                    line += "\t"
                                    line += str(labels[i])
                                    out.write(line)
                                    out.write("\n")

                    data_path = os.path.join(args.data_dir,
                                             "qkr_dev.ks_score.tk")
                    src_path = os.path.join(args.data_dir, "qkr_dev.src.tk")
                    src_out_path = os.path.join(args.data_dir,
                                                "rank_qkr_dev.src.tk")
                    tgt_path = os.path.join(args.data_dir, "qkr_dev.tgt")

                    knowledge_selection(data_path, src_path, src_out_path)
                    logger.info(" ** ** * DEV Know Selection End ** ** * ")

                    # Decode Stage
                    logger.info(" ** ** * Dev Decode Begin ** ** * ")
                    with open(src_out_path, encoding="utf-8") as file:
                        dev_src_lines = file.readlines()
                    with open(tgt_path, encoding="utf-8") as file:
                        golden_response_lines = file.readlines()

                    decode_result = decode_batch(model, dev_src_lines)
                    logger.info(" ** ** * Dev Decode End ** ** * ")

                    # Compute dev F1
                    assert len(decode_result) == len(golden_response_lines)
                    C_F1 = f_one(decode_result, golden_response_lines)[0]
                    logger.info(
                        "** ** * Current F1 is {} ** ** * ".format(C_F1))
                    if C_F1 < max_F1:
                        logger.info(
                            "** ** * Current F1 is lower than Previous F1. So Stop Training ** ** * "
                        )
                        logger.info(
                            "** ** * The best model is {} ** ** * ".format(
                                best_step))
                        break
                    else:
                        max_F1 = C_F1
                        best_step = step
                        logger.info(
                            "** ** * Current F1 is larger than Previous F1. So Continue Training ** ** * "
                        )

                    # Save trained model
                    if (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):
                        logger.info(
                            "** ** * Saving fine-tuned model and optimizer ** ** * "
                        )
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir,
                            "model.{}_{}.bin".format(i_epoch, global_step))
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        output_optim_file = os.path.join(
                            args.output_dir, "optim.bin")
                        torch.save(optimizer.state_dict(), output_optim_file)

                        #logger.info(" ** ** * CUDA.empty_cache() ** ** * ")
                        torch.cuda.empty_cache()

    # ################# Predict ############################ #
    if args.do_predict:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq_predict(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]

        next_i = 0
        model.eval()

        with open(os.path.join(args.data_dir, args.predict_input_file),
                  "r",
                  encoding="utf-8") as file:
            src_file = file.readlines()
        with open("train_tgt_pad.empty", "r", encoding="utf-8") as file:
            tgt_file = file.readlines()
        with open(os.path.join(args.data_dir, args.predict_output_file),
                  "w",
                  encoding="utf-8") as out:
            logger.info("** ** * Continue knowledge ranking ** ** * ")
            for next_i in tqdm(
                    range(len(src_file) // args.eval_batch_size + 1)):
                #while next_i < len(src_file):
                batch_src = src_file[next_i *
                                     args.eval_batch_size:(next_i + 1) *
                                     args.eval_batch_size]
                batch_tgt = tgt_file[next_i *
                                     args.eval_batch_size:(next_i + 1) *
                                     args.eval_batch_size]
                #next_i += args.eval_batch_size

                ex_list = []
                for src, tgt in zip(batch_src, batch_tgt):
                    src_tk = data_tokenizer.tokenize(src.strip())
                    tgt_tk = data_tokenizer.tokenize(tgt.strip())
                    ex_list.append((src_tk, tgt_tk))

                batch = []
                for idx in range(len(ex_list)):
                    instance = ex_list[idx]
                    for proc in bi_uni_pipeline:
                        instance = proc(instance)
                        batch.append(instance)

                batch_tensor = seq2seq_loader.batch_list_to_batch_tensors(
                    batch)
                batch = [
                    t.to(device) if t is not None else None
                    for t in batch_tensor
                ]

                input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch

                predict_bleu = args.predict_bleu * torch.ones(
                    [input_ids.shape[0]], device=input_ids.device)
                oracle_pos, oracle_weights, oracle_labels = None, None, None
                with torch.no_grad():
                    logits = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   is_next,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv,
                                   labels=predict_bleu,
                                   train_ks=True)

                    logits = torch.nn.functional.softmax(logits, dim=1)
                    labels = logits[:, 1].cpu().numpy()
                    for i in range(len(labels)):
                        line = batch_src[i].strip()
                        line += "\t"
                        line += str(labels[i])
                        out.write(line)
                        out.write("\n")
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")

    parser.add_argument("--train_vae",
                        action='store_true',
                        help="Whether to train vae.")

    parser.add_argument('--bleu', type=float, default=0.2, help="Set Bleu ")

    # decoding parameters
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument('--subset',
                        type=int,
                        default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Ignore the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=1, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=1)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--not_predict_token',
                        type=str,
                        default=None,
                        help="Do not predict the tokens during decoding.")

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            new_segment_ids=args.new_segment_ids,
            mode="s2s",
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + \
          (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    def _get_token_id_set(s):
        r = None
        if s:
            w_list = []
            for w in s.split('|'):
                if w.startswith('[') and w.endswith(']'):
                    w_list.append(w.upper())
                else:
                    w_list.append(w)
            r = set(tokenizer.convert_tokens_to_ids(w_list))
        return r

    forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word)
    not_predict_set = _get_token_id_set(args.not_predict_token)
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=pair_num_relation,
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            not_predict_set=not_predict_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            mode=args.mode,
            max_position_embeddings=args.max_seq_length,
            ffn_type=args.ffn_type,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            pos_shift=args.pos_shift)
        del model_recover

        if args.fp16:
            model.half()
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()
        next_i = 0
        max_src_length = args.max_seq_length - 2 - args.max_tgt_length

        with open(args.input_file, encoding="utf-8") as fin:
            input_lines = [x.strip() for x in fin.readlines()]
            if args.subset > 0:
                logger.info("Decoding subset: %d", args.subset)
                input_lines = input_lines[:args.subset]
        data_tokenizer = WhitespaceTokenizer(
        ) if args.tokenized_input else tokenizer
        input_lines = [
            data_tokenizer.tokenize(x)[:max_src_length] for x in input_lines
        ]
        input_lines = sorted(list(enumerate(input_lines)),
                             key=lambda x: -len(x[1]))
        output_lines = [""] * len(input_lines)
        score_trace_list = [None] * len(input_lines)
        total_batch = math.ceil(len(input_lines) / args.batch_size)

        with tqdm(total=total_batch) as pbar:
            while next_i < len(input_lines):
                _chunk = input_lines[next_i:next_i + args.batch_size]
                buf_id = [x[0] for x in _chunk]
                buf = [x[1] for x in _chunk]
                next_i += args.batch_size
                max_a_len = max([len(x) for x in buf])
                instances = []
                for instance in [(x, max_a_len) for x in buf]:
                    for proc in bi_uni_pipeline:
                        instances.append(proc(instance))
                with torch.no_grad():
                    batch = seq2seq_loader.batch_list_to_batch_tensors(
                        instances)
                    batch = [
                        t.to(device) if t is not None else None for t in batch
                    ]
                    input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                    traces = model(input_ids,
                                   token_type_ids,
                                   position_ids,
                                   input_mask,
                                   task_idx=task_idx,
                                   mask_qkv=mask_qkv,
                                   bleu=args.bleu)
                    if args.beam_size > 1:
                        traces = {k: v.tolist() for k, v in traces.items()}
                        output_ids = traces['pred_seq']
                    else:
                        output_ids = traces.tolist()
                    for i in range(len(buf)):
                        w_ids = output_ids[i]
                        output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                        output_tokens = []
                        for t in output_buf:
                            if t in ("[SEP]", "[PAD]"):
                                break
                            output_tokens.append(t)
                        output_sequence = ' '.join(detokenize(output_tokens))
                        output_lines[buf_id[i]] = output_sequence
                        if args.need_score_traces:
                            score_trace_list[buf_id[i]] = {
                                'scores': traces['scores'][i],
                                'wids': traces['wids'][i],
                                'ptrs': traces['ptrs'][i]
                            }
                pbar.update(1)
        if args.output_file:
            fn_out = args.output_file
        else:
            fn_out = model_recover_path + '.' + args.split
        with open(fn_out, "w", encoding="utf-8") as fout:
            for l in output_lines:
                fout.write(l)
                fout.write("\n")

        if args.need_score_traces:
            with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                pickle.dump({
                    "version": 0.0,
                    "num_samples": len(input_lines)
                }, fout_trace)
                for x in score_trace_list:
                    pickle.dump(x, fout_trace)
Beispiel #4
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")

    # decoding parameters
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument('--subset',
                        type=int,
                        default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument('--topk', type=int, default=10, help="Value of K.")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Ignore the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--not_predict_token',
                        type=str,
                        default=None,
                        help="Do not predict the tokens during decoding.")

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # tokenizer = BertTokenizer.from_pretrained(
    #     args.bert_model, do_lower_case=args.do_lower_case)
    tokenizer = BertTokenizer(
        vocab_file=
        '/ps2/intern/clsi/BERT/bert_weights/cased_L-24_H-1024_A-16/vocab.txt',
        do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            new_segment_ids=args.new_segment_ids,
            mode="s2s",
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    def _get_token_id_set(s):
        r = None
        if s:
            w_list = []
            for w in s.split('|'):
                if w.startswith('[') and w.endswith(']'):
                    w_list.append(w.upper())
                else:
                    w_list.append(w)
            r = set(tokenizer.convert_tokens_to_ids(w_list))
        return r

    forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word)
    not_predict_set = _get_token_id_set(args.not_predict_token)
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=pair_num_relation,
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            not_predict_set=not_predict_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            mode=args.mode,
            max_position_embeddings=args.max_seq_length,
            ffn_type=args.ffn_type,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            pos_shift=args.pos_shift,
            topk=args.topk,
            config_path=args.config_path)
        del model_recover

        if args.fp16:
            model.half()
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()
        next_i = 0
        max_src_length = args.max_seq_length - 2 - args.max_tgt_length

        ## for YFG style json
        # testset = loads_json(args.input_file, 'Load Test Set: '+args.input_file)
        # if args.subset > 0:
        #     logger.info("Decoding subset: %d", args.subset)
        #     testset = testset[:args.subset]

        with open(args.input_file, encoding="utf-8") as fin:
            data = json.load(fin)
        #     input_lines = [x.strip() for x in fin.readlines()]
        #     if args.subset > 0:
        #         logger.info("Decoding subset: %d", args.subset)
        #         input_lines = input_lines[:args.subset]
        # data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
        # input_lines = [data_tokenizer.tokenize(
        #     x)[:max_src_length] for x in input_lines]
        # input_lines = sorted(list(enumerate(input_lines)),
        #                      key=lambda x: -len(x[1]))
        # output_lines = [""] * len(input_lines)
        # score_trace_list = [None] * len(input_lines)
        # total_batch = math.ceil(len(input_lines) / args.batch_size)

        data_tokenizer = WhitespaceTokenizer(
        ) if args.tokenized_input else tokenizer
        PQA_dict = {}  #will store the generated distractors
        dis_tot = 0
        dis_n = 0
        len_tot = 0
        hypothesis = {}
        ##change to process one by one and store the distractors in PQA json form
        ##with tqdm(total=total_batch) as pbar:
        # for example in tqdm(testset):
        #     question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id'])
        #     if question_id in hypothesis:
        #         continue
        # dis_n += 1
        # if dis_n % 2000 == 0:
        #     logger.info("Already processed: "+str(dis_n))
        counter = 0
        for race_id, example in tqdm(data.items()):
            counter += 1
            if args.subset > 0 and counter >= args.subset:
                break
            eg_dict = {}
            # eg_dict["question_id"] = question_id
            # eg_dict["question"] = ' '.join(example['question'])
            # eg_dict["context"] = ' '.join(example['article'])

            eg_dict["question"] = example['question']
            eg_dict["context"] = example['context']
            label = int(example["label"])
            options = example["options"]
            answer = options[label]
            #new_distractors = []
            pred1 = None
            pred2 = None
            pred3 = None
            #while next_i < len(input_lines):
            #_chunk = input_lines[next_i:next_i + args.batch_size]
            #line = example["context"].strip() + ' ' + example["question"].strip()
            question = example['question']
            question = question.replace('_', ' ')
            line = ' '.join(
                nltk.word_tokenize(example['context']) +
                nltk.word_tokenize(question))
            line = [data_tokenizer.tokenize(line)[:max_src_length]]
            # buf_id = [x[0] for x in _chunk]
            # buf = [x[1] for x in _chunk]
            buf = line
            #next_i += args.batch_size
            max_a_len = max([len(x) for x in buf])
            instances = []
            for instance in [(x, max_a_len) for x in buf]:
                for proc in bi_uni_pipeline:
                    instances.append(proc(instance))
            with torch.no_grad():
                batch = seq2seq_loader.batch_list_to_batch_tensors(instances)
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                # for i in range(1):
                #try max 10 times
                # if len(new_distractors) >= 3:
                #     break
                traces = model(input_ids,
                               token_type_ids,
                               position_ids,
                               input_mask,
                               task_idx=task_idx,
                               mask_qkv=mask_qkv)
                if args.beam_size > 1:
                    traces = {k: v.tolist() for k, v in traces.items()}
                    output_ids = traces['pred_seq']
                    # print (np.array(output_ids).shape)
                    # print (output_ids)
                else:
                    output_ids = traces.tolist()
                # now only supports single batch decoding!!!
                # will keep the second and third sequence as backup
                for i in range(len(buf)):
                    # print (len(buf), buf)
                    for s in range(len(output_ids)):
                        output_seq = output_ids[s]
                        #w_ids = output_ids[i]
                        #output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                        output_buf = tokenizer.convert_ids_to_tokens(
                            output_seq)
                        output_tokens = []
                        for t in output_buf:
                            if t in ("[SEP]", "[PAD]"):
                                break
                            output_tokens.append(t)
                        if s == 1:
                            backup_1 = output_tokens
                        if s == 2:
                            backup_2 = output_tokens
                        if pred1 is None:
                            pred1 = output_tokens
                        elif jaccard_similarity(pred1, output_tokens) < 0.5:
                            if pred2 is None:
                                pred2 = output_tokens
                            elif pred3 is None:
                                if jaccard_similarity(pred2,
                                                      output_tokens) < 0.5:
                                    pred3 = output_tokens
                        if pred1 is not None and pred2 is not None and pred3 is not None:
                            break
                    if pred2 is None:
                        pred2 = backup_1
                        if pred3 is None:
                            pred3 = backup_2
                    elif pred3 is None:
                        pred3 = backup_1
                        # output_sequence = ' '.join(detokenize(output_tokens))
                        # print (output_sequence)
                        # print (output_sequence)
                        # if output_sequence.lower().strip() == answer.lower().strip():
                        #     continue
                        # repeated = False
                        # for cand in new_distractors:
                        #     if output_sequence.lower().strip() == cand.lower().strip():
                        #         repeated = True
                        #         break
                        # if not repeated:
                        #     new_distractors.append(output_sequence.strip())

            #hypothesis[question_id] = [pred1, pred2, pred3]
            new_distractors = [pred1, pred2, pred3]
            # print (new_distractors)
            # dis_tot += len(new_distractors)
            # # fill the missing ones with original distractors
            # for i in range(4):
            #     if len(new_distractors) >= 3:
            #         break
            #     elif i == label:
            #         continue
            #     else:
            #         new_distractors.append(options[i])
            for dis in new_distractors:
                len_tot += len(dis)
                dis_n += 1
            new_distractors = [
                ' '.join(detokenize(dis)) for dis in new_distractors
                if dis is not None
            ]
            assert len(new_distractors) == 3, "Number of distractors WRONG"
            new_distractors.insert(label, answer)
            #eg_dict["generated_distractors"] = new_distractors
            eg_dict["options"] = new_distractors
            eg_dict["label"] = label
            #PQA_dict[question_id] = eg_dict
            PQA_dict[race_id] = eg_dict

        # reference = {}
        # for example in testset:
        #     question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id'])
        #     if question_id not in reference.keys():
        #         reference[question_id] = [example['distractor']]
        #     else:
        #         reference[question_id].append(example['distractor'])

        # _ = eval(hypothesis, reference)
        # assert len(PQA_dict) == len(data), "Number of examples WRONG"
        # logger.info("Average number of GENERATED distractor per question: "+str(dis_tot/dis_n))
        logger.info("Average length of distractors: " + str(len_tot / dis_n))
        with open(args.output_file, mode='w', encoding='utf-8') as f:
            json.dump(PQA_dict, f, indent=4)