def test_constructor_with_type_mismatch(self): initialize = _build_initialize_comp(0) with self.assertRaisesRegex( TypeError, r'The return type of initialize_fn must be assignable.*'): @computations.federated_computation(tf.float32, tf.float32) def add_float32(current, val): return current + val measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_float32) with self.assertRaisesRegex( TypeError, 'The return type of next_fn must be assignable to the first parameter' ): @computations.federated_computation(tf.int32) def add_bad_result(_): return 0.0 measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_bad_result) with self.assertRaisesRegex( TypeError, 'The return type of next_fn must be assignable to the first parameter' ): @computations.federated_computation(tf.int32) def add_bad_multi_result(_): return 0.0, 0 measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_bad_multi_result) with self.assertRaisesRegex( TypeError, 'MeasuredProcess must return a NamedTupleType'): @computations.federated_computation(tf.int32) def add_not_tuple_result(_): return 0 measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_not_tuple_result) with self.assertRaisesRegex( TypeError, 'must match type signature <state=A,result=B,measurements=C>'): @computations.federated_computation(tf.int32) def add_not_named_tuple_result(_): return 0, 0, 0 measured_process.MeasuredProcess( initialize_fn=initialize, next_fn=add_not_named_tuple_result)
def test_constructor_with_initialize_bad_type(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): measured_process.MeasuredProcess(initialize_fn=None, next_fn=add_int32) with self.assertRaises(iterative_process.InitializeFnHasArgsError): @computations.federated_computation(tf.int32) def one_arg_initialize(one_arg): del one_arg # Unused. return values.to_value(0) measured_process.MeasuredProcess(initialize_fn=one_arg_initialize, next_fn=add_int32)
def test_next_return_namedtuple_raises(self): measured_process_output = collections.namedtuple( 'MeasuredProcessOutput', ['state', 'result', 'measurements']) namedtuple_next_fn = computations.tf_computation( tf.int32)(lambda state: measured_process_output(state, (), ())) with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): measured_process.MeasuredProcess(test_initialize_fn, namedtuple_next_fn)
def build_encoded_broadcast_process(value_type, encoders): """Builds `MeasuredProcess` for `value_type`, to be encoded by `encoders`. The returned `MeasuredProcess` has a next function with the TFF type signature: ``` (<state_type@SERVER, {value_type}@CLIENTS> -> <state=state_type@SERVER, result=value_type@SERVER, measurements=()@SERVER>) ``` Args: value_type: The type of values to be broadcasted by the `MeasuredProcess`. Either a `tff.TensorType` or a `tff.StructType`. encoders: A collection of `SimpleEncoder` objects to be used for encoding `values`. Must have the same structure as `values`. Returns: A `MeasuredProcess` of which `next_fn` encodes the input at `tff.SERVER`, broadcasts the encoded representation and decodes the encoded representation at `tff.CLIENTS`. Raises: ValueError: If `value_type` and `encoders` do not have the same structure. TypeError: If `encoders` are not instances of `SimpleEncoder`, or if `value_type` are not compatible with the expected input of the `encoders`. """ py_typecheck.check_type( value_type, (computation_types.TensorType, computation_types.StructType)) _validate_value_type_and_encoders(value_type, encoders, tensor_encoding.core.SimpleEncoder) initial_state_fn, state_type = _build_initial_state_tf_computation( encoders) @computations.federated_computation() def initial_state_comp(): return intrinsics.federated_eval(initial_state_fn, placements.SERVER) encode_fn, decode_fn = _build_encode_decode_tf_computations_for_broadcast( state_type, value_type, encoders) @computations.federated_computation( initial_state_comp.type_signature.result, computation_types.FederatedType(value_type, placements.SERVER)) def encoded_broadcast_comp(state, value): """Encoded broadcast federated_computation.""" empty_metrics = intrinsics.federated_value((), placements.SERVER) new_state, encoded_value = intrinsics.federated_map( encode_fn, (state, value)) client_encoded_value = intrinsics.federated_broadcast(encoded_value) client_value = intrinsics.federated_map(decode_fn, client_encoded_value) return measured_process.MeasuredProcessOutput( state=new_state, result=client_value, measurements=empty_metrics) return measured_process.MeasuredProcess(initialize_fn=initial_state_comp, next_fn=encoded_broadcast_comp)
def test_measured_process_output_as_state_raises(self): empty_output = lambda: MeasuredProcessOutput((), (), ()) initialize_fn = computations.tf_computation(empty_output) next_fn = computations.tf_computation( initialize_fn.type_signature.result)(lambda state: empty_output()) with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = computations.tf_computation()(lambda: ()) next_fn = computations.tf_computation( ())(lambda x: MeasuredProcessOutput(x, (), ())) try: measured_process.MeasuredProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail('Could not construct an MeasuredProcess with empty state.')
def test_federated_init_state_not_assignable(self): zero = lambda: intrinsics.federated_value(0, placements.SERVER) initialize_fn = computations.federated_computation()(zero) next_fn = computations.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS))( lambda state: MeasuredProcessOutput(state, zero(), zero())) with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def build_encoded_sum_process(value_type, encoders): """Builds `MeasuredProcess` for `value_type`, to be encoded by `encoders`. The returned `MeasuredProcess` has a next function with the TFF type signature: ``` (<state_type@SERVER, {value_type}@CLIENTS> -> <state=state_type@SERVER, result=value_type@SERVER, measurements=()@SERVER>) ``` Args: value_type: The type of values to be encoded by the `MeasuredProcess`. Either a `tff.TensorType` or a `tff.StructType`. encoders: A collection of `GatherEncoder` objects to be used for encoding `values`. Must have the same structure as `values`. Returns: A `MeasuredProcess` of which `next_fn` encodes the input at `tff.CLIENTS`, and computes their sum at `tff.SERVER`, automatically splitting the decoding part based on its commutativity with sum. Raises: ValueError: If `value_type` and `encoders` do not have the same structure. TypeError: If `encoders` are not instances of `GatherEncoder`, or if `value_type` are not compatible with the expected input of the `encoders`. """ py_typecheck.check_type( value_type, (computation_types.TensorType, computation_types.StructType)) _validate_value_type_and_encoders(value_type, encoders, tensor_encoding.core.GatherEncoder) initial_state_fn, state_type = _build_initial_state_tf_computation( encoders) @computations.federated_computation() def initial_state_comp(): return intrinsics.federated_eval(initial_state_fn, placements.SERVER) nest_encoder = _build_tf_computations_for_gather(state_type, value_type, encoders) encoded_sum_fn = _build_encoded_sum_fn(nest_encoder) @computations.federated_computation( initial_state_comp.type_signature.result, computation_types.FederatedType(value_type, placements.CLIENTS)) def encoded_sum_comp(state, values): """Encoded sum federated_computation.""" empty_metrics = intrinsics.federated_value((), placements.SERVER) state, result = encoded_sum_fn(state, values) return collections.OrderedDict(state=state, result=result, measurements=empty_metrics) return measured_process.MeasuredProcess(initialize_fn=initial_state_comp, next_fn=encoded_sum_comp)
def test_constructor_with_next_result_param_type_mismatch(self): initialize = _build_initialize_comp(0) with self.assertRaises(iterative_process.NextMustReturnStateError): @computations.federated_computation(tf.int32) def add_bad_result(_): return 0.0 measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_bad_result)
def test_not_finalizer_type_raises(self): finalizer = test_finalizer() bad_finalizer = measured_process.MeasuredProcess( finalizer.initialize, finalizer.next) with self.assertRaisesRegex(TypeError, 'FinalizerProcess'): composers.compose_learning_process(test_init_model_weights_fn, test_distributor(), test_client_work(), test_aggregator(), bad_finalizer)
def _wrap_in_measured_process( stateful_fn: Union[computation_utils.StatefulBroadcastFn, computation_utils.StatefulAggregateFn], input_type: computation_types.Type ) -> measured_process.MeasuredProcess: """Converts a `computation_utils.StatefulFn` to a `tff.templates.MeasuredProcess`.""" py_typecheck.check_type(stateful_fn, (computation_utils.StatefulBroadcastFn, computation_utils.StatefulAggregateFn)) @computations.federated_computation() def initialize_comp(): if not isinstance(stateful_fn.initialize, computation_base.Computation): initialize = computations.tf_computation(stateful_fn.initialize) else: initialize = stateful_fn.initialize return intrinsics.federated_eval(initialize, placements.SERVER) state_type = initialize_comp.type_signature.result if isinstance(stateful_fn, computation_utils.StatefulBroadcastFn): @computations.federated_computation( state_type, computation_types.FederatedType(input_type, placements.SERVER), ) def next_comp(state, value): empty_metrics = intrinsics.federated_value((), placements.SERVER) state, result = stateful_fn(state, value) return collections.OrderedDict(state=state, result=result, measurements=empty_metrics) elif isinstance(stateful_fn, computation_utils.StatefulAggregateFn): @computations.federated_computation( state_type, computation_types.FederatedType(input_type, placements.CLIENTS), computation_types.FederatedType(tf.float32, placements.CLIENTS)) def next_comp(state, value, weight): empty_metrics = intrinsics.federated_value((), placements.SERVER) state, result = stateful_fn(state, value, weight) return collections.OrderedDict(state=state, result=result, measurements=empty_metrics) else: raise TypeError( 'Received a {t}, expected either a computation_utils.StatefulAggregateFn or a ' 'computation_utils.StatefulBroadcastFn.'.format( t=type(stateful_fn))) return measured_process.MeasuredProcess(initialize_fn=initialize_comp, next_fn=next_comp)
def test_constructor_with_state_only(self): ip = measured_process.MeasuredProcess(_build_initialize_comp(0), count_int32) state = ip.initialize() iterations = 10 for _ in range(iterations): state, result, measurements = attr.astuple(ip.next(state)) self.assertLen(result, 0) self.assertLen(measurements, 0) self.assertEqual(state, iterations)
def test_federated_next_state_not_assignable(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) @computations.federated_computation(initialize_fn.type_signature.result) def next_fn(state): return MeasuredProcessOutput( intrinsics.federated_broadcast(state), (), ()) with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def test_constructor_with_init_next_type_mismatch(self): initialize = _build_initialize_comp(0) with self.assertRaises( iterative_process.NextMustAcceptStateFromInitializeError): @computations.federated_computation(tf.float32, tf.float32) def add_float32(current, val): return current + val measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_float32)
def test_constructor_with_init_next_type_mismatch(self): initialize = _build_initialize_comp(0) with self.assertRaisesRegex( TypeError, r'The return type of initialize_fn must be assignable.*'): @computations.federated_computation(tf.float32, tf.float32) def add_float32(current, val): return current + val measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_float32)
def test_is_valid_broadcast_process_bad_placement(self): @federated_computation.federated_computation() def stateless_init(): return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation( computation_types.FederatedType((), placements.SERVER), computation_types.FederatedType((), placements.SERVER), ) def fake_broadcast(state, value): empty_metrics = intrinsics.federated_value(1.0, placements.SERVER) return measured_process.MeasuredProcessOutput( state=state, result=value, measurements=empty_metrics) stateless_process = measured_process.MeasuredProcess( initialize_fn=stateless_init, next_fn=fake_broadcast) # Expect to be false because `result` of `next` is on the server. self.assertFalse( optimizer_utils.is_valid_broadcast_process(stateless_process)) @federated_computation.federated_computation() def stateless_init2(): return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation( computation_types.FederatedType((), placements.SERVER), computation_types.FederatedType((), placements.CLIENTS), ) def stateless_broadcast(state, value): empty_metrics = intrinsics.federated_value(1.0, placements.SERVER) return measured_process.MeasuredProcessOutput( state=state, result=value, measurements=empty_metrics) stateless_process = measured_process.MeasuredProcess( initialize_fn=stateless_init2, next_fn=stateless_broadcast) # Expect to be false because second param of `next` is on the clients. self.assertFalse( optimizer_utils.is_valid_broadcast_process(stateless_process))
def test_constructor_with_next_result_param_type_mismatch(self): initialize = _build_initialize_comp(0) with self.assertRaisesRegex( TypeError, 'The return type of next_fn must be assignable to the first parameter' ): @computations.federated_computation(tf.int32) def add_bad_result(_): return 0.0 measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_bad_result)
def test_constructor_with_state_tuple_arg(self): ip = measured_process.MeasuredProcess(_build_initialize_comp(0), add_int32) state = ip.initialize() iterations = 10 for val in range(iterations): output = ip.next(state, val) state = output.state self.assertEqual(output.state, sum(range(iterations))) self.assertEqual(output.result, val) expected_measurment = sum(range(iterations - 1)) / iterations self.assertAllClose(output.measurements, [expected_measurment])
def test_constructor_with_next_result_not_measuredprocessoutput(self): initialize = _build_initialize_comp(0) with self.assertRaisesRegex( TypeError, 'MeasuredProcess must return a MeasuredProcessOutput'): @computations.federated_computation(tf.int32) def add_not_tuple_result(_): return 0 measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_not_tuple_result) with self.assertRaisesRegex( TypeError, 'MeasuredProcess must return a MeasuredProcessOutput'): @computations.federated_computation(tf.int32) def add_not_named_tuple_result(_): return 0, 0, 0 measured_process.MeasuredProcess( initialize_fn=initialize, next_fn=add_not_named_tuple_result)
def _create_test_measured_process_state_at_clients(): @federated_computation.federated_computation( computation_types.at_clients(tf.int32), computation_types.at_clients(tf.int32)) def next_fn(state, values): return measured_process.MeasuredProcessOutput( state, intrinsics.federated_sum(values), intrinsics.federated_value(1, placements.SERVER)) return measured_process.MeasuredProcess( initialize_fn=federated_computation.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)), next_fn=next_fn)
def test_construction_with_unknown_dimension_does_not_raise(self): initialize_fn = computations.tf_computation()( lambda: tf.constant([], dtype=tf.string)) @computations.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string)) def next_fn(strings): return MeasuredProcessOutput( tf.concat([strings, tf.constant(['abc'])], axis=0), (), ()) try: measured_process.MeasuredProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail('Could not construct an MeasuredProcess with parameter types ' 'with statically unknown shape.')
def test_measured_process_output_as_state_raises(self): no_value = lambda: intrinsics.federated_value((), placements.SERVER) @computations.federated_computation() def initialize_fn(): return intrinsics.federated_zip( MeasuredProcessOutput(no_value(), no_value(), no_value())) @computations.federated_computation( initialize_fn.type_signature.result, CLIENTS_FLOAT) def next_fn(state, value): del state, value return MeasuredProcessOutput(no_value(), no_value(), no_value()) with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def test_federated_measured_process_output_raises(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) empty = lambda: intrinsics.federated_value((), placements.SERVER) state_type = initialize_fn.type_signature.result # Using federated_zip to place FederatedType at the top of the hierarchy. @computations.federated_computation(state_type) def next_fn(state): return intrinsics.federated_zip( MeasuredProcessOutput(state, empty(), empty())) # A MeasuredProcessOutput containing three `FederatedType`s is different # than a `FederatedType` containing a MeasuredProcessOutput. Corrently, only # the former is considered valid. with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def build_stateless_mean( *, model_delta_type: Union[computation_types.StructType, computation_types.TensorType] ) -> measured_process.MeasuredProcess: """Builds a `MeasuredProcess` that wraps` tff.federated_mean`.""" @computations.federated_computation( NONE_SERVER_TYPE, computation_types.FederatedType(model_delta_type, placements.CLIENTS), computation_types.FederatedType(tf.float32, placements.CLIENTS)) def stateless_mean(state, value, weight): empty_metrics = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput( state=state, result=intrinsics.federated_mean(value, weight=weight), measurements=empty_metrics) return measured_process.MeasuredProcess( initialize_fn=_empty_server_initialization, next_fn=stateless_mean)
def test_constructor_with_tensors_unknown_dimensions_succeeds(self): @computations.tf_computation def init(): return tf.constant([], dtype=tf.string) @computations.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string)) def next_fn(strings): return MeasuredProcessOutput(state=tf.concat( [strings, tf.constant(['abc'])], axis=0), result=(), measurements=()) try: measured_process.MeasuredProcess(init, next_fn) except: # pylint: disable=bare-except self.fail( 'Could not construct an MeasuredProcess with parameter types ' 'including unknown dimension tennsors.')
def test_constructor_with_next_struct_of_different_placedresult(self): @computations.federated_computation def initialize_comp(): return intrinsics.federated_value(0, placements.SERVER) # A `next` function that returns different placements for the components. @computations.federated_computation( initialize_comp.type_signature.result) def next_comp(state): return measured_process.MeasuredProcessOutput( state=state, result=intrinsics.federated_value(0, placements.CLIENTS), measurements=intrinsics.federated_value((), placements.SERVER)) try: measured_process.MeasuredProcess(initialize_fn=initialize_comp, next_fn=next_comp) except Exception as e: # pylint: disable=broad-except self.fail(f'Failed to construct MeasuredProcess: {e}')
def build_stateless_broadcaster( *, model_weights_type: Union[computation_types.StructType, computation_types.TensorType] ) -> measured_process.MeasuredProcess: """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`.""" @computations.federated_computation( computation_types.FederatedType((), placements.SERVER), computation_types.FederatedType(model_weights_type, placements.SERVER), ) def stateless_broadcast(state, value): empty_metrics = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput( state=state, result=intrinsics.federated_broadcast(value), measurements=empty_metrics) return measured_process.MeasuredProcess( initialize_fn=_empty_server_initialization, next_fn=stateless_broadcast)
def test_federated_evaluation_fails_stateful_broadcast(self): # Create a test stateful measured process that doesn't do anything useful. @computations.federated_computation def init_fn(): return intrinsics.federated_eval( computations.tf_computation( lambda: tf.zeros(shape=[], dtype=tf.float32)), placements.SERVER) @computations.federated_computation( computation_types.at_server(tf.float32), computation_types.at_clients(tf.int32)) def next_fn(state, value): return measured_process.MeasuredProcessOutput(state, value, state) broadcaster = measured_process.MeasuredProcess(init_fn, next_fn) with self.assertRaisesRegex(ValueError, 'stateful broadcast'): federated_evaluation.build_federated_evaluation( TestModelQuant, broadcast_process=broadcaster)
def test_constructor_with_next_federated_same_placed_struct_result(self): @computations.federated_computation def initialize_comp(): return intrinsics.federated_value(0, placements.SERVER) # A `next` function that returns all the same placement and is zipped so # the FederatedType is at the top of the type hierarchy. @computations.federated_computation( initialize_comp.type_signature.result) def next_comp(state): return intrinsics.federated_zip( measured_process.MeasuredProcessOutput( state=state, result=intrinsics.federated_value(0, placements.SERVER), measurements=intrinsics.federated_value( (), placements.SERVER))), with self.assertRaises(iterative_process.NextMustReturnStateError): measured_process.MeasuredProcess(initialize_fn=initialize_comp, next_fn=next_comp)
def test_is_stateful_process_true(self): @computations.federated_computation() def stateful_init(): return intrinsics.federated_value(2.0, placements.SERVER) @computations.federated_computation( computation_types.FederatedType(tf.float32, placements.SERVER), computation_types.FederatedType((), placements.SERVER), ) def stateful_broadcast(state, value): empty_metrics = intrinsics.federated_value(1.0, placements.SERVER) return measured_process.MeasuredProcessOutput( state=state, result=intrinsics.federated_broadcast(value), measurements=empty_metrics) stateful_process = measured_process.MeasuredProcess( initialize_fn=stateful_init, next_fn=stateful_broadcast) self.assertTrue(optimizer_utils.is_stateful_process(stateful_process))