def display_batch(batch, tokenizer: BertTokenizer):
    id, sent_ids, qlen, chunk_lengths, chunk_tokens, segment_ids, sent_starts, sent_ends, sent_targets = batch
    assert chunk_tokens.shape[0] == len(chunk_lengths)
    chunk_tokens = chunk_tokens.numpy()
    segment_ids = segment_ids.numpy()
    sent_starts = sent_starts.numpy()
    sent_ends = sent_ends.numpy()
    sent_targets = sent_targets.numpy()
    logger.info(f'{id}')
    all_toks = []
    for ci in range(len(chunk_lengths)):
        clen = chunk_lengths[ci]
        chunk_toks = tokenizer.convert_ids_to_tokens(chunk_tokens[ci])
        all_toks.extend(chunk_toks[qlen:qlen + clen])
        segments = segment_ids[ci]
        logger.info(f'{str(list(zip(chunk_toks, segments)))}')
    for si in range(len(sent_ids)):
        logger.info(
            f'{sent_targets[si]} {sent_ids[si]} = {str(all_toks[sent_starts[si]:sent_ends[si]+1])}'
        )
Example #2
0
class BertCorrector(Detector):
    def __init__(self,
                 bert_model_dir='',
                 bert_model_vocab='',
                 max_seq_length=384):
        super(BertCorrector, self).__init__()
        self.name = 'bert_corrector'
        self.bert_model_dir = os.path.join(pwd_path, bert_model_dir)
        self.bert_model_vocab = os.path.join(pwd_path, bert_model_vocab)
        self.max_seq_length = max_seq_length
        self.initialized_bert_corrector = False

    def check_bert_corrector_initialized(self):
        if not self.initialized_bert_corrector:
            self.initialize_bert_corrector()

    def initialize_bert_corrector(self):
        t1 = time.time()
        self.bert_tokenizer = BertTokenizer(self.bert_model_vocab)
        # Prepare model
        self.model = BertForMaskedLM.from_pretrained(self.bert_model_dir)
        print("Loaded model: %s, vocab file: %s, spend: %.3f s." %
              (self.bert_model_dir, self.bert_model_vocab, time.time() - t1))
        self.initialized_bert_corrector = True

    def convert_sentence_to_features(self,
                                     sentence,
                                     tokenizer,
                                     max_seq_length,
                                     error_begin_idx=0,
                                     error_end_idx=0):
        """Loads a sentence into a list of `InputBatch`s."""
        self.check_bert_corrector_initialized()
        features = []
        tokens_a = list(sentence)

        # For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0      0   0   0  0    0   0
        tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
        k = error_begin_idx + 1
        for i in range(error_end_idx - error_begin_idx):
            tokens[k] = '[MASK]'
            k += 1
        segment_ids = [0] * len(tokens)

        input_ids = self.bert_tokenizer.convert_tokens_to_ids(tokens)
        mask_ids = [i for i, v in enumerate(input_ids) if v == MASK_ID]
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        features.append(
            InputFeatures(input_ids=input_ids,
                          input_mask=input_mask,
                          mask_ids=mask_ids,
                          segment_ids=segment_ids,
                          input_tokens=tokens))
        return features

    def check_vocab_has_all_token(self, sentence):
        self.check_bert_corrector_initialized()
        flag = True
        for i in list(sentence):
            if i not in self.bert_tokenizer.vocab:
                flag = False
                break
        return flag

    def bert_lm_infer(self, sentence, error_begin_idx=0, error_end_idx=0):
        self.check_bert_corrector_initialized()
        corrected_item = sentence[error_begin_idx:error_end_idx]
        eval_features = self.convert_sentence_to_features(
            sentence=sentence,
            tokenizer=self.bert_tokenizer,
            max_seq_length=self.max_seq_length,
            error_begin_idx=error_begin_idx,
            error_end_idx=error_end_idx)

        for f in eval_features:
            input_ids = torch.tensor([f.input_ids])
            segment_ids = torch.tensor([f.segment_ids])
            predictions = self.model(input_ids, segment_ids)
            # confirm we were able to predict 'henson'
            masked_ids = f.mask_ids
            if masked_ids:
                for idx, i in enumerate(masked_ids):
                    predicted_index = torch.argmax(predictions[0, i]).item()
                    predicted_token = self.bert_tokenizer.convert_ids_to_tokens(
                        [predicted_index])[0]
                    print('original text is:', f.input_tokens)
                    print('Mask predict is:', predicted_token)
                    corrected_item = predicted_token
        return corrected_item

    def correct(self, sentence=''):
        """
        句子改错
        :param sentence: 句子文本
        :return: 改正后的句子, list(wrong, right, begin_idx, end_idx)
        """
        detail = []
        maybe_errors = self.detect(sentence)
        maybe_errors = sorted(maybe_errors,
                              key=operator.itemgetter(2),
                              reverse=False)
        for item, begin_idx, end_idx, err_type in maybe_errors:
            # 纠错,逐个处理
            before_sent = sentence[:begin_idx]
            after_sent = sentence[end_idx:]

            # 困惑集中指定的词,直接取结果
            if err_type == error_type["confusion"]:
                corrected_item = self.custom_confusion[item]
            elif err_type == error_type["char"]:
                # 对非中文的错字不做处理
                if not is_chinese_string(item):
                    continue
                if not self.check_vocab_has_all_token(sentence):
                    continue
                # 取得所有可能正确的字
                corrected_item = self.bert_lm_infer(sentence,
                                                    error_begin_idx=begin_idx,
                                                    error_end_idx=end_idx)
            elif err_type == error_type["word"]:
                corrected_item = item
            else:
                print('not strand error_type')
            # output
            if corrected_item != item:
                sentence = before_sent + corrected_item + after_sent
                detail_word = [item, corrected_item, begin_idx, end_idx]
                detail.append(detail_word)
        detail = sorted(detail, key=operator.itemgetter(2))
        return sentence, detail
