示例#1
0
    def augment_dataset(dataset: tf.data.Dataset):
        """
        Bases on Sudoku isomorphism: https://en.wikipedia.org/wiki/Mathematics_of_Sudoku
        """

        for _ in range(2):
            shuffle_rows = dataset.map(cycle_row_example,
                                       num_parallel_calls=AUTOTUNE)
            dataset = dataset.concatenate(shuffle_rows)

        for _ in range(2):
            shuffle_col = dataset.map(cycle_col_example,
                                      num_parallel_calls=AUTOTUNE)
            dataset = dataset.concatenate(shuffle_col)

        tran = dataset.map(transpose, num_parallel_calls=AUTOTUNE)
        return dataset.concatenate(tran)
示例#2
0
def pad_dataset(dataset: tf.data.Dataset,
                *,
                batch_dims: Sequence[int],
                pad_up_to_batches: Optional[int] = None,
                cardinality: Optional[int] = None):
    """Adds padding to a dataset.

  Args:
    dataset: The dataset to be padded.
    batch_dims: List of size of batch dimensions. Multiple batch dimension can
      be used to provide inputs for multiple devices. E.g.
      [jax.local_device_count(), batch_size // jax.device_count()].
    pad_up_to_batches: Set this option to process the entire dataset. When set,
      then the dataset is first padded to the specified number of batches. A new
      feature called "mask" is added to every batch. This feature is set to
      `True` for every example that comes from `dataset_builder`, and to `False`
      for every example that is padded to get to the specified number of
      batches. Note that the specified `dataset_builder` and `split` must result
      in at least `pad_up_to_batches` (possibly partial) batches. If `None`,
      derives from `batch_dims` and `cardinality` such that `pad_up_to_batches *
      batch_dims == cardinality`. Note that `cardinality` is what you pass in,
      not necessarily the original full dataset size if you decide to shard it
      per host.
    cardinality: Number of examples in the dataset. Only needed when the
      cardinality cannot be retrieved via `ds.cardinalty()` (e.g. because of
      using `ds.filter()`).

  Returns:
    The padded dataset, with the added feature "mask" that is set to `True` for
    examples from the original `dataset` and to `False` for padded examples.
  """
    if not isinstance(dataset.element_spec, dict):
        raise ValueError("The dataset must have dictionary elements.")
    if cardinality is None:
        cardinality = dataset.cardinality()
        if cardinality == tf.data.UNKNOWN_CARDINALITY:
            raise ValueError(
                "Cannot determine dataset cardinality. This can happen when you use "
                "a `.filter()` on the dataset. Please provide the cardinality as an "
                "argument to `create_dataset()`.")
    if "mask" in dataset.element_spec:
        raise ValueError("Dataset already contains a feature named \"mask\".")
    if pad_up_to_batches is None:
        pad_up_to_batches = int(np.ceil(cardinality / np.prod(batch_dims)))

    filler_element = tf.nest.map_structure(
        lambda spec: tf.zeros(spec.shape, spec.dtype)[None],
        dataset.element_spec)
    filler_element["mask"] = [False]
    filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)

    dataset = dataset.map(lambda features: dict(mask=True, **features),
                          num_parallel_calls=AUTOTUNE)
    padding = pad_up_to_batches * np.prod(batch_dims) - int(cardinality)
    assert padding >= 0, (
        f"Invalid padding={padding} (batch_dims={batch_dims}, cardinality="
        f"{cardinality}, pad_up_to_batches={pad_up_to_batches})")
    return dataset.concatenate(filler_dataset.repeat(padding))
示例#3
0
def accumulated_batch(
    dataset: tf.data.Dataset,
    accumulator: Union[Accumulator, Mapping, Iterable],
    **map_kwargs,
):
    accumulator = accumulator_structure(accumulator)

    def initial_map_fn(*args):
        if len(args) == 1:
            (args,) = args
        return args, False

    @tf.function
    def scan_fn(state, el_and_final):
        el, final = el_and_final
        new_state = accumulator.append(state, el)
        valid = tf.reduce_all(accumulator.valid_conditions(new_state))
        invalid = tf.logical_not(valid)
        if invalid:
            new_state = accumulator.append(accumulator.initial_state(), el)
        return new_state, (state, tf.logical_or(invalid, final))

    def filter_fn(state, invalid):
        del state
        return invalid

    def map_fn(state, invalid):
        del invalid
        return accumulator.finalize(state)

    cardinality = tf.data.experimental.cardinality(dataset)

    dataset = dataset.map(initial_map_fn)
    if cardinality != tf.data.experimental.INFINITE_CARDINALITY:
        # append (empty, True) element to ensure final elements are generated
        state_spec = dataset.element_spec[0]
        empty_el = tf.nest.map_structure(
            lambda spec: tf.zeros(
                [1, *(0 if s is None else s for s in spec.shape)], dtype=spec.dtype
            ),
            state_spec,
        )
        true_el = tf.ones((1,), dtype=tf.bool)
        dataset = dataset.concatenate(
            tf.data.Dataset.from_tensor_slices((empty_el, true_el))
        )

    dataset = dataset.apply(
        tf.data.experimental.scan(accumulator.initial_state(), scan_fn)
    )

    dataset = dataset.filter(filter_fn)
    dataset = dataset.map(map_fn, **map_kwargs)
    return dataset
示例#4
0
 def get_ratio_enforced_dataset(self, positive_training_dataset: tf.data.Dataset,
                                negative_training_dataset: tf.data.Dataset,
                                positive_to_negative_data_ratio: float) -> tf.data.Dataset:
     """Generates a dataset with an enforced data ratio."""
     if positive_to_negative_data_ratio is not None:
         positive_count = len(list(positive_training_dataset))
         negative_count = len(list(negative_training_dataset))
         existing_ratio = positive_count / negative_count
         if existing_ratio < positive_to_negative_data_ratio:
             desired_number_of_positive_examples = int(positive_to_negative_data_ratio * negative_count)
             positive_training_dataset = self.repeat_dataset_to_size(positive_training_dataset,
                                                                     desired_number_of_positive_examples)
         else:
             desired_number_of_negative_examples = int((1 / positive_to_negative_data_ratio) * positive_count)
             negative_training_dataset = self.repeat_dataset_to_size(negative_training_dataset,
                                                                     desired_number_of_negative_examples)
     return positive_training_dataset.concatenate(negative_training_dataset)
示例#5
0
def dataset_join(dataset_left: tf.data.Dataset,
                 dataset_right: tf.data.Dataset) -> tf.data.Dataset:
    dataset_joined = dataset_left.concatenate(dataset_right)
    return dataset_joined