def trivial_aggregate(): empty_at_clients = intrinsics.federated_value((), placements.CLIENTS) zero = () accumulate = computations.tf_computation(lambda _a, _b: ()) merge = computations.tf_computation(lambda _a, _b: ()) report = computations.tf_computation(lambda _: ()) return intrinsics.federated_aggregate(empty_at_clients, zero, accumulate, merge, report)
def simple_aggregate(): one_at_clients = intrinsics.federated_value(1, placement_literals.CLIENTS) zero = 0 accumulate = computations.tf_computation(lambda a, b: a + b) merge = computations.tf_computation(lambda a, b: a + b) report = computations.tf_computation(lambda a: a) return intrinsics.federated_aggregate(one_at_clients, zero, accumulate, merge, report)
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = computations.tf_computation()(lambda: ()) next_fn = computations.tf_computation(())(lambda x: (x, 1.0)) report_fn = computations.tf_computation(())(lambda x: x) try: estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn) except: # pylint: disable=bare-except self.fail( 'Could not construct an EstimationProcess with empty state.')
def next_fn(state, value): state = intrinsics.federated_map( computations.tf_computation(lambda x: x + 1), state) result = intrinsics.federated_map( computations.tf_computation( lambda x: tf.nest.map_structure(lambda y: y + 1, x)), intrinsics.federated_sum(value)) measurements = intrinsics.federated_value(MEASUREMENT_CONSTANT, placements.SERVER) return measured_process.MeasuredProcessOutput( state, result, measurements)
def test_invoke_returns_value_with_correct_type(self): context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack) comp = computations.tf_computation(lambda: tf.constant(10)) result = context.invoke(comp, None) self.assertIsInstance(result, value_base.Value) self.assertEqual(str(result.type_signature), 'int32')
def test_roundtrip(self): add = computations.tf_computation(lambda x, y: x + y) server_data_type = computation_types.at_server(tf.int32) client_data_type = computation_types.at_clients(tf.int32) @computations.federated_computation(server_data_type, client_data_type) def add_server_number_plus_one(server_number, client_numbers): one = intrinsics.federated_value(1, placements.SERVER) server_context = intrinsics.federated_map(add, (one, server_number)) client_context = intrinsics.federated_broadcast(server_context) return intrinsics.federated_map(add, (client_context, client_numbers)) bf = form_utils.get_broadcast_form_for_computation( add_server_number_plus_one) self.assertEqual(bf.server_data_label, 'server_number') self.assertEqual(bf.client_data_label, 'client_numbers') self.assert_types_equivalent( bf.compute_server_context.type_signature, computation_types.FunctionType(tf.int32, (tf.int32,))) self.assertEqual(2, bf.compute_server_context(1)[0]) self.assert_types_equivalent( bf.client_processing.type_signature, computation_types.FunctionType(((tf.int32,), tf.int32), tf.int32)) self.assertEqual(3, bf.client_processing((1,), 2)) round_trip_comp = form_utils.get_computation_for_broadcast_form(bf) self.assert_types_equivalent(round_trip_comp.type_signature, add_server_number_plus_one.type_signature) # 2 (server data) + 1 (constant in comp) + 2 (client data) = 5 (output) self.assertEqual([5, 6, 7], round_trip_comp(2, [2, 3, 4]))
def test_next_return_namedtuple_raises(self): measured_process_output = collections.namedtuple( 'MeasuredProcessOutput', ['state', 'result', 'measurements']) namedtuple_next_fn = computations.tf_computation( tf.int32)(lambda state: measured_process_output(state, (), ())) with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): measured_process.MeasuredProcess(test_initialize_fn, namedtuple_next_fn)
def test_sequence_reduce(self): add_numbers = computations.tf_computation( lambda a, b: tf.add(a, b), # pylint: disable=unnecessary-lambda [tf.int32, tf.int32]) @computations.federated_computation( computation_types.SequenceType(tf.int32)) def foo1(x): val = intrinsics.sequence_reduce(x, 0, add_numbers) self.assertIsInstance(val, value_base.Value) return val self.assert_type(foo1, '(int32* -> int32)') @computations.federated_computation( computation_types.FederatedType( computation_types.SequenceType(tf.int32), placements.SERVER)) def foo2(x): zero = intrinsics.federated_value(0, placements.SERVER) value = intrinsics.sequence_reduce(x, zero, add_numbers) self.assertIsInstance(value, value_base.Value) return value self.assert_type(foo2, '(int32*@SERVER -> int32@SERVER)') @computations.federated_computation( computation_types.FederatedType( computation_types.SequenceType(tf.int32), placements.CLIENTS)) def foo3(x): zero = intrinsics.federated_value(0, placements.CLIENTS) value = intrinsics.sequence_reduce(x, zero, add_numbers) self.assertIsInstance(value, value_base.Value) return value self.assert_type(foo3, '({int32*}@CLIENTS -> {int32}@CLIENTS)')
def test_raises_zeroing_norm_fn_bad_arg(self): zeroing_norm_fn = computations.tf_computation(lambda x: x + 3, tf.int32) with self.assertRaisesRegex(TypeError, 'Argument of `zeroing_norm_fn`'): clipping_factory.ZeroingClippingFactory(2.0, zeroing_norm_fn, mean_factory.MeanFactory())
def initialize_comp(): if not isinstance(stateful_fn.initialize, computation_base.Computation): initialize = computations.tf_computation(stateful_fn.initialize) else: initialize = stateful_fn.initialize return intrinsics.federated_eval(initialize, placements.SERVER)
def test_sequence_reduce(self): add_numbers = computations.tf_computation(tf.add, [tf.int32, tf.int32]) @computations.federated_computation( computation_types.SequenceType(tf.int32)) def foo1(x): val = intrinsics.sequence_reduce(x, 0, add_numbers) self.assertIsInstance(val, value_base.Value) return val self.assert_type(foo1, '(int32* -> int32)') @computations.federated_computation( computation_types.FederatedType( computation_types.SequenceType(tf.int32), placements.SERVER)) def foo2(x): val = intrinsics.sequence_reduce(x, 0, add_numbers) self.assertIsInstance(val, value_base.Value) return val self.assert_type(foo2, '(int32*@SERVER -> int32@SERVER)') @computations.federated_computation( computation_types.FederatedType( computation_types.SequenceType(tf.int32), placements.CLIENTS)) def foo3(x): val = intrinsics.sequence_reduce(x, 0, add_numbers) self.assertIsInstance(val, value_base.Value) return val self.assert_type(foo3, '({int32*}@CLIENTS -> {int32}@CLIENTS)')
def test_sequence_reduce(self): add_numbers = computations.tf_computation(tf.add, [tf.int32, tf.int32]) @computations.federated_computation( computation_types.SequenceType(tf.int32)) def foo1(x): return intrinsics.sequence_reduce(x, 0, add_numbers) self.assertEqual(str(foo1.type_signature), '(int32* -> int32)') @computations.federated_computation( computation_types.FederatedType( computation_types.SequenceType(tf.int32), placements.SERVER, True)) def foo2(x): return intrinsics.sequence_reduce(x, 0, add_numbers) self.assertEqual(str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)') @computations.federated_computation( computation_types.FederatedType( computation_types.SequenceType(tf.int32), placements.CLIENTS)) def foo3(x): return intrinsics.sequence_reduce(x, 0, add_numbers) self.assertEqual(str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)')
def test_one_param(self): @tf.function def foo(x): return x + 1 tf_comp = computations.tf_computation(foo, tf.int32) self.assertEqual(tf_comp(1), 2)
def test_next_state_not_assignable_tuple_result(self): float_next_fn = computations.tf_computation( tf.float32, tf.float32)(lambda state, x: (tf.cast(state, tf.float32), x)) with self.assertRaises(errors.TemplateStateNotAssignableError): iterative_process.IterativeProcess(test_initialize_fn, float_next_fn)
def apply(transform_fn: computation_base.Computation, arg_process: EstimationProcess): """Builds an `EstimationProcess` by applying `transform_fn` to `arg_process`. Args: transform_fn: A `computation_base.Computation` to apply to the estimate of the arg_process. arg_process: An `EstimationProcess` to which the transformation will be applied. Returns: An estimation process that applies `transform_fn` to the result of calling `arg_process.get_estimate`. """ py_typecheck.check_type(transform_fn, computation_base.Computation) py_typecheck.check_type(arg_process, EstimationProcess) arg_process_estimate_type = arg_process.get_estimate.type_signature.result transform_fn_arg_type = transform_fn.type_signature.parameter if not transform_fn_arg_type.is_assignable_from(arg_process_estimate_type): raise errors.TemplateStateNotAssignableError( f'The return type of `get_estimate` of `arg_process` must be ' f'assignable to the input argument of `transform_fn`, but ' f'`get_estimate` returns type:\n{arg_process_estimate_type}\n' f'and the argument of `transform_fn` is:\n' f'{transform_fn_arg_type}') transformed_estimate_fn = computations.tf_computation( lambda state: transform_fn(arg_process.get_estimate(state)), arg_process.state_type) return EstimationProcess(initialize_fn=arg_process.initialize, next_fn=arg_process.next, get_estimate_fn=transformed_estimate_fn)
def next_fn(strings, val): new_state_fn = computations.tf_computation()( lambda s: tf.concat([s, tf.constant(['abc'])], axis=0)) return MeasuredProcessOutput( intrinsics.federated_map(new_state_fn, strings), intrinsics.federated_sum(val), intrinsics.federated_value(1, placements.SERVER))
def test_namedtuple_param(self): MyType = collections.namedtuple('MyType', ['x', 'y']) # pylint: disable=invalid-name @tf.function def foo(t): self.assertIsInstance(t, MyType) return t.x + t.y # Explicit type tf_comp = computations.tf_computation(foo, MyType(tf.int32, tf.int32)) self.assertEqual(tf_comp(MyType(1, 2)), 3) # Polymorphic tf_comp = computations.tf_computation(foo) self.assertEqual(tf_comp(MyType(1, 2)), 3)
class ContainsAggregationShared(parameterized.TestCase): @parameterized.named_parameters([ ('trivial_tf', computations.tf_computation(lambda: ())), ('trivial_tff', computations.federated_computation(lambda: ())), ('non_aggregation_intrinsics', non_aggregation_intrinsics), ('unused_aggregation', unused_aggregation), ('trivial_aggregate', trivial_aggregate), ('trivial_collect', trivial_collect), ('trivial_mean', trivial_mean), ('trivial_reduce', trivial_reduce), ('trivial_sum', trivial_sum), # TODO(b/120439632) Enable once federated_mean accepts structured weight. # ('trivial_weighted_mean', trivial_weighted_mean), ('trivial_secure_sum', trivial_secure_sum), ]) def test_returns_none(self, comp): self.assertEmpty( tree_analysis.find_unsecure_aggregation_in_tree( comp.to_building_block())) self.assertEmpty( tree_analysis.find_secure_aggregation_in_tree( comp.to_building_block())) def test_throws_on_unresolvable_function_call(self): input_ty = () output_ty = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS) @computations.federated_computation( computation_types.FunctionType(input_ty, output_ty)) def comp(unknown_func): return unknown_func(()) with self.assertRaises(ValueError): tree_analysis.find_unsecure_aggregation_in_tree( comp.to_building_block()) with self.assertRaises(ValueError): tree_analysis.find_secure_aggregation_in_tree( comp.to_building_block()) # functions without a federated output can't aggregate def test_returns_none_on_unresolvable_function_call_with_non_federated_output( self): input_ty = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS) output_ty = tf.int32 @computations.federated_computation( computation_types.FunctionType(input_ty, output_ty)) def comp(unknown_func): return unknown_func( intrinsics.federated_value(1, placement_literals.CLIENTS)) self.assertEmpty( tree_analysis.find_unsecure_aggregation_in_tree( comp.to_building_block())) self.assertEmpty( tree_analysis.find_secure_aggregation_in_tree( comp.to_building_block()))
def next_fn(state, weights, updates): new_weights = intrinsics.federated_map( computations.tf_computation(lambda x, y: x + y), (weights.trainable, updates)) new_weights = intrinsics.federated_zip( model_utils.ModelWeights(new_weights, ())) return measured_process.MeasuredProcessOutput(state, new_weights, empty_at_server())
def test_py_and_tf_args(self): @tf.function(autograph=False) def foo(x, y, add=True): return x + y if add else x - y # Note: tf.Functions support mixing tensorflow and Python arguments, # usually with the semantics you would expect. Currently, TFF does not # support this kind of mixing, even for Polymorphic TFF functions. # However, you can work around this by explicitly binding any Python # arguments on a tf.Function: tf_poly_add = computations.tf_computation(lambda x, y: foo(x, y, True)) tf_poly_sub = computations.tf_computation( lambda x, y: foo(x, y, False)) self.assertEqual(tf_poly_add(2, 1), 3) self.assertEqual(tf_poly_add(2., 1.), 3.) self.assertEqual(tf_poly_sub(2, 1), 1)
def test_explicit_tuple_param(self): # See also test_polymorphic_tuple_input @tf.function def foo(t): return t[0] + t[1] tf_comp = computations.tf_computation(foo, (tf.int32, tf.int32)) self.assertEqual(tf_comp((1, 2)), 3)
def foo(temperatures, threshold): return intrinsics.federated_sum( intrinsics.federated_map( computations.tf_computation( lambda x, y: tf.to_int32(tf.greater(x, y)), [tf.float32, tf.float32]), [temperatures, intrinsics.federated_broadcast(threshold)]))
def test_map_estimate_not_assignable(self): map_fn = computations.tf_computation( tf.int32)(lambda estimate: estimate) process = estimation_process.EstimationProcess(test_initialize_fn, test_next_fn, test_report_fn) with self.assertRaises(estimation_process.EstimateNotAssignableError): process.map(map_fn)
def next_comp(state, value): return measured_process.MeasuredProcessOutput( state=intrinsics.federated_map(_add_one, state), result=intrinsics.federated_broadcast(value), # Arbitrary metrics for testing. measurements=intrinsics.federated_map( computations.tf_computation( lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0), value))
def foo(temperatures, threshold): val = intrinsics.federated_sum( intrinsics.federated_map( computations.tf_computation( lambda x, y: tf.cast(tf.greater(x, y), tf.int32)), [temperatures, intrinsics.federated_broadcast(threshold)])) self.assertIsInstance(val, value_base.Value) return val
def test_non_federated_init_next_raises(self): initialize_fn = computations.tf_computation(lambda: 0) @computations.tf_computation(tf.int32, tf.float32) def next_fn(state, val): return MeasuredProcessOutput(state, val, ()) with self.assertRaises(errors.TemplateNotFederatedError): distributors.DistributionProcess(initialize_fn, next_fn)
def test_non_federated_init_next_raises(self): initialize_fn = computations.tf_computation(lambda: 0) @computations.tf_computation(tf.int32, tf.float32) def next_fn(state, val): return MeasuredProcessOutput(state, val, ()) with self.assertRaises(aggregation_process.AggregationNotFederatedError): aggregation_process.AggregationProcess(initialize_fn, next_fn)
def test_federated_map_with_client_dataset_reduce(self): ds = _mock_data_of_type( computation_types.at_clients(computation_types.SequenceType( tf.int32), all_equal=True)) val = intrinsics.federated_map( computations.tf_computation( lambda ds: ds.reduce(np.int32(0), lambda x, y: x + y)), ds) self.assert_value(val, '{int32}@CLIENTS')
def test_nested_tuple_input_explicit_types(self): @tf.function(autograph=False) def foo(tuple1, tuple2): return tuple1[0] + tuple1[1][0] + tuple1[1][1] + tuple2[ 0] + tuple2[1] tff_type = [(tf.int32, (tf.int32, tf.int32)), (tf.int32, tf.int32)] tf_comp = computations.tf_computation(foo, tff_type) self.assertEqual(tf_comp((1, (2, 3)), (0, 0)), 6)
def test_tensorflow_computation_with_lambda_and_selection(self): @computations.federated_computation(tf.int32, computation_types.FunctionType( tf.int32, tf.int32)) def apply_twice(x, f): return f(f(x)) add_one = computations.tf_computation(lambda x: x + 1, tf.int32) self.assertEqual(apply_twice(5, add_one), 7)