def test_returns_tree(self): ip = get_iterative_process_for_sum_example_with_no_federated_aggregate() next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_aggregate( next_tree) before_federated_secure_sum, after_federated_secure_sum = ( transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_SECURE_SUM.uri])) self.assertIsInstance(before_aggregate, building_blocks.Lambda) self.assertIsInstance(before_aggregate.result, building_blocks.Tuple) self.assertLen(before_aggregate.result, 2) # pyformat: disable self.assertEqual( before_aggregate.result[0].formatted_representation(), '<\n' ' federated_value_at_clients(<>),\n' ' <>,\n' ' (_var1 -> <>),\n' ' (_var2 -> <>),\n' ' (_var3 -> <>)\n' '>' ) # pyformat: enable self.assertEqual( before_aggregate.result[1].formatted_representation(), before_federated_secure_sum.result.formatted_representation()) self.assertIsInstance(after_aggregate, building_blocks.Lambda) self.assertIsInstance(after_aggregate.result, building_blocks.Call) actual_tree, _ = tree_transformations.uniquify_reference_names( after_aggregate.result.function) expected_tree, _ = tree_transformations.uniquify_reference_names( after_federated_secure_sum) self.assertEqual(actual_tree.formatted_representation(), expected_tree.formatted_representation()) # pyformat: disable self.assertEqual( after_aggregate.result.argument.formatted_representation(), '<\n' ' _var4[0],\n' ' _var4[1][1]\n' '>' )
def test_returns_tree(self): ip = get_iterative_process_for_sum_example_with_no_federated_aggregate( ) next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_aggregate( next_tree) before_federated_secure_sum, after_federated_secure_sum = ( transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_SECURE_SUM.uri])) self.assertIsInstance(before_aggregate, building_blocks.Lambda) self.assertIsInstance(before_aggregate.result, building_blocks.Struct) self.assertLen(before_aggregate.result, 2) # pyformat: disable self.assertEqual( before_aggregate.result[0].formatted_representation(), '<\n' ' federated_value_at_clients(<>),\n' ' <>,\n' ' (_var1 -> <>),\n' ' (_var2 -> <>),\n' ' (_var3 -> <>)\n' '>') # pyformat: enable # trees_equal will fail if computations refer to unbound references, so we # create a new dummy computation to bind them. unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references( before_aggregate.result[1])[before_aggregate.result[1]] unbound_refs_in_before_secure_sum_result = transformation_utils.get_map_of_unbound_references( before_federated_secure_sum.result)[ before_federated_secure_sum.result] dummy_data = building_blocks.Data('data', computation_types.AbstractType('T')) blk_binding_refs_in_before_agg = building_blocks.Block( [(name, dummy_data) for name in unbound_refs_in_before_agg_result], before_aggregate.result[1]) blk_binding_refs_in_before_secure_sum = building_blocks.Block( [(name, dummy_data) for name in unbound_refs_in_before_secure_sum_result], before_federated_secure_sum.result) self.assertTrue( tree_analysis.trees_equal(blk_binding_refs_in_before_agg, blk_binding_refs_in_before_secure_sum)) self.assertIsInstance(after_aggregate, building_blocks.Lambda) self.assertIsInstance(after_aggregate.result, building_blocks.Call) actual_after_aggregate_tree, _ = tree_transformations.uniquify_reference_names( after_aggregate.result.function) expected_after_aggregate_tree, _ = tree_transformations.uniquify_reference_names( after_federated_secure_sum) self.assertTrue( tree_analysis.trees_equal(actual_after_aggregate_tree, expected_after_aggregate_tree)) # pyformat: disable self.assertEqual( after_aggregate.result.argument.formatted_representation(), '<\n' ' _var4[0],\n' ' _var4[1][1]\n' '>')