Exemple #1
0
def build_encoded_sum(values, encoders):
  """Builds `StatefulAggregateFn` for `values`, to be encoded by `encoders`.

  Args:
    values: Values to be encoded by the `StatefulAggregateFn`. Must be
      convertible to `tff.Value`.
    encoders: A collection of `GatherEncoder` objects to be used for encoding
      `values`. Must have the same structure as `values`.

  Returns:
    A `StatefulAggregateFn` of which `next_fn` encodes the input at
    `tff.CLIENTS`, and computes their sum at `tff.SERVER`, automatically
    splitting the decoding part based on its commutativity with sum.

  Raises:
    ValueError: If `values` and `encoders` do not have the same structure.
    TypeError: If `encoders` are not instances of `GatherEncoder`, or if
      `values` are not compatible with the expected input of the `encoders`.
  """

  tf.nest.assert_same_structure(values, encoders)
  tf.nest.map_structure(
      lambda e, v: _validate_encoder(e, v, tensor_encoding.core.GatherEncoder),
      encoders, values)

  value_type = type_utils.type_from_tensors(values)

  initial_state_fn, state_type = _build_initial_state_tf_computation(encoders)

  nest_encoder = _build_tf_computations_for_gather(state_type, value_type,
                                                   encoders)
  encoded_sum_fn = _build_encoded_sum_fn(nest_encoder)

  return computation_utils.StatefulAggregateFn(
      initialize_fn=initial_state_fn, next_fn=encoded_sum_fn)
    def test_execute_with_default_weight(self):
        aggregate_fn = computation_utils.StatefulAggregateFn(
            initialize_fn=agg_initialize_fn, next_fn=agg_next_fn)
        aggregate_arg_type = computation_types.FederatedType(
            tf.float32, placements.CLIENTS)

        @computations.federated_computation(aggregate_arg_type)
        def federated_aggregate_test(args):
            state = intrinsics.federated_value(aggregate_fn.initialize(),
                                               placements.SERVER)
            return aggregate_fn(state, args)

        expected_type_signature = computation_types.FunctionType(
            parameter=aggregate_arg_type,
            result=computation_types.NamedTupleType([
                computation_types.FederatedType(
                    collections.OrderedDict([('call_count', tf.int32)]),
                    placements.SERVER),
                computation_types.FederatedType(tf.float32, placements.SERVER)
            ]))
        self.assertEqual(federated_aggregate_test.type_signature,
                         expected_type_signature)
        state, mean = federated_aggregate_test([1.0, 2.0, 3.0])
        self.assertAlmostEqual(mean, 2.0)  # (1 + 2 + 3) / (1 + 1 + 1)
        self.assertDictEqual(state._asdict(), {'call_count': 1})
Exemple #3
0
def build_encoded_mean(values, encoders):
    """Builds `StatefulAggregateFn` for `values`, to be encoded by `encoders`.

  Args:
    values: Values to be encoded by the `StatefulAggregateFn`. Must be
      convertible to `tff.Value`.
    encoders: A collection of `GatherEncoder` objects to be used for encoding
      `values`. Must have the same structure as `values`.

  Returns:
    A `StatefulAggregateFn` of which `next_fn` encodes the input at
    `tff.CLIENTS`, and computes their mean at `tff.SERVER`, automatically
    splitting the decoding part based on its commutativity with sum.

  Raises:
    ValueError: If `values` and `encoders` do not have the same structure.
    TypeError: If `encoders` are not instances of `GatherEncoder`, or if
      `values` are not compatible with the expected input of the `encoders`.
  """
    warnings.warn(
        'Deprecation warning: tff.utils.build_encoded_mean() is deprecated, use '
        'tff.utils.build_encoded_mean_process() instead.', DeprecationWarning)

    tf.nest.assert_same_structure(values, encoders)
    tf.nest.map_structure(
        lambda e, v: _validate_encoder(e, v, tensor_encoding.core.GatherEncoder
                                       ), encoders, values)

    value_type = type_conversions.type_from_tensors(values)

    initial_state_fn, state_type = _build_initial_state_tf_computation(
        encoders)

    nest_encoder = _build_tf_computations_for_gather(state_type, value_type,
                                                     encoders)
    encoded_sum_fn = _build_encoded_sum_fn(nest_encoder)

    @computations.tf_computation(value_type, tf.float32)
    def multiply_fn(value, weight):
        return tf.nest.map_structure(lambda v: v * tf.cast(weight, v.dtype),
                                     value)

    @computations.tf_computation(value_type, tf.float32)
    def divide_fn(value, denominator):
        return tf.nest.map_structure(
            lambda v: v / tf.cast(denominator, v.dtype), value)

    def encoded_mean_fn(state, values, weight):
        weighted_values = intrinsics.federated_map(multiply_fn,
                                                   [values, weight])
        updated_state, summed_decoded_values = encoded_sum_fn(
            state, weighted_values)
        summed_weights = intrinsics.federated_sum(weight)
        decoded_values = intrinsics.federated_map(
            divide_fn, [summed_decoded_values, summed_weights])
        return updated_state, decoded_values

    return computation_utils.StatefulAggregateFn(
        initialize_fn=initial_state_fn, next_fn=encoded_mean_fn)
