def create_compiled_keras_model(): """Create compiled keras model based on the original FedAvg CNN.""" model = models.create_original_fedavg_cnn_model() model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer=utils_impl.get_optimizer_from_flags('client'), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) return model
def test_get_server_optimizer_from_flags(self, optimizer_name, optimizer_cls): FLAGS['{}_optimizer'.format(TEST_SERVER_FLAG_PREFIX)].value = optimizer_name # Construct a default optimizer. default_optimizer = utils_impl.get_optimizer_from_flags( TEST_SERVER_FLAG_PREFIX) self.assertIsInstance(default_optimizer, optimizer_cls) # Set a flag to a non-default. FLAGS['{}_{}'.format(TEST_SERVER_FLAG_PREFIX, 'learning_rate')].value = 100.0 custom_optimizer = utils_impl.get_optimizer_from_flags( TEST_SERVER_FLAG_PREFIX) self.assertIsInstance(custom_optimizer, optimizer_cls) self.assertEqual(custom_optimizer.get_config()['learning_rate'], 100.0) # Override the flag value. custom_optimizer = utils_impl.get_optimizer_from_flags( TEST_SERVER_FLAG_PREFIX, {'learning_rate': 5.0}) self.assertIsInstance(custom_optimizer, optimizer_cls) self.assertEqual(custom_optimizer.get_config()['learning_rate'], 5.0)
def test_get_optimizer_from_flags_flags_set_not_for_optimizer(self): FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'sgd' # Set an Adam flag that isn't used in SGD. # We need to use `_parse_args` because that is the only way FLAGS is # notified that a non-default value is being used. FLAGS._parse_args( args=['--{}_adam_beta_1=0.5'.format(TEST_CLIENT_FLAG_PREFIX)], known_only=True) with self.assertRaisesRegex( ValueError, r'Commandline flags for .*\[sgd\].*\'test_client_adam_beta_1\'.*'): _ = utils_impl.get_optimizer_from_flags(TEST_CLIENT_FLAG_PREFIX)
def run_experiment(): """Data preprocessing and experiment execution.""" np.random.seed(FLAGS.random_seed) tf.random.set_random_seed(FLAGS.random_seed) emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data() if FLAGS.num_pseudo_clients > 1: emnist_train = tff.simulation.datasets.emnist.get_infinite( emnist_train, FLAGS.num_pseudo_clients) emnist_test = tff.simulation.datasets.emnist.get_infinite( emnist_test, FLAGS.num_pseudo_clients) example_tuple = collections.namedtuple('Example', ['x', 'y']) def element_fn(element): return example_tuple(x=tf.reshape(element['pixels'], [-1]), y=tf.reshape(element['label'], [1])) def preprocess_train_dataset(dataset): """Preprocess training dataset.""" return dataset.map(element_fn).apply( tf.data.experimental.shuffle_and_repeat( buffer_size=10000, count=FLAGS.client_epochs_per_round)).batch(FLAGS.batch_size) def preprocess_test_dataset(dataset): """Preprocess testing dataset.""" return dataset.map(element_fn).batch(100, drop_remainder=False) emnist_train = emnist_train.preprocess(preprocess_train_dataset) emnist_test = preprocess_test_dataset( emnist_test.create_tf_dataset_from_all_clients()) example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(example_dataset))) def client_datasets_fn(round_num): """Returns a list of client datasets.""" del round_num # Unused. sampled_clients = np.random.choice(emnist_train.client_ids, size=FLAGS.train_clients_per_round, replace=False) return [ emnist_train.create_tf_dataset_for_client(client) for client in sampled_clients ] tf.io.gfile.makedirs(FLAGS.root_output_dir) hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) keras_model = create_compiled_keras_model() the_metrics_hook = metrics_hook.MetricsHook.build(FLAGS.exp_name, FLAGS.root_output_dir, emnist_test, hparam_dict, keras_model) optimizer_fn = lambda: utils_impl.get_optimizer_from_flags('server') model = tff.learning.from_compiled_keras_model(keras_model, sample_batch) dp_query = tff.utils.build_dp_query( FLAGS.clip, FLAGS.noise_multiplier, FLAGS.train_clients_per_round, FLAGS.adaptive_clip_learning_rate, FLAGS.target_unclipped_quantile, FLAGS.clipped_count_budget_allocation, FLAGS.train_clients_per_round, FLAGS.use_per_vector, model) # Uniform weighting. def client_weight_fn(outputs): del outputs # unused. return 1.0 dp_aggregate_fn, _ = tff.utils.build_dp_aggregate(dp_query) def model_fn(): keras_model = create_compiled_keras_model() return tff.learning.from_compiled_keras_model(keras_model, sample_batch) training_loops.federated_averaging_training_loop( model_fn, optimizer_fn, client_datasets_fn, total_rounds=FLAGS.total_rounds, rounds_per_eval=FLAGS.rounds_per_eval, metrics_hook=the_metrics_hook, client_weight_fn=client_weight_fn, stateful_delta_aggregate_fn=dp_aggregate_fn)
def run_experiment(): """Data preprocessing and experiment execution.""" np.random.seed(FLAGS.random_seed) tf.random.set_random_seed(FLAGS.random_seed) emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data() example_tuple = collections.namedtuple('Example', ['x', 'y']) def element_fn(element): return example_tuple( x=tf.reshape(element['pixels'], [-1]), y=tf.reshape(element['label'], [1])) def preprocess_train_dataset(dataset): """Preprocess training dataset.""" return dataset.map(element_fn).apply( tf.data.experimental.shuffle_and_repeat( buffer_size=10000, count=FLAGS.client_epochs_per_round)).batch(FLAGS.batch_size) def preprocess_test_dataset(dataset): """Preprocess testing dataset.""" return dataset.map(element_fn).batch(100, drop_remainder=False) emnist_train = emnist_train.preprocess(preprocess_train_dataset) emnist_test = preprocess_test_dataset( emnist_test.create_tf_dataset_from_all_clients()) example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(example_dataset))) def model_fn(): keras_model = create_compiled_keras_model() return tff.learning.from_compiled_keras_model(keras_model, sample_batch) def client_datasets_fn(round_num): """Returns a list of client datasets.""" del round_num # Unused. sampled_clients = np.random.choice( emnist_train.client_ids, size=FLAGS.train_clients_per_round, replace=False) return [ emnist_train.create_tf_dataset_for_client(client) for client in sampled_clients ] tf.io.gfile.makedirs(FLAGS.root_output_dir) hparam_dict = collections.OrderedDict([ (name, FLAGS[name].value) for name in hparam_flags ]) metrics_hook = MetricsHook.build(FLAGS.exp_name, FLAGS.root_output_dir, emnist_test, hparam_dict) optimizer_fn = lambda: utils_impl.get_optimizer_from_flags('server') training_loops.federated_averaging_training_loop( model_fn, optimizer_fn, client_datasets_fn, total_rounds=FLAGS.total_rounds, rounds_per_eval=FLAGS.rounds_per_eval, metrics_hook=metrics_hook)
def test_get_optimizer_from_flags(self): utils_impl.define_optimizer_flags('server', defaults=dict(learning_rate=1.25)) self.assertEqual(FLAGS.server_learning_rate, 1.25) optimizer = utils_impl.get_optimizer_from_flags('server') self.assertEqual(optimizer.get_config()['learning_rate'], 1.25)
def test_get_optimizer_from_flags_invalid_optimizer(self): FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'foo' with self.assertRaisesRegex(ValueError, 'not a valid optimizer'): _ = utils_impl.get_optimizer_from_flags(TEST_CLIENT_FLAG_PREFIX)