Esempio n. 1
0
    def test_full_tokenizer(self):
        vocab_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un",
            "runn", "##ing", ","
        ]
        with open("/tmp/bert_tokenizer_test.txt", "w",
                  encoding='utf-8') as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

            vocab_file = vocab_writer.name

        tokenizer = BertTokenizer(vocab_file)
        os.remove(vocab_file)

        tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
        self.assertListEqual(tokens,
                             ["un", "##want", "##ed", ",", "runn", "##ing"])

        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
                             [7, 4, 5, 10, 8, 9])

        vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
        tokenizer.from_pretrained(vocab_file)
        os.remove(vocab_file)

        tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
        self.assertListEqual(tokens,
                             ["un", "##want", "##ed", ",", "runn", "##ing"])

        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
                             [7, 4, 5, 10, 8, 9])
Esempio n. 2
0
    def test_full_tokenizer_raises_error_for_long_sequences(self):
        vocab_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
            "##ing", ","
        ]
        with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
            vocab_file = vocab_writer.name

        tokenizer = BertTokenizer(vocab_file, max_len=10)
        os.remove(vocab_file)
        tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time")
        indices = tokenizer.convert_tokens_to_ids(tokens)
        self.assertListEqual(indices, [0 for _ in range(10)])

        tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time .")
        self.assertRaises(ValueError, tokenizer.convert_tokens_to_ids, tokens)
Esempio n. 3
0
    def __init__(self, tokens: [str, ...], tokenizer: BertTokenizer,
                 max_seq_length: int):
        self.max_seq_length = max_seq_length if max_seq_length else 0
        self.tokens = [CLS] + tokenizer.tokenize(" ".join(tokens)) + [SEP]
        self.token_mask_ids = [
            idx for idx, token in enumerate(self.tokens) if token == MASK
        ]
        self.len = len(self.tokens)

        if self.max_seq_length and self.len > self.max_seq_length:
            logger.warning("'tokens_a' is over {}: {}".format(
                max_seq_length, self.len))
            # raise RuntimeError("'tokens_a' is over {}: {}".format(max_seq_length, self.len))
        else:
            self.input_ids = tokenizer.convert_tokens_to_ids(
                self.tokens) + [0] * max(self.max_seq_length - self.len, 0)
            self.attention_mask = [1] * self.len + [0] * max(
                self.max_seq_length - self.len, 0)
Esempio n. 4
0
def convert_example_to_features(uid: int, text_a: str, seq_length: int,
                                tokenizer: BertTokenizer) -> InputFeatures:
    tokens_a = tokenizer.tokenize(text_a)

    # 我们只处理一个句子,对长句截断, 所以只需要头尾附加CLS/SEP
    tokens_a = tokens_a[:seq_length - 2]

    # For single sequences:
    #  tokens:   [CLS] the dog is hairy . [SEP]
    #  type_ids: 0     0   0   0  0     0 0
    #
    # Where "type_ids" are used to indicate whether this is the first
    # sequence or the second sequence. The embedding vectors for `type=0` and
    # `type=1` were learned during pre-training and are added to the wordpiece
    # embedding vector (and position vector). This is not *strictly* necessary
    # since the [SEP] token unambiguously separates the sequences, but it makes
    # it easier for the model to learn the concept of sequences.
    #
    # For classification tasks, the first vector (corresponding to [CLS]) is
    # used as as the "sentence vector". Note that this only makes sense because
    # the entire model is fine-tuned.
    tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
    input_type_ids = [0] * (len(tokens_a) + 2)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # Zero-pad up to the sequence length.
    x = [0] * (seq_length - len(input_ids))
    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids) + x
    input_ids += x
    input_type_ids += x

    assert len(input_ids) == seq_length
    assert len(input_mask) == seq_length
    assert len(input_type_ids) == seq_length
    return InputFeatures(
        unique_id=uid,
        tokens=tokens,
        input_ids=input_ids,
        input_mask=input_mask,
        input_type_ids=input_type_ids,
    )
