Exemplo n.º 1
0
    def get_features(all_sents,
                     tokenizer,
                     max_seq_length,
                     labels=None,
                     verbose=True):
        """Encode a list of sentences into a list of tuples of (input_ids, segment_ids, input_mask, label)."""
        features = []
        sent_lengths = []
        too_long_count = 0
        for sent_id, sent in enumerate(all_sents):
            if sent_id % 1000 == 0:
                logging.debug(f"Encoding sentence {sent_id}/{len(all_sents)}")
            sent_subtokens = [tokenizer.cls_token]
            for word in sent:
                word_tokens = tokenizer.text_to_tokens(word)
                sent_subtokens.extend(word_tokens)

            if max_seq_length > 0 and len(sent_subtokens) + 1 > max_seq_length:
                sent_subtokens = sent_subtokens[:max_seq_length]
                too_long_count += 1

            sent_subtokens.append(tokenizer.sep_token)
            sent_lengths.append(len(sent_subtokens))

            input_ids = [tokenizer.tokens_to_ids(t) for t in sent_subtokens]

            # The mask has 1 for real tokens and 0 for padding tokens.
            # Only real tokens are attended to.
            input_mask = [1] * len(input_ids)
            segment_ids = [0] * len(input_ids)

            if verbose and sent_id < 2:
                logging.info("*** Example ***")
                logging.info(f"example {sent_id}: {sent}")
                logging.info("subtokens: %s" % " ".join(sent_subtokens))
                logging.info("input_ids: %s" % list2str(input_ids))
                logging.info("segment_ids: %s" % list2str(segment_ids))
                logging.info("input_mask: %s" % list2str(input_mask))
                logging.info("label: %s" %
                             labels[sent_id] if labels else "**Not Provided**")

            label = labels[sent_id] if labels else -1
            features.append([
                np.asarray(input_ids),
                np.asarray(segment_ids),
                np.asarray(input_mask), label
            ])

        if max_seq_length > -1 and too_long_count > 0:
            logging.warning(
                f'Found {too_long_count} out of {len(all_sents)} sentences with more than {max_seq_length} subtokens. '
                f'Truncated long sentences from the end.')
        if verbose:
            get_stats(sent_lengths)
        return features
