Ejemplo n.º 1
0
  def one_round_computation(server_state, federated_dataset):
    """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
      `tff.learning.Model.federated_output_computation`, both having
      `tff.SERVER` placement.
    """
    broadcast_output = broadcast_process.next(
        server_state.model_broadcast_state, server_state.model)
    client_outputs = tff.federated_map(
        _compute_local_training_and_client_delta,
        (federated_dataset, broadcast_output.result))
    aggregation_output = aggregation_process.next(
        server_state.delta_aggregate_state, client_outputs.weights_delta,
        client_outputs.weights_delta_weight)
    new_global_model, new_optimizer_state = tff.federated_map(
        server_update, (server_state.model, aggregation_output.result,
                        server_state.optimizer_state))
    new_server_state = tff.federated_zip(
        ServerState(new_global_model, new_optimizer_state,
                    aggregation_output.state, broadcast_output.state))
    aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
        client_outputs.model_output)
    measurements = tff.federated_zip(
        collections.OrderedDict(
            broadcast=broadcast_output.measurements,
            aggregation=aggregation_output.measurements,
            train=aggregated_outputs))
    return new_server_state, measurements
Ejemplo n.º 2
0
 def next_comp(state, value):
     return collections.OrderedDict(
         state=tff.federated_map(_add_one, state),
         result=tff.federated_broadcast(value),
         # Arbitrary metrics for testing.
         measurements=tff.federated_map(
             tff.tf_computation(
                 lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0),
             value))
Ejemplo n.º 3
0
 def next_comp(state, value, weight):
     return collections.OrderedDict(
         state=tff.federated_map(_add_one, state),
         result=tff.federated_mean(value, weight),
         measurements=tff.federated_zip(
             collections.OrderedDict(num_clients=tff.federated_sum(
                 tff.federated_value(1, tff.CLIENTS)))))
 def robust_aggregation_fn(value, weight):
     aggregate = tff.federated_mean(value, weight=weight)
     for _ in range(num_communication_passes - 1):
         aggregate_at_client = tff.federated_broadcast(aggregate)
         updated_weight = tff.federated_map(
             update_weight_fn, (weight, aggregate_at_client, value))
         aggregate = tff.federated_mean(value, weight=updated_weight)
     return aggregate
Ejemplo n.º 5
0
    def run_one_round_tff(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
        new_broadcaster_state, client_model = stateful_model_broadcast_fn(
            server_state.model_broadcast_state, server_state.model)

        client_outputs = tff.federated_map(tf_client_delta,
                                           (federated_dataset, client_model))

        # TODO(b/124070381): We hope to remove this explicit cast once we have a
        # full solution for type analysis in multiplications and divisions
        # inside TFF
        weight_denom = tff.federated_map(_cast_weight_to_float,
                                         client_outputs.weights_delta_weight)
        new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn(
            server_state.delta_aggregate_state,
            client_outputs.weights_delta,
            weight=weight_denom)

        # TODO(b/123408447): remove tff.federated_map and call
        # tf_server_update directly once T <-> T@SERVER isomorphism is
        # supported.
        server_state = tff.federated_map(
            tf_server_update,
            (server_state, round_model_delta, new_delta_aggregate_state,
             new_broadcaster_state))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)

        # TODO(b/131429028): Ideally this federated_zip shouldn't ever be needed.
        if isinstance(aggregated_outputs.type_signature, tff.NamedTupleType):
            # Promote the FederatedType outside the NamedTupleType.
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Ejemplo n.º 6
0
 def next_fn_taking_client_ids(param):
   datasets_on_clients = tff.federated_map(dataset_computation,
                                           param[dataset_index])
   original_param = []
   for idx, elem in enumerate(param):
     if idx != dataset_index:
       original_param.append(elem)
     else:
       original_param.append(datasets_on_clients)
   return process.next(original_param)
  def personalization_eval(server_model_weights, federated_client_input):
    """TFF orchestration logic."""
    client_init_weights = tff.federated_broadcast(server_model_weights)
    client_final_metrics = tff.federated_map(
        _client_computation, (client_init_weights, federated_client_input))

    # WARNING: Collecting information from clients can be risky. Users have to
    # make sure that it is proper to collect those metrics from clients.
    # TODO(b/147889283): Add a link to the TFF doc once it exists.
    results = tff.utils.federated_sample(client_final_metrics, max_num_samples)
    return results
