def test_observe_types_and_fields(self):
    tree = generic_ast_graphs.GenericASTNode(
        0, "root", {
            "children": [
                generic_ast_graphs.GenericASTNode(
                    1, "foo", {
                        "a": [generic_ast_graphs.GenericASTNode(12, "bar", {})],
                        "b": [generic_ast_graphs.GenericASTNode(13, "bar", {})],
                        "c": [],
                        "d": [
                            generic_ast_graphs.GenericASTNode(14, "bar", {}),
                            generic_ast_graphs.GenericASTNode(15, "bar", {})
                        ],
                    }),
                generic_ast_graphs.GenericASTNode(
                    2, "foo", {
                        "a": [generic_ast_graphs.GenericASTNode(22, "bar", {})],
                        "b": [],
                        "d": [generic_ast_graphs.GenericASTNode(24, "bar", {})],
                    })
            ]
        })

    observations = ast_spec_inference.observe_types_and_fields(tree)
    expected = ast_spec_inference.ASTObservations(
        example_count=1,
        node_types={
            "root":
                ast_spec_inference.NodeObservations(
                    count=1,
                    count_root=1,
                    fields={
                        "children":
                            ast_spec_inference.FieldObservations(count_many=1)
                    }),
            "foo":
                ast_spec_inference.NodeObservations(
                    count=2,
                    count_root=0,
                    fields={
                        "a":
                            ast_spec_inference.FieldObservations(count_one=2),
                        "b":
                            ast_spec_inference.FieldObservations(count_one=1),
                        "c":
                            ast_spec_inference.FieldObservations(),
                        "d":
                            ast_spec_inference.FieldObservations(
                                count_one=1, count_many=1),
                    }),
            "bar":
                ast_spec_inference.NodeObservations(
                    count=6, count_root=0, fields={}),
        })

    self.assertEqual(observations, expected)
Esempio n. 2
0
def py_ast_to_generic(tree):
  """Convert a gast AST node to a generic representation.

  IDs are set based on the python `id` of the AST node. Only children that are
  AST nodes or lists of AST nodes will be processed.

  Args:
    tree: Node of the AST to convert.

  Returns:
    Generic representation of the AST.
  """
  fields = {}
  for field_name in tree._fields:
    value = getattr(tree, field_name)
    if isinstance(value, gast.AST):
      fields[field_name] = [py_ast_to_generic(value)]
    elif isinstance(value, list):
      if value and isinstance(value[0], gast.AST):
        fields[field_name] = [py_ast_to_generic(child) for child in value]
    else:
      # Doesn't contain any AST nodes, so ignore it.
      pass

  return generic_ast_graphs.GenericASTNode(
      node_id=id(tree), node_type=type(tree).__name__, fields=fields)
    def test_compute_nth_child_edges(self):
        mini_ast_spec = {
            "root":
            generic_ast_graphs.ASTNodeSpec(
                fields={"children": generic_ast_graphs.FieldType.SEQUENCE},
                sequence_item_types={"children": "child"},
                has_parent=False),
            "leaf":
            generic_ast_graphs.ASTNodeSpec()
        }
        mini_ast_node = generic_ast_graphs.GenericASTNode(
            "root", "root", {
                "children": [
                    generic_ast_graphs.GenericASTNode("leaf0", "leaf", {}),
                    generic_ast_graphs.GenericASTNode("leaf1", "leaf", {}),
                    generic_ast_graphs.GenericASTNode("leaf2", "leaf", {}),
                    generic_ast_graphs.GenericASTNode("leaf3", "leaf", {}),
                    generic_ast_graphs.GenericASTNode("leaf4", "leaf", {}),
                ]
            })
        mini_ast_graph, _ = generic_ast_graphs.ast_to_graph(
            mini_ast_node, mini_ast_spec)

        # Allowing 10 children.
        nth_child_edges = graph_edge_util.compute_nth_child_edges(
            mini_ast_graph, 10)
        self.assertEqual(
            nth_child_edges,
            [("root__root", f"root_children_{i}__child-seq-helper",
              f"CHILD_INDEX_{i}") for i in range(5)])

        # Allowing 2 children.
        nth_child_edges = graph_edge_util.compute_nth_child_edges(
            mini_ast_graph, 2)
        self.assertEqual(
            nth_child_edges,
            [("root__root", f"root_children_{i}__child-seq-helper",
              f"CHILD_INDEX_{i}") for i in range(2)])