Esempio n. 5
0
    def __init__(self, text: str, tokenizer: BertTokenizer,
                 max_seq_length: int, language: str):
        self.max_seq_length = max_seq_length if max_seq_length else 0
        self.lang = language
        if self.lang == "ja":
            tokens = []
            for v, group in groupby(text, key=lambda x: x == "M"):
                if v:
                    tokens += [MASK for _ in group]
                else:
                    tokens += [
                        morph.midasi
                        for morph in juman.analysis("".join(group))
                    ]
        elif self.lang == "en":
            tokens = [
                MASK if token == "M" else token for token in text.split(" ")
            ]
        else:
            raise ValueError("Unsupported value: {}".format(self.lang))

        self.original_tokens = tokens
        self.original_token_mask_ids = [
            idx for idx, token in enumerate(self.original_tokens)
            if token == MASK
        ]
        self.tokens = [CLS] + tokenizer.tokenize(" ".join(tokens)) + [SEP]
        self.token_mask_ids = [
            idx for idx, token in enumerate(self.tokens) if token == MASK
        ]
        self.len = len(self.tokens)

        if self.max_seq_length and self.len > self.max_seq_length:
            raise RuntimeError("'tokens_a' is over {}: {}".format(
                max_seq_length, self.len))

        self.input_ids = tokenizer.convert_tokens_to_ids(
            self.tokens) + [0] * max(self.max_seq_length - self.len, 0)
        self.attention_mask = [1] * self.len + [0] * max(
            self.max_seq_length - self.len, 0)
    def convert_examples_to_features(examples: List[SQuADFullExample],
                                     tokenizer: BertTokenizer, max_seq_length,
                                     doc_stride, max_query_length,
                                     is_training: bool):
        """Loads a data file into a list of `InputBatch`s."""

        unique_id = 1000000000
        features = []
        for (example_index,
             example) in tqdm(enumerate(examples),
                              desc='Convert examples to features',
                              total=len(examples)):
            query_tokens = tokenizer.tokenize(example.question_text)

            if len(query_tokens) > max_query_length:
                # query_tokens = query_tokens[0:max_query_length]
                # Remove the tokens appended at the front of query, which may belong to last query and answer.
                query_tokens = query_tokens[-max_query_length:]

            # word piece index -> token index
            tok_to_orig_index = []
            # token index -> word pieces group start index
            # BertTokenizer.tokenize(doc_tokens[i]) = all_doc_tokens[orig_to_tok_index[i]: orig_to_tok_index[i + 1]]
            orig_to_tok_index = []
            # word pieces for all doc tokens
            all_doc_tokens = []
            for (i, token) in enumerate(example.doc_tokens):
                orig_to_tok_index.append(len(all_doc_tokens))
                sub_tokens = tokenizer.tokenize(token)
                for sub_token in sub_tokens:
                    tok_to_orig_index.append(i)
                    all_doc_tokens.append(sub_token)

            # Process sentence span list
            sentence_spans = []
            for (start, end) in example.sentence_span_list:
                piece_start = orig_to_tok_index[start]
                if end < len(example.doc_tokens) - 1:
                    piece_end = orig_to_tok_index[end + 1] - 1
                else:
                    piece_end = len(all_doc_tokens) - 1
                sentence_spans.append((piece_start, piece_end))

            # The -3 accounts for [CLS], [SEP] and [SEP]
            max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

            # We can have documents that are longer than the maximum sequence length.
            # To deal with this we do a sliding window approach, where we take chunks
            # of the up to our max length with a stride of `doc_stride`.
            _DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
            doc_spans = []
            start_offset = 0
            while start_offset < len(all_doc_tokens):
                length = len(all_doc_tokens) - start_offset
                if length > max_tokens_for_doc:
                    length = max_tokens_for_doc
                doc_spans.append(_DocSpan(start=start_offset, length=length))
                if start_offset + length == len(all_doc_tokens):
                    break
                start_offset += min(length, doc_stride)

            sentence_spans_list = []
            sentence_ids_list = []
            for span_id, doc_span in enumerate(doc_spans):
                span_start = doc_span.start
                span_end = span_start + doc_span.length - 1

                span_sentence = []
                sen_ids = []

                for sen_idx, (sen_start, sen_end) in enumerate(sentence_spans):
                    if sen_end < span_start:
                        continue
                    if sen_start > span_end:
                        break
                    span_sentence.append(
                        (max(sen_start, span_start), min(sen_end, span_end)))
                    sen_ids.append(sen_idx)

                sentence_spans_list.append(span_sentence)
                sentence_ids_list.append(sen_ids)

            ini_sen_id = example.sentence_id
            for (doc_span_index, doc_span) in enumerate(doc_spans):
                # Store the input tokens to transform into input ids later.
                tokens = []
                token_to_orig_map = {}
                token_is_max_context = {}
                segment_ids = []
                tokens.append("[CLS]")
                segment_ids.append(0)
                for token in query_tokens:
                    tokens.append(token)
                    segment_ids.append(0)
                tokens.append("[SEP]")
                segment_ids.append(0)

                doc_start = doc_span.start
                doc_offset = len(query_tokens) + 2
                sentence_list = sentence_spans_list[doc_span_index]
                cur_sentence_list = []
                for sen_id, sen in enumerate(sentence_list):
                    new_sen = (sen[0] - doc_start + doc_offset,
                               sen[1] - doc_start + doc_offset)
                    cur_sentence_list.append(new_sen)

                for i in range(doc_span.length):
                    split_token_index = doc_span.start + i  # Original index of word piece in all_doc_tokens
                    # Index of word piece in input sequence -> Original word index in doc_tokens
                    token_to_orig_map[len(
                        tokens)] = tok_to_orig_index[split_token_index]
                    # Check if the word piece has the max context in all doc spans.
                    is_max_context = utils.check_is_max_context(
                        doc_spans, doc_span_index, split_token_index)

                    token_is_max_context[len(tokens)] = is_max_context
                    tokens.append(all_doc_tokens[split_token_index])
                    segment_ids.append(1)
                tokens.append("[SEP]")
                segment_ids.append(1)

                input_ids = tokenizer.convert_tokens_to_ids(tokens)

                # The mask has 1 for real tokens and 0 for padding tokens. Only real
                # tokens are attended to.
                input_mask = [1] * len(input_ids)

                # Zero-pad up to the sequence length.
                while len(input_ids) < max_seq_length:
                    input_ids.append(0)
                    input_mask.append(0)
                    segment_ids.append(0)

                assert len(input_ids) == max_seq_length
                assert len(input_mask) == max_seq_length
                assert len(segment_ids) == max_seq_length

                # ral_start = None
                # ral_end = None
                # answer_choice = None

                answer_choice = -1

                # Process sentence id
                span_sen_id = -1
                for piece_sen_id, sen_id in enumerate(
                        sentence_ids_list[doc_span_index]):
                    if ini_sen_id == sen_id:
                        span_sen_id = piece_sen_id
                meta_data = {
                    'span_sen_to_orig_sen_map':
                    sentence_ids_list[doc_span_index]
                }

                if example_index < 0:
                    logger.info("*** Example ***")
                    logger.info("unique_id: %s" % unique_id)
                    logger.info("example_index: %s" % example_index)
                    logger.info("doc_span_index: %s" % doc_span_index)
                    logger.info("sentence_spans_list: %s" %
                                " ".join([(str(x[0]) + '-' + str(x[1]))
                                          for x in cur_sentence_list]))

                    logger.info("answer choice: %s" % str(answer_choice))

                features.append(
                    QAFullInputFeatures(
                        qas_id=example.qas_id,
                        unique_id=unique_id,
                        example_index=example_index,
                        doc_span_index=doc_span_index,
                        sentence_span_list=cur_sentence_list,
                        tokens=tokens,
                        token_to_orig_map=token_to_orig_map,
                        token_is_max_context=token_is_max_context,
                        input_ids=input_ids,
                        input_mask=input_mask,
                        segment_ids=segment_ids,
                        is_impossible=answer_choice,
                        sentence_id=span_sen_id,
                        start_position=None,
                        end_position=None,
                        ral_start_position=-1,
                        ral_end_position=-1,
                        meta_data=meta_data))

                unique_id += 1

        return features
