def get_features(
    queries,
    max_seq_length,
    tokenizer,
    label_ids=None,
    pad_label='O',
    raw_labels=None,
    ignore_extra_tokens=False,
    ignore_start_end=False,
):
    """
    Args:
    queries (list of str): text sequences
    max_seq_length (int): max sequence length minus 2 for [CLS] and [SEP]
    tokenizer (Tokenizer): such as NemoBertTokenizer
    pad_label (str): pad value use for labels.
        by default, it's the neutral label.
    raw_labels (list of str): list of labels for every word in a sequence
    label_ids (dict): dict to map labels to label ids. Starts
        with pad_label->0 and then increases in alphabetical order.
        Required for training and evaluation, not needed for inference.
    ignore_extra_tokens (bool): whether to ignore extra tokens in
        the loss_mask,
    ignore_start_end (bool): whether to ignore bos and eos tokens in
        the loss_mask
    """
    all_subtokens = []
    all_loss_mask = []
    all_subtokens_mask = []
    all_segment_ids = []
    all_input_ids = []
    all_input_mask = []
    sent_lengths = []
    all_labels = []
    with_label = False

    if raw_labels is not None:
        with_label = True

    for i, query in enumerate(queries):
        words = query.strip().split()

        # add bos token
        subtokens = [tokenizer.cls_token]
        loss_mask = [1 - ignore_start_end]
        subtokens_mask = [0]
        if with_label:
            pad_id = label_ids[pad_label]
            labels = [pad_id]
            query_labels = [label_ids[lab] for lab in raw_labels[i]]

        for j, word in enumerate(words):
            word_tokens = tokenizer.text_to_tokens(word)
            subtokens.extend(word_tokens)

            loss_mask.append(1)
            loss_mask.extend([int(not ignore_extra_tokens)] *
                             (len(word_tokens) - 1))

            subtokens_mask.append(1)
            subtokens_mask.extend([0] * (len(word_tokens) - 1))

            if with_label:
                labels.extend([query_labels[j]] * len(word_tokens))
        # add eos token
        subtokens.append(tokenizer.sep_token)
        loss_mask.append(1 - ignore_start_end)
        subtokens_mask.append(0)
        sent_lengths.append(len(subtokens))
        all_subtokens.append(subtokens)
        all_loss_mask.append(loss_mask)
        all_subtokens_mask.append(subtokens_mask)
        all_input_mask.append([1] * len(subtokens))

        if with_label:
            labels.append(pad_id)
            all_labels.append(labels)

    max_seq_length = min(max_seq_length, max(sent_lengths))
    logging.info(f'Max length: {max_seq_length}')
    get_stats(sent_lengths)
    too_long_count = 0

    for i, subtokens in enumerate(all_subtokens):
        if len(subtokens) > max_seq_length:
            subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1:]
            all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1:]
            all_loss_mask[i] = [int(not ignore_start_end)
                                ] + all_loss_mask[i][-max_seq_length + 1:]
            all_subtokens_mask[i] = [
                0
            ] + all_subtokens_mask[i][-max_seq_length + 1:]

            if with_label:
                all_labels[i] = [pad_id] + all_labels[i][-max_seq_length + 1:]
            too_long_count += 1

        all_input_ids.append([tokenizer.tokens_to_ids(t) for t in subtokens])

        if len(subtokens) < max_seq_length:
            extra = max_seq_length - len(subtokens)
            all_input_ids[i] = all_input_ids[i] + [0] * extra
            all_loss_mask[i] = all_loss_mask[i] + [0] * extra
            all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra
            all_input_mask[i] = all_input_mask[i] + [0] * extra

            if with_label:
                all_labels[i] = all_labels[i] + [pad_id] * extra

        all_segment_ids.append([0] * max_seq_length)

    logging.warning(f'{too_long_count} are longer than {max_seq_length}')

    for i in range(min(len(all_input_ids), 5)):
        logging.info("*** Example ***")
        logging.info("i: %s", i)
        logging.info("subtokens: %s", " ".join(list(map(str,
                                                        all_subtokens[i]))))
        logging.info("loss_mask: %s", " ".join(list(map(str,
                                                        all_loss_mask[i]))))
        logging.info("input_mask: %s",
                     " ".join(list(map(str, all_input_mask[i]))))
        logging.info("subtokens_mask: %s",
                     " ".join(list(map(str, all_subtokens_mask[i]))))
        if with_label:
            logging.info("labels: %s", " ".join(list(map(str, all_labels[i]))))
    return (all_input_ids, all_segment_ids, all_input_mask, all_loss_mask,
            all_subtokens_mask, all_labels)
Beispiel #2
0
    def __init__(self,
                 input_file,
                 max_seq_length,
                 tokenizer,
                 num_samples=-1,
                 shuffle=False,
                 use_cache=False):
        self.input_file = input_file
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.num_samples = num_samples
        self.use_cache = use_cache
        self.shuffle = shuffle
        self.vocab_size = self.tokenizer.tokenizer.vocab_size

        if use_cache:
            data_dir, filename = os.path.split(input_file)
            tokenizer_type = type(tokenizer.tokenizer).__name__
            cached_features_file = os.path.join(
                data_dir,
                "cached_{}_{}_{}".format(filename[:-4], tokenizer_type,
                                         str(max_seq_length), '.hdf5'))

        if use_cache and os.path.exists(cached_features_file):
            self.load_cached_features(cached_features_file)

        else:
            with open(input_file, "r") as f:
                sent_labels, all_sent_subtokens = [], []
                sent_lengths = []
                too_long_count = 0

                lines = f.readlines()[1:]
                logging.info(f'{input_file}: {len(lines)}')

                if shuffle or num_samples > -1:
                    random.seed(0)
                    random.shuffle(lines)
                    if num_samples > 0:
                        lines = lines[:num_samples]

                for index, line in enumerate(lines):
                    if index % 20000 == 0:
                        logging.debug(f"Processing line {index}/{len(lines)}")

                    sent_label = int(line.split()[-1])
                    sent_labels.append(sent_label)
                    sent_words = line.strip().split()[:-1]
                    sent_subtokens = [tokenizer.cls_token]

                    for word in sent_words:
                        word_tokens = tokenizer.text_to_tokens(word)
                        sent_subtokens.extend(word_tokens)

                    sent_subtokens.append(tokenizer.sep_token)

                    all_sent_subtokens.append(sent_subtokens)
                    sent_lengths.append(len(sent_subtokens))
            get_stats(sent_lengths)

            for i in range(len(all_sent_subtokens)):
                if len(all_sent_subtokens[i]) > max_seq_length:
                    shorten_sent = all_sent_subtokens[i][-max_seq_length + 1:]
                    all_sent_subtokens[i] = [tokenizer.cls_token
                                             ] + shorten_sent
                    too_long_count += 1

            logging.info(f'{too_long_count} out of {len(sent_lengths)} \
                        sentences with more than {max_seq_length} subtokens.')

            self.convert_sequences_to_features(all_sent_subtokens, sent_labels,
                                               tokenizer, max_seq_length)

            if self.use_cache:
                self.cache_features(cached_features_file, self.features)

                # update self.features to use features from hdf5
                self.load_cached_features(cached_features_file)