def test_load_from_gcs(self): self.skipTest( "CI infrastructure doesn't support downloading from GCS. Remove " 'skipTest to run test locally.') cifar_test = cifar100.load_data(FLAGS.test_tmpdir)[1] self.assertLen(cifar_test.client_ids, 100) self.assertCountEqual(cifar_test.client_ids, [str(i) for i in range(100)]) self.assertEqual(cifar_test.element_type_structure, EXPECTED_ELEMENT_TYPE) expected_coarse_labels = [] expected_labels = [] for i in range(20): expected_coarse_labels += [i] * 500 for i in range(100): expected_labels += [i] * 100 coarse_labels = [] labels = [] for client_id in cifar_test.client_ids: client_data = self.evaluate( list(cifar_test.create_tf_dataset_for_client(client_id))) self.assertLen(client_data, 100) for x in client_data: coarse_labels.append(x['coarse_label']) labels.append(x['label']) self.assertLen(coarse_labels, 10000) self.assertLen(labels, 10000) self.assertCountEqual(coarse_labels, expected_coarse_labels) self.assertCountEqual(labels, expected_labels)
def test_load_test_data(self): _, cifar_test = cifar100.load_data() self.assertLen(cifar_test.client_ids, 100) self.assertCountEqual(cifar_test.client_ids, [str(i) for i in range(100)]) self.assertEqual( cifar_test.element_type_structure, collections.OrderedDict([ ('coarse_label', tf.TensorSpec(shape=(), dtype=tf.int64)), ('image', tf.TensorSpec(shape=(32, 32, 3), dtype=tf.uint8)), ('label', tf.TensorSpec(shape=(), dtype=tf.int64)), ])) expected_coarse_labels = [] expected_labels = [] for i in range(20): expected_coarse_labels += [i] * 500 for i in range(100): expected_labels += [i] * 100 coarse_labels = [] labels = [] for client_id in cifar_test.client_ids: client_data = self.evaluate( list(cifar_test.create_tf_dataset_for_client(client_id))) self.assertLen(client_data, 100) for x in client_data: coarse_labels.append(x['coarse_label']) labels.append(x['label']) self.assertLen(coarse_labels, 10000) self.assertLen(labels, 10000) self.assertCountEqual(coarse_labels, expected_coarse_labels) self.assertCountEqual(labels, expected_labels)
def create_image_classification_task( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec] = None, model_id: Union[str, ResnetModel] = 'resnet18', crop_height: int = DEFAULT_CROP_HEIGHT, crop_width: int = DEFAULT_CROP_WIDTH, distort_train_images: bool = False, cache_dir: Optional[str] = None, use_synthetic_data: bool = False) -> baseline_task.BaselineTask: """Creates a baseline task for image classification on CIFAR-100. The goal of the task is to minimize the sparse categorical crossentropy between the output labels of the model and the true label of the image. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. model_id: A string identifier for a digit recognition model. Must be one of `resnet18`, `resnet34`, `resnet50`, `resnet101` and `resnet152. These correspond to various ResNet architectures. Unlike standard ResNet architectures though, the batch normalization layers are replaced with group normalization. crop_height: An integer specifying the desired height for cropping images. Must be between 1 and 32 (the height of uncropped CIFAR-100 images). By default, this is set to `tff.simulation.baselines.cifar100.DEFAULT_CROP_HEIGHT`. crop_width: An integer specifying the desired width for cropping images. Must be between 1 and 32 (the width of uncropped CIFAR-100 images). By default this is set to `tff.simulation.baselines.cifar100.DEFAULT_CROP_WIDTH`. distort_train_images: Whether to distort images in the train preprocessing function. cache_dir: An optional directory to cache the downloadeded datasets. If `None`, they will be cached to `~/.tff/`. use_synthetic_data: A boolean indicating whether to use synthetic CIFAR-100 data. This option should only be used for testing purposes, in order to avoid downloading the entire CIFAR-100 dataset. Returns: A `tff.simulation.baselines.BaselineTask`. """ if use_synthetic_data: synthetic_data = cifar100.get_synthetic() cifar_train = synthetic_data cifar_test = synthetic_data else: cifar_train, cifar_test = cifar100.load_data(cache_dir=cache_dir) return create_image_classification_task_with_datasets( train_client_spec, eval_client_spec, model_id, crop_height, crop_width, distort_train_images, cifar_train, cifar_test)
def get_centralized_datasets( train_batch_size: int = 20, test_batch_size: int = 100, train_shuffle_buffer_size: int = 10000, test_shuffle_buffer_size: int = 1, crop_shape: Tuple[int, int, int] = CIFAR_SHAPE ) -> Tuple[tf.data.Dataset, tf.data.Dataset]: """Loads and preprocesses centralized CIFAR100 training and testing sets. Args: train_batch_size: The batch size for the training dataset. test_batch_size: The batch size for the test dataset. train_shuffle_buffer_size: An integer specifying the buffer size used to shuffle the train dataset via `tf.data.Dataset.shuffle`. If set to an integer less than or equal to 1, no shuffling occurs. test_shuffle_buffer_size: An integer specifying the buffer size used to shuffle the test dataset via `tf.data.Dataset.shuffle`. If set to an integer less than or equal to 1, no shuffling occurs. crop_shape: An iterable of integers specifying the desired crop shape for pre-processing. Must be convertable to a tuple of integers (CROP_HEIGHT, CROP_WIDTH, NUM_CHANNELS) which cannot have elements that exceed (32, 32, 3), element-wise. The element in the last index should be set to 3 to maintain the RGB image structure of the elements. Returns: A tuple (cifar_train, cifar_test) of `tf.data.Dataset` instances representing the centralized training and test datasets. """ cifar_train, cifar_test = cifar100.load_data() cifar_train = cifar_train.create_tf_dataset_from_all_clients() cifar_test = cifar_test.create_tf_dataset_from_all_clients() train_preprocess_fn = create_preprocess_fn( num_epochs=1, batch_size=train_batch_size, shuffle_buffer_size=train_shuffle_buffer_size, crop_shape=crop_shape, distort_image=True) cifar_train = train_preprocess_fn(cifar_train) test_preprocess_fn = create_preprocess_fn( num_epochs=1, batch_size=test_batch_size, shuffle_buffer_size=test_shuffle_buffer_size, crop_shape=crop_shape, distort_image=False) cifar_test = test_preprocess_fn(cifar_test) return cifar_train, cifar_test
def get_federated_datasets( train_client_batch_size: int = 20, test_client_batch_size: int = 100, train_client_epochs_per_round: int = 1, test_client_epochs_per_round: int = 1, train_shuffle_buffer_size: int = NUM_EXAMPLES_PER_CLIENT, test_shuffle_buffer_size: int = 1, crop_shape: Tuple[int, int, int] = CIFAR_SHAPE, serializable: bool = False ) -> Tuple[client_data.ClientData, client_data.ClientData]: """Loads and preprocesses federated CIFAR100 training and testing sets. Args: train_client_batch_size: The batch size for all train clients. test_client_batch_size: The batch size for all test clients. train_client_epochs_per_round: The number of epochs each train client should iterate over their local dataset, via `tf.data.Dataset.repeat`. Must be set to a positive integer. test_client_epochs_per_round: The number of epochs each test client should iterate over their local dataset, via `tf.data.Dataset.repeat`. Must be set to a positive integer. train_shuffle_buffer_size: An integer representing the shuffle buffer size (as in `tf.data.Dataset.shuffle`) for each train client's dataset. By default, this is set to the largest dataset size among all clients. If set to some integer less than or equal to 1, no shuffling occurs. test_shuffle_buffer_size: An integer representing the shuffle buffer size (as in `tf.data.Dataset.shuffle`) for each test client's dataset. If set to some integer less than or equal to 1, no shuffling occurs. crop_shape: An iterable of integers specifying the desired crop shape for pre-processing. Must be convertable to a tuple of integers (CROP_HEIGHT, CROP_WIDTH, NUM_CHANNELS) which cannot have elements that exceed (32, 32, 3), element-wise. The element in the last index should be set to 3 to maintain the RGB image structure of the elements. serializable: Boolean indicating whether the returned datasets are intended to be serialized and shipped across RPC channels. If `True`, stateful transformations will be disallowed. Returns: A tuple (cifar_train, cifar_test) of `tff.simulation.ClientData` instances representing the federated training and test datasets. Raises: TypeError: If `serializable` is not a boolean. """ if not isinstance(serializable, bool): raise TypeError( 'serializable must be a Boolean; you passed {} of type {}.'.format( serializable, type(serializable))) cifar_train, cifar_test = cifar100.load_data() train_preprocess_fn = create_preprocess_fn( num_epochs=train_client_epochs_per_round, batch_size=train_client_batch_size, shuffle_buffer_size=train_shuffle_buffer_size, crop_shape=crop_shape, distort_image=not serializable) test_preprocess_fn = create_preprocess_fn( num_epochs=test_client_epochs_per_round, batch_size=test_client_batch_size, shuffle_buffer_size=test_shuffle_buffer_size, crop_shape=crop_shape, distort_image=False) cifar_train = cifar_train.preprocess(train_preprocess_fn) cifar_test = cifar_test.preprocess(test_preprocess_fn) return cifar_train, cifar_test