Exemplo n.º 2
0
def get_features(
    queries: List[str],
    max_seq_length: int,
    tokenizer: TokenizerSpec,
    punct_label_ids: dict = None,
    capit_label_ids: dict = None,
    pad_label: str = 'O',
    punct_labels_lines=None,
    capit_labels_lines=None,
    ignore_extra_tokens=False,
    ignore_start_end: Optional[bool] = False,
):
    """
    Processes the data and returns features.

    Args:
        queries: text sequences
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
        tokenizer: such as AutoTokenizer
        pad_label: pad value use for labels. By default, it's the neutral label.
        punct_label_ids: dict to map punctuation labels to label ids.
            Starts with pad_label->0 and then increases in alphabetical order.
            Required for training and evaluation, not needed for inference.
        capit_label_ids: dict to map labels to label ids. Starts
            with pad_label->0 and then increases in alphabetical order.
            Required for training and evaluation, not needed for inference.
        punct_labels: list of labels for every word in a sequence (str)
        capit_labels: list of labels for every word in a sequence (str)
        ignore_extra_tokens: whether to ignore extra tokens in the loss_mask
        ignore_start_end: whether to ignore bos and eos tokens in the loss_mask

    Returns:
        all_input_ids: input ids for all tokens
        all_segment_ids: token type ids
        all_input_mask: attention mask to use for BERT model
        all_subtokens_mask: masks out all subwords besides the first one
        all_loss_mask: loss mask to mask out tokens during training
        punct_all_labels: all labels for punctuation task (ints)
        capit_all_labels: all labels for capitalization task (ints)
        punct_label_ids: label (str) to id (int) map for punctuation task
        capit_label_ids: label (str) to id (int) map for capitalization task
    """
    all_subtokens = []
    all_loss_mask = []
    all_subtokens_mask = []
    all_segment_ids = []
    all_input_ids = []
    all_input_mask = []
    sent_lengths = []
    punct_all_labels = []
    capit_all_labels = []
    with_label = False

    if punct_labels_lines and capit_labels_lines:
        with_label = True

    for i, query in enumerate(queries):
        words = query.strip().split()

        # add bos token
        subtokens = [tokenizer.cls_token]
        loss_mask = [1 - ignore_start_end]
        subtokens_mask = [0]
        if with_label:
            pad_id = punct_label_ids[pad_label]
            punct_labels = [pad_id]
            punct_query_labels = [
                punct_label_ids[lab] for lab in punct_labels_lines[i]
            ]

            capit_labels = [pad_id]
            capit_query_labels = [
                capit_label_ids[lab] for lab in capit_labels_lines[i]
            ]

        for j, word in enumerate(words):
            word_tokens = tokenizer.text_to_tokens(word)
            subtokens.extend(word_tokens)

            loss_mask.append(1)
            loss_mask.extend([int(not ignore_extra_tokens)] *
                             (len(word_tokens) - 1))

            subtokens_mask.append(1)
            subtokens_mask.extend([0] * (len(word_tokens) - 1))

            if with_label:
                punct_labels.extend([punct_query_labels[j]] * len(word_tokens))
                capit_labels.extend([capit_query_labels[j]] * len(word_tokens))

        # add eos token
        subtokens.append(tokenizer.sep_token)
        loss_mask.append(1 - ignore_start_end)
        subtokens_mask.append(0)
        sent_lengths.append(len(subtokens))
        all_subtokens.append(subtokens)
        all_loss_mask.append(loss_mask)
        all_subtokens_mask.append(subtokens_mask)
        all_input_mask.append([1] * len(subtokens))

        if with_label:
            punct_labels.append(pad_id)
            punct_all_labels.append(punct_labels)
            capit_labels.append(pad_id)
            capit_all_labels.append(capit_labels)

    max_seq_length = min(max_seq_length, max(sent_lengths))
    logging.info(f'Max length: {max_seq_length}')
    get_stats(sent_lengths)
    too_long_count = 0

    for i, subtokens in enumerate(all_subtokens):
        if len(subtokens) > max_seq_length:
            subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1:]
            all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1:]
            all_loss_mask[i] = [int(not ignore_start_end)
                                ] + all_loss_mask[i][-max_seq_length + 1:]
            all_subtokens_mask[i] = [
                0
            ] + all_subtokens_mask[i][-max_seq_length + 1:]

            if with_label:
                punct_all_labels[i] = [
                    pad_id
                ] + punct_all_labels[i][-max_seq_length + 1:]
                capit_all_labels[i] = [
                    pad_id
                ] + capit_all_labels[i][-max_seq_length + 1:]
            too_long_count += 1

        all_input_ids.append(tokenizer.tokens_to_ids(subtokens))

        if len(subtokens) < max_seq_length:
            extra = max_seq_length - len(subtokens)
            all_input_ids[i] = all_input_ids[i] + [0] * extra
            all_loss_mask[i] = all_loss_mask[i] + [0] * extra
            all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra
            all_input_mask[i] = all_input_mask[i] + [0] * extra

            if with_label:
                punct_all_labels[i] = punct_all_labels[i] + [pad_id] * extra
                capit_all_labels[i] = capit_all_labels[i] + [pad_id] * extra

        all_segment_ids.append([0] * max_seq_length)

    logging.info(f'{too_long_count} are longer than {max_seq_length}')

    for i in range(min(len(all_input_ids), 5)):
        logging.info("*** Example ***")
        logging.info("i: %s" % (i))
        logging.info("subtokens: %s" %
                     " ".join(list(map(str, all_subtokens[i]))))
        logging.info("loss_mask: %s" %
                     " ".join(list(map(str, all_loss_mask[i]))))
        logging.info("input_mask: %s" %
                     " ".join(list(map(str, all_input_mask[i]))))
        logging.info("subtokens_mask: %s" %
                     " ".join(list(map(str, all_subtokens_mask[i]))))
        if with_label:
            logging.info("punct_labels: %s" %
                         " ".join(list(map(str, punct_all_labels[i]))))
            logging.info("capit_labels: %s" %
                         " ".join(list(map(str, capit_all_labels[i]))))

    return (
        all_input_ids,
        all_segment_ids,
        all_input_mask,
        all_subtokens_mask,
        all_loss_mask,
        punct_all_labels,
        capit_all_labels,
        punct_label_ids,
        capit_label_ids,
    )
