def robust_aggregation_fn(value, weight):
     aggregate = tff.federated_mean(value, weight=weight)
     for _ in range(num_communication_passes - 1):
         aggregate_at_client = tff.federated_broadcast(aggregate)
         updated_weight = tff.federated_map(
             update_weight_fn, (weight, aggregate_at_client, value))
         aggregate = tff.federated_mean(value, weight=updated_weight)
     return aggregate
 def next_comp(state, value):
     return collections.OrderedDict(
         state=tff.federated_map(_add_one, state),
         result=tff.federated_broadcast(value),
         # Arbitrary metrics for testing.
         measurements=tff.federated_map(
             tff.tf_computation(
                 lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0),
             value))
  def personalization_eval(server_model_weights, federated_client_input):
    """TFF orchestration logic."""
    client_init_weights = tff.federated_broadcast(server_model_weights)
    client_final_metrics = tff.federated_map(
        _client_computation, (client_init_weights, federated_client_input))

    # WARNING: Collecting information from clients can be risky. Users have to
    # make sure that it is proper to collect those metrics from clients.
    # TODO(b/147889283): Add a link to the TFF doc once it exists.
    results = tff.utils.federated_sample(client_final_metrics, max_num_samples)
    return results
示例#4
0
 def test_fails_stateful_broadcast_and_process(self):
     model_weights_type = model_utils.weights_type_from_model(
         model_examples.LinearRegression)
     with self.assertRaises(optimizer_utils.DisjointArgumentError):
         federated_averaging.build_federated_averaging_process(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=tf.keras.optimizers.SGD,
             stateful_model_broadcast_fn=tff.utils.StatefulBroadcastFn(
                 initialize_fn=lambda: (),
                 next_fn=lambda state, weights:  # pylint: disable=g-long-lambda
                 (state, tff.federated_broadcast(weights))),
             broadcast_process=optimizer_utils.build_stateless_broadcaster(
                 model_weights_type=model_weights_type))
示例#5
0
 def test_fails_stateful_broadcast_and_process(self):
     with tf.Graph().as_default():
         model_weights_type = tff.framework.type_from_tensors(
             model_utils.ModelWeights.from_model(
                 model_examples.LinearRegression()))
     with self.assertRaises(optimizer_utils.DisjointArgumentError):
         optimizer_utils.build_model_delta_optimizer_process(
             model_fn=model_examples.LinearRegression,
             model_to_client_delta_fn=DummyClientDeltaFn,
             server_optimizer_fn=tf.keras.optimizers.SGD,
             stateful_model_broadcast_fn=tff.utils.StatefulBroadcastFn(
                 initialize_fn=lambda: (),
                 next_fn=lambda state, weights:  # pylint: disable=g-long-lambda
                 (state, tff.federated_broadcast(weights))),
             broadcast_process=optimizer_utils.build_stateless_broadcaster(
                 model_weights_type=model_weights_type))
示例#6
0
def _state_incrementing_broadcast_next(server_state, server_value):
    add_one = tff.tf_computation(lambda x: x + 1, tf.int32)
    new_state = tff.federated_map(add_one, server_state)
    return (new_state, tff.federated_broadcast(server_value))
示例#7
0
def build_stateless_broadcaster():
    """Just tff.federated_broadcast with empty state, to use as a default."""
    return tff.utils.StatefulBroadcastFn(
        initialize_fn=lambda: (),
        next_fn=lambda state, value: (  # pylint: disable=g-long-lambda
            state, tff.federated_broadcast(value)))
示例#8
0
 def federated_train(model, learning_rate, data):
     return tff.federated_average(
         tff.federated_map(local_train, [
             tff.federated_broadcast(model),
             tff.federated_broadcast(learning_rate), data
         ]))
def _state_incrementing_broadcast_next(server_state, server_value):
    new_state = tff.federated_map(_add_one, server_state)
    return (new_state, tff.federated_broadcast(server_value))
示例#10
0
  def run_one_round_tff(server_state, federated_dataset):
    """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
    model_weights_type = federated_server_state_type.member.model

    @tff.tf_computation(tf_dataset_type, model_weights_type)
    def client_delta_tf(tf_dataset, initial_model_weights):
      """Performs client local model optimization.

      Args:
        tf_dataset: a `tf.data.Dataset` that provides training examples.
        initial_model_weights: a `model_utils.ModelWeights` containing the
          starting weights.

      Returns:
        A `ClientOutput` structure.
      """
      client_delta_fn = model_to_client_delta_fn(model_fn)

      # TODO(b/123092620): this can be removed once AnonymousTuple works with
      # tf.contrib.framework.nest, or the following behavior is moved to
      # anonymous_tuple module.
      if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple):
        initial_model_weights = model_utils.ModelWeights.from_tff_value(
            initial_model_weights)

      client_output = client_delta_fn(tf_dataset, initial_model_weights)
      return client_output

    client_outputs = tff.federated_map(
        client_delta_tf,
        (federated_dataset, tff.federated_broadcast(server_state.model)))

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_model_tf(server_state, model_delta):
      """Converts args to correct python types and calls server_update_model."""
      # We need to convert TFF types to the types server_update_model expects.
      # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not
      # pretty, fold this into anonymous_tuple module or get working with
      # tf.contrib.framework.nest.
      py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple)
      model_delta = anonymous_tuple.to_odict(model_delta)
      py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple)
      server_state = ServerState(
          model=model_utils.ModelWeights.from_tff_value(server_state.model),
          optimizer_state=list(server_state.optimizer_state))

      return server_update_model(
          server_state,
          model_delta,
          model_fn=model_fn,
          optimizer_fn=server_optimizer_fn)

    # TODO(b/124070381): We hope to remove this explicit cast once we have a
    # full solution for type analysis in multiplications and divisions
    # inside TFF
    fed_weight_type = client_outputs.weights_delta_weight.type_signature.member
    py_typecheck.check_type(fed_weight_type, tff.TensorType)
    if fed_weight_type.dtype.is_integer:

      @tff.tf_computation(fed_weight_type)
      def _cast_to_float(x):
        return tf.cast(x, tf.float32)

      weight_denom = tff.federated_map(_cast_to_float,
                                       client_outputs.weights_delta_weight)
    else:
      weight_denom = client_outputs.weights_delta_weight
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    # TODO(b/123408447): remove tff.federated_apply and call
    # server_update_model_tf directly once T <-> T@SERVER isomorphism is
    # supported.
    server_state = tff.federated_apply(server_update_model_tf,
                                       (server_state, round_model_delta))

    # Re-use graph used to construct `model`, since it has the variables, which
    # need to be read in federated_output_computation to get the correct shapes
    # and types for the federated aggregation.
    with g.as_default():
      aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
          client_outputs.model_output)

    # Promote the FederatedType outside the NamedTupleType
    aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs
示例#11
0
 def stateless_broadcast(state, value):
   empty_metrics = tff.federated_value((), tff.SERVER)
   return collections.OrderedDict(
       state=state,
       result=tff.federated_broadcast(value),
       measurements=empty_metrics)
示例#12
0
 def server_eval(server_model_weights, federated_dataset):
     client_outputs = tff.federated_map(
         client_eval,
         [tff.federated_broadcast(server_model_weights), federated_dataset])
     return model.federated_output_computation(client_outputs.local_outputs)