Exemplo n.º 1
0
def convert_examples_to_dataset(task,
                                examples: list,
                                tokenizer,
                                feat_spec: FeaturizationSpec,
                                phase: str,
                                verbose=False):
    """Create ListDataset containing DataRows and metadata.

    Args:
        task (Task): Task object
        examples (list[Example]): list of task Examples.
        tokenizer: TODO  (Issue #44)
        feat_spec (FeaturizationSpec): Tokenization-related metadata.
        phase (str): string identifying the data subset (e.g., train, val or test).
        verbose: If True, display progress bar.

    Returns:
        ListDataset containing DataRows and metadata.

    """
    data_rows = tokenize_and_featurize(
        task=task,
        examples=examples,
        tokenizer=tokenizer,
        feat_spec=feat_spec,
        phase=phase,
        verbose=verbose,
    )
    metadata = {"example_id": list(range(len(data_rows)))}
    data = []
    for i, data_row in enumerate(data_rows):
        metadata_row = {k: v[i] for k, v in metadata.items()}
        data.append({"data_row": data_row, "metadata": metadata_row})
    return torch_utils.ListDataset(data)
Exemplo n.º 2
0
def smart_truncate(dataset: torch_utils.ListDataset,
                   max_seq_length: int,
                   verbose: bool = False):
    """Truncate data to the length of the longest example in the dataset.

    Args:
        dataset (torch_utils.ListDataset): ListDataset to truncate if possible.
        max_seq_length (int): The maximum total input sequence length.
        verbose (bool): If True, display progress bar tracking truncation progress.

    Returns:
        Tuple[torch_utils.ListDataset, int]: truncated dataset, and length of the longest sequence.

    """
    if "input_mask" not in dataset.data[0]["data_row"].get_fields():
        raise RuntimeError("Smart truncate not supported")
    valid_length_ls = []
    range_idx = np.arange(max_seq_length)
    for datum in dataset.data:
        # TODO: document why reshape and max happen here (for cola this isn't necessary).
        #       (issue #1185)
        indexer = datum["data_row"].input_mask.reshape(-1,
                                                       max_seq_length).max(-2)
        valid_length_ls.append(range_idx[indexer.astype(bool)].max() + 1)
    max_valid_length = max(valid_length_ls)

    if max_valid_length == max_seq_length:
        return dataset, max_seq_length

    new_datum_ls = []
    for step, datum in enumerate(
            maybe_tqdm(dataset.data,
                       desc="Smart truncate data",
                       verbose=verbose)):
        regular_log(logger, step)
        new_datum_ls.append(
            smart_truncate_datum(
                datum=datum,
                max_seq_length=max_seq_length,
                max_valid_length=max_valid_length,
            ))
    new_dataset = torch_utils.ListDataset(new_datum_ls)
    return new_dataset, max_valid_length