Example #1
0
 def __repr__(self):
     s = ''
     s += 'qas_id: %s' % (printable_text(self.qas_id))
     s += ', question_text: %s' % (printable_text(self.question_text))
     s += ', paragraph_text: [%s]' % (' '.join(self.paragraph_text))
     if self.start_position:
         s += ', start_position: %d' % (self.start_position)
     if self.start_position:
         s += ', is_impossible: %r' % (self.is_impossible)
     return s
Example #2
0
 def __repr__(self):
     s = ""
     s += "qas_id: %s" % (printable_text(self.qas_id))
     s += ", question_text: %s" % (printable_text(self.question_text))
     s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
     if self.start_position:
         s += ", start_position: %d" % (self.start_position)
     if self.start_position:
         s += ", is_impossible: %r" % (self.is_impossible)
     return s
    def record(self, predicts):
        decoded_results = []
        for predict in predicts:
            input_tokens = self.tokenizer.ids_to_tokens(predict["input_ids"])
            input_masks = predict["input_masks"]

            input_text = "".join(input_tokens).replace(
                prepro_utils.SPIECE_UNDERLINE, " ")

            decoded_result = {
                "text":
                prepro_utils.printable_text(input_text),
                "sent_label":
                self.sent_label_list[predict["sent_label_id"]],
                "sent_predict":
                self.sent_label_list[predict["sent_predict_id"]],
                "sent_score":
                float(predict["sent_predict_score"]),
                "sent_probs":
                [float(prob) for prob in predict["sent_predict_prob"]]
            }

            decoded_results.append(decoded_result)

        self._write_to_json(decoded_results, self.output_path)
Example #4
0
    def record(self, predicts):
        decoded_results = []
        for predict in predicts:
            input_tokens = self.tokenizer.ids_to_tokens(predict["input_ids"])
            input_masks = predict["input_masks"]
            token_labels = [
                self.token_label_list[idx]
                for idx in predict["token_label_ids"]
            ]
            token_predicts = [
                self.token_label_list[idx]
                for idx in predict["token_predict_ids"]
            ]

            decoded_tokens = []
            decoded_labels = []
            decoded_predicts = []
            results = zip(input_tokens, input_masks, token_labels,
                          token_predicts)
            for input_token, input_mask, token_label, token_predict in results:
                if input_token in ["<cls>", "<sep>"] or input_mask == 1:
                    continue

                if token_predict in ["<pad>", "<cls>", "<sep>", "X"]:
                    token_predict = "O"

                if input_token.startswith(prepro_utils.SPIECE_UNDERLINE):
                    decoded_tokens.append(input_token)
                    decoded_labels.append(token_label)
                    decoded_predicts.append(token_predict)
                else:
                    decoded_tokens[-1] = decoded_tokens[-1] + input_token

            decoded_text = "".join(decoded_tokens).replace(
                prepro_utils.SPIECE_UNDERLINE, " ")
            decoded_label = " ".join(decoded_labels)
            decoded_predict = " ".join(decoded_predicts)

            decoded_result = {
                "text": prepro_utils.printable_text(decoded_text),
                "token_label": decoded_label,
                "token_predict": decoded_predict,
                "sent_label": self.sent_label_list[predict["sent_label_id"]],
                "sent_predict":
                self.sent_label_list[predict["sent_predict_id"]],
            }

            decoded_results.append(decoded_result)

        self._write_to_json(decoded_results, self.output_path)
