def test_create_optimizer_fn_with_no_learning_rate(self):
     with flag_sandbox({
             '{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX):
             'sgd',
             '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
             None
     }):
         with self.assertRaisesRegex(ValueError, 'Learning rate'):
             optimizer_utils.create_optimizer_fn_from_flags(
                 TEST_CLIENT_FLAG_PREFIX)
 def test_create_optimizer_fn_from_flags_flags_set_not_for_optimizer(self):
     with flag_sandbox(
         {'{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): 'sgd'}):
         # Set an Adam flag that isn't used in SGD.
         # We need to use `_parse_args` because that is the only way FLAGS is
         # notified that a non-default value is being used.
         bad_adam_flag = '{}_adam_beta_1'.format(TEST_CLIENT_FLAG_PREFIX)
         FLAGS._parse_args(args=['--{}=0.5'.format(bad_adam_flag)],
                           known_only=True)
         with self.assertRaisesRegex(
                 ValueError,
                 r'Commandline flags for .*\[sgd\].*\'test_client_adam_beta_1\'.*'
         ):
             optimizer_utils.create_optimizer_fn_from_flags(
                 TEST_CLIENT_FLAG_PREFIX)
         FLAGS[bad_adam_flag].unparse()
 def test_create_server_optimizer_from_flags(self, optimizer_name,
                                             optimizer_cls):
     commandline_set_learning_rate = 100.0
     with flag_sandbox({
             '{}_optimizer'.format(TEST_SERVER_FLAG_PREFIX):
             optimizer_name,
             '{}_learning_rate'.format(TEST_SERVER_FLAG_PREFIX):
             commandline_set_learning_rate
     }):
         custom_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
             TEST_SERVER_FLAG_PREFIX)
         custom_optimizer = custom_optimizer_fn()
         self.assertIsInstance(custom_optimizer, optimizer_cls)
         self.assertEqual(custom_optimizer.get_config()['learning_rate'],
                          commandline_set_learning_rate)
         custom_optimizer_with_arg = custom_optimizer_fn(11.0)
         self.assertIsInstance(custom_optimizer_with_arg, optimizer_cls)
         self.assertEqual(
             custom_optimizer_with_arg.get_config()['learning_rate'], 11.0)
 def test_create_optimizer_fn_from_flags_invalid_optimizer(self):
     FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'foo'
     with self.assertRaisesRegex(ValueError, 'not a valid optimizer'):
         optimizer_utils.create_optimizer_fn_from_flags(
             TEST_CLIENT_FLAG_PREFIX)
示例#5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model]
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """
        if FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp':

            def client_weight_fn(local_outputs):
                return tf.cast(tf.squeeze(local_outputs['num_tokens']),
                               tf.float32)
        else:
            client_weight_fn = None

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            tau=FLAGS.tau,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec,
            crop_size=FLAGS.cifar100_crop_size,
            distort_train_images=FLAGS.cifar100_distort_train_images,
            cache_dir=FLAGS.cache_dir)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model, cache_dir=FLAGS.cache_dir)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(
            task_spec, cache_dir=FLAGS.cache_dir)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec,
            sequence_length=FLAGS.shakespeare_sequence_length,
            cache_dir=FLAGS.cache_dir)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples,
            cache_dir=FLAGS.cache_dir)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples,
            cache_dir=FLAGS.cache_dir)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    _write_hparam_flags()

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()
    hparams_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                            for name in hparam_flags])

    common_args = collections.OrderedDict([
        ('optimizer', optimizer),
        ('experiment_name', FLAGS.experiment_name),
        ('root_output_dir', FLAGS.root_output_dir),
        ('num_epochs', FLAGS.num_epochs),
        ('batch_size', FLAGS.batch_size),
        ('decay_epochs', FLAGS.decay_epochs),
        ('lr_decay', FLAGS.lr_decay),
        ('hparams_dict', hparams_dict),
    ])

    if FLAGS.task == 'cifar100':
        centralized_cifar100.run_centralized(
            **common_args,
            crop_size=FLAGS.cifar100_crop_size,
            cache_dir=FLAGS.cache_dir)

    elif FLAGS.task == 'emnist_cr':
        centralized_emnist.run_centralized(**common_args,
                                           emnist_model=FLAGS.emnist_cr_model,
                                           cache_dir=FLAGS.cache_dir)

    elif FLAGS.task == 'emnist_ae':
        centralized_emnist_ae.run_centralized(**common_args,
                                              cache_dir=FLAGS.cache_dir)

    elif FLAGS.task == 'shakespeare':
        centralized_shakespeare.run_centralized(
            **common_args,
            sequence_length=FLAGS.shakespeare_sequence_length,
            cache_dir=FLAGS.cache_dir)

    elif FLAGS.task == 'stackoverflow_nwp':
        so_nwp_flags = collections.OrderedDict()
        for flag_name in FLAGS:
            if flag_name.startswith('so_nwp_'):
                so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
        centralized_stackoverflow.run_centralized(**common_args,
                                                  **so_nwp_flags,
                                                  cache_dir=FLAGS.cache_dir)

    elif FLAGS.task == 'stackoverflow_lr':
        so_lr_flags = collections.OrderedDict()
        for flag_name in FLAGS:
            if flag_name.startswith('so_lr_'):
                so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
        centralized_stackoverflow_lr.run_centralized(**common_args,
                                                     **so_lr_flags,
                                                     cache_dir=FLAGS.cache_dir)

    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))