Esempio n. 1
0
  def test_create_inv_lin_client_lr_schedule_from_flags(self):
    with flag_sandbox({
        '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 5.0,
        '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_lin_decay',
        '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
        '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0,
        '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): True,
    }):
      lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
          TEST_CLIENT_FLAG_PREFIX)
      self.assertNear(lr_schedule(0), 5.0, err=1e-5)
      self.assertNear(lr_schedule(1), 5.0, err=1e-5)
      self.assertNear(lr_schedule(10), 0.454545454545, err=1e-5)
      self.assertNear(lr_schedule(19), 0.454545454545, err=1e-5)
      self.assertNear(lr_schedule(20), 0.238095238095, err=1e-5)

    with flag_sandbox({
        '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 5.0,
        '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_lin_decay',
        '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
        '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0,
        '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False,
    }):
      lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
          TEST_CLIENT_FLAG_PREFIX)
      self.assertNear(lr_schedule(0), 5.0, err=1e-5)
      self.assertNear(lr_schedule(1), 2.5, err=1e-5)
      self.assertNear(lr_schedule(9), 0.5, err=1e-5)
      self.assertNear(lr_schedule(19), 0.25, err=1e-5)
Esempio n. 2
0
  def test_create_inv_sqrt_client_lr_schedule_from_flags(self):
    with flag_sandbox({
        '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 2.0,
        '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_sqrt_decay',
        '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
        '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0,
        '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): True,
    }):
      lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
          TEST_CLIENT_FLAG_PREFIX)
      self.assertNear(lr_schedule(0), 2.0, err=1e-5)
      self.assertNear(lr_schedule(1), 2.0, err=1e-5)
      self.assertNear(lr_schedule(10), 0.603022689155, err=1e-5)
      self.assertNear(lr_schedule(19), 0.603022689155, err=1e-5)
      self.assertNear(lr_schedule(20), 0.436435780472, err=1e-5)

    with flag_sandbox({
        '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 2.0,
        '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_sqrt_decay',
        '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
        '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0,
        '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False,
    }):
      lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
          TEST_CLIENT_FLAG_PREFIX)
      self.assertNear(lr_schedule(0), 2.0, err=1e-5)
      self.assertNear(lr_schedule(3), 1.0, err=1e-5)
      self.assertNear(lr_schedule(99), 0.2, err=1e-5)
      self.assertNear(lr_schedule(399), 0.1, err=1e-5)
 def test_create_constant_client_lr_schedule_from_flags(self):
     with flag_sandbox({
             '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
             3.0,
             '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
             'constant',
     }):
         lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
             TEST_CLIENT_FLAG_PREFIX)
         self.assertNear(lr_schedule(0), 3.0, err=1e-5)
         self.assertNear(lr_schedule(1), 3.0, err=1e-5)
         self.assertNear(lr_schedule(105), 3.0, err=1e-5)
         self.assertNear(lr_schedule(1042), 3.0, err=1e-5)
     with flag_sandbox({
             '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
             3.0,
             '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX):
             'constant',
             '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX):
             10
     }):
         lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
             TEST_CLIENT_FLAG_PREFIX)
         self.assertNear(lr_schedule(0), 0.3, err=1e-5)
         self.assertNear(lr_schedule(1), 0.6, err=1e-5)
         self.assertNear(lr_schedule(10), 3.0, err=1e-5)
         self.assertNear(lr_schedule(11), 3.0, err=1e-5)
         self.assertNear(lr_schedule(115), 3.0, err=1e-5)
         self.assertNear(lr_schedule(1052), 3.0, err=1e-5)
Esempio n. 4
0
  def test_create_exp_decay_client_lr_schedule_from_flags(self):
    with flag_sandbox({
        '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0,
        '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'exp_decay',
        '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
        '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 0.1,
        '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): True,
    }):
      lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
          TEST_CLIENT_FLAG_PREFIX)
      self.assertNear(lr_schedule(0), 3.0, err=1e-5)
      self.assertNear(lr_schedule(3), 3.0, err=1e-5)
      self.assertNear(lr_schedule(10), 0.3, err=1e-5)
      self.assertNear(lr_schedule(19), 0.3, err=1e-5)
      self.assertNear(lr_schedule(20), 0.03, err=1e-5)

    with flag_sandbox({
        '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0,
        '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'exp_decay',
        '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
        '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 0.1,
        '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False,
    }):
      lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
          TEST_CLIENT_FLAG_PREFIX)
      self.assertNear(lr_schedule(0), 3.0, err=1e-5)
      self.assertNear(lr_schedule(1), 2.38298470417, err=1e-5)
      self.assertNear(lr_schedule(10), 0.3, err=1e-5)
      self.assertNear(lr_schedule(25), 0.00948683298, err=1e-5)
Esempio n. 5
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client')
  server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server')

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)

  shared_args = utils_impl.lookup_flag_values(shared_flags)
  shared_args['iterative_process_builder'] = iterative_process_builder
  task_args = _get_task_args()
  hparam_dict = _get_hparam_flags()

  if FLAGS.task == 'cifar100':
    run_federated_fn = federated_cifar100.run_federated
  elif FLAGS.task == 'emnist_cr':
    run_federated_fn = federated_emnist.run_federated
  elif FLAGS.task == 'emnist_ae':
    run_federated_fn = federated_emnist_ae.run_federated
  elif FLAGS.task == 'shakespeare':
    run_federated_fn = federated_shakespeare.run_federated
  elif FLAGS.task == 'stackoverflow_nwp':
    run_federated_fn = federated_stackoverflow.run_federated
  elif FLAGS.task == 'stackoverflow_lr':
    run_federated_fn = federated_stackoverflow_lr.run_federated
  else:
    raise ValueError(
        '--task flag {} is not supported, must be one of {}.'.format(
            FLAGS.task, _SUPPORTED_TASKS))

  run_federated_fn(**shared_args, **task_args, hparam_dict=hparam_dict)