Exemple #4
0
 def test_fails_stateful_aggregate_and_process(self):
   model_weights_type = model_utils.weights_type_from_model(
       model_examples.LinearRegression)
   with self.assertRaises(optimizer_utils.DisjointArgumentError):
     optimizer_utils.build_model_delta_optimizer_process(
         model_fn=model_examples.LinearRegression,
         model_to_client_delta_fn=DummyClientDeltaFn,
         server_optimizer_fn=tf.keras.optimizers.SGD,
         stateful_delta_aggregate_fn=computation_utils.StatefulAggregateFn(
             initialize_fn=lambda: (),
             next_fn=lambda state, value, weight=None:  # pylint: disable=g-long-lambda
             (state, intrinsics.federated_mean(value, weight))),
         aggregation_process=optimizer_utils.build_stateless_mean(
             model_delta_type=model_weights_type.trainable))
    def test_execute_with_explicit_weights(self):
        aggregate_fn = computation_utils.StatefulAggregateFn(
            initialize_fn=agg_initialize_fn, next_fn=agg_next_fn)

        @computations.federated_computation(
            computation_types.FederatedType(tf.float32, placements.CLIENTS),
            computation_types.FederatedType(tf.float32, placements.CLIENTS))
        def federated_aggregate_test(args, weights):
            state = intrinsics.federated_value(aggregate_fn.initialize(),
                                               placements.SERVER)
            return aggregate_fn(state, args, weights)

        state, mean = federated_aggregate_test([1.0, 2.0, 3.0],
                                               [4.0, 1.0, 1.0])
        self.assertAlmostEqual(mean, 1.5)  # (1*4 + 2*1 + 3*1) / (4 + 1 + 1)
        self.assertDictEqual(state._asdict(), {'call_count': 1})
Exemple #6
0
            model_output=self._model.report_local_outputs(),
            optimizer_output=collections.OrderedDict([('client_weight',
                                                       client_weight)]))


@computations.tf_computation(tf.int32)
def _add_one(x):
    return x + 1


def _state_incrementing_mean_next(server_state, client_value, weight=None):
    new_state = intrinsics.federated_map(_add_one, server_state)
    return (new_state, intrinsics.federated_mean(client_value, weight=weight))


state_incrementing_mean = computation_utils.StatefulAggregateFn(
    lambda: tf.constant(0), _state_incrementing_mean_next)


def _state_incrementing_broadcast_next(server_state, server_value):
    new_state = intrinsics.federated_map(_add_one, server_state)
    return (new_state, intrinsics.federated_broadcast(server_value))


state_incrementing_broadcaster = computation_utils.StatefulBroadcastFn(
    lambda: tf.constant(0), _state_incrementing_broadcast_next)


def _build_test_measured_broadcast(
    model_weights_type: computation_types.StructType
) -> measured_process.MeasuredProcess:
    """Builds a test `MeasuredProcess` that has state and metrics."""
