예제 #1
0
 def get_ptune_query(
     self,
     content: Dict,
     prompt_token_id: int,
     max_seq_len: int,
     templates: List[int],
     tokenizer: TokenizerSpec,
 ):
     text_a = content['sentence']
     text_b = content['question']
     sentence_a = f" Paragraph: {text_a}"
     sentence_b = f" Question: {text_b}?"
     a_input_token_ids = tokenizer.text_to_ids(sentence_a)
     b_input_token_ids = tokenizer.text_to_ids(sentence_b)
     c_input_token_ids = tokenizer.text_to_ids(" Answer:")
     cut = 0
     total_num_ids = len(a_input_token_ids) + len(b_input_token_ids) + len(
         c_input_token_ids) + sum(templates)
     if total_num_ids > max_seq_len:
         logging.warning(
             "Input sequence is longer than the LM model max seq, will cut it off to fit"
         )
         cut = total_num_ids - max_seq_len
     return ([prompt_token_id] * templates[0] + a_input_token_ids[cut:] +
             [prompt_token_id] * templates[1] + b_input_token_ids +
             [prompt_token_id] * templates[2] + c_input_token_ids)
예제 #2
0
 def get_ptune_query(
     self,
     content: Dict,
     prompt_token_id: int,
     max_seq_len: int,
     templates: List[int],
     tokenizer: TokenizerSpec,
 ):
     all_ids = []
     limits = []
     for piece in self.pieces:
         if isinstance(piece, str):
             # replace variables if any
             variables = re.findall(r'{\w*}', piece)
             variable_text = {}
             limit_length = False
             for var in variables:
                 varname = var[1:-1]
                 variable_text[varname] = content[varname]
                 if varname == self.limit_length_field:
                     limit_length = True
             text = piece.format(**variable_text)
             text_ids = tokenizer.text_to_ids(text)
             all_ids.append(text_ids)
             limits.append(limit_length)
         else:
             # this is virtual token
             all_ids.append([prompt_token_id] * templates[piece])
             limits.append(False)
     total_num_of_ids = sum([len(i) for i in all_ids])
     if total_num_of_ids > max_seq_len:
         logging.warning(
             "Input sequence is longer than the LM model max seq, will cut it off to fit"
         )
         cut = total_num_of_ids - max_seq_len
         new_ids = []
         for i in range(len(limits)):
             if limits[i]:
                 if len(all_ids[i]) < cut:
                     raise ValueError(
                         f"Some other field length is too long, cutting {self.limit_length_field} is not enough"
                     )
                 new_ids.append(all_ids[i][cut:])
             else:
                 new_ids.append(all_ids[i])
         return functools.reduce(lambda x, y: x + y, new_ids)
     else:
         return functools.reduce(lambda x, y: x + y, all_ids)
