示例#1
0
    def __init__(self,
                 input_file,
                 max_seq_length,
                 tokenizer,
                 num_samples=-1,
                 shuffle=True):
        with open(input_file, "r") as f:
            sent_labels, all_sent_subtokens = [], []
            sent_lengths = []
            too_long_count = 0

            lines = f.readlines()[1:]
            logging.info(f'{input_file}: {len(lines)}')

            if shuffle or num_samples > -1:
                random.seed(0)
                random.shuffle(lines)
                if num_samples > 0:
                    lines = lines[:num_samples]

            for index, line in enumerate(lines):
                if index % 20000 == 0:
                    logging.debug(f"Processing line {index}/{len(lines)}")

                sent_label = int(line.split()[-1])
                sent_labels.append(sent_label)
                sent_words = line.strip().split()[:-1]
                sent_subtokens = ['[CLS]']

                for word in sent_words:
                    word_tokens = tokenizer.tokenize(word)
                    sent_subtokens.extend(word_tokens)

                sent_subtokens.append('[SEP]')

                all_sent_subtokens.append(sent_subtokens)
                sent_lengths.append(len(sent_subtokens))

        get_stats(sent_lengths)
        self.max_seq_length = min(max_seq_length, max(sent_lengths))

        for i in range(len(all_sent_subtokens)):
            if len(all_sent_subtokens[i]) > self.max_seq_length:
                shorten_sent = all_sent_subtokens[i][-self.max_seq_length + 1:]
                all_sent_subtokens[i] = ['[CLS]'] + shorten_sent
                too_long_count += 1

        logging.info(f'{too_long_count} out of {len(sent_lengths)} \
                       sentencess with more than {max_seq_length} subtokens.')

        self.convert_sequences_to_features(all_sent_subtokens, sent_labels,
                                           tokenizer, self.max_seq_length)

        self.tokenizer = tokenizer
        self.vocab_size = self.tokenizer.vocab_size
def get_features(
    queries,
    max_seq_length,
    tokenizer,
    pad_label=128,
    raw_slots=None,
    ignore_extra_tokens=False,
    ignore_start_end=False,
):
    all_subtokens = []
    all_loss_mask = []
    all_subtokens_mask = []
    all_segment_ids = []
    all_input_ids = []
    all_input_mask = []
    sent_lengths = []
    all_slots = []

    with_label = False
    if raw_slots is not None:
        with_label = True

    for i, query in enumerate(queries):
        words = query.strip().split()
        subtokens = ['[CLS]']
        loss_mask = [1 - ignore_start_end]
        subtokens_mask = [0]
        if with_label:
            slots = [pad_label]

        for j, word in enumerate(words):
            word_tokens = tokenizer.text_to_tokens(word)
            subtokens.extend(word_tokens)

            loss_mask.append(1)
            loss_mask.extend([not ignore_extra_tokens] *
                             (len(word_tokens) - 1))

            subtokens_mask.append(1)
            subtokens_mask.extend([0] * (len(word_tokens) - 1))

            if with_label:
                slots.extend([raw_slots[i][j]] * len(word_tokens))

        subtokens.append('[SEP]')
        loss_mask.append(not ignore_start_end)
        subtokens_mask.append(0)
        sent_lengths.append(len(subtokens))
        all_subtokens.append(subtokens)
        all_loss_mask.append(loss_mask)
        all_subtokens_mask.append(subtokens_mask)
        all_input_mask.append([1] * len(subtokens))
        if with_label:
            slots.append(pad_label)
            all_slots.append(slots)

    max_seq_length = min(max_seq_length, max(sent_lengths))
    logging.info(f'Max length: {max_seq_length}')
    get_stats(sent_lengths)
    too_long_count = 0

    for i, subtokens in enumerate(all_subtokens):
        if len(subtokens) > max_seq_length:
            subtokens = ['[CLS]'] + subtokens[-max_seq_length + 1:]
            all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1:]
            all_loss_mask[i] = [1 - ignore_start_end
                                ] + all_loss_mask[i][-max_seq_length + 1:]
            all_subtokens_mask[i] = [
                0
            ] + all_subtokens_mask[i][-max_seq_length + 1:]

            if with_label:
                all_slots[i] = [pad_label] + all_slots[i][-max_seq_length + 1:]
            too_long_count += 1

        all_input_ids.append([tokenizer.tokens_to_ids(t) for t in subtokens])

        if len(subtokens) < max_seq_length:
            extra = max_seq_length - len(subtokens)
            all_input_ids[i] = all_input_ids[i] + [0] * extra
            all_loss_mask[i] = all_loss_mask[i] + [0] * extra
            all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra
            all_input_mask[i] = all_input_mask[i] + [0] * extra

            if with_label:
                all_slots[i] = all_slots[i] + [pad_label] * extra

        all_segment_ids.append([0] * max_seq_length)

    logging.info(f'{too_long_count} are longer than {max_seq_length}')

    return (all_input_ids, all_segment_ids, all_input_mask, all_loss_mask,
            all_subtokens_mask, all_slots)
