Exemplo n.º 1
0
def from_flags(
    input_spec,
    model_builder: ModelBuilder,
    loss_builder: LossBuilder,
    metrics_builder: MetricsBuilder,
    client_weight_fn: Optional[ClientWeightFn] = None,
) -> tff.templates.IterativeProcess:
    """Builds a `tff.templates.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    input_spec: A value convertible to a `tff.Type`, representing the data which
      will be fed into the `tff.templates.IterativeProcess.next` function over
      the course of training. Generally, this can be found by accessing the
      `element_spec` attribute of a client `tf.data.Dataset`.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras
    # model.
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    model_input_spec = input_spec

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=model_input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)
Exemplo n.º 2
0
 def test_create_optimizer_fn_with_no_learning_rate(self):
     with flag_sandbox({
             '{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX):
             'sgd',
             '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
             None
     }):
         with self.assertRaisesRegex(ValueError, 'Learning rate'):
             optimizer_utils.create_optimizer_fn_from_flags(
                 TEST_CLIENT_FLAG_PREFIX)
Exemplo n.º 3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return tff.learning.build_federated_averaging_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weight_fn=client_weight_fn,
            use_experimental_simulation_loop=True)

    dataset_type = dataset.DatasetType.GLD23K
    if FLAGS.dataset_type == 'gld160k':
        dataset_type = dataset.DatasetType.GLD160K

    federated_main.run_federated(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        max_elements_per_user=FLAGS.max_elements_per_user,
        image_size=FLAGS.image_size,
        num_groups=FLAGS.num_groups,
        total_rounds=FLAGS.total_rounds,
        dataset_type=dataset_type,
        experiment_name=FLAGS.experiment_name,
        root_output_dir=FLAGS.root_output_dir,
        dropout_prob=FLAGS.dropout_prob,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed,
        rounds_per_eval=FLAGS.rounds_per_eval,
        rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
        hparam_dict=get_hparam_flags())
Exemplo n.º 4
0
 def test_create_optimizer_fn_from_flags_flags_set_not_for_optimizer(self):
     with flag_sandbox(
         {'{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): 'sgd'}):
         # Set an Adam flag that isn't used in SGD.
         # We need to use `_parse_args` because that is the only way FLAGS is
         # notified that a non-default value is being used.
         bad_adam_flag = '{}_adam_beta_1'.format(TEST_CLIENT_FLAG_PREFIX)
         FLAGS._parse_args(args=['--{}=0.5'.format(bad_adam_flag)],
                           known_only=True)
         with self.assertRaisesRegex(
                 ValueError,
                 r'Commandline flags for .*\[sgd\].*\'test_client_adam_beta_1\'.*'
         ):
             optimizer_utils.create_optimizer_fn_from_flags(
                 TEST_CLIENT_FLAG_PREFIX)
         FLAGS[bad_adam_flag].unparse()
Exemplo n.º 5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()
    hparams_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                            for name in hparam_flags])

    common_args = collections.OrderedDict([
        ('optimizer', optimizer),
        ('experiment_name', FLAGS.experiment_name),
        ('root_output_dir', FLAGS.root_output_dir),
        ('num_epochs', FLAGS.num_epochs),
        ('batch_size', FLAGS.batch_size),
        ('decay_epochs', FLAGS.decay_epochs),
        ('lr_decay', FLAGS.lr_decay),
        ('hparams_dict', hparams_dict),
    ])

    if FLAGS.task == 'cifar10':
        centralized_cifar10.run_centralized(**common_args,
                                            crop_size=FLAGS.cifar100_crop_size)

    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()
    hparams_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                            for name in hparam_flags])

    common_args = collections.OrderedDict([
        ('optimizer', optimizer),
        ('experiment_name', FLAGS.experiment_name),
        ('root_output_dir', FLAGS.root_output_dir),
        ('num_epochs', FLAGS.num_epochs),
        ('batch_size', FLAGS.batch_size),
        ('decay_epochs', FLAGS.decay_epochs),
        ('lr_decay', FLAGS.lr_decay),
        ('hparams_dict', hparams_dict),
    ])

    if FLAGS.task == 'cifar100':
        centralized_cifar100.run_centralized(
            **common_args, crop_size=FLAGS.cifar100_crop_size)

    elif FLAGS.task == 'emnist_cr':
        centralized_emnist.run_centralized(**common_args,
                                           emnist_model=FLAGS.emnist_cr_model)

    elif FLAGS.task == 'emnist_ae':
        centralized_emnist_ae.run_centralized(**common_args)

    elif FLAGS.task == 'shakespeare':
        centralized_shakespeare.run_centralized(
            **common_args, sequence_length=FLAGS.shakespeare_sequence_length)

    elif FLAGS.task == 'stackoverflow_nwp':
        so_nwp_flags = collections.OrderedDict()
        for flag_name in FLAGS:
            if flag_name.startswith('so_nwp_'):
                so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
        centralized_stackoverflow.run_centralized(**common_args,
                                                  **so_nwp_flags)

    elif FLAGS.task == 'stackoverflow_lr':
        so_lr_flags = collections.OrderedDict()
        for flag_name in FLAGS:
            if flag_name.startswith('so_lr_'):
                so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
        centralized_stackoverflow_lr.run_centralized(**common_args,
                                                     **so_lr_flags)

    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))
Exemplo n.º 7
0
 def test_create_server_optimizer_from_flags(self, optimizer_name,
                                             optimizer_cls):
     commandline_set_learning_rate = 100.0
     with flag_sandbox({
             '{}_optimizer'.format(TEST_SERVER_FLAG_PREFIX):
             optimizer_name,
             '{}_learning_rate'.format(TEST_SERVER_FLAG_PREFIX):
             commandline_set_learning_rate
     }):
         custom_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
             TEST_SERVER_FLAG_PREFIX)
         custom_optimizer = custom_optimizer_fn()
         self.assertIsInstance(custom_optimizer, optimizer_cls)
         self.assertEqual(custom_optimizer.get_config()['learning_rate'],
                          commandline_set_learning_rate)
         custom_optimizer_with_arg = custom_optimizer_fn(11.0)
         self.assertIsInstance(custom_optimizer_with_arg, optimizer_cls)
         self.assertEqual(
             custom_optimizer_with_arg.get_config()['learning_rate'], 11.0)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    hparams_dict = utils_impl.lookup_flag_values(hparam_flags)
    hparams_dict = optimizer_utils.remove_unused_flags('centralized',
                                                       hparams_dict)

    centralized_main.run_centralized(
        optimizer_utils.create_optimizer_fn_from_flags('centralized')(),
        FLAGS.num_epochs,
        FLAGS.batch_size,
        vocab_size=FLAGS.vocab_size,
        d_embed=FLAGS.d_embed,
        d_model=FLAGS.d_model,
        d_hidden=FLAGS.d_hidden,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        dropout=FLAGS.dropout,
        experiment_name=FLAGS.experiment_name,
        root_output_dir=FLAGS.root_output_dir,
        max_batches=FLAGS.max_batches,
        hparams_dict=hparams_dict)
Exemplo n.º 9
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    hparams_dict = utils_impl.lookup_flag_values(hparam_flags)
    hparams_dict = optimizer_utils.remove_unused_flags('centralized',
                                                       hparams_dict)

    dataset_type = dataset.DatasetType.GLD23K
    if FLAGS.dataset_type == 'gld160k':
        dataset_type = dataset.DatasetType.GLD160K

    centralized_main.run_centralized(
        optimizer=optimizer_utils.create_optimizer_fn_from_flags(
            'centralized')(),
        image_size=FLAGS.image_size,
        num_epochs=FLAGS.num_epochs,
        batch_size=FLAGS.batch_size,
        num_groups=FLAGS.num_groups,
        dataset_type=dataset_type,
        experiment_name=FLAGS.experiment_name,
        root_output_dir=FLAGS.root_output_dir,
        dropout_prob=FLAGS.dropout_prob,
        hparams_dict=hparams_dict)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model]
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """
        if FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp':

            def client_weight_fn(local_outputs):
                return tf.cast(tf.squeeze(local_outputs['num_tokens']),
                               tf.float32)
        else:
            client_weight_fn = None

        return fed_avg_local_adaptivity.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=FLAGS.client_learning_rate,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=FLAGS.server_learning_rate,
            client_weight_fn=client_weight_fn,
            correction=FLAGS.correction_type)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec,
            crop_size=FLAGS.cifar100_crop_size,
            distort_train_images=FLAGS.cifar100_distort_train_images)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    _write_hparam_flags()

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
Exemplo n.º 11
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec,
            crop_size=FLAGS.cifar100_crop_size,
            distort_train_images=FLAGS.cifar100_distort_train_images)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    _write_hparam_flags()

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile)
Exemplo n.º 12
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_callback = callbacks.create_reduce_lr_on_plateau(
      learning_rate=FLAGS.client_learning_rate,
      decay_factor=FLAGS.client_decay_factor,
      min_delta=FLAGS.min_delta,
      min_lr=FLAGS.min_lr,
      window_size=FLAGS.window_size,
      patience=FLAGS.patience)

  server_lr_callback = callbacks.create_reduce_lr_on_plateau(
      learning_rate=FLAGS.server_learning_rate,
      decay_factor=FLAGS.server_decay_factor,
      min_delta=FLAGS.min_delta,
      min_lr=FLAGS.min_lr,
      window_size=FLAGS.window_size,
      patience=FLAGS.patience)

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return adaptive_fed_avg.build_fed_avg_process(
        model_fn,
        client_lr_callback,
        server_lr_callback,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn)

  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())

  shared_args = utils_impl.lookup_flag_values(shared_flags)
  shared_args['iterative_process_builder'] = iterative_process_builder

  if FLAGS.task == 'cifar100':
    hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size
    federated_cifar100.run_federated(
        **shared_args,
        crop_size=FLAGS.cifar100_crop_size,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_cr':
    federated_emnist.run_federated(
        **shared_args, model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_ae':
    federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict)

  elif FLAGS.task == 'shakespeare':
    federated_shakespeare.run_federated(
        **shared_args,
        sequence_length=FLAGS.shakespeare_sequence_length,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_nwp':
    so_nwp_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_nwp_'):
        so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
    federated_stackoverflow.run_federated(
        **shared_args, **so_nwp_flags, hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_lr':
    so_lr_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_lr_'):
        so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
    federated_stackoverflow_lr.run_federated(
        **shared_args, **so_lr_flags, hparam_dict=hparam_dict)
Exemplo n.º 13
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.client_learning_rate,
        decay_factor=FLAGS.client_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    server_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.server_learning_rate,
        decay_factor=FLAGS.server_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec, crop_size=FLAGS.cifar100_crop_size)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile)
Exemplo n.º 14
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    if FLAGS.schedule == 'importance':
        fed_avg_schedule = importance_schedule
    elif FLAGS.schedule == 'loss':
        fed_avg_schedule = fed_loss
    else:
        fed_avg_schedule = fed_avg

    if FLAGS.schedule == 'importance':

        def iterative_process_builder(
            model_fn: Callable[[], tff.learning.Model],
            client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        ) -> tff.templates.IterativeProcess:

            factory = importance_aggregation_factory.ImportanceSamplingFactory(
                FLAGS.clients_per_round)
            weights_type = importance_aggregation_factory.weights_type_from_model_fn(
                model_fn)
            importance_aggregation_process = factory.create(
                value_type=weights_type,
                weight_type=tff.TensorType(tf.float32))

            return importance_schedule.build_fed_avg_process(
                model_fn=model_fn,
                client_optimizer_fn=client_optimizer_fn,
                client_lr=client_lr_schedule,
                server_optimizer_fn=server_optimizer_fn,
                server_lr=server_lr_schedule,
                aggregation_process=importance_aggregation_process)
    elif FLAGS.schedule == 'loss':

        def iterative_process_builder(
            model_fn: Callable[[], tff.learning.Model],
            client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        ) -> tff.templates.IterativeProcess:
            """Creates an iterative process using a given TFF `model_fn`.

      Args:
        model_fn: A no-arg function returning a `tff.learning.Model`.
        client_weight_fn: Optional function that takes the output of
          `model.report_local_outputs` and returns a tensor providing the weight
          in the federated average of model deltas. If not provided, the default
          is the total number of examples processed on device.

      Returns:
        A `tff.templates.IterativeProcess`.
      """
            return fed_avg_schedule.build_fed_avg_process(
                total_clients=FLAGS.loss_pool_size,
                effective_num_clients=FLAGS.clients_per_round,
                model_fn=model_fn,
                client_optimizer_fn=client_optimizer_fn,
                client_lr=client_lr_schedule,
                server_optimizer_fn=server_optimizer_fn,
                server_lr=server_lr_schedule,
                client_weight_fn=client_weight_fn,
                aggregation_process=None)
    else:

        def iterative_process_builder(
            model_fn: Callable[[], tff.learning.Model],
            client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        ) -> tff.templates.IterativeProcess:
            """Creates an iterative process using a given TFF `model_fn`.

      Args:
        model_fn: A no-arg function returning a `tff.learning.Model`.
        client_weight_fn: Optional function that takes the output of
          `model.report_local_outputs` and returns a tensor providing the weight
          in the federated average of model deltas. If not provided, the default
          is the total number of examples processed on device.

      Returns:
        A `tff.templates.IterativeProcess`.
      """
            return fed_avg_schedule.build_fed_avg_process(
                model_fn=model_fn,
                client_optimizer_fn=client_optimizer_fn,
                client_lr=client_lr_schedule,
                server_optimizer_fn=server_optimizer_fn,
                server_lr=server_lr_schedule,
                client_weight_fn=client_weight_fn)

    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder
    task_args = _get_task_args()
    hparam_dict = _get_hparam_flags()
    # shared_args['prob_transmit'] = FLAGS.prob_transmit

    if FLAGS.task == 'cifar100':
        run_federated_fn = federated_cifar100.run_federated

    elif FLAGS.task == 'emnist_cr':
        run_federated_fn = federated_emnist.run_federated
    elif FLAGS.task == 'emnist_ae':
        run_federated_fn = federated_emnist_ae.run_federated
    elif FLAGS.task == 'shakespeare':
        run_federated_fn = federated_shakespeare.run_federated
    elif FLAGS.task == 'stackoverflow_nwp':
        run_federated_fn = federated_stackoverflow.run_federated
    elif FLAGS.task == 'stackoverflow_lr':
        run_federated_fn = federated_stackoverflow_lr.run_federated
    elif FLAGS.task == 'synthetic':
        run_federated_fn = federated_synthetic.run_federated
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))
    run_federated_fn(**shared_args,
                     **task_args,
                     beta=FLAGS.beta,
                     hparam_dict=hparam_dict,
                     schedule=FLAGS.schedule)