Exemple #7
0
def build_dp_aggregate(query, value_type_fn=_default_get_value_type_fn):
    """Builds a stateful aggregator for tensorflow_privacy DPQueries.

  The returned `StatefulAggregateFn` can be called with any nested structure for
  the values being statefully aggregated. However, it's necessary to provide two
  functions as arguments which indicate the properties (the `tff.Type` and the
  `anonymous_tuple.AnonymousTuple` conversion) of the nested structure that will
  be used. If using a `collections.OrderedDict` as the value's nested structure,
  the defaults for the arguments suffice.

  Args:
    query: A DPQuery to aggregate. For compatibility with tensorflow_federated,
      the global_state and sample_state of the query must be structures
      supported by tf.nest.
    value_type_fn: Python function that takes the value argument of next_fn and
      returns the value type. This will be used in determining the TensorSpecs
      that establish the initial sample state. If the value being aggregated is
      an `collections.OrderedDict`, the default for this argument can be used.
      This argument probably gets removed once b/123092620 is addressed (and the
      associated processing step gets replaced with a simple call to
      `value.type_signature.member`).

  Returns:
    A tuple of:
      - a `computation_utils.StatefulAggregateFn` that aggregates according to
          the query
      - the TFF type of the DP aggregator's global state
  """
    warnings.warn(
        'Deprecation warning: tff.utils.build_dp_aggregate() is deprecated, use '
        'tff.utils.build_dp_aggregate_process() instead.', DeprecationWarning)

    @computations.tf_computation
    def initialize_fn():
        return query.initial_global_state()

    def next_fn(global_state, value, weight=None):
        """Defines next_fn for StatefulAggregateFn."""
        # Weighted aggregation is not supported.
        # TODO(b/140236959): Add an assertion that weight is None here, so the
        # contract of this method is better established. Will likely cause some
        # downstream breaks.
        del weight

        #######################################
        # Define local tf_computations

        # TODO(b/129567727): Make most of these tf_computations polymorphic
        # so type manipulation isn't needed.

        global_state_type = initialize_fn.type_signature.result

        @computations.tf_computation(global_state_type)
        def derive_sample_params(global_state):
            return query.derive_sample_params(global_state)

        @computations.tf_computation(
            derive_sample_params.type_signature.result,
            value.type_signature.member)
        def preprocess_record(params, record):
            return query.preprocess_record(params, record)

        # TODO(b/123092620): We should have the expected container type here.
        value_type = value_type_fn(value)
        value_type = computation_types.to_type(value_type)

        tensor_specs = type_conversions.type_to_tf_tensor_specs(value_type)

        @computations.tf_computation
        def zero():
            return query.initial_sample_state(tensor_specs)

        sample_state_type = zero.type_signature.result

        @computations.tf_computation(sample_state_type,
                                     preprocess_record.type_signature.result)
        def accumulate(sample_state, preprocessed_record):
            return query.accumulate_preprocessed_record(
                sample_state, preprocessed_record)

        @computations.tf_computation(sample_state_type, sample_state_type)
        def merge(sample_state_1, sample_state_2):
            return query.merge_sample_states(sample_state_1, sample_state_2)

        @computations.tf_computation(merge.type_signature.result)
        def report(sample_state):
            return sample_state

        @computations.tf_computation(sample_state_type, global_state_type)
        def post_process(sample_state, global_state):
            result, new_global_state = query.get_noised_result(
                sample_state, global_state)
            return new_global_state, result

        #######################################
        # Orchestration logic

        sample_params = intrinsics.federated_map(derive_sample_params,
                                                 global_state)
        client_sample_params = intrinsics.federated_broadcast(sample_params)
        preprocessed_record = intrinsics.federated_map(
            preprocess_record, (client_sample_params, value))
        agg_result = intrinsics.federated_aggregate(preprocessed_record,
                                                    zero(), accumulate, merge,
                                                    report)

        return intrinsics.federated_map(post_process,
                                        (agg_result, global_state))

    # TODO(b/140236959): Find a way to have this method return only one thing. The
    # best approach is probably to add (to StatefulAggregateFn) a property that
    # stores the type of the global state.
    aggregate_fn = computation_utils.StatefulAggregateFn(
        initialize_fn=initialize_fn, next_fn=next_fn)
    return (aggregate_fn, initialize_fn.type_signature.result)
 def federated_aggregate_test(values, weights):
   aggregate_fn = computation_utils.StatefulAggregateFn(
       initialize_fn=agg_initialize_fn, next_fn=agg_next_fn)
   state = tff.federated_value(aggregate_fn.initialize(), tff.SERVER)
   return aggregate_fn(state, values, weights)