Ejemplo n.º 8
0
    def one_round_computation(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
      `tff.learning.Model.federated_output_computation`, both having
      `tff.SERVER` placement.
    """
        new_broadcaster_state, client_model = model_broadcast_fn(
            server_state.model_broadcast_state, server_state.model)

        client_outputs = tff.federated_map(
            _compute_local_training_and_client_delta,
            (federated_dataset, client_model))

        new_delta_aggregate_state, round_model_delta = delta_aggregate_fn(
            server_state.delta_aggregate_state,
            client_outputs.weights_delta,
            weight=client_outputs.weights_delta_weight)

        new_global_model, new_optimizer_state = tff.federated_map(
            server_update, (server_state.model, round_model_delta,
                            server_state.optimizer_state))

        new_server_state = tff.federated_zip(
            ServerState(new_global_model, new_optimizer_state,
                        new_delta_aggregate_state, new_broadcaster_state))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)

        if isinstance(aggregated_outputs.type_signature, tff.NamedTupleType):
            # Promote the FederatedType outside the NamedTupleType.
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return new_server_state, aggregated_outputs
Ejemplo n.º 9
0
    def run_one_round_tff(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
        model_weights_type = server_state_type.model

        @tff.tf_computation(tf_dataset_type, model_weights_type)
        def client_delta_tf(tf_dataset, initial_model_weights):
            """Performs client local model optimization.

      Args:
        tf_dataset: a `tf.data.Dataset` that provides training examples.
        initial_model_weights: a `model_utils.ModelWeights` containing the
          starting weights.

      Returns:
        A `ClientOutput` structure.
      """
            client_delta_fn = model_to_client_delta_fn(model_fn)
            client_output = client_delta_fn(tf_dataset, initial_model_weights)
            return client_output

        new_broadcaster_state, client_model = stateful_model_broadcast_fn(
            server_state.model_broadcast_state, server_state.model)

        client_outputs = tff.federated_map(client_delta_tf,
                                           (federated_dataset, client_model))

        @tff.tf_computation(
            server_state_type, model_weights_type.trainable,
            server_state.delta_aggregate_state.type_signature.member,
            server_state.model_broadcast_state.type_signature.member)
        def server_update_tf(server_state, model_delta,
                             new_delta_aggregate_state, new_broadcaster_state):
            """Converts args to correct python types and calls server_update_model."""
            py_typecheck.check_type(server_state, ServerState)
            server_state = ServerState(
                model=server_state.model,
                optimizer_state=list(server_state.optimizer_state),
                delta_aggregate_state=new_delta_aggregate_state,
                model_broadcast_state=new_broadcaster_state)

            return server_update_model(server_state,
                                       model_delta,
                                       model_fn=model_fn,
                                       optimizer_fn=server_optimizer_fn)

        # TODO(b/124070381): We hope to remove this explicit cast once we have a
        # full solution for type analysis in multiplications and divisions
        # inside TFF
        fed_weight_type = client_outputs.weights_delta_weight.type_signature.member
        py_typecheck.check_type(fed_weight_type, tff.TensorType)
        if fed_weight_type.dtype.is_integer:

            @tff.tf_computation(fed_weight_type)
            def _cast_to_float(x):
                return tf.cast(x, tf.float32)

            weight_denom = tff.federated_map(
                _cast_to_float, client_outputs.weights_delta_weight)
        else:
            weight_denom = client_outputs.weights_delta_weight

        new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn(
            server_state.delta_aggregate_state,
            client_outputs.weights_delta,
            weight=weight_denom)

        # TODO(b/123408447): remove tff.federated_apply and call
        # server_update_tf directly once T <-> T@SERVER isomorphism is
        # supported.
        server_state = tff.federated_apply(
            server_update_tf,
            (server_state, round_model_delta, new_delta_aggregate_state,
             new_broadcaster_state))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)

        # Promote the FederatedType outside the NamedTupleType
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Ejemplo n.º 10
0
 def federated_train(model, learning_rate, data):
     return tff.federated_average(
         tff.federated_map(local_train, [
             tff.federated_broadcast(model),
             tff.federated_broadcast(learning_rate), data
         ]))
Ejemplo n.º 11
0
def _state_incrementing_broadcast_next(server_state, server_value):
    new_state = tff.federated_map(_add_one, server_state)
    return (new_state, tff.federated_broadcast(server_value))
Ejemplo n.º 12
0
def _state_incrementing_mean_next(server_state, client_value, weight=None):
    new_state = tff.federated_map(_add_one, server_state)
    return (new_state, tff.federated_mean(client_value, weight=weight))
Ejemplo n.º 13
0
 def server_eval(server_model_weights, federated_dataset):
     client_outputs = tff.federated_map(
         client_eval,
         [tff.federated_broadcast(server_model_weights), federated_dataset])
     return model.federated_output_computation(client_outputs.local_outputs)
Ejemplo n.º 14
0
 def next_fn(empty_tup, x):
     del empty_tup  # Unused
     return tff.federated_sum(tff.federated_map(reduce_dataset, x))
Ejemplo n.º 15
0
 def cast_to_float_mean(state, value, weight):
     return state, tff.federated_mean(value,
                                      weight=tff.federated_map(
                                          _cast_weight_to_float, weight))
Ejemplo n.º 16
0
def _state_incrementing_mean_next(server_state, client_value, weight=None):
    add_one = tff.tf_computation(lambda x: x + 1, tf.int32)
    new_state = tff.federated_map(add_one, server_state)
    return (new_state, tff.federated_mean(client_value, weight=weight))
Ejemplo n.º 17
0
def _state_incrementing_broadcast_next(server_state, server_value):
    add_one = tff.tf_computation(lambda x: x + 1, tf.int32)
    new_state = tff.federated_map(add_one, server_state)
    return (new_state, tff.federated_broadcast(server_value))
Ejemplo n.º 18
0
  def run_one_round_tff(server_state, federated_dataset):
    """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
    model_weights_type = federated_server_state_type.member.model

    @tff.tf_computation(tf_dataset_type, model_weights_type)
    def client_delta_tf(tf_dataset, initial_model_weights):
      """Performs client local model optimization.

      Args:
        tf_dataset: a `tf.data.Dataset` that provides training examples.
        initial_model_weights: a `model_utils.ModelWeights` containing the
          starting weights.

      Returns:
        A `ClientOutput` structure.
      """
      client_delta_fn = model_to_client_delta_fn(model_fn)

      # TODO(b/123092620): this can be removed once AnonymousTuple works with
      # tf.contrib.framework.nest, or the following behavior is moved to
      # anonymous_tuple module.
      if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple):
        initial_model_weights = model_utils.ModelWeights.from_tff_value(
            initial_model_weights)

      client_output = client_delta_fn(tf_dataset, initial_model_weights)
      return client_output

    client_outputs = tff.federated_map(
        client_delta_tf,
        (federated_dataset, tff.federated_broadcast(server_state.model)))

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_model_tf(server_state, model_delta):
      """Converts args to correct python types and calls server_update_model."""
      # We need to convert TFF types to the types server_update_model expects.
      # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not
      # pretty, fold this into anonymous_tuple module or get working with
      # tf.contrib.framework.nest.
      py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple)
      model_delta = anonymous_tuple.to_odict(model_delta)
      py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple)
      server_state = ServerState(
          model=model_utils.ModelWeights.from_tff_value(server_state.model),
          optimizer_state=list(server_state.optimizer_state))

      return server_update_model(
          server_state,
          model_delta,
          model_fn=model_fn,
          optimizer_fn=server_optimizer_fn)

    # TODO(b/124070381): We hope to remove this explicit cast once we have a
    # full solution for type analysis in multiplications and divisions
    # inside TFF
    fed_weight_type = client_outputs.weights_delta_weight.type_signature.member
    py_typecheck.check_type(fed_weight_type, tff.TensorType)
    if fed_weight_type.dtype.is_integer:

      @tff.tf_computation(fed_weight_type)
      def _cast_to_float(x):
        return tf.cast(x, tf.float32)

      weight_denom = tff.federated_map(_cast_to_float,
                                       client_outputs.weights_delta_weight)
    else:
      weight_denom = client_outputs.weights_delta_weight
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    # TODO(b/123408447): remove tff.federated_apply and call
    # server_update_model_tf directly once T <-> T@SERVER isomorphism is
    # supported.
    server_state = tff.federated_apply(server_update_model_tf,
                                       (server_state, round_model_delta))

    # Re-use graph used to construct `model`, since it has the variables, which
    # need to be read in federated_output_computation to get the correct shapes
    # and types for the federated aggregation.
    with g.as_default():
      aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
          client_outputs.model_output)

    # Promote the FederatedType outside the NamedTupleType
    aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs