예제 #1
0
    def dataset_split_fn(
        client_dataset: tf.data.Dataset
    ) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
        """A `DatasetSplitFn` built with the given arguments.

    Args:
      client_dataset: `tf.data.Dataset` representing client data.

    Returns:
      A tuple of two `tf.data.Datasets`, the first to be used for
      reconstruction, the second to be used post-reconstruction.
    """
        # Split dataset if needed. This assumes the dataset has a consistent
        # order across iterations.
        if split_dataset:
            recon_dataset = client_dataset.enumerate().filter(
                recon_condition).map(get_entry)
            post_recon_dataset = client_dataset.enumerate().filter(
                post_recon_condition).map(get_entry)
        else:
            recon_dataset = client_dataset
            post_recon_dataset = client_dataset

        # Apply `recon_epochs` before limiting to a maximum number of batches
        # if needed.
        recon_dataset = recon_dataset.repeat(recon_epochs)
        if recon_steps_max is not None:
            recon_dataset = recon_dataset.take(recon_steps_max)

        # Do the same for post-reconstruction.
        post_recon_dataset = post_recon_dataset.repeat(post_recon_epochs)
        if post_recon_steps_max is not None:
            post_recon_dataset = post_recon_dataset.take(post_recon_steps_max)

        return recon_dataset, post_recon_dataset
예제 #2
0
def orb_elts0(ds: tf.data.Dataset):
    """Get the initial orbital elements in a dataset"""
    # Get array with the starting values of the six orbital elements
    orb_a = np.concatenate([data[1]['orb_a'][:, 0, :] for i, data in ds.enumerate()], axis=0)
    orb_e = np.concatenate([data[1]['orb_e'][:, 0, :] for i, data in ds.enumerate()], axis=0)
    orb_inc = np.concatenate([data[1]['orb_inc'][:, 0, :] for i, data in ds.enumerate()], axis=0)
    orb_Omega = np.concatenate([data[1]['orb_Omega'][:, 0, :] for i, data in ds.enumerate()], axis=0)
    orb_omega = np.concatenate([data[1]['orb_omega'][:, 0, :] for i, data in ds.enumerate()], axis=0)
    orb_f = np.concatenate([data[1]['orb_f'][:, 0, :] for i, data in ds.enumerate()], axis=0)
    # H = np.concatenate([data[1]['H'][:] for i, data in ds.enumerate()], axis=0)
    
    # Combined orbital elements; array of shape num_trajectories, 12
    orb_elt = np.concatenate([orb_a, orb_e, orb_inc, orb_Omega, orb_omega, orb_f], axis=1)
    
    return orb_elt
예제 #3
0
def _preprocess_with_per_example_rng(ds: tf.data.Dataset,
                                     preprocess_fn: Callable[[Features],
                                                             Features], *,
                                     rng: jnp.ndarray) -> tf.data.Dataset:
    """Maps `ds` using the preprocess_fn and a deterministic RNG per example.

  Args:
    ds: Dataset containing Python dictionary with the features. The 'rng'
      feature should not exist.
    preprocess_fn: Preprocessing function that takes a Python dictionary of
      tensors and returns a Python dictionary of tensors. The function should be
      convertible into a TF graph.
    rng: Base RNG to use. Per example RNGs will be derived from this by folding
      in the example index.

  Returns:
    The dataset mapped by the `preprocess_fn`.
  """
    def _fn(example_index: int, features: Features) -> Features:
        example_index = tf.cast(example_index, tf.int32)
        features["rng"] = tf.random.experimental.stateless_fold_in(
            tf.cast(rng, tf.int64), example_index)
        processed = preprocess_fn(features)
        if isinstance(processed, dict) and "rng" in processed:
            del processed["rng"]
        return processed

    return ds.enumerate().map(_fn, num_parallel_calls=AUTOTUNE)
    def split_dataset(self, dataset: tf.data.Dataset,
                      validation_data_fraction: float):
        """
        Splits a dataset of type tf.data.Dataset into a training and validation dataset using given ratio. Fractions are
        rounded up to two decimal places.
        @param dataset: the input dataset to split.
        @param validation_data_fraction: the fraction of the validation data as a float between 0 and 1.
        @return: a tuple of two tf.data.Datasets as (training, validation)
        Refrence URL:
        https://stackoverflow.com/questions/59669413/
        what-is-the-canonical-way-to-split-tf-dataset-into-test-and-validation-subsets
        """

        validation_data_percent = round(validation_data_fraction * 100)
        if not (0 <= validation_data_percent <= 100):
            raise ValueError("validation data fraction must be ∈ [0,1]")

        dataset = dataset.enumerate()
        train_dataset = dataset.filter(
            lambda f, data: f % 100 > validation_data_percent)
        validation_dataset = dataset.filter(
            lambda f, data: f % 100 <= validation_data_percent)

        # remove enumeration
        train_dataset = train_dataset.map(lambda f, data: data)
        validation_dataset = validation_dataset.map(lambda f, data: data)

        return train_dataset, validation_dataset
