def setUpModule(): # Create flags here to ensure duplicate flags are not created. optimizer_utils.define_optimizer_flags(TEST_SERVER_FLAG_PREFIX) optimizer_utils.define_optimizer_flags(TEST_CLIENT_FLAG_PREFIX) optimizer_utils.define_lr_schedule_flags(TEST_SERVER_FLAG_PREFIX) optimizer_utils.define_lr_schedule_flags(TEST_CLIENT_FLAG_PREFIX)
from tensorflow_federated.python.research.optimization.shared import fed_avg_schedule from tensorflow_federated.python.research.optimization.shared import optimizer_utils from tensorflow_federated.python.research.optimization.stackoverflow import federated_stackoverflow from tensorflow_federated.python.research.optimization.stackoverflow_lr import federated_stackoverflow_lr from tensorflow_federated.python.research.utils import utils_impl _SUPPORTED_TASKS = [ 'cifar100', 'emnist_cr', 'emnist_ae', 'shakespeare', 'stackoverflow_nwp', 'stackoverflow_lr' ] with utils_impl.record_hparam_flags() as optimizer_flags: # Defining optimizer flags optimizer_utils.define_optimizer_flags('client') optimizer_utils.define_optimizer_flags('server') optimizer_utils.define_lr_schedule_flags('client') optimizer_utils.define_lr_schedule_flags('server') with utils_impl.record_hparam_flags() as shared_flags: # Federated training hyperparameters flags.DEFINE_integer('client_epochs_per_round', 1, 'Number of epochs in the client to take per round.') flags.DEFINE_integer('client_batch_size', 20, 'Batch size on the clients.') flags.DEFINE_integer('clients_per_round', 10, 'How many clients to sample per round.') flags.DEFINE_integer('client_datasets_random_seed', 1, 'Random seed for client sampling.') flags.DEFINE_integer('total_rounds', 200, 'Number of total training rounds.') # Training loop configuration