def test_comparable_to_fed_avg(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        reference_iterative_process = tff.learning.build_federated_averaging_process(
            _uncompiled_model_builder,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
            server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0))

        _, train_outputs = self._run_rounds(iterative_process, 5)
        _, reference_train_outputs = self._run_rounds_tff_fedavg(
            reference_iterative_process, 5)

        for i in range(5):
            self.assertAllClose(train_outputs[i]['during_training']['loss'],
                                reference_train_outputs[i]['loss'], 1e-4)
    def test_fed_sgd_without_decay_decreases_loss(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.0,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state, train_outputs = self._run_rounds(iterative_process, 5)
        self.assertLess(train_outputs[-1]['before_training']['loss'],
                        train_outputs[0]['before_training']['loss'])
        self.assertLess(train_outputs[-1]['during_training']['loss'],
                        train_outputs[0]['during_training']['loss'])
        self.assertNear(state.client_lr_callback.learning_rate, 0.0, 1e-8)
        self.assertNear(state.server_lr_callback.learning_rate, 0.1, 1e-8)
    def test_iterative_process_type_signature(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type, federated_dataset=dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
    def test_small_lr_comparable_zero_lr(self):
        client_lr_callback1 = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.0,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        client_lr_callback2 = callbacks.create_reduce_lr_on_plateau(
            learning_rate=1e-8,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process1 = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback1,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)
        iterative_process2 = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback2,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state1, train_outputs1 = self._run_rounds(iterative_process1, 5)
        state2, train_outputs2 = self._run_rounds(iterative_process2, 5)

        self.assertAllClose(state1.model.trainable, state2.model.trainable,
                            1e-4)
        self.assertAllClose(train_outputs1, train_outputs2, 1e-4)
示例#5
0
    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn)
    def test_get_model_weights(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            window_size=1,
            patience=1,
            decay_factor=1.0,
            cooldown=0)

        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            window_size=1,
            patience=1,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state = iterative_process.initialize()

        self.assertIsInstance(iterative_process.get_model_weights(state),
                              tff.learning.ModelWeights)
        self.assertAllClose(
            state.model.trainable,
            iterative_process.get_model_weights(state).trainable)

        state, _ = self._run_rounds(iterative_process, 5)

        self.assertIsInstance(iterative_process.get_model_weights(state),
                              tff.learning.ModelWeights)
        self.assertAllClose(
            state.model.trainable,
            iterative_process.get_model_weights(state).trainable)
示例#7
0
  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return adaptive_fed_avg.build_fed_avg_process(
        model_fn,
        client_lr_callback,
        server_lr_callback,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn)