예제 #1
0
 def test_preprocess_element_spec(self, crop_shape, distort_image):
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   preprocess_fn = cifar100_dataset.create_preprocess_fn(
       num_epochs=1,
       batch_size=1,
       shuffle_buffer_size=1,
       crop_shape=crop_shape,
       distort_image=distort_image)
   preprocessed_ds = preprocess_fn(ds)
   expected_element_shape = (None,) + crop_shape
   self.assertEqual(
       preprocessed_ds.element_spec,
       (tf.TensorSpec(shape=expected_element_shape, dtype=tf.float32),
        tf.TensorSpec(shape=(None,), dtype=tf.int64)))
예제 #2
0
  def test_preprocess_returns_correct_element(self, crop_shape, distort_image):
    ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
    preprocess_fn = cifar100_dataset.create_preprocess_fn(
        num_epochs=1,
        batch_size=20,
        shuffle_buffer_size=1,
        crop_shape=crop_shape,
        distort_image=distort_image)
    preprocessed_ds = preprocess_fn(ds)

    expected_element_shape = (1,) + crop_shape
    element = next(iter(preprocessed_ds))
    expected_element = (tf.zeros(
        shape=expected_element_shape,
        dtype=tf.float32), tf.ones(shape=(1,), dtype=tf.int32))
    self.assertAllClose(self.evaluate(element), expected_element)
예제 #3
0
def configure_training(
        task_spec: training_specs.TaskSpec,
        crop_size: int = 24,
        distort_train_images: bool = True) -> training_specs.RunnerSpec:
    """Configures training for the CIFAR-100 classification task.

  This method will load and pre-process datasets and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process compatible with `federated_research.utils.training_loop`.

  Args:
    task_spec: A `TaskSpec` class for creating federated training tasks.
    crop_size: An optional integer representing the resulting size of input
      images after preprocessing.
    distort_train_images: A boolean indicating whether to distort training
      images during preprocessing via random crops, as opposed to simply
      resizing the image.

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """
    crop_shape = (crop_size, crop_size, 3)

    cifar_train, _ = tff.simulation.datasets.cifar100.load_data()
    _, cifar_test = cifar100_dataset.get_centralized_datasets(
        train_batch_size=task_spec.client_batch_size, crop_shape=crop_shape)

    train_preprocess_fn = cifar100_dataset.create_preprocess_fn(
        num_epochs=task_spec.client_epochs_per_round,
        batch_size=task_spec.client_batch_size,
        crop_shape=crop_shape,
        distort_image=distort_train_images)
    input_spec = train_preprocess_fn.type_signature.result.element

    model_builder = functools.partial(resnet_models.create_resnet18,
                                      input_shape=crop_shape,
                                      num_classes=NUM_CLASSES)

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    iterative_process = task_spec.iterative_process_builder(tff_model_fn)

    @tff.tf_computation(tf.string)
    def build_train_dataset_from_client_id(client_id):
        client_dataset = cifar_train.dataset_computation(client_id)
        return train_preprocess_fn(client_dataset)

    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        build_train_dataset_from_client_id, iterative_process)
    client_ids_fn = training_utils.build_sample_fn(
        cifar_train.client_ids,
        size=task_spec.clients_per_round,
        replace=False,
        random_seed=task_spec.client_datasets_random_seed)
    # We convert the output to a list (instead of an np.ndarray) so that it can
    # be used as input to the iterative process.
    client_sampling_fn = lambda x: list(client_ids_fn(x))

    training_process.get_model_weights = iterative_process.get_model_weights

    centralized_eval_fn = training_utils.build_centralized_evaluate_fn(
        eval_dataset=cifar_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    def test_fn(state):
        return centralized_eval_fn(iterative_process.get_model_weights(state))

    def validation_fn(state, round_num):
        del round_num
        return test_fn(state)

    return training_specs.RunnerSpec(iterative_process=training_process,
                                     client_datasets_fn=client_sampling_fn,
                                     validation_fn=validation_fn,
                                     test_fn=test_fn)