def build_dp_aggregate(query,
                       value_type_fn=_default_get_value_type_fn,
                       from_anon_tuple_fn=_default_from_anon_tuple_fn):
    """Builds a stateful aggregator for tensorflow_privacy DPQueries.

  The returned StatefulAggregateFn can be called with any nested structure for
  the values being statefully aggregated. However, it's necessary to provide two
  functions as arguments which indicate the properties (the tff.Type and the
  AnonymousTuple conversion) of the nested structure that will be used. If using
  an OrderedDict as the value's nested structure, the defaults for the arguments
  suffice.

  Args:
    query: A DPQuery to aggregate. For compatibility with tensorflow_federated,
      the global_state and sample_state of the query must be structures
      supported by tf.nest.
    value_type_fn: Python function that takes the value argument of next_fn and
      returns the value type. This will be used in determining the TensorSpecs
      that establish the initial sample state. If the value being aggregated is
      an OrderedDict, the default for this argument can be used. This argument
      probably gets removed once b/123092620 is addressed (and the associated
      processing step gets replaced with a simple call to
      value.type_signature.member).
    from_anon_tuple_fn: Python function that takes a client record and converts
      it to the container type that it was in before passing through TFF. (Right
      now, TFF computation causes the client record to be changed into an
      AnonymousTuple, and this method corrects for that). If the value being
      aggregated is an OrderedDict, the default for this argument can be used.
      This argument likely goes away once b/123092620 is addressed. The default
      behavior assumes that the client record (before being converted to
      AnonymousTuple) was an OrderedDict containing a flat structure of Tensors
      (as it is if using the tff.learning APIs like
      tff.learning.build_federated_averaging_process).

  Returns:
    A tuple of:
      - a `computation_utils.StatefulAggregateFn` that aggregates according to
          the query
      - the TFF type of the DP aggregator's global state
  """
    @tff.tf_computation
    def initialize_fn():
        return query.initial_global_state()

    def next_fn(global_state, value, weight=None):
        """Defines next_fn for StatefulAggregateFn."""
        # Weighted aggregation is not supported.
        # TODO(b/140236959): Add an assertion that weight is None here, so the
        # contract of this method is better established. Will likely cause some
        # downstream breaks.
        del weight

        #######################################
        # Define local tf_computations

        # TODO(b/129567727): Make most of these tf_computations polymorphic
        # so type manipulation isn't needed.

        global_state_type = initialize_fn.type_signature.result

        @tff.tf_computation(global_state_type)
        def derive_sample_params(global_state):
            return query.derive_sample_params(global_state)

        @tff.tf_computation(derive_sample_params.type_signature.result,
                            value.type_signature.member)
        def preprocess_record(params, record):
            # TODO(b/123092620): Once TFF passes the expected container type (instead
            # of AnonymousTuple), we shouldn't need this.
            record = from_anon_tuple_fn(record)

            return query.preprocess_record(params, record)

        # TODO(b/123092620): We should have the expected container type here.
        value_type = value_type_fn(value)

        tensor_specs = tff_framework.type_to_tf_tensor_specs(value_type)

        @tff.tf_computation
        def zero():
            return query.initial_sample_state(tensor_specs)

        sample_state_type = zero.type_signature.result

        @tff.tf_computation(sample_state_type,
                            preprocess_record.type_signature.result)
        def accumulate(sample_state, preprocessed_record):
            return query.accumulate_preprocessed_record(
                sample_state, preprocessed_record)

        @tff.tf_computation(sample_state_type, sample_state_type)
        def merge(sample_state_1, sample_state_2):
            return query.merge_sample_states(sample_state_1, sample_state_2)

        @tff.tf_computation(merge.type_signature.result)
        def report(sample_state):
            return sample_state

        @tff.tf_computation(sample_state_type, global_state_type)
        def post_process(sample_state, global_state):
            result, new_global_state = query.get_noised_result(
                sample_state, global_state)
            return new_global_state, result

        #######################################
        # Orchestration logic

        sample_params = tff.federated_apply(derive_sample_params, global_state)
        client_sample_params = tff.federated_broadcast(sample_params)
        preprocessed_record = tff.federated_map(preprocess_record,
                                                (client_sample_params, value))
        agg_result = tff.federated_aggregate(preprocessed_record, zero(),
                                             accumulate, merge, report)

        return tff.federated_apply(post_process, (agg_result, global_state))

    # TODO(b/140236959): Find a way to have this method return only one thing. The
    # best approach is probably to add (to StatefulAggregateFn) a property that
    # stores the type of the global state.
    aggregate_fn = computation_utils.StatefulAggregateFn(
        initialize_fn=initialize_fn, next_fn=next_fn)
    return (aggregate_fn, initialize_fn.type_signature.result)