示例#1
0
 def test_load_from_gcs(self):
   self.skipTest(
       "CI infrastructure doesn't support downloading from GCS. Remove "
       'skipTest to run test locally.')
   cifar_test = cifar100.load_data(FLAGS.test_tmpdir)[1]
   self.assertLen(cifar_test.client_ids, 100)
   self.assertCountEqual(cifar_test.client_ids, [str(i) for i in range(100)])
   self.assertEqual(cifar_test.element_type_structure, EXPECTED_ELEMENT_TYPE)
   expected_coarse_labels = []
   expected_labels = []
   for i in range(20):
     expected_coarse_labels += [i] * 500
   for i in range(100):
     expected_labels += [i] * 100
   coarse_labels = []
   labels = []
   for client_id in cifar_test.client_ids:
     client_data = self.evaluate(
         list(cifar_test.create_tf_dataset_for_client(client_id)))
     self.assertLen(client_data, 100)
     for x in client_data:
       coarse_labels.append(x['coarse_label'])
       labels.append(x['label'])
   self.assertLen(coarse_labels, 10000)
   self.assertLen(labels, 10000)
   self.assertCountEqual(coarse_labels, expected_coarse_labels)
   self.assertCountEqual(labels, expected_labels)
示例#2
0
    def test_load_test_data(self):
        _, cifar_test = cifar100.load_data()
        self.assertLen(cifar_test.client_ids, 100)
        self.assertCountEqual(cifar_test.client_ids,
                              [str(i) for i in range(100)])

        self.assertEqual(
            cifar_test.element_type_structure,
            collections.OrderedDict([
                ('coarse_label', tf.TensorSpec(shape=(), dtype=tf.int64)),
                ('image', tf.TensorSpec(shape=(32, 32, 3), dtype=tf.uint8)),
                ('label', tf.TensorSpec(shape=(), dtype=tf.int64)),
            ]))

        expected_coarse_labels = []
        expected_labels = []
        for i in range(20):
            expected_coarse_labels += [i] * 500
        for i in range(100):
            expected_labels += [i] * 100

        coarse_labels = []
        labels = []
        for client_id in cifar_test.client_ids:
            client_data = self.evaluate(
                list(cifar_test.create_tf_dataset_for_client(client_id)))
            self.assertLen(client_data, 100)
            for x in client_data:
                coarse_labels.append(x['coarse_label'])
                labels.append(x['label'])
        self.assertLen(coarse_labels, 10000)
        self.assertLen(labels, 10000)
        self.assertCountEqual(coarse_labels, expected_coarse_labels)
        self.assertCountEqual(labels, expected_labels)
