def test_returns_trees_with_one_federated_secure_sum_and_two_federated_aggregates(
            self):
        federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c')
        federated_secure_sum = compiler_test_utils.create_whimsy_called_federated_secure_sum(
        )
        called_intrinsics = building_blocks.Struct([
            federated_secure_sum,
            federated_aggregate,
            federated_aggregate,
        ])
        comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
        uri = [
            intrinsic_defs.FEDERATED_AGGREGATE.uri,
            intrinsic_defs.FEDERATED_SECURE_SUM.uri,
        ]

        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, uri)

        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
    def test_returns_trees_with_one_federated_secure_sum(self):
        federated_secure_sum = compiler_test_utils.create_whimsy_called_federated_secure_sum(
        )
        called_intrinsics = building_blocks.Struct([federated_secure_sum])
        comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
        uri = [intrinsic_defs.FEDERATED_SECURE_SUM.uri]

        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, uri)

        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
示例#3
0
 def test_returns_str_on_nested_secure_aggregation(self):
     comp = test_utils.create_whimsy_called_federated_secure_sum(
         (tf.int32, tf.int32))
     self.assert_one_aggregation(comp)
示例#4
0

non_aggregation_intrinsics = building_blocks.Struct([
    (None, test_utils.create_whimsy_called_federated_broadcast()),
    (None, test_utils.create_whimsy_called_federated_value(placements.CLIENTS))
])

unit = computation_types.StructType([])
trivial_aggregate = test_utils.create_whimsy_called_federated_aggregate(
    value_type=unit)
trivial_collect = test_utils.create_whimsy_called_federated_collect(unit)
trivial_mean = test_utils.create_whimsy_called_federated_mean(unit)
trivial_sum = test_utils.create_whimsy_called_federated_sum(unit)
# TODO(b/120439632) Enable once federated_mean accepts structured weights.
# trivial_weighted_mean = ...
trivial_secure_sum = test_utils.create_whimsy_called_federated_secure_sum(unit)


class ContainsAggregationShared(parameterized.TestCase):
    @parameterized.named_parameters([
        ('non_aggregation_intrinsics', non_aggregation_intrinsics),
        ('trivial_aggregate', trivial_aggregate),
        ('trivial_collect', trivial_collect),
        ('trivial_mean', trivial_mean),
        ('trivial_sum', trivial_sum),
        # TODO(b/120439632) Enable once federated_mean accepts structured weight.
        # ('trivial_weighted_mean', trivial_weighted_mean),
        ('trivial_secure_sum', trivial_secure_sum),
    ])
    def test_returns_none(self, comp):
        self.assertEmpty(tree_analysis.find_unsecure_aggregation_in_tree(comp))