def test_removes_selected_intrinsic_leaving_remaining_intrinsic(self):
   federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
       accumulate_parameter_name='a',
       merge_parameter_name='b',
       report_parameter_name='c')
   federated_secure_sum_bitwidth = compiler_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
   )
   called_intrinsics = building_blocks.Struct([
       federated_aggregate,
       federated_secure_sum_bitwidth,
   ])
   comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
   null_aggregate = building_block_factory.create_null_federated_aggregate()
   secure_sum_bitwidth_uri = federated_secure_sum_bitwidth.function.uri
   aggregate_uri = null_aggregate.function.uri
   before, after = transformations.force_align_and_split_by_intrinsics(
       comp, [null_aggregate])
   self.assertTrue(
       tree_analysis.contains_called_intrinsic(comp, secure_sum_bitwidth_uri))
   self.assertTrue(
       tree_analysis.contains_called_intrinsic(comp, aggregate_uri))
   self.assertFalse(
       tree_analysis.contains_called_intrinsic(before, aggregate_uri))
   self.assertFalse(
       tree_analysis.contains_called_intrinsic(after, aggregate_uri))
   self.assertTrue(
       tree_analysis.contains_called_intrinsic(before,
                                               secure_sum_bitwidth_uri) or
       tree_analysis.contains_called_intrinsic(after, secure_sum_bitwidth_uri))
 def test_splits_on_selected_intrinsic_secure_sum_bitwidth(self):
   federated_secure_sum_bitwidth = compiler_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
   )
   called_intrinsics = building_blocks.Struct([federated_secure_sum_bitwidth])
   comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
   call = building_block_factory.create_null_federated_secure_sum_bitwidth()
   self.assert_splits_on(comp, call)
 def test_splits_on_two_intrinsics(self):
   federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
       accumulate_parameter_name='a',
       merge_parameter_name='b',
       report_parameter_name='c')
   federated_secure_sum_bitwidth = compiler_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
   )
   called_intrinsics = building_blocks.Struct([
       federated_aggregate,
       federated_secure_sum_bitwidth,
   ])
   comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
   self.assert_splits_on(comp, [
       building_block_factory.create_null_federated_aggregate(),
       building_block_factory.create_null_federated_secure_sum_bitwidth()
   ])
 def test_returns_str_on_nested_secure_aggregation(self):
     comp = test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
         (tf.int32, tf.int32))
     self.assert_one_aggregation(comp)

non_aggregation_intrinsics = building_blocks.Struct([
    (None, test_utils.create_whimsy_called_federated_broadcast()),
    (None, test_utils.create_whimsy_called_federated_value(placements.CLIENTS))
])

unit = computation_types.StructType([])
trivial_aggregate = test_utils.create_whimsy_called_federated_aggregate(
    value_type=unit)
trivial_collect = test_utils.create_whimsy_called_federated_collect(unit)
trivial_mean = test_utils.create_whimsy_called_federated_mean(unit)
trivial_sum = test_utils.create_whimsy_called_federated_sum(unit)
# TODO(b/120439632) Enable once federated_mean accepts structured weights.
# trivial_weighted_mean = ...
trivial_secure_sum = test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
    unit)


class ContainsAggregationShared(parameterized.TestCase):
    @parameterized.named_parameters([
        ('non_aggregation_intrinsics', non_aggregation_intrinsics),
        ('trivial_aggregate', trivial_aggregate),
        ('trivial_collect', trivial_collect),
        ('trivial_mean', trivial_mean),
        ('trivial_sum', trivial_sum),
        # TODO(b/120439632) Enable once federated_mean accepts structured weight.
        # ('trivial_weighted_mean', trivial_weighted_mean),
        ('trivial_secure_sum', trivial_secure_sum),
    ])
    def test_returns_none(self, comp):
        self.assertEmpty(tree_analysis.find_unsecure_aggregation_in_tree(comp))