def test_returns_trees_with_one_federated_secure_sum_and_two_federated_aggregates( self): federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') federated_secure_sum = compiler_test_utils.create_whimsy_called_federated_secure_sum( ) called_intrinsics = building_blocks.Struct([ federated_secure_sum, federated_aggregate, federated_aggregate, ]) comp = building_blocks.Lambda('d', tf.int32, called_intrinsics) uri = [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def test_returns_trees_with_one_federated_secure_sum(self): federated_secure_sum = compiler_test_utils.create_whimsy_called_federated_secure_sum( ) called_intrinsics = building_blocks.Struct([federated_secure_sum]) comp = building_blocks.Lambda('a', tf.int32, called_intrinsics) uri = [intrinsic_defs.FEDERATED_SECURE_SUM.uri] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def test_returns_str_on_nested_secure_aggregation(self): comp = test_utils.create_whimsy_called_federated_secure_sum( (tf.int32, tf.int32)) self.assert_one_aggregation(comp)
non_aggregation_intrinsics = building_blocks.Struct([ (None, test_utils.create_whimsy_called_federated_broadcast()), (None, test_utils.create_whimsy_called_federated_value(placements.CLIENTS)) ]) unit = computation_types.StructType([]) trivial_aggregate = test_utils.create_whimsy_called_federated_aggregate( value_type=unit) trivial_collect = test_utils.create_whimsy_called_federated_collect(unit) trivial_mean = test_utils.create_whimsy_called_federated_mean(unit) trivial_sum = test_utils.create_whimsy_called_federated_sum(unit) # TODO(b/120439632) Enable once federated_mean accepts structured weights. # trivial_weighted_mean = ... trivial_secure_sum = test_utils.create_whimsy_called_federated_secure_sum(unit) class ContainsAggregationShared(parameterized.TestCase): @parameterized.named_parameters([ ('non_aggregation_intrinsics', non_aggregation_intrinsics), ('trivial_aggregate', trivial_aggregate), ('trivial_collect', trivial_collect), ('trivial_mean', trivial_mean), ('trivial_sum', trivial_sum), # TODO(b/120439632) Enable once federated_mean accepts structured weight. # ('trivial_weighted_mean', trivial_weighted_mean), ('trivial_secure_sum', trivial_secure_sum), ]) def test_returns_none(self, comp): self.assertEmpty(tree_analysis.find_unsecure_aggregation_in_tree(comp))