def _get_subtokens_and_subtokens_mask(
        query: str, tokenizer: TokenizerSpec) -> Tuple[List[str], List[int]]:
    """
    Tokenizes input query into subtokens and creates subtokens mask. Subtokens mask is an array of the same length as
    subtokens array and contains zeros and ones in which. If element of mask equals 1, then corresponding subtoken in
    subtokens array is first subtoken in some word
    Args:
        query: a string that will be tokenized
        tokenizer: an instance of tokenizer
    Returns:
        subtokens: list of subtokens
        subtokens_mask: list of ints
    """
    words = query.strip().split()
    subtokens = []
    subtokens_mask = []
    for j, word in enumerate(words):
        word_tokens = tokenizer.text_to_tokens(word)
        subtokens.extend(word_tokens)
        subtokens_mask.append(1)
        subtokens_mask.extend([0] * (len(word_tokens) - 1))
    return subtokens, subtokens_mask
    def convert_examples_to_features(
        self,
        examples: List[str],
        label_list: List[int],
        max_seq_length: int,
        tokenizer: TokenizerSpec,
        output_mode: str,
        bos_token: str = None,
        eos_token: str = '[SEP]',
        pad_token: str = '[PAD]',
        cls_token: str = '[CLS]',
        sep_token_extra: str = None,
        cls_token_at_end: bool = False,
        cls_token_segment_id: int = 0,
        pad_token_segment_id: int = 0,
        pad_on_left: bool = False,
        mask_padding_with_zero: bool = True,
        sequence_a_segment_id: int = 0,
        sequence_b_segment_id: int = 1,
    ):
        """
        Loads a data file into a list of `InputBatch`s.
        The `cls_token_at_end` defines the location of the CLS token:
            * False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            * True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        The `cls_token_segment_id` defines the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
        
        The convention in BERT is:
        
            a. For sequence pairs:
                * tokens:   [CLS] is this jack ##ville ? [SEP] no it is not . [SEP]
                * type_ids:   0   0  0    0    0       0   0   1  1  1  1   1   1
            b. For single sequences:
                * tokens:   [CLS] the dog is hairy . [SEP]
                * type_ids:   0   0   0   0  0     0   0
        Where "type_ids" are used to indicate whether this is the first
        sequence or the second sequence. The embedding vectors for `type=0`
        and `type=1` were learned during pre-training and are added to the
        wordpiece embedding vector (and position vector). This is
        not *strictly* necessarysince the [SEP] token unambiguously separates
        the sequences, but it makes it easier for the model to learn
        the concept of sequences.
        For classification tasks, the first vector (corresponding to [CLS])
        is used as as the "sentence vector". Note that this only makes sense
        because the entire model is fine-tuned.
        
        The convention for NMT is:
        
            a. For sequence pairs:
                * tokens:<BOS> is this jack ##ville ? <EOS> <BOS> no it is not . <EOS>
                * type_ids:0   0  0    0    0       0   0     1   1  1  1  1   1   1
            b. For single sequences:
                * tokens:   <BOS> the dog is hairy . <EOS>
                * type_ids:   0   0   0   0  0     0   0
        """
        label_map = {label: i for i, label in enumerate(label_list)}

        features = []
        for ex_index, example in enumerate(examples):
            if ex_index % 10000 == 0:
                logging.info("Writing example %d of %d" %
                             (ex_index, len(examples)))

            tokens_a = tokenizer.text_to_tokens(example.text_a)

            tokens_b = None
            if example.text_b:
                tokens_b = tokenizer.text_to_tokens(example.text_b)

                special_tokens_count = 2 if eos_token else 0
                special_tokens_count += 1 if sep_token_extra else 0
                special_tokens_count += 2 if bos_token else 0
                special_tokens_count += 1 if cls_token else 0
                self._truncate_seq_pair(tokens_a, tokens_b,
                                        max_seq_length - special_tokens_count)
            else:
                special_tokens_count = 1 if eos_token else 0
                special_tokens_count += 1 if sep_token_extra else 0
                special_tokens_count += 1 if bos_token else 0
                if len(tokens_a) > max_seq_length - special_tokens_count:
                    tokens_a = tokens_a[:max_seq_length - special_tokens_count]
            # Add special tokens to sequence_a
            tokens = tokens_a
            if bos_token:
                tokens = [bos_token] + tokens
            if eos_token:
                tokens += [eos_token]
            segment_ids = [sequence_a_segment_id] * len(tokens)

            # Add sequence separator between sequences
            if tokens_b and sep_token_extra:
                tokens += [sep_token_extra]
                segment_ids += [sequence_a_segment_id]

            # Add special tokens to sequence_b
            if tokens_b:
                if bos_token:
                    tokens += [bos_token]
                    segment_ids += [sequence_b_segment_id]
                tokens += tokens_b
                segment_ids += [sequence_b_segment_id] * (len(tokens_b))
                if eos_token:
                    tokens += [eos_token]
                    segment_ids += [sequence_b_segment_id]

            # Add classification token - for BERT models
            if cls_token:
                if cls_token_at_end:
                    tokens += [cls_token]
                    segment_ids += [cls_token_segment_id]
                else:
                    tokens = [cls_token] + tokens
                    segment_ids = [cls_token_segment_id] + segment_ids
            input_ids = tokenizer.tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding_length = max_seq_length - len(input_ids)
            pad_token_id = tokenizer.tokens_to_ids([pad_token])[0]
            if pad_on_left:
                input_ids = ([pad_token_id] * padding_length) + input_ids
                input_mask = ([0 if mask_padding_with_zero else 1] *
                              padding_length) + input_mask
                segment_ids = ([pad_token_segment_id] *
                               padding_length) + segment_ids
            else:
                input_ids = input_ids + ([pad_token_id] * padding_length)
                input_mask = input_mask + (
                    [0 if mask_padding_with_zero else 1] * padding_length)
                segment_ids = segment_ids + ([pad_token_segment_id] *
                                             padding_length)
            if len(input_ids) != max_seq_length:
                raise ValueError("input_ids must be of length max_seq_length")
            if len(input_mask) != max_seq_length:
                raise ValueError("input_mask must be of length max_seq_length")
            if len(segment_ids) != max_seq_length:
                raise ValueError(
                    "segment_ids must be of length max_seq_length")
            if output_mode == "classification":
                label_id = label_map[example.label]
            elif output_mode == "regression":
                label_id = np.float32(example.label)
            else:
                raise KeyError(output_mode)

            if ex_index < 5:
                logging.info("*** Example ***")
                logging.info("guid: %s" % (example.guid))
                logging.info("tokens: %s" % " ".join(list(map(str, tokens))))
                logging.info("input_ids: %s" %
                             " ".join(list(map(str, input_ids))))
                logging.info("input_mask: %s" %
                             " ".join(list(map(str, input_mask))))
                logging.info("segment_ids: %s" %
                             " ".join(list(map(str, segment_ids))))
                logging.info("label: %s (id = %d)" % (example.label, label_id))

            features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_id=label_id))
        return features