Exemplo n.º 15
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    emnist_train, emnist_test = emnist_dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    if FLAGS.model == 'cnn':
        model_builder = functools.partial(
            emnist_models.create_conv_dropout_model, only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(
            emnist_models.create_two_hidden_layer_model, only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    if FLAGS.uniform_weighting:

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0

    else:
        client_weight_fn = None  #  Defaults to the number of examples per client.

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=emnist_test.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )

        dp_query = tff.utils.build_dp_query(
            clip=FLAGS.clip,
            noise_multiplier=FLAGS.noise_multiplier,
            expected_total_weight=FLAGS.clients_per_round,
            adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            clipped_count_budget_allocation=FLAGS.
            clipped_count_budget_allocation,
            expected_clients_per_round=FLAGS.clients_per_round,
            per_vector_clipping=FLAGS.per_vector_clipping,
            model=model_fn())

        weights_type = tff.learning.framework.weights_type_from_model(model_fn)
        aggregation_process = tff.utils.build_dp_aggregate_process(
            weights_type.trainable, dp_query)
    else:
        aggregation_process = None

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn,
        client_optimizer_fn=client_optimizer_fn,
        aggregation_process=aggregation_process)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      hparam_dict=hparam_dict,
                      **training_loop_dict)
Exemplo n.º 16
0
 def test_create_optimizer_fn_from_flags_invalid_optimizer(self):
     FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'foo'
     with self.assertRaisesRegex(ValueError, 'not a valid optimizer'):
         optimizer_utils.create_optimizer_fn_from_flags(
             TEST_CLIENT_FLAG_PREFIX)
