def robust_aggregation_fn(value, weight):
     aggregate = tff.federated_mean(value, weight=weight)
     for _ in range(num_communication_passes - 1):
         aggregate_at_client = tff.federated_broadcast(aggregate)
         updated_weight = tff.federated_map(
             update_weight_fn, (weight, aggregate_at_client, value))
         aggregate = tff.federated_mean(value, weight=updated_weight)
     return aggregate
Exemplo n.º 2
0
 def next_comp(state, value, weight):
     return collections.OrderedDict(
         state=tff.federated_map(_add_one, state),
         result=tff.federated_mean(value, weight),
         measurements=tff.federated_zip(
             collections.OrderedDict(num_clients=tff.federated_sum(
                 tff.federated_value(1, tff.CLIENTS)))))
Exemplo n.º 3
0
 def fed_output(local_outputs):
     # TODO(b/124070381): Remove need for using num_examples_float here.
     return {
         'num_examples':
         tff.federated_sum(local_outputs.num_examples),
         'loss':
         tff.federated_mean(local_outputs.loss,
                            weight=local_outputs.num_examples_float),
     }
Exemplo n.º 4
0
 def test_fails_stateful_aggregate_and_process(self):
     model_weights_type = model_utils.weights_type_from_model(
         model_examples.LinearRegression)
     with self.assertRaises(optimizer_utils.DisjointArgumentError):
         federated_averaging.build_federated_averaging_process(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=tf.keras.optimizers.SGD,
             stateful_delta_aggregate_fn=tff.utils.StatefulAggregateFn(
                 initialize_fn=lambda: (),
                 next_fn=lambda state, value, weight=None:  # pylint: disable=g-long-lambda
                 (state, tff.federated_mean(value, weight))),
             aggregation_process=optimizer_utils.build_stateless_mean(
                 model_delta_type=model_weights_type.trainable))
Exemplo n.º 5
0
 def test_fails_stateful_aggregate_and_process(self):
     with tf.Graph().as_default():
         model_weights_type = tff.framework.type_from_tensors(
             model_utils.ModelWeights.from_model(
                 model_examples.LinearRegression()))
     with self.assertRaises(optimizer_utils.DisjointArgumentError):
         optimizer_utils.build_model_delta_optimizer_process(
             model_fn=model_examples.LinearRegression,
             model_to_client_delta_fn=DummyClientDeltaFn,
             server_optimizer_fn=tf.keras.optimizers.SGD,
             stateful_delta_aggregate_fn=tff.utils.StatefulAggregateFn(
                 initialize_fn=lambda: (),
                 next_fn=lambda state, value, weight=None:  # pylint: disable=g-long-lambda
                 (state, tff.federated_mean(value, weight))),
             aggregation_process=optimizer_utils.build_stateless_mean(
                 model_delta_type=model_weights_type.trainable))
Exemplo n.º 6
0
def _state_incrementing_mean_next(server_state, client_value, weight=None):
    add_one = tff.tf_computation(lambda x: x + 1, tf.int32)
    new_state = tff.federated_map(add_one, server_state)
    return (new_state, tff.federated_mean(client_value, weight=weight))
Exemplo n.º 7
0
def build_stateless_mean():
    """Just tff.federated_mean with empty state, to use as a default."""
    return tff.utils.StatefulAggregateFn(
        initialize_fn=lambda: (),
        next_fn=lambda state, value, weight=None: (  # pylint: disable=g-long-lambda
            state, tff.federated_mean(value, weight=weight)))
Exemplo n.º 8
0
 def federated_train(model, learning_rate, data):
     return tff.federated_mean(
         tff.federated_map(local_train, [
             tff.federated_broadcast(model),
             tff.federated_broadcast(learning_rate), data
         ]))
Exemplo n.º 9
0
def _state_incrementing_mean_next(server_state, client_value, weight=None):
    new_state = tff.federated_map(_add_one, server_state)
    return (new_state, tff.federated_mean(client_value, weight=weight))
Exemplo n.º 10
0
  def run_one_round_tff(server_state, federated_dataset):
    """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
    model_weights_type = federated_server_state_type.member.model

    @tff.tf_computation(tf_dataset_type, model_weights_type)
    def client_delta_tf(tf_dataset, initial_model_weights):
      """Performs client local model optimization.

      Args:
        tf_dataset: a `tf.data.Dataset` that provides training examples.
        initial_model_weights: a `model_utils.ModelWeights` containing the
          starting weights.

      Returns:
        A `ClientOutput` structure.
      """
      client_delta_fn = model_to_client_delta_fn(model_fn)

      # TODO(b/123092620): this can be removed once AnonymousTuple works with
      # tf.contrib.framework.nest, or the following behavior is moved to
      # anonymous_tuple module.
      if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple):
        initial_model_weights = model_utils.ModelWeights.from_tff_value(
            initial_model_weights)

      client_output = client_delta_fn(tf_dataset, initial_model_weights)
      return client_output

    client_outputs = tff.federated_map(
        client_delta_tf,
        (federated_dataset, tff.federated_broadcast(server_state.model)))

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_model_tf(server_state, model_delta):
      """Converts args to correct python types and calls server_update_model."""
      # We need to convert TFF types to the types server_update_model expects.
      # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not
      # pretty, fold this into anonymous_tuple module or get working with
      # tf.contrib.framework.nest.
      py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple)
      model_delta = anonymous_tuple.to_odict(model_delta)
      py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple)
      server_state = ServerState(
          model=model_utils.ModelWeights.from_tff_value(server_state.model),
          optimizer_state=list(server_state.optimizer_state))

      return server_update_model(
          server_state,
          model_delta,
          model_fn=model_fn,
          optimizer_fn=server_optimizer_fn)

    # TODO(b/124070381): We hope to remove this explicit cast once we have a
    # full solution for type analysis in multiplications and divisions
    # inside TFF
    fed_weight_type = client_outputs.weights_delta_weight.type_signature.member
    py_typecheck.check_type(fed_weight_type, tff.TensorType)
    if fed_weight_type.dtype.is_integer:

      @tff.tf_computation(fed_weight_type)
      def _cast_to_float(x):
        return tf.cast(x, tf.float32)

      weight_denom = tff.federated_map(_cast_to_float,
                                       client_outputs.weights_delta_weight)
    else:
      weight_denom = client_outputs.weights_delta_weight
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    # TODO(b/123408447): remove tff.federated_apply and call
    # server_update_model_tf directly once T <-> T@SERVER isomorphism is
    # supported.
    server_state = tff.federated_apply(server_update_model_tf,
                                       (server_state, round_model_delta))

    # Re-use graph used to construct `model`, since it has the variables, which
    # need to be read in federated_output_computation to get the correct shapes
    # and types for the federated aggregation.
    with g.as_default():
      aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
          client_outputs.model_output)

    # Promote the FederatedType outside the NamedTupleType
    aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs
Exemplo n.º 11
0
 def stateless_mean(state, value, weight):
   empty_metrics = tff.federated_value((), tff.SERVER)
   return collections.OrderedDict(
       state=state,
       result=tff.federated_mean(value, weight=weight),
       measurements=empty_metrics)
Exemplo n.º 12
0
 def cast_to_float_mean(state, value, weight):
     return state, tff.federated_mean(value,
                                      weight=tff.federated_map(
                                          _cast_weight_to_float, weight))