Esempio n. 7
0
class BertCorrector(Detector):
    def __init__(self,
                 bert_model_dir='',
                 bert_model_vocab='',
                 max_seq_length=384):
        super(BertCorrector, self).__init__()
        self.name = 'bert_corrector'
        self.bert_model_dir = os.path.join(pwd_path, bert_model_dir)
        self.bert_model_vocab = os.path.join(pwd_path, bert_model_vocab)
        self.max_seq_length = max_seq_length
        self.initialized_bert_corrector = False

    def check_bert_corrector_initialized(self):
        if not self.initialized_bert_corrector:
            self.initialize_bert_corrector()

    def initialize_bert_corrector(self):
        t1 = time.time()
        self.bert_tokenizer = BertTokenizer(self.bert_model_vocab)
        # Prepare model
        self.model = BertForMaskedLM.from_pretrained(self.bert_model_dir)
        print("Loaded model: %s, vocab file: %s, spend: %.3f s." %
              (self.bert_model_dir, self.bert_model_vocab, time.time() - t1))
        self.initialized_bert_corrector = True

    def convert_sentence_to_features(self,
                                     sentence,
                                     tokenizer,
                                     max_seq_length,
                                     error_begin_idx=0,
                                     error_end_idx=0):
        """Loads a sentence into a list of `InputBatch`s."""
        self.check_bert_corrector_initialized()
        features = []
        tokens_a = list(sentence)

        # For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0      0   0   0  0    0   0
        tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
        k = error_begin_idx + 1
        for i in range(error_end_idx - error_begin_idx):
            tokens[k] = '[MASK]'
            k += 1
        segment_ids = [0] * len(tokens)

        input_ids = self.bert_tokenizer.convert_tokens_to_ids(tokens)
        mask_ids = [i for i, v in enumerate(input_ids) if v == MASK_ID]
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        features.append(
            InputFeatures(input_ids=input_ids,
                          input_mask=input_mask,
                          mask_ids=mask_ids,
                          segment_ids=segment_ids,
                          input_tokens=tokens))
        return features

    def check_vocab_has_all_token(self, sentence):
        self.check_bert_corrector_initialized()
        flag = True
        for i in list(sentence):
            if i not in self.bert_tokenizer.vocab:
                flag = False
                break
        return flag

    def bert_lm_infer(self, sentence, error_begin_idx=0, error_end_idx=0):
        self.check_bert_corrector_initialized()
        corrected_item = sentence[error_begin_idx:error_end_idx]
        eval_features = self.convert_sentence_to_features(
            sentence=sentence,
            tokenizer=self.bert_tokenizer,
            max_seq_length=self.max_seq_length,
            error_begin_idx=error_begin_idx,
            error_end_idx=error_end_idx)

        for f in eval_features:
            input_ids = torch.tensor([f.input_ids])
            segment_ids = torch.tensor([f.segment_ids])
            predictions = self.model(input_ids, segment_ids)
            # confirm we were able to predict 'henson'
            masked_ids = f.mask_ids
            if masked_ids:
                for idx, i in enumerate(masked_ids):
                    predicted_index = torch.argmax(predictions[0, i]).item()
                    predicted_token = self.bert_tokenizer.convert_ids_to_tokens(
                        [predicted_index])[0]
                    print('original text is:', f.input_tokens)
                    print('Mask predict is:', predicted_token)
                    corrected_item = predicted_token
        return corrected_item

    def correct(self, sentence=''):
        """
        句子改错
        :param sentence: 句子文本
        :return: 改正后的句子, list(wrong, right, begin_idx, end_idx)
        """
        detail = []
        maybe_errors = self.detect(sentence)
        maybe_errors = sorted(maybe_errors,
                              key=operator.itemgetter(2),
                              reverse=False)
        for item, begin_idx, end_idx, err_type in maybe_errors:
            # 纠错,逐个处理
            before_sent = sentence[:begin_idx]
            after_sent = sentence[end_idx:]

            # 困惑集中指定的词,直接取结果
            if err_type == error_type["confusion"]:
                corrected_item = self.custom_confusion[item]
            elif err_type == error_type["char"]:
                # 对非中文的错字不做处理
                if not is_chinese_string(item):
                    continue
                if not self.check_vocab_has_all_token(sentence):
                    continue
                # 取得所有可能正确的字
                corrected_item = self.bert_lm_infer(sentence,
                                                    error_begin_idx=begin_idx,
                                                    error_end_idx=end_idx)
            elif err_type == error_type["word"]:
                corrected_item = item
            else:
                print('not strand error_type')
            # output
            if corrected_item != item:
                sentence = before_sent + corrected_item + after_sent
                detail_word = [item, corrected_item, begin_idx, end_idx]
                detail.append(detail_word)
        detail = sorted(detail, key=operator.itemgetter(2))
        return sentence, detail
Esempio n. 8
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--bert_model_dir", default=None, type=str, required=True,
                        help="Bert pre-trained model config dir")
    parser.add_argument("--bert_model_vocab", default=None, type=str, required=True,
                        help="Bert pre-trained model vocab path")
    parser.add_argument("--output_dir", default="./output", type=str, required=True,
                        help="The output directory where the model checkpoints and predictions will be written.")

    # Other parameters
    parser.add_argument("--predict_file", default=None, type=str,
                        help="for predictions.")
    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--verbose_logging", default=False, action='store_true',
                        help="If true, all of the warnings related to data processing will be printed. "
                             "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()

    device = torch.device("cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer(args.bert_model_vocab)

    # Prepare model
    model = BertForMaskedLM.from_pretrained(args.bert_model_dir)

    # Save a trained model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    if not os.path.exists(output_model_file):
        torch.save(model_to_save.state_dict(), output_model_file)

    # Load a trained model that you have fine-tuned
    model_state_dict = torch.load(output_model_file)
    model.to(device)

    # Tokenized input
    text = "吸 烟 的 人 容 易 得 癌 症"
    print(text)
    tokenized_text = tokenizer.tokenize(text)

    # Mask a token that we will try to predict back with `BertForMaskedLM`
    masked_index = 8
    tokenized_text[masked_index] = '[MASK]'
    print(tokenized_text)

    # Convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
    segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0]

    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    # Load pre-trained model (weights)
    model.eval()

    # Predict all tokens
    predictions = model(tokens_tensor, segments_tensors)

    # confirm we were able to predict 'henson'
    predicted_index = torch.argmax(predictions[0, masked_index]).item()
    print(predicted_index)
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
    print(predicted_token)
    # infer one line end

    if args.predict_file:
        eval_examples = read_lm_examples(input_file=args.predict_file)
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("Start predict ...")
        for f in eval_features:
            input_ids = torch.tensor([f.input_ids])
            segment_ids = torch.tensor([f.segment_ids])
            predictions = model(input_ids, segment_ids)
            # confirm we were able to predict 'henson'
            masked_ids = f.mask_ids
            if masked_ids:
                print(masked_ids)
                for idx, i in enumerate(masked_ids):
                    predicted_index = torch.argmax(predictions[0, i]).item()
                    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
                    print('original text is:', f.input_tokens)
                    print('Mask predict is:', predicted_token)
