예제 #1
0
def run_experiment():
    """Data preprocessing and experiment execution."""
    emnist_train, emnist_test = dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=FLAGS.only_digits)

    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(example_dataset)))

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'client')
    server_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'server')

    def tff_model_fn():
        keras_model = model_builder()
        return tff.learning.from_keras_model(keras_model,
                                             dummy_batch=sample_batch,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    if FLAGS.use_compression:
        # We create a `StatefulBroadcastFn` and `StatefulAggregateFn` by providing
        # the `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding
        # utilities. The fns are called once for each of the model weights created
        # by tff_model_fn, and return instances of appropriate encoders.
        encoded_broadcast_fn = (
            tff.learning.framework.build_encoded_broadcast_from_model(
                tff_model_fn, _broadcast_encoder_fn))
        encoded_mean_fn = tff.learning.framework.build_encoded_mean_from_model(
            tff_model_fn, _mean_encoder_fn)
    else:
        encoded_broadcast_fn = None
        encoded_mean_fn = None

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        stateful_delta_aggregate_fn=encoded_mean_fn,
        stateful_model_broadcast_fn=encoded_broadcast_fn)
    iterative_process = compression_process_adapter.CompressionProcessAdapter(
        iterative_process)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      evaluate_fn=evaluate_fn)
예제 #2
0
def run_experiment():
    """Data preprocessing and experiment execution."""
    emnist_train, emnist_test = emnist_dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=FLAGS.only_digits)

    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    input_spec = example_dataset.element_spec

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    assign_weights_fn = compression_process_adapter.CompressionServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    client_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'client')
    server_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'server')

    def tff_model_fn():
        keras_model = model_builder()
        return tff.learning.from_keras_model(keras_model,
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    if FLAGS.use_compression:
        # We create a `MeasuredProcess` for broadcast process and a
        # `MeasuredProcess` for aggregate process by providing the
        # `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding utilities.
        # The fns are called once for each of the model weights created by
        # tff_model_fn, and return instances of appropriate encoders.
        encoded_broadcast_process = (
            tff.learning.framework.build_encoded_broadcast_process_from_model(
                tff_model_fn, _broadcast_encoder_fn))
        encoded_mean_process = (
            tff.learning.framework.build_encoded_mean_process_from_model(
                tff_model_fn, _mean_encoder_fn))
    else:
        encoded_broadcast_process = None
        encoded_mean_process = None

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        aggregation_process=encoded_mean_process,
        broadcast_process=encoded_broadcast_process)
    iterative_process = compression_process_adapter.CompressionProcessAdapter(
        iterative_process)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      evaluate_fn=evaluate_fn)