Exemplo n.º 17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    if FLAGS.task == 'stackoverflow_nwp_finetune':
        if not FLAGS.global_variables_only:
            raise ValueError('`FLAGS.global_variables_only` must be True for '
                             'a `stackoverflow_nwp_finetune` task.')
        if not FLAGS.client_epochs_per_round:
            raise ValueError('`FLAGS.client_epochs_per_round` must be set for '
                             'a `stackoverflow_nwp_finetune` task.')
        finetune_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
            'finetune')
    else:
        reconstruction_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
            'reconstruction')

    def iterative_process_builder(
        model_fn: Callable[[], reconstruction_model.ReconstructionModel],
        loss_fn: Callable[[], List[tf.keras.losses.Loss]],
        metrics_fn: Optional[Callable[[],
                                      List[tf.keras.metrics.Metric]]] = None,
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        dataset_split_fn_builder: Callable[
            ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils.
        build_dataset_split_fn,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    For a `stackoverflow_nwp_finetune` task, the `model_fn` must return a model
    that has only global variables, and the argument `dataset_split_fn_builder`
    is ignored. The returned iterative process is basically the same as the one
    created by the standard `tff.learning.build_federated_averaging_process`.

    For other tasks, the returned iterative process performs the federated
    reconstruction algorithm defined by
    `training_process.build_federated_reconstruction_process`.

    Args:
      model_fn: A no-arg function returning a
        `reconstruction_model.ReconstructionModel`. The returned model must have
        only global variables for a `stackoverflow_nwp_finetune` task.
      loss_fn: A no-arg function returning a list of `tf.keras.losses.Loss`.
      metrics_fn: A no-arg function returning a list of
        `tf.keras.metrics.Metric`.
      client_weight_fn: Optional function that takes the local model's output,
        and returns a tensor that provides the weight in the federated average
        of model deltas. If not provided, the default is the total number of
        examples processed on device. If DP is used, this argument is ignored,
        and uniform client weighting is used.
      dataset_split_fn_builder: `DatasetSplitFn` builder. Returns a method used
        to split the examples into a reconstruction, and post-reconstruction
        set. Ignored for a `stackoverflow_nwp_finetune` task.

    Raises:
      ValueError: if `model_fn` returns a model with local variables for a
        `stackoverflow_nwp_finetune` task.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        # Get aggregation factory for DP, if needed.
        aggregation_factory = None
        client_weighting = client_weight_fn
        if FLAGS.dp_noise_multiplier is not None:
            aggregation_factory = tff.learning.dp_aggregator(
                noise_multiplier=FLAGS.dp_noise_multiplier,
                clients_per_round=float(FLAGS.clients_per_round),
                zeroing=FLAGS.dp_zeroing)
            # DP is only implemented for uniform weighting.
            client_weighting = lambda _: 1.0

        if FLAGS.task == 'stackoverflow_nwp_finetune':

            if not reconstruction_utils.has_only_global_variables(model_fn()):
                raise ValueError(
                    '`model_fn` should return a model with only global variables. '
                )

            def fake_dataset_split_fn(
                client_dataset: tf.data.Dataset, round_num: tf.Tensor
            ) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
                del round_num
                return client_dataset.repeat(0), client_dataset.repeat(
                    FLAGS.client_epochs_per_round)

            return training_process.build_federated_reconstruction_process(
                model_fn=model_fn,
                loss_fn=loss_fn,
                metrics_fn=metrics_fn,
                server_optimizer_fn=lambda: server_optimizer_fn(
                    FLAGS.server_learning_rate),
                client_optimizer_fn=lambda: client_optimizer_fn(
                    FLAGS.client_learning_rate),
                dataset_split_fn=fake_dataset_split_fn,
                client_weight_fn=client_weighting,
                aggregation_factory=aggregation_factory)

        return training_process.build_federated_reconstruction_process(
            model_fn=model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            server_optimizer_fn=lambda: server_optimizer_fn(
                FLAGS.server_learning_rate),
            client_optimizer_fn=lambda: client_optimizer_fn(
                FLAGS.client_learning_rate),
            reconstruction_optimizer_fn=functools.partial(
                reconstruction_optimizer_fn,
                FLAGS.reconstruction_learning_rate),
            dataset_split_fn=dataset_split_fn_builder(
                recon_epochs_max=FLAGS.recon_epochs_max,
                recon_epochs_constant=FLAGS.recon_epochs_constant,
                recon_steps_max=FLAGS.recon_steps_max,
                post_recon_epochs=FLAGS.post_recon_epochs,
                post_recon_steps_max=FLAGS.post_recon_steps_max,
                split_dataset=FLAGS.split_dataset),
            evaluate_reconstruction=FLAGS.evaluate_reconstruction,
            jointly_train_variables=FLAGS.jointly_train_variables,
            client_weight_fn=client_weighting,
            aggregation_factory=aggregation_factory)

    def evaluation_computation_builder(
        model_fn: Callable[[], reconstruction_model.ReconstructionModel],
        loss_fn: Callable[[], tf.losses.Loss],
        metrics_fn: Callable[[], List[tf.metrics.Metric]],
        dataset_split_fn_builder: Callable[
            ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils.
        build_dataset_split_fn,
    ) -> tff.Computation:
        """Creates a `tff.Computation` for federated evaluation.

    For a `stackoverflow_nwp_finetune` task, the returned `tff.Computation` is
    created by `federated_evaluation.build_federated_finetune_evaluation`. For
    other tasks, the returned `tff.Computation` is given by
    `evaluation_computation.build_federated_reconstruction_evaluation`.

    Args:
      model_fn: A no-arg function that returns a `ReconstructionModel`. The
        returned model must have only global variables for a
        `stackoverflow_nwp_finetune` task. This method must *not* capture
        Tensorflow tensors or variables and use them. Must be constructed
        entirely from scratch on each invocation, returning the same model each
        call will result in an error.
      loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
        evaluate the model. The final loss metric is the example-weighted mean
        loss across batches (and across clients).
      metrics_fn: A no-arg function returning a list of
        `tf.keras.metrics.Metric`s to use to evaluate the model. The final
        metrics are the example-weighted mean metrics across batches (and across
        clients).
      dataset_split_fn_builder: `DatasetSplitFn` builder. Returns a method used
        to split the examples into a reconstruction set (which is used as a
        fine-tuning set for a `stackoverflow_nwp_finetune` task), and an
        evaluation set.

    Returns:
      A `tff.Computation` for federated evaluation.
    """

        # For a `stackoverflow_nwp_finetune` task, the first dataset returned by
        # `dataset_split_fn` is used for fine-tuning global variables. For other
        # tasks, the first dataset is used for reconstructing local variables.
        dataset_split_fn = dataset_split_fn_builder(
            recon_epochs_max=FLAGS.recon_epochs_max,
            recon_epochs_constant=FLAGS.recon_epochs_constant,
            recon_steps_max=FLAGS.recon_steps_max,
            post_recon_epochs=FLAGS.post_recon_epochs,
            post_recon_steps_max=FLAGS.post_recon_steps_max,
            # Getting meaningful evaluation metrics requires splitting the data.
            split_dataset=True)

        if FLAGS.task == 'stackoverflow_nwp_finetune':
            return federated_evaluation.build_federated_finetune_evaluation(
                model_fn=model_fn,
                loss_fn=loss_fn,
                metrics_fn=metrics_fn,
                finetune_optimizer_fn=functools.partial(
                    finetune_optimizer_fn, FLAGS.finetune_learning_rate),
                dataset_split_fn=dataset_split_fn)

        return evaluation_computation.build_federated_reconstruction_evaluation(
            model_fn=model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            reconstruction_optimizer_fn=functools.partial(
                reconstruction_optimizer_fn,
                FLAGS.reconstruction_learning_rate),
            dataset_split_fn=dataset_split_fn)

    # Shared args, useful to support more tasks.
    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder
    shared_args[
        'evaluation_computation_builder'] = evaluation_computation_builder

    task_args = _get_task_args()
    _write_hparam_flags()

    if FLAGS.task in ['stackoverflow_nwp', 'stackoverflow_nwp_finetune']:
        run_federated_fn = federated_stackoverflow.run_federated
    elif FLAGS.task == 'movielens_mf':
        run_federated_fn = federated_movielens.run_federated
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    run_federated_fn(**shared_args, **task_args)
Exemplo n.º 18
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    client_mixedin_schedule_fn = fed_pa_schedule.create_mixin_check_fn(
        name=FLAGS.client_mixin_check_scheme,
        num_mixin_epochs=FLAGS.client_mixin_epochs_per_round,
        start_round=FLAGS.client_mixin_check_start_round)
    client_update_delta_fn = fed_pa_schedule.create_update_delta_fn(
        name=FLAGS.client_update_delta_scheme, rho=FLAGS.client_shrinkage_rho)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model]
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """
        return fed_pa_schedule.build_fed_pa_process(
            model_fn=model_fn,
            client_update_epochs=FLAGS.client_epochs_per_round,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_mixedin_schedule_fn=client_mixedin_schedule_fn,
            client_update_delta_fn=client_update_delta_fn,
            mask_zeros_in_client_updates=FLAGS.mask_zeros_in_client_updates)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        # Since the number of epochs each client makes every round is handled
        # by the logic in client update functions, here we set it to 1.
        client_epochs_per_round=1,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec,
            crop_size=FLAGS.cifar100_crop_size,
            distort_train_images=FLAGS.cifar100_distort_train_images)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    _write_hparam_flags()

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
Exemplo n.º 19
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  emnist_train, _ = emnist_dataset.get_federated_datasets(
      train_client_batch_size=FLAGS.client_batch_size,
      train_client_epochs_per_round=FLAGS.client_epochs_per_round,
      only_digits=False)

  _, emnist_test = emnist_dataset.get_centralized_datasets()

  if FLAGS.model == 'cnn':
    model_builder = functools.partial(
        emnist_models.create_conv_dropout_model, only_digits=False)
  elif FLAGS.model == '2nn':
    model_builder = functools.partial(
        emnist_models.create_two_hidden_layer_model, only_digits=False)
  else:
    raise ValueError('Cannot handle model flag [{!s}].'.format(FLAGS.model))

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  if FLAGS.uniform_weighting:
    client_weighting = tff.learning.ClientWeighting.UNIFORM
  else:
    client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES

  def model_fn():
    return tff.learning.from_keras_model(
        model_builder(),
        loss_builder(),
        input_spec=emnist_test.element_spec,
        metrics=metrics_builder())

  if FLAGS.noise_multiplier is not None:
    if not FLAGS.uniform_weighting:
      raise ValueError(
          'Differential privacy is only implemented for uniform weighting.')
    if FLAGS.noise_multiplier <= 0:
      raise ValueError('noise_multiplier must be positive if DP is enabled.')
    if FLAGS.clip is None or FLAGS.clip <= 0:
      raise ValueError('clip must be positive if DP is enabled.')

    if not FLAGS.adaptive_clip_learning_rate:
      aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
          noise_multiplier=FLAGS.noise_multiplier,
          clients_per_round=FLAGS.clients_per_round,
          clip=FLAGS.clip)
    else:
      if FLAGS.adaptive_clip_learning_rate <= 0:
        raise ValueError('adaptive_clip_learning_rate must be positive if '
                         'adaptive clipping is enabled.')
      aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive(
          noise_multiplier=FLAGS.noise_multiplier,
          clients_per_round=FLAGS.clients_per_round,
          initial_l2_norm_clip=FLAGS.clip,
          target_unclipped_quantile=FLAGS.target_unclipped_quantile,
          learning_rate=FLAGS.adaptive_clip_learning_rate)
  else:
    if FLAGS.uniform_weighting:
      aggregation_factory = tff.aggregators.UnweightedMeanFactory()
    else:
      aggregation_factory = tff.aggregators.MeanFactory()

  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')
  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=model_fn,
      server_optimizer_fn=server_optimizer_fn,
      client_weighting=client_weighting,
      client_optimizer_fn=client_optimizer_fn,
      model_update_aggregation_factory=aggregation_factory)

  client_datasets_fn = training_utils.build_client_datasets_fn(
      emnist_train, FLAGS.clients_per_round)

  evaluate_fn = training_utils.build_centralized_evaluate_fn(
      eval_dataset=emnist_test,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)
  validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  # Log hyperparameters to CSV
  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
  results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                             FLAGS.experiment_name)
  utils_impl.create_directory_if_not_exists(results_dir)
  hparam_file = os.path.join(results_dir, 'hparams.csv')
  utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)

  training_loop.run(
      iterative_process=iterative_process,
      client_datasets_fn=client_datasets_fn,
      validation_fn=validation_fn,
      total_rounds=FLAGS.total_rounds,
      experiment_name=FLAGS.experiment_name,
      root_output_dir=FLAGS.root_output_dir,
      rounds_per_eval=FLAGS.rounds_per_eval,
      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
      rounds_per_profile=FLAGS.rounds_per_profile)
Exemplo n.º 20
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)

    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder
    task_args = _get_task_args()
    hparam_dict = _get_hparam_flags()

    if FLAGS.task == 'cifar100':
        run_federated_fn = federated_cifar100.run_federated
    elif FLAGS.task == 'emnist_cr':
        run_federated_fn = federated_emnist.run_federated
    elif FLAGS.task == 'emnist_ae':
        run_federated_fn = federated_emnist_ae.run_federated
    elif FLAGS.task == 'shakespeare':
        run_federated_fn = federated_shakespeare.run_federated
    elif FLAGS.task == 'stackoverflow_nwp':
        run_federated_fn = federated_stackoverflow.run_federated
    elif FLAGS.task == 'stackoverflow_lr':
        run_federated_fn = federated_stackoverflow_lr.run_federated
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    run_federated_fn(**shared_args, **task_args, hparam_dict=hparam_dict)
Exemplo n.º 21
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    logging.info('Trainable weights:')
    for weight in model_fn().weights.trainable:
      logging.info('name: %s  shape: %s', weight.name, weight.shape)

    if FLAGS.uniform_weighting:
      client_weighting = tff.learning.ClientWeighting.UNIFORM
    elif FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp':

      def client_weighting(local_outputs):
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)
    else:
      client_weighting = None

    if FLAGS.noise_multiplier is None:
      if FLAGS.uniform_weighting:
        aggregation_factory = tff.aggregators.UnweightedMeanFactory()
      else:
        aggregation_factory = tff.aggregators.MeanFactory()
      if FLAGS.clip is not None:
        if FLAGS.clip <= 0:
          raise ValueError('clip must be positive if clipping is enabled.')
        if FLAGS.adaptive_clip_learning_rate is None:
          clip = FLAGS.clip
        else:
          if FLAGS.adaptive_clip_learning_rate <= 0:
            raise ValueError('adaptive_clip_learning_rate must be positive if '
                             'adaptive clipping is enabled.')
          clip = tff.aggregators.PrivateQuantileEstimationProcess.no_noise(
              initial_estimate=FLAGS.clip,
              target_quantile=FLAGS.target_unclipped_quantile,
              learning_rate=FLAGS.adaptive_clip_learning_rate)
        aggregation_factory = tff.aggregators.clipping_factory(
            clip, aggregation_factory)
    else:
      if not FLAGS.uniform_weighting:
        raise ValueError(
            'Differential privacy is only implemented for uniform weighting.')
      if FLAGS.noise_multiplier <= 0:
        raise ValueError('noise_multiplier must be positive if DP is enabled.')
      if FLAGS.clip is None or FLAGS.clip <= 0:
        raise ValueError('clip must be positive if DP is enabled.')
      if FLAGS.adaptive_clip_learning_rate is None:
        aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
            noise_multiplier=FLAGS.noise_multiplier,
            clients_per_round=FLAGS.clients_per_round,
            clip=FLAGS.clip)
      else:
        if FLAGS.adaptive_clip_learning_rate <= 0:
          raise ValueError('adaptive_clip_learning_rate must be positive if '
                           'adaptive clipping is enabled.')
        aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive(
            noise_multiplier=FLAGS.noise_multiplier,
            clients_per_round=FLAGS.clients_per_round,
            initial_l2_norm_clip=FLAGS.clip,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            learning_rate=FLAGS.adaptive_clip_learning_rate)

    return tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weighting=client_weighting,
        client_optimizer_fn=client_optimizer_fn,
        model_update_aggregation_factory=aggregation_factory)

  task_spec = training_specs.TaskSpec(
      iterative_process_builder=iterative_process_builder,
      client_epochs_per_round=FLAGS.client_epochs_per_round,
      client_batch_size=FLAGS.client_batch_size,
      clients_per_round=FLAGS.clients_per_round,
      client_datasets_random_seed=FLAGS.client_datasets_random_seed)

  if FLAGS.task == 'cifar100':
    runner_spec = federated_cifar100.configure_training(task_spec)
  elif FLAGS.task == 'emnist_cr':
    runner_spec = federated_emnist.configure_training(task_spec)
  elif FLAGS.task == 'emnist_ae':
    runner_spec = federated_emnist_ae.configure_training(task_spec)
  elif FLAGS.task == 'shakespeare':
    runner_spec = federated_shakespeare.configure_training(task_spec)
  elif FLAGS.task == 'stackoverflow_nwp':
    runner_spec = federated_stackoverflow.configure_training(task_spec)
  elif FLAGS.task == 'stackoverflow_lr':
    runner_spec = federated_stackoverflow_lr.configure_training(task_spec)
  else:
    raise ValueError(
        '--task flag {} is not supported, must be one of {}.'.format(
            FLAGS.task, _SUPPORTED_TASKS))

  _write_hparam_flags()

  training_loop.run(
      iterative_process=runner_spec.iterative_process,
      client_datasets_fn=runner_spec.client_datasets_fn,
      validation_fn=runner_spec.validation_fn,
      test_fn=runner_spec.test_fn,
      total_rounds=FLAGS.total_rounds,
      experiment_name=FLAGS.experiment_name,
      root_output_dir=FLAGS.root_output_dir,
      rounds_per_eval=FLAGS.rounds_per_eval,
      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
Exemplo n.º 22
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tff.backends.native.set_local_execution_context(max_fanout=10)

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=FLAGS.vocab_size,
        embedding_size=FLAGS.embedding_size,
        latent_size=FLAGS.latent_size,
        num_layers=FLAGS.num_layers,
        shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        FLAGS.vocab_size)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
        ]

    train_dataset, _ = stackoverflow_word_prediction.get_federated_datasets(
        vocab_size=FLAGS.vocab_size,
        train_client_batch_size=FLAGS.client_batch_size,
        train_client_epochs_per_round=FLAGS.client_epochs_per_round,
        max_sequence_length=FLAGS.sequence_length,
        max_elements_per_train_client=FLAGS.max_elements_per_user)
    _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size=FLAGS.vocab_size,
        max_sequence_length=FLAGS.sequence_length,
        num_validation_examples=FLAGS.num_validation_examples)

    if FLAGS.uniform_weighting:
        client_weighting = tff.learning.ClientWeighting.UNIFORM
    else:
        client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=validation_dataset.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )
        if FLAGS.noise_multiplier <= 0:
            raise ValueError(
                'noise_multiplier must be positive if DP is enabled.')
        if FLAGS.clip is None or FLAGS.clip <= 0:
            raise ValueError('clip must be positive if DP is enabled.')

        if not FLAGS.adaptive_clip_learning_rate:
            aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
                noise_multiplier=FLAGS.noise_multiplier,
                clients_per_round=FLAGS.clients_per_round,
                clip=FLAGS.clip)
        else:
            if FLAGS.adaptive_clip_learning_rate <= 0:
                raise ValueError(
                    'adaptive_clip_learning_rate must be positive if '
                    'adaptive clipping is enabled.')
            aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive(
                noise_multiplier=FLAGS.noise_multiplier,
                clients_per_round=FLAGS.clients_per_round,
                initial_l2_norm_clip=FLAGS.clip,
                target_unclipped_quantile=FLAGS.target_unclipped_quantile,
                learning_rate=FLAGS.adaptive_clip_learning_rate)
    else:
        if FLAGS.uniform_weighting:
            aggregation_factory = tff.aggregators.UnweightedMeanFactory()
        else:
            aggregation_factory = tff.aggregators.MeanFactory()

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weighting=client_weighting,
        client_optimizer_fn=client_optimizer_fn,
        model_update_aggregation_factory=aggregation_factory)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)
    validation_fn = lambda state, round_num: evaluate_fn(state.model)

    evaluate_test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)
    test_fn = lambda state: evaluate_test_fn(state.model)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    # Log hyperparameters to CSV
    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                               FLAGS.experiment_name)
    utils_impl.create_directory_if_not_exists(results_dir)
    hparam_file = os.path.join(results_dir, 'hparams.csv')
    utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=validation_fn,
                      test_fn=test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
Exemplo n.º 23
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tff.backends.native.set_local_execution_context(max_fanout=10)

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=FLAGS.vocab_size,
        embedding_size=FLAGS.embedding_size,
        latent_size=FLAGS.latent_size,
        num_layers=FLAGS.num_layers,
        shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_dataset.get_special_tokens(FLAGS.vocab_size)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
        ]

    datasets = stackoverflow_dataset.construct_word_level_datasets(
        FLAGS.vocab_size, FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round, FLAGS.sequence_length,
        FLAGS.max_elements_per_user, FLAGS.num_validation_examples)
    train_dataset, validation_dataset, test_dataset = datasets

    if FLAGS.uniform_weighting:

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0
    else:

        def client_weight_fn(local_outputs):
            return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=validation_dataset.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )

        dp_query = tff.utils.build_dp_query(
            clip=FLAGS.clip,
            noise_multiplier=FLAGS.noise_multiplier,
            expected_total_weight=FLAGS.clients_per_round,
            adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            clipped_count_budget_allocation=FLAGS.
            clipped_count_budget_allocation,
            expected_clients_per_round=FLAGS.clients_per_round)

        weights_type = tff.learning.framework.weights_type_from_model(model_fn)
        aggregation_process = tff.utils.build_dp_aggregate_process(
            weights_type.trainable, dp_query)
    else:
        aggregation_process = None

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn,
        client_optimizer_fn=client_optimizer_fn,
        aggregation_process=aggregation_process)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      test_fn=test_fn,
                      hparam_dict=hparam_dict,
                      **training_loop_dict)
Exemplo n.º 24
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec,
            crop_size=FLAGS.cifar100_crop_size,
            distort_train_images=FLAGS.cifar100_distort_train_images)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    _write_hparam_flags()

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)