Example #5
0
def convert_examples_to_features(examples, sp_model, max_seq_length,
                                 doc_stride, max_query_length, is_training,
                                 uncased):
    """Loads a data file into a list of `InputBatch`s."""

    cnt_pos, cnt_neg = 0, 0
    unique_id = 1000000000
    max_N, max_M = 1024, 1024
    f = np.zeros((max_N, max_M), dtype=np.float32)

    for (example_index, example) in enumerate(examples):

        if example_index % 100 == 0:
            print('Converting {}/{} pos {} neg {}'.format(
                example_index, len(examples), cnt_pos, cnt_neg))

        query_tokens = encode_ids(
            sp_model, preprocess_text(example.question_text, lower=uncased))

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        paragraph_text = example.paragraph_text
        para_tokens = encode_pieces(
            sp_model, preprocess_text(example.paragraph_text, lower=uncased))

        chartok_to_tok_index = []
        tok_start_to_chartok_index = []
        tok_end_to_chartok_index = []
        char_cnt = 0
        for i, token in enumerate(para_tokens):
            chartok_to_tok_index.extend([i] * len(token))
            tok_start_to_chartok_index.append(char_cnt)
            char_cnt += len(token)
            tok_end_to_chartok_index.append(char_cnt - 1)

        tok_cat_text = ''.join(para_tokens).replace(SPIECE_UNDERLINE, ' ')
        N, M = len(paragraph_text), len(tok_cat_text)

        if N > max_N or M > max_M:
            max_N = max(N, max_N)
            max_M = max(M, max_M)
            f = np.zeros((max_N, max_M), dtype=np.float32)
            gc.collect()

        g = {}

        def _lcs_match(max_dist):
            f.fill(0)
            g.clear()

            ### longest common sub sequence
            # f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
            for i in range(N):

                # note(zhiliny):
                # unlike standard LCS, this is specifically optimized for the setting
                # because the mismatch between sentence pieces and original text will
                # be small
                for j in range(i - max_dist, i + max_dist):
                    if j >= M or j < 0: continue

                    if i > 0:
                        g[(i, j)] = 0
                        f[i, j] = f[i - 1, j]

                    if j > 0 and f[i, j - 1] > f[i, j]:
                        g[(i, j)] = 1
                        f[i, j] = f[i, j - 1]

                    f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
                    if (preprocess_text(paragraph_text[i],
                                        lower=uncased,
                                        remove_space=False) == tok_cat_text[j]
                            and f_prev + 1 > f[i, j]):
                        g[(i, j)] = 2
                        f[i, j] = f_prev + 1

        max_dist = abs(N - M) + 5
        for _ in range(2):
            _lcs_match(max_dist)
            if f[N - 1, M - 1] > 0.8 * N: break
            max_dist *= 2

        orig_to_chartok_index = [None] * N
        chartok_to_orig_index = [None] * M
        i, j = N - 1, M - 1
        while i >= 0 and j >= 0:
            if (i, j) not in g: break
            if g[(i, j)] == 2:
                orig_to_chartok_index[i] = j
                chartok_to_orig_index[j] = i
                i, j = i - 1, j - 1
            elif g[(i, j)] == 1:
                j = j - 1
            else:
                i = i - 1

        if all(v is None
               for v in orig_to_chartok_index) or f[N - 1, M - 1] < 0.8 * N:
            print('MISMATCH DETECTED!')
            continue

        tok_start_to_orig_index = []
        tok_end_to_orig_index = []
        for i in range(len(para_tokens)):
            start_chartok_pos = tok_start_to_chartok_index[i]
            end_chartok_pos = tok_end_to_chartok_index[i]
            start_orig_pos = _convert_index(chartok_to_orig_index,
                                            start_chartok_pos,
                                            N,
                                            is_start=True)
            end_orig_pos = _convert_index(chartok_to_orig_index,
                                          end_chartok_pos,
                                          N,
                                          is_start=False)

            tok_start_to_orig_index.append(start_orig_pos)
            tok_end_to_orig_index.append(end_orig_pos)

        if not is_training:
            tok_start_position = tok_end_position = None

        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1

        if is_training and not example.is_impossible:
            start_position = example.start_position
            end_position = start_position + len(example.orig_answer_text) - 1

            start_chartok_pos = _convert_index(orig_to_chartok_index,
                                               start_position,
                                               is_start=True)
            tok_start_position = chartok_to_tok_index[start_chartok_pos]

            end_chartok_pos = _convert_index(orig_to_chartok_index,
                                             end_position,
                                             is_start=False)
            tok_end_position = chartok_to_tok_index[end_chartok_pos]
            assert tok_start_position <= tok_end_position

        def _piece_to_id(x):
            if six.PY2 and isinstance(x, unicode):
                x = x.encode('utf-8')
            return sp_model.PieceToId(x)

        all_doc_tokens = list(map(_piece_to_id, para_tokens))

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_is_max_context = {}
            segment_ids = []
            p_mask = []

            cur_tok_start_to_orig_index = []
            cur_tok_end_to_orig_index = []

            for i in range(doc_span.length):
                split_token_index = doc_span.start + i

                cur_tok_start_to_orig_index.append(
                    tok_start_to_orig_index[split_token_index])
                cur_tok_end_to_orig_index.append(
                    tok_end_to_orig_index[split_token_index])

                is_max_context = _check_is_max_context(doc_spans,
                                                       doc_span_index,
                                                       split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(SEG_ID_P)
                p_mask.append(0)

            paragraph_len = len(tokens)

            tokens.append(SEP_ID)
            segment_ids.append(SEG_ID_P)
            p_mask.append(1)

            # note(zhiliny): we put P before Q
            # because during pretraining, B is always shorter than A
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(SEG_ID_Q)
                p_mask.append(1)
            tokens.append(SEP_ID)
            segment_ids.append(SEG_ID_Q)
            p_mask.append(1)

            cls_index = len(segment_ids)
            tokens.append(CLS_ID)
            segment_ids.append(SEG_ID_CLS)
            p_mask.append(0)

            input_ids = tokens

            # The mask has 0 for real tokens and 1 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [0] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(1)
                segment_ids.append(SEG_ID_PAD)
                p_mask.append(1)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length
            assert len(p_mask) == max_seq_length

            span_is_impossible = example.is_impossible
            start_position = None
            end_position = None
            if is_training and not span_is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start
                        and tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    # continue
                    start_position = 0
                    end_position = 0
                    span_is_impossible = True
                else:
                    # note(zhiliny): we put P before Q, so doc_offset should be zero.
                    # doc_offset = len(query_tokens) + 2
                    doc_offset = 0
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset

            if is_training and span_is_impossible:
                start_position = cls_index
                end_position = cls_index

            if example_index < 0:
                print("*** Example ***")
                print("unique_id: %s" % (unique_id))
                print("example_index: %s" % (example_index))
                print("doc_span_index: %s" % (doc_span_index))
                print("tok_start_to_orig_index: %s" %
                      " ".join([str(x) for x in cur_tok_start_to_orig_index]))
                print("tok_end_to_orig_index: %s" %
                      " ".join([str(x) for x in cur_tok_end_to_orig_index]))
                print("token_is_max_context: %s" % " ".join([
                    "%d:%s" % (x, y)
                    for (x, y) in six.iteritems(token_is_max_context)
                ]))
                print("input_ids: %s" % " ".join([str(x) for x in input_ids]))
                print("input_mask: %s" % " ".join([str(x)
                                                   for x in input_mask]))
                print("segment_ids: %s" %
                      " ".join([str(x) for x in segment_ids]))

                if is_training and span_is_impossible:
                    print("impossible example span")

                if is_training and not span_is_impossible:
                    pieces = [
                        sp_model.IdToPiece(token)
                        for token in tokens[start_position:(end_position + 1)]
                    ]
                    answer_text = sp_model.DecodePieces(pieces)
                    print("start_position: %d" % (start_position))
                    print("end_position: %d" % (end_position))
                    print("answer: %s" % (printable_text(answer_text)))

                    # note(zhiliny): With multi processing,
                    # the example_index is actually the index within the current process
                    # therefore we use example_index=None to avoid being used in the future.
                    # The current code does not use example_index of training data.
            if is_training:
                feat_example_index = None
            else:
                feat_example_index = example_index

            feature = InputFeatures(
                unique_id=unique_id,
                example_index=feat_example_index,
                doc_span_index=doc_span_index,
                tok_start_to_orig_index=cur_tok_start_to_orig_index,
                tok_end_to_orig_index=cur_tok_end_to_orig_index,
                token_is_max_context=token_is_max_context,
                input_ids=input_ids,
                input_mask=input_mask,
                p_mask=p_mask,
                segment_ids=segment_ids,
                paragraph_len=paragraph_len,
                cls_index=cls_index,
                start_position=start_position,
                end_position=end_position,
                is_impossible=span_is_impossible)

            unique_id += 1
            if span_is_impossible:
                cnt_neg += 1
            else:
                cnt_pos += 1

            yield feature

    print("Total number of instances: {} = pos {} neg {}".format(
        cnt_pos + cnt_neg, cnt_pos, cnt_neg))
Example #6
0
    def record(self, predicts):
        """
        Write all tokens with labels to file, to be processed in separate script
        """

        decoded_tokens_file = open("{}/results/decoded_tokens.txt".format(
            FLAGS.output_dir),
                                   "w",
                                   encoding="utf8")
        decoded_tokens_file.close()
        results_path = os.path.join(FLAGS.output_dir, FLAGS.output_file)
        results_file = open(results_path, "w", encoding="utf8")
        results_file.close()
        results_file = open(results_path, "a", encoding="utf8")

        decoded_results = []
        for predict, guid in zip(predicts, self.guids):
            try:
                input_tokens = self.tokenizer.ids_to_tokens(
                    predict["input_ids"])
                input_masks = predict["input_masks"]
                # Don't need labels; they are placeholders
                #input_labels = [self.label_list[idx] for idx in predict["label_ids"]]
                output_predicts = [
                    self.label_list[idx] for idx in predict["predict_ids"]
                ]
            except Exception as e:
                print(e)
                print(predict['label_ids'])
                print(self.label_list)
                raise

            decoded_tokens = []
            decoded_labels = []
            decoded_predicts = []
            results = zip(input_tokens, input_masks, output_predicts)
            for input_token, input_mask, output_predict in results:
                if input_token in ["<cls>", "<sep>", "<pad>"
                                   ] or input_mask == 1:
                    continue
                if output_predict in ["<pad>", "<cls>", "<sep>"]:
                    output_predict = "O"

                if input_token.startswith(prepro_utils.SPIECE_UNDERLINE):
                    decoded_tokens.append(input_token)
                    decoded_predicts.append(output_predict)
                else:
                    # In sentence piece, spaces are represented by __
                    # Subwords don't have __ before them
                    decoded_tokens[-1] = decoded_tokens[-1] + input_token

                # Write all tokens to file
                input_token = input_token.strip(prepro_utils.SPIECE_UNDERLINE)
                results_file.write("{0}\t{1}\t{2}\n".format(
                    guid, input_token, output_predict))

            decoded_text = "".join(decoded_tokens).replace(
                prepro_utils.SPIECE_UNDERLINE, " ")
            decoded_predict = " ".join(decoded_predicts)

            decoded_result = {
                "text": prepro_utils.printable_text(decoded_text),
                "predict": decoded_predict,
            }

            decoded_results.append(decoded_result)
            self._write_output(decoded_tokens, decoded_predicts, guid)
Example #7
0
    def convert_single_example(self, example, logging=False):
        """Converts a single `InputExample` into a single `InputFeatures`."""
        default_feature = InputFeatures(input_ids=[0] * self.max_seq_length,
                                        input_masks=[1] * self.max_seq_length,
                                        segment_ids=[0] * self.max_seq_length,
                                        label_ids=[0] * self.max_seq_length)

        if isinstance(example, PaddingInputExample):
            return default_feature

        token_items = self.tokenizer.tokenize(example.text)

        if FLAGS.do_train:
            label_items = example.label.split(" ")

            if len(label_items) != len([
                    token for token in token_items
                    if token.startswith(prepro_utils.SPIECE_UNDERLINE)
            ]):
                return default_feature

        tokens = []
        labels = []
        idx = 0
        for token in token_items:
            if FLAGS.do_train:
                if token.startswith(prepro_utils.SPIECE_UNDERLINE):
                    label = label_items[idx]
                    idx += 1
                else:
                    label = "X"
                labels.append(label)
            tokens.append(token)

        if len(tokens) > self.max_seq_length - 2:
            tokens = tokens[0:(self.max_seq_length - 2)]

        if FLAGS.do_train:
            if len(labels) > self.max_seq_length - 2:
                labels = labels[0:(self.max_seq_length - 2)]

        printable_tokens = [
            prepro_utils.printable_text(token) for token in tokens
        ]

        # The convention in XLNet is:
        # (a) For sequence pairs:
        #  tokens:      is it a dog ? [SEP] no , it is not . [SEP] [CLS]
        #  segment_ids: 0  0  0 0   0 0     1  1 1  1  1   1 1     2
        # (b) For single sequences:
        #  tokens:      this dog is big . [SEP] [CLS]
        #  segment_ids: 0    0   0  0   0 0     2
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambiguously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the last vector (corresponding to [CLS]) is
        # used as the "sentence vector". Note that this only makes sense when
        # the entire model is fine-tuned.
        input_tokens = []
        segment_ids = []
        label_ids = []

        for i, token in enumerate(tokens):
            input_tokens.append(token)
            segment_ids.append(self.segment_vocab_map["<a>"])
            if FLAGS.do_train:
                label_ids.append(self.label_map[labels[i]])

        input_tokens.append("<sep>")
        segment_ids.append(self.segment_vocab_map["<a>"])
        if FLAGS.do_train:
            label_ids.append(self.label_map["<sep>"])
            label_ids.append(self.label_map["<cls>"])
        input_tokens.append("<cls>")
        segment_ids.append(self.segment_vocab_map["<cls>"])

        input_ids = self.tokenizer.tokens_to_ids(input_tokens)
        # The mask has 0 for real tokens and 1 for padding tokens. Only real tokens are attended to.
        input_masks = [0] * len(input_ids)

        # Zero-pad up to the sequence length.
        if len(input_ids) < self.max_seq_length:
            pad_seq_length = self.max_seq_length - len(input_ids)
            input_ids = [self.special_vocab_map["<pad>"]
                         ] * pad_seq_length + input_ids
            input_masks = [1] * pad_seq_length + input_masks
            segment_ids = [self.segment_vocab_map["<pad>"]
                           ] * pad_seq_length + segment_ids
            if FLAGS.do_train:
                label_ids = [self.label_map["<pad>"]
                             ] * pad_seq_length + label_ids

        if not FLAGS.do_train:
            # "Hack" to not have to deal with labels
            # for inference
            label_len = len(segment_ids)
            for i in range(label_len):
                label_ids.append(4)

        assert len(input_ids) == self.max_seq_length
        assert len(input_masks) == self.max_seq_length
        assert len(segment_ids) == self.max_seq_length
        assert len(label_ids) == self.max_seq_length

        if logging:
            tf.logging.info("*** Example ***")
            tf.logging.info("guid: %s" % (example.guid))
            tf.logging.info("tokens: %s" % " ".join(printable_tokens))
            tf.logging.info("labels: %s" % " ".join(labels))
            tf.logging.info("input_ids: %s" %
                            " ".join([str(x) for x in input_ids]))
            tf.logging.info("input_masks: %s" %
                            " ".join([str(x) for x in input_masks]))
            tf.logging.info("segment_ids: %s" %
                            " ".join([str(x) for x in segment_ids]))
            if FLAGS.do_train:
                tf.logging.info("label_ids: %s" %
                                " ".join([str(x) for x in label_ids]))

        example.sp_tokens = input_tokens

        feature = InputFeatures(input_ids=input_ids,
                                input_masks=input_masks,
                                segment_ids=segment_ids,
                                label_ids=label_ids)

        return feature
Example #8
0
def convert_single_example(ex_index, example, label_list, max_seq_length, tokenize_fn):
    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i
    with open(FLAGS.output_dir + "/label2id.pkl", "wb") as w:
        pickle.dump(label_map, w)
    textlist = example.text.split(" ")
    labellist = example.label.split(" ")
    tokens = []
    labels = []
    for i, word in enumerate(textlist):
        # print(word)
        token = [preprocess_text(word, lower=True)]
        # print(token)
        tokens.extend(token)
        label_1 = labellist[i]
        for m in range(len(token)):
            if m == 0:
                labels.append(label_1)
            else:
                labels.append("X")

    if len(tokens) >= max_seq_length - 1:
        tokens = tokens[0:(max_seq_length - 2)]
        labels = labels[0:(max_seq_length - 2)]
    ntokens = []
    segment_ids = []
    label_ids = []
    input_ids = []
    ntokens.append("[CLS]")
    segment_ids.append(0)
    # append("O") or append("[CLS]") not sure!
    label_ids.append(label_map["[CLS]"])
    # print('tokens')
    # print(tokens)
    for i, token in enumerate(tokens):
        ntokens.append(token)
        segment_ids.append(0)
        label_ids.append(label_map[labels[i]])
    #ntokens.append("[SEP]")
    #segment_ids.append(0)
    # append("O") or append("[SEP]") not sure!
    #label_ids.append(label_map["[SEP]"])
    for i in ntokens:
        input_ids.append(tokenize_fn(i))
    input_mask = [1] * len(input_ids)

    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)
        ntokens.append(0)
        label_ids.append(0)
    # print(len(input_mask))
    # print(len(input_ids))

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length
    assert len(label_ids) == max_seq_length
    assert len(ntokens) == max_seq_length
    if ex_index < 3:
        tf.logging.info("*** Example ***")
        tf.logging.info("guid: %s" % (example.guid))
        tf.logging.info("tokens: %s" % " ".join([printable_text(x) for x in tokens]))
        tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids]))

    # print('input_ids')
    # print(input_ids)
    # print(type(input_ids))

    feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        label_ids=label_ids)
    return feature
