Example #1
0
    def test_fed_sgd_without_decay_decreases_loss(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.0,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state, train_outputs = self._run_rounds(iterative_process, 5)
        self.assertLess(train_outputs[-1]['before_training']['loss'],
                        train_outputs[0]['before_training']['loss'])
        self.assertLess(train_outputs[-1]['during_training']['loss'],
                        train_outputs[0]['during_training']['loss'])
        self.assertNear(state.client_lr_callback.learning_rate, 0.0, 1e-8)
        self.assertNear(state.server_lr_callback.learning_rate, 0.1, 1e-8)
Example #2
0
    def test_comparable_to_fed_avg(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        reference_iterative_process = tff.learning.build_federated_averaging_process(
            _uncompiled_model_builder,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
            server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0))

        _, train_outputs = self._run_rounds(iterative_process, 5)
        _, reference_train_outputs = self._run_rounds_tff_fedavg(
            reference_iterative_process, 5)

        for i in range(5):
            self.assertAllClose(train_outputs[i]['during_training']['loss'],
                                reference_train_outputs[i]['loss'], 1e-4)
Example #3
0
    def test_iterative_process_type_signature(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type, federated_dataset=dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
Example #4
0
 def test_raises_bad_decay_factor(self):
     with self.assertRaises(ValueError):
         callbacks.create_reduce_lr_on_plateau(learning_rate=0.1,
                                               decay_factor=2.0,
                                               cooldown=0)
     with self.assertRaises(ValueError):
         callbacks.create_reduce_lr_on_plateau(learning_rate=0.1,
                                               decay_factor=-1.0)
Example #5
0
 def test_cooldown(self):
     lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=2.0,
                                                         decay_factor=0.5,
                                                         minimize=False,
                                                         window_size=1,
                                                         patience=0,
                                                         cooldown=3)
     logging.info('LR Callback: %s', lr_callback)
     self.assertEqual(lr_callback.learning_rate, 2.0)
     self.assertEqual(lr_callback.cooldown, 3)
     self.assertEqual(lr_callback.cooldown_counter, 3)
     for i in range(2):
         lr_callback = callbacks.update_reduce_lr_on_plateau(
             lr_callback, -1.0)
         logging.info('LR Callback: %s', lr_callback)
         self.assertEqual(lr_callback.learning_rate, 2.0)
         self.assertEqual(lr_callback.wait, 0)
         self.assertEqual(lr_callback.cooldown, 3)
         self.assertEqual(lr_callback.cooldown_counter, 2 - i)
     lr_callback = callbacks.update_reduce_lr_on_plateau(lr_callback, -1)
     logging.info('LR Callback: %s', lr_callback)
     self.assertEqual(lr_callback.learning_rate, 1.0)
     self.assertEqual(lr_callback.wait, 0)
     self.assertEqual(lr_callback.cooldown, 3)
     self.assertEqual(lr_callback.cooldown_counter, 3)
Example #6
0
    def test_small_lr_comparable_zero_lr(self):
        client_lr_callback1 = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.0,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        client_lr_callback2 = callbacks.create_reduce_lr_on_plateau(
            learning_rate=1e-8,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process1 = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback1,
            callbacks.update_reduce_lr_on_plateau,
            server_lr_callback,
            callbacks.update_reduce_lr_on_plateau,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)
        iterative_process2 = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback2,
            callbacks.update_reduce_lr_on_plateau,
            server_lr_callback,
            callbacks.update_reduce_lr_on_plateau,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state1, train_outputs1 = self._run_rounds(iterative_process1, 5)
        state2, train_outputs2 = self._run_rounds(iterative_process2, 5)

        self.assertAllClose(state1.model.trainable, state2.model.trainable,
                            1e-4)
        self.assertAllClose(train_outputs1, train_outputs2, 1e-4)
    def test_iterative_process_type_signature(self):
        iterative_process = decay_iterative_process_builder.from_flags(
            input_spec=get_input_spec(),
            model_builder=model_builder,
            loss_builder=loss_builder,
            metrics_builder=metrics_builder)

        dummy_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=FLAGS.client_learning_rate,
            decay_factor=FLAGS.client_decay_factor,
            min_delta=FLAGS.min_delta,
            min_lr=FLAGS.min_lr,
            window_size=FLAGS.window_size,
            patience=FLAGS.patience)
        lr_callback_type = tff.framework.type_from_tensors(dummy_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(
                mean_squared_error=tff.TensorType(tf.float32),
                loss=tff.TensorType(tf.float32)), tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=(server_state_type,
                                                    dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertTrue(actual_type.is_equivalent_to(expected_type))
Example #8
0
 def test_min_lr(self):
     lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=0.1,
                                                         decay_factor=0.5,
                                                         min_lr=0.2,
                                                         minimize=False,
                                                         window_size=1,
                                                         patience=1,
                                                         cooldown=0)
     logging.info('LR Callback: %s', lr_callback)
     self.assertEqual(lr_callback.learning_rate, 0.2)
     for i in range(5):
         x = -float(i)
         lr_callback = lr_callback.update(x)
         logging.info('LR Callback: %s', lr_callback)
         self.assertEqual(lr_callback.best, 0.0)
         self.assertEqual(lr_callback.learning_rate, 0.2)
         self.assertEqual(lr_callback.wait, i + 1)
Example #9
0
    def test_lr_decay_after_patience_rounds(self):
        lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=1.0,
                                                            decay_factor=0.5,
                                                            minimize=False,
                                                            window_size=3,
                                                            patience=5,
                                                            cooldown=0)
        logging.info('LR Callback: %s', lr_callback)
        self.assertEqual(lr_callback.metrics_window, [0.0, 0.0, 0.0])
        for i in range(4):
            lr_callback = lr_callback.update(-1.0)
            logging.info('LR Callback: %s', lr_callback)
            self.assertEqual(lr_callback.best, 0.0)
            self.assertEqual(lr_callback.learning_rate, 1.0)
            self.assertEqual(lr_callback.wait, i + 1)

        lr_callback = lr_callback.update(-1.0)
        logging.info('LR Callback: %s', lr_callback)
        self.assertEqual(lr_callback.best, 0.0)
        self.assertEqual(lr_callback.learning_rate, 0.5)
        self.assertEqual(lr_callback.wait, 0)
