示例#1
0
from typing import Any, Callable, Optional

from absl import app
from absl import flags
import tensorflow as tf
import tensorflow_federated as tff

from fedopt_guide.gld23k_mobilenet import dataset
from fedopt_guide.gld23k_mobilenet import federated_main
from optimization.shared import optimizer_utils
from utils import utils_impl

with utils_impl.record_hparam_flags() as optimizer_flags:
    # Defining optimizer flags
    optimizer_utils.define_optimizer_flags('client')
    optimizer_utils.define_optimizer_flags('server')

with utils_impl.record_hparam_flags() as shared_flags:
    # Federated training hyperparameters
    flags.DEFINE_integer('client_epochs_per_round', 1,
                         'Number of epochs in the client to take per round.')
    flags.DEFINE_integer('client_batch_size', 16, 'Batch size on the clients.')
    flags.DEFINE_integer('clients_per_round', 10,
                         'How many clients to sample per round.')
    flags.DEFINE_integer('client_datasets_random_seed', None,
                         'Random seed for client sampling.')

    # Training loop configuration
    flags.DEFINE_string(
        'experiment_name', None,
def setUpModule():
    # Create flags here to ensure duplicate flags are not created.
    optimizer_utils.define_optimizer_flags(TEST_SERVER_FLAG_PREFIX)
    optimizer_utils.define_optimizer_flags(TEST_CLIENT_FLAG_PREFIX)
    optimizer_utils.define_lr_schedule_flags(TEST_SERVER_FLAG_PREFIX)
    optimizer_utils.define_lr_schedule_flags(TEST_CLIENT_FLAG_PREFIX)
from optimization.shared import optimizer_utils
from optimization.stackoverflow import centralized_stackoverflow
from optimization.stackoverflow_lr import centralized_stackoverflow_lr
from utils import utils_impl

_SUPPORTED_TASKS = [
    'cifar100', 'emnist_cr', 'emnist_ae', 'shakespeare', 'stackoverflow_nwp',
    'stackoverflow_lr'
]

with utils_impl.record_new_flags() as hparam_flags:
    flags.DEFINE_enum('task', None, _SUPPORTED_TASKS,
                      'Which task to perform federated training on.')

    # Generic centralized training flags
    optimizer_utils.define_optimizer_flags('centralized')
    flags.DEFINE_string(
        'experiment_name', None,
        'Name of the experiment. Part of the name of the output directory.')
    flags.DEFINE_string('cache_dir', None, 'Where to send the datasets')
    flags.mark_flag_as_required('experiment_name')
    flags.DEFINE_string(
        'root_output_dir', '/tmp/centralized_opt',
        'The top-level output directory experiment runs. --experiment_name will '
        'be appended, and the directory will contain tensorboard logs, metrics '
        'written as CSVs, and a CSV of hyperparameter choices.')
    flags.DEFINE_integer('num_epochs', 50, 'Number of epochs to train.')
    flags.DEFINE_integer('batch_size', 32,
                         'Size of batches for training and eval.')
    flags.DEFINE_integer(
        'decay_epochs', None, 'Number of epochs before decaying '
示例#4
0
from reconstruction.movielens import federated_movielens
from reconstruction.shared import federated_evaluation
from reconstruction.shared import federated_trainer_utils
from reconstruction.stackoverflow import federated_stackoverflow
from utils import utils_impl

_SUPPORTED_TASKS = [
    'stackoverflow_nwp', 'movielens_mf', 'stackoverflow_nwp_finetune'
]

with utils_impl.record_hparam_flags() as optimizer_flags:
    # Define optimizer flags.
    # define_optimizer_flags defines flags prefixed by its argument.
    # For each prefix, two flags get created: <prefix>_optimizer, and
    # <prefix>_learning_rate.
    optimizer_utils.define_optimizer_flags('client')
    optimizer_utils.define_optimizer_flags('server')
    # Ignored when the task is `stackoverflow_nwp_finetune`.
    optimizer_utils.define_optimizer_flags('reconstruction')
    # Only used when the task is `stackoverflow_nwp_finetune`. Ignored otherwise.
    optimizer_utils.define_optimizer_flags('finetune')

with utils_impl.record_hparam_flags() as shared_flags:
    # Federated training hyperparameters.
    # The flags below are passed to `run_federated`, and are common to all the
    # tasks.
    flags.DEFINE_integer('client_batch_size', 20, 'Batch size on the clients.')
    flags.DEFINE_integer(
        'clients_per_round', 100,
        'How many clients to sample per round, for both train '
        'and evaluation.')