Esempio n. 9
0
    def convert_examples_to_features(examples: List[QAFullExample],
                                     tokenizer: BertTokenizer, max_seq_length,
                                     doc_stride, max_query_length,
                                     is_training: bool):
        """Loads a data file into a list of `InputBatch`s."""

        unique_id = 1000000000
        features = []
        drop = 0
        for (example_index,
             example) in tqdm(enumerate(examples),
                              desc='Convert examples to features',
                              total=len(examples)):
            query_tokens = tokenizer.tokenize(example.question_text)

            # if len(query_tokens) > max_query_length:
            # query_tokens = query_tokens[0:max_query_length]
            # Remove the tokens appended at the front of query, which may belong to last query and answer.
            # query_tokens = query_tokens[-max_query_length:]
            query_tokens = ["[CLS]"] + query_tokens + ["[SEP]"]
            ques_input_ids = tokenizer.convert_tokens_to_ids(query_tokens)
            ques_input_mask = [1] * len(ques_input_ids)
            assert len(ques_input_ids) <= max_query_length
            while len(ques_input_ids) < max_query_length:
                ques_input_ids.append(0)
                ques_input_mask.append(0)
            assert len(ques_input_ids) == max_query_length
            assert len(ques_input_mask) == max_query_length

            doc_sen_tokens = example.doc_tokens
            all_doc_tokens = []
            for sentence in doc_sen_tokens:
                cur_sen_doc_tokens = ["[CLS]"]
                for token in sentence:
                    sub_tokens = tokenizer.tokenize(token)
                    if len(cur_sen_doc_tokens) + 1 + len(
                            sub_tokens) > max_seq_length:
                        drop += 1
                        break
                    cur_sen_doc_tokens.extend(sub_tokens)
                cur_sen_doc_tokens.append("[SEP]")
                all_doc_tokens.append(cur_sen_doc_tokens)

            pass_input_ids = []
            pass_input_mask = []
            for sentence in all_doc_tokens:
                sentence_input_ids = tokenizer.convert_tokens_to_ids(sentence)
                sentence_input_mask = [1] * len(sentence_input_ids)

                assert len(sentence_input_ids) <= max_seq_length, len(
                    sentence_input_ids)

                while len(sentence_input_ids) < max_seq_length:
                    sentence_input_ids.append(0)
                    sentence_input_mask.append(0)
                assert len(sentence_input_ids) == max_seq_length
                assert len(sentence_input_mask) == max_seq_length

                pass_input_ids.append(sentence_input_ids)
                pass_input_mask.append(sentence_input_mask)

            features.append(
                SingleSentenceFeature(qas_id=example.qas_id,
                                      unique_id=unique_id,
                                      example_index=example_index,
                                      tokens=all_doc_tokens,
                                      ques_input_ids=ques_input_ids,
                                      ques_input_mask=ques_input_mask,
                                      pass_input_ids=pass_input_ids,
                                      pass_input_mask=pass_input_mask,
                                      is_impossible=example.is_impossible,
                                      sentence_id=example.sentence_id))
            unique_id += 1
        logger.info(
            f'Read {len(features)} features and trunk {drop} sentences')
        return features