Example #10
0
    def test_window_with_inf_values(self):
        lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=1.0,
                                                            decay_factor=0.5,
                                                            minimize=True,
                                                            window_size=3,
                                                            patience=1,
                                                            cooldown=0)
        logging.info('LR Callback: %s', lr_callback)
        self.assertEqual(lr_callback.metrics_window,
                         [np.Inf for _ in range(3)])
        for i in range(2):
            lr_callback = lr_callback.update(3.0)
            logging.info('LR Callback: %s', lr_callback)
            self.assertEqual(lr_callback.best, np.Inf)
            self.assertEqual(lr_callback.learning_rate, (0.5)**(i + 1))
            self.assertEqual(lr_callback.wait, 0)

        lr_callback = lr_callback.update(6.0)
        logging.info('LR Callback: %s', lr_callback)
        self.assertEqual(lr_callback.best, 4.0)
        self.assertEqual(lr_callback.learning_rate, 0.25)
        self.assertEqual(lr_callback.wait, 0)
Example #11
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.client_learning_rate,
        decay_factor=FLAGS.client_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    server_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.server_learning_rate,
        decay_factor=FLAGS.server_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weight_fn=client_weight_fn)

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())

    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder

    if FLAGS.task == 'cifar100':
        hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size
        federated_cifar100.run_federated(**shared_args,
                                         crop_size=FLAGS.cifar100_crop_size,
                                         hparam_dict=hparam_dict)

    elif FLAGS.task == 'emnist_cr':
        federated_emnist.run_federated(**shared_args,
                                       emnist_model=FLAGS.emnist_cr_model,
                                       hparam_dict=hparam_dict)

    elif FLAGS.task == 'emnist_ae':
        federated_emnist_ae.run_federated(**shared_args,
                                          hparam_dict=hparam_dict)

    elif FLAGS.task == 'shakespeare':
        federated_shakespeare.run_federated(
            **shared_args,
            sequence_length=FLAGS.shakespeare_sequence_length,
            hparam_dict=hparam_dict)

    elif FLAGS.task == 'stackoverflow_nwp':
        so_nwp_flags = collections.OrderedDict()
        for flag_name in task_flags:
            if flag_name.startswith('so_nwp_'):
                so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
        federated_stackoverflow.run_federated(**shared_args,
                                              **so_nwp_flags,
                                              hparam_dict=hparam_dict)

    elif FLAGS.task == 'stackoverflow_lr':
        so_lr_flags = collections.OrderedDict()
        for flag_name in task_flags:
            if flag_name.startswith('so_lr_'):
                so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
        federated_stackoverflow_lr.run_federated(**shared_args,
                                                 **so_lr_flags,
                                                 hparam_dict=hparam_dict)