Example #9
0
    def convert_single_example(self, example, logging=False):
        """Converts a single `InputExample` into a single `InputFeatures`."""
        default_feature = InputFeatures(input_ids=[0] * self.max_seq_length,
                                        input_mask=[1] * self.max_seq_length,
                                        segment_ids=[0] * self.max_seq_length,
                                        label_id=0)

        if isinstance(example, PaddingInputExample):
            return default_feature

        tokens_a = self.tokenizer.tokenize(example.text_a)
        tokens_b = None
        if example.text_b:
            tokens_b = self.tokenizer.tokenize(example.text_b)

        if tokens_b:
            self._truncate_seq_pair(tokens_a, tokens_b,
                                    self.max_seq_length - 3)
        else:
            if len(tokens_a) > self.max_seq_length - 2:
                tokens_a = tokens_a[0:(self.max_seq_length - 2)]

        # The convention in XLNet is:
        # (a) For sequence pairs:
        #  tokens:      is it a dog ? [SEP] no , it is not . [SEP] [CLS]
        #  segment_ids: 0  0  0 0   0 0     1  1 1  1  1   1 1     2
        # (b) For single sequences:
        #  tokens:      this dog is big . [SEP] [CLS]
        #  segment_ids: 0    0   0  0   0 0     2
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambiguously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the last vector (corresponding to [CLS]) is
        # used as the "sentence vector". Note that this only makes sense when
        # the entire model is fine-tuned.
        tokens = []
        segment_ids = []

        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(self.segment_vocab_map["<a>"])

        tokens.append("<sep>")
        segment_ids.append(self.segment_vocab_map["<a>"])

        if tokens_b:
            for token in tokens_b:
                tokens.append(token)
                segment_ids.append(self.segment_vocab_map["<b>"])

            tokens.append("<sep>")
            segment_ids.append(self.segment_vocab_map["<b>"])

        tokens.append("<cls>")
        segment_ids.append(self.segment_vocab_map["<cls>"])

        input_ids = tokenizer.tokens_to_ids(tokens)

        # The mask has 0 for real tokens and 1 for padding tokens. Only real tokens are attended to.
        input_mask = [0] * len(input_ids)

        # Zero-pad up to the sequence length.
        if len(input_ids) < self.max_seq_length:
            pad_seq_length = self.max_seq_length - len(input_ids)
            input_ids = [self.special_vocab_map["<pad>"]
                         ] * pad_seq_length + input_ids
            input_mask = [1] * pad_seq_length + input_mask
            segment_ids = [self.segment_vocab_map["<pad>"]
                           ] * pad_seq_length + segment_ids

        assert len(input_ids) == self.max_seq_length
        assert len(input_mask) == self.max_seq_length
        assert len(segment_ids) == self.max_seq_length

        label_id = label_map[example.label]
        if logging:
            tf.logging.info("*** Example ***")
            tf.logging.info("guid: %s" % (example.guid))
            tf.logging.info("tokens: %s" % " ".join(
                [prepro_utils.printable_text(token) for token in tokens]))
            tf.logging.info("input_ids: %s" %
                            " ".join([str(x) for x in input_ids]))
            tf.logging.info("input_mask: %s" %
                            " ".join([str(x) for x in input_mask]))
            tf.logging.info("segment_ids: %s" %
                            " ".join([str(x) for x in segment_ids]))
            tf.logging.info("label: %s (id = %d)" % (example.label, label_id))

        feature = InputFeatures(input_ids=input_ids,
                                input_mask=input_mask,
                                segment_ids=segment_ids,
                                label_id=label_id)

        return feature