Esempio n. 10
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--bert_model_dir",
                        default='../data/bert_models/chinese_finetuned_lm/',
                        type=str,
                        help="Bert pre-trained model config dir")
    parser.add_argument(
        "--bert_model_vocab",
        default='../data/bert_models/chinese_finetuned_lm/vocab.txt',
        type=str,
        help="Bert pre-trained model vocab path")
    parser.add_argument(
        "--output_dir",
        default="./output",
        type=str,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    # Other parameters
    parser.add_argument("--predict_file",
                        default='../data/cn/lm_test_zh.txt',
                        type=str,
                        help="for predictions.")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=64,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument(
        "--verbose_logging",
        default=False,
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()

    device = torch.device("cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer(args.bert_model_vocab)
    MASK_ID = tokenizer.convert_tokens_to_ids([MASK_TOKEN])[0]
    print('MASK_ID,', MASK_ID)

    # Prepare model
    model = BertForMaskedLM.from_pretrained(args.bert_model_dir)

    # Save a trained model
    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    if not os.path.exists(output_model_file):
        torch.save(model_to_save.state_dict(), output_model_file)

    # Load a trained model that you have fine-tuned
    model_state_dict = torch.load(output_model_file)
    model.to(device)

    # Tokenized input
    text = "吸烟的人容易得癌症"
    tokenized_text = tokenizer.tokenize(text)
    print(text, '=>', tokenized_text)

    # Mask a token that we will try to predict back with `BertForMaskedLM`
    masked_index = 8
    tokenized_text[masked_index] = '[MASK]'

    # Convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
    segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0]

    # Convert inputs to PyTorch tensors
    print('tokens, segments_ids:', indexed_tokens, segments_ids)
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    # Load pre-trained model (weights)
    model.eval()
    # Predict all tokens
    predictions = model(tokens_tensor, segments_tensors)
    predicted_index = torch.argmax(predictions[0, masked_index]).item()
    print(predicted_index)
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
    print(predicted_token)
    # infer one line end

    # predict ppl and prob of each word
    text = "吸烟的人容易得癌症"
    tokenized_text = tokenizer.tokenize(text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
    segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0]
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    sentence_loss = 0.0
    sentence_count = 0
    for idx, label in enumerate(text):
        print(label)
        label_id = tokenizer.convert_tokens_to_ids([label])[0]
        lm_labels = [-1, -1, -1, -1, -1, -1, -1, -1, -1]
        if idx != 0:
            lm_labels[idx] = label_id
        if idx == 1:
            lm_labels = indexed_tokens
        print(lm_labels)
        masked_lm_labels = torch.tensor([lm_labels])

        # Predict all tokens
        loss = model(tokens_tensor,
                     segments_tensors,
                     masked_lm_labels=masked_lm_labels)
        print('loss:', loss)
        prob = float(np.exp(-loss.item()))
        print('prob:', prob)
        sentence_loss += prob
        sentence_count += 1
    ppl = float(np.exp(sentence_loss / sentence_count))
    print('ppl:', ppl)

    # confirm we were able to predict 'henson'
    # infer each word with mask one
    text = "吸烟的人容易得癌症"
    for masked_index, label in enumerate(text):
        tokenized_text = tokenizer.tokenize(text)
        print(text, '=>', tokenized_text)
        tokenized_text[masked_index] = '[MASK]'
        print(tokenized_text)

        # Convert token to vocabulary indices
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens_tensor = torch.tensor([indexed_tokens])
        segments_tensors = torch.tensor([segments_ids])
        predictions = model(tokens_tensor, segments_tensors)
        print('expected label:', label)

        predicted_index = torch.argmax(predictions[0, masked_index]).item()
        predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
        print('predict label:', predicted_token)

        scores = predictions[0, masked_index]
        # predicted_index = torch.argmax(scores).item()
        top_scores = torch.sort(scores, 0, True)
        top_score_val = top_scores[0][:5]
        top_score_idx = top_scores[1][:5]
        for j in range(len(top_score_idx)):
            print(
                'Mask predict is:',
                tokenizer.convert_ids_to_tokens([top_score_idx[j].item()])[0],
                ' prob:', top_score_val[j].item())
        print()

    if args.predict_file:
        eval_examples = read_lm_examples(input_file=args.predict_file)
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            mask_token=MASK_TOKEN,
            mask_id=MASK_ID)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("Start predict ...")
        for f in eval_features:
            input_ids = torch.tensor([f.input_ids])
            segment_ids = torch.tensor([f.segment_ids])
            predictions = model(input_ids, segment_ids)
            # confirm we were able to predict 'henson'
            mask_positions = f.mask_positions

            if mask_positions:
                for idx, i in enumerate(mask_positions):
                    if not i:
                        continue
                    scores = predictions[0, i]
                    # predicted_index = torch.argmax(scores).item()
                    top_scores = torch.sort(scores, 0, True)
                    top_score_val = top_scores[0][:5]
                    top_score_idx = top_scores[1][:5]
                    # predicted_prob = predictions[0, i][predicted_index].item()
                    # predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
                    print('original text is:', f.input_tokens)
                    # print('Mask predict is:', predicted_token, ' prob:', predicted_prob)
                    for j in range(len(top_score_idx)):
                        print(
                            'Mask predict is:',
                            tokenizer.convert_ids_to_tokens(
                                [top_score_idx[j].item()])[0], ' prob:',
                            top_score_val[j].item())
Esempio n. 11
0
    def convert_examples_to_features(examples: List[QAFullExample], tokenizer: BertTokenizer, max_seq_length: int, doc_stride: int,
                                     max_query_length: int, is_training: bool):
        unique_id = 1000000000
        features = []
        for (example_index, example) in tqdm(enumerate(examples), desc='Converting examples to features..', total=len(examples)):
            query_tokens = tokenizer.tokenize(example.question_text)

            if len(query_tokens) > max_query_length:
                query_tokens = query_tokens[-max_query_length:]

            tok_to_orig_index = []
            orig_to_tok_index = []
            all_doc_tokens = []
            for (i, token) in enumerate(example.doc_tokens):
                orig_to_tok_index.append(len(all_doc_tokens))
                sub_tokens = tokenizer.tokenize(token)
                for sub_token in sub_tokens:
                    tok_to_orig_index.append(i)
                    all_doc_tokens.append(sub_token)

            sentence_spans = []
            for (start, end) in example.sentence_span_list:
                piece_start = orig_to_tok_index[start]
                if end < len(example.doc_tokens) - 1:
                    piece_end = orig_to_tok_index[end + 1] - 1
                else:
                    piece_end = len(all_doc_tokens) - 1
                sentence_spans.append((piece_start, piece_end))

            max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

            _DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
            doc_spans = []
            start_offset = 0
            while start_offset < len(all_doc_tokens):
                length = len(all_doc_tokens) - start_offset
                if length > max_tokens_for_doc:
                    length = max_tokens_for_doc
                doc_spans.append(_DocSpan(start=start_offset, length=length))
                if start_offset + length == len(all_doc_tokens):
                    break
                start_offset += min(length, doc_stride)

            sentence_spans_list = []
            sentence_ids_list = []
            for span_id, doc_span in enumerate(doc_spans):
                span_start = doc_span.start
                span_end = span_start + doc_span.length - 1

                span_sentence = []
                sen_ids = []
                for sen_idx, (sen_start, sen_end) in enumerate(sentence_spans):
                    if sen_end < span_start:
                        continue
                    if sen_start > span_end:
                        break
                    span_sentence.append((max(sen_start, span_start), min(sen_end, span_end)))
                    sen_ids.append(sen_idx)

                sentence_spans_list.append(span_sentence)
                sentence_ids_list.append(sen_ids)

            ini_sen_id: List[int] = example.sentence_id
            for (doc_span_index, doc_span) in enumerate(doc_spans):

                token_to_orig_map = {}
                token_is_max_context = {}
                tokens = ["[CLS]"] + query_tokens + ["[SEP]"]
                segment_ids = [0] * len(tokens)

                doc_start = doc_span.start
                doc_offset = len(query_tokens) + 2
                sentence_list = sentence_spans_list[doc_span_index]
                cur_sentence_list = []
                for sen_id, sen in enumerate(sentence_list):
                    new_sen = (sen[0] - doc_start + doc_offset, sen[1] - doc_start + doc_offset)
                    cur_sentence_list.append(new_sen)

                for i in range(doc_span.length):
                    split_token_index = doc_span.start + i
                    token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
                    is_max_context = utils.check_is_max_context(doc_spans, doc_span_index, split_token_index)

                    token_is_max_context[len(tokens)] = is_max_context
                    tokens.append(all_doc_tokens[split_token_index])
                    segment_ids.append(1)
                tokens.append("[SEP]")
                segment_ids.append(1)

                input_ids = tokenizer.convert_tokens_to_ids(tokens)
                input_mask = [1] * len(input_ids)

                while len(input_ids) < max_seq_length:
                    input_ids.append(0)
                    input_mask.append(0)
                    segment_ids.append(0)

                assert len(input_ids) == len(input_mask) == len(segment_ids) == max_seq_length

                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1

                """
                There are multiple evidence sentences in some examples. To avoid multi-label setting,
                we choose to use the evidence sentence with the max length.
                """
                span_sen_id = -1
                max_evidence_length = 0
                for piece_sen_id, sen_id in enumerate(sentence_ids_list[doc_span_index]):
                    if sen_id in ini_sen_id:
                        evidence_length = cur_sentence_list[piece_sen_id][1] - cur_sentence_list[piece_sen_id][0]
                        if evidence_length > max_evidence_length:
                            max_evidence_length = evidence_length
                            span_sen_id = piece_sen_id
                meta_data = {
                    'span_sen_to_orig_sen_map': sentence_ids_list[doc_span_index]
                }

                if span_sen_id == -1:
                    answer_choice = 0
                else:
                    answer_choice = example.is_impossible + 1

                features.append(QAFullInputFeatures(
                    qas_id=example.qas_id,
                    unique_id=unique_id,
                    example_index=example_index,
                    doc_span_index=doc_span_index,
                    sentence_span_list=cur_sentence_list,
                    tokens=tokens,
                    token_to_orig_map=token_to_orig_map,
                    token_is_max_context=token_is_max_context,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    segment_ids=segment_ids,
                    is_impossible=answer_choice,
                    sentence_id=span_sen_id,
                    start_position=None,
                    end_position=None,
                    ral_start_position=None,
                    ral_end_position=None,
                    meta_data=meta_data
                ))
            unique_id += 1
        return features
Esempio n. 12
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")

    # decoding parameters
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument('--subset',
                        type=int,
                        default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument('--topk', type=int, default=10, help="Value of K.")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Ignore the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--not_predict_token',
                        type=str,
                        default=None,
                        help="Do not predict the tokens during decoding.")

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # tokenizer = BertTokenizer.from_pretrained(
    #     args.bert_model, do_lower_case=args.do_lower_case)
    tokenizer = BertTokenizer(
        vocab_file=
        '/ps2/intern/clsi/BERT/bert_weights/cased_L-24_H-1024_A-16/vocab.txt',
        do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            new_segment_ids=args.new_segment_ids,
            mode="s2s",
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    def _get_token_id_set(s):
        r = None
        if s:
            w_list = []
            for w in s.split('|'):
                if w.startswith('[') and w.endswith(']'):
                    w_list.append(w.upper())
                else:
                    w_list.append(w)
            r = set(tokenizer.convert_tokens_to_ids(w_list))
        return r

    forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word)
    not_predict_set = _get_token_id_set(args.not_predict_token)
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=pair_num_relation,
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            not_predict_set=not_predict_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            mode=args.mode,
            max_position_embeddings=args.max_seq_length,
            ffn_type=args.ffn_type,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            pos_shift=args.pos_shift,
            topk=args.topk,
            config_path=args.config_path)
        del model_recover

        if args.fp16:
            model.half()
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()
        next_i = 0
        max_src_length = args.max_seq_length - 2 - args.max_tgt_length

        ## for YFG style json
        # testset = loads_json(args.input_file, 'Load Test Set: '+args.input_file)
        # if args.subset > 0:
        #     logger.info("Decoding subset: %d", args.subset)
        #     testset = testset[:args.subset]

        with open(args.input_file, encoding="utf-8") as fin:
            data = json.load(fin)
        #     input_lines = [x.strip() for x in fin.readlines()]
        #     if args.subset > 0:
        #         logger.info("Decoding subset: %d", args.subset)
        #         input_lines = input_lines[:args.subset]
        # data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
        # input_lines = [data_tokenizer.tokenize(
        #     x)[:max_src_length] for x in input_lines]
        # input_lines = sorted(list(enumerate(input_lines)),
        #                      key=lambda x: -len(x[1]))
        # output_lines = [""] * len(input_lines)
        # score_trace_list = [None] * len(input_lines)
        # total_batch = math.ceil(len(input_lines) / args.batch_size)

        data_tokenizer = WhitespaceTokenizer(
        ) if args.tokenized_input else tokenizer
        PQA_dict = {}  #will store the generated distractors
        dis_tot = 0
        dis_n = 0
        len_tot = 0
        hypothesis = {}
        ##change to process one by one and store the distractors in PQA json form
        ##with tqdm(total=total_batch) as pbar:
        # for example in tqdm(testset):
        #     question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id'])
        #     if question_id in hypothesis:
        #         continue
        # dis_n += 1
        # if dis_n % 2000 == 0:
        #     logger.info("Already processed: "+str(dis_n))
        counter = 0
        for race_id, example in tqdm(data.items()):
            counter += 1
            if args.subset > 0 and counter >= args.subset:
                break
            eg_dict = {}
            # eg_dict["question_id"] = question_id
            # eg_dict["question"] = ' '.join(example['question'])
            # eg_dict["context"] = ' '.join(example['article'])

            eg_dict["question"] = example['question']
            eg_dict["context"] = example['context']
            label = int(example["label"])
            options = example["options"]
            answer = options[label]
            #new_distractors = []
            pred1 = None
            pred2 = None
            pred3 = None
            #while next_i < len(input_lines):
            #_chunk = input_lines[next_i:next_i + args.batch_size]
            #line = example["context"].strip() + ' ' + example["question"].strip()
            question = example['question']
            question = question.replace('_', ' ')
            line = ' '.join(
                nltk.word_tokenize(example['context']) +
                nltk.word_tokenize(question))
            line = [data_tokenizer.tokenize(line)[:max_src_length]]
            # buf_id = [x[0] for x in _chunk]
            # buf = [x[1] for x in _chunk]
            buf = line
            #next_i += args.batch_size
            max_a_len = max([len(x) for x in buf])
            instances = []
            for instance in [(x, max_a_len) for x in buf]:
                for proc in bi_uni_pipeline:
                    instances.append(proc(instance))
            with torch.no_grad():
                batch = seq2seq_loader.batch_list_to_batch_tensors(instances)
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                # for i in range(1):
                #try max 10 times
                # if len(new_distractors) >= 3:
                #     break
                traces = model(input_ids,
                               token_type_ids,
                               position_ids,
                               input_mask,
                               task_idx=task_idx,
                               mask_qkv=mask_qkv)
                if args.beam_size > 1:
                    traces = {k: v.tolist() for k, v in traces.items()}
                    output_ids = traces['pred_seq']
                    # print (np.array(output_ids).shape)
                    # print (output_ids)
                else:
                    output_ids = traces.tolist()
                # now only supports single batch decoding!!!
                # will keep the second and third sequence as backup
                for i in range(len(buf)):
                    # print (len(buf), buf)
                    for s in range(len(output_ids)):
                        output_seq = output_ids[s]
                        #w_ids = output_ids[i]
                        #output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                        output_buf = tokenizer.convert_ids_to_tokens(
                            output_seq)
                        output_tokens = []
                        for t in output_buf:
                            if t in ("[SEP]", "[PAD]"):
                                break
                            output_tokens.append(t)
                        if s == 1:
                            backup_1 = output_tokens
                        if s == 2:
                            backup_2 = output_tokens
                        if pred1 is None:
                            pred1 = output_tokens
                        elif jaccard_similarity(pred1, output_tokens) < 0.5:
                            if pred2 is None:
                                pred2 = output_tokens
                            elif pred3 is None:
                                if jaccard_similarity(pred2,
                                                      output_tokens) < 0.5:
                                    pred3 = output_tokens
                        if pred1 is not None and pred2 is not None and pred3 is not None:
                            break
                    if pred2 is None:
                        pred2 = backup_1
                        if pred3 is None:
                            pred3 = backup_2
                    elif pred3 is None:
                        pred3 = backup_1
                        # output_sequence = ' '.join(detokenize(output_tokens))
                        # print (output_sequence)
                        # print (output_sequence)
                        # if output_sequence.lower().strip() == answer.lower().strip():
                        #     continue
                        # repeated = False
                        # for cand in new_distractors:
                        #     if output_sequence.lower().strip() == cand.lower().strip():
                        #         repeated = True
                        #         break
                        # if not repeated:
                        #     new_distractors.append(output_sequence.strip())

            #hypothesis[question_id] = [pred1, pred2, pred3]
            new_distractors = [pred1, pred2, pred3]
            # print (new_distractors)
            # dis_tot += len(new_distractors)
            # # fill the missing ones with original distractors
            # for i in range(4):
            #     if len(new_distractors) >= 3:
            #         break
            #     elif i == label:
            #         continue
            #     else:
            #         new_distractors.append(options[i])
            for dis in new_distractors:
                len_tot += len(dis)
                dis_n += 1
            new_distractors = [
                ' '.join(detokenize(dis)) for dis in new_distractors
                if dis is not None
            ]
            assert len(new_distractors) == 3, "Number of distractors WRONG"
            new_distractors.insert(label, answer)
            #eg_dict["generated_distractors"] = new_distractors
            eg_dict["options"] = new_distractors
            eg_dict["label"] = label
            #PQA_dict[question_id] = eg_dict
            PQA_dict[race_id] = eg_dict

        # reference = {}
        # for example in testset:
        #     question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id'])
        #     if question_id not in reference.keys():
        #         reference[question_id] = [example['distractor']]
        #     else:
        #         reference[question_id].append(example['distractor'])

        # _ = eval(hypothesis, reference)
        # assert len(PQA_dict) == len(data), "Number of examples WRONG"
        # logger.info("Average number of GENERATED distractor per question: "+str(dis_tot/dis_n))
        logger.info("Average length of distractors: " + str(len_tot / dis_n))
        with open(args.output_file, mode='w', encoding='utf-8') as f:
            json.dump(PQA_dict, f, indent=4)
