Exemplo n.º 1
0
  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
      dataset_preprocess_comp: Optional[tff.Computation] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.
      dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
        pipeline on the clients. The computation must take a squence of values
        and return a sequence of values, or in TFF type shorthand `(U* -> V*)`.
        If `None`, no dataset preprocessing is applied.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn,
        dataset_preprocess_comp=dataset_preprocess_comp)
Exemplo n.º 2
0
    def test_execute_with_preprocess_function(self):
        test_dataset = tf.data.Dataset.range(1)

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_example(x):
                del x  # Unused.
                return _Batch(x=np.ones([784], dtype=np.float32),
                              y=np.ones([1], dtype=np.int64))

            return ds.map(to_example).batch(1)

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            dataset_preprocess_comp=preprocess_dataset)

        _, train_outputs, _ = self._run_rounds(iterproc, [test_dataset], 6)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
        train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2][
            'loss']
        train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5][
            'loss']
        self.assertLess(train_gap_second_half, train_gap_first_half)
    def test_build_evaluate_fn(self):

        loss_builder = tf.keras.losses.MeanSquaredError
        metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]

        def tff_model_fn():
            return tff.learning.from_keras_model(
                keras_model=model_builder(),
                dummy_batch=get_sample_batch(),
                loss=loss_builder(),
                metrics=metrics_builder())

        iterative_process = fed_avg_schedule.build_fed_avg_process(
            tff_model_fn, client_optimizer_fn=tf.keras.optimizers.SGD)

        state = iterative_process.initialize()

        test_dataset = create_tf_dataset_for_client(1)

        evaluate_fn = training_utils.build_evaluate_fn(test_dataset,
                                                       model_builder,
                                                       loss_builder,
                                                       metrics_builder)
        test_metrics = evaluate_fn(state)
        self.assertIn('loss', test_metrics)
Exemplo n.º 4
0
def iterative_process_builder(model_fn, client_weight_fn=None):
    return fed_avg_schedule.build_fed_avg_process(
        model_fn=model_fn,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        client_lr=0.1,
        server_optimizer_fn=tf.keras.optimizers.SGD,
        server_lr=1.0,
        client_weight_fn=client_weight_fn)