Esempio n. 6
0
def from_flags(
    input_spec,
    model_builder: ModelBuilder,
    loss_builder: LossBuilder,
    metrics_builder: MetricsBuilder,
    client_weight_fn: Optional[ClientWeightFn] = None,
) -> tff.templates.IterativeProcess:
    """Builds a `tff.templates.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    input_spec: A value convertible to a `tff.Type`, representing the data which
      will be fed into the `tff.templates.IterativeProcess.next` function over
      the course of training. Generally, this can be found by accessing the
      `element_spec` attribute of a client `tf.data.Dataset`.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras
    # model.
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    model_input_spec = input_spec

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=model_input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)
def from_flags(dummy_batch,
               model_builder,
               loss_builder,
               metrics_builder,
               client_weight_fn=None):
    """Builds a `tff.utils.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    dummy_batch: A nested structure of values that are convertible to batched
      tensors with the same shapes and types as expected in the forward pass of
      training. The actual values are not important and can hold any reasonable
      value.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.

  Returns:
    A `tff.utils.IterativeProcess` instance.
  """
    # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras
    # model.
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def tff_model_fn():
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             dummy_batch=dummy_batch,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)
Esempio n. 8
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client')
  server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server')

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
      dataset_preprocess_comp: Optional[tff.Computation] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.
      dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
        pipeline on the clients. The computation must take a squence of values
        and return a sequence of values, or in TFF type shorthand `(U* -> V*)`.
        If `None`, no dataset preprocessing is applied.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn,
        dataset_preprocess_comp=dataset_preprocess_comp)

  assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

  common_args = collections.OrderedDict([
      ('iterative_process_builder', iterative_process_builder),
      ('assign_weights_fn', assign_weights_fn),
      ('client_epochs_per_round', FLAGS.client_epochs_per_round),
      ('client_batch_size', FLAGS.client_batch_size),
      ('clients_per_round', FLAGS.clients_per_round),
      ('client_datasets_random_seed', FLAGS.client_datasets_random_seed)
  ])

  if FLAGS.task == 'cifar100':
    federated_cifar100.run_federated(
        **common_args, crop_size=FLAGS.cifar100_crop_size)

  elif FLAGS.task == 'emnist_cr':
    federated_emnist.run_federated(
        **common_args, emnist_model=FLAGS.emnist_cr_model)

  elif FLAGS.task == 'emnist_ae':
    federated_emnist_ae.run_federated(**common_args)

  elif FLAGS.task == 'shakespeare':
    federated_shakespeare.run_federated(
        **common_args, sequence_length=FLAGS.shakespeare_sequence_length)

  elif FLAGS.task == 'stackoverflow_nwp':
    so_nwp_flags = collections.OrderedDict()
    for flag_name in FLAGS:
      if flag_name.startswith('so_nwp_'):
        so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
    federated_stackoverflow.run_federated(**common_args, **so_nwp_flags)

  elif FLAGS.task == 'stackoverflow_lr':
    so_lr_flags = collections.OrderedDict()
    for flag_name in FLAGS:
      if flag_name.startswith('so_lr_'):
        so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
    federated_stackoverflow_lr.run_federated(**common_args, **so_lr_flags)

  else:
    raise ValueError(
        '--task flag {} is not supported, must be one of {}.'.format(
            FLAGS.task, _SUPPORTED_TASKS))
def from_flags(
    input_spec,
    model_builder: ModelBuilder,
    loss_builder: LossBuilder,
    metrics_builder: MetricsBuilder,
    client_weight_fn: Optional[ClientWeightFn] = None,
    *,
    dataset_preprocess_comp: Optional[tff.Computation] = None,
) -> fed_avg_schedule.FederatedAveragingProcessAdapter:
  """Builds a `tff.templates.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    input_spec: A value convertible to a `tff.Type`, representing the data which
      will be fed into the `tff.templates.IterativeProcess.next` function over
      the course of training. Generally, this can be found by accessing the
      `element_spec` attribute of a client `tf.data.Dataset`.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.
    dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
      pipeline on the clients. The computation must take a squence of values
      and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If
      `None`, no dataset preprocessing is applied. If specified, `input_spec` is
      optinal, as the necessary type signatures will taken from the computation.

  Returns:
    A `fed_avg_schedule.FederatedAveragingProcessAdapter`.
  """
  # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras
  # model.
  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client')
  server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server')

  if dataset_preprocess_comp is not None:
    if input_spec is not None:
      print('Specified both `dataset_preprocess_comp` and `input_spec` when '
            'only one is necessary. Ignoring `input_spec` and using type '
            'signature of `dataset_preprocess_comp`.')
    model_input_spec = dataset_preprocess_comp.type_signature.result.element
  else:
    model_input_spec = input_spec

  def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=model_input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  return fed_avg_schedule.build_fed_avg_process(
      model_fn=tff_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      client_lr=client_lr_schedule,
      server_optimizer_fn=server_optimizer_fn,
      server_lr=server_lr_schedule,
      client_weight_fn=client_weight_fn,
      dataset_preprocess_comp=dataset_preprocess_comp)
Esempio n. 10
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client')
  server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server')

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)

  assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model
  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())

  shared_args = utils_impl.lookup_flag_values(shared_flags)
  shared_args['iterative_process_builder'] = iterative_process_builder
  shared_args['assign_weights_fn'] = assign_weights_fn

  if FLAGS.task == 'cifar100':
    hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size
    federated_cifar100.run_federated(
        **shared_args,
        crop_size=FLAGS.cifar100_crop_size,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_cr':
    federated_emnist.run_federated(
        **shared_args,
        emnist_model=FLAGS.emnist_cr_model,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_ae':
    federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict)

  elif FLAGS.task == 'shakespeare':
    federated_shakespeare.run_federated(
        **shared_args,
        sequence_length=FLAGS.shakespeare_sequence_length,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_nwp':
    so_nwp_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_nwp_'):
        so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
    federated_stackoverflow.run_federated(
        **shared_args, **so_nwp_flags, hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_lr':
    so_lr_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_lr_'):
        so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
    federated_stackoverflow_lr.run_federated(
        **shared_args, **so_lr_flags, hparam_dict=hparam_dict)

  else:
    raise ValueError(
        '--task flag {} is not supported, must be one of {}.'.format(
            FLAGS.task, _SUPPORTED_TASKS))