示例#3
0
def get_features(
    queries,
    max_seq_length,
    tokenizer,
    punct_label_ids=None,
    capit_label_ids=None,
    pad_label='O',
    punct_labels_lines=None,
    capit_labels_lines=None,
    ignore_extra_tokens=False,
    ignore_start_end=False,
):
    """
    Args:
    queries (list of str): text sequences
    max_seq_length (int): max sequence length minus 2 for [CLS] and [SEP]
    tokenizer (TokenizerSpec): such as NemoBertTokenizer
    pad_label (str): pad value use for labels.
        by default, it's the neutral label.
    punct_label_ids (dict): dict to map punctuation labels to label ids.
        Starts with pad_label->0 and then increases in alphabetical order.
        Required for training and evaluation, not needed for inference.
    capit_label_ids (dict): dict to map labels to label ids. Starts
        with pad_label->0 and then increases in alphabetical order.
        Required for training and evaluation, not needed for inference.
    punct_labels (list of str): list of labels for every word in a sequence
    capit_labels (list of str): list of labels for every word in a sequence
    ignore_extra_tokens (bool): whether to ignore extra tokens in
        the loss_mask,
    ignore_start_end (bool): whether to ignore bos and eos tokens in
        the loss_mask
    """
    all_subtokens = []
    all_loss_mask = []
    all_subtokens_mask = []
    all_segment_ids = []
    all_input_ids = []
    all_input_mask = []
    sent_lengths = []
    punct_all_labels = []
    capit_all_labels = []
    with_label = False

    if punct_labels_lines and capit_labels_lines:
        with_label = True

    for i, query in enumerate(queries):
        words = query.strip().split()

        # add bos token
        subtokens = [tokenizer.cls_token]
        loss_mask = [1 - ignore_start_end]
        subtokens_mask = [0]
        if with_label:
            pad_id = punct_label_ids[pad_label]
            punct_labels = [pad_id]
            punct_query_labels = [
                punct_label_ids[lab] for lab in punct_labels_lines[i]
            ]

            capit_labels = [pad_id]
            capit_query_labels = [
                capit_label_ids[lab] for lab in capit_labels_lines[i]
            ]

        for j, word in enumerate(words):
            word_tokens = tokenizer.text_to_tokens(word)
            subtokens.extend(word_tokens)

            loss_mask.append(1)
            loss_mask.extend([int(not ignore_extra_tokens)] *
                             (len(word_tokens) - 1))

            subtokens_mask.append(1)
            subtokens_mask.extend([0] * (len(word_tokens) - 1))

            if with_label:
                punct_labels.extend([punct_query_labels[j]] * len(word_tokens))
                capit_labels.extend([capit_query_labels[j]] * len(word_tokens))

        # add eos token
        subtokens.append(tokenizer.sep_token)
        loss_mask.append(1 - ignore_start_end)
        subtokens_mask.append(0)
        sent_lengths.append(len(subtokens))
        all_subtokens.append(subtokens)
        all_loss_mask.append(loss_mask)
        all_subtokens_mask.append(subtokens_mask)
        all_input_mask.append([1] * len(subtokens))

        if with_label:
            punct_labels.append(pad_id)
            punct_all_labels.append(punct_labels)
            capit_labels.append(pad_id)
            capit_all_labels.append(capit_labels)

    max_seq_length = min(max_seq_length, max(sent_lengths))
    logging.info(f'Max length: {max_seq_length}')
    get_stats(sent_lengths)
    too_long_count = 0

    for i, subtokens in enumerate(all_subtokens):
        if len(subtokens) > max_seq_length:
            subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1:]
            all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1:]
            all_loss_mask[i] = [int(not ignore_start_end)
                                ] + all_loss_mask[i][-max_seq_length + 1:]
            all_subtokens_mask[i] = [
                0
            ] + all_subtokens_mask[i][-max_seq_length + 1:]

            if with_label:
                punct_all_labels[i] = [
                    pad_id
                ] + punct_all_labels[i][-max_seq_length + 1:]
                capit_all_labels[i] = [
                    pad_id
                ] + capit_all_labels[i][-max_seq_length + 1:]
            too_long_count += 1

        all_input_ids.append([tokenizer.tokens_to_ids(t) for t in subtokens])

        if len(subtokens) < max_seq_length:
            extra = max_seq_length - len(subtokens)
            all_input_ids[i] = all_input_ids[i] + [0] * extra
            all_loss_mask[i] = all_loss_mask[i] + [0] * extra
            all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra
            all_input_mask[i] = all_input_mask[i] + [0] * extra

            if with_label:
                punct_all_labels[i] = punct_all_labels[i] + [pad_id] * extra
                capit_all_labels[i] = capit_all_labels[i] + [pad_id] * extra

        all_segment_ids.append([0] * max_seq_length)

    logging.info(f'{too_long_count} are longer than {max_seq_length}')

    for i in range(min(len(all_input_ids), 5)):
        logging.info("*** Example ***")
        logging.info("i: %s" % (i))
        logging.info("subtokens: %s" %
                     " ".join(list(map(str, all_subtokens[i]))))
        logging.info("loss_mask: %s" %
                     " ".join(list(map(str, all_loss_mask[i]))))
        logging.info("input_mask: %s" %
                     " ".join(list(map(str, all_input_mask[i]))))
        logging.info("subtokens_mask: %s" %
                     " ".join(list(map(str, all_subtokens_mask[i]))))
        if with_label:
            logging.info("punct_labels: %s" %
                         " ".join(list(map(str, punct_all_labels[i]))))
            logging.info("capit_labels: %s" %
                         " ".join(list(map(str, capit_all_labels[i]))))

    return (
        all_input_ids,
        all_segment_ids,
        all_input_mask,
        all_loss_mask,
        all_subtokens_mask,
        punct_all_labels,
        capit_all_labels,
        punct_label_ids,
        capit_label_ids,
    )
