예제 #1
0
    def test_fed_sgd_without_decay_decreases_loss(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.0,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state, train_outputs = self._run_rounds(iterative_process, 5)
        self.assertLess(train_outputs[-1]['before_training']['loss'],
                        train_outputs[0]['before_training']['loss'])
        self.assertLess(train_outputs[-1]['during_training']['loss'],
                        train_outputs[0]['during_training']['loss'])
        self.assertNear(state.client_lr_callback.learning_rate, 0.0, 1e-8)
        self.assertNear(state.server_lr_callback.learning_rate, 0.1, 1e-8)
예제 #2
0
    def test_comparable_to_fed_avg(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        reference_iterative_process = tff.learning.build_federated_averaging_process(
            _uncompiled_model_builder,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
            server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0))

        _, train_outputs = self._run_rounds(iterative_process, 5)
        _, reference_train_outputs = self._run_rounds_tff_fedavg(
            reference_iterative_process, 5)

        for i in range(5):
            self.assertAllClose(train_outputs[i]['during_training']['loss'],
                                reference_train_outputs[i]['loss'], 1e-4)
예제 #3
0
    def test_iterative_process_type_signature(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type, federated_dataset=dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
예제 #4
0
 def test_raises_bad_decay_factor(self):
     with self.assertRaises(ValueError):
         callbacks.create_reduce_lr_on_plateau(learning_rate=0.1,
                                               decay_factor=2.0,
                                               cooldown=0)
     with self.assertRaises(ValueError):
         callbacks.create_reduce_lr_on_plateau(learning_rate=0.1,
                                               decay_factor=-1.0)
예제 #5
0
    def test_small_lr_comparable_zero_lr(self):
        client_lr_callback1 = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.0,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        client_lr_callback2 = callbacks.create_reduce_lr_on_plateau(
            learning_rate=1e-8,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process1 = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback1,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)
        iterative_process2 = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback2,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state1, train_outputs1 = self._run_rounds(iterative_process1, 5)
        state2, train_outputs2 = self._run_rounds(iterative_process2, 5)

        self.assertAllClose(state1.model.trainable, state2.model.trainable,
                            1e-4)
        self.assertAllClose(train_outputs1, train_outputs2, 1e-4)
예제 #6
0
    def test_get_model_weights(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            window_size=1,
            patience=1,
            decay_factor=1.0,
            cooldown=0)

        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            window_size=1,
            patience=1,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state = iterative_process.initialize()

        self.assertIsInstance(iterative_process.get_model_weights(state),
                              tff.learning.ModelWeights)
        self.assertAllClose(
            state.model.trainable,
            iterative_process.get_model_weights(state).trainable)

        state, _ = self._run_rounds(iterative_process, 5)

        self.assertIsInstance(iterative_process.get_model_weights(state),
                              tff.learning.ModelWeights)
        self.assertAllClose(
            state.model.trainable,
            iterative_process.get_model_weights(state).trainable)
예제 #7
0
 def test_min_lr(self):
     lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=0.1,
                                                         decay_factor=0.5,
                                                         min_lr=0.2,
                                                         minimize=False,
                                                         window_size=1,
                                                         patience=1,
                                                         cooldown=0)
     logging.info('LR Callback: %s', lr_callback)
     self.assertEqual(lr_callback.learning_rate, 0.2)
     for i in range(5):
         x = -float(i)
         lr_callback = lr_callback.update(x)
         logging.info('LR Callback: %s', lr_callback)
         self.assertEqual(lr_callback.best, 0.0)
         self.assertEqual(lr_callback.learning_rate, 0.2)
         self.assertEqual(lr_callback.wait, i + 1)
예제 #8
0
    def test_lr_decay_after_patience_rounds(self):
        lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=1.0,
                                                            decay_factor=0.5,
                                                            minimize=False,
                                                            window_size=3,
                                                            patience=5,
                                                            cooldown=0)
        logging.info('LR Callback: %s', lr_callback)
        self.assertEqual(lr_callback.metrics_window, [0.0, 0.0, 0.0])
        for i in range(4):
            lr_callback = lr_callback.update(-1.0)
            logging.info('LR Callback: %s', lr_callback)
            self.assertEqual(lr_callback.best, 0.0)
            self.assertEqual(lr_callback.learning_rate, 1.0)
            self.assertEqual(lr_callback.wait, i + 1)

        lr_callback = lr_callback.update(-1.0)
        logging.info('LR Callback: %s', lr_callback)
        self.assertEqual(lr_callback.best, 0.0)
        self.assertEqual(lr_callback.learning_rate, 0.5)
        self.assertEqual(lr_callback.wait, 0)
