def model_fn(): keras_model = models.create_original_fedavg_cnn_model() return tff.learning.from_keras_model( keras_model, dummy_batch=sample_batch, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=METRICS_LIST)
def model_fn(): keras_model = models.create_original_fedavg_cnn_model() return tff.learning.from_keras_model( keras_model, dummy_batch=sample_batch, loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
def create_compiled_keras_model(): """Create compiled keras model based on the original FedAvg CNN.""" model = models.create_original_fedavg_cnn_model() model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer=utils_impl.create_optimizer_from_flags('client'), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) return model
def create_compiled_keras_model(): """Create compiled keras model.""" model = models.create_original_fedavg_cnn_model( only_digits=FLAGS.digit_only_emnist) model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer=utils_impl.create_optimizer_from_flags('client'), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) return model
def create_compiled_keras_model(): """Create compiled keras model.""" if FLAGS.training_model == 'cnn': model = models.create_conv_dropout_model(only_digits=FLAGS.only_digits) elif FLAGS.training_model == 'orig_cnn': model = models.create_original_fedavg_cnn_model( only_digits=FLAGS.only_digits) elif FLAGS.training_model == '2nn': model = models.create_two_hidden_layer_model(only_digits=FLAGS.only_digits) elif FLAGS.training_model == 'resnet': model = models.create_resnet(num_blocks=9, only_digits=FLAGS.only_digits) else: raise ValueError('Model {} is not supported.'.format(FLAGS.training_model)) model.compile( loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer=tf.keras.optimizers.SGD( learning_rate=FLAGS.learning_rate, momentum=FLAGS.momentum), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) return model
def run_experiment(): """Data preprocessing and experiment execution.""" emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data() if FLAGS.num_pseudo_clients > 1: emnist_train = tff.simulation.datasets.emnist.get_infinite( emnist_train, FLAGS.num_pseudo_clients) emnist_test = tff.simulation.datasets.emnist.get_infinite( emnist_test, FLAGS.num_pseudo_clients) example_tuple = collections.namedtuple('Example', ['x', 'y']) def element_fn(element): return example_tuple( x=tf.reshape(element['pixels'], [-1]), y=tf.reshape(element['label'], [1])) def preprocess_train_dataset(dataset): """Preprocess training dataset.""" return (dataset.map(element_fn).shuffle(buffer_size=10000).repeat( FLAGS.client_epochs_per_round).batch(FLAGS.batch_size)) def preprocess_test_dataset(dataset): """Preprocess testing dataset.""" return dataset.map(element_fn).batch(100, drop_remainder=False) emnist_train = emnist_train.preprocess(preprocess_train_dataset) emnist_test = preprocess_test_dataset( emnist_test.create_tf_dataset_from_all_clients()) example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(example_dataset))) def client_datasets_fn(round_num): """Returns a list of client datasets.""" del round_num # Unused. sampled_clients = np.random.choice( emnist_train.client_ids, size=FLAGS.train_clients_per_round, replace=False) return [ emnist_train.create_tf_dataset_for_client(client) for client in sampled_clients ] tf.io.gfile.makedirs(FLAGS.root_output_dir) hparam_dict = collections.OrderedDict([ (name, FLAGS[name].value) for name in hparam_flags ]) keras_model = models.create_original_fedavg_cnn_model() keras_model.compile( loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) the_metrics_hook = metrics_hook.MetricsHook.build(FLAGS.exp_name, FLAGS.root_output_dir, emnist_test, hparam_dict, keras_model) client_optimizer_fn = functools.partial( utils_impl.create_optimizer_from_flags, 'client') server_optimizer_fn = functools.partial( utils_impl.create_optimizer_from_flags, 'server') def model_fn(): keras_model = models.create_original_fedavg_cnn_model() return tff.learning.from_keras_model( keras_model, dummy_batch=sample_batch, loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) model = model_fn() dp_query = tff.utils.build_dp_query( FLAGS.clip, FLAGS.noise_multiplier, FLAGS.train_clients_per_round, FLAGS.adaptive_clip_learning_rate, FLAGS.target_unclipped_quantile, FLAGS.clipped_count_budget_allocation, FLAGS.train_clients_per_round, FLAGS.use_per_vector, model) # Uniform weighting. def client_weight_fn(outputs): del outputs # unused. return 1.0 dp_aggregate_fn, _ = tff.utils.build_dp_aggregate(dp_query) training_loops.federated_averaging_training_loop( model_fn, client_optimizer_fn=client_optimizer_fn, server_fn=server_optimizer_fn, client_datasets_fn=client_datasets_fn, total_rounds=FLAGS.total_rounds, rounds_per_eval=FLAGS.rounds_per_eval, metrics_hook=the_metrics_hook, client_weight_fn=client_weight_fn, stateful_delta_aggregate_fn=dp_aggregate_fn)