Example #1
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client')
  server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server')

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
      dataset_preprocess_comp: Optional[tff.Computation] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.
      dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
        pipeline on the clients. The computation must take a squence of values
        and return a sequence of values, or in TFF type shorthand `(U* -> V*)`.
        If `None`, no dataset preprocessing is applied.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn,
        dataset_preprocess_comp=dataset_preprocess_comp)

  assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

  common_args = collections.OrderedDict([
      ('iterative_process_builder', iterative_process_builder),
      ('assign_weights_fn', assign_weights_fn),
      ('client_epochs_per_round', FLAGS.client_epochs_per_round),
      ('client_batch_size', FLAGS.client_batch_size),
      ('clients_per_round', FLAGS.clients_per_round),
      ('client_datasets_random_seed', FLAGS.client_datasets_random_seed)
  ])

  if FLAGS.task == 'cifar100':
    federated_cifar100.run_federated(
        **common_args, crop_size=FLAGS.cifar100_crop_size)

  elif FLAGS.task == 'emnist_cr':
    federated_emnist.run_federated(
        **common_args, emnist_model=FLAGS.emnist_cr_model)

  elif FLAGS.task == 'emnist_ae':
    federated_emnist_ae.run_federated(**common_args)

  elif FLAGS.task == 'shakespeare':
    federated_shakespeare.run_federated(
        **common_args, sequence_length=FLAGS.shakespeare_sequence_length)

  elif FLAGS.task == 'stackoverflow_nwp':
    so_nwp_flags = collections.OrderedDict()
    for flag_name in FLAGS:
      if flag_name.startswith('so_nwp_'):
        so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
    federated_stackoverflow.run_federated(**common_args, **so_nwp_flags)

  elif FLAGS.task == 'stackoverflow_lr':
    so_lr_flags = collections.OrderedDict()
    for flag_name in FLAGS:
      if flag_name.startswith('so_lr_'):
        so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
    federated_stackoverflow_lr.run_federated(**common_args, **so_lr_flags)

  else:
    raise ValueError(
        '--task flag {} is not supported, must be one of {}.'.format(
            FLAGS.task, _SUPPORTED_TASKS))
Example #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.client_learning_rate,
        decay_factor=FLAGS.client_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    server_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.server_learning_rate,
        decay_factor=FLAGS.server_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weight_fn=client_weight_fn)

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())

    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder

    if FLAGS.task == 'cifar100':
        hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size
        federated_cifar100.run_federated(**shared_args,
                                         crop_size=FLAGS.cifar100_crop_size,
                                         hparam_dict=hparam_dict)

    elif FLAGS.task == 'emnist_cr':
        federated_emnist.run_federated(**shared_args,
                                       emnist_model=FLAGS.emnist_cr_model,
                                       hparam_dict=hparam_dict)

    elif FLAGS.task == 'emnist_ae':
        federated_emnist_ae.run_federated(**shared_args,
                                          hparam_dict=hparam_dict)

    elif FLAGS.task == 'shakespeare':
        federated_shakespeare.run_federated(
            **shared_args,
            sequence_length=FLAGS.shakespeare_sequence_length,
            hparam_dict=hparam_dict)

    elif FLAGS.task == 'stackoverflow_nwp':
        so_nwp_flags = collections.OrderedDict()
        for flag_name in task_flags:
            if flag_name.startswith('so_nwp_'):
                so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
        federated_stackoverflow.run_federated(**shared_args,
                                              **so_nwp_flags,
                                              hparam_dict=hparam_dict)

    elif FLAGS.task == 'stackoverflow_lr':
        so_lr_flags = collections.OrderedDict()
        for flag_name in task_flags:
            if flag_name.startswith('so_lr_'):
                so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
        federated_stackoverflow_lr.run_federated(**shared_args,
                                                 **so_lr_flags,
                                                 hparam_dict=hparam_dict)
Example #3
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client')
  server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server')

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)

  assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model
  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())

  shared_args = utils_impl.lookup_flag_values(shared_flags)
  shared_args['iterative_process_builder'] = iterative_process_builder
  shared_args['assign_weights_fn'] = assign_weights_fn

  if FLAGS.task == 'cifar100':
    hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size
    federated_cifar100.run_federated(
        **shared_args,
        crop_size=FLAGS.cifar100_crop_size,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_cr':
    federated_emnist.run_federated(
        **shared_args,
        emnist_model=FLAGS.emnist_cr_model,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_ae':
    federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict)

  elif FLAGS.task == 'shakespeare':
    federated_shakespeare.run_federated(
        **shared_args,
        sequence_length=FLAGS.shakespeare_sequence_length,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_nwp':
    so_nwp_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_nwp_'):
        so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
    federated_stackoverflow.run_federated(
        **shared_args, **so_nwp_flags, hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_lr':
    so_lr_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_lr_'):
        so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
    federated_stackoverflow_lr.run_federated(
        **shared_args, **so_lr_flags, hparam_dict=hparam_dict)

  else:
    raise ValueError(
        '--task flag {} is not supported, must be one of {}.'.format(
            FLAGS.task, _SUPPORTED_TASKS))