예제 #9
0
    def test_window_with_inf_values(self):
        lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=1.0,
                                                            decay_factor=0.5,
                                                            minimize=True,
                                                            window_size=3,
                                                            patience=1,
                                                            cooldown=0)
        logging.info('LR Callback: %s', lr_callback)
        self.assertEqual(lr_callback.metrics_window,
                         [np.Inf for _ in range(3)])
        for i in range(2):
            lr_callback = lr_callback.update(3.0)
            logging.info('LR Callback: %s', lr_callback)
            self.assertEqual(lr_callback.best, np.Inf)
            self.assertEqual(lr_callback.learning_rate, (0.5)**(i + 1))
            self.assertEqual(lr_callback.wait, 0)

        lr_callback = lr_callback.update(6.0)
        logging.info('LR Callback: %s', lr_callback)
        self.assertEqual(lr_callback.best, 4.0)
        self.assertEqual(lr_callback.learning_rate, 0.25)
        self.assertEqual(lr_callback.wait, 0)
예제 #10
0
 def test_cooldown(self):
     lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=2.0,
                                                         decay_factor=0.5,
                                                         minimize=False,
                                                         window_size=1,
                                                         patience=0,
                                                         cooldown=3)
     logging.info('LR Callback: %s', lr_callback)
     self.assertEqual(lr_callback.learning_rate, 2.0)
     self.assertEqual(lr_callback.cooldown, 3)
     self.assertEqual(lr_callback.cooldown_counter, 3)
     for i in range(2):
         lr_callback = lr_callback.update(-1.0)
         logging.info('LR Callback: %s', lr_callback)
         self.assertEqual(lr_callback.learning_rate, 2.0)
         self.assertEqual(lr_callback.wait, 0)
         self.assertEqual(lr_callback.cooldown, 3)
         self.assertEqual(lr_callback.cooldown_counter, 2 - i)
     lr_callback = lr_callback.update(-1.0)
     logging.info('LR Callback: %s', lr_callback)
     self.assertEqual(lr_callback.learning_rate, 1.0)
     self.assertEqual(lr_callback.wait, 0)
     self.assertEqual(lr_callback.cooldown, 3)
     self.assertEqual(lr_callback.cooldown_counter, 3)
예제 #11
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.client_learning_rate,
        decay_factor=FLAGS.client_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    server_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.server_learning_rate,
        decay_factor=FLAGS.server_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec, crop_size=FLAGS.cifar100_crop_size)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile)
예제 #12
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_callback = callbacks.create_reduce_lr_on_plateau(
      learning_rate=FLAGS.client_learning_rate,
      decay_factor=FLAGS.client_decay_factor,
      min_delta=FLAGS.min_delta,
      min_lr=FLAGS.min_lr,
      window_size=FLAGS.window_size,
      patience=FLAGS.patience)

  server_lr_callback = callbacks.create_reduce_lr_on_plateau(
      learning_rate=FLAGS.server_learning_rate,
      decay_factor=FLAGS.server_decay_factor,
      min_delta=FLAGS.min_delta,
      min_lr=FLAGS.min_lr,
      window_size=FLAGS.window_size,
      patience=FLAGS.patience)

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return adaptive_fed_avg.build_fed_avg_process(
        model_fn,
        client_lr_callback,
        server_lr_callback,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn)

  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())

  shared_args = utils_impl.lookup_flag_values(shared_flags)
  shared_args['iterative_process_builder'] = iterative_process_builder

  if FLAGS.task == 'cifar100':
    hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size
    federated_cifar100.run_federated(
        **shared_args,
        crop_size=FLAGS.cifar100_crop_size,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_cr':
    federated_emnist.run_federated(
        **shared_args, model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_ae':
    federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict)

  elif FLAGS.task == 'shakespeare':
    federated_shakespeare.run_federated(
        **shared_args,
        sequence_length=FLAGS.shakespeare_sequence_length,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_nwp':
    so_nwp_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_nwp_'):
        so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
    federated_stackoverflow.run_federated(
        **shared_args, **so_nwp_flags, hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_lr':
    so_lr_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_lr_'):
        so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
    federated_stackoverflow_lr.run_federated(
        **shared_args, **so_lr_flags, hparam_dict=hparam_dict)