def test_raises_length_2_crop(self): with self.assertRaises(ValueError): dataset.get_federated_cifar100(client_epochs_per_round=1, train_batch_size=10, crop_shape=(32, 32)) with self.assertRaises(ValueError): dataset.get_centralized_cifar100(train_batch_size=10, crop_shape=(32, 32))
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) # TODO(b/139129100): Remove this once the local executor is the default. tff.framework.set_default_executor( tff.framework.local_executor_factory(max_fanout=25)) cifar_train, _ = dataset.get_federated_cifar100( client_epochs_per_round=1, train_batch_size=FLAGS.client_batch_size, crop_shape=CROP_SHAPE) input_spec = cifar_train.create_tf_dataset_for_client( cifar_train.client_ids[0]).element_spec model_builder = functools.partial(resnet_models.create_resnet18, input_shape=CROP_SHAPE, num_classes=NUM_CLASSES) tff_model = tff.learning.from_keras_model( keras_model=model_builder(), input_spec=input_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy()) yogi_init_accum_estimate = optimizer_utils.compute_yogi_init( cifar_train, tff_model, num_clients=FLAGS.num_clients) logging.info('Yogi initializer: {:s}'.format( format(yogi_init_accum_estimate, '10.6E')))
def test_federated_cifar_structure(self): crop_shape = (32, 32, 3) cifar_train, _ = dataset.get_federated_cifar100( client_epochs_per_round=1, train_batch_size=10, crop_shape=crop_shape) client_id = cifar_train.client_ids[0] client_dataset = cifar_train.create_tf_dataset_for_client(client_id) train_batch = next(iter(client_dataset)) train_batch_shape = tuple(train_batch[0].shape) self.assertEqual(train_batch_shape, (10, 32, 32, 3))
def test_take_with_repeat(self): cifar_train, _ = dataset.get_federated_cifar100( client_epochs_per_round=-1, train_batch_size=10, max_batches_per_client=10) self.assertEqual(len(cifar_train.client_ids), 500) for i in range(10): client_ds = cifar_train.create_tf_dataset_for_client( cifar_train.client_ids[i]) self.assertEqual(_compute_length_of_dataset(client_ds), 10)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tf.compat.v1.enable_v2_behavior() cifar_train, cifar_test = dataset.get_federated_cifar100( client_epochs_per_round=FLAGS.client_epochs_per_round, train_batch_size=FLAGS.client_batch_size, crop_shape=CROP_SHAPE) input_spec = cifar_train.create_tf_dataset_for_client( cifar_train.client_ids[0]).element_spec model_builder = functools.partial( resnet_models.create_resnet18, input_shape=CROP_SHAPE, num_classes=NUM_CLASSES) logging.info('Training model:') logging.info(model_builder().summary()) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] training_process = iterative_process_builder.from_flags( input_spec=input_spec, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset=cifar_train, train_clients_per_round=FLAGS.clients_per_round, random_seed=FLAGS.client_datasets_random_seed) assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model evaluate_fn = training_utils.build_evaluate_fn( eval_dataset=cifar_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder, assign_weights_to_keras_model=assign_weights_fn) training_loop.run( iterative_process=training_process, client_datasets_fn=client_datasets_fn, evaluate_fn=evaluate_fn)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tf.compat.v1.enable_v2_behavior() # TODO(b/139129100): Remove this once the local executor is the default. tff.framework.set_default_executor( tff.framework.local_executor_factory(max_fanout=25)) cifar_train, cifar_test = dataset.get_federated_cifar100( client_epochs_per_round=FLAGS.client_epochs_per_round, train_batch_size=FLAGS.client_batch_size, crop_shape=CROP_SHAPE) sample_client_dataset = cifar_train.create_tf_dataset_for_client( cifar_train.client_ids[0]) sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(sample_client_dataset))) model_builder = functools.partial( resnet_models.create_resnet18, input_shape=CROP_SHAPE, num_classes=NUM_CLASSES) logging.info('Training model:') logging.info(model_builder().summary()) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] training_process = iterative_process_builder.from_flags( dummy_batch=sample_batch, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) training_loop.run( iterative_process=training_process, client_datasets_fn=training_utils.build_client_datasets_fn( cifar_train, FLAGS.clients_per_round), evaluate_fn=training_utils.build_evaluate_fn( eval_dataset=cifar_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder), )
def test_raises_no_repeat_and_no_take(self): with self.assertRaisesRegex( ValueError, 'Argument client_epochs_per_round is set to -1'): dataset.get_federated_cifar100(client_epochs_per_round=-1, train_batch_size=10, max_batches_per_client=-1)