Example #3
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--bert_model_dir", default=None, type=str, required=True,
                        help="Bert pre-trained model config dir")
    parser.add_argument("--bert_model_vocab", default=None, type=str, required=True,
                        help="Bert pre-trained model vocab path")
    parser.add_argument("--output_dir", default="./output", type=str, required=True,
                        help="The output directory where the model checkpoints and predictions will be written.")

    # Other parameters
    parser.add_argument("--predict_file", default=None, type=str,
                        help="for predictions.")
    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--verbose_logging", default=False, action='store_true',
                        help="If true, all of the warnings related to data processing will be printed. "
                             "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()

    device = torch.device("cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer(args.bert_model_vocab)

    # Prepare model
    model = BertForMaskedLM.from_pretrained(args.bert_model_dir)

    # Save a trained model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    if not os.path.exists(output_model_file):
        torch.save(model_to_save.state_dict(), output_model_file)

    # Load a trained model that you have fine-tuned
    model_state_dict = torch.load(output_model_file)
    model.to(device)

    # Tokenized input
    text = "吸 烟 的 人 容 易 得 癌 症"
    print(text)
    tokenized_text = tokenizer.tokenize(text)

    # Mask a token that we will try to predict back with `BertForMaskedLM`
    masked_index = 8
    tokenized_text[masked_index] = '[MASK]'
    print(tokenized_text)

    # Convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
    segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0]

    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    # Load pre-trained model (weights)
    model.eval()

    # Predict all tokens
    predictions = model(tokens_tensor, segments_tensors)

    # confirm we were able to predict 'henson'
    predicted_index = torch.argmax(predictions[0, masked_index]).item()
    print(predicted_index)
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
    print(predicted_token)
    # infer one line end

    if args.predict_file:
        eval_examples = read_lm_examples(input_file=args.predict_file)
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("Start predict ...")
        for f in eval_features:
            input_ids = torch.tensor([f.input_ids])
            segment_ids = torch.tensor([f.segment_ids])
            predictions = model(input_ids, segment_ids)
            # confirm we were able to predict 'henson'
            masked_ids = f.mask_ids
            if masked_ids:
                print(masked_ids)
                for idx, i in enumerate(masked_ids):
                    predicted_index = torch.argmax(predictions[0, i]).item()
                    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
                    print('original text is:', f.input_tokens)
                    print('Mask predict is:', predicted_token)
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--bert_model_dir",
                        default='../data/bert_models/chinese_finetuned_lm/',
                        type=str,
                        help="Bert pre-trained model config dir")
    parser.add_argument(
        "--bert_model_vocab",
        default='../data/bert_models/chinese_finetuned_lm/vocab.txt',
        type=str,
        help="Bert pre-trained model vocab path")
    parser.add_argument(
        "--output_dir",
        default="./output",
        type=str,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    # Other parameters
    parser.add_argument("--predict_file",
                        default='../data/cn/lm_test_zh.txt',
                        type=str,
                        help="for predictions.")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=64,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument(
        "--verbose_logging",
        default=False,
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()

    device = torch.device("cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer(args.bert_model_vocab)
    MASK_ID = tokenizer.convert_tokens_to_ids([MASK_TOKEN])[0]
    print('MASK_ID,', MASK_ID)

    # Prepare model
    model = BertForMaskedLM.from_pretrained(args.bert_model_dir)

    # Save a trained model
    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    if not os.path.exists(output_model_file):
        torch.save(model_to_save.state_dict(), output_model_file)

    # Load a trained model that you have fine-tuned
    model_state_dict = torch.load(output_model_file)
    model.to(device)

    # Tokenized input
    text = "吸烟的人容易得癌症"
    tokenized_text = tokenizer.tokenize(text)
    print(text, '=>', tokenized_text)

    # Mask a token that we will try to predict back with `BertForMaskedLM`
    masked_index = 8
    tokenized_text[masked_index] = '[MASK]'

    # Convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
    segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0]

    # Convert inputs to PyTorch tensors
    print('tokens, segments_ids:', indexed_tokens, segments_ids)
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    # Load pre-trained model (weights)
    model.eval()
    # Predict all tokens
    predictions = model(tokens_tensor, segments_tensors)
    predicted_index = torch.argmax(predictions[0, masked_index]).item()
    print(predicted_index)
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
    print(predicted_token)
    # infer one line end

    # predict ppl and prob of each word
    text = "吸烟的人容易得癌症"
    tokenized_text = tokenizer.tokenize(text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
    segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0]
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    sentence_loss = 0.0
    sentence_count = 0
    for idx, label in enumerate(text):
        print(label)
        label_id = tokenizer.convert_tokens_to_ids([label])[0]
        lm_labels = [-1, -1, -1, -1, -1, -1, -1, -1, -1]
        if idx != 0:
            lm_labels[idx] = label_id
        if idx == 1:
            lm_labels = indexed_tokens
        print(lm_labels)
        masked_lm_labels = torch.tensor([lm_labels])

        # Predict all tokens
        loss = model(tokens_tensor,
                     segments_tensors,
                     masked_lm_labels=masked_lm_labels)
        print('loss:', loss)
        prob = float(np.exp(-loss.item()))
        print('prob:', prob)
        sentence_loss += prob
        sentence_count += 1
    ppl = float(np.exp(sentence_loss / sentence_count))
    print('ppl:', ppl)

    # confirm we were able to predict 'henson'
    # infer each word with mask one
    text = "吸烟的人容易得癌症"
    for masked_index, label in enumerate(text):
        tokenized_text = tokenizer.tokenize(text)
        print(text, '=>', tokenized_text)
        tokenized_text[masked_index] = '[MASK]'
        print(tokenized_text)

        # Convert token to vocabulary indices
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens_tensor = torch.tensor([indexed_tokens])
        segments_tensors = torch.tensor([segments_ids])
        predictions = model(tokens_tensor, segments_tensors)
        print('expected label:', label)

        predicted_index = torch.argmax(predictions[0, masked_index]).item()
        predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
        print('predict label:', predicted_token)

        scores = predictions[0, masked_index]
        # predicted_index = torch.argmax(scores).item()
        top_scores = torch.sort(scores, 0, True)
        top_score_val = top_scores[0][:5]
        top_score_idx = top_scores[1][:5]
        for j in range(len(top_score_idx)):
            print(
                'Mask predict is:',
                tokenizer.convert_ids_to_tokens([top_score_idx[j].item()])[0],
                ' prob:', top_score_val[j].item())
        print()

    if args.predict_file:
        eval_examples = read_lm_examples(input_file=args.predict_file)
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            mask_token=MASK_TOKEN,
            mask_id=MASK_ID)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("Start predict ...")
        for f in eval_features:
            input_ids = torch.tensor([f.input_ids])
            segment_ids = torch.tensor([f.segment_ids])
            predictions = model(input_ids, segment_ids)
            # confirm we were able to predict 'henson'
            mask_positions = f.mask_positions

            if mask_positions:
                for idx, i in enumerate(mask_positions):
                    if not i:
                        continue
                    scores = predictions[0, i]
                    # predicted_index = torch.argmax(scores).item()
                    top_scores = torch.sort(scores, 0, True)
                    top_score_val = top_scores[0][:5]
                    top_score_idx = top_scores[1][:5]
                    # predicted_prob = predictions[0, i][predicted_index].item()
                    # predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
                    print('original text is:', f.input_tokens)
                    # print('Mask predict is:', predicted_token, ' prob:', predicted_prob)
                    for j in range(len(top_score_idx)):
                        print(
                            'Mask predict is:',
                            tokenizer.convert_ids_to_tokens(
                                [top_score_idx[j].item()])[0], ' prob:',
                            top_score_val[j].item())
Example #5
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")

    # decoding parameters
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument('--subset',
                        type=int,
                        default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument('--topk', type=int, default=10, help="Value of K.")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Ignore the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--not_predict_token',
                        type=str,
                        default=None,
                        help="Do not predict the tokens during decoding.")

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # tokenizer = BertTokenizer.from_pretrained(
    #     args.bert_model, do_lower_case=args.do_lower_case)
    tokenizer = BertTokenizer(
        vocab_file=
        '/ps2/intern/clsi/BERT/bert_weights/cased_L-24_H-1024_A-16/vocab.txt',
        do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            new_segment_ids=args.new_segment_ids,
            mode="s2s",
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    def _get_token_id_set(s):
        r = None
        if s:
            w_list = []
            for w in s.split('|'):
                if w.startswith('[') and w.endswith(']'):
                    w_list.append(w.upper())
                else:
                    w_list.append(w)
            r = set(tokenizer.convert_tokens_to_ids(w_list))
        return r

    forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word)
    not_predict_set = _get_token_id_set(args.not_predict_token)
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=pair_num_relation,
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            not_predict_set=not_predict_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            mode=args.mode,
            max_position_embeddings=args.max_seq_length,
            ffn_type=args.ffn_type,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            pos_shift=args.pos_shift,
            topk=args.topk,
            config_path=args.config_path)
        del model_recover

        if args.fp16:
            model.half()
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()
        next_i = 0
        max_src_length = args.max_seq_length - 2 - args.max_tgt_length

        ## for YFG style json
        # testset = loads_json(args.input_file, 'Load Test Set: '+args.input_file)
        # if args.subset > 0:
        #     logger.info("Decoding subset: %d", args.subset)
        #     testset = testset[:args.subset]

        with open(args.input_file, encoding="utf-8") as fin:
            data = json.load(fin)
        #     input_lines = [x.strip() for x in fin.readlines()]
        #     if args.subset > 0:
        #         logger.info("Decoding subset: %d", args.subset)
        #         input_lines = input_lines[:args.subset]
        # data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
        # input_lines = [data_tokenizer.tokenize(
        #     x)[:max_src_length] for x in input_lines]
        # input_lines = sorted(list(enumerate(input_lines)),
        #                      key=lambda x: -len(x[1]))
        # output_lines = [""] * len(input_lines)
        # score_trace_list = [None] * len(input_lines)
        # total_batch = math.ceil(len(input_lines) / args.batch_size)

        data_tokenizer = WhitespaceTokenizer(
        ) if args.tokenized_input else tokenizer
        PQA_dict = {}  #will store the generated distractors
        dis_tot = 0
        dis_n = 0
        len_tot = 0
        hypothesis = {}
        ##change to process one by one and store the distractors in PQA json form
        ##with tqdm(total=total_batch) as pbar:
        # for example in tqdm(testset):
        #     question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id'])
        #     if question_id in hypothesis:
        #         continue
        # dis_n += 1
        # if dis_n % 2000 == 0:
        #     logger.info("Already processed: "+str(dis_n))
        counter = 0
        for race_id, example in tqdm(data.items()):
            counter += 1
            if args.subset > 0 and counter >= args.subset:
                break
            eg_dict = {}
            # eg_dict["question_id"] = question_id
            # eg_dict["question"] = ' '.join(example['question'])
            # eg_dict["context"] = ' '.join(example['article'])

            eg_dict["question"] = example['question']
            eg_dict["context"] = example['context']
            label = int(example["label"])
            options = example["options"]
            answer = options[label]
            #new_distractors = []
            pred1 = None
            pred2 = None
            pred3 = None
            #while next_i < len(input_lines):
            #_chunk = input_lines[next_i:next_i + args.batch_size]
            #line = example["context"].strip() + ' ' + example["question"].strip()
            question = example['question']
            question = question.replace('_', ' ')
            line = ' '.join(
                nltk.word_tokenize(example['context']) +
                nltk.word_tokenize(question))
            line = [data_tokenizer.tokenize(line)[:max_src_length]]
            # buf_id = [x[0] for x in _chunk]
            # buf = [x[1] for x in _chunk]
            buf = line
            #next_i += args.batch_size
            max_a_len = max([len(x) for x in buf])
            instances = []
            for instance in [(x, max_a_len) for x in buf]:
                for proc in bi_uni_pipeline:
                    instances.append(proc(instance))
            with torch.no_grad():
                batch = seq2seq_loader.batch_list_to_batch_tensors(instances)
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                # for i in range(1):
                #try max 10 times
                # if len(new_distractors) >= 3:
                #     break
                traces = model(input_ids,
                               token_type_ids,
                               position_ids,
                               input_mask,
                               task_idx=task_idx,
                               mask_qkv=mask_qkv)
                if args.beam_size > 1:
                    traces = {k: v.tolist() for k, v in traces.items()}
                    output_ids = traces['pred_seq']
                    # print (np.array(output_ids).shape)
                    # print (output_ids)
                else:
                    output_ids = traces.tolist()
                # now only supports single batch decoding!!!
                # will keep the second and third sequence as backup
                for i in range(len(buf)):
                    # print (len(buf), buf)
                    for s in range(len(output_ids)):
                        output_seq = output_ids[s]
                        #w_ids = output_ids[i]
                        #output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                        output_buf = tokenizer.convert_ids_to_tokens(
                            output_seq)
                        output_tokens = []
                        for t in output_buf:
                            if t in ("[SEP]", "[PAD]"):
                                break
                            output_tokens.append(t)
                        if s == 1:
                            backup_1 = output_tokens
                        if s == 2:
                            backup_2 = output_tokens
                        if pred1 is None:
                            pred1 = output_tokens
                        elif jaccard_similarity(pred1, output_tokens) < 0.5:
                            if pred2 is None:
                                pred2 = output_tokens
                            elif pred3 is None:
                                if jaccard_similarity(pred2,
                                                      output_tokens) < 0.5:
                                    pred3 = output_tokens
                        if pred1 is not None and pred2 is not None and pred3 is not None:
                            break
                    if pred2 is None:
                        pred2 = backup_1
                        if pred3 is None:
                            pred3 = backup_2
                    elif pred3 is None:
                        pred3 = backup_1
                        # output_sequence = ' '.join(detokenize(output_tokens))
                        # print (output_sequence)
                        # print (output_sequence)
                        # if output_sequence.lower().strip() == answer.lower().strip():
                        #     continue
                        # repeated = False
                        # for cand in new_distractors:
                        #     if output_sequence.lower().strip() == cand.lower().strip():
                        #         repeated = True
                        #         break
                        # if not repeated:
                        #     new_distractors.append(output_sequence.strip())

            #hypothesis[question_id] = [pred1, pred2, pred3]
            new_distractors = [pred1, pred2, pred3]
            # print (new_distractors)
            # dis_tot += len(new_distractors)
            # # fill the missing ones with original distractors
            # for i in range(4):
            #     if len(new_distractors) >= 3:
            #         break
            #     elif i == label:
            #         continue
            #     else:
            #         new_distractors.append(options[i])
            for dis in new_distractors:
                len_tot += len(dis)
                dis_n += 1
            new_distractors = [
                ' '.join(detokenize(dis)) for dis in new_distractors
                if dis is not None
            ]
            assert len(new_distractors) == 3, "Number of distractors WRONG"
            new_distractors.insert(label, answer)
            #eg_dict["generated_distractors"] = new_distractors
            eg_dict["options"] = new_distractors
            eg_dict["label"] = label
            #PQA_dict[question_id] = eg_dict
            PQA_dict[race_id] = eg_dict

        # reference = {}
        # for example in testset:
        #     question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id'])
        #     if question_id not in reference.keys():
        #         reference[question_id] = [example['distractor']]
        #     else:
        #         reference[question_id].append(example['distractor'])

        # _ = eval(hypothesis, reference)
        # assert len(PQA_dict) == len(data), "Number of examples WRONG"
        # logger.info("Average number of GENERATED distractor per question: "+str(dis_tot/dis_n))
        logger.info("Average length of distractors: " + str(len_tot / dis_n))
        with open(args.output_file, mode='w', encoding='utf-8') as f:
            json.dump(PQA_dict, f, indent=4)