예제 #5
0
    def split_dataset(dataset: tf.data.Dataset,
                      validation_data_fraction: float):
        """
        Splits a dataset of type tf.data.Dataset into a training and validation dataset using given ratio. Fractions are
        rounded up to two decimal places.
        @param dataset: the input dataset to split.
        @param validation_data_fraction: the fraction of the validation data as a float between 0 and 1.
        @return: a tuple of two tf.data.Datasets as (training, validation)
        """

        validation_data_percent = round(validation_data_fraction * 100)
        if not (0 <= validation_data_percent <= 100):
            raise ValueError("validation data fraction must be ∈ [0,1]")

        dataset = dataset.enumerate()
        train_dataset = dataset.filter(
            lambda f, data: f % 100 > validation_data_percent)
        validation_dataset = dataset.filter(
            lambda f, data: f % 100 <= validation_data_percent)

        # remove enumeration
        train_dataset = train_dataset.map(lambda f, data: data)
        validation_dataset = validation_dataset.map(lambda f, data: data)

        return train_dataset, validation_dataset
예제 #6
0
 def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int,
                                                                 int]):
     d = dataset.enumerate()
     d = d.map(lambda index, x: dict(image=tf.cast(x['image'], tf.float32) /
                                     127.5 - 1,
                                     label=x['label'],
                                     index=x))
     return cls(d, image_shape, parse_fn=None)
    def dataset_split_fn(
            client_dataset: tf.data.Dataset,
            round_num: tf.Tensor) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
        """A `DatasetSplitFn` built with the given arguments.

    Args:
      client_dataset: `tf.data.Dataset` representing client data.
      round_num: Scalar tf.int64 tensor representing the 1-indexed round number
        during training. During evaluation, this is 0.

    Returns:
      A tuple of two `tf.data.Datasets`, the first to be used for
      reconstruction, the second to be used post-reconstruction.
    """
        # Split dataset if needed. This assumes the dataset has a consistent
        # order across iterations.
        if split_dataset:
            recon_dataset = client_dataset.enumerate().filter(
                recon_condition).map(get_entry)
            post_recon_dataset = client_dataset.enumerate().filter(
                post_recon_condition).map(get_entry)
        else:
            recon_dataset = client_dataset
            post_recon_dataset = client_dataset

        # Number of reconstruction epochs is exactly recon_epochs_max if
        # recon_epochs_constant is True, and min(round_num, recon_epochs_max) if
        # not.
        num_recon_epochs = recon_epochs_max
        if not recon_epochs_constant:
            num_recon_epochs = tf.math.minimum(round_num, recon_epochs_max)

        # Apply `num_recon_epochs` before limiting to a maximum number of batches
        # if needed.
        recon_dataset = recon_dataset.repeat(num_recon_epochs)
        if recon_steps_max is not None:
            recon_dataset = recon_dataset.take(recon_steps_max)

        # Do the same for post-reconstruction.
        post_recon_dataset = post_recon_dataset.repeat(post_recon_epochs)
        if post_recon_steps_max is not None:
            post_recon_dataset = post_recon_dataset.take(post_recon_steps_max)

        return recon_dataset, post_recon_dataset
def Split_Dataset(dataset: tf.data.Dataset, validation_data_fraction: float):

    validation_data_percent = round(validation_data_fraction * 100)
    if not (0 <= validation_data_percent <= 100):
        raise ValueError("validation data fraction must be ∈ [0,1]")

    dataset = dataset.enumerate()
    train_dataset = dataset.filter(
        lambda f, data: f % 100 >= validation_data_percent)
    validation_dataset = dataset.filter(
        lambda f, data: f % 100 < validation_data_percent)

    # remove enumeration
    train_dataset = train_dataset.map(lambda f, data: data)
    validation_dataset = validation_dataset.map(lambda f, data: data)

    return train_dataset, validation_dataset
