from absl import flags import tensorflow as tf import tensorflow_federated as tff from fedopt_guide.cifar10_resnet import federated_cifar10 from optimization.shared import optimizer_utils from utils import utils_impl _SUPPORTED_TASKS = ['cifar10'] with utils_impl.record_hparam_flags() as optimizer_flags: # Defining optimizer flags optimizer_utils.define_optimizer_flags('client') optimizer_utils.define_optimizer_flags('server') optimizer_utils.define_lr_schedule_flags('client') optimizer_utils.define_lr_schedule_flags('server') with utils_impl.record_hparam_flags() as shared_flags: # Federated training hyperparameters flags.DEFINE_integer('client_epochs_per_round', 1, 'Number of epochs in the client to take per round.') flags.DEFINE_integer('client_batch_size', 20, 'Batch size on the clients.') flags.DEFINE_integer('clients_per_round', 10, 'How many clients to sample per round.') flags.DEFINE_integer('client_datasets_random_seed', 1, 'Random seed for client sampling.') # Training loop configuration flags.DEFINE_string( 'experiment_name', None,
def setUpModule(): # Create flags here to ensure duplicate flags are not created. optimizer_utils.define_optimizer_flags(TEST_SERVER_FLAG_PREFIX) optimizer_utils.define_optimizer_flags(TEST_CLIENT_FLAG_PREFIX) optimizer_utils.define_lr_schedule_flags(TEST_SERVER_FLAG_PREFIX) optimizer_utils.define_lr_schedule_flags(TEST_CLIENT_FLAG_PREFIX)