def test_preprocess_element_spec(self, crop_shape, distort_image): ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) preprocess_fn = cifar100_dataset.create_preprocess_fn( num_epochs=1, batch_size=1, shuffle_buffer_size=1, crop_shape=crop_shape, distort_image=distort_image) preprocessed_ds = preprocess_fn(ds) expected_element_shape = (None,) + crop_shape self.assertEqual( preprocessed_ds.element_spec, (tf.TensorSpec(shape=expected_element_shape, dtype=tf.float32), tf.TensorSpec(shape=(None,), dtype=tf.int64)))
def test_preprocess_returns_correct_element(self, crop_shape, distort_image): ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) preprocess_fn = cifar100_dataset.create_preprocess_fn( num_epochs=1, batch_size=20, shuffle_buffer_size=1, crop_shape=crop_shape, distort_image=distort_image) preprocessed_ds = preprocess_fn(ds) expected_element_shape = (1,) + crop_shape element = next(iter(preprocessed_ds)) expected_element = (tf.zeros( shape=expected_element_shape, dtype=tf.float32), tf.ones(shape=(1,), dtype=tf.int32)) self.assertAllClose(self.evaluate(element), expected_element)
def configure_training( task_spec: training_specs.TaskSpec, crop_size: int = 24, distort_train_images: bool = True) -> training_specs.RunnerSpec: """Configures training for the CIFAR-100 classification task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. crop_size: An optional integer representing the resulting size of input images after preprocessing. distort_train_images: A boolean indicating whether to distort training images during preprocessing via random crops, as opposed to simply resizing the image. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ crop_shape = (crop_size, crop_size, 3) cifar_train, _ = tff.simulation.datasets.cifar100.load_data() _, cifar_test = cifar100_dataset.get_centralized_datasets( train_batch_size=task_spec.client_batch_size, crop_shape=crop_shape) train_preprocess_fn = cifar100_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, crop_shape=crop_shape, distort_image=distort_train_images) input_spec = train_preprocess_fn.type_signature.result.element model_builder = functools.partial(resnet_models.create_resnet18, input_shape=crop_shape, num_classes=NUM_CLASSES) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = cifar_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( cifar_train.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) training_process.get_model_weights = iterative_process.get_model_weights centralized_eval_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=cifar_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) def test_fn(state): return centralized_eval_fn(iterative_process.get_model_weights(state)) def validation_fn(state, round_num): del round_num return test_fn(state) return training_specs.RunnerSpec(iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)