def test_removes_selected_intrinsic_leaving_remaining_intrinsic(self): federated_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') federated_secure_sum_bitwidth = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth( ) called_intrinsics = building_blocks.Struct([ federated_aggregate, federated_secure_sum_bitwidth, ]) comp = building_blocks.Lambda('d', tf.int32, called_intrinsics) null_aggregate = building_block_factory.create_null_federated_aggregate( ) secure_sum_bitwidth_uri = federated_secure_sum_bitwidth.function.uri aggregate_uri = null_aggregate.function.uri before, after = transformations.force_align_and_split_by_intrinsics( comp, [null_aggregate]) self.assertTrue( tree_analysis.contains_called_intrinsic(comp, secure_sum_bitwidth_uri)) self.assertTrue( tree_analysis.contains_called_intrinsic(comp, aggregate_uri)) self.assertFalse( tree_analysis.contains_called_intrinsic(before, aggregate_uri)) self.assertFalse( tree_analysis.contains_called_intrinsic(after, aggregate_uri)) self.assertTrue( tree_analysis.contains_called_intrinsic(before, secure_sum_bitwidth_uri) or tree_analysis.contains_called_intrinsic( after, secure_sum_bitwidth_uri))
def test_splits_on_selected_intrinsic_secure_sum_bitwidth(self): federated_secure_sum_bitwidth = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth( ) called_intrinsics = building_blocks.Struct( [federated_secure_sum_bitwidth]) comp = building_blocks.Lambda('a', tf.int32, called_intrinsics) call = building_block_factory.create_null_federated_secure_sum_bitwidth( ) self.assert_splits_on(comp, call)
def test_splits_on_two_intrinsics(self): federated_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') federated_secure_sum_bitwidth = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth( ) called_intrinsics = building_blocks.Struct([ federated_aggregate, federated_secure_sum_bitwidth, ]) comp = building_blocks.Lambda('d', tf.int32, called_intrinsics) self.assert_splits_on(comp, [ building_block_factory.create_null_federated_aggregate(), building_block_factory.create_null_federated_secure_sum_bitwidth() ])
def test_returns_str_on_nested_secure_aggregation(self): comp = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth( (tf.int32, tf.int32)) self.assert_one_aggregation(comp)
building_block_test_utils.create_whimsy_called_federated_broadcast()), (None, building_block_test_utils.create_whimsy_called_federated_value( placements.CLIENTS)) ]) unit = computation_types.StructType([]) trivial_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate( value_type=unit) trivial_mean = building_block_test_utils.create_whimsy_called_federated_mean( unit) trivial_sum = building_block_test_utils.create_whimsy_called_federated_sum( unit) # TODO(b/120439632) Enable once federated_mean accepts structured weights. # trivial_weighted_mean = ... trivial_secure_sum = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth( unit) class ContainsAggregationShared(parameterized.TestCase): @parameterized.named_parameters([ ('non_aggregation_intrinsics', non_aggregation_intrinsics), ('trivial_aggregate', trivial_aggregate), ('trivial_mean', trivial_mean), ('trivial_sum', trivial_sum), # TODO(b/120439632) Enable once federated_mean accepts structured weight. # ('trivial_weighted_mean', trivial_weighted_mean), ('trivial_secure_sum', trivial_secure_sum), ]) def test_returns_none(self, comp): self.assertEmpty(tree_analysis.find_unsecure_aggregation_in_tree(comp)) self.assertEmpty(tree_analysis.find_secure_aggregation_in_tree(comp))