def test_non_server_placed_next_result_raises(self): init_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation(SERVER_INT, CLIENTS_INT)( lambda x, y: measured_process.MeasuredProcessOutput(x, y, x)) with self.assertRaises(TypeError): aggregation_process.AggregationProcess(init_fn, next_fn)
def test_single_param_next_raises(self): init_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation(SERVER_INT)( lambda x: measured_process.MeasuredProcessOutput(x, x, x)) with self.assertRaises(TypeError): aggregation_process.AggregationProcess(init_fn, next_fn)
def test_non_server_placed_init_state_raises(self): init_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)) next_fn = computations.federated_computation(CLIENTS_INT)( lambda x: measured_process_output(x, x, x)) with self.assertRaises(TypeError): aggregation_process.AggregationProcess(init_fn, next_fn)
def test_federated_aggregate_with_federated_zero_fails(self): @computations.federated_computation() def build_federated_zero(): val = intrinsics.federated_value(0, placements.SERVER) self.assertIsInstance(val, value_base.Value) return val @computations.tf_computation([tf.int32, tf.int32]) def accumulate(accu, elem): return accu + elem # The operator to use during the second stage simply adds total and count. @computations.tf_computation([tf.int32, tf.int32]) def merge(x, y): return x + y # The operator to use during the final stage simply computes the ratio. @computations.tf_computation(tf.int32) def report(accu): return accu def foo(x): return intrinsics.federated_aggregate(x, build_federated_zero(), accumulate, merge, report) with self.assertRaisesRegex( TypeError, 'Expected `zero` to be assignable to type int32, ' 'but was of incompatible type int32@SERVER'): computations.federated_computation( foo, computation_types.FederatedType(tf.int32, placements.CLIENTS))
def test_non_clients_placed_next_value_param_raises(self): init_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation( SERVER_INT, SERVER_INT)(lambda x, y: measured_process_output(x, y, x)) with self.assertRaises(TypeError): aggregation_process.AggregationProcess(init_fn, next_fn)
def test_federated_init_state_not_assignable(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation( computation_types.FederatedType( tf.int32, placements.CLIENTS))(lambda state: state) with self.assertRaises(errors.TemplateStateNotAssignableError): iterative_process.IterativeProcess(initialize_fn, next_fn)
def test_next_value_type_mismatch_raises(self): init_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation( SERVER_INT, CLIENTS_FLOAT)(lambda x, y: measured_process_output(x, x, x)) with self.assertRaises(TypeError): aggregation_process.AggregationProcess(init_fn, next_fn)
def test_federated_next_state_not_assignable(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation( initialize_fn.type_signature.result)( intrinsics.federated_broadcast) with self.assertRaises(errors.TemplateStateNotAssignableError): iterative_process.IterativeProcess(initialize_fn, next_fn)
def test_federated_init_state_not_assignable(self): zero = lambda: intrinsics.federated_value(0, placements.SERVER) initialize_fn = computations.federated_computation()(zero) next_fn = computations.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS))( lambda state: MeasuredProcessOutput(state, zero(), zero())) with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def test_federated_report_state_not_assignable(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation( initialize_fn.type_signature.result)(lambda state: state) report_fn = computations.federated_computation( computation_types.FederatedType( tf.int32, placements.CLIENTS))(lambda state: state) with self.assertRaises(errors.TemplateStateNotAssignableError): estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn)
def _constant_process(value): """Creates an `EstimationProcess` that reports a constant value.""" init_fn = computations.federated_computation( lambda: intrinsics.federated_value((), placements.SERVER)) next_fn = computations.federated_computation( lambda state, value: state, init_fn.type_signature.result, computation_types.at_clients(NORM_TF_TYPE)) report_fn = computations.federated_computation( lambda state: intrinsics.federated_value(value, placements.SERVER), init_fn.type_signature.result) return estimation_process.EstimationProcess(init_fn, next_fn, report_fn)
def test_build_encoded_broadcast(self, value_constructor, encoder_constructor): value = value_constructor(np.random.rand(20)) value_spec = tf.TensorSpec(value.shape, tf.dtypes.as_dtype(value.dtype)) value_type = computation_types.to_type(value_spec) encoder = te.encoders.as_simple_encoder(encoder_constructor(), value_spec) broadcast_fn = encoding_utils.build_encoded_broadcast(value, encoder) state_type = broadcast_fn._initialize_fn.type_signature.result broadcast_signature = computations.federated_computation( broadcast_fn._next_fn, computation_types.FederatedType( broadcast_fn._initialize_fn.type_signature.result, placements.SERVER), computation_types.FederatedType(value_type, placements.SERVER)).type_signature self.assertIsInstance(broadcast_fn, StatefulBroadcastFn) self.assertEqual(state_type, broadcast_signature.result[0].member) self.assertEqual(placements.SERVER, broadcast_signature.result[0].placement) self.assertEqual(value_type, broadcast_signature.result[1].member) self.assertEqual(placements.CLIENTS, broadcast_signature.result[1].placement)
def federated_output_computation_from_metrics( metrics: List[tf.keras.metrics.Metric] ) -> computations.federated_computation: """Produces a federated computation for aggregating Keras metrics. This can be used to evaluate both Keras and non-Keras models using Keras metrics. Aggregates metrics across clients by summing their internal variables, producing new metrics with summed internal variables, and calling metric.result() on each. See `tff.learning.federated_aggregate_keras_metric` for details. Args: metrics: A List of `tf.keras.metrics.Metric` to aggregate. Returns: A `tff.federated_computation` aggregating metrics across clients by summing their internal variables, producing new metrics with summed internal variables, and calling metric.result() on each. """ # Get a sample of metric variables to use to determine its type. sample_metric_variables = read_metric_variables(metrics) metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, sample_metric_variables) federated_local_outputs_type = computation_types.at_clients( metric_variable_type_dict) def federated_output(local_outputs): return base_utils.federated_aggregate_keras_metric( metrics, local_outputs) federated_output_computation = computations.federated_computation( federated_output, federated_local_outputs_type) return federated_output_computation
def test_raises_on_bad_process_next_single_param(self, make_factory): next_fn = computations.federated_computation(lambda state: state, _float_at_server) norm = _test_norm_process(next_fn=next_fn) with self.assertRaisesRegex(TypeError, '.* must take two arguments.'): make_factory(norm)
def test_raises_computation_no_dataset_parameter(self): no_dataset_comp = computations.federated_computation( lambda x: x, [tf.int32]) with self.assertRaises( iterative_process_compositions.SequenceTypeNotFoundError): iterative_process_compositions.compose_dataset_computation_with_computation( int_dataset_computation, no_dataset_comp)
def test_federated_map_injected_zip_fails_different_placements(self): def foo(x, y): return intrinsics.federated_map( computations.tf_computation(lambda x, y: x > 10, [tf.int32, tf.int32]), [x, y]) with self.assertRaisesRegex( TypeError, 'The value to be mapped must be a FederatedType or implicitly ' 'convertible to a FederatedType.'): computations.federated_computation(foo, [ computation_types.FederatedType(tf.int32, placements.SERVER), computation_types.FederatedType(tf.int32, placements.CLIENTS) ])
def _bind_federated_value(unused_input, input_type, federated_output_value): federated_input_type = computation_types.FederatedType( input_type, placements.CLIENTS) wrapper = computations.federated_computation( lambda _: federated_output_value, federated_input_type) return wrapper(unused_input)
def test_raises_on_bad_process_next_two_outputs(self, make_factory): next_fn = computations.federated_computation( lambda state, val: (state, state), _float_at_server, _float_at_clients) norm = _test_norm_process(next_fn=next_fn) with self.assertRaisesRegex(TypeError, 'Result type .* state only.'): make_factory(norm)
def federated_output_computation(self): def aggregate_metrics(client_metrics): return collections.OrderedDict( num_over=intrinsics.federated_sum(client_metrics.num_over)) return computations.federated_computation(aggregate_metrics)
class ContainsAggregationShared(parameterized.TestCase): @parameterized.named_parameters([ ('trivial_tf', computations.tf_computation(lambda: ())), ('trivial_tff', computations.federated_computation(lambda: ())), ('non_aggregation_intrinsics', non_aggregation_intrinsics), ('unused_aggregation', unused_aggregation), ('trivial_aggregate', trivial_aggregate), ('trivial_collect', trivial_collect), ('trivial_mean', trivial_mean), ('trivial_reduce', trivial_reduce), ('trivial_sum', trivial_sum), # TODO(b/120439632) Enable once federated_mean accepts structured weight. # ('trivial_weighted_mean', trivial_weighted_mean), ('trivial_secure_sum', trivial_secure_sum), ]) def test_returns_none(self, comp): self.assertEmpty( tree_analysis.find_unsecure_aggregation_in_tree( comp.to_building_block())) self.assertEmpty( tree_analysis.find_secure_aggregation_in_tree( comp.to_building_block())) def test_throws_on_unresolvable_function_call(self): input_ty = () output_ty = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS) @computations.federated_computation( computation_types.FunctionType(input_ty, output_ty)) def comp(unknown_func): return unknown_func(()) with self.assertRaises(ValueError): tree_analysis.find_unsecure_aggregation_in_tree( comp.to_building_block()) with self.assertRaises(ValueError): tree_analysis.find_secure_aggregation_in_tree( comp.to_building_block()) # functions without a federated output can't aggregate def test_returns_none_on_unresolvable_function_call_with_non_federated_output( self): input_ty = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS) output_ty = tf.int32 @computations.federated_computation( computation_types.FunctionType(input_ty, output_ty)) def comp(unknown_func): return unknown_func( intrinsics.federated_value(1, placement_literals.CLIENTS)) self.assertEmpty( tree_analysis.find_unsecure_aggregation_in_tree( comp.to_building_block())) self.assertEmpty( tree_analysis.find_secure_aggregation_in_tree( comp.to_building_block()))
def test_raises_on_bad_process_next_three_params(self, factory_cons): next_fn = computations.federated_computation( lambda state, value1, value2: state, _float_at_server, _float_at_clients, _float_at_clients) norm = _test_norm_process(next_fn=next_fn) with self.assertRaisesRegex(TypeError, '.* must take two arguments.'): factory_cons(norm)
def test_raises_on_bad_norm_process_result(self, value, placement, make_factory): report_fn = computations.federated_computation( lambda s: intrinsics.federated_value(value, placement), _float_at_server) norm = _test_norm_process(report_fn=report_fn) with self.assertRaisesRegex(TypeError, r'Result type .* assignable to'): make_factory(norm)
def test_raises_on_bad_process_next_not_float(self, make_factory): complex_at_clients = computation_types.at_clients(tf.complex64) next_fn = computations.federated_computation( lambda state, value: state, _float_at_server, complex_at_clients) norm = _test_norm_process(next_fn=next_fn) with self.assertRaisesRegex(TypeError, 'Second argument .* assignable from'): make_factory(norm)
def test_federated_next_state_not_assignable(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) @computations.federated_computation(initialize_fn.type_signature.result) def next_fn(state): return MeasuredProcessOutput( intrinsics.federated_broadcast(state), (), ()) with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def test_non_server_placed_init_state_raises(self): initialize_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)) @computations.federated_computation(CLIENTS_INT, CLIENTS_FLOAT) def next_fn(state, val): return MeasuredProcessOutput(state, intrinsics.federated_sum(val), server_zero()) with self.assertRaises(aggregation_process.AggregationPlacementError): aggregation_process.AggregationProcess(initialize_fn, next_fn)
def test_init_tuple_of_federated_types_raises(self): initialize_fn = computations.federated_computation()( lambda: (server_zero(), server_zero())) @computations.federated_computation(initialize_fn.type_signature.result, CLIENTS_FLOAT) def next_fn(state, val): return MeasuredProcessOutput(state, intrinsics.federated_sum(val), ()) with self.assertRaises(aggregation_process.AggregationNotFederatedError): aggregation_process.AggregationProcess(initialize_fn, next_fn)
def test_init_fn_with_client_placed_state_raises(self): init_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)) @computations.federated_computation(init_fn.type_signature.result, ClientIntSequenceType) def next_fn(state, client_values): return LearningProcessOutput(state, client_values) with self.assertRaises(learning_process.LearningProcessPlacementError): learning_process.LearningProcess(init_fn, next_fn, test_report_fn)
def test_non_server_placed_init_state_raises(self): initialize_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)) @computations.federated_computation(CLIENTS_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def next_fn(state, weights, update): return MeasuredProcessOutput( state, test_finalizer_result(weights, update), server_zero()) with self.assertRaises(errors.TemplatePlacementError): finalizers.FinalizerProcess(initialize_fn, next_fn)
def test_non_server_placed_init_state_raises(self): initialize_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)) @computations.federated_computation(CLIENTS_INT, SERVER_FLOAT) def next_fn(state, val): return MeasuredProcessOutput(state, intrinsics.federated_broadcast(val), server_zero()) with self.assertRaises(errors.TemplatePlacementError): distributors.DistributionProcess(initialize_fn, next_fn)
def test_init_tuple_of_federated_types_raises(self): initialize_fn = computations.federated_computation()( lambda: (server_zero(), server_zero())) @computations.federated_computation( initialize_fn.type_signature.result, SERVER_FLOAT) def next_fn(state, val): return MeasuredProcessOutput(state, intrinsics.federated_broadcast(val), server_zero()) with self.assertRaises(errors.TemplateNotFederatedError): distributors.DistributionProcess(initialize_fn, next_fn)