def test_ds_length_is_ceil_num_epochs_over_batch_size(
         self, num_epochs, batch_size):
     ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
     preprocess_spec = client_spec.ClientSpec(num_epochs=num_epochs,
                                              batch_size=batch_size)
     preprocess_fn = image_classification_preprocessing.create_preprocess_fn(
         preprocess_spec)
     preprocessed_ds = preprocess_fn(ds)
     self.assertEqual(
         _compute_length_of_dataset(preprocessed_ds),
         tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32))
 def test_ds_length_with_max_elements(self, max_elements):
     repeat_size = 10
     ds = tf.data.Dataset.from_tensor_slices(TEST_DATA).repeat(repeat_size)
     preprocess_spec = client_spec.ClientSpec(num_epochs=1,
                                              batch_size=1,
                                              max_elements=max_elements)
     preprocess_fn = image_classification_preprocessing.create_preprocess_fn(
         preprocess_spec)
     preprocessed_ds = preprocess_fn(ds)
     self.assertEqual(_compute_length_of_dataset(preprocessed_ds),
                      min(repeat_size, max_elements))
    def test_preprocess_fn_returns_correct_element(self, crop_shape,
                                                   distort_image):
        ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
        preprocess_spec = client_spec.ClientSpec(num_epochs=1,
                                                 batch_size=1,
                                                 shuffle_buffer_size=1)
        preprocess_fn = image_classification_preprocessing.create_preprocess_fn(
            preprocess_spec,
            crop_shape=crop_shape,
            distort_image=distort_image)
        preprocessed_ds = preprocess_fn(ds)
        expected_element_spec_shape = (None, ) + crop_shape
        self.assertEqual(preprocessed_ds.element_spec, (tf.TensorSpec(
            shape=expected_element_spec_shape,
            dtype=tf.float32), tf.TensorSpec(shape=(None, ), dtype=tf.int64)))

        expected_element_shape = (1, ) + crop_shape
        element = next(iter(preprocessed_ds))
        expected_element = (tf.zeros(shape=expected_element_shape,
                                     dtype=tf.float32),
                            tf.ones(shape=(1, ), dtype=tf.int32))
        self.assertAllClose(self.evaluate(element), expected_element)
 def test_raises_iterable_length_2_crop(self):
     preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
     with self.assertRaisesRegex(ValueError,
                                 'The crop_shape must have length 3'):
         image_classification_preprocessing.create_preprocess_fn(
             preprocess_spec, crop_shape=(32, 32))
 def test_raises_non_iterable_crop(self):
     preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
     with self.assertRaisesRegex(TypeError,
                                 'crop_shape must be an iterable'):
         image_classification_preprocessing.create_preprocess_fn(
             preprocess_spec, crop_shape=32)
def create_image_classification_task_with_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    model_id: Union[str, ResnetModel],
    crop_height: int,
    crop_width: int,
    distort_train_images: bool,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
    """Creates a baseline task for image classification on CIFAR-100.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    model_id: A string identifier for a digit recognition model. Must be one of
      `resnet18`, `resnet34`, `resnet50`, `resnet101` and `resnet152. These
      correspond to various ResNet architectures. Unlike standard ResNet
      architectures though, the batch normalization layers are replaced with
      group normalization.
    crop_height: An integer specifying the desired height for cropping images.
      Must be between 1 and 32 (the height of uncropped CIFAR-100 images). By
      default, this is set to
      `tff.simulation.baselines.cifar100.DEFAULT_CROP_HEIGHT`.
    crop_width: An integer specifying the desired width for cropping images.
      Must be between 1 and 32 (the width of uncropped CIFAR-100 images). By
      default this is set to
      `tff.simulation.baselines.cifar100.DEFAULT_CROP_WIDTH`.
    distort_train_images: Whether to distort images in the train preprocessing
      function.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    if crop_height < 1 or crop_width < 1 or crop_height > 32 or crop_width > 32:
        raise ValueError(
            'The crop_height and crop_width must be between 1 and 32.')
    crop_shape = (crop_height, crop_width, 3)

    if eval_client_spec is None:
        eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                  batch_size=64,
                                                  shuffle_buffer_size=1)

    train_preprocess_fn = image_classification_preprocessing.create_preprocess_fn(
        train_client_spec,
        crop_shape=crop_shape,
        distort_image=distort_train_images)
    eval_preprocess_fn = image_classification_preprocessing.create_preprocess_fn(
        eval_client_spec, crop_shape=crop_shape)

    task_datasets = task_data.BaselineTaskDatasets(
        train_data=train_data,
        test_data=test_data,
        validation_data=None,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

    def model_fn() -> model.Model:
        return keras_utils.from_keras_model(
            keras_model=_get_resnet_model(model_id, crop_shape),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            input_spec=task_datasets.element_type_structure,
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    return baseline_task.BaselineTask(task_datasets, model_fn)