예제 #5
0
def get_features(
    queries: List[str],
    max_seq_length: int,
    tokenizer: TokenizerSpec,
    punct_label_ids: dict = None,
    capit_label_ids: dict = None,
    pad_label: str = 'O',
    punct_labels_lines=None,
    capit_labels_lines=None,
    ignore_extra_tokens=False,
    ignore_start_end: Optional[bool] = False,
):
    """
    Processes the data and returns features.

    Args:
        queries: text sequences
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
        tokenizer: such as AutoTokenizer
        pad_label: pad value use for labels. By default, it's the neutral label.
        punct_label_ids: dict to map punctuation labels to label ids.
            Starts with pad_label->0 and then increases in alphabetical order.
            Required for training and evaluation, not needed for inference.
        capit_label_ids: dict to map labels to label ids. Starts
            with pad_label->0 and then increases in alphabetical order.
            Required for training and evaluation, not needed for inference.
        punct_labels: list of labels for every word in a sequence (str)
        capit_labels: list of labels for every word in a sequence (str)
        ignore_extra_tokens: whether to ignore extra tokens in the loss_mask
        ignore_start_end: whether to ignore bos and eos tokens in the loss_mask

    Returns:
        all_input_ids: input ids for all tokens
        all_segment_ids: token type ids
        all_input_mask: attention mask to use for BERT model
        all_subtokens_mask: masks out all subwords besides the first one
        all_loss_mask: loss mask to mask out tokens during training
        punct_all_labels: all labels for punctuation task (ints)
        capit_all_labels: all labels for capitalization task (ints)
        punct_label_ids: label (str) to id (int) map for punctuation task
        capit_label_ids: label (str) to id (int) map for capitalization task
    """
    all_subtokens = []
    all_loss_mask = []
    all_subtokens_mask = []
    all_segment_ids = []
    all_input_ids = []
    all_input_mask = []
    sent_lengths = []
    punct_all_labels = []
    capit_all_labels = []
    with_label = False

    if punct_labels_lines and capit_labels_lines:
        with_label = True

    for i, query in enumerate(queries):
        words = query.strip().split()

        # add bos token
        subtokens = [tokenizer.cls_token]
        loss_mask = [1 - ignore_start_end]
        subtokens_mask = [0]
        if with_label:
            pad_id = punct_label_ids[pad_label]
            punct_labels = [pad_id]
            punct_query_labels = [
                punct_label_ids[lab] for lab in punct_labels_lines[i]
            ]

            capit_labels = [pad_id]
            capit_query_labels = [
                capit_label_ids[lab] for lab in capit_labels_lines[i]
            ]

        for j, word in enumerate(words):
            word_tokens = tokenizer.text_to_tokens(word)
            subtokens.extend(word_tokens)

            loss_mask.append(1)
            loss_mask.extend([int(not ignore_extra_tokens)] *
                             (len(word_tokens) - 1))

            subtokens_mask.append(1)
            subtokens_mask.extend([0] * (len(word_tokens) - 1))

            if with_label:
                punct_labels.extend([punct_query_labels[j]] * len(word_tokens))
                capit_labels.extend([capit_query_labels[j]] * len(word_tokens))

        # add eos token
        subtokens.append(tokenizer.sep_token)
        loss_mask.append(1 - ignore_start_end)
        subtokens_mask.append(0)
        sent_lengths.append(len(subtokens))
        all_subtokens.append(subtokens)
        all_loss_mask.append(loss_mask)
        all_subtokens_mask.append(subtokens_mask)
        all_input_mask.append([1] * len(subtokens))

        if with_label:
            punct_labels.append(pad_id)
            punct_all_labels.append(punct_labels)
            capit_labels.append(pad_id)
            capit_all_labels.append(capit_labels)

    max_seq_length = min(max_seq_length, max(sent_lengths))
    logging.info(f'Max length: {max_seq_length}')
    get_stats(sent_lengths)
    too_long_count = 0

    for i, subtokens in enumerate(all_subtokens):
        if len(subtokens) > max_seq_length:
            subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1:]
            all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1:]
            all_loss_mask[i] = [int(not ignore_start_end)
                                ] + all_loss_mask[i][-max_seq_length + 1:]
            all_subtokens_mask[i] = [
                0
            ] + all_subtokens_mask[i][-max_seq_length + 1:]

            if with_label:
                punct_all_labels[i] = [
                    pad_id
                ] + punct_all_labels[i][-max_seq_length + 1:]
                capit_all_labels[i] = [
                    pad_id
                ] + capit_all_labels[i][-max_seq_length + 1:]
            too_long_count += 1

        all_input_ids.append(tokenizer.tokens_to_ids(subtokens))

        if len(subtokens) < max_seq_length:
            extra = max_seq_length - len(subtokens)
            all_input_ids[i] = all_input_ids[i] + [0] * extra
            all_loss_mask[i] = all_loss_mask[i] + [0] * extra
            all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra
            all_input_mask[i] = all_input_mask[i] + [0] * extra

            if with_label:
                punct_all_labels[i] = punct_all_labels[i] + [pad_id] * extra
                capit_all_labels[i] = capit_all_labels[i] + [pad_id] * extra

        all_segment_ids.append([0] * max_seq_length)

    logging.info(f'{too_long_count} are longer than {max_seq_length}')

    for i in range(min(len(all_input_ids), 5)):
        logging.info("*** Example ***")
        logging.info("i: %s" % (i))
        logging.info("subtokens: %s" %
                     " ".join(list(map(str, all_subtokens[i]))))
        logging.info("loss_mask: %s" %
                     " ".join(list(map(str, all_loss_mask[i]))))
        logging.info("input_mask: %s" %
                     " ".join(list(map(str, all_input_mask[i]))))
        logging.info("subtokens_mask: %s" %
                     " ".join(list(map(str, all_subtokens_mask[i]))))
        if with_label:
            logging.info("punct_labels: %s" %
                         " ".join(list(map(str, punct_all_labels[i]))))
            logging.info("capit_labels: %s" %
                         " ".join(list(map(str, capit_all_labels[i]))))

    return (
        all_input_ids,
        all_segment_ids,
        all_input_mask,
        all_subtokens_mask,
        all_loss_mask,
        punct_all_labels,
        capit_all_labels,
        punct_label_ids,
        capit_label_ids,
    )
