def test_federated_init_state_not_assignable(self): initialize_fn = federated_computation.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = federated_computation.federated_computation( FederatedType(tf.int32, placements.CLIENTS))(lambda state: state) with self.assertRaises(errors.TemplateStateNotAssignableError): iterative_process.IterativeProcess(initialize_fn, next_fn)
def test_federated_init_state_not_assignable(self): zero = lambda: intrinsics.federated_value(0, placements.SERVER) initialize_fn = federated_computation.federated_computation()(zero) next_fn = federated_computation.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS))( lambda state: MeasuredProcessOutput(state, zero(), zero())) with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def test_federated_next_state_not_assignable(self): initialize_fn = federated_computation.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = federated_computation.federated_computation( initialize_fn.type_signature.result)( intrinsics.federated_broadcast) with self.assertRaises(errors.TemplateStateNotAssignableError): iterative_process.IterativeProcess(initialize_fn, next_fn)
def _constant_process(value): """Creates an `EstimationProcess` that reports a constant value.""" init_fn = federated_computation.federated_computation( lambda: intrinsics.federated_value((), placements.SERVER)) next_fn = federated_computation.federated_computation( lambda state, value: state, init_fn.type_signature.result, computation_types.at_clients(NORM_TF_TYPE)) report_fn = federated_computation.federated_computation( lambda state: intrinsics.federated_value(value, placements.SERVER), init_fn.type_signature.result) return estimation_process.EstimationProcess(init_fn, next_fn, report_fn)
def test_federated_report_state_not_assignable(self): initialize_fn = federated_computation.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = federated_computation.federated_computation( initialize_fn.type_signature.result)(lambda state: state) report_fn = federated_computation.federated_computation( computation_types.FederatedType( tf.int32, placements.CLIENTS))(lambda state: state) with self.assertRaises(errors.TemplateStateNotAssignableError): estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn)
def federated_output_computation_from_metrics( metrics: List[tf.keras.metrics.Metric] ) -> federated_computation.federated_computation: """Produces a federated computation for aggregating Keras metrics. This can be used to evaluate both Keras and non-Keras models using Keras metrics. Aggregates metrics across clients by summing their internal variables, producing new metrics with summed internal variables, and calling metric.result() on each. See `tff.learning.federated_aggregate_keras_metric` for details. Args: metrics: A List of `tf.keras.metrics.Metric` to aggregate. Returns: A `tff.federated_computation` aggregating metrics across clients by summing their internal variables, producing new metrics with summed internal variables, and calling metric.result() on each. """ # Get a sample of metric variables to use to determine its type. sample_metric_variables = read_metric_variables(metrics) metric_variable_type_dict = tf.nest.map_structure(tf.TensorSpec.from_tensor, sample_metric_variables) federated_local_outputs_type = computation_types.at_clients( metric_variable_type_dict) def federated_output(local_outputs): return base_utils.federated_aggregate_keras_metric(metrics, local_outputs) federated_output_computation = federated_computation.federated_computation( federated_output, federated_local_outputs_type) return federated_output_computation
def test_cardinality_free_data_descriptor_places_data(self): ds = data_descriptor.CardinalityFreeDataDescriptor( federated_computation.federated_computation( lambda x: intrinsics.federated_value(x, placements.CLIENTS), tf.int32), 1000, computation_types.TensorType(tf.int32)) self.assertEqual(str(ds.type_signature), 'int32@CLIENTS') @federated_computation.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS, all_equal=True)) def foo(x): return intrinsics.federated_sum(x) # Since this DataDescriptor does not specify its cardinality, the number of # values placed is inferred from the decault setting for the executor. with executor_test_utils.install_executor( executor_test_utils.LocalTestExecutorFactory( default_num_clients=1)): result = foo(ds) self.assertEqual(result, 1000) with executor_test_utils.install_executor( executor_test_utils.LocalTestExecutorFactory( default_num_clients=3)): result = foo(ds) self.assertEqual(result, 3000)
def test_raises_computation_no_dataset_parameter(self): no_dataset_comp = federated_computation.federated_computation( lambda x: x, [tf.int32]) with self.assertRaises( iterative_process_compositions.SequenceTypeNotFoundError): iterative_process_compositions.compose_dataset_computation_with_computation( int_dataset_computation, no_dataset_comp)
def test_raises_on_bad_process_next_single_param(self, make_factory): next_fn = federated_computation.federated_computation( lambda state: state, _float_at_server) norm = _test_norm_process(next_fn=next_fn) with self.assertRaisesRegex(TypeError, '.* must take two arguments.'): make_factory(norm)
def _bind_federated_value(unused_input, input_type, federated_output_value): federated_input_type = computation_types.FederatedType( input_type, placements.CLIENTS) wrapper = federated_computation.federated_computation( lambda _: federated_output_value, federated_input_type) return wrapper(unused_input)
def test_raises_on_bad_process_next_two_outputs(self, make_factory): next_fn = federated_computation.federated_computation( lambda state, val: (state, state), _float_at_server, _float_at_clients) norm = _test_norm_process(next_fn=next_fn) with self.assertRaisesRegex(TypeError, 'Result type .* state only.'): make_factory(norm)
def _test_float_next_fn(factor): @tensorflow_computation.tf_computation def shift_one(x): return x + (factor * 1.0) return federated_computation.federated_computation( lambda state, value: intrinsics.federated_map(shift_one, state), _float_at_server, _float_at_clients)
def test_raises_on_bad_process_next_not_float(self, make_factory): complex_at_clients = computation_types.at_clients(tf.complex64) next_fn = federated_computation.federated_computation( lambda state, value: state, _float_at_server, complex_at_clients) norm = _test_norm_process(next_fn=next_fn) with self.assertRaisesRegex(TypeError, 'Second argument .* assignable from'): make_factory(norm)
def test_raises_on_bad_norm_process_result(self, value, placement, make_factory): report_fn = federated_computation.federated_computation( lambda s: intrinsics.federated_value(value, placement), _float_at_server) norm = _test_norm_process(report_fn=report_fn) with self.assertRaisesRegex(TypeError, r'Result type .* assignable to'): make_factory(norm)
def test_non_server_placed_init_state_raises(self): initialize_fn = federated_computation.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)) @federated_computation.federated_computation(CLIENTS_INT, CLIENTS_FLOAT) def next_fn(state, val): return MeasuredProcessOutput(state, intrinsics.federated_sum(val), server_zero()) with self.assertRaises(aggregation_process.AggregationPlacementError): aggregation_process.AggregationProcess(initialize_fn, next_fn)
def test_init_tuple_of_federated_types_raises(self): initialize_fn = federated_computation.federated_computation()( lambda: (server_zero(), server_zero())) @federated_computation.federated_computation( initialize_fn.type_signature.result, CLIENTS_FLOAT) def next_fn(state, val): return MeasuredProcessOutput(state, intrinsics.federated_sum(val), ()) with self.assertRaises(aggregation_process.AggregationNotFederatedError): aggregation_process.AggregationProcess(initialize_fn, next_fn)
def test_federated_next_state_not_assignable(self): initialize_fn = federated_computation.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) @federated_computation.federated_computation( initialize_fn.type_signature.result) def next_fn(state): return MeasuredProcessOutput( intrinsics.federated_broadcast(state), (), ()) with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def test_init_tuple_of_federated_types_raises(self): initialize_fn = federated_computation.federated_computation()( lambda: (server_zero(), server_zero())) @federated_computation.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def next_fn(state, weights, update): return MeasuredProcessOutput(state, test_finalizer_result(weights, update), server_zero()) with self.assertRaises(errors.TemplateNotFederatedError): finalizers.FinalizerProcess(initialize_fn, next_fn)
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = federated_computation.federated_computation()(server_zero) @federated_computation.federated_computation(SERVER_INT, CLIENTS_FLOAT) def next_fn(state, val): return MeasuredProcessOutput( state, intrinsics.federated_sum(val), intrinsics.federated_value(1, placements.SERVER)) try: aggregation_process.AggregationProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail('Could not construct an AggregationProcess with empty state.')
def _create_test_measured_process_state_at_clients(): @federated_computation.federated_computation( computation_types.at_clients(tf.int32), computation_types.at_clients(tf.int32)) def next_fn(state, values): return measured_process.MeasuredProcessOutput( state, intrinsics.federated_sum(values), intrinsics.federated_value(1, placements.SERVER)) return measured_process.MeasuredProcess( initialize_fn=federated_computation.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)), next_fn=next_fn)
def test_non_server_placed_init_state_raises(self): initialize_fn = federated_computation.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)) @federated_computation.federated_computation(CLIENTS_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def next_fn(state, weights, update): return MeasuredProcessOutput(state, test_finalizer_result(weights, update), server_zero()) with self.assertRaises(errors.TemplatePlacementError): finalizers.FinalizerProcess(initialize_fn, next_fn)
def _create_test_aggregation_process(state_type, state_init, values_type): @federated_computation.federated_computation( computation_types.at_server(state_type), computation_types.at_clients(values_type)) def next_fn(state, values): return measured_process.MeasuredProcessOutput( state, intrinsics.federated_sum(values), intrinsics.federated_value(1, placements.SERVER)) return aggregation_process.AggregationProcess( initialize_fn=federated_computation.federated_computation( lambda: intrinsics.federated_value(state_init, placements.SERVER)), next_fn=next_fn)
def test_non_server_placed_init_state_raises(self): initialize_fn = federated_computation.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)) @federated_computation.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): return MeasuredProcessOutput(state, test_client_result(weights, data), server_zero()) with self.assertRaises(errors.TemplatePlacementError): client_works.ClientWorkProcess(initialize_fn, next_fn)
def test_init_tuple_of_federated_types_raises(self): initialize_fn = federated_computation.federated_computation()( lambda: (server_zero(), server_zero())) @federated_computation.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): return MeasuredProcessOutput(state, test_client_result(weights, data), server_zero()) with self.assertRaises(errors.TemplateNotFederatedError): client_works.ClientWorkProcess(initialize_fn, next_fn)
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = federated_computation.federated_computation()( lambda: intrinsics.federated_value((), placements.SERVER)) @federated_computation.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def next_fn(state, weights, update): return MeasuredProcessOutput( state, test_finalizer_result(weights, update), intrinsics.federated_value(1, placements.SERVER)) try: finalizers.FinalizerProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail('Could not construct an FinalizerProcess with empty state.')
def test_federated_mapped_process_as_expected(self): initialize_fn = federated_computation.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = federated_computation.federated_computation( initialize_fn.type_signature.result)(lambda state: state) report_fn = federated_computation.federated_computation( initialize_fn.type_signature.result)( lambda state: intrinsics.federated_map(test_report_fn, state)) process = estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn) map_fn = federated_computation.federated_computation( report_fn.type_signature.result )(lambda estimate: intrinsics.federated_map(test_map_fn, estimate)) mapped_process = process.map(map_fn) self.assertIsInstance(mapped_process, estimation_process.EstimationProcess) self.assertEqual(process.initialize, mapped_process.initialize) self.assertEqual(process.next, mapped_process.next) self.assertEqual(process.report.type_signature.parameter, mapped_process.report.type_signature.parameter) self.assertEqual(map_fn.type_signature.result, mapped_process.report.type_signature.result)
def _encoded_init_fn(encoders): """Creates `init_fn` for the process returned by `EncodedSumFactory`. The state for the `EncodedSumFactory` is directly derived from the state of the `GatherEncoder` objects that parameterize the functionality. Args: encoders: A collection of `GatherEncoder` objects. Returns: A no-arg `tff.Computation` returning initial state for `EncodedSumFactory`. """ init_fn_tf = tensorflow_computation.tf_computation( lambda: tf.nest.map_structure(lambda e: e.initial_state(), encoders)) init_fn = federated_computation.federated_computation( lambda: intrinsics.federated_eval(init_fn_tf, placements.SERVER)) return init_fn
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = federated_computation.federated_computation()( lambda: intrinsics.federated_value((), placements.SERVER)) @federated_computation.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): return MeasuredProcessOutput( state, test_client_result(weights, data), intrinsics.federated_value(1, placements.SERVER)) try: client_works.ClientWorkProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail( 'Could not construct an ClientWorkProcess with empty state.')
def test_federated_measured_process_output_raises(self): initialize_fn = federated_computation.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) empty = lambda: intrinsics.federated_value((), placements.SERVER) state_type = initialize_fn.type_signature.result # Using federated_zip to place FederatedType at the top of the hierarchy. @federated_computation.federated_computation(state_type) def next_fn(state): return intrinsics.federated_zip( MeasuredProcessOutput(state, empty(), empty())) # A MeasuredProcessOutput containing three `FederatedType`s is different # than a `FederatedType` containing a MeasuredProcessOutput. Corrently, only # the former is considered valid. with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def test_federated(self): ds = data_descriptor.DataDescriptor( federated_computation.federated_computation( lambda x: intrinsics.federated_value(x, placements.CLIENTS), tf.int32), 1000, computation_types.TensorType(tf.int32), 3) self.assertEqual(str(ds.type_signature), 'int32@CLIENTS') @federated_computation.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS, all_equal=True)) def foo(x): return intrinsics.federated_sum(x) with executor_test_utils.install_executor( executor_test_utils.LocalTestExecutorFactory()): result = foo(ds) self.assertEqual(result, 3000)