def test_federated_training_loop(self): Batch = collections.namedtuple('Batch', ['x', 'y']) # pylint: disable=invalid-name batch = Batch(x=np.ones([1, 784], dtype=np.float32), y=np.ones([1, 1], dtype=np.int64)) federated_data = [[batch]] def client_datasets_fn(round_num): del round_num return federated_data loss_list = [] def metrics_hook(state, metrics, round_num): del round_num del metrics keras_model = models.create_keras_model(compile_model=True) tff.learning.assign_weights_to_keras_model(keras_model, state.model) loss_list.append(keras_model.test_on_batch(batch.x, batch.y)) server_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=1.0 ) training_loops.federated_averaging_training_loop( models.model_fn, server_optimizer_fn, client_datasets_fn, total_rounds=3, metrics_hook=metrics_hook) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def run_experiment(): """Data preprocessing and experiment execution.""" emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data() if FLAGS.num_pseudo_clients > 1: emnist_train = tff.simulation.datasets.emnist.get_infinite( emnist_train, FLAGS.num_pseudo_clients) emnist_test = tff.simulation.datasets.emnist.get_infinite( emnist_test, FLAGS.num_pseudo_clients) example_tuple = collections.namedtuple('Example', ['x', 'y']) def element_fn(element): return example_tuple(x=tf.reshape(element['pixels'], [-1]), y=tf.reshape(element['label'], [1])) def preprocess_train_dataset(dataset): """Preprocess training dataset.""" return dataset.map(element_fn).apply( tf.data.experimental.shuffle_and_repeat( buffer_size=10000, count=FLAGS.client_epochs_per_round)).batch(FLAGS.batch_size) def preprocess_test_dataset(dataset): """Preprocess testing dataset.""" return dataset.map(element_fn).batch(100, drop_remainder=False) emnist_train = emnist_train.preprocess(preprocess_train_dataset) emnist_test = preprocess_test_dataset( emnist_test.create_tf_dataset_from_all_clients()) example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(example_dataset))) def client_datasets_fn(round_num): """Returns a list of client datasets.""" del round_num # Unused. sampled_clients = np.random.choice(emnist_train.client_ids, size=FLAGS.train_clients_per_round, replace=False) return [ emnist_train.create_tf_dataset_for_client(client) for client in sampled_clients ] tf.io.gfile.makedirs(FLAGS.root_output_dir) hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) keras_model = create_compiled_keras_model() the_metrics_hook = metrics_hook.MetricsHook.build(FLAGS.exp_name, FLAGS.root_output_dir, emnist_test, hparam_dict, keras_model) optimizer_fn = functools.partial(utils_impl.create_optimizer_from_flags, 'server') model = tff.learning.from_compiled_keras_model(keras_model, sample_batch) dp_query = tff.utils.build_dp_query( FLAGS.clip, FLAGS.noise_multiplier, FLAGS.train_clients_per_round, FLAGS.adaptive_clip_learning_rate, FLAGS.target_unclipped_quantile, FLAGS.clipped_count_budget_allocation, FLAGS.train_clients_per_round, FLAGS.use_per_vector, model) # Uniform weighting. def client_weight_fn(outputs): del outputs # unused. return 1.0 dp_aggregate_fn, _ = tff.utils.build_dp_aggregate(dp_query) def model_fn(): keras_model = create_compiled_keras_model() return tff.learning.from_compiled_keras_model(keras_model, sample_batch) training_loops.federated_averaging_training_loop( model_fn, optimizer_fn, client_datasets_fn, total_rounds=FLAGS.total_rounds, rounds_per_eval=FLAGS.rounds_per_eval, metrics_hook=the_metrics_hook, client_weight_fn=client_weight_fn, stateful_delta_aggregate_fn=dp_aggregate_fn)
def run_experiment(): """Data preprocessing and experiment execution.""" np.random.seed(FLAGS.random_seed) tf.random.set_random_seed(FLAGS.random_seed) emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data() example_tuple = collections.namedtuple('Example', ['x', 'y']) def element_fn(element): return example_tuple( x=tf.reshape(element['pixels'], [-1]), y=tf.reshape(element['label'], [1])) def preprocess_train_dataset(dataset): """Preprocess training dataset.""" return dataset.map(element_fn).apply( tf.data.experimental.shuffle_and_repeat( buffer_size=10000, count=FLAGS.client_epochs_per_round)).batch(FLAGS.batch_size) def preprocess_test_dataset(dataset): """Preprocess testing dataset.""" return dataset.map(element_fn).batch(100, drop_remainder=False) emnist_train = emnist_train.preprocess(preprocess_train_dataset) emnist_test = preprocess_test_dataset( emnist_test.create_tf_dataset_from_all_clients()) example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(example_dataset))) def model_fn(): keras_model = create_compiled_keras_model() return tff.learning.from_compiled_keras_model(keras_model, sample_batch) def client_datasets_fn(round_num): """Returns a list of client datasets.""" del round_num # Unused. sampled_clients = np.random.choice( emnist_train.client_ids, size=FLAGS.train_clients_per_round, replace=False) return [ emnist_train.create_tf_dataset_for_client(client) for client in sampled_clients ] tf.io.gfile.makedirs(FLAGS.root_output_dir) hparam_dict = collections.OrderedDict([ (name, FLAGS[name].value) for name in hparam_flags ]) metrics_hook = MetricsHook.build(FLAGS.exp_name, FLAGS.root_output_dir, emnist_test, hparam_dict) optimizer_fn = lambda: utils_impl.get_optimizer_from_flags('server') training_loops.federated_averaging_training_loop( model_fn, optimizer_fn, client_datasets_fn, total_rounds=FLAGS.total_rounds, rounds_per_eval=FLAGS.rounds_per_eval, metrics_hook=metrics_hook)
def run_experiment(): """Data preprocessing and experiment execution.""" emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data() example_tuple = collections.namedtuple('Example', ['x', 'y']) def element_fn(element): return example_tuple(x=tf.reshape(element['pixels'], [-1]), y=tf.reshape(element['label'], [1])) def preprocess_train_dataset(dataset): """Preprocess training dataset.""" return dataset.map(element_fn).apply( tf.data.experimental.shuffle_and_repeat( buffer_size=10000, count=FLAGS.client_epochs_per_round)).batch(FLAGS.batch_size) def preprocess_test_dataset(dataset): """Preprocess testing dataset.""" return dataset.map(element_fn).batch(100, drop_remainder=False) emnist_train = emnist_train.preprocess(preprocess_train_dataset) emnist_test = preprocess_test_dataset( emnist_test.create_tf_dataset_from_all_clients()) example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(example_dataset))) def model_fn(): keras_model = create_compiled_keras_model() return tff.learning.from_compiled_keras_model(keras_model, sample_batch) def client_datasets_fn(round_num): """Returns a list of client datasets.""" del round_num # Unused. sampled_clients = np.random.choice(emnist_train.client_ids, size=FLAGS.train_clients_per_round, replace=False) return [ emnist_train.create_tf_dataset_for_client(client) for client in sampled_clients ] tf.io.gfile.makedirs(FLAGS.root_output_dir) hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) the_metrics_hook = metrics_hook.MetricsHook.build( FLAGS.exp_name, FLAGS.root_output_dir, emnist_test, hparam_dict, create_compiled_keras_model()) optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('server') if FLAGS.use_compression: # We create a `StatefulBroadcastFn` and `StatefulAggregateFn` by providing # the `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding # utilities. The fns are called once for each of the model weights created # by model_fn, and return instances of appropriate encoders. encoded_broadcast_fn = ( tff.learning.framework.build_encoded_broadcast_from_model( model_fn, _broadcast_encoder_fn)) encoded_mean_fn = tff.learning.framework.build_encoded_mean_from_model( model_fn, _mean_encoder_fn) else: encoded_broadcast_fn = None encoded_mean_fn = None training_loops.federated_averaging_training_loop( model_fn, optimizer_fn, client_datasets_fn, total_rounds=FLAGS.total_rounds, rounds_per_eval=FLAGS.rounds_per_eval, metrics_hook=the_metrics_hook, stateful_model_broadcast_fn=encoded_broadcast_fn, stateful_delta_aggregate_fn=encoded_mean_fn)
def run_experiment(): """Data preprocessing and experiment execution.""" emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data() example_tuple = collections.namedtuple('Example', ['x', 'y']) def element_fn(element): return example_tuple( x=tf.reshape(element['pixels'], [-1]), y=tf.reshape(element['label'], [1])) def preprocess_train_dataset(dataset): """Preprocess training dataset.""" return (dataset.map(element_fn).shuffle(buffer_size=10000).repeat( FLAGS.client_epochs_per_round).batch(FLAGS.batch_size)) def preprocess_test_dataset(dataset): """Preprocess testing dataset.""" return dataset.map(element_fn).batch(100, drop_remainder=False) emnist_train = emnist_train.preprocess(preprocess_train_dataset) emnist_test = preprocess_test_dataset( emnist_test.create_tf_dataset_from_all_clients()) example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(example_dataset))) def model_fn(): keras_model = models.create_original_fedavg_cnn_model() return tff.learning.from_keras_model( keras_model, dummy_batch=sample_batch, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=METRICS_LIST) def client_datasets_fn(round_num): """Returns a list of client datasets.""" del round_num # Unused. sampled_clients = np.random.choice( emnist_train.client_ids, size=FLAGS.train_clients_per_round, replace=False) return [ emnist_train.create_tf_dataset_for_client(client) for client in sampled_clients ] tf.io.gfile.makedirs(FLAGS.root_output_dir) hparam_dict = collections.OrderedDict([ (name, FLAGS[name].value) for name in hparam_flags ]) the_metrics_hook = metrics_hook.MetricsHook.build( FLAGS.exp_name, FLAGS.root_output_dir, emnist_test, hparam_dict, create_compiled_keras_model()) client_optimizer_fn = functools.partial( utils_impl.create_optimizer_from_flags, 'client') server_optimizer_fn = functools.partial( utils_impl.create_optimizer_from_flags, 'server') training_loops.federated_averaging_training_loop( model_fn, client_optimizer_fn, server_optimizer_fn, client_datasets_fn, total_rounds=FLAGS.total_rounds, rounds_per_eval=FLAGS.rounds_per_eval, metrics_hook=the_metrics_hook)