def test_removes_selected_intrinsic_leaving_remaining_intrinsic(self):
     federated_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate(
         accumulate_parameter_name='a',
         merge_parameter_name='b',
         report_parameter_name='c')
     federated_secure_sum_bitwidth = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
     )
     called_intrinsics = building_blocks.Struct([
         federated_aggregate,
         federated_secure_sum_bitwidth,
     ])
     comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
     null_aggregate = building_block_factory.create_null_federated_aggregate(
     )
     secure_sum_bitwidth_uri = federated_secure_sum_bitwidth.function.uri
     aggregate_uri = null_aggregate.function.uri
     before, after = transformations.force_align_and_split_by_intrinsics(
         comp, [null_aggregate])
     self.assertTrue(
         tree_analysis.contains_called_intrinsic(comp,
                                                 secure_sum_bitwidth_uri))
     self.assertTrue(
         tree_analysis.contains_called_intrinsic(comp, aggregate_uri))
     self.assertFalse(
         tree_analysis.contains_called_intrinsic(before, aggregate_uri))
     self.assertFalse(
         tree_analysis.contains_called_intrinsic(after, aggregate_uri))
     self.assertTrue(
         tree_analysis.contains_called_intrinsic(before,
                                                 secure_sum_bitwidth_uri)
         or tree_analysis.contains_called_intrinsic(
             after, secure_sum_bitwidth_uri))
 def test_splits_on_selected_intrinsic_secure_sum_bitwidth(self):
     federated_secure_sum_bitwidth = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
     )
     called_intrinsics = building_blocks.Struct(
         [federated_secure_sum_bitwidth])
     comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
     call = building_block_factory.create_null_federated_secure_sum_bitwidth(
     )
     self.assert_splits_on(comp, call)
 def test_splits_on_two_intrinsics(self):
     federated_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate(
         accumulate_parameter_name='a',
         merge_parameter_name='b',
         report_parameter_name='c')
     federated_secure_sum_bitwidth = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
     )
     called_intrinsics = building_blocks.Struct([
         federated_aggregate,
         federated_secure_sum_bitwidth,
     ])
     comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
     self.assert_splits_on(comp, [
         building_block_factory.create_null_federated_aggregate(),
         building_block_factory.create_null_federated_secure_sum_bitwidth()
     ])
示例#4
0
 def test_returns_str_on_nested_secure_aggregation(self):
     comp = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
         (tf.int32, tf.int32))
     self.assert_one_aggregation(comp)
示例#5
0
     building_block_test_utils.create_whimsy_called_federated_broadcast()),
    (None,
     building_block_test_utils.create_whimsy_called_federated_value(
         placements.CLIENTS))
])

unit = computation_types.StructType([])
trivial_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate(
    value_type=unit)
trivial_mean = building_block_test_utils.create_whimsy_called_federated_mean(
    unit)
trivial_sum = building_block_test_utils.create_whimsy_called_federated_sum(
    unit)
# TODO(b/120439632) Enable once federated_mean accepts structured weights.
# trivial_weighted_mean = ...
trivial_secure_sum = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
    unit)


class ContainsAggregationShared(parameterized.TestCase):
    @parameterized.named_parameters([
        ('non_aggregation_intrinsics', non_aggregation_intrinsics),
        ('trivial_aggregate', trivial_aggregate),
        ('trivial_mean', trivial_mean),
        ('trivial_sum', trivial_sum),
        # TODO(b/120439632) Enable once federated_mean accepts structured weight.
        # ('trivial_weighted_mean', trivial_weighted_mean),
        ('trivial_secure_sum', trivial_secure_sum),
    ])
    def test_returns_none(self, comp):
        self.assertEmpty(tree_analysis.find_unsecure_aggregation_in_tree(comp))
        self.assertEmpty(tree_analysis.find_secure_aggregation_in_tree(comp))