def build_encoded_sum(values, encoders): """Builds `StatefulAggregateFn` for `values`, to be encoded by `encoders`. Args: values: Values to be encoded by the `StatefulAggregateFn`. Must be convertible to `tff.Value`. encoders: A collection of `GatherEncoder` objects to be used for encoding `values`. Must have the same structure as `values`. Returns: A `StatefulAggregateFn` of which `next_fn` encodes the input at `tff.CLIENTS`, and computes their sum at `tff.SERVER`, automatically splitting the decoding part based on its commutativity with sum. Raises: ValueError: If `values` and `encoders` do not have the same structure. TypeError: If `encoders` are not instances of `GatherEncoder`, or if `values` are not compatible with the expected input of the `encoders`. """ tf.nest.assert_same_structure(values, encoders) tf.nest.map_structure( lambda e, v: _validate_encoder(e, v, tensor_encoding.core.GatherEncoder), encoders, values) value_type = type_utils.type_from_tensors(values) initial_state_fn, state_type = _build_initial_state_tf_computation(encoders) nest_encoder = _build_tf_computations_for_gather(state_type, value_type, encoders) encoded_sum_fn = _build_encoded_sum_fn(nest_encoder) return computation_utils.StatefulAggregateFn( initialize_fn=initial_state_fn, next_fn=encoded_sum_fn)
def test_execute_with_default_weight(self): aggregate_fn = computation_utils.StatefulAggregateFn( initialize_fn=agg_initialize_fn, next_fn=agg_next_fn) aggregate_arg_type = computation_types.FederatedType( tf.float32, placements.CLIENTS) @computations.federated_computation(aggregate_arg_type) def federated_aggregate_test(args): state = intrinsics.federated_value(aggregate_fn.initialize(), placements.SERVER) return aggregate_fn(state, args) expected_type_signature = computation_types.FunctionType( parameter=aggregate_arg_type, result=computation_types.NamedTupleType([ computation_types.FederatedType( collections.OrderedDict([('call_count', tf.int32)]), placements.SERVER), computation_types.FederatedType(tf.float32, placements.SERVER) ])) self.assertEqual(federated_aggregate_test.type_signature, expected_type_signature) state, mean = federated_aggregate_test([1.0, 2.0, 3.0]) self.assertAlmostEqual(mean, 2.0) # (1 + 2 + 3) / (1 + 1 + 1) self.assertDictEqual(state._asdict(), {'call_count': 1})
def build_encoded_mean(values, encoders): """Builds `StatefulAggregateFn` for `values`, to be encoded by `encoders`. Args: values: Values to be encoded by the `StatefulAggregateFn`. Must be convertible to `tff.Value`. encoders: A collection of `GatherEncoder` objects to be used for encoding `values`. Must have the same structure as `values`. Returns: A `StatefulAggregateFn` of which `next_fn` encodes the input at `tff.CLIENTS`, and computes their mean at `tff.SERVER`, automatically splitting the decoding part based on its commutativity with sum. Raises: ValueError: If `values` and `encoders` do not have the same structure. TypeError: If `encoders` are not instances of `GatherEncoder`, or if `values` are not compatible with the expected input of the `encoders`. """ warnings.warn( 'Deprecation warning: tff.utils.build_encoded_mean() is deprecated, use ' 'tff.utils.build_encoded_mean_process() instead.', DeprecationWarning) tf.nest.assert_same_structure(values, encoders) tf.nest.map_structure( lambda e, v: _validate_encoder(e, v, tensor_encoding.core.GatherEncoder ), encoders, values) value_type = type_conversions.type_from_tensors(values) initial_state_fn, state_type = _build_initial_state_tf_computation( encoders) nest_encoder = _build_tf_computations_for_gather(state_type, value_type, encoders) encoded_sum_fn = _build_encoded_sum_fn(nest_encoder) @computations.tf_computation(value_type, tf.float32) def multiply_fn(value, weight): return tf.nest.map_structure(lambda v: v * tf.cast(weight, v.dtype), value) @computations.tf_computation(value_type, tf.float32) def divide_fn(value, denominator): return tf.nest.map_structure( lambda v: v / tf.cast(denominator, v.dtype), value) def encoded_mean_fn(state, values, weight): weighted_values = intrinsics.federated_map(multiply_fn, [values, weight]) updated_state, summed_decoded_values = encoded_sum_fn( state, weighted_values) summed_weights = intrinsics.federated_sum(weight) decoded_values = intrinsics.federated_map( divide_fn, [summed_decoded_values, summed_weights]) return updated_state, decoded_values return computation_utils.StatefulAggregateFn( initialize_fn=initial_state_fn, next_fn=encoded_mean_fn)
def test_fails_stateful_aggregate_and_process(self): model_weights_type = model_utils.weights_type_from_model( model_examples.LinearRegression) with self.assertRaises(optimizer_utils.DisjointArgumentError): optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, stateful_delta_aggregate_fn=computation_utils.StatefulAggregateFn( initialize_fn=lambda: (), next_fn=lambda state, value, weight=None: # pylint: disable=g-long-lambda (state, intrinsics.federated_mean(value, weight))), aggregation_process=optimizer_utils.build_stateless_mean( model_delta_type=model_weights_type.trainable))
def test_execute_with_explicit_weights(self): aggregate_fn = computation_utils.StatefulAggregateFn( initialize_fn=agg_initialize_fn, next_fn=agg_next_fn) @computations.federated_computation( computation_types.FederatedType(tf.float32, placements.CLIENTS), computation_types.FederatedType(tf.float32, placements.CLIENTS)) def federated_aggregate_test(args, weights): state = intrinsics.federated_value(aggregate_fn.initialize(), placements.SERVER) return aggregate_fn(state, args, weights) state, mean = federated_aggregate_test([1.0, 2.0, 3.0], [4.0, 1.0, 1.0]) self.assertAlmostEqual(mean, 1.5) # (1*4 + 2*1 + 3*1) / (4 + 1 + 1) self.assertDictEqual(state._asdict(), {'call_count': 1})
model_output=self._model.report_local_outputs(), optimizer_output=collections.OrderedDict([('client_weight', client_weight)])) @computations.tf_computation(tf.int32) def _add_one(x): return x + 1 def _state_incrementing_mean_next(server_state, client_value, weight=None): new_state = intrinsics.federated_map(_add_one, server_state) return (new_state, intrinsics.federated_mean(client_value, weight=weight)) state_incrementing_mean = computation_utils.StatefulAggregateFn( lambda: tf.constant(0), _state_incrementing_mean_next) def _state_incrementing_broadcast_next(server_state, server_value): new_state = intrinsics.federated_map(_add_one, server_state) return (new_state, intrinsics.federated_broadcast(server_value)) state_incrementing_broadcaster = computation_utils.StatefulBroadcastFn( lambda: tf.constant(0), _state_incrementing_broadcast_next) def _build_test_measured_broadcast( model_weights_type: computation_types.StructType ) -> measured_process.MeasuredProcess: """Builds a test `MeasuredProcess` that has state and metrics."""
def build_dp_aggregate(query, value_type_fn=_default_get_value_type_fn): """Builds a stateful aggregator for tensorflow_privacy DPQueries. The returned `StatefulAggregateFn` can be called with any nested structure for the values being statefully aggregated. However, it's necessary to provide two functions as arguments which indicate the properties (the `tff.Type` and the `anonymous_tuple.AnonymousTuple` conversion) of the nested structure that will be used. If using a `collections.OrderedDict` as the value's nested structure, the defaults for the arguments suffice. Args: query: A DPQuery to aggregate. For compatibility with tensorflow_federated, the global_state and sample_state of the query must be structures supported by tf.nest. value_type_fn: Python function that takes the value argument of next_fn and returns the value type. This will be used in determining the TensorSpecs that establish the initial sample state. If the value being aggregated is an `collections.OrderedDict`, the default for this argument can be used. This argument probably gets removed once b/123092620 is addressed (and the associated processing step gets replaced with a simple call to `value.type_signature.member`). Returns: A tuple of: - a `computation_utils.StatefulAggregateFn` that aggregates according to the query - the TFF type of the DP aggregator's global state """ warnings.warn( 'Deprecation warning: tff.utils.build_dp_aggregate() is deprecated, use ' 'tff.utils.build_dp_aggregate_process() instead.', DeprecationWarning) @computations.tf_computation def initialize_fn(): return query.initial_global_state() def next_fn(global_state, value, weight=None): """Defines next_fn for StatefulAggregateFn.""" # Weighted aggregation is not supported. # TODO(b/140236959): Add an assertion that weight is None here, so the # contract of this method is better established. Will likely cause some # downstream breaks. del weight ####################################### # Define local tf_computations # TODO(b/129567727): Make most of these tf_computations polymorphic # so type manipulation isn't needed. global_state_type = initialize_fn.type_signature.result @computations.tf_computation(global_state_type) def derive_sample_params(global_state): return query.derive_sample_params(global_state) @computations.tf_computation( derive_sample_params.type_signature.result, value.type_signature.member) def preprocess_record(params, record): return query.preprocess_record(params, record) # TODO(b/123092620): We should have the expected container type here. value_type = value_type_fn(value) value_type = computation_types.to_type(value_type) tensor_specs = type_conversions.type_to_tf_tensor_specs(value_type) @computations.tf_computation def zero(): return query.initial_sample_state(tensor_specs) sample_state_type = zero.type_signature.result @computations.tf_computation(sample_state_type, preprocess_record.type_signature.result) def accumulate(sample_state, preprocessed_record): return query.accumulate_preprocessed_record( sample_state, preprocessed_record) @computations.tf_computation(sample_state_type, sample_state_type) def merge(sample_state_1, sample_state_2): return query.merge_sample_states(sample_state_1, sample_state_2) @computations.tf_computation(merge.type_signature.result) def report(sample_state): return sample_state @computations.tf_computation(sample_state_type, global_state_type) def post_process(sample_state, global_state): result, new_global_state = query.get_noised_result( sample_state, global_state) return new_global_state, result ####################################### # Orchestration logic sample_params = intrinsics.federated_map(derive_sample_params, global_state) client_sample_params = intrinsics.federated_broadcast(sample_params) preprocessed_record = intrinsics.federated_map( preprocess_record, (client_sample_params, value)) agg_result = intrinsics.federated_aggregate(preprocessed_record, zero(), accumulate, merge, report) return intrinsics.federated_map(post_process, (agg_result, global_state)) # TODO(b/140236959): Find a way to have this method return only one thing. The # best approach is probably to add (to StatefulAggregateFn) a property that # stores the type of the global state. aggregate_fn = computation_utils.StatefulAggregateFn( initialize_fn=initialize_fn, next_fn=next_fn) return (aggregate_fn, initialize_fn.type_signature.result)
def federated_aggregate_test(values, weights): aggregate_fn = computation_utils.StatefulAggregateFn( initialize_fn=agg_initialize_fn, next_fn=agg_next_fn) state = tff.federated_value(aggregate_fn.initialize(), tff.SERVER) return aggregate_fn(state, values, weights)
def build_dp_aggregate(query, value_type_fn=_default_get_value_type_fn, from_anon_tuple_fn=_default_from_anon_tuple_fn): """Builds a stateful aggregator for tensorflow_privacy DPQueries. The returned StatefulAggregateFn can be called with any nested structure for the values being statefully aggregated. However, it's necessary to provide two functions as arguments which indicate the properties (the tff.Type and the AnonymousTuple conversion) of the nested structure that will be used. If using an OrderedDict as the value's nested structure, the defaults for the arguments suffice. Args: query: A DPQuery to aggregate. For compatibility with tensorflow_federated, the global_state and sample_state of the query must be structures supported by tf.nest. value_type_fn: Python function that takes the value argument of next_fn and returns the value type. This will be used in determining the TensorSpecs that establish the initial sample state. If the value being aggregated is an OrderedDict, the default for this argument can be used. This argument probably gets removed once b/123092620 is addressed (and the associated processing step gets replaced with a simple call to value.type_signature.member). from_anon_tuple_fn: Python function that takes a client record and converts it to the container type that it was in before passing through TFF. (Right now, TFF computation causes the client record to be changed into an AnonymousTuple, and this method corrects for that). If the value being aggregated is an OrderedDict, the default for this argument can be used. This argument likely goes away once b/123092620 is addressed. The default behavior assumes that the client record (before being converted to AnonymousTuple) was an OrderedDict containing a flat structure of Tensors (as it is if using the tff.learning APIs like tff.learning.build_federated_averaging_process). Returns: A tuple of: - a `computation_utils.StatefulAggregateFn` that aggregates according to the query - the TFF type of the DP aggregator's global state """ @tff.tf_computation def initialize_fn(): return query.initial_global_state() def next_fn(global_state, value, weight=None): """Defines next_fn for StatefulAggregateFn.""" # Weighted aggregation is not supported. # TODO(b/140236959): Add an assertion that weight is None here, so the # contract of this method is better established. Will likely cause some # downstream breaks. del weight ####################################### # Define local tf_computations # TODO(b/129567727): Make most of these tf_computations polymorphic # so type manipulation isn't needed. global_state_type = initialize_fn.type_signature.result @tff.tf_computation(global_state_type) def derive_sample_params(global_state): return query.derive_sample_params(global_state) @tff.tf_computation(derive_sample_params.type_signature.result, value.type_signature.member) def preprocess_record(params, record): # TODO(b/123092620): Once TFF passes the expected container type (instead # of AnonymousTuple), we shouldn't need this. record = from_anon_tuple_fn(record) return query.preprocess_record(params, record) # TODO(b/123092620): We should have the expected container type here. value_type = value_type_fn(value) tensor_specs = tff_framework.type_to_tf_tensor_specs(value_type) @tff.tf_computation def zero(): return query.initial_sample_state(tensor_specs) sample_state_type = zero.type_signature.result @tff.tf_computation(sample_state_type, preprocess_record.type_signature.result) def accumulate(sample_state, preprocessed_record): return query.accumulate_preprocessed_record( sample_state, preprocessed_record) @tff.tf_computation(sample_state_type, sample_state_type) def merge(sample_state_1, sample_state_2): return query.merge_sample_states(sample_state_1, sample_state_2) @tff.tf_computation(merge.type_signature.result) def report(sample_state): return sample_state @tff.tf_computation(sample_state_type, global_state_type) def post_process(sample_state, global_state): result, new_global_state = query.get_noised_result( sample_state, global_state) return new_global_state, result ####################################### # Orchestration logic sample_params = tff.federated_apply(derive_sample_params, global_state) client_sample_params = tff.federated_broadcast(sample_params) preprocessed_record = tff.federated_map(preprocess_record, (client_sample_params, value)) agg_result = tff.federated_aggregate(preprocessed_record, zero(), accumulate, merge, report) return tff.federated_apply(post_process, (agg_result, global_state)) # TODO(b/140236959): Find a way to have this method return only one thing. The # best approach is probably to add (to StatefulAggregateFn) a property that # stores the type of the global state. aggregate_fn = computation_utils.StatefulAggregateFn( initialize_fn=initialize_fn, next_fn=next_fn) return (aggregate_fn, initialize_fn.type_signature.result)