Example #1
0
 def test_raises_length_2_crop(self):
     with self.assertRaises(ValueError):
         dataset.get_federated_cifar100(client_epochs_per_round=1,
                                        train_batch_size=10,
                                        crop_shape=(32, 32))
     with self.assertRaises(ValueError):
         dataset.get_centralized_cifar100(train_batch_size=10,
                                          crop_shape=(32, 32))
Example #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    # TODO(b/139129100): Remove this once the local executor is the default.
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    cifar_train, _ = dataset.get_federated_cifar100(
        client_epochs_per_round=1,
        train_batch_size=FLAGS.client_batch_size,
        crop_shape=CROP_SHAPE)
    input_spec = cifar_train.create_tf_dataset_for_client(
        cifar_train.client_ids[0]).element_spec

    model_builder = functools.partial(resnet_models.create_resnet18,
                                      input_shape=CROP_SHAPE,
                                      num_classes=NUM_CLASSES)
    tff_model = tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=input_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy())

    yogi_init_accum_estimate = optimizer_utils.compute_yogi_init(
        cifar_train, tff_model, num_clients=FLAGS.num_clients)
    logging.info('Yogi initializer: {:s}'.format(
        format(yogi_init_accum_estimate, '10.6E')))
Example #3
0
 def test_federated_cifar_structure(self):
   crop_shape = (32, 32, 3)
   cifar_train, _ = dataset.get_federated_cifar100(
       client_epochs_per_round=1, train_batch_size=10, crop_shape=crop_shape)
   client_id = cifar_train.client_ids[0]
   client_dataset = cifar_train.create_tf_dataset_for_client(client_id)
   train_batch = next(iter(client_dataset))
   train_batch_shape = tuple(train_batch[0].shape)
   self.assertEqual(train_batch_shape, (10, 32, 32, 3))
Example #4
0
 def test_take_with_repeat(self):
     cifar_train, _ = dataset.get_federated_cifar100(
         client_epochs_per_round=-1,
         train_batch_size=10,
         max_batches_per_client=10)
     self.assertEqual(len(cifar_train.client_ids), 500)
     for i in range(10):
         client_ds = cifar_train.create_tf_dataset_for_client(
             cifar_train.client_ids[i])
         self.assertEqual(_compute_length_of_dataset(client_ds), 10)
Example #5
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  tf.compat.v1.enable_v2_behavior()

  cifar_train, cifar_test = dataset.get_federated_cifar100(
      client_epochs_per_round=FLAGS.client_epochs_per_round,
      train_batch_size=FLAGS.client_batch_size,
      crop_shape=CROP_SHAPE)

  input_spec = cifar_train.create_tf_dataset_for_client(
      cifar_train.client_ids[0]).element_spec

  model_builder = functools.partial(
      resnet_models.create_resnet18,
      input_shape=CROP_SHAPE,
      num_classes=NUM_CLASSES)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  training_process = iterative_process_builder.from_flags(
      input_spec=input_spec,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)

  client_datasets_fn = training_utils.build_client_datasets_fn(
      train_dataset=cifar_train,
      train_clients_per_round=FLAGS.clients_per_round,
      random_seed=FLAGS.client_datasets_random_seed)

  assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

  evaluate_fn = training_utils.build_evaluate_fn(
      eval_dataset=cifar_test,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder,
      assign_weights_to_keras_model=assign_weights_fn)

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=client_datasets_fn,
      evaluate_fn=evaluate_fn)
Example #6
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  tf.compat.v1.enable_v2_behavior()
  # TODO(b/139129100): Remove this once the local executor is the default.
  tff.framework.set_default_executor(
      tff.framework.local_executor_factory(max_fanout=25))

  cifar_train, cifar_test = dataset.get_federated_cifar100(
      client_epochs_per_round=FLAGS.client_epochs_per_round,
      train_batch_size=FLAGS.client_batch_size,
      crop_shape=CROP_SHAPE)

  sample_client_dataset = cifar_train.create_tf_dataset_for_client(
      cifar_train.client_ids[0])

  sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                       next(iter(sample_client_dataset)))

  model_builder = functools.partial(
      resnet_models.create_resnet18,
      input_shape=CROP_SHAPE,
      num_classes=NUM_CLASSES)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  training_process = iterative_process_builder.from_flags(
      dummy_batch=sample_batch,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=training_utils.build_client_datasets_fn(
          cifar_train, FLAGS.clients_per_round),
      evaluate_fn=training_utils.build_evaluate_fn(
          eval_dataset=cifar_test,
          model_builder=model_builder,
          loss_builder=loss_builder,
          metrics_builder=metrics_builder),
  )
Example #7
0
 def test_raises_no_repeat_and_no_take(self):
     with self.assertRaisesRegex(
             ValueError, 'Argument client_epochs_per_round is set to -1'):
         dataset.get_federated_cifar100(client_epochs_per_round=-1,
                                        train_batch_size=10,
                                        max_batches_per_client=-1)