def split_dataset(dataset: tf.data.Dataset, val_split: float,
                  test_split: float):
    # Splits a dataset of type tf.data.Dataset into a training and test dataset using given ratio. Fractions are
    #   rounded up to two decimal places.
    # Input:
    #       dataset: the input dataset to split.
    #       val_split: the fraction of val data as a float between 0 and 1.
    #       test_split: the fraction of the test data as a float between 0 and 1.
    # Return:
    #       a tuple of two tf.data.Datasets as (training, test)
    # Source: https://stackoverflow.com/questions/59669413/what-is-the-canonical-way-to-split-tf-dataset-into-test-and-validation-subsets

    test_data_percent = round(test_split * 100)
    if not (0 <= test_data_percent <= 100):
        raise ValueError("test data fraction must be ∈ [0,1]")

    val_data_percent = round(val_split * 100)
    if not (0 <= val_data_percent <= 100):
        raise ValueError("val data fraction must be ∈ [0,1]")

    dataset = dataset.enumerate()
    train_val_dataset = dataset.filter(
        lambda f, data: f % 100 > test_data_percent)
    test_dataset = dataset.filter(lambda f, data: f % 100 <= test_data_percent)

    # remove enumeration
    train_val_dataset = train_val_dataset.map(lambda f, data: data)
    test_dataset = test_dataset.map(lambda f, data: data)

    # add validation from training
    train_val_dataset = train_val_dataset.enumerate()
    train_dataset = train_val_dataset.filter(
        lambda f, data: f % 100 > val_data_percent)
    val_dataset = train_val_dataset.filter(
        lambda f, data: f % 100 <= val_data_percent)

    # remove enumeration
    train_dataset = train_dataset.map(lambda f, data: data)
    val_dataset = val_dataset.map(lambda f, data: data)

    return train_dataset, val_dataset, test_dataset
예제 #10
0
    def dataset_split_fn(
            client_dataset: tf.data.Dataset,
            round_num: tf.Tensor) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
        """A `DatasetSplitFn` built with the given arguments.

    Args:
      client_dataset: `tf.data.Dataset` representing client data.
      round_num: Scalar tf.int64 tensor representing the 1-indexed round number
        during training. During evaluation, this is 0.

    Returns:
      A tuple of two `tf.data.Dataset`s, the first to be used for
      reconstruction, the second to be used post-reconstruction.
    """
        get_entry = lambda i, entry: entry
        if split_dataset:
            if split_dataset_strategy == SPLIT_STRATEGY_SKIP:

                def recon_condition(i, _):
                    return tf.equal(
                        tf.math.floormod(i, split_dataset_proportion), 0)

                def post_recon_condition(i, _):
                    return tf.greater(
                        tf.math.floormod(i, split_dataset_proportion), 0)

            elif split_dataset_strategy == SPLIT_STRATEGY_AGGREGATED:
                num_elements = client_dataset.reduce(
                    tf.constant(0.0, dtype=tf.float32), lambda x, _: x + 1)

                def recon_condition(i, _):
                    return i <= tf.cast(
                        num_elements / split_dataset_proportion,
                        dtype=tf.int64)

                def post_recon_condition(i, _):
                    return i > tf.cast(num_elements / split_dataset_proportion,
                                       dtype=tf.int64)

            else:
                raise ValueError(
                    'Unimplemented `split_dataset_strategy`: Must be one of '
                    '`{}`, or `{}`. Found {}'.format(
                        SPLIT_STRATEGY_SKIP, SPLIT_STRATEGY_AGGREGATED,
                        split_dataset_strategy))
        # split_dataset=False.
        else:
            recon_condition = lambda i, _: True
            post_recon_condition = lambda i, _: True

        recon_dataset = client_dataset.enumerate().filter(recon_condition).map(
            get_entry)
        post_recon_dataset = client_dataset.enumerate().filter(
            post_recon_condition).map(get_entry)

        # Number of reconstruction epochs is exactly recon_epochs_max if
        # recon_epochs_constant is True, and min(round_num, recon_epochs_max) if
        # not.
        num_recon_epochs = recon_epochs_max
        if not recon_epochs_constant:
            num_recon_epochs = tf.math.minimum(round_num, recon_epochs_max)

        # Apply `num_recon_epochs` before limiting to a maximum number of batches
        # if needed.
        recon_dataset = recon_dataset.repeat(num_recon_epochs)
        if recon_steps_max is not None:
            recon_dataset = recon_dataset.take(recon_steps_max)

        # Do the same for post-reconstruction.
        post_recon_dataset = post_recon_dataset.repeat(post_recon_epochs)
        if post_recon_steps_max is not None:
            post_recon_dataset = post_recon_dataset.take(post_recon_steps_max)

        return recon_dataset, post_recon_dataset