コード例 #1
0
    def test_fed_sgd_without_decay_decreases_loss(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.0,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state, train_outputs = self._run_rounds(iterative_process, 5)
        self.assertLess(train_outputs[-1]['before_training']['loss'],
                        train_outputs[0]['before_training']['loss'])
        self.assertLess(train_outputs[-1]['during_training']['loss'],
                        train_outputs[0]['during_training']['loss'])
        self.assertNear(state.client_lr_callback.learning_rate, 0.0, 1e-8)
        self.assertNear(state.server_lr_callback.learning_rate, 0.1, 1e-8)
コード例 #2
0
    def test_comparable_to_fed_avg(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        reference_iterative_process = tff.learning.build_federated_averaging_process(
            _uncompiled_model_builder,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
            server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0))

        _, train_outputs = self._run_rounds(iterative_process, 5)
        _, reference_train_outputs = self._run_rounds_tff_fedavg(
            reference_iterative_process, 5)

        for i in range(5):
            self.assertAllClose(train_outputs[i]['during_training']['loss'],
                                reference_train_outputs[i]['loss'], 1e-4)
コード例 #3
0
    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        dataset_preprocess_comp: Optional[tff.Computation] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.
      dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
        pipeline on the clients. The computation must take a squence of values
        and return a sequence of values, or in TFF type shorthand `(U* -> V*)`.
        If `None`, no dataset preprocessing is applied.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weight_fn=client_weight_fn,
            dataset_preprocess_comp=dataset_preprocess_comp)
コード例 #4
0
    def test_iterative_process_type_signature(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type, federated_dataset=dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
コード例 #5
0
    def test_small_lr_comparable_zero_lr(self):
        client_lr_callback1 = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.0,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        client_lr_callback2 = callbacks.create_reduce_lr_on_plateau(
            learning_rate=1e-8,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process1 = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback1,
            callbacks.update_reduce_lr_on_plateau,
            server_lr_callback,
            callbacks.update_reduce_lr_on_plateau,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)
        iterative_process2 = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback2,
            callbacks.update_reduce_lr_on_plateau,
            server_lr_callback,
            callbacks.update_reduce_lr_on_plateau,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state1, train_outputs1 = self._run_rounds(iterative_process1, 5)
        state2, train_outputs2 = self._run_rounds(iterative_process2, 5)

        self.assertAllClose(state1.model.trainable, state2.model.trainable,
                            1e-4)
        self.assertAllClose(train_outputs1, train_outputs2, 1e-4)
コード例 #6
0
    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weight_fn=client_weight_fn)
コード例 #7
0
    def test_build_with_preprocess_funtion(self):
        test_dataset = tf.data.Dataset.range(5)
        client_datasets_type = tff.FederatedType(
            tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_batch(x):
                return collections.OrderedDict(x=[float(x) * 1.0],
                                               y=[float(x) * 3.0 + 1.0])

            return ds.map(to_batch).repeat().batch(2).take(3)

        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            dataset_preprocess_comp=preprocess_dataset)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)
        expected_result_type = (server_state_type, output_type)

        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type,
            federated_dataset=client_datasets_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
def from_flags(input_spec,
               model_builder,
               loss_builder,
               metrics_builder,
               client_weight_fn=None):
    """Builds a `tff.templates.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    input_spec: A value convertible to a `tff.Type`, representing the data which
      will be fed into the `tff.templates.IterativeProcess.next` function over
      the course of training. Generally, this can be found by accessing the
      `element_spec` attribute of a client `tf.data.Dataset`.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.

  Returns:
    A `tff.templates.IterativeProcess` instance.
  """
    client_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.client_learning_rate,
        decay_factor=FLAGS.client_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    server_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.server_learning_rate,
        decay_factor=FLAGS.server_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    def tff_model_fn():
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    return adaptive_fed_avg.build_fed_avg_process(
        tff_model_fn,
        client_lr_callback,
        callbacks.update_reduce_lr_on_plateau,
        server_lr_callback,
        callbacks.update_reduce_lr_on_plateau,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn)