Esempio n. 1
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--model_type", default=None, type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(TOKENIZER_CLASSES.keys()))
    parser.add_argument("--model_path", default=None, type=str, required=True,
                        help="Path to the model checkpoint.")
    parser.add_argument("--config_path", default=None, type=str,
                        help="Path to config.json for the model.")

    parser.add_argument("--sentence_shuffle_rate", default=0, type=float)
    parser.add_argument("--layoutlm_only_layout", action='store_true')

    # tokenizer_name
    parser.add_argument("--tokenizer_name", default=None, type=str, required=True,
                        help="tokenizer name")
    parser.add_argument("--max_seq_length", default=512, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")

    # decoding parameters
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp', action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument("--input_folder", type=str, help="Input folder")
    parser.add_argument("--cached_feature_file", type=str)
    parser.add_argument('--subset', type=int, default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split", type=str, default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input', action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed', type=int, default=123,
                        help="random seed for initialization")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--batch_size', type=int, default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size', type=int, default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty', type=float, default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word', type=str, default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=1, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode', default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length', type=int, default=128,
                        help="maximum length of target sequence")
    parser.add_argument('--s2s_special_token', action='store_true',
                        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment', action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument('--s2s_share_segment', action='store_true',
                        help="Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment).")
    parser.add_argument('--pos_shift', action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument("--cache_dir", default=None, type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")

    args = parser.parse_args()

    model_path = args.model_path
    assert os.path.exists(model_path), 'model_path ' + model_path + ' not exists!'

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1.")
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    if args.seed > 0:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)
    else:
        random_seed = random.randint(0, 10000)
        logger.info("Set random seed as: {}".format(random_seed))
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)

    tokenizer = TOKENIZER_CLASSES[args.model_type].from_pretrained(
        args.tokenizer_name, do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
        max_len=args.max_seq_length
    )

    if args.model_type == "roberta":
        vocab = tokenizer.encoder
    else:
        vocab = tokenizer.vocab

    # NOTE: tokenizer cannot setattr, so move this to the initialization step
    # tokenizer.max_len = args.max_seq_length

    config_file = args.config_path if args.config_path else os.path.join(args.model_path, "config.json")
    logger.info("Read decoding config from: %s" % config_file)
    config = BertConfig.from_json_file(config_file,
                                       # base_model_type=args.model_type
                                       layoutlm_only_layout_flag=args.layoutlm_only_layout
                                       )

    bi_uni_pipeline = []
    bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(
        list(vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length,
        max_tgt_length=args.max_tgt_length, pos_shift=args.pos_shift,
        source_type_id=config.source_type_id, target_type_id=config.target_type_id,
        cls_token=tokenizer.cls_token, sep_token=tokenizer.sep_token, pad_token=tokenizer.pad_token,
        layout_flag=args.model_type == 'layoutlm'
    ))

    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        [tokenizer.mask_token, tokenizer.sep_token, tokenizer.sep_token])
    forbid_ignore_set = None
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))
    print(args.model_path)
    found_checkpoint_flag = False
    for model_recover_path in [args.model_path.strip()]:
        logger.info("***** Recover model: %s *****", model_recover_path)
        found_checkpoint_flag = True
        model = LayoutlmForSeq2SeqDecoder.from_pretrained(
            model_recover_path, config=config, mask_word_id=mask_word_id, search_beam_size=args.beam_size,
            length_penalty=args.length_penalty, eos_id=eos_word_ids, sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set,
            ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode,
            max_position_embeddings=args.max_seq_length, pos_shift=args.pos_shift,
        )

        if args.fp16:
            model.half()
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()
        next_i = 0
        max_src_length = args.max_seq_length - 2 - args.max_tgt_length
        max_tgt_length = args.max_tgt_length

        example_path = args.input_file if args.input_file else args.input_folder

        to_pred = load_and_cache_layoutlm_examples(
            example_path, tokenizer, local_rank=-1,
            cached_features_file=args.cached_feature_file, shuffle=False, layout_flag=args.model_type == 'layoutlm',
            src_shuffle_rate=args.sentence_shuffle_rate
        )

        input_lines = convert_src_layout_inputs_to_tokens(to_pred, tokenizer.convert_ids_to_tokens, max_src_length,
                                                          layout_flag=args.model_type == 'layoutlm')
        target_lines = convert_tgt_layout_inputs_to_tokens(to_pred, tokenizer.convert_ids_to_tokens, max_tgt_length,
                                                           layout_flag=args.model_type == 'layoutlm')
        target_geo_scores = [x['bleu'] for x in to_pred]

        if args.subset > 0:
            logger.info("Decoding subset: %d", args.subset)
            input_lines = input_lines[:args.subset]

        # NOTE: add the sequence index through enumerate
        input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1]))

        score_trace_list = [None] * len(input_lines)
        total_batch = math.ceil(len(input_lines) / args.batch_size)

        fn_out = args.output_file
        fout = open(fn_out, "w", encoding="utf-8")

        with tqdm(total=total_batch) as pbar:
            batch_count = 0
            first_batch = True
            while first_batch or (next_i + args.batch_size <= len(input_lines)):
            # while next_i < len(input_lines):
                _chunk = input_lines[next_i:next_i + args.batch_size]
                buf_id = [x[0] for x in _chunk]
                buf = [x[1] for x in _chunk]
                next_i += args.batch_size
                batch_count += 1
                max_a_len = max([len(x) for x in buf])
                instances = []
                for instance in [(x, max_a_len) for x in buf]:
                    for proc in bi_uni_pipeline:
                        instances.append(proc(instance))
                with torch.no_grad():
                    batch = seq2seq_loader.batch_list_to_batch_tensors(
                        instances)
                    batch = [
                        t.to(device) if t is not None else None for t in batch]
                    input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                    traces = model(input_ids, token_type_ids,
                                   position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv)
                    if args.beam_size > 1:
                        traces = {k: v.tolist() for k, v in traces.items()}
                        output_ids = traces['pred_seq']
                    else:
                        output_ids = traces.tolist()
                    for i in range(len(buf)):
                        w_ids = output_ids[i]
                        output_buf = get_tokens_from_src_and_index(src=buf[i], index=w_ids, modifier=lambda x: x-1)
                        output_tokens = []
                        for t in output_buf:
                            if t in (tokenizer.sep_token, tokenizer.pad_token):
                                break
                            output_tokens.append(t)
                        output_tokens = output_tokens[:len(target_lines[buf_id[i]])]
                        if args.model_type == "roberta":
                            output_sequence = tokenizer.convert_tokens_to_string(output_tokens)
                        else:
                            output_sequence = ' '.join(detokenize(output_tokens))
                        if '\n' in output_sequence:
                            output_sequence = " [X_SEP] ".join(output_sequence.split('\n'))

                        target = target_lines[buf_id[i]]
                        target = detokenize(target)
                        result = output_sequence.split()
                        score = sentence_bleu([target], result)

                        geo_score = target_geo_scores[buf_id[i]]
                        target_sequence = ' '.join(target)

                        fout.write('{}\t{:.8f}\t{:.8f}\t{}\t{}\n'.format(buf_id[i], score, geo_score, output_sequence, target_sequence))

                        if first_batch or batch_count % 50 == 0:
                            logger.info("{}: BLEU={:.4f} GEO={:.4f} | {}"
                                        .format(buf_id[i], score, target_geo_scores[buf_id[i]], output_sequence))
                        if args.need_score_traces:
                            score_trace_list[buf_id[i]] = {
                                'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]}
                pbar.update(1)
                first_batch = False

        outscore = open(fn_out, encoding='utf-8')
        bleu_score = geo_score = {}
        total_bleu = total_geo = 0.0
        for line in outscore.readlines():
            id, bleu, geo, out_seq, tgt_seq = line.split('\t')
            bleu_score[int(id)] = float(bleu)
            total_bleu += float(bleu)
            geo_score[int(id)] = float(geo)
            total_geo += float(geo)
        print("avg_bleu", round(100 * total_bleu / len(bleu_score), 1))
        print("avg_geo", round(100 * total_geo / len(geo_score), 1))
        # released model (layoutreader-base-readingbank): avg_bleu 98.2, avg_geo 69.7

        if args.need_score_traces:
            with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                pickle.dump(
                    {"version": 0.0, "num_samples": len(input_lines)}, fout_trace)
                for x in score_trace_list:
                    pickle.dump(x, fout_trace)

    if not found_checkpoint_flag:
        logger.info("Not found the model checkpoint file!")
Esempio n. 2
0
def valid_generator(valid_dataset, tokenizer, args, iters, test=False):
    input_lines = []
    if args.num_labels == 3:
        sentence_list, label_list = valid_dataset.sentence_list, valid_dataset.label_list
        for sen, label in zip(sentence_list, label_list):
            sep_token = tokenizer.sep_token
            label = 'positive' if label == 1 else 'negative'
            src = (sen + sep_token + label)
            src = tokenizer.tokenize(src)
            input_lines.append(src)
    else:
        sentence_list, rel_list = valid_dataset.sentence_list, valid_dataset.rel_list
        for sen, rel in zip(sentence_list, rel_list):
            sep_token = tokenizer.sep_token
            src = (sen + sep_token + rel)
            src = tokenizer.tokenize(src)
            input_lines.append(src)

    config_file = args.config_name if args.config_name else os.path.join(args.generator_path, "config.json")
    config = BertConfig.from_json_file(config_file)

    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        [tokenizer.mask_token, tokenizer.sep_token, tokenizer.sep_token])

    generator = BertForSeq2SeqDecoder.from_pretrained(
        args.generator_path, config=config, mask_word_id=mask_word_id, 
        eos_id=eos_word_ids, sos_id=sos_word_id)
    generator.to(args.device)

    preprocessor = seq2seq_loader.Preprocess4Seq2seqDecoder(
        list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_source_seq_length,
        max_tgt_length=args.max_target_seq_length, pos_shift=args.pos_shift,
        source_type_id=config.source_type_id, target_type_id=config.target_type_id,
        cls_token=tokenizer.cls_token, sep_token=tokenizer.sep_token,
        pad_token=tokenizer.pad_token)

    input_lines = list(enumerate(input_lines))
    output_lines = [""] * len(input_lines)
    next_i = 0
    with torch.no_grad():
        while next_i < len(input_lines):
            _chunk = input_lines[next_i:next_i + args.batch_size * 32]
            buf_id = [x[0] for x in _chunk]
            buf = [x[1] for x in _chunk]
            next_i += args.batch_size * 32
            max_a_len = max([len(x) for x in buf])
            instances = []
            for instance in [(x, max_a_len) for x in buf]:
                instances.append(preprocessor(instance))

            with torch.no_grad():
                batch = seq2seq_loader.batch_list_to_batch_tensors(
                    instances)
                batch = [t.to(args.device) if t is not None else None for t in batch]
                input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                traces = generator(input_ids, token_type_ids,
                                   position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv)
                if args.beam_size > 1:
                    traces = {k: v.tolist() for k, v in traces.items()}
                    output_ids = traces['pred_seq']
                else:
                    output_ids = traces.tolist()

                def detokenize(tk_list):
                    r_list = []
                    for tk in tk_list:
                        if tk.startswith('##') and len(r_list) > 0:
                            r_list[-1] = r_list[-1] + tk[2:]
                        else:
                            r_list.append(tk)
                    return r_list

                for i in range(len(buf)):
                    w_ids = output_ids[i]
                    output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                    output_tokens = []
                    for t in output_buf:
                        if t in (tokenizer.sep_token, tokenizer.pad_token):
                            break
                        output_tokens.append(t)
                    output_sequence = ' '.join(detokenize(output_tokens))
                    if '\n' in output_sequence:
                        output_sequence = " [X_SEP] ".join(output_sequence.split('\n'))
                    output_lines[buf_id[i]] = output_sequence

    save_path = os.path.join(args.output_dir, "output_generation")
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    if test:
        filename = '/test_explanation_{}.txt'.format(iters)
    else:
        filename = '/valid_explanation_{}.txt'.format(iters)
    with open(save_path + filename, 'w', encoding='utf-8') as f:
        f.write('\n'.join(output_lines))
Esempio n. 3
0
    def get_classifier_dataset(self, args, generator, retriever, classifier_dataset, iters):
        input_lines = []
        if args.num_labels == 3:
            for sen, label in zip(self.sentence_list, self.label_list):
                label = 'positive' if label == 1 else 'negative'
                src = (sen + self.generator_tokenizer.sep_token + label)
                src = self.generator_tokenizer.tokenize(src)
                input_lines.append(src)
        else:
            for sen, label in zip(self.sentence_list, self.rel_list):
                src = sen + self.generator_tokenizer.sep_token + label
                src = self.generator_tokenizer.tokenize(src)
                input_lines.append(src)

        config_file = args.config_name if args.config_name else os.path.join(args.generator_path, "config.json")
        config = BertConfig.from_json_file(config_file)
        config.vocab_size = len(self.generator_tokenizer)

        mask_word_id, eos_word_ids, sos_word_id = self.generator_tokenizer.convert_tokens_to_ids(
            [self.generator_tokenizer.mask_token, self.generator_tokenizer.sep_token, self.generator_tokenizer.sep_token])

        generator = BertForSeq2SeqDecoder.from_pretrained(
            args.generator_path, config=config, mask_word_id=mask_word_id,
            eos_id=eos_word_ids, sos_id=sos_word_id)
        generator.to(args.device)
        if torch.cuda.device_count() > 1:
            generator = torch.nn.DataParallel(generator)

        preprocessor = seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(self.generator_tokenizer.vocab.keys()), self.generator_tokenizer.convert_tokens_to_ids, pos_shift=args.pos_shift,
            source_type_id=config.source_type_id, target_type_id=config.target_type_id,
            cls_token=self.generator_tokenizer.cls_token, sep_token=self.generator_tokenizer.sep_token,
            pad_token=self.generator_tokenizer.pad_token)

        input_lines = list(enumerate(input_lines))
        output_lines = [""] * len(input_lines)
        next_i = 0

        temp_batch_size = args.batch_size * args.gradient_accumulation_steps * 4
        total_batch = math.ceil(len(input_lines) / temp_batch_size)
        with tqdm.tqdm(total=total_batch) as pbar:
            with torch.no_grad():
                while next_i < len(input_lines):
                    _chunk = input_lines[next_i:next_i + temp_batch_size]
                    buf_id = [x[0] for x in _chunk]
                    buf = [x[1] for x in _chunk]
                    next_i += temp_batch_size
                    max_a_len = max([len(x) for x in buf])
                    instances = []
                    for instance in [(x, max_a_len) for x in buf]:
                        instances.append(preprocessor(instance))

                    with torch.no_grad():
                        batch = seq2seq_loader.batch_list_to_batch_tensors(
                            instances)
                        batch = [t.to(args.device) if t is not None else None for t in batch]
                        input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                        traces = generator(input_ids, token_type_ids,
                                        position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv)
                        if args.beam_size > 1:
                            traces = {k: v.tolist() for k, v in traces.items()}
                            output_ids = traces['pred_seq']
                        else:
                            output_ids = traces.tolist()

                        def detokenize(tk_list):
                            r_list = []
                            for tk in tk_list:
                                if tk.startswith('##') and len(r_list) > 0:
                                    r_list[-1] = r_list[-1] + tk[2:]
                                else:
                                    r_list.append(tk)
                            return r_list

                        for i in range(len(buf)):
                            w_ids = output_ids[i]
                            output_buf = self.generator_tokenizer.convert_ids_to_tokens(w_ids)
                            output_tokens = []
                            for t in output_buf:
                                if t in (self.generator_tokenizer.sep_token, self.generator_tokenizer.pad_token):
                                    break
                                output_tokens.append(t)
                            output_sequence = ' '.join(detokenize(output_tokens))
                            if '\n' in output_sequence:
                                output_sequence = " [X_SEP] ".join(output_sequence.split('\n'))
                            output_lines[buf_id[i]] = output_sequence
                    pbar.update(1)

        save_path = os.path.join(args.output_dir, "output_generation")
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        filename = '/train_explanation_{}.txt'.format(iters)
        with open(save_path + filename, 'w', encoding='utf-8') as f:
            f.write('\n'.join(output_lines))
        
        if not retriever.loaded:
            retriever.load_unlabeled_sen(args.unlabeled_data)
        
        retriever.update_exp(self.sentence_list, output_lines)

        classifier_dataset.update_unlabeled(self.sentence_list, self.label_list)
        return classifier_dataset
