예제 #1
0
def _get_client_datasets_fn(train_data):
    """Returns function for client datasets per round."""
    if FLAGS.total_epochs is None:

        def client_datasets_fn(round_num: int, epoch: int):
            del round_num
            sampled_clients = np.random.choice(train_data.client_ids,
                                               size=FLAGS.clients_per_round,
                                               replace=False)
            return [
                train_data.create_tf_dataset_for_client(client)
                for client in sampled_clients
            ], epoch

        logging.info('Sample clients for max %d rounds', FLAGS.total_rounds)
    else:
        client_shuffer = training_loop.ClientIDShuffler(
            FLAGS.clients_per_round, train_data)

        def client_datasets_fn(round_num: int, epoch: int):
            sampled_clients, epoch = client_shuffer.sample_client_ids(
                round_num, epoch)
            return [
                train_data.create_tf_dataset_for_client(client)
                for client in sampled_clients
            ], epoch

        logging.info('Shuffle clients for max %d epochs and %d rounds',
                     FLAGS.total_epochs, FLAGS.total_rounds)
    return client_datasets_fn
예제 #2
0
 def test_remainder(self):
     clients_data = tff.simulation.datasets.stackoverflow.get_synthetic()
     client_shuffer1 = training_loop.ClientIDShuffler(
         len(clients_data.client_ids) - 1,
         clients_data,
         drop_remainder=True)
     client_shuffer2 = training_loop.ClientIDShuffler(
         len(clients_data.client_ids) - 1,
         clients_data,
         drop_remainder=False)
     epoch1, epoch2, round_num = 0, 0, 0
     total_rounds = 2
     while round_num < total_rounds:
         clients1, epoch1 = client_shuffer1.sample_client_ids(
             round_num, epoch1)
         clients2, epoch2 = client_shuffer2.sample_client_ids(
             round_num, epoch2)
         round_num += 1
     self.assertEqual(len(clients1), len(clients_data.client_ids) - 1)
     self.assertEqual(len(clients2), 1)
     self.assertEqual(epoch1, 2)
     self.assertEqual(epoch2, 1)
예제 #3
0
 def test_shuffling(self):
     clients_data = tff.simulation.datasets.stackoverflow.get_synthetic()
     client_shuffer = training_loop.ClientIDShuffler(1, clients_data)
     epoch, round_num = 0, 0
     total_epochs = 2
     epoch2clientid = [[] for _ in range(total_epochs)]
     while epoch < total_epochs:
         clients, new_epoch = client_shuffer.sample_client_ids(
             round_num, epoch)
         epoch2clientid[epoch].extend(clients)
         round_num += 1
         epoch = new_epoch
     self.assertCountEqual(epoch2clientid[0], epoch2clientid[1])
예제 #4
0
def train_and_eval():
    """Train and evaluate StackOver NWP task."""
    logging.info('Show FLAGS for debugging:')
    for f in HPARAM_FLAGS:
        logging.info('%s=%s', f, FLAGS[f].value)

    train_dataset_computation, train_set, validation_set, test_set = _preprocess_stackoverflow(
        FLAGS.vocab_size, FLAGS.num_oov_buckets, FLAGS.sequence_length,
        FLAGS.num_validation_examples, FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round, FLAGS.max_elements_per_user)

    input_spec = train_dataset_computation.type_signature.result.element

    def tff_model_fn():
        keras_model = models.create_recurrent_model(
            vocab_size=FLAGS.vocab_size,
            embedding_size=FLAGS.embedding_size,
            latent_size=FLAGS.latent_size,
            num_layers=FLAGS.num_layers,
            shared_embedding=FLAGS.shared_embedding)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        return dp_fedavg.KerasModelWrapper(keras_model, input_spec, loss)

    noise_std = FLAGS.clip_norm * FLAGS.noise_multiplier / float(
        FLAGS.clients_per_round)
    server_optimizer_fn = functools.partial(_server_optimizer_fn,
                                            name=FLAGS.server_optimizer,
                                            learning_rate=FLAGS.server_lr,
                                            noise_std=noise_std)
    client_optimizer_fn = functools.partial(_client_optimizer_fn,
                                            name=FLAGS.client_optimizer,
                                            learning_rate=FLAGS.client_lr)
    iterative_process = dp_fedavg.build_federated_averaging_process(
        tff_model_fn,
        dp_clip_norm=FLAGS.clip_norm,
        server_optimizer_fn=server_optimizer_fn,
        client_optimizer_fn=client_optimizer_fn)
    iterative_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        dataset_computation=train_dataset_computation,
        process=iterative_process)

    keras_metics = _get_stackoverflow_metrics(FLAGS.vocab_size,
                                              FLAGS.num_oov_buckets)
    model = tff_model_fn()

    def evaluate_fn(model_weights, dataset):
        model.from_weights(model_weights)
        metrics = dp_fedavg.keras_evaluate(model.keras_model, dataset,
                                           keras_metics)
        return collections.OrderedDict(
            (metric.name, metric.result().numpy()) for metric in metrics)

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in HPARAM_FLAGS])

    if FLAGS.total_epochs is None:

        def client_dataset_ids_fn(round_num: int, epoch: int):
            return _sample_client_ids(FLAGS.clients_per_round, train_set,
                                      round_num, epoch)

        logging.info('Sample clients for max %d rounds', FLAGS.total_rounds)
        total_epochs = 0
    else:
        client_shuffer = training_loop.ClientIDShuffler(
            FLAGS.clients_per_round, train_set)
        client_dataset_ids_fn = client_shuffer.sample_client_ids
        logging.info('Shuffle clients for max %d epochs and %d rounds',
                     FLAGS.total_epochs, FLAGS.total_rounds)
        total_epochs = FLAGS.total_epochs

    training_loop.run(iterative_process,
                      client_dataset_ids_fn,
                      validation_fn=functools.partial(evaluate_fn,
                                                      dataset=validation_set),
                      total_epochs=total_epochs,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      train_eval_fn=None,
                      test_fn=functools.partial(evaluate_fn, dataset=test_set),
                      root_output_dir=FLAGS.root_output_dir,
                      hparam_dict=hparam_dict,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_train_eval=2000)