def get_features_infer(
    queries: List[str],
    tokenizer: TokenizerSpec,
    max_seq_length: int = 64,
    step: Optional[int] = 8,
    margin: Optional[int] = 16,
) -> Tuple[List[List[int]], List[List[int]], List[List[int]], List[List[int]],
           List[int], List[int], List[bool], List[bool], ]:
    """
    Processes the data and returns features.

    Args:
        queries: text sequences
        tokenizer: such as AutoTokenizer
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
        step: relative shift of consequent segments into which long queries are split. Long queries are split into
            segments which can overlap. Parameter ``step`` controls such overlapping. Imagine that queries are
            tokenized into characters, ``max_seq_length=5``, and ``step=2``. In such a case query "hello" is
            tokenized into segments ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'l', 'l', 'o', '[SEP]']]``.
        margin: number of subtokens near edges of segments which are not used for punctuation and capitalization
            prediction. The first segment does not have left margin and the last segment does not have right
            margin. For example, if input sequence is tokenized into characters, ``max_seq_length=5``,
            ``step=1``, and ``margin=1``, then query "hello" will be tokenized into segments
            ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'e', 'l', 'l', '[SEP]'],
            ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. These segments are passed to the model. Before final predictions
            computation, margins are removed. In the next list, subtokens which logits are not used for final
            predictions computation are marked with asterisk: ``[['[CLS]'*, 'h', 'e', 'l'*, '[SEP]'*],
            ['[CLS]'*, 'e'*, 'l', 'l'*, '[SEP]'*], ['[CLS]'*, 'l'*, 'l', 'o', '[SEP]'*]]``.

    Returns:
        all_input_ids: list of input ids of all segments
        all_segment_ids: token type ids of all segments
        all_input_mask: attention mask to use for BERT model
        all_subtokens_mask: masks out all subwords besides the first one
        all_quantities_of_preceding_words: number of words in query preceding a segment. Used for joining
            predictions from overlapping segments.
        all_query_ids: index of a query to which segment belongs
        all_is_first: is segment first segment in a query
        all_is_last: is segment last segment in a query
    """
    st = []
    stm = []
    sent_lengths = []
    for i, query in enumerate(queries):
        subtokens, subtokens_mask = _get_subtokens_and_subtokens_mask(
            query, tokenizer)
        sent_lengths.append(len(subtokens))
        st.append(subtokens)
        stm.append(subtokens_mask)
    _check_max_seq_length_and_margin_and_step(max_seq_length, margin, step)
    if max_seq_length > max(sent_lengths) + 2:
        max_seq_length = max(sent_lengths) + 2
        # If `max_seq_length` is greater than maximum length of input query, parameters ``margin`` and ``step`` are
        # not used will not be used.
        step = 1
        # Maximum number of word subtokens in segment. The first and the last tokens in segment are CLS and EOS
        length = max_seq_length - 2
    else:
        # Maximum number of word subtokens in segment. The first and the last tokens in segment are CLS and EOS
        length = max_seq_length - 2
        step = min(length - margin * 2, step)
    logging.info(f'Max length: {max_seq_length}')
    get_stats(sent_lengths)
    all_input_ids, all_segment_ids, all_subtokens_mask, all_input_mask, all_input_mask = [], [], [], [], []
    all_quantities_of_preceding_words, all_query_ids, all_is_first, all_is_last = [], [], [], []
    for q_i, query_st in enumerate(st):
        q_inp_ids, q_segment_ids, q_subtokens_mask, q_inp_mask, q_quantities_of_preceding_words = [], [], [], [], []
        for i in range(0, max(len(query_st), length) - length + step, step):
            subtokens = [tokenizer.cls_token
                         ] + query_st[i:i + length] + [tokenizer.sep_token]
            q_inp_ids.append(tokenizer.tokens_to_ids(subtokens))
            q_segment_ids.append([0] * len(subtokens))
            q_subtokens_mask.append([0] + stm[q_i][i:i + length] + [0])
            q_inp_mask.append([1] * len(subtokens))
            q_quantities_of_preceding_words.append(
                np.count_nonzero(stm[q_i][:i]))
        all_input_ids.append(q_inp_ids)
        all_segment_ids.append(q_segment_ids)
        all_subtokens_mask.append(q_subtokens_mask)
        all_input_mask.append(q_inp_mask)
        all_quantities_of_preceding_words.append(
            q_quantities_of_preceding_words)
        all_query_ids.append([q_i] * len(q_inp_ids))
        all_is_first.append([True] + [False] * (len(q_inp_ids) - 1))
        all_is_last.append([False] * (len(q_inp_ids) - 1) + [True])
    return (
        list(itertools.chain(*all_input_ids)),
        list(itertools.chain(*all_segment_ids)),
        list(itertools.chain(*all_input_mask)),
        list(itertools.chain(*all_subtokens_mask)),
        list(itertools.chain(*all_quantities_of_preceding_words)),
        list(itertools.chain(*all_query_ids)),
        list(itertools.chain(*all_is_first)),
        list(itertools.chain(*all_is_last)),
    )