示例#4
0
def get_features(
    queries,
    max_seq_length,
    tokenizer,
    pad_label=128,
    raw_slots=None,
    ignore_extra_tokens=False,
    ignore_start_end=False,
):
    all_subtokens = []
    all_loss_mask = []
    all_subtokens_mask = []
    all_segment_ids = []
    all_input_ids = []
    all_input_mask = []
    sent_lengths = []
    all_slots = []

    with_label = False
    if raw_slots is not None:
        with_label = True

    for i, query in enumerate(queries):
        words = query.strip().split()
        subtokens = [tokenizer.cls_token]
        loss_mask = [1 - ignore_start_end]
        subtokens_mask = [0]
        if with_label:
            slots = [pad_label]

        for j, word in enumerate(words):
            word_tokens = tokenizer.text_to_tokens(word)
            subtokens.extend(word_tokens)

            loss_mask.append(1)
            loss_mask.extend([not ignore_extra_tokens] * (len(word_tokens) - 1))

            subtokens_mask.append(1)
            subtokens_mask.extend([0] * (len(word_tokens) - 1))

            if with_label:
                slots.extend([raw_slots[i][j]] * len(word_tokens))

        subtokens.append(tokenizer.sep_token)
        loss_mask.append(1 - ignore_start_end)
        subtokens_mask.append(0)
        sent_lengths.append(len(subtokens))
        all_subtokens.append(subtokens)
        all_loss_mask.append(loss_mask)
        all_subtokens_mask.append(subtokens_mask)
        all_input_mask.append([1] * len(subtokens))
        if with_label:
            slots.append(pad_label)
            all_slots.append(slots)

    max_seq_length = min(max_seq_length, max(sent_lengths))
    logging.info(f'Max length: {max_seq_length}')
    get_stats(sent_lengths)
    too_long_count = 0

    for i, subtokens in enumerate(all_subtokens):
        if len(subtokens) > max_seq_length:
            subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1 :]
            all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1 :]
            all_loss_mask[i] = [1 - ignore_start_end] + all_loss_mask[i][-max_seq_length + 1 :]
            all_subtokens_mask[i] = [0] + all_subtokens_mask[i][-max_seq_length + 1 :]

            if with_label:
                all_slots[i] = [pad_label] + all_slots[i][-max_seq_length + 1 :]
            too_long_count += 1

        all_input_ids.append([tokenizer.tokens_to_ids(t) for t in subtokens])

        if len(subtokens) < max_seq_length:
            extra = max_seq_length - len(subtokens)
            all_input_ids[i] = all_input_ids[i] + [0] * extra
            all_loss_mask[i] = all_loss_mask[i] + [0] * extra
            all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra
            all_input_mask[i] = all_input_mask[i] + [0] * extra

            if with_label:
                all_slots[i] = all_slots[i] + [pad_label] * extra

        all_segment_ids.append([0] * max_seq_length)

    logging.info(f'{too_long_count} are longer than {max_seq_length}')

    logging.info("*** Some Examples of Processed Data***")
    for i in range(min(len(all_input_ids), 5)):
        logging.info("i: %s" % (i))
        logging.info("subtokens: %s" % " ".join(list(map(str, all_subtokens[i]))))
        logging.info("loss_mask: %s" % " ".join(list(map(str, all_loss_mask[i]))))
        logging.info("input_mask: %s" % " ".join(list(map(str, all_input_mask[i]))))
        logging.info("subtokens_mask: %s" % " ".join(list(map(str, all_subtokens_mask[i]))))
        if with_label:
            logging.info("slots_label: %s" % " ".join(list(map(str, all_slots[i]))))

    return (all_input_ids, all_segment_ids, all_input_mask, all_loss_mask, all_subtokens_mask, all_slots)