Esempio n. 13
0
def batch_generator_with_multi(file_path: str, tokenizer: BertTokenizer,
                               max_seq_length: int, batch_size: int, device,
                               data_limit: int):
    batch_inputs = []
    batch_att_mask = []
    batch_target_ids = []
    batch_position = []

    for n, instance in enumerate(read_instance(file_path)):
        if data_limit == n:
            logger.info(
                "The maximum number of rows has been reached: {}".format(
                    data_limit))
            batch_inputs = []
            break

        tokens_with_mask = instance["surfaces"]
        mask_ids = [
            idx for idx, token in enumerate(tokens_with_mask) if token == MASK
        ]

        tokenized_tokens = tokenizer.tokenize(" ".join(tokens_with_mask))

        if n < 3:
            logger.debug(tokens_with_mask)
            logger.debug(tokenized_tokens)
            logger.debug(mask_ids)

        subword_mask_ids = [
            idx for idx, subword in enumerate(tokenized_tokens)
            if subword == MASK
        ]
        within_mask_ids = [
            idx for idx in subword_mask_ids if idx < max_seq_length - 2
        ]
        out_mask_ids = [
            idx for idx in subword_mask_ids if idx >= max_seq_length - 2
        ]

        buffer = []
        if within_mask_ids:
            in_tokens = [CLS] + tokenized_tokens[0:max_seq_length - 2] + [SEP]
            target_ids = [
                idx for idx, token in enumerate(in_tokens) if token == MASK
            ]
            buffer.append((in_tokens, target_ids))
        if out_mask_ids:
            logger.debug("exceed {}".format(max_seq_length))
            in_tokens = [CLS] + tokenized_tokens[len(tokenized_tokens) -
                                                 max_seq_length + 2:] + [SEP]
            target_ids = [
                idx for idx, token in enumerate(in_tokens) if token == MASK
            ][-len(out_mask_ids):]
            assert len(in_tokens) == max_seq_length
            logger.debug(in_tokens)
            logger.debug(target_ids)
            buffer.append((in_tokens, target_ids))
            if out_mask_ids[-1] >= (max_seq_length - 2) * 2:
                raise RuntimeError("Sentence is too long.")

        len_batch = len(batch_inputs)
        for in_tokens, target_ids in buffer:
            input_ids = tokenizer.convert_tokens_to_ids(
                in_tokens) + [0] * (max_seq_length - len(in_tokens))
            att_mask = [1] * len(in_tokens) + [0] * (max_seq_length -
                                                     len(in_tokens))
            batch_inputs.append(input_ids)
            batch_att_mask.append(att_mask)
            batch_target_ids += [(len_batch, i) for i in target_ids]
        batch_position += [(instance["unique_id"], instance["sentence id"],
                            instance["file name"], mask_idx)
                           for mask_idx in mask_ids]

        if len(batch_inputs) >= batch_size:
            assert len(batch_inputs) == len(batch_att_mask)
            batch_inputs = torch.LongTensor(batch_inputs).to(device)
            batch_att_mask = torch.LongTensor(batch_att_mask).to(device)
            batch_target_ids = [[i[0] for i in batch_target_ids],
                                [i[1] for i in batch_target_ids]]

            yield batch_inputs, batch_att_mask, batch_target_ids, batch_position

            batch_inputs = []
            batch_att_mask = []
            batch_target_ids = []
            batch_position = []

    if batch_inputs:
        assert len(batch_inputs) == len(batch_att_mask)
        batch_inputs = torch.LongTensor(batch_inputs).to(device)
        batch_att_mask = torch.LongTensor(batch_att_mask).to(device)
        batch_target_ids = [[i[0] for i in batch_target_ids],
                            [i[1] for i in batch_target_ids]]

        yield batch_inputs, batch_att_mask, batch_target_ids, batch_position