Example #12
0
    def test_build_with_preprocess_funtion(self):
        test_dataset = tf.data.Dataset.range(5)
        client_datasets_type = tff.FederatedType(
            tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_batch(x):
                return collections.OrderedDict(x=[float(x) * 1.0],
                                               y=[float(x) * 3.0 + 1.0])

            return ds.map(to_batch).repeat().batch(2).take(3)

        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            dataset_preprocess_comp=preprocess_dataset)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)
        expected_result_type = (server_state_type, output_type)

        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type,
            federated_dataset=client_datasets_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
def from_flags(input_spec,
               model_builder,
               loss_builder,
               metrics_builder,
               client_weight_fn=None):
    """Builds a `tff.templates.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    input_spec: A value convertible to a `tff.Type`, representing the data which
      will be fed into the `tff.templates.IterativeProcess.next` function over
      the course of training. Generally, this can be found by accessing the
      `element_spec` attribute of a client `tf.data.Dataset`.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.

  Returns:
    A `tff.templates.IterativeProcess` instance.
  """
    client_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.client_learning_rate,
        decay_factor=FLAGS.client_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    server_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.server_learning_rate,
        decay_factor=FLAGS.server_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    def tff_model_fn():
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    return adaptive_fed_avg.build_fed_avg_process(
        tff_model_fn,
        client_lr_callback,
        callbacks.update_reduce_lr_on_plateau,
        server_lr_callback,
        callbacks.update_reduce_lr_on_plateau,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn)
Example #14
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.client_learning_rate,
        decay_factor=FLAGS.client_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    server_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.server_learning_rate,
        decay_factor=FLAGS.server_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        dataset_preprocess_comp: Optional[tff.Computation] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.
      dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
        pipeline on the clients. The computation must take a squence of values
        and return a sequence of values, or in TFF type shorthand `(U* -> V*)`.
        If `None`, no dataset preprocessing is applied.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weight_fn=client_weight_fn,
            dataset_preprocess_comp=dataset_preprocess_comp)

    assign_weights_fn = adaptive_fed_avg.ServerState.assign_weights_to_keras_model

    common_args = collections.OrderedDict([
        ('iterative_process_builder', iterative_process_builder),
        ('assign_weights_fn', assign_weights_fn),
        ('client_epochs_per_round', FLAGS.client_epochs_per_round),
        ('client_batch_size', FLAGS.client_batch_size),
        ('clients_per_round', FLAGS.clients_per_round),
        ('max_batches_per_client', FLAGS.max_batches_per_client),
        ('client_datasets_random_seed', FLAGS.client_datasets_random_seed)
    ])

    if FLAGS.task == 'cifar100':
        federated_cifar100.run_federated(**common_args,
                                         crop_size=FLAGS.cifar100_crop_size)

    elif FLAGS.task == 'emnist_cr':
        federated_emnist.run_federated(**common_args,
                                       emnist_model=FLAGS.emnist_cr_model)

    elif FLAGS.task == 'emnist_ae':
        federated_emnist_ae.run_federated(**common_args)

    elif FLAGS.task == 'shakespeare':
        federated_shakespeare.run_federated(
            **common_args, sequence_length=FLAGS.shakespeare_sequence_length)

    elif FLAGS.task == 'stackoverflow_nwp':
        so_nwp_flags = collections.OrderedDict()
        for flag_name in FLAGS:
            if flag_name.startswith('so_nwp_'):
                so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
        federated_stackoverflow.run_federated(**common_args, **so_nwp_flags)

    elif FLAGS.task == 'stackoverflow_lr':
        so_lr_flags = collections.OrderedDict()
        for flag_name in FLAGS:
            if flag_name.startswith('so_lr_'):
                so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
        federated_stackoverflow_lr.run_federated(**common_args, **so_lr_flags)

    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))