Exemplo n.º 5
0
def from_flags(
    input_spec,
    model_builder: ModelBuilder,
    loss_builder: LossBuilder,
    metrics_builder: MetricsBuilder,
    client_weight_fn: Optional[ClientWeightFn] = None,
) -> tff.templates.IterativeProcess:
    """Builds a `tff.templates.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    input_spec: A value convertible to a `tff.Type`, representing the data which
      will be fed into the `tff.templates.IterativeProcess.next` function over
      the course of training. Generally, this can be found by accessing the
      `element_spec` attribute of a client `tf.data.Dataset`.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras
    # model.
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    model_input_spec = input_spec

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=model_input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)
Exemplo n.º 6
0
    def test_server_update_with_nan_data_is_noop(self):
        federated_data = [[_batch_fn(has_nan=True)]]

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state, _, initial_state = self._run_rounds(iterproc, federated_data, 1)
        self.assertAllClose(state.model, initial_state.model, 1e-8)
Exemplo n.º 7
0
    def test_fed_avg_without_schedule_decreases_loss(self):
        federated_data = [[_batch_fn()]]

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 5)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def from_flags(dummy_batch,
               model_builder,
               loss_builder,
               metrics_builder,
               client_weight_fn=None):
    """Builds a `tff.utils.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    dummy_batch: A nested structure of values that are convertible to batched
      tensors with the same shapes and types as expected in the forward pass of
      training. The actual values are not important and can hold any reasonable
      value.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.

  Returns:
    A `tff.utils.IterativeProcess` instance.
  """
    # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras
    # model.
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def tff_model_fn():
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             dummy_batch=dummy_batch,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)
  def test_conversion_from_tff_result(self):
    federated_data = [[_batch_fn()]]

    iterative_process = fed_avg_schedule.build_fed_avg_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        server_optimizer_fn=tf.keras.optimizers.SGD)

    state, _, _ = self._run_rounds(iterative_process, federated_data, 1)
    converted_state = fed_avg_schedule.ServerState.from_tff_result(state)
    self.assertIsInstance(converted_state, fed_avg_schedule.ServerState)
    self.assertIsInstance(converted_state.model, fed_avg_schedule.ModelWeights)
Exemplo n.º 10
0
    def test_server_update_with_inf_weight_is_noop(self):
        federated_data = [[_batch_fn()]]
        client_weight_fn = lambda x: np.inf

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            client_weight_fn=client_weight_fn)

        state, _, initial_state = self._run_rounds(iterproc, federated_data, 1)
        self.assertAllClose(state.model, initial_state.model, 1e-8)
Exemplo n.º 11
0
    def test_fed_avg_with_custom_client_weight_fn(self):
        federated_data = [[_batch_fn()]]

        def client_weight_fn(local_outputs):
            return 1.0 / (1.0 + local_outputs['loss'][-1])

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            client_weight_fn=client_weight_fn)

        _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 5)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
Exemplo n.º 12
0
  def test_fed_avg_with_client_and_server_schedules(self):
    federated_data = [[_batch_fn()]]

    iterproc = fed_avg_schedule.build_fed_avg_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        client_lr=lambda x: 0.1 / (x + 1)**2,
        server_optimizer_fn=tf.keras.optimizers.SGD,
        server_lr=lambda x: 1.0 / (x + 1)**2)

    _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 6)
    self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
    train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2]['loss']
    train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5]['loss']
    self.assertLess(train_gap_second_half, train_gap_first_half)
Exemplo n.º 13
0
  def test_build_with_preprocess_function(self):
    test_dataset = tf.data.Dataset.range(5)
    client_datasets_type = tff.FederatedType(
        tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)

    @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
    def preprocess_dataset(ds):

      def to_batch(x):
        return _Batch(
            tf.fill(dims=(784,), value=float(x) * 2.0),
            tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0))

      return ds.map(to_batch).batch(2)

    iterproc = fed_avg_schedule.build_fed_avg_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        server_optimizer_fn=tf.keras.optimizers.SGD)

    iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
        preprocess_dataset, iterproc)

    with tf.Graph().as_default():
      test_model_for_types = _uncompiled_model_builder()

    server_state_type = tff.FederatedType(
        fed_avg_schedule.ServerState(
            model=tff.framework.type_from_tensors(
                tff.learning.ModelWeights(
                    test_model_for_types.trainable_variables,
                    test_model_for_types.non_trainable_variables)),
            optimizer_state=(tf.int64,),
            round_num=tf.float32), tff.SERVER)
    metrics_type = test_model_for_types.federated_output_computation.type_signature.result

    expected_parameter_type = collections.OrderedDict(
        server_state=server_state_type,
        federated_dataset=client_datasets_type,
    )
    expected_result_type = (server_state_type, metrics_type)

    expected_type = tff.FunctionType(
        parameter=expected_parameter_type, result=expected_result_type)
    self.assertTrue(
        iterproc.next.type_signature.is_equivalent_to(expected_type),
        msg='{s}\n!={t}'.format(
            s=iterproc.next.type_signature, t=expected_type))
Exemplo n.º 14
0
  def test_fed_avg_with_server_schedule(self):
    federated_data = [[_batch_fn()]]

    @tf.function
    def lr_schedule(x):
      return 1.0 if x < 1.5 else 0.0

    iterproc = fed_avg_schedule.build_fed_avg_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        server_optimizer_fn=tf.keras.optimizers.SGD,
        server_lr=lr_schedule)

    _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 4)
    self.assertLess(train_outputs[1]['loss'], train_outputs[0]['loss'])
    self.assertNear(
        train_outputs[2]['loss'], train_outputs[3]['loss'], err=1e-4)
Exemplo n.º 15
0
    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)
Exemplo n.º 16
0
def _build_federated_averaging_process():
    return fed_avg_schedule.build_fed_avg_process(
        _uncompiled_model_fn,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        server_optimizer_fn=tf.keras.optimizers.SGD)
def from_flags(
    input_spec,
    model_builder: ModelBuilder,
    loss_builder: LossBuilder,
    metrics_builder: MetricsBuilder,
    client_weight_fn: Optional[ClientWeightFn] = None,
    *,
    dataset_preprocess_comp: Optional[tff.Computation] = None,
) -> fed_avg_schedule.FederatedAveragingProcessAdapter:
  """Builds a `tff.templates.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    input_spec: A value convertible to a `tff.Type`, representing the data which
      will be fed into the `tff.templates.IterativeProcess.next` function over
      the course of training. Generally, this can be found by accessing the
      `element_spec` attribute of a client `tf.data.Dataset`.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.
    dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
      pipeline on the clients. The computation must take a squence of values
      and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If
      `None`, no dataset preprocessing is applied. If specified, `input_spec` is
      optinal, as the necessary type signatures will taken from the computation.

  Returns:
    A `fed_avg_schedule.FederatedAveragingProcessAdapter`.
  """
  # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras
  # model.
  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client')
  server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server')

  if dataset_preprocess_comp is not None:
    if input_spec is not None:
      print('Specified both `dataset_preprocess_comp` and `input_spec` when '
            'only one is necessary. Ignoring `input_spec` and using type '
            'signature of `dataset_preprocess_comp`.')
    model_input_spec = dataset_preprocess_comp.type_signature.result.element
  else:
    model_input_spec = input_spec

  def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=model_input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  return fed_avg_schedule.build_fed_avg_process(
      model_fn=tff_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      client_lr=client_lr_schedule,
      server_optimizer_fn=server_optimizer_fn,
      server_lr=server_lr_schedule,
      client_weight_fn=client_weight_fn,
      dataset_preprocess_comp=dataset_preprocess_comp)