def test_throws_on_unresolvable_function_call(self): comp = building_blocks.Call( building_blocks.Data( 'unknown_func', computation_types.FunctionType( None, computation_types.at_clients(tf.int32)))) with self.assertRaises(ValueError): tree_analysis.find_unsecure_aggregation_in_tree(comp) with self.assertRaises(ValueError): tree_analysis.find_secure_aggregation_in_tree(comp)
def test_throws_on_unresolvable_function_call(self): input_ty = () output_ty = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS) @computations.federated_computation( computation_types.FunctionType(input_ty, output_ty)) def comp(unknown_func): return unknown_func(()) with self.assertRaises(ValueError): tree_analysis.find_unsecure_aggregation_in_tree( comp.to_building_block()) with self.assertRaises(ValueError): tree_analysis.find_secure_aggregation_in_tree( comp.to_building_block())
def test_returns_none(self, comp): self.assertEmpty( tree_analysis.find_unsecure_aggregation_in_tree( comp.to_building_block())) self.assertEmpty( tree_analysis.find_secure_aggregation_in_tree( comp.to_building_block()))
def test_returns_none_on_unresolvable_function_call_with_non_federated_output( self): input_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) output_type = tf.int32 comp = building_blocks.Call( building_blocks.Data( 'unknown_func', computation_types.FunctionType(input_type, output_type)), building_blocks.Data('client_data', input_type)) self.assertEmpty(tree_analysis.find_unsecure_aggregation_in_tree(comp)) self.assertEmpty(tree_analysis.find_secure_aggregation_in_tree(comp))
def test_returns_none_on_unresolvable_function_call_with_non_federated_output( self): input_ty = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS) output_ty = tf.int32 @computations.federated_computation( computation_types.FunctionType(input_ty, output_ty)) def comp(unknown_func): return unknown_func( intrinsics.federated_value(1, placement_literals.CLIENTS)) self.assertEmpty( tree_analysis.find_unsecure_aggregation_in_tree( comp.to_building_block())) self.assertEmpty( tree_analysis.find_secure_aggregation_in_tree( comp.to_building_block()))
def assert_not_contains_unsecure_aggregation(comp): """Asserts that `comp` contains no unsecure aggregation calls. Args: comp: A `tff.Computation`, often a function annotated with `tff.federated_computation` or `tff.tf_computation`. Note that polymorphic functions (those without the types of their arguments explicitly specified) will not yet be `tff.Computation`s. Raises: AssertionError if `comp` contains an unsecure aggregation call. ValueError if `comp` contains a call whose target function cannot be identified. This may result from calls to references or other indirect structures. """ py_typecheck.check_type(comp, computation_impl.ComputationImpl) comp = comp.to_building_block() calls = tree_analysis.find_unsecure_aggregation_in_tree(comp) if len(calls) != 0: # pylint: disable=g-explicit-length-test _raise_expected_none(calls, 'unsecure')
def test_returns_one_on_unsecure_aggregation(self, comp): self.assertLen(tree_analysis.find_unsecure_aggregation_in_tree(comp), 1)
def test_returns_none_on_secure_aggregation(self): self.assertEmpty( tree_analysis.find_unsecure_aggregation_in_tree(simple_secure_sum))