def test_observe_types_and_fields(self): tree = generic_ast_graphs.GenericASTNode( 0, "root", { "children": [ generic_ast_graphs.GenericASTNode( 1, "foo", { "a": [generic_ast_graphs.GenericASTNode(12, "bar", {})], "b": [generic_ast_graphs.GenericASTNode(13, "bar", {})], "c": [], "d": [ generic_ast_graphs.GenericASTNode(14, "bar", {}), generic_ast_graphs.GenericASTNode(15, "bar", {}) ], }), generic_ast_graphs.GenericASTNode( 2, "foo", { "a": [generic_ast_graphs.GenericASTNode(22, "bar", {})], "b": [], "d": [generic_ast_graphs.GenericASTNode(24, "bar", {})], }) ] }) observations = ast_spec_inference.observe_types_and_fields(tree) expected = ast_spec_inference.ASTObservations( example_count=1, node_types={ "root": ast_spec_inference.NodeObservations( count=1, count_root=1, fields={ "children": ast_spec_inference.FieldObservations(count_many=1) }), "foo": ast_spec_inference.NodeObservations( count=2, count_root=0, fields={ "a": ast_spec_inference.FieldObservations(count_one=2), "b": ast_spec_inference.FieldObservations(count_one=1), "c": ast_spec_inference.FieldObservations(), "d": ast_spec_inference.FieldObservations( count_one=1, count_many=1), }), "bar": ast_spec_inference.NodeObservations( count=6, count_root=0, fields={}), }) self.assertEqual(observations, expected)
def py_ast_to_generic(tree): """Convert a gast AST node to a generic representation. IDs are set based on the python `id` of the AST node. Only children that are AST nodes or lists of AST nodes will be processed. Args: tree: Node of the AST to convert. Returns: Generic representation of the AST. """ fields = {} for field_name in tree._fields: value = getattr(tree, field_name) if isinstance(value, gast.AST): fields[field_name] = [py_ast_to_generic(value)] elif isinstance(value, list): if value and isinstance(value[0], gast.AST): fields[field_name] = [py_ast_to_generic(child) for child in value] else: # Doesn't contain any AST nodes, so ignore it. pass return generic_ast_graphs.GenericASTNode( node_id=id(tree), node_type=type(tree).__name__, fields=fields)
def test_compute_nth_child_edges(self): mini_ast_spec = { "root": generic_ast_graphs.ASTNodeSpec( fields={"children": generic_ast_graphs.FieldType.SEQUENCE}, sequence_item_types={"children": "child"}, has_parent=False), "leaf": generic_ast_graphs.ASTNodeSpec() } mini_ast_node = generic_ast_graphs.GenericASTNode( "root", "root", { "children": [ generic_ast_graphs.GenericASTNode("leaf0", "leaf", {}), generic_ast_graphs.GenericASTNode("leaf1", "leaf", {}), generic_ast_graphs.GenericASTNode("leaf2", "leaf", {}), generic_ast_graphs.GenericASTNode("leaf3", "leaf", {}), generic_ast_graphs.GenericASTNode("leaf4", "leaf", {}), ] }) mini_ast_graph, _ = generic_ast_graphs.ast_to_graph( mini_ast_node, mini_ast_spec) # Allowing 10 children. nth_child_edges = graph_edge_util.compute_nth_child_edges( mini_ast_graph, 10) self.assertEqual( nth_child_edges, [("root__root", f"root_children_{i}__child-seq-helper", f"CHILD_INDEX_{i}") for i in range(5)]) # Allowing 2 children. nth_child_edges = graph_edge_util.compute_nth_child_edges( mini_ast_graph, 2) self.assertEqual( nth_child_edges, [("root__root", f"root_children_{i}__child-seq-helper", f"CHILD_INDEX_{i}") for i in range(2)])