def get_features_infer(
    queries: List[str],
    tokenizer: TokenizerSpec,
    max_seq_length: int = 64,
    step: Optional[int] = 8,
    margin: Optional[int] = 16,
) -> Tuple[List[List[int]], List[List[int]], List[List[int]], List[List[int]],
           List[int], List[int], List[bool], List[bool], ]:
    """
    Processes the data and returns features.

    Args:
        queries: text sequences
        tokenizer: such as AutoTokenizer
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
        step: relative shift of consequent segments into which long queries are split. Long queries are split into
            segments which can overlap. Parameter ``step`` controls such overlapping. Imagine that queries are
            tokenized into characters, ``max_seq_length=5``, and ``step=2``. In such a case query "hello" is
            tokenized into segments ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'l', 'l', 'o', '[SEP]']]``.
        margin: number of subtokens near edges of segments which are not used for punctuation and capitalization
            prediction. The first segment does not have left margin and the last segment does not have right
            margin. For example, if input sequence is tokenized into characters, ``max_seq_length=5``,
            ``step=1``, and ``margin=1``, then query "hello" will be tokenized into segments
            ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'e', 'l', 'l', '[SEP]'],
            ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. These segments are passed to the model. Before final predictions
            computation, margins are removed. In the next list, subtokens which logits are not used for final
            predictions computation are marked with asterisk: ``[['[CLS]'*, 'h', 'e', 'l'*, '[SEP]'*],
            ['[CLS]'*, 'e'*, 'l', 'l'*, '[SEP]'*], ['[CLS]'*, 'l'*, 'l', 'o', '[SEP]'*]]``.

    Returns:
        all_input_ids: list of input ids of all segments
        all_segment_ids: token type ids of all segments
        all_input_mask: attention mask to use for BERT model
        all_subtokens_mask: masks out all subwords besides the first one
        all_quantities_of_preceding_words: number of words in query preceding a segment. Used for joining
            predictions from overlapping segments.
        all_query_ids: index of a query to which segment belongs
        all_is_first: is segment first segment in a query
        all_is_last: is segment last segment in a query
    """
    st = []
    stm = []
    sent_lengths = []
    for i, query in enumerate(queries):
        subtokens, subtokens_mask = _get_subtokens_and_subtokens_mask(
            query, tokenizer)
        sent_lengths.append(len(subtokens))
        st.append(subtokens)
        stm.append(subtokens_mask)
    _check_max_seq_length_and_margin_and_step(max_seq_length, margin, step)
    if max_seq_length > max(sent_lengths) + 2:
        max_seq_length = max(sent_lengths) + 2
        # If `max_seq_length` is greater than maximum length of input query, parameters ``margin`` and ``step`` are
        # not used will not be used.
        step = 1
        # Maximum number of word subtokens in segment. The first and the last tokens in segment are CLS and EOS
        length = max_seq_length - 2
    else:
        # Maximum number of word subtokens in segment. The first and the last tokens in segment are CLS and EOS
        length = max_seq_length - 2
        step = min(length - margin * 2, step)
    logging.info(f'Max length: {max_seq_length}')
    get_stats(sent_lengths)
    all_input_ids, all_segment_ids, all_subtokens_mask, all_input_mask, all_input_mask = [], [], [], [], []
    all_quantities_of_preceding_words, all_query_ids, all_is_first, all_is_last = [], [], [], []
    for q_i, query_st in enumerate(st):
        q_inp_ids, q_segment_ids, q_subtokens_mask, q_inp_mask, q_quantities_of_preceding_words = [], [], [], [], []
        for i in range(0, max(len(query_st), length) - length + step, step):
            subtokens = [tokenizer.cls_token
                         ] + query_st[i:i + length] + [tokenizer.sep_token]
            q_inp_ids.append(tokenizer.tokens_to_ids(subtokens))
            q_segment_ids.append([0] * len(subtokens))
            q_subtokens_mask.append([0] + stm[q_i][i:i + length] + [0])
            q_inp_mask.append([1] * len(subtokens))
            q_quantities_of_preceding_words.append(
                np.count_nonzero(stm[q_i][:i]))
        all_input_ids.append(q_inp_ids)
        all_segment_ids.append(q_segment_ids)
        all_subtokens_mask.append(q_subtokens_mask)
        all_input_mask.append(q_inp_mask)
        all_quantities_of_preceding_words.append(
            q_quantities_of_preceding_words)
        all_query_ids.append([q_i] * len(q_inp_ids))
        all_is_first.append([True] + [False] * (len(q_inp_ids) - 1))
        all_is_last.append([False] * (len(q_inp_ids) - 1) + [True])
    return (
        list(itertools.chain(*all_input_ids)),
        list(itertools.chain(*all_segment_ids)),
        list(itertools.chain(*all_input_mask)),
        list(itertools.chain(*all_subtokens_mask)),
        list(itertools.chain(*all_quantities_of_preceding_words)),
        list(itertools.chain(*all_query_ids)),
        list(itertools.chain(*all_is_first)),
        list(itertools.chain(*all_is_last)),
    )
