def test_create_constant_client_lr_schedule_from_flags(self):
     with flag_sandbox({
             '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
             3.0,
             '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
             'constant',
     }):
         lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
             TEST_CLIENT_FLAG_PREFIX)
         self.assertNear(lr_schedule(0), 3.0, err=1e-5)
         self.assertNear(lr_schedule(1), 3.0, err=1e-5)
         self.assertNear(lr_schedule(105), 3.0, err=1e-5)
         self.assertNear(lr_schedule(1042), 3.0, err=1e-5)
     with flag_sandbox({
             '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
             3.0,
             '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
             'constant',
             '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX):
             10
     }):
         lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
             TEST_CLIENT_FLAG_PREFIX)
         self.assertNear(lr_schedule(0), 0.3, err=1e-5)
         self.assertNear(lr_schedule(1), 0.6, err=1e-5)
         self.assertNear(lr_schedule(10), 3.0, err=1e-5)
         self.assertNear(lr_schedule(11), 3.0, err=1e-5)
         self.assertNear(lr_schedule(115), 3.0, err=1e-5)
         self.assertNear(lr_schedule(1052), 3.0, err=1e-5)
示例#2
0
def from_flags(
    input_spec,
    model_builder: ModelBuilder,
    loss_builder: LossBuilder,
    metrics_builder: MetricsBuilder,
    client_weight_fn: Optional[ClientWeightFn] = None,
) -> tff.templates.IterativeProcess:
    """Builds a `tff.templates.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    input_spec: A value convertible to a `tff.Type`, representing the data which
      will be fed into the `tff.templates.IterativeProcess.next` function over
      the course of training. Generally, this can be found by accessing the
      `element_spec` attribute of a client `tf.data.Dataset`.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras
    # model.
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    model_input_spec = input_spec

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=model_input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)
示例#3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec,
            crop_size=FLAGS.cifar100_crop_size,
            distort_train_images=FLAGS.cifar100_distort_train_images)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    _write_hparam_flags()

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile)
示例#4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    if FLAGS.schedule == 'importance':
        fed_avg_schedule = importance_schedule
    elif FLAGS.schedule == 'loss':
        fed_avg_schedule = fed_loss
    else:
        fed_avg_schedule = fed_avg

    if FLAGS.schedule == 'importance':

        def iterative_process_builder(
            model_fn: Callable[[], tff.learning.Model],
            client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        ) -> tff.templates.IterativeProcess:

            factory = importance_aggregation_factory.ImportanceSamplingFactory(
                FLAGS.clients_per_round)
            weights_type = importance_aggregation_factory.weights_type_from_model_fn(
                model_fn)
            importance_aggregation_process = factory.create(
                value_type=weights_type,
                weight_type=tff.TensorType(tf.float32))

            return importance_schedule.build_fed_avg_process(
                model_fn=model_fn,
                client_optimizer_fn=client_optimizer_fn,
                client_lr=client_lr_schedule,
                server_optimizer_fn=server_optimizer_fn,
                server_lr=server_lr_schedule,
                aggregation_process=importance_aggregation_process)
    elif FLAGS.schedule == 'loss':

        def iterative_process_builder(
            model_fn: Callable[[], tff.learning.Model],
            client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        ) -> tff.templates.IterativeProcess:
            """Creates an iterative process using a given TFF `model_fn`.

      Args:
        model_fn: A no-arg function returning a `tff.learning.Model`.
        client_weight_fn: Optional function that takes the output of
          `model.report_local_outputs` and returns a tensor providing the weight
          in the federated average of model deltas. If not provided, the default
          is the total number of examples processed on device.

      Returns:
        A `tff.templates.IterativeProcess`.
      """
            return fed_avg_schedule.build_fed_avg_process(
                total_clients=FLAGS.loss_pool_size,
                effective_num_clients=FLAGS.clients_per_round,
                model_fn=model_fn,
                client_optimizer_fn=client_optimizer_fn,
                client_lr=client_lr_schedule,
                server_optimizer_fn=server_optimizer_fn,
                server_lr=server_lr_schedule,
                client_weight_fn=client_weight_fn,
                aggregation_process=None)
    else:

        def iterative_process_builder(
            model_fn: Callable[[], tff.learning.Model],
            client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        ) -> tff.templates.IterativeProcess:
            """Creates an iterative process using a given TFF `model_fn`.

      Args:
        model_fn: A no-arg function returning a `tff.learning.Model`.
        client_weight_fn: Optional function that takes the output of
          `model.report_local_outputs` and returns a tensor providing the weight
          in the federated average of model deltas. If not provided, the default
          is the total number of examples processed on device.

      Returns:
        A `tff.templates.IterativeProcess`.
      """
            return fed_avg_schedule.build_fed_avg_process(
                model_fn=model_fn,
                client_optimizer_fn=client_optimizer_fn,
                client_lr=client_lr_schedule,
                server_optimizer_fn=server_optimizer_fn,
                server_lr=server_lr_schedule,
                client_weight_fn=client_weight_fn)

    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder
    task_args = _get_task_args()
    hparam_dict = _get_hparam_flags()
    # shared_args['prob_transmit'] = FLAGS.prob_transmit

    if FLAGS.task == 'cifar100':
        run_federated_fn = federated_cifar100.run_federated

    elif FLAGS.task == 'emnist_cr':
        run_federated_fn = federated_emnist.run_federated
    elif FLAGS.task == 'emnist_ae':
        run_federated_fn = federated_emnist_ae.run_federated
    elif FLAGS.task == 'shakespeare':
        run_federated_fn = federated_shakespeare.run_federated
    elif FLAGS.task == 'stackoverflow_nwp':
        run_federated_fn = federated_stackoverflow.run_federated
    elif FLAGS.task == 'stackoverflow_lr':
        run_federated_fn = federated_stackoverflow_lr.run_federated
    elif FLAGS.task == 'synthetic':
        run_federated_fn = federated_synthetic.run_federated
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))
    run_federated_fn(**shared_args,
                     **task_args,
                     beta=FLAGS.beta,
                     hparam_dict=hparam_dict,
                     schedule=FLAGS.schedule)
示例#5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)

    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder
    task_args = _get_task_args()
    hparam_dict = _get_hparam_flags()

    if FLAGS.task == 'cifar100':
        run_federated_fn = federated_cifar100.run_federated
    elif FLAGS.task == 'emnist_cr':
        run_federated_fn = federated_emnist.run_federated
    elif FLAGS.task == 'emnist_ae':
        run_federated_fn = federated_emnist_ae.run_federated
    elif FLAGS.task == 'shakespeare':
        run_federated_fn = federated_shakespeare.run_federated
    elif FLAGS.task == 'stackoverflow_nwp':
        run_federated_fn = federated_stackoverflow.run_federated
    elif FLAGS.task == 'stackoverflow_lr':
        run_federated_fn = federated_stackoverflow_lr.run_federated
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    run_federated_fn(**shared_args, **task_args, hparam_dict=hparam_dict)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    client_mixedin_schedule_fn = fed_pa_schedule.create_mixin_check_fn(
        name=FLAGS.client_mixin_check_scheme,
        num_mixin_epochs=FLAGS.client_mixin_epochs_per_round,
        start_round=FLAGS.client_mixin_check_start_round)
    client_update_delta_fn = fed_pa_schedule.create_update_delta_fn(
        name=FLAGS.client_update_delta_scheme, rho=FLAGS.client_shrinkage_rho)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model]
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """
        return fed_pa_schedule.build_fed_pa_process(
            model_fn=model_fn,
            client_update_epochs=FLAGS.client_epochs_per_round,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_mixedin_schedule_fn=client_mixedin_schedule_fn,
            client_update_delta_fn=client_update_delta_fn,
            mask_zeros_in_client_updates=FLAGS.mask_zeros_in_client_updates)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        # Since the number of epochs each client makes every round is handled
        # by the logic in client update functions, here we set it to 1.
        client_epochs_per_round=1,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec,
            crop_size=FLAGS.cifar100_crop_size,
            distort_train_images=FLAGS.cifar100_distort_train_images)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    _write_hparam_flags()

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
    def test_create_inv_sqrt_client_lr_schedule_from_flags(self):
        with flag_sandbox({
                '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                2.0,
                '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
                'inv_sqrt_decay',
                '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                10.0,
                '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX):
                True,
        }):
            lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
                TEST_CLIENT_FLAG_PREFIX)
            self.assertNear(lr_schedule(0), 2.0, err=1e-5)
            self.assertNear(lr_schedule(1), 2.0, err=1e-5)
            self.assertNear(lr_schedule(10), 0.603022689155, err=1e-5)
            self.assertNear(lr_schedule(19), 0.603022689155, err=1e-5)
            self.assertNear(lr_schedule(20), 0.436435780472, err=1e-5)

        with flag_sandbox({
                '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                2.0,
                '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
                'inv_sqrt_decay',
                '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                0,
                '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                10.0,
                '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX):
                False,
        }):
            lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
                TEST_CLIENT_FLAG_PREFIX)
            self.assertNear(lr_schedule(0), 2.0, err=1e-5)
            self.assertNear(lr_schedule(3), 1.0, err=1e-5)
            self.assertNear(lr_schedule(99), 0.2, err=1e-5)
            self.assertNear(lr_schedule(399), 0.1, err=1e-5)

        with flag_sandbox({
                '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                2.0,
                '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
                'inv_sqrt_decay',
                '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                10.0,
                '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX):
                False,
        }):
            lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
                TEST_CLIENT_FLAG_PREFIX)
            self.assertNear(lr_schedule(0), 0.2, err=1e-5)
            self.assertNear(lr_schedule(1), 0.4, err=1e-5)
            self.assertNear(lr_schedule(10), 2.0, err=1e-5)
            self.assertNear(lr_schedule(13), 1.0, err=1e-5)
            self.assertNear(lr_schedule(109), 0.2, err=1e-5)
            self.assertNear(lr_schedule(409), 0.1, err=1e-5)
    def test_create_inv_lin_client_lr_schedule_from_flags(self):
        with flag_sandbox({
                '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                5.0,
                '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
                'inv_lin_decay',
                '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                10.0,
                '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX):
                True,
        }):
            lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
                TEST_CLIENT_FLAG_PREFIX)
            self.assertNear(lr_schedule(0), 5.0, err=1e-5)
            self.assertNear(lr_schedule(1), 5.0, err=1e-5)
            self.assertNear(lr_schedule(10), 0.454545454545, err=1e-5)
            self.assertNear(lr_schedule(19), 0.454545454545, err=1e-5)
            self.assertNear(lr_schedule(20), 0.238095238095, err=1e-5)

        with flag_sandbox({
                '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                5.0,
                '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
                'inv_lin_decay',
                '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                0,
                '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                10.0,
                '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX):
                False,
        }):
            lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
                TEST_CLIENT_FLAG_PREFIX)
            self.assertNear(lr_schedule(0), 5.0, err=1e-5)
            self.assertNear(lr_schedule(1), 2.5, err=1e-5)
            self.assertNear(lr_schedule(9), 0.5, err=1e-5)
            self.assertNear(lr_schedule(19), 0.25, err=1e-5)

        with flag_sandbox({
                '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                5.0,
                '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
                'inv_lin_decay',
                '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                10.0,
                '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX):
                False,
        }):
            lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
                TEST_CLIENT_FLAG_PREFIX)
            self.assertNear(lr_schedule(0), 0.5, err=1e-5)
            self.assertNear(lr_schedule(1), 1.0, err=1e-5)
            self.assertNear(lr_schedule(10), 5.0, err=1e-5)
            self.assertNear(lr_schedule(11), 2.5, err=1e-5)
            self.assertNear(lr_schedule(19), 0.5, err=1e-5)
            self.assertNear(lr_schedule(29), 0.25, err=1e-5)
    def test_create_exp_decay_client_lr_schedule_from_flags(self):
        with flag_sandbox({
                '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                3.0,
                '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
                'exp_decay',
                '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                0.1,
                '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX):
                True,
        }):
            lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
                TEST_CLIENT_FLAG_PREFIX)
            self.assertNear(lr_schedule(0), 3.0, err=1e-5)
            self.assertNear(lr_schedule(3), 3.0, err=1e-5)
            self.assertNear(lr_schedule(10), 0.3, err=1e-5)
            self.assertNear(lr_schedule(19), 0.3, err=1e-5)
            self.assertNear(lr_schedule(20), 0.03, err=1e-5)

        with flag_sandbox({
                '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                3.0,
                '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
                'exp_decay',
                '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                0,
                '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                0.1,
                '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX):
                False,
        }):
            lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
                TEST_CLIENT_FLAG_PREFIX)
            self.assertNear(lr_schedule(0), 3.0, err=1e-5)
            self.assertNear(lr_schedule(1), 2.38298470417, err=1e-5)
            self.assertNear(lr_schedule(10), 0.3, err=1e-5)
            self.assertNear(lr_schedule(25), 0.00948683298, err=1e-5)

        with flag_sandbox({
                '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                3.0,
                '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
                'exp_decay',
                '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX):
                10,
                '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                0.1,
                '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX):
                False,
        }):
            lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
                TEST_CLIENT_FLAG_PREFIX)
            self.assertNear(lr_schedule(0), 0.3, err=1e-5)
            self.assertNear(lr_schedule(1), 0.6, err=1e-5)
            self.assertNear(lr_schedule(10), 3.0, err=1e-5)
            self.assertNear(lr_schedule(11), 2.38298470417, err=1e-5)
            self.assertNear(lr_schedule(20), 0.3, err=1e-5)
            self.assertNear(lr_schedule(35), 0.00948683298, err=1e-5)
示例#10
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec,
            crop_size=FLAGS.cifar100_crop_size,
            distort_train_images=FLAGS.cifar100_distort_train_images)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    _write_hparam_flags()

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)