from absl import app from absl import flags from absl import logging import tensorflow as tf import tensorflow_federated as tff from tensorflow_federated.python.research.differential_privacy import dp_utils from tensorflow_federated.python.research.optimization.shared import optimizer_utils from tensorflow_federated.python.research.utils import training_loop from tensorflow_federated.python.research.utils import training_utils from tensorflow_federated.python.research.utils import utils_impl from tensorflow_federated.python.research.utils.datasets import emnist_dataset from tensorflow_federated.python.research.utils.models import emnist_models with utils_impl.record_hparam_flags(): # Experiment hyperparameters flags.DEFINE_enum( 'model', 'cnn', ['cnn', '2nn'], 'Which model to use. This ' 'can be a convolutional model (cnn) or a two hidden-layer ' 'densely connected network (2nn).') flags.DEFINE_integer('client_batch_size', 20, 'Batch size used on the client.') flags.DEFINE_integer('clients_per_round', 10, 'How many clients to sample per round.') flags.DEFINE_integer( 'client_epochs_per_round', 1, 'Number of client (inner optimizer) epochs per federated round.') flags.DEFINE_boolean( 'uniform_weighting', False, 'Whether to weigh clients uniformly. If false, clients '
from tensorflow_federated.python.research.adaptive_lr_decay import callbacks from tensorflow_federated.python.research.optimization.cifar100 import federated_cifar100 from tensorflow_federated.python.research.optimization.emnist import federated_emnist from tensorflow_federated.python.research.optimization.emnist_ae import federated_emnist_ae from tensorflow_federated.python.research.optimization.shakespeare import federated_shakespeare from tensorflow_federated.python.research.optimization.shared import optimizer_utils from tensorflow_federated.python.research.optimization.stackoverflow import federated_stackoverflow from tensorflow_federated.python.research.optimization.stackoverflow_lr import federated_stackoverflow_lr from tensorflow_federated.python.research.utils import utils_impl _SUPPORTED_TASKS = [ 'cifar100', 'emnist_cr', 'emnist_ae', 'shakespeare', 'stackoverflow_nwp', 'stackoverflow_lr' ] with utils_impl.record_hparam_flags() as optimizer_flags: optimizer_utils.define_optimizer_flags('client') optimizer_utils.define_optimizer_flags('server') with utils_impl.record_hparam_flags() as callback_flags: flags.DEFINE_float( 'client_decay_factor', 0.1, 'Amount to decay the client learning rate ' 'upon reaching a plateau.') flags.DEFINE_float( 'server_decay_factor', 0.9, 'Amount to decay the server learning rate ' 'upon reaching a plateau.') flags.DEFINE_float( 'min_delta', 1e-4, 'Minimum delta for improvement in the learning rate ' 'callbacks.') flags.DEFINE_integer(
from tensorflow_federated.python.research.optimization.cifar100 import federated_cifar100 from tensorflow_federated.python.research.optimization.emnist import federated_emnist from tensorflow_federated.python.research.optimization.emnist_ae import federated_emnist_ae from tensorflow_federated.python.research.optimization.shakespeare import federated_shakespeare from tensorflow_federated.python.research.optimization.shared import fed_avg_schedule from tensorflow_federated.python.research.optimization.shared import optimizer_utils from tensorflow_federated.python.research.optimization.stackoverflow import federated_stackoverflow from tensorflow_federated.python.research.optimization.stackoverflow_lr import federated_stackoverflow_lr from tensorflow_federated.python.research.utils import utils_impl _SUPPORTED_TASKS = [ 'cifar100', 'emnist_cr', 'emnist_ae', 'shakespeare', 'stackoverflow_nwp', 'stackoverflow_lr' ] with utils_impl.record_hparam_flags() as optimizer_flags: # Defining optimizer flags optimizer_utils.define_optimizer_flags('client') optimizer_utils.define_optimizer_flags('server') optimizer_utils.define_lr_schedule_flags('client') optimizer_utils.define_lr_schedule_flags('server') with utils_impl.record_hparam_flags() as shared_flags: # Federated training hyperparameters flags.DEFINE_integer('client_epochs_per_round', 1, 'Number of epochs in the client to take per round.') flags.DEFINE_integer('client_batch_size', 20, 'Batch size on the clients.') flags.DEFINE_integer('clients_per_round', 10, 'How many clients to sample per round.') flags.DEFINE_integer('client_datasets_random_seed', 1, 'Random seed for client sampling.')