예제 #7
0
def get_features(
    queries: List[str],
    tokenizer: TokenizerSpec,
    max_seq_length: int = -1,
    label_ids: dict = None,
    pad_label: str = 'O',
    raw_labels: List[str] = None,
    ignore_extra_tokens: bool = False,
    ignore_start_end: bool = False,
):
    """
    Processes the data and returns features.
    Args:
        queries: text sequences
        tokenizer: such as AutoTokenizer
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP], when -1 - use the max len from the data
        pad_label: pad value use for labels. By default, it's the neutral label.
        raw_labels: list of labels for every word in a sequence
        label_ids: dict to map labels to label ids.
            Starts with pad_label->0 and then increases in alphabetical order.
            Required for training and evaluation, not needed for inference.
        ignore_extra_tokens: whether to ignore extra tokens in the loss_mask
        ignore_start_end: whether to ignore bos and eos tokens in the loss_mask
    """
    all_subtokens = []
    all_loss_mask = []
    all_subtokens_mask = []
    all_segment_ids = []
    all_input_ids = []
    all_input_mask = []
    sent_lengths = []
    all_labels = []
    with_label = False

    if raw_labels is not None:
        with_label = True

    for i, query in enumerate(queries):
        words = query.strip().split()

        # add bos token
        subtokens = [tokenizer.cls_token]
        loss_mask = [1 - ignore_start_end]
        subtokens_mask = [0]
        if with_label:
            pad_id = label_ids[pad_label]
            labels = [pad_id]
            query_labels = [label_ids[lab] for lab in raw_labels[i]]

        for j, word in enumerate(words):
            word_tokens = tokenizer.text_to_tokens(word)

            # to handle emojis that could be neglected during tokenization
            if len(word.strip()) > 0 and len(word_tokens) == 0:
                word_tokens = [tokenizer.ids_to_tokens(tokenizer.unk_id)]

            subtokens.extend(word_tokens)

            loss_mask.append(1)
            loss_mask.extend([int(not ignore_extra_tokens)] * (len(word_tokens) - 1))

            subtokens_mask.append(1)
            subtokens_mask.extend([0] * (len(word_tokens) - 1))

            if with_label:
                labels.extend([query_labels[j]] * len(word_tokens))
        # add eos token
        subtokens.append(tokenizer.sep_token)
        loss_mask.append(1 - ignore_start_end)
        subtokens_mask.append(0)
        sent_lengths.append(len(subtokens))
        all_subtokens.append(subtokens)
        all_loss_mask.append(loss_mask)
        all_subtokens_mask.append(subtokens_mask)
        all_input_mask.append([1] * len(subtokens))

        if with_label:
            labels.append(pad_id)
            all_labels.append(labels)

    max_seq_length_data = max(sent_lengths)
    max_seq_length = min(max_seq_length, max_seq_length_data) if max_seq_length > 0 else max_seq_length_data
    logging.info(f'Setting Max Seq length to: {max_seq_length}')
    get_stats(sent_lengths)
    too_long_count = 0

    for i, subtokens in enumerate(all_subtokens):
        if len(subtokens) > max_seq_length:
            subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1 :]
            all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1 :]
            all_loss_mask[i] = [int(not ignore_start_end)] + all_loss_mask[i][-max_seq_length + 1 :]
            all_subtokens_mask[i] = [0] + all_subtokens_mask[i][-max_seq_length + 1 :]

            if with_label:
                all_labels[i] = [pad_id] + all_labels[i][-max_seq_length + 1 :]
            too_long_count += 1

        all_input_ids.append(tokenizer.tokens_to_ids(subtokens))

        if len(subtokens) < max_seq_length:
            extra = max_seq_length - len(subtokens)
            all_input_ids[i] = all_input_ids[i] + [0] * extra
            all_loss_mask[i] = all_loss_mask[i] + [0] * extra
            all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra
            all_input_mask[i] = all_input_mask[i] + [0] * extra

            if with_label:
                all_labels[i] = all_labels[i] + [pad_id] * extra

        all_segment_ids.append([0] * max_seq_length)

    logging.warning(f'{too_long_count} are longer than {max_seq_length}')

    for i in range(min(len(all_input_ids), 1)):
        logging.info("*** Example ***")
        logging.info("i: %s", i)
        logging.info("subtokens: %s", " ".join(list(map(str, all_subtokens[i]))))
        logging.info("loss_mask: %s", " ".join(list(map(str, all_loss_mask[i]))))
        logging.info("input_mask: %s", " ".join(list(map(str, all_input_mask[i]))))
        logging.info("subtokens_mask: %s", " ".join(list(map(str, all_subtokens_mask[i]))))
        if with_label:
            logging.info("labels: %s", " ".join(list(map(str, all_labels[i]))))
    return (all_input_ids, all_segment_ids, all_input_mask, all_subtokens_mask, all_loss_mask, all_labels)