Esempio n. 14
0
def batch_generator_with_single(file_path: str, tokenizer: BertTokenizer,
                                max_seq_length: int, batch_size: int, device,
                                data_limit: int):
    batch_inputs = []
    batch_att_mask = []
    batch_target_ids = []
    batch_position = []

    for n, instance in enumerate(read_instance(file_path)):
        if data_limit == n:
            logger.info(
                "The maximum number of rows has been reached: {}".format(
                    data_limit))
            batch_inputs = []
            break

        if n % 1000 == 0:
            logger.info("Unique ID: {}".format(instance["unique_id"]))

        tokens_with_mask = instance["surfaces"]
        mask_ids = [
            idx for idx, token in enumerate(tokens_with_mask) if token == MASK
        ]
        for mask_idx in mask_ids:
            original_tokens = copy.deepcopy(instance["original_surfaces"])
            original_tokens[mask_idx] = MASK
            tokenized_tokens = tokenizer.tokenize(" ".join(original_tokens))

            if n < 3:
                logger.debug(original_tokens)
                logger.debug(tokenized_tokens)
                logger.debug(mask_idx)

            if tokenized_tokens.index(MASK) < max_seq_length - 2:
                in_tokens = [CLS] + tokenized_tokens[0:max_seq_length -
                                                     2] + [SEP]
            elif len(tokenized_tokens) < (max_seq_length - 2) * 2:
                logger.debug("exceed {}".format(max_seq_length))
                in_tokens = [CLS
                             ] + tokenized_tokens[len(tokenized_tokens) -
                                                  max_seq_length + 2:] + [SEP]
                assert len(in_tokens) == max_seq_length
            else:
                raise RuntimeError("Sentence is too long.")

            assert MASK in in_tokens
            input_ids = tokenizer.convert_tokens_to_ids(
                in_tokens) + [0] * (max_seq_length - len(in_tokens))
            att_mask = [1] * len(in_tokens) + [0] * (max_seq_length -
                                                     len(in_tokens))

            batch_inputs.append(input_ids)
            batch_att_mask.append(att_mask)
            batch_target_ids.append(in_tokens.index(MASK))
            batch_position.append(
                (instance["unique_id"], instance["sentence id"],
                 instance["file name"], mask_idx))
            if len(batch_inputs) == batch_size:
                assert len(batch_inputs) == len(batch_att_mask) == len(
                    batch_target_ids) == len(batch_position)
                batch_inputs = torch.LongTensor(batch_inputs).to(device)
                batch_att_mask = torch.LongTensor(batch_att_mask).to(device)
                batch_target_ids = [[i for i in range(batch_size)],
                                    batch_target_ids]

                yield batch_inputs, batch_att_mask, batch_target_ids, batch_position

                batch_inputs = []
                batch_att_mask = []
                batch_target_ids = []
                batch_position = []
    if batch_inputs:
        batch_target_ids = [[i for i in range(len(batch_inputs))],
                            batch_target_ids]
        batch_inputs = torch.LongTensor(batch_inputs).to(device)
        batch_att_mask = torch.LongTensor(batch_att_mask).to(device)

        yield batch_inputs, batch_att_mask, batch_target_ids, batch_position
