Esempio n. 1
0
    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model]
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """
        if FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp':

            def client_weight_fn(local_outputs):
                return tf.cast(tf.squeeze(local_outputs['num_tokens']),
                               tf.float32)
        else:
            client_weight_fn = None

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)
    def test_execute_with_preprocess_function(self):
        test_dataset = tf.data.Dataset.range(1)

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_example(x):
                del x  # Unused.
                return collections.OrderedDict(x=[3.0], y=[2.0])

            return ds.map(to_example).batch(1)

        iterproc = fed_avg_schedule.build_fed_avg_process(
            model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            client_lr=0.01,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
            preprocess_dataset, iterproc)

        _, train_outputs, _ = self._run_rounds(iterproc, [test_dataset], 6)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
        train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2][
            'loss']
        train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5][
            'loss']
        self.assertLess(train_gap_second_half, train_gap_first_half)
Esempio n. 3
0
def iterative_process_builder(model_fn, client_weight_fn=None):
    return fed_avg_schedule.build_fed_avg_process(
        model_fn=model_fn,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        client_lr=0.1,
        server_optimizer_fn=tf.keras.optimizers.SGD,
        server_lr=1.0,
        client_weight_fn=client_weight_fn)
Esempio n. 4
0
def from_flags(
    input_spec,
    model_builder: ModelBuilder,
    loss_builder: LossBuilder,
    metrics_builder: MetricsBuilder,
    client_weight_fn: Optional[ClientWeightFn] = None,
) -> tff.templates.IterativeProcess:
    """Builds a `tff.templates.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    input_spec: A value convertible to a `tff.Type`, representing the data which
      will be fed into the `tff.templates.IterativeProcess.next` function over
      the course of training. Generally, this can be found by accessing the
      `element_spec` attribute of a client `tf.data.Dataset`.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras
    # model.
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    model_input_spec = input_spec

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=model_input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)
    def test_fed_avg_without_schedule_decreases_loss(self):
        federated_data = [[_batch_fn()]]

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 5)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
    def test_server_update_with_nan_data_is_noop(self):
        federated_data = [[_batch_fn(has_nan=True)]]

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state, _, initial_state = self._run_rounds(iterproc, federated_data, 1)
        self.assertAllClose(state.model.trainable,
                            initial_state.model.trainable, 1e-8)
        self.assertAllClose(state.model.non_trainable,
                            initial_state.model.non_trainable, 1e-8)
    def test_fed_avg_with_custom_client_weight_fn(self):
        federated_data = [[_batch_fn()]]

        def client_weight_fn(local_outputs):
            return 1.0 / (1.0 + local_outputs['loss'][-1])

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            client_weight_fn=client_weight_fn)

        _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 5)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
Esempio n. 8
0
  def test_fed_avg_with_client_and_server_schedules(self):
    federated_data = [[_batch_fn()]]

    iterproc = fed_avg_schedule.build_fed_avg_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        client_lr=lambda x: 0.1 / (x + 1)**2,
        server_optimizer_fn=tf.keras.optimizers.SGD,
        server_lr=lambda x: 1.0 / (x + 1)**2)

    _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 6)
    self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
    train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2]['loss']
    train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5]['loss']
    self.assertLess(train_gap_second_half, train_gap_first_half)
Esempio n. 9
0
  def test_build_with_preprocess_function(self):
    test_dataset = tf.data.Dataset.range(5)
    client_datasets_type = tff.type_at_clients(
        tff.SequenceType(test_dataset.element_spec))

    @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
    def preprocess_dataset(ds):

      def to_batch(x):
        return _Batch(
            tf.fill(dims=(784,), value=float(x) * 2.0),
            tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0))

      return ds.map(to_batch).batch(2)

    iterproc = fed_avg_schedule.build_fed_avg_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        server_optimizer_fn=tf.keras.optimizers.SGD)

    iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
        preprocess_dataset, iterproc)

    with tf.Graph().as_default():
      test_model_for_types = _uncompiled_model_builder()

    server_state_type = tff.FederatedType(
        fed_avg_schedule.ServerState(
            model=tff.framework.type_from_tensors(
                tff.learning.ModelWeights(
                    test_model_for_types.trainable_variables,
                    test_model_for_types.non_trainable_variables)),
            optimizer_state=(tf.int64,),
            round_num=tf.float32), tff.SERVER)
    metrics_type = test_model_for_types.federated_output_computation.type_signature.result

    expected_parameter_type = collections.OrderedDict(
        server_state=server_state_type,
        federated_dataset=client_datasets_type,
    )
    expected_result_type = (server_state_type, metrics_type)

    expected_type = tff.FunctionType(
        parameter=expected_parameter_type, result=expected_result_type)
    self.assertTrue(
        iterproc.next.type_signature.is_equivalent_to(expected_type),
        msg='{s}\n!={t}'.format(
            s=iterproc.next.type_signature, t=expected_type))
    def test_server_update_with_inf_weight_is_noop(self):
        federated_data = [create_dataset()]
        client_weight_fn = lambda x: np.inf

        iterproc = fed_avg_schedule.build_fed_avg_process(
            model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            client_lr=0.01,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            client_weight_fn=client_weight_fn)

        state, _, initial_state = self._run_rounds(iterproc, federated_data, 1)
        self.assertAllClose(state.model.trainable,
                            initial_state.model.trainable, 1e-8)
        self.assertAllClose(state.model.non_trainable,
                            initial_state.model.non_trainable, 1e-8)
Esempio n. 11
0
  def test_fed_avg_with_server_schedule(self):
    federated_data = [[_batch_fn()]]

    @tf.function
    def lr_schedule(x):
      return 1.0 if x < 1.5 else 0.0

    iterproc = fed_avg_schedule.build_fed_avg_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        server_optimizer_fn=tf.keras.optimizers.SGD,
        server_lr=lr_schedule)

    _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 4)
    self.assertLess(train_outputs[1]['loss'], train_outputs[0]['loss'])
    self.assertNear(
        train_outputs[2]['loss'], train_outputs[3]['loss'], err=1e-4)
    def test_get_model_weights(self):
        federated_data = [create_dataset()]

        iterative_process = fed_avg_schedule.build_fed_avg_process(
            model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            client_lr=0.01,
            server_optimizer_fn=tf.keras.optimizers.SGD)
        state = iterative_process.initialize()

        self.assertIsInstance(iterative_process.get_model_weights(state),
                              tff.learning.ModelWeights)
        self.assertAllClose(
            state.model.trainable,
            iterative_process.get_model_weights(state).trainable)

        for _ in range(3):
            state, _ = iterative_process.next(state, federated_data)
            self.assertIsInstance(iterative_process.get_model_weights(state),
                                  tff.learning.ModelWeights)
            self.assertAllClose(
                state.model.trainable,
                iterative_process.get_model_weights(state).trainable)
Esempio n. 13
0
    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)