Exemplo n.º 4
0
def get_features(
    queries: List[str],
    tokenizer: TokenizerSpec,
    max_seq_length: int = -1,
    label_ids: dict = None,
    pad_label: str = 'O',
    raw_labels: List[str] = None,
    ignore_extra_tokens: bool = False,
    ignore_start_end: bool = False,
):
    """
    Processes the data and returns features.
    Args:
        queries: text sequences
        tokenizer: such as AutoTokenizer
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP], when -1 - use the max len from the data
        pad_label: pad value use for labels. By default, it's the neutral label.
        raw_labels: list of labels for every word in a sequence
        label_ids: dict to map labels to label ids.
            Starts with pad_label->0 and then increases in alphabetical order.
            Required for training and evaluation, not needed for inference.
        ignore_extra_tokens: whether to ignore extra tokens in the loss_mask
        ignore_start_end: whether to ignore bos and eos tokens in the loss_mask
    """
    all_subtokens = []
    all_loss_mask = []
    all_subtokens_mask = []
    all_segment_ids = []
    all_input_ids = []
    all_input_mask = []
    sent_lengths = []
    all_labels = []
    with_label = False

    if raw_labels is not None:
        with_label = True

    for i, query in enumerate(queries):
        words = query.strip().split()

        # add bos token
        subtokens = [tokenizer.cls_token]
        loss_mask = [1 - ignore_start_end]
        subtokens_mask = [0]
        if with_label:
            pad_id = label_ids[pad_label]
            labels = [pad_id]
            query_labels = [label_ids[lab] for lab in raw_labels[i]]

        for j, word in enumerate(words):
            word_tokens = tokenizer.text_to_tokens(word)

            # to handle emojis that could be neglected during tokenization
            if len(word.strip()) > 0 and len(word_tokens) == 0:
                word_tokens = [tokenizer.ids_to_tokens(tokenizer.unk_id)]

            subtokens.extend(word_tokens)

            loss_mask.append(1)
            loss_mask.extend([int(not ignore_extra_tokens)] * (len(word_tokens) - 1))

            subtokens_mask.append(1)
            subtokens_mask.extend([0] * (len(word_tokens) - 1))

            if with_label:
                labels.extend([query_labels[j]] * len(word_tokens))
        # add eos token
        subtokens.append(tokenizer.sep_token)
        loss_mask.append(1 - ignore_start_end)
        subtokens_mask.append(0)
        sent_lengths.append(len(subtokens))
        all_subtokens.append(subtokens)
        all_loss_mask.append(loss_mask)
        all_subtokens_mask.append(subtokens_mask)
        all_input_mask.append([1] * len(subtokens))

        if with_label:
            labels.append(pad_id)
            all_labels.append(labels)

    max_seq_length_data = max(sent_lengths)
    max_seq_length = min(max_seq_length, max_seq_length_data) if max_seq_length > 0 else max_seq_length_data
    logging.info(f'Setting Max Seq length to: {max_seq_length}')
    get_stats(sent_lengths)
    too_long_count = 0

    for i, subtokens in enumerate(all_subtokens):
        if len(subtokens) > max_seq_length:
            subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1 :]
            all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1 :]
            all_loss_mask[i] = [int(not ignore_start_end)] + all_loss_mask[i][-max_seq_length + 1 :]
            all_subtokens_mask[i] = [0] + all_subtokens_mask[i][-max_seq_length + 1 :]

            if with_label:
                all_labels[i] = [pad_id] + all_labels[i][-max_seq_length + 1 :]
            too_long_count += 1

        all_input_ids.append(tokenizer.tokens_to_ids(subtokens))

        if len(subtokens) < max_seq_length:
            extra = max_seq_length - len(subtokens)
            all_input_ids[i] = all_input_ids[i] + [0] * extra
            all_loss_mask[i] = all_loss_mask[i] + [0] * extra
            all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra
            all_input_mask[i] = all_input_mask[i] + [0] * extra

            if with_label:
                all_labels[i] = all_labels[i] + [pad_id] * extra

        all_segment_ids.append([0] * max_seq_length)

    logging.warning(f'{too_long_count} are longer than {max_seq_length}')

    for i in range(min(len(all_input_ids), 1)):
        logging.info("*** Example ***")
        logging.info("i: %s", i)
        logging.info("subtokens: %s", " ".join(list(map(str, all_subtokens[i]))))
        logging.info("loss_mask: %s", " ".join(list(map(str, all_loss_mask[i]))))
        logging.info("input_mask: %s", " ".join(list(map(str, all_input_mask[i]))))
        logging.info("subtokens_mask: %s", " ".join(list(map(str, all_subtokens_mask[i]))))
        if with_label:
            logging.info("labels: %s", " ".join(list(map(str, all_labels[i]))))
    return (all_input_ids, all_segment_ids, all_input_mask, all_subtokens_mask, all_loss_mask, all_labels)