Esempio n. 15
0
class CreateDataset(Dataset):
    def __init__(self, data_path, max_seq_len, vocab_path, example_type, seed):
        self.seed = seed
        self.max_seq_len = max_seq_len
        self.example_type = example_type
        self.data_path = data_path
        self.vocab_path = vocab_path
        self.reset()

    # 初始化
    def reset(self):
        # 加载语料库,这是pretrained Bert模型自带的
        self.tokenizer = BertTokenizer(vocab_file=self.vocab_path)
        # 构建examples
        self.build_examples()

    # 读取数据集
    def read_data(self, quotechar=None):
        '''
        默认是以tab分割的数据
        :param quotechar:
        :return:
        '''
        lines = []
        with open(self.data_path, 'r', encoding='utf-8') as fr:
            reader = csv.reader(fr, delimiter='\t', quotechar=quotechar)
            for line in reader:
                lines.append(line)
        return lines

    # 构建数据examples
    def build_examples(self):
        lines = self.read_data()
        self.examples = []
        for i, line in enumerate(lines):
            guid = '%s-%d' % (self.example_type, i)
            label = line[0]
            text_a = line[1]
            example = InputExample(guid=guid, text_a=text_a, label=label)
            self.examples.append(example)
        del lines

    # 将example转化为feature
    def build_features(self, example):
        '''
        # 对于两个句子:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1

        # 对于单个句子:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0   0   0   0  0     0 0
        # type_ids:表示是第一个句子还是第二个句子
        '''
        #转化为token
        tokens_a = self.tokenizer.tokenize(example.text_a)
        # Account for [CLS] and [SEP] with "- 2"
        if len(tokens_a) > self.max_seq_len - 2:
            tokens_a = tokens_a[:(self.max_seq_len - 2)]
        # 句子首尾加入标示符
        tokens = ['[CLS]'] + tokens_a + ['[SEP]']
        segment_ids = [0] * len(tokens)  # 对应type_ids
        # 将词转化为语料库中对应的id
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        # 输入mask
        input_mask = [1] * len(input_ids)
        # padding,使用0进行填充
        padding = [0] * (self.max_seq_len - len(input_ids))

        input_ids += padding
        input_mask += padding
        segment_ids += padding

        # 标签
        label_id = int(example.label)
        feature = InputFeature(input_ids=input_ids,
                               input_mask=input_mask,
                               segment_ids=segment_ids,
                               label_id=label_id)
        return feature

    def _preprocess(self, index):
        example = self.examples[index]
        feature = self.build_features(example)
        return np.array(feature.input_ids),np.array(feature.input_mask),\
               np.array(feature.segment_ids),np.array(feature.label_id)

    def __getitem__(self, index):
        return self._preprocess(index)

    def __len__(self):
        return len(self.examples)
def get_data(filename, tokenizer: BertTokenizer, opts: DataOptions, limit=0):
    dataset = []
    max_chunks = 0
    max_sent_len = 0
    qid2supportingfacts = dict()
    qid2sentids = dict()
    for features in get_features(filename, tokenizer, opts):
        # convert to torch tensors
        slen = max([len(ct) for ct in features.chunk_tokens])
        chunk_token_ids = [
            tokenizer.convert_tokens_to_ids(ct) + [0] * (slen - len(ct))
            for ct in features.chunk_tokens
        ]
        segment_ids = [
            sids + [0] * (slen - len(sids)) for sids in features.segment_ids
        ]
        if len(features.chunk_tokens) > max_chunks:
            max_chunks = len(features.chunk_tokens)
        max_sent_len = max(max_sent_len,
                           (np.array(features.sent_ends) -
                            np.array(features.sent_starts)).max())
        sent_targets = None
        if features.supporting_facts is not None:
            qid2supportingfacts[features.id] = features.supporting_facts
            for sid in features.sent_ids:
                qid2sentids.setdefault(features.id, set()).add(sid)
            sent_targets = torch.zeros(len(features.sent_ids),
                                       dtype=torch.float)
            for sf in features.supporting_facts:
                if sf not in features.sent_ids:
                    continue
                sent_targets[features.sent_ids.index(sf)] = 1
        assert len(features.sent_starts) == len(features.sent_ends) == len(
            features.sent_ids)
        assert len(chunk_token_ids) == len(features.chunk_lengths)
        dataset.append((features.id, features.sent_ids, features.question_len,
                        features.chunk_lengths,
                        torch.tensor(chunk_token_ids, dtype=torch.long),
                        torch.tensor(segment_ids, dtype=torch.long),
                        torch.tensor(features.sent_starts, dtype=torch.long),
                        torch.tensor(features.sent_ends,
                                     dtype=torch.long), sent_targets))
        if 0 < limit <= len(dataset):
            break
        if len(dataset) % 5000 == 0:
            logger.info(f'loading dataset item {len(dataset)} from {filename}')
    logger.info(
        f'in {filename}: max_chunks = {max_chunks}, max_sent_length = {max_sent_len}'
    )
    out_of_recall = 0
    total_positives = 0
    for id, sps in qid2supportingfacts.items():
        total_positives += len(sps)
        sent_ids = qid2sentids.get(id)
        for sp in sps:
            if sp not in sent_ids:
                out_of_recall += 1
    if len(qid2supportingfacts) > 0:
        logger.info(
            f'in {filename}, due to truncations we have lost {out_of_recall} out of {total_positives} positives'
        )
    return dataset, qid2supportingfacts