def convert_examples_to_dataset(task, examples: list, tokenizer, feat_spec: FeaturizationSpec, phase: str, verbose=False): """Create ListDataset containing DataRows and metadata. Args: task (Task): Task object examples (list[Example]): list of task Examples. tokenizer: TODO (Issue #44) feat_spec (FeaturizationSpec): Tokenization-related metadata. phase (str): string identifying the data subset (e.g., train, val or test). verbose: If True, display progress bar. Returns: ListDataset containing DataRows and metadata. """ data_rows = tokenize_and_featurize( task=task, examples=examples, tokenizer=tokenizer, feat_spec=feat_spec, phase=phase, verbose=verbose, ) metadata = {"example_id": list(range(len(data_rows)))} data = [] for i, data_row in enumerate(data_rows): metadata_row = {k: v[i] for k, v in metadata.items()} data.append({"data_row": data_row, "metadata": metadata_row}) return torch_utils.ListDataset(data)
def smart_truncate(dataset: torch_utils.ListDataset, max_seq_length: int, verbose: bool = False): """Truncate data to the length of the longest example in the dataset. Args: dataset (torch_utils.ListDataset): ListDataset to truncate if possible. max_seq_length (int): The maximum total input sequence length. verbose (bool): If True, display progress bar tracking truncation progress. Returns: Tuple[torch_utils.ListDataset, int]: truncated dataset, and length of the longest sequence. """ if "input_mask" not in dataset.data[0]["data_row"].get_fields(): raise RuntimeError("Smart truncate not supported") valid_length_ls = [] range_idx = np.arange(max_seq_length) for datum in dataset.data: # TODO: document why reshape and max happen here (for cola this isn't necessary). # (issue #1185) indexer = datum["data_row"].input_mask.reshape(-1, max_seq_length).max(-2) valid_length_ls.append(range_idx[indexer.astype(bool)].max() + 1) max_valid_length = max(valid_length_ls) if max_valid_length == max_seq_length: return dataset, max_seq_length new_datum_ls = [] for step, datum in enumerate( maybe_tqdm(dataset.data, desc="Smart truncate data", verbose=verbose)): regular_log(logger, step) new_datum_ls.append( smart_truncate_datum( datum=datum, max_seq_length=max_seq_length, max_valid_length=max_valid_length, )) new_dataset = torch_utils.ListDataset(new_datum_ls) return new_dataset, max_valid_length