def get_features( queries, max_seq_length, tokenizer, label_ids=None, pad_label='O', raw_labels=None, ignore_extra_tokens=False, ignore_start_end=False, ): """ Args: queries (list of str): text sequences max_seq_length (int): max sequence length minus 2 for [CLS] and [SEP] tokenizer (Tokenizer): such as NemoBertTokenizer pad_label (str): pad value use for labels. by default, it's the neutral label. raw_labels (list of str): list of labels for every word in a sequence label_ids (dict): dict to map labels to label ids. Starts with pad_label->0 and then increases in alphabetical order. Required for training and evaluation, not needed for inference. ignore_extra_tokens (bool): whether to ignore extra tokens in the loss_mask, ignore_start_end (bool): whether to ignore bos and eos tokens in the loss_mask """ all_subtokens = [] all_loss_mask = [] all_subtokens_mask = [] all_segment_ids = [] all_input_ids = [] all_input_mask = [] sent_lengths = [] all_labels = [] with_label = False if raw_labels is not None: with_label = True for i, query in enumerate(queries): words = query.strip().split() # add bos token subtokens = [tokenizer.cls_token] loss_mask = [1 - ignore_start_end] subtokens_mask = [0] if with_label: pad_id = label_ids[pad_label] labels = [pad_id] query_labels = [label_ids[lab] for lab in raw_labels[i]] for j, word in enumerate(words): word_tokens = tokenizer.text_to_tokens(word) subtokens.extend(word_tokens) loss_mask.append(1) loss_mask.extend([int(not ignore_extra_tokens)] * (len(word_tokens) - 1)) subtokens_mask.append(1) subtokens_mask.extend([0] * (len(word_tokens) - 1)) if with_label: labels.extend([query_labels[j]] * len(word_tokens)) # add eos token subtokens.append(tokenizer.sep_token) loss_mask.append(1 - ignore_start_end) subtokens_mask.append(0) sent_lengths.append(len(subtokens)) all_subtokens.append(subtokens) all_loss_mask.append(loss_mask) all_subtokens_mask.append(subtokens_mask) all_input_mask.append([1] * len(subtokens)) if with_label: labels.append(pad_id) all_labels.append(labels) max_seq_length = min(max_seq_length, max(sent_lengths)) logging.info(f'Max length: {max_seq_length}') get_stats(sent_lengths) too_long_count = 0 for i, subtokens in enumerate(all_subtokens): if len(subtokens) > max_seq_length: subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1:] all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1:] all_loss_mask[i] = [int(not ignore_start_end) ] + all_loss_mask[i][-max_seq_length + 1:] all_subtokens_mask[i] = [ 0 ] + all_subtokens_mask[i][-max_seq_length + 1:] if with_label: all_labels[i] = [pad_id] + all_labels[i][-max_seq_length + 1:] too_long_count += 1 all_input_ids.append([tokenizer.tokens_to_ids(t) for t in subtokens]) if len(subtokens) < max_seq_length: extra = max_seq_length - len(subtokens) all_input_ids[i] = all_input_ids[i] + [0] * extra all_loss_mask[i] = all_loss_mask[i] + [0] * extra all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra all_input_mask[i] = all_input_mask[i] + [0] * extra if with_label: all_labels[i] = all_labels[i] + [pad_id] * extra all_segment_ids.append([0] * max_seq_length) logging.warning(f'{too_long_count} are longer than {max_seq_length}') for i in range(min(len(all_input_ids), 5)): logging.info("*** Example ***") logging.info("i: %s", i) logging.info("subtokens: %s", " ".join(list(map(str, all_subtokens[i])))) logging.info("loss_mask: %s", " ".join(list(map(str, all_loss_mask[i])))) logging.info("input_mask: %s", " ".join(list(map(str, all_input_mask[i])))) logging.info("subtokens_mask: %s", " ".join(list(map(str, all_subtokens_mask[i])))) if with_label: logging.info("labels: %s", " ".join(list(map(str, all_labels[i])))) return (all_input_ids, all_segment_ids, all_input_mask, all_loss_mask, all_subtokens_mask, all_labels)
def __init__(self, input_file, max_seq_length, tokenizer, num_samples=-1, shuffle=False, use_cache=False): self.input_file = input_file self.max_seq_length = max_seq_length self.tokenizer = tokenizer self.num_samples = num_samples self.use_cache = use_cache self.shuffle = shuffle self.vocab_size = self.tokenizer.tokenizer.vocab_size if use_cache: data_dir, filename = os.path.split(input_file) tokenizer_type = type(tokenizer.tokenizer).__name__ cached_features_file = os.path.join( data_dir, "cached_{}_{}_{}".format(filename[:-4], tokenizer_type, str(max_seq_length), '.hdf5')) if use_cache and os.path.exists(cached_features_file): self.load_cached_features(cached_features_file) else: with open(input_file, "r") as f: sent_labels, all_sent_subtokens = [], [] sent_lengths = [] too_long_count = 0 lines = f.readlines()[1:] logging.info(f'{input_file}: {len(lines)}') if shuffle or num_samples > -1: random.seed(0) random.shuffle(lines) if num_samples > 0: lines = lines[:num_samples] for index, line in enumerate(lines): if index % 20000 == 0: logging.debug(f"Processing line {index}/{len(lines)}") sent_label = int(line.split()[-1]) sent_labels.append(sent_label) sent_words = line.strip().split()[:-1] sent_subtokens = [tokenizer.cls_token] for word in sent_words: word_tokens = tokenizer.text_to_tokens(word) sent_subtokens.extend(word_tokens) sent_subtokens.append(tokenizer.sep_token) all_sent_subtokens.append(sent_subtokens) sent_lengths.append(len(sent_subtokens)) get_stats(sent_lengths) for i in range(len(all_sent_subtokens)): if len(all_sent_subtokens[i]) > max_seq_length: shorten_sent = all_sent_subtokens[i][-max_seq_length + 1:] all_sent_subtokens[i] = [tokenizer.cls_token ] + shorten_sent too_long_count += 1 logging.info(f'{too_long_count} out of {len(sent_lengths)} \ sentences with more than {max_seq_length} subtokens.') self.convert_sequences_to_features(all_sent_subtokens, sent_labels, tokenizer, max_seq_length) if self.use_cache: self.cache_features(cached_features_file, self.features) # update self.features to use features from hdf5 self.load_cached_features(cached_features_file)