Esempio n. 4
0
def decode_all():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(TOKENIZER_CLASSES.keys()))
    parser.add_argument("--model_path",
                        default=None,
                        type=str,
                        required=True,
                        help="Path to the model checkpoint.")
    parser.add_argument("--model_ckpt",
                        default=None,
                        type=str,
                        required=True,
                        help="Path to the model checkpoint.")

    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Path to config.json for the model.")

    # tokenizer_name
    parser.add_argument("--tokenizer_name",
                        default=None,
                        type=str,
                        required=True,
                        help="tokenizer name")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    # decoding parameters
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument('--subset',
                        type=int,
                        default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        "--prepend_len",
        action='store_true',
        help=
        "Set this flag if you are using dataset with prepended length tokens in target sequences."
    )
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=1, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument(
        "--cache_dir",
        default=None,
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    if args.seed > 0:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)
    else:
        random_seed = random.randint(0, 10000)
        logger.info("Set random seed as: {}".format(random_seed))
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)

    if args.prepend_len:
        tgt_segments = [85] + list(range(100, 400, 15)) + [400]
        additional_special_tokens = [f'[unused{seg}]' for seg in tgt_segments]
        logger.info(f'additional_special_tokens: {additional_special_tokens}')
        tokenizer = TOKENIZER_CLASSES[args.model_type].from_pretrained(
            args.tokenizer_name,
            do_lower_case=args.do_lower_case,
            cache_dir=args.cache_dir if args.cache_dir else None,
            additional_special_tokens=additional_special_tokens)
    else:
        tokenizer = TOKENIZER_CLASSES[args.model_type].from_pretrained(
            args.tokenizer_name,
            do_lower_case=args.do_lower_case,
            cache_dir=args.cache_dir if args.cache_dir else None)

    if args.model_type == "roberta":
        vocab = tokenizer.encoder
    else:
        vocab = tokenizer.vocab

    # tokenizer.max_len = args.max_seq_length
    tokenizer.model_max_length = args.max_seq_length

    # print(f'TORCH VERISON: {torch.version.cuda}')
    # print(f'CUDA HOME: {torch.utils.cpp_extension.CUDA_HOME}')
    checkpoints = args.model_ckpt.split(',')
    for ckpt in checkpoints:
        model_path = os.path.join(args.model_path, 'ckpt-{}'.format(ckpt))
        config_file = args.config_path if args.config_path else os.path.join(
            model_path, "config.json")
        logger.info("Read decoding config from: %s" % config_file)
        config = BertConfig.from_json_file(config_file)

        bi_uni_pipeline = []
        bi_uni_pipeline.append(
            seq2seq_loader.Preprocess4Seq2seqDecoder(
                list(vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                max_tgt_length=args.max_tgt_length,
                pos_shift=args.pos_shift,
                source_type_id=config.source_type_id,
                target_type_id=config.target_type_id,
                cls_token=tokenizer.cls_token,
                sep_token=tokenizer.sep_token,
                pad_token=tokenizer.pad_token))

        mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
            [tokenizer.mask_token, tokenizer.sep_token, tokenizer.sep_token])
        forbid_ignore_set = None
        if args.forbid_ignore_word:
            w_list = []
            for w in args.forbid_ignore_word.split('|'):
                if w.startswith('[') and w.endswith(']'):
                    w_list.append(w.upper())
                else:
                    w_list.append(w)
            forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))
        print(model_path)
        found_checkpoint_flag = False
        for model_recover_path in [model_path.strip()]:
            logger.info("***** Recover model: %s *****", model_recover_path)
            found_checkpoint_flag = True
            model = BertForSeq2SeqDecoder.from_pretrained(
                model_recover_path,
                config=config,
                mask_word_id=mask_word_id,
                search_beam_size=args.beam_size,
                length_penalty=args.length_penalty,
                eos_id=eos_word_ids,
                sos_id=sos_word_id,
                forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
                forbid_ignore_set=forbid_ignore_set,
                ngram_size=args.ngram_size,
                min_len=args.min_len,
                mode=args.mode,
                max_position_embeddings=args.max_seq_length,
                pos_shift=args.pos_shift,
            )

            if args.fp16:
                model.half()
            model.to(device)
            if n_gpu > 1:
                model = torch.nn.DataParallel(model)

            torch.cuda.empty_cache()
            model.eval()
            next_i = 0
            max_src_length = args.max_seq_length - 2 - args.max_tgt_length

            to_pred = load_and_cache_examples(args.input_file,
                                              tokenizer,
                                              local_rank=-1,
                                              cached_features_file=None,
                                              shuffle=False)

            input_lines = []
            for line in to_pred:
                input_lines.append(
                    tokenizer.convert_ids_to_tokens(
                        line["source_ids"])[:max_src_length])
            if args.subset > 0:
                logger.info("Decoding subset: %d", args.subset)
                input_lines = input_lines[:args.subset]

            input_lines = sorted(list(enumerate(input_lines)),
                                 key=lambda x: -len(x[1]))
            output_lines = [""] * len(input_lines)
            score_trace_list = [None] * len(input_lines)
            total_batch = math.ceil(len(input_lines) / args.batch_size)

            with tqdm(total=total_batch) as pbar:
                batch_count = 0
                first_batch = True
                while next_i < len(input_lines):
                    _chunk = input_lines[next_i:next_i + args.batch_size]
                    buf_id = [x[0] for x in _chunk]
                    buf = [x[1] for x in _chunk]
                    next_i += args.batch_size
                    batch_count += 1
                    max_a_len = max([len(x) for x in buf])
                    instances = []
                    for instance in [(x, max_a_len) for x in buf]:
                        for proc in bi_uni_pipeline:
                            instances.append(proc(instance))
                    with torch.no_grad():
                        batch = seq2seq_loader.batch_list_to_batch_tensors(
                            instances)
                        batch = [
                            t.to(device) if t is not None else None
                            for t in batch
                        ]
                        input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                        traces = model(input_ids,
                                       token_type_ids,
                                       position_ids,
                                       input_mask,
                                       task_idx=task_idx,
                                       mask_qkv=mask_qkv)
                        if args.beam_size > 1:
                            traces = {k: v.tolist() for k, v in traces.items()}
                            output_ids = traces['pred_seq']
                        else:
                            output_ids = traces.tolist()
                        for i in range(len(buf)):
                            w_ids = output_ids[i]
                            output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                            output_tokens = []
                            for t in output_buf:
                                if t in (tokenizer.sep_token,
                                         tokenizer.pad_token):
                                    break
                                output_tokens.append(t)
                            if args.model_type == "roberta":
                                output_sequence = tokenizer.convert_tokens_to_string(
                                    output_tokens)
                            else:
                                output_sequence = ' '.join(
                                    detokenize(output_tokens))
                            if '\n' in output_sequence:
                                output_sequence = " [X_SEP] ".join(
                                    output_sequence.split('\n'))
                            output_lines[buf_id[i]] = output_sequence
                            if first_batch or batch_count % 50 == 0:
                                logger.info("{} = {}".format(
                                    buf_id[i], output_sequence))
                            if args.need_score_traces:
                                score_trace_list[buf_id[i]] = {
                                    'scores': traces['scores'][i],
                                    'wids': traces['wids'][i],
                                    'ptrs': traces['ptrs'][i]
                                }
                    pbar.update(1)
                    first_batch = False
            if args.output_file:
                fn_out = args.output_file
            else:
                fn_out = model_recover_path + '.' + args.split
            with open(fn_out, "w", encoding="utf-8") as fout:
                for l in output_lines:
                    fout.write(l)
                    fout.write("\n")

            if args.need_score_traces:
                with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                    pickle.dump(
                        {
                            "version": 0.0,
                            "num_samples": len(input_lines)
                        }, fout_trace)
                    for x in score_trace_list:
                        pickle.dump(x, fout_trace)

        if not found_checkpoint_flag:
            logger.info("Not found the model checkpoint file!")
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(TOKENIZER_CLASSES.keys()))
    parser.add_argument("--model_path",
                        default=None,
                        type=str,
                        required=True,
                        help="Path to the model checkpoint.")
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Path to config.json for the model.")

    parser.add_argument("--embedding_model_path",
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument("--doc_file", default=None, type=str, required=True)
    parser.add_argument("--top_k", default=5, type=int)
    # tokenizer_name
    parser.add_argument("--tokenizer_name",
                        default=None,
                        type=str,
                        required=True,
                        help="tokenizer name")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    # decoding parameters
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument('--subset',
                        type=int,
                        default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=1, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument(
        "--cache_dir",
        default=None,
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--cache_feature_file", default=str)

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    if args.seed > 0:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)
    else:
        random_seed = random.randint(0, 10000)
        logger.info("Set random seed as: {}".format(random_seed))
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)

    tokenizer = TOKENIZER_CLASSES[args.model_type].from_pretrained(
        args.tokenizer_name,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None)

    if args.model_type == "roberta":
        vocab = tokenizer.encoder
    else:
        vocab = tokenizer.vocab

    tokenizer.max_len = args.max_seq_length

    config_file = args.config_path if args.config_path else os.path.join(
        args.model_path, "config.json")
    logger.info("Read decoding config from: %s" % config_file)
    config = BertConfig.from_json_file(config_file)
    retrieval_config = RetrievalConfig.from_json_file(config_file)
    embedding_model_config = BertConfig.from_json_file(config_file)

    concator = DecoderConcator(list(vocab.keys()),
                               tokenizer.convert_tokens_to_ids,
                               args.max_seq_length,
                               max_tgt_length=args.max_tgt_length,
                               pos_shift=args.pos_shift,
                               source_type_id=config.source_type_id,
                               target_type_id=config.target_type_id,
                               cls_token=tokenizer.cls_token,
                               sep_token=tokenizer.sep_token,
                               pad_token=tokenizer.pad_token)

    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        [tokenizer.mask_token, tokenizer.sep_token, tokenizer.sep_token])
    forbid_ignore_set = None
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))
    print(args.model_path)
    found_checkpoint_flag = False
    for model_recover_path in [args.model_path.strip()]:
        logger.info("***** Recover model: %s *****", model_recover_path)
        found_checkpoint_flag = True
        model = BertForRetrievalSeq2SeqDecoder.from_pretrained(
            model_recover_path,
            config=config,
            r_config=retrieval_config,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            concator=concator,
            top_k=args.top_k,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            mode=args.mode,
            max_position_embeddings=args.max_seq_length,
            pos_shift=args.pos_shift,
        )

        embedding_model = BertForRetrieval.from_pretrained(
            args.embedding_model_path,
            config=retrieval_config,
            model_type=args.model_type,
            reuse_position_embedding=True,
            cache_dir=None)

        doc_features = utils.load_and_cache_doc_examples(
            example_file=args.doc_file,
            tokenizer=tokenizer,
            local_rank=-1,
            cached_doc_features_file=args.cache_feature_file)

        model.retrieval.features = doc_features

        doc_dataset = utils.RetrievalSeq2seqDocDatasetForBert(
            features=doc_features,
            max_source_len=args.max_seq_length - args.max_tgt_length - 2,
            max_target_len=args.max_tgt_length,
            vocab_size=tokenizer.vocab_size,
            cls_id=tokenizer.cls_token_id,
            sep_id=tokenizer.sep_token_id,
            pad_id=tokenizer.pad_token_id,
            mask_id=tokenizer.mask_token_id,
            random_prob=0.15,
            keep_prob=0.1,
            offset=0,
            num_training_instances=len(doc_features))

        doc_sampler = SequentialSampler(doc_dataset)
        doc_dataloader = DataLoader(
            doc_dataset,
            sampler=doc_sampler,
            batch_size=16,
            collate_fn=utils.batch_list_to_batch_tensors)

        doc_iterator = tqdm(doc_dataloader,
                            initial=0,
                            desc="Embeding docs:",
                            disable=False)

        all_embeds = []

        if args.fp16:
            model.half()
        model.to(device)
        embedding_model.to(device)

        if n_gpu > 1:
            model = torch.nn.DataParallel(model)
        model.eval()
        model.zero_grad()

        for step, batch in enumerate(doc_iterator):
            batch = tuple(t.to(device) for t in batch)
            with torch.no_grad():
                embeds = embedding_model.get_embeds(batch[0]) if hasattr(
                    model, "module") else embedding_model.get_embeds(batch[0])
            all_embeds.extend(embeds.view(-1, 768).detach().cpu().tolist())

        if hasattr(model, "module"):
            model.module.retrieval.doc_embeds = torch.tensor(
                all_embeds, dtype=torch.float32)

            model.module.retrieval.build_indexs_from_embeds(
                model.module.retrieval.doc_embeds)
        else:
            model.retrieval.doc_embeds = torch.tensor(all_embeds,
                                                      dtype=torch.float32)

            model.retrieval.build_indexs_from_embeds(
                model.retrieval.doc_embeds)

        logger.info("start Decoding")

        torch.cuda.empty_cache()
        model.eval()
        next_i = 0
        max_src_length = args.max_seq_length - 2 - args.max_tgt_length

        to_pred = utils.load_and_cache_examples(args.input_file,
                                                tokenizer,
                                                local_rank=-1,
                                                cached_features_file=None,
                                                shuffle=False)

        input_lines = []
        for line in to_pred:
            input_lines.append(
                tokenizer.convert_ids_to_tokens(
                    line["source_ids"])[:max_src_length])
        if args.subset > 0:
            logger.info("Decoding subset: %d", args.subset)
            input_lines = input_lines[:args.subset]

        # input_lines = sorted(list(enumerate(input_lines)),
        #                      key=lambda x: -len(x[1]))
        output_lines = [""] * len(input_lines)
        score_trace_list = [None] * len(input_lines)
        # total_batch = math.ceil(len(input_lines) / args.batch_size)
        total_batch = len(input_lines)

        with tqdm(total=total_batch) as pbar:
            batch_count = 0
            first_batch = True
            for line in input_lines:
                # while next_i < len(input_lines):
                #     _chunk = input_lines[next_i:next_i + args.batch_size]
                #     buf_id = [x[0] for x in _chunk]
                #     buf = [x[1] for x in _chunk]
                #     next_i += args.batch_size
                #     batch_count += 1
                #     max_a_len = max([len(x) for x in buf])

                #     instances = []
                # for instance in buf:

                # for instance in [(x, max_a_len) for x in buf]:
                #     for proc in bi_uni_pipeline:
                #         instances.append(proc(instance))
                with torch.no_grad():

                    traces = model(line)
                    if args.beam_size > 1:
                        traces = {k: v.tolist() for k, v in traces.items()}
                        output_ids = traces['pred_seq']
                    else:
                        output_ids = traces.tolist()
                        output_buf = tokenizer.convert_ids_to_tokens(
                            output_ids[0])
                        output_tokens = []
                        for t in output_buf:
                            if t in (tokenizer.sep_token, tokenizer.pad_token):
                                break
                            output_tokens.append(t)
                        output_sequence = ' '.join(detokenize(output_tokens))
                        if '\n' in output_sequence:
                            output_sequence = " [X_SEP] ".join(
                                output_sequence.split('\n'))

                        if first_batch or batch_count % 50 == 0:
                            logger.info("{} = {}".format(
                                batch_count, output_sequence))
                        output_lines.append(output_sequence)

                    # for i in range(len(buf)):
                    #     w_ids = output_ids[i]
                    #     output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                    #     output_tokens = []
                    #     for t in output_buf:
                    #         if t in (tokenizer.sep_token, tokenizer.pad_token):
                    #             break
                    #         output_tokens.append(t)
                    #     if args.model_type == "roberta":
                    #         output_sequence = tokenizer.convert_tokens_to_string(output_tokens)
                    #     else:
                    #         output_sequence = ' '.join(detokenize(output_tokens))
                    #     if '\n' in output_sequence:
                    #         output_sequence = " [X_SEP] ".join(output_sequence.split('\n'))
                    #     output_lines[buf_id[i]] = output_sequence
                    #     if first_batch or batch_count % 50 == 0:
                    #         logger.info("{} = {}".format(buf_id[i], output_sequence))
                    #     if args.need_score_traces:
                    #         score_trace_list[buf_id[i]] = {
                    #             'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]}
                batch_count += 1
                pbar.update(1)
                first_batch = False
        if args.output_file:
            fn_out = args.output_file
        else:
            fn_out = model_recover_path + '.' + args.split
        with open(fn_out, "w", encoding="utf-8") as fout:
            for l in output_lines:
                fout.write(l)
                fout.write("\n")

        if args.need_score_traces:
            with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                pickle.dump({
                    "version": 0.0,
                    "num_samples": len(input_lines)
                }, fout_trace)
                for x in score_trace_list:
                    pickle.dump(x, fout_trace)

    if not found_checkpoint_flag:
        logger.info("Not found the model checkpoint file!")