Example #1
0
 def fed_output(local_outputs):
     # TODO(b/124070381): Remove need for using num_examples_float here.
     return {
         'num_examples':
         tff.federated_sum(local_outputs.num_examples),
         'loss':
         tff.federated_average(local_outputs.loss,
                               weight=local_outputs.num_examples_float),
     }
Example #2
0
 def federated_train(model, learning_rate, data):
     return tff.federated_average(
         tff.federated_map(local_train, [
             tff.federated_broadcast(model),
             tff.federated_broadcast(learning_rate), data
         ]))
Example #3
0
    def run_one_round_tff(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
        model_weights_type = federated_server_state_type.member.model

        @tff.tf_computation(tf_dataset_type, model_weights_type)
        def client_delta_tf(tf_dataset, initial_model_weights):
            """Performs client local model optimization.

      Args:
        tf_dataset: a `tf.data.Dataset` that provides training examples.
        initial_model_weights: a `model_utils.ModelWeights` containing the
          starting weights.

      Returns:
        A `ClientOutput` structure.
      """
            client_delta_fn = model_to_client_delta_fn(model_fn)

            # TODO(b/123092620): this can be removed once AnonymousTuple works with
            # tf.contrib.framework.nest, or the following behavior is moved to
            # anonymous_tuple module.
            if isinstance(initial_model_weights,
                          anonymous_tuple.AnonymousTuple):
                initial_model_weights = model_utils.ModelWeights.from_tff_value(
                    initial_model_weights)

            client_output = client_delta_fn(tf_dataset, initial_model_weights)
            return client_output

        client_outputs = tff.federated_map(
            client_delta_tf,
            (federated_dataset, tff.federated_broadcast(server_state.model)))

        @tff.tf_computation(server_state_type, model_weights_type.trainable)
        def server_update_model_tf(server_state, model_delta):
            """Converts args to correct python types and calls server_update_model."""
            # We need to convert TFF types to the types server_update_model expects.
            # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not
            # pretty, fold this into anonymous_tuple module or get working with
            # tf.contrib.framework.nest.
            py_typecheck.check_type(model_delta,
                                    anonymous_tuple.AnonymousTuple)
            model_delta = anonymous_tuple.to_odict(model_delta)
            py_typecheck.check_type(server_state,
                                    anonymous_tuple.AnonymousTuple)
            server_state = ServerState(
                model=model_utils.ModelWeights.from_tff_value(
                    server_state.model),
                optimizer_state=list(server_state.optimizer_state))

            return server_update_model(server_state,
                                       model_delta,
                                       model_fn=model_fn,
                                       optimizer_fn=server_optimizer_fn)

        # TODO(b/124070381): We hope to remove this explicit cast once we have a
        # full solution for type analysis in multiplications and divisions
        # inside TFF
        fed_weight_type = client_outputs.weights_delta_weight.type_signature.member
        py_typecheck.check_type(fed_weight_type, tff.TensorType)
        if fed_weight_type.dtype.is_integer:

            @tff.tf_computation(fed_weight_type)
            def _cast_to_float(x):
                return tf.cast(x, tf.float32)

            weight_denom = tff.federated_map(
                _cast_to_float, client_outputs.weights_delta_weight)
        else:
            weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_average(client_outputs.weights_delta,
                                                  weight=weight_denom)

        # TODO(b/123408447): remove tff.federated_apply and call
        # server_update_model_tf directly once T <-> T@SERVER isomorphism is
        # supported.
        server_state = tff.federated_apply(server_update_model_tf,
                                           (server_state, round_model_delta))

        # Re-use graph used to construct `model`, since it has the variables, which
        # need to be read in federated_output_computation to get the correct shapes
        # and types for the federated aggregation.
        with g.as_default():
            aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
                client_outputs.model_output)

        # Promote the FederatedType outside the NamedTupleType
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs