def _get_client_datasets_fn(train_data): """Returns function for client datasets per round.""" if FLAGS.total_epochs is None: def client_datasets_fn(round_num: int, epoch: int): del round_num sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.clients_per_round, replace=False) return [ train_data.create_tf_dataset_for_client(client) for client in sampled_clients ], epoch logging.info('Sample clients for max %d rounds', FLAGS.total_rounds) else: client_shuffer = training_loop.ClientIDShuffler( FLAGS.clients_per_round, train_data) def client_datasets_fn(round_num: int, epoch: int): sampled_clients, epoch = client_shuffer.sample_client_ids( round_num, epoch) return [ train_data.create_tf_dataset_for_client(client) for client in sampled_clients ], epoch logging.info('Shuffle clients for max %d epochs and %d rounds', FLAGS.total_epochs, FLAGS.total_rounds) return client_datasets_fn
def test_remainder(self): clients_data = tff.simulation.datasets.stackoverflow.get_synthetic() client_shuffer1 = training_loop.ClientIDShuffler( len(clients_data.client_ids) - 1, clients_data, drop_remainder=True) client_shuffer2 = training_loop.ClientIDShuffler( len(clients_data.client_ids) - 1, clients_data, drop_remainder=False) epoch1, epoch2, round_num = 0, 0, 0 total_rounds = 2 while round_num < total_rounds: clients1, epoch1 = client_shuffer1.sample_client_ids( round_num, epoch1) clients2, epoch2 = client_shuffer2.sample_client_ids( round_num, epoch2) round_num += 1 self.assertEqual(len(clients1), len(clients_data.client_ids) - 1) self.assertEqual(len(clients2), 1) self.assertEqual(epoch1, 2) self.assertEqual(epoch2, 1)
def test_shuffling(self): clients_data = tff.simulation.datasets.stackoverflow.get_synthetic() client_shuffer = training_loop.ClientIDShuffler(1, clients_data) epoch, round_num = 0, 0 total_epochs = 2 epoch2clientid = [[] for _ in range(total_epochs)] while epoch < total_epochs: clients, new_epoch = client_shuffer.sample_client_ids( round_num, epoch) epoch2clientid[epoch].extend(clients) round_num += 1 epoch = new_epoch self.assertCountEqual(epoch2clientid[0], epoch2clientid[1])
def train_and_eval(): """Train and evaluate StackOver NWP task.""" logging.info('Show FLAGS for debugging:') for f in HPARAM_FLAGS: logging.info('%s=%s', f, FLAGS[f].value) train_dataset_computation, train_set, validation_set, test_set = _preprocess_stackoverflow( FLAGS.vocab_size, FLAGS.num_oov_buckets, FLAGS.sequence_length, FLAGS.num_validation_examples, FLAGS.client_batch_size, FLAGS.client_epochs_per_round, FLAGS.max_elements_per_user) input_spec = train_dataset_computation.type_signature.result.element def tff_model_fn(): keras_model = models.create_recurrent_model( vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) return dp_fedavg.KerasModelWrapper(keras_model, input_spec, loss) noise_std = FLAGS.clip_norm * FLAGS.noise_multiplier / float( FLAGS.clients_per_round) server_optimizer_fn = functools.partial(_server_optimizer_fn, name=FLAGS.server_optimizer, learning_rate=FLAGS.server_lr, noise_std=noise_std) client_optimizer_fn = functools.partial(_client_optimizer_fn, name=FLAGS.client_optimizer, learning_rate=FLAGS.client_lr) iterative_process = dp_fedavg.build_federated_averaging_process( tff_model_fn, dp_clip_norm=FLAGS.clip_norm, server_optimizer_fn=server_optimizer_fn, client_optimizer_fn=client_optimizer_fn) iterative_process = tff.simulation.compose_dataset_computation_with_iterative_process( dataset_computation=train_dataset_computation, process=iterative_process) keras_metics = _get_stackoverflow_metrics(FLAGS.vocab_size, FLAGS.num_oov_buckets) model = tff_model_fn() def evaluate_fn(model_weights, dataset): model.from_weights(model_weights) metrics = dp_fedavg.keras_evaluate(model.keras_model, dataset, keras_metics) return collections.OrderedDict( (metric.name, metric.result().numpy()) for metric in metrics) hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in HPARAM_FLAGS]) if FLAGS.total_epochs is None: def client_dataset_ids_fn(round_num: int, epoch: int): return _sample_client_ids(FLAGS.clients_per_round, train_set, round_num, epoch) logging.info('Sample clients for max %d rounds', FLAGS.total_rounds) total_epochs = 0 else: client_shuffer = training_loop.ClientIDShuffler( FLAGS.clients_per_round, train_set) client_dataset_ids_fn = client_shuffer.sample_client_ids logging.info('Shuffle clients for max %d epochs and %d rounds', FLAGS.total_epochs, FLAGS.total_rounds) total_epochs = FLAGS.total_epochs training_loop.run(iterative_process, client_dataset_ids_fn, validation_fn=functools.partial(evaluate_fn, dataset=validation_set), total_epochs=total_epochs, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, train_eval_fn=None, test_fn=functools.partial(evaluate_fn, dataset=test_set), root_output_dir=FLAGS.root_output_dir, hparam_dict=hparam_dict, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_train_eval=2000)