def create_image_classification_task(
        train_client_spec: client_spec.ClientSpec,
        eval_client_spec: Optional[client_spec.ClientSpec] = None,
        model_id: Union[str, ResnetModel] = 'resnet18',
        crop_height: int = DEFAULT_CROP_HEIGHT,
        crop_width: int = DEFAULT_CROP_WIDTH,
        distort_train_images: bool = False,
        cache_dir: Optional[str] = None,
        use_synthetic_data: bool = False) -> baseline_task.BaselineTask:
    """Creates a baseline task for image classification on CIFAR-100.

  The goal of the task is to minimize the sparse categorical crossentropy
  between the output labels of the model and the true label of the image.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    model_id: A string identifier for a digit recognition model. Must be one of
      `resnet18`, `resnet34`, `resnet50`, `resnet101` and `resnet152. These
      correspond to various ResNet architectures. Unlike standard ResNet
      architectures though, the batch normalization layers are replaced with
      group normalization.
    crop_height: An integer specifying the desired height for cropping images.
      Must be between 1 and 32 (the height of uncropped CIFAR-100 images). By
      default, this is set to
      `tff.simulation.baselines.cifar100.DEFAULT_CROP_HEIGHT`.
    crop_width: An integer specifying the desired width for cropping images.
      Must be between 1 and 32 (the width of uncropped CIFAR-100 images). By
      default this is set to
      `tff.simulation.baselines.cifar100.DEFAULT_CROP_WIDTH`.
    distort_train_images: Whether to distort images in the train preprocessing
      function.
    cache_dir: An optional directory to cache the downloadeded datasets. If
      `None`, they will be cached to `~/.tff/`.
    use_synthetic_data: A boolean indicating whether to use synthetic CIFAR-100
      data. This option should only be used for testing purposes, in order to
      avoid downloading the entire CIFAR-100 dataset.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    if use_synthetic_data:
        synthetic_data = cifar100.get_synthetic()
        cifar_train = synthetic_data
        cifar_test = synthetic_data
    else:
        cifar_train, cifar_test = cifar100.load_data(cache_dir=cache_dir)

    return create_image_classification_task_with_datasets(
        train_client_spec, eval_client_spec, model_id, crop_height, crop_width,
        distort_train_images, cifar_train, cifar_test)
def get_centralized_datasets(
    train_batch_size: int = 20,
    test_batch_size: int = 100,
    train_shuffle_buffer_size: int = 10000,
    test_shuffle_buffer_size: int = 1,
    crop_shape: Tuple[int, int, int] = CIFAR_SHAPE
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    """Loads and preprocesses centralized CIFAR100 training and testing sets.

  Args:
    train_batch_size: The batch size for the training dataset.
    test_batch_size: The batch size for the test dataset.
    train_shuffle_buffer_size: An integer specifying the buffer size used to
      shuffle the train dataset via `tf.data.Dataset.shuffle`. If set to an
      integer less than or equal to 1, no shuffling occurs.
    test_shuffle_buffer_size: An integer specifying the buffer size used to
      shuffle the test dataset via `tf.data.Dataset.shuffle`. If set to an
      integer less than or equal to 1, no shuffling occurs.
    crop_shape: An iterable of integers specifying the desired crop
      shape for pre-processing. Must be convertable to a tuple of integers
      (CROP_HEIGHT, CROP_WIDTH, NUM_CHANNELS) which cannot have elements that
      exceed (32, 32, 3), element-wise. The element in the last index should be
      set to 3 to maintain the RGB image structure of the elements.

  Returns:
    A tuple (cifar_train, cifar_test) of `tf.data.Dataset` instances
    representing the centralized training and test datasets.
  """
    cifar_train, cifar_test = cifar100.load_data()
    cifar_train = cifar_train.create_tf_dataset_from_all_clients()
    cifar_test = cifar_test.create_tf_dataset_from_all_clients()

    train_preprocess_fn = create_preprocess_fn(
        num_epochs=1,
        batch_size=train_batch_size,
        shuffle_buffer_size=train_shuffle_buffer_size,
        crop_shape=crop_shape,
        distort_image=True)
    cifar_train = train_preprocess_fn(cifar_train)

    test_preprocess_fn = create_preprocess_fn(
        num_epochs=1,
        batch_size=test_batch_size,
        shuffle_buffer_size=test_shuffle_buffer_size,
        crop_shape=crop_shape,
        distort_image=False)
    cifar_test = test_preprocess_fn(cifar_test)

    return cifar_train, cifar_test
def get_federated_datasets(
    train_client_batch_size: int = 20,
    test_client_batch_size: int = 100,
    train_client_epochs_per_round: int = 1,
    test_client_epochs_per_round: int = 1,
    train_shuffle_buffer_size: int = NUM_EXAMPLES_PER_CLIENT,
    test_shuffle_buffer_size: int = 1,
    crop_shape: Tuple[int, int, int] = CIFAR_SHAPE,
    serializable: bool = False
) -> Tuple[client_data.ClientData, client_data.ClientData]:
    """Loads and preprocesses federated CIFAR100 training and testing sets.

  Args:
    train_client_batch_size: The batch size for all train clients.
    test_client_batch_size: The batch size for all test clients.
    train_client_epochs_per_round: The number of epochs each train client should
      iterate over their local dataset, via `tf.data.Dataset.repeat`. Must be
      set to a positive integer.
    test_client_epochs_per_round: The number of epochs each test client should
      iterate over their local dataset, via `tf.data.Dataset.repeat`. Must be
      set to a positive integer.
    train_shuffle_buffer_size: An integer representing the shuffle buffer size
      (as in `tf.data.Dataset.shuffle`) for each train client's dataset. By
      default, this is set to the largest dataset size among all clients. If set
      to some integer less than or equal to 1, no shuffling occurs.
    test_shuffle_buffer_size: An integer representing the shuffle buffer size
      (as in `tf.data.Dataset.shuffle`) for each test client's dataset. If set
      to some integer less than or equal to 1, no shuffling occurs.
    crop_shape: An iterable of integers specifying the desired crop
      shape for pre-processing. Must be convertable to a tuple of integers
      (CROP_HEIGHT, CROP_WIDTH, NUM_CHANNELS) which cannot have elements that
      exceed (32, 32, 3), element-wise. The element in the last index should be
      set to 3 to maintain the RGB image structure of the elements.
    serializable: Boolean indicating whether the returned datasets are intended
      to be serialized and shipped across RPC channels. If `True`, stateful
      transformations will be disallowed.

  Returns:
    A tuple (cifar_train, cifar_test) of `tff.simulation.ClientData` instances
      representing the federated training and test datasets.

  Raises:
    TypeError: If `serializable` is not a boolean.
  """
    if not isinstance(serializable, bool):
        raise TypeError(
            'serializable must be a Boolean; you passed {} of type {}.'.format(
                serializable, type(serializable)))
    cifar_train, cifar_test = cifar100.load_data()

    train_preprocess_fn = create_preprocess_fn(
        num_epochs=train_client_epochs_per_round,
        batch_size=train_client_batch_size,
        shuffle_buffer_size=train_shuffle_buffer_size,
        crop_shape=crop_shape,
        distort_image=not serializable)

    test_preprocess_fn = create_preprocess_fn(
        num_epochs=test_client_epochs_per_round,
        batch_size=test_client_batch_size,
        shuffle_buffer_size=test_shuffle_buffer_size,
        crop_shape=crop_shape,
        distort_image=False)

    cifar_train = cifar_train.preprocess(train_preprocess_fn)
    cifar_test = cifar_test.preprocess(test_preprocess_fn)
    return cifar_train, cifar_test