def __init__(self, input_file, max_seq_length, tokenizer, num_samples=-1, shuffle=True): with open(input_file, "r") as f: sent_labels, all_sent_subtokens = [], [] sent_lengths = [] too_long_count = 0 lines = f.readlines()[1:] logging.info(f'{input_file}: {len(lines)}') if shuffle or num_samples > -1: random.seed(0) random.shuffle(lines) if num_samples > 0: lines = lines[:num_samples] for index, line in enumerate(lines): if index % 20000 == 0: logging.debug(f"Processing line {index}/{len(lines)}") sent_label = int(line.split()[-1]) sent_labels.append(sent_label) sent_words = line.strip().split()[:-1] sent_subtokens = ['[CLS]'] for word in sent_words: word_tokens = tokenizer.tokenize(word) sent_subtokens.extend(word_tokens) sent_subtokens.append('[SEP]') all_sent_subtokens.append(sent_subtokens) sent_lengths.append(len(sent_subtokens)) get_stats(sent_lengths) self.max_seq_length = min(max_seq_length, max(sent_lengths)) for i in range(len(all_sent_subtokens)): if len(all_sent_subtokens[i]) > self.max_seq_length: shorten_sent = all_sent_subtokens[i][-self.max_seq_length + 1:] all_sent_subtokens[i] = ['[CLS]'] + shorten_sent too_long_count += 1 logging.info(f'{too_long_count} out of {len(sent_lengths)} \ sentencess with more than {max_seq_length} subtokens.') self.convert_sequences_to_features(all_sent_subtokens, sent_labels, tokenizer, self.max_seq_length) self.tokenizer = tokenizer self.vocab_size = self.tokenizer.vocab_size
def get_features( queries, max_seq_length, tokenizer, pad_label=128, raw_slots=None, ignore_extra_tokens=False, ignore_start_end=False, ): all_subtokens = [] all_loss_mask = [] all_subtokens_mask = [] all_segment_ids = [] all_input_ids = [] all_input_mask = [] sent_lengths = [] all_slots = [] with_label = False if raw_slots is not None: with_label = True for i, query in enumerate(queries): words = query.strip().split() subtokens = ['[CLS]'] loss_mask = [1 - ignore_start_end] subtokens_mask = [0] if with_label: slots = [pad_label] for j, word in enumerate(words): word_tokens = tokenizer.text_to_tokens(word) subtokens.extend(word_tokens) loss_mask.append(1) loss_mask.extend([not ignore_extra_tokens] * (len(word_tokens) - 1)) subtokens_mask.append(1) subtokens_mask.extend([0] * (len(word_tokens) - 1)) if with_label: slots.extend([raw_slots[i][j]] * len(word_tokens)) subtokens.append('[SEP]') loss_mask.append(not ignore_start_end) subtokens_mask.append(0) sent_lengths.append(len(subtokens)) all_subtokens.append(subtokens) all_loss_mask.append(loss_mask) all_subtokens_mask.append(subtokens_mask) all_input_mask.append([1] * len(subtokens)) if with_label: slots.append(pad_label) all_slots.append(slots) max_seq_length = min(max_seq_length, max(sent_lengths)) logging.info(f'Max length: {max_seq_length}') get_stats(sent_lengths) too_long_count = 0 for i, subtokens in enumerate(all_subtokens): if len(subtokens) > max_seq_length: subtokens = ['[CLS]'] + subtokens[-max_seq_length + 1:] all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1:] all_loss_mask[i] = [1 - ignore_start_end ] + all_loss_mask[i][-max_seq_length + 1:] all_subtokens_mask[i] = [ 0 ] + all_subtokens_mask[i][-max_seq_length + 1:] if with_label: all_slots[i] = [pad_label] + all_slots[i][-max_seq_length + 1:] too_long_count += 1 all_input_ids.append([tokenizer.tokens_to_ids(t) for t in subtokens]) if len(subtokens) < max_seq_length: extra = max_seq_length - len(subtokens) all_input_ids[i] = all_input_ids[i] + [0] * extra all_loss_mask[i] = all_loss_mask[i] + [0] * extra all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra all_input_mask[i] = all_input_mask[i] + [0] * extra if with_label: all_slots[i] = all_slots[i] + [pad_label] * extra all_segment_ids.append([0] * max_seq_length) logging.info(f'{too_long_count} are longer than {max_seq_length}') return (all_input_ids, all_segment_ids, all_input_mask, all_loss_mask, all_subtokens_mask, all_slots)
def get_features( queries, max_seq_length, tokenizer, punct_label_ids=None, capit_label_ids=None, pad_label='O', punct_labels_lines=None, capit_labels_lines=None, ignore_extra_tokens=False, ignore_start_end=False, ): """ Args: queries (list of str): text sequences max_seq_length (int): max sequence length minus 2 for [CLS] and [SEP] tokenizer (TokenizerSpec): such as NemoBertTokenizer pad_label (str): pad value use for labels. by default, it's the neutral label. punct_label_ids (dict): dict to map punctuation labels to label ids. Starts with pad_label->0 and then increases in alphabetical order. Required for training and evaluation, not needed for inference. capit_label_ids (dict): dict to map labels to label ids. Starts with pad_label->0 and then increases in alphabetical order. Required for training and evaluation, not needed for inference. punct_labels (list of str): list of labels for every word in a sequence capit_labels (list of str): list of labels for every word in a sequence ignore_extra_tokens (bool): whether to ignore extra tokens in the loss_mask, ignore_start_end (bool): whether to ignore bos and eos tokens in the loss_mask """ all_subtokens = [] all_loss_mask = [] all_subtokens_mask = [] all_segment_ids = [] all_input_ids = [] all_input_mask = [] sent_lengths = [] punct_all_labels = [] capit_all_labels = [] with_label = False if punct_labels_lines and capit_labels_lines: with_label = True for i, query in enumerate(queries): words = query.strip().split() # add bos token subtokens = [tokenizer.cls_token] loss_mask = [1 - ignore_start_end] subtokens_mask = [0] if with_label: pad_id = punct_label_ids[pad_label] punct_labels = [pad_id] punct_query_labels = [ punct_label_ids[lab] for lab in punct_labels_lines[i] ] capit_labels = [pad_id] capit_query_labels = [ capit_label_ids[lab] for lab in capit_labels_lines[i] ] for j, word in enumerate(words): word_tokens = tokenizer.text_to_tokens(word) subtokens.extend(word_tokens) loss_mask.append(1) loss_mask.extend([int(not ignore_extra_tokens)] * (len(word_tokens) - 1)) subtokens_mask.append(1) subtokens_mask.extend([0] * (len(word_tokens) - 1)) if with_label: punct_labels.extend([punct_query_labels[j]] * len(word_tokens)) capit_labels.extend([capit_query_labels[j]] * len(word_tokens)) # add eos token subtokens.append(tokenizer.sep_token) loss_mask.append(1 - ignore_start_end) subtokens_mask.append(0) sent_lengths.append(len(subtokens)) all_subtokens.append(subtokens) all_loss_mask.append(loss_mask) all_subtokens_mask.append(subtokens_mask) all_input_mask.append([1] * len(subtokens)) if with_label: punct_labels.append(pad_id) punct_all_labels.append(punct_labels) capit_labels.append(pad_id) capit_all_labels.append(capit_labels) max_seq_length = min(max_seq_length, max(sent_lengths)) logging.info(f'Max length: {max_seq_length}') get_stats(sent_lengths) too_long_count = 0 for i, subtokens in enumerate(all_subtokens): if len(subtokens) > max_seq_length: subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1:] all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1:] all_loss_mask[i] = [int(not ignore_start_end) ] + all_loss_mask[i][-max_seq_length + 1:] all_subtokens_mask[i] = [ 0 ] + all_subtokens_mask[i][-max_seq_length + 1:] if with_label: punct_all_labels[i] = [ pad_id ] + punct_all_labels[i][-max_seq_length + 1:] capit_all_labels[i] = [ pad_id ] + capit_all_labels[i][-max_seq_length + 1:] too_long_count += 1 all_input_ids.append([tokenizer.tokens_to_ids(t) for t in subtokens]) if len(subtokens) < max_seq_length: extra = max_seq_length - len(subtokens) all_input_ids[i] = all_input_ids[i] + [0] * extra all_loss_mask[i] = all_loss_mask[i] + [0] * extra all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra all_input_mask[i] = all_input_mask[i] + [0] * extra if with_label: punct_all_labels[i] = punct_all_labels[i] + [pad_id] * extra capit_all_labels[i] = capit_all_labels[i] + [pad_id] * extra all_segment_ids.append([0] * max_seq_length) logging.info(f'{too_long_count} are longer than {max_seq_length}') for i in range(min(len(all_input_ids), 5)): logging.info("*** Example ***") logging.info("i: %s" % (i)) logging.info("subtokens: %s" % " ".join(list(map(str, all_subtokens[i])))) logging.info("loss_mask: %s" % " ".join(list(map(str, all_loss_mask[i])))) logging.info("input_mask: %s" % " ".join(list(map(str, all_input_mask[i])))) logging.info("subtokens_mask: %s" % " ".join(list(map(str, all_subtokens_mask[i])))) if with_label: logging.info("punct_labels: %s" % " ".join(list(map(str, punct_all_labels[i])))) logging.info("capit_labels: %s" % " ".join(list(map(str, capit_all_labels[i])))) return ( all_input_ids, all_segment_ids, all_input_mask, all_loss_mask, all_subtokens_mask, punct_all_labels, capit_all_labels, punct_label_ids, capit_label_ids, )
def get_features( queries, max_seq_length, tokenizer, pad_label=128, raw_slots=None, ignore_extra_tokens=False, ignore_start_end=False, ): all_subtokens = [] all_loss_mask = [] all_subtokens_mask = [] all_segment_ids = [] all_input_ids = [] all_input_mask = [] sent_lengths = [] all_slots = [] with_label = False if raw_slots is not None: with_label = True for i, query in enumerate(queries): words = query.strip().split() subtokens = [tokenizer.cls_token] loss_mask = [1 - ignore_start_end] subtokens_mask = [0] if with_label: slots = [pad_label] for j, word in enumerate(words): word_tokens = tokenizer.text_to_tokens(word) subtokens.extend(word_tokens) loss_mask.append(1) loss_mask.extend([not ignore_extra_tokens] * (len(word_tokens) - 1)) subtokens_mask.append(1) subtokens_mask.extend([0] * (len(word_tokens) - 1)) if with_label: slots.extend([raw_slots[i][j]] * len(word_tokens)) subtokens.append(tokenizer.sep_token) loss_mask.append(1 - ignore_start_end) subtokens_mask.append(0) sent_lengths.append(len(subtokens)) all_subtokens.append(subtokens) all_loss_mask.append(loss_mask) all_subtokens_mask.append(subtokens_mask) all_input_mask.append([1] * len(subtokens)) if with_label: slots.append(pad_label) all_slots.append(slots) max_seq_length = min(max_seq_length, max(sent_lengths)) logging.info(f'Max length: {max_seq_length}') get_stats(sent_lengths) too_long_count = 0 for i, subtokens in enumerate(all_subtokens): if len(subtokens) > max_seq_length: subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1 :] all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1 :] all_loss_mask[i] = [1 - ignore_start_end] + all_loss_mask[i][-max_seq_length + 1 :] all_subtokens_mask[i] = [0] + all_subtokens_mask[i][-max_seq_length + 1 :] if with_label: all_slots[i] = [pad_label] + all_slots[i][-max_seq_length + 1 :] too_long_count += 1 all_input_ids.append([tokenizer.tokens_to_ids(t) for t in subtokens]) if len(subtokens) < max_seq_length: extra = max_seq_length - len(subtokens) all_input_ids[i] = all_input_ids[i] + [0] * extra all_loss_mask[i] = all_loss_mask[i] + [0] * extra all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra all_input_mask[i] = all_input_mask[i] + [0] * extra if with_label: all_slots[i] = all_slots[i] + [pad_label] * extra all_segment_ids.append([0] * max_seq_length) logging.info(f'{too_long_count} are longer than {max_seq_length}') logging.info("*** Some Examples of Processed Data***") for i in range(min(len(all_input_ids), 5)): logging.info("i: %s" % (i)) logging.info("subtokens: %s" % " ".join(list(map(str, all_subtokens[i])))) logging.info("loss_mask: %s" % " ".join(list(map(str, all_loss_mask[i])))) logging.info("input_mask: %s" % " ".join(list(map(str, all_input_mask[i])))) logging.info("subtokens_mask: %s" % " ".join(list(map(str, all_subtokens_mask[i])))) if with_label: logging.info("slots_label: %s" % " ".join(list(map(str, all_slots[i])))) return (all_input_ids, all_segment_ids, all_input_mask, all_loss_mask, all_subtokens_mask, all_slots)