def test_run_encoded_mean(self): value = np.array([0.0, 1.0, 2.0, -1.0]) value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype)) value_type = tff.to_type(value_spec) encoder = te.encoders.as_gather_encoder(te.encoders.identity(), value_spec) gather_fn = encoding_utils.build_encoded_mean(value, encoder) initial_state = gather_fn.initialize() @tff.federated_computation( tff.FederatedType(gather_fn._initialize_fn.type_signature.result, tff.SERVER), tff.FederatedType(value_type, tff.CLIENTS), tff.FederatedType(tff.to_type(tf.float32), tff.CLIENTS)) def call_gather(state, value, weight): return gather_fn(state, value, weight) _, value_mean = call_gather(initial_state, [value, value], [1.0, 1.0]) self.assertAllClose(1 * value, value_mean) _, value_mean = call_gather(initial_state, [value, value], [0.3, 0.7]) self.assertAllClose(1 * value, value_mean) _, value_mean = call_gather(initial_state, [value, 2 * value], [1.0, 2.0]) self.assertAllClose(5 / 3 * value, value_mean)
def build_encoded_mean(values, encoders): """Builds `StatefulAggregateFn` for `values`, to be encoded by `encoders`. Args: values: Values to be encoded by the `StatefulAggregateFn`. Must be convertible to `tff.Value`. encoders: A collection of `GatherEncoder` objects to be used for encoding `values`. Must have the same structure as `values`. Returns: A `StatefulAggregateFn` of which `next_fn` encodes the input at `tff.CLIENTS`, and computes their mean at `tff.SERVER`, automatically splitting the decoding part based on its commutativity with sum. Raises: ValueError: If `values` and `encoders` do not have the same structure. TypeError: If `encoders` are not instances of `GatherEncoder`, or if `values` are not compatible with the expected input of the `encoders`. """ tf.nest.assert_same_structure(values, encoders) tf.nest.map_structure( lambda e, v: _validate_encoder(e, v, te.core.GatherEncoder), encoders, values) value_type = tff_framework.type_from_tensors(values) initial_state_fn = _build_initial_state_tf_computation(encoders) state_type = initial_state_fn.type_signature.result nest_encoder = _build_tf_computations_for_gather(state_type, value_type, encoders) encoded_sum_fn = _build_encoded_sum_fn(nest_encoder) @tff.tf_computation(value_type, tff.to_type(tf.float32)) def multiply_fn(value, weight): return tf.nest.map_structure(lambda v: v * tf.cast(weight, v.dtype), value) @tff.tf_computation(value_type, tff.to_type(tf.float32)) def divide_fn(value, denominator): return tf.nest.map_structure( lambda v: v / tf.cast(denominator, v.dtype), value) def encoded_mean_fn(state, values, weight): weighted_values = tff.federated_map(multiply_fn, [values, weight]) updated_state, summed_decoded_values = encoded_sum_fn( state, weighted_values) summed_weights = tff.federated_sum(weight) decoded_values = tff.federated_map( divide_fn, [summed_decoded_values, summed_weights]) return updated_state, decoded_values return StatefulAggregateFn(initialize_fn=initial_state_fn, next_fn=encoded_mean_fn)
def test_build_encoded_sum(self, value_constructor, encoder_constructor): value = value_constructor(np.random.rand(20)) value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype)) value_type = tff.to_type(value_spec) encoder = te.encoders.as_gather_encoder(encoder_constructor(), value_spec) gather_fn = encoding_utils.build_encoded_sum(value, encoder) state_type = gather_fn._initialize_fn.type_signature.result gather_signature = tff.federated_computation( gather_fn._next_fn, tff.FederatedType(state_type, tff.SERVER), tff.FederatedType(value_type, tff.CLIENTS), tff.FederatedType(tff.to_type(tf.float32), tff.CLIENTS)).type_signature self.assertIsInstance(gather_fn, StatefulAggregateFn) self.assertEqual(state_type, gather_signature.result[0].member) self.assertEqual(tff.SERVER, gather_signature.result[0].placement) self.assertEqual(value_type, gather_signature.result[1].member) self.assertEqual(tff.SERVER, gather_signature.result[1].placement)
def test_build_encoded_broadcast(self, value_constructor, encoder_constructor): value = value_constructor(np.random.rand(20)) value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype)) value_type = tff.to_type(value_spec) encoder = te.core.SimpleEncoder(encoder_constructor(), value_spec) broadcast_fn = encoding_utils.build_encoded_broadcast(value, encoder) broadcast_signature = broadcast_fn._next_fn.type_signature self.assertIsInstance(broadcast_fn, StatefulBroadcastFn) self.assertEqual(value_type, broadcast_signature.parameter[1].member) self.assertEqual(value_type, broadcast_signature.result[1].member) self.assertEqual(broadcast_signature.parameter[1].placement, tff.SERVER) self.assertEqual(broadcast_signature.result[1].placement, tff.CLIENTS)
def _extract_intrinsic_as_reference_to_top_level_lambda(comp, uri): """Extracts an intrinsic from `comp` as a reference for the given `uri`. Args: comp: The `tff_framework.Lambda` to transform. The names of lambda parameters and block variables in `comp` must be unique. uri: A URI of an intrinsic. Returns: A new computation with the transformation applied or the original `comp`. Raises: ValueError: If there is more than one intrinsic for the give `uri` or if the intrinsic is not exclusively bound by `comp`. """ py_typecheck.check_type(comp, tff_framework.Lambda) tff_framework.check_has_unique_names(comp) py_typecheck.check_type(uri, six.string_types) intrinsics = _get_called_intrinsics(comp, uri) length = len(intrinsics) if length != 1: raise ValueError( 'Expected a computation with exactly one intrinsic with the uri: {}, ' 'found: {}.'.format(uri, length)) if not _are_comps_bound_exclusively_by_top_level_lambda(comp, intrinsics): raise ValueError( 'Expected a computation which binds all the references in the ' 'intrinsic with the uri: {}.'.format(uri)) name_generator = tff_framework.unique_name_generator(comp) extracted_intrinsic = intrinsics[0] ref_name = six.next(name_generator) ref_type = tff.to_type(extracted_intrinsic.type_signature) ref = tff_framework.Reference(ref_name, ref_type) def _should_transform(comp): return tff_framework.is_called_intrinsic(comp, uri) def _transform(comp): if not _should_transform(comp): return comp, False return ref, True comp, _ = tff_framework.transform_postorder(comp, _transform) comp = _insert_comp_in_top_level_lambda( comp, name=ref.name, comp_to_insert=extracted_intrinsic) return comp, True
def test_build_encoded_broadcast(self, value_constructor, encoder_constructor): value = value_constructor(np.random.rand(20)) value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype)) value_type = tff.to_type(value_spec) encoder = te.encoders.as_simple_encoder(encoder_constructor(), value_spec) broadcast_fn = encoding_utils.build_encoded_broadcast(value, encoder) state_type = broadcast_fn._initialize_fn.type_signature.result broadcast_signature = tff.federated_computation( broadcast_fn._next_fn, tff.FederatedType( broadcast_fn._initialize_fn.type_signature.result, tff.SERVER), tff.FederatedType(value_type, tff.SERVER)).type_signature self.assertIsInstance(broadcast_fn, StatefulBroadcastFn) self.assertEqual(state_type, broadcast_signature.result[0].member) self.assertEqual(tff.SERVER, broadcast_signature.result[0].placement) self.assertEqual(value_type, broadcast_signature.result[1].member) self.assertEqual(tff.CLIENTS, broadcast_signature.result[1].placement)
def test_build_encode_decode_tf_computations_for_broadcast( self, encoder_constructor): value_spec = tf.TensorSpec((20, ), tf.float32) encoder = te.encoders.as_simple_encoder(encoder_constructor(), value_spec) _, state_type = encoding_utils._build_initial_state_tf_computation( encoder) value_type = tff.to_type(value_spec) encode_fn, decode_fn = ( encoding_utils._build_encode_decode_tf_computations_for_broadcast( state_type, value_type, encoder)) self.assertEqual(state_type, encode_fn.type_signature.parameter[0]) self.assertEqual(state_type, encode_fn.type_signature.result[0]) # Output of encode should be the input to decode. self.assertEqual(encode_fn.type_signature.result[1], decode_fn.type_signature.parameter) # Decode should return the same type as input to encode - value_type. self.assertEqual(value_type, encode_fn.type_signature.parameter[1]) self.assertEqual(value_type, decode_fn.type_signature.result)
def test_build_tf_computations_for_sum(self, encoder_constructor): # Tests that the partial computations have matching relevant input-output # signatures. value_spec = tf.TensorSpec((20, ), tf.float32) encoder = te.encoders.as_gather_encoder(encoder_constructor(), value_spec) initial_state_fn = encoding_utils._build_initial_state_tf_computation( encoder) state_type = initial_state_fn.type_signature.result value_type = tff.to_type(value_spec) nest_encoder = encoding_utils._build_tf_computations_for_gather( state_type, value_type, encoder) self.assertEqual(state_type, nest_encoder.get_params_fn.type_signature.parameter) encode_params_type = nest_encoder.get_params_fn.type_signature.result[ 0] decode_before_sum_params_type = nest_encoder.get_params_fn.type_signature.result[ 1] decode_after_sum_params_type = nest_encoder.get_params_fn.type_signature.result[ 2] self.assertEqual(value_type, nest_encoder.encode_fn.type_signature.parameter[0]) self.assertEqual(encode_params_type, nest_encoder.encode_fn.type_signature.parameter[1]) self.assertEqual(decode_before_sum_params_type, nest_encoder.encode_fn.type_signature.parameter[2]) state_update_tensors_type = nest_encoder.encode_fn.type_signature.result[ 2] accumulator_type = nest_encoder.zero_fn.type_signature.result self.assertEqual(state_update_tensors_type, accumulator_type.state_update_tensors) self.assertEqual( accumulator_type, nest_encoder.accumulate_fn.type_signature.parameter[0]) self.assertEqual( nest_encoder.encode_fn.type_signature.result, nest_encoder.accumulate_fn.type_signature.parameter[1]) self.assertEqual(accumulator_type, nest_encoder.accumulate_fn.type_signature.result) self.assertEqual(accumulator_type, nest_encoder.merge_fn.type_signature.parameter[0]) self.assertEqual(accumulator_type, nest_encoder.merge_fn.type_signature.parameter[1]) self.assertEqual(accumulator_type, nest_encoder.merge_fn.type_signature.result) self.assertEqual(accumulator_type, nest_encoder.report_fn.type_signature.parameter) self.assertEqual(accumulator_type, nest_encoder.report_fn.type_signature.result) self.assertEqual( accumulator_type.values, nest_encoder.decode_after_sum_fn.type_signature.parameter[0]) self.assertEqual( decode_after_sum_params_type, nest_encoder.decode_after_sum_fn.type_signature.parameter[1]) self.assertEqual( value_type, nest_encoder.decode_after_sum_fn.type_signature.result) self.assertEqual( state_type, nest_encoder.update_state_fn.type_signature.parameter[0]) self.assertEqual( state_update_tensors_type, nest_encoder.update_state_fn.type_signature.parameter[1]) self.assertEqual(state_type, nest_encoder.update_state_fn.type_signature.result)
def __init__(self, initialize, prepare, work, zero, accumulate, merge, report, update): """Constructs a representation of a MapReduce-like iterative process. NOTE: All the computations supplied here as arguments must be TensorFlow computations, i.e., instances of `tff.Computation` constructed by the `tff.tf_computation` decorator/wrapper. Args: initialize: The computation that produces the initial server state. prepare: The computation that prepares the input for the clients. work: The client-side work computation. zero: The computation that produces the initial state for accumulators. accumulate: The computation that adds a client update to an accumulator. merge: The computation to use for merging pairs of accumulators. report: The computation that produces the final server-side aggregate for the top level accumulator (the global update). update: The computation that takes the global update and the server state and produces the new server state, as well as server-side output. Raises: TypeError: If the Python or TFF types of the arguments are invalid or not compatible with each other. AssertionError: If the manner in which the given TensorFlow computations are represented by TFF does not match what this code is expecting (this is an internal error that requires code update). """ for label, comp in [ ('initialize', initialize), ('prepare', prepare), ('work', work), ('zero', zero), ('accumulate', accumulate), ('merge', merge), ('report', report), ('update', update), ]: py_typecheck.check_type(comp, tff.Computation, label) # TODO(b/130633916): Remove private access once an appropriate API for it # becomes available. comp_proto = comp._computation_proto # pylint: disable=protected-access if not isinstance(comp_proto, computation_pb2.Computation): # Explicitly raised to force it to be done in non-debug mode as well. raise AssertionError('Cannot find the embedded computation definition.') which_comp = comp_proto.WhichOneof('computation') if which_comp != 'tensorflow': raise TypeError('Expected all computations supplied as arguments to ' 'be plain TensorFlow, found {}.'.format(which_comp)) if prepare.type_signature.parameter != initialize.type_signature.result: raise TypeError( 'The `prepare` computation expects an argument of type {}, ' 'which does not match the result type {} of `initialize`.'.format( prepare.type_signature.parameter, initialize.type_signature.result)) if (not isinstance(work.type_signature.parameter, tff.NamedTupleType) or len(work.type_signature.parameter) != 2): raise TypeError( 'The `work` computation expects an argument of type {} that is not ' 'a two-tuple.'.format(work.type_signature.parameter)) if work.type_signature.parameter[1] != prepare.type_signature.result: raise TypeError( 'The `work` computation expects an argument tuple with type {} as ' 'the second element (the initial client state from the server), ' 'which does not match the result type {} of `prepare`.'.format( work.type_signature.parameter[1], prepare.type_signature.result)) if (not isinstance(work.type_signature.result, tff.NamedTupleType) or len(work.type_signature.result) != 2): raise TypeError( 'The `work` computation returns a result of type {} that is not a ' 'two-tuple.'.format(work.type_signature.result)) expected_accumulate_type = tff.FunctionType( [zero.type_signature.result, work.type_signature.result[0]], zero.type_signature.result) if accumulate.type_signature != expected_accumulate_type: raise TypeError( 'The `accumulate` computation has type signature {}, which does ' 'not match the expected {} as implied by the type signatures of ' '`zero` and `work`.'.format(accumulate.type_signature, expected_accumulate_type)) expected_merge_type = tff.FunctionType( [accumulate.type_signature.result, accumulate.type_signature.result], accumulate.type_signature.result) if merge.type_signature != expected_merge_type: raise TypeError( 'The `merge` computation has type signature {}, which does ' 'not match the expected {} as implied by the type signature ' 'of `accumulate`.'.format(merge.type_signature, expected_merge_type)) if report.type_signature.parameter != merge.type_signature.result: raise TypeError( 'The `report` computation expects an argument of type {}, ' 'which does not match the result type {} of `merge`.'.format( report.type_signature.parameter, merge.type_signature.result)) expected_update_parameter_type = tff.to_type( [initialize.type_signature.result, report.type_signature.result]) if update.type_signature.parameter != expected_update_parameter_type: raise TypeError( 'The `update` computation expects an argument of type {}, ' 'which does not match the expected {} as implied by the type ' 'signatures of `initialize` and `report`.'.format( update.type_signature.parameter, expected_update_parameter_type)) if (not isinstance(update.type_signature.result, tff.NamedTupleType) or len(update.type_signature.result) != 2): raise TypeError( 'The `update` computation returns a result of type {} that is not ' 'a two-tuple.'.format(update.type_signature.result)) if update.type_signature.result[0] != initialize.type_signature.result: raise TypeError( 'The `update` computation returns a result tuple with type {} as ' 'the first element (the updated state of the server), which does ' 'not match the result type {} of `initialize`.'.format( update.type_signature.result[0], initialize.type_signature.result)) self._initialize = initialize self._prepare = prepare self._work = work self._zero = zero self._accumulate = accumulate self._merge = merge self._report = report self._update = update