示例#1
0
    def test_ast_graph_optional_field_edges(self):
        """Test that edges for optional fields are correct."""
        root = gast.parse("return 1\nreturn")
        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        self.assertEqual(
            graph["root_body_0_item__Return"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId(
                        "root_body_0_item_value__Constant"),
                    in_edge=graph_types.InEdgeType("parent_in"))
            ])

        self.assertEqual(
            graph["root_body_0_item_value__Constant"].out_edges["parent_out"],
            [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_0_item__Return"),
                    in_edge=graph_types.InEdgeType("value_in"))
            ])

        self.assertEqual(
            graph["root_body_1_item__Return"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_1_item__Return"),
                    in_edge=graph_types.InEdgeType("value_missing"))
            ])
示例#2
0
    def test_ast_graph_conforms_to_schema(self):
        # Some example code using a few different syntactic constructs, to cover
        # a large set of nodes in the schema
        root = gast.parse(
            textwrap.dedent("""\
        def foo(n):
          if n <= 1:
            return 1
          else:
            return foo(n-1) + foo(n-2)

        def bar(m, n) -> int:
          x = n
          for i in range(m):
            if False:
              continue
            x = x + i
          while True:
            break
          return x

        x0 = 1 + 2 - 3 * 4 / 5
        x1 = (1 == 2) and (3 < 4) and (5 > 6)
        x2 = (7 <= 8) and (9 >= 10) or (11 != 12)
        x2 = bar(13, 14 + 15)
        """))

        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        # Graph should match the schema
        schema_util.assert_conforms_to_schema(graph, py_ast_graphs.SCHEMA)
示例#3
0
    def test_ast_graph_nodes(self):
        """Check node IDs, node types, and forward mapping."""
        root = gast.parse(
            textwrap.dedent("""\
        pass
        def foo(n):
            if n <= 1:
              return 1
        """))

        graph, forward_map = py_ast_graphs.py_ast_to_graph(root)

        # pytype: disable=attribute-error
        self.assertIn("root__Module", graph)
        self.assertEqual(graph["root__Module"].node_type, "Module")
        self.assertEqual(forward_map[id(root)], "root__Module")

        self.assertIn("root_body_1__Module_body-seq-helper", graph)
        self.assertEqual(
            graph["root_body_1__Module_body-seq-helper"].node_type,
            "Module_body-seq-helper")

        self.assertIn("root_body_1_item_body_0_item__If", graph)
        self.assertEqual(graph["root_body_1_item_body_0_item__If"].node_type,
                         "If")
        self.assertEqual(forward_map[id(root.body[1].body[0])],
                         "root_body_1_item_body_0_item__If")

        self.assertIn("root_body_1_item_body_0_item_test_left__Name", graph)
        self.assertEqual(
            graph["root_body_1_item_body_0_item_test_left__Name"].node_type,
            "Name")
        self.assertEqual(forward_map[id(root.body[1].body[0].test.left)],
                         "root_body_1_item_body_0_item_test_left__Name")
 def build_example(size):
     tree = gast.Module(
         body=[gast.Constant(value=i, kind=None) for i in range(size)],
         type_ignores=[])
     py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree))
     edges = []
     for i in range(1, size, 2):
         edges.append((ast_to_node_id[id(tree.body[i])],
                       ast_to_node_id[id(tree.body[i - 1])], 1))
     return graph_bundle.convert_graph_with_edges(
         py_graph, edges, py_ast_graphs.BUILDER)
示例#5
0
    def test_convert_no_targets(self):
        tree = gast.parse(
            textwrap.dedent("""\
          def foo():
            x = 5
            return x
          """))

        py_graph, _ = py_ast_graphs.py_ast_to_graph(tree)
        example = graph_bundle.convert_graph_with_edges(
            py_graph, [], builder=py_ast_graphs.BUILDER)

        # Target indices should still be a valid operator, but with no nonzero
        # entries.
        self.assertEqual(example.edges.input_indices.shape, (0, 1))
        self.assertEqual(example.edges.output_indices.shape, (0, 2))
        self.assertEqual(example.edges.values.shape, (0, ))
示例#6
0
  def test_ast_graph_unique_field_edges(self):
    """Test that edges for unique fields are correct."""
    root = gast.parse("print(1)")
    graph, _ = py_ast_graphs.py_ast_to_graph(root)

    self.assertEqual(graph["root_body_0_item__Expr"].out_edges["value_out"], [
        graph_types.InputTaggedNode(
            node_id=graph_types.NodeId("root_body_0_item_value__Call"),
            in_edge=graph_types.InEdgeType("parent_in"))
    ])

    self.assertEqual(
        graph["root_body_0_item_value__Call"].out_edges["parent_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId("root_body_0_item__Expr"),
                in_edge=graph_types.InEdgeType("value_in"))
        ])
示例#7
0
    def test_zeros_like_padded_example(self):
        tree = gast.parse("pass")
        py_graph, _ = py_ast_graphs.py_ast_to_graph(tree)
        example = graph_bundle.convert_graph_with_edges(
            py_graph, [], builder=py_ast_graphs.BUILDER)

        padding_config = graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=16, num_input_tagged_nodes=34),
            max_initial_transitions=64,
            max_in_tagged_transitions=128,
            max_edges=4)

        padded_example = graph_bundle.pad_example(example, padding_config)
        generated = graph_bundle.zeros_like_padded_example(padding_config)

        def _check(x, y):
            x = np.asarray(x)
            y = np.asarray(y)
            self.assertEqual(x.shape, y.shape)
            self.assertEqual(x.dtype, y.dtype)

        jax.tree_multimap(_check, generated, padded_example)
示例#8
0
 def test_invalid_graphs(self, ast, expected_error):
     with self.assertRaisesRegex(ValueError, expected_error):
         py_ast_graphs.py_ast_to_graph(ast)
示例#9
0
    def test_ast_graph_sequence_field_edges(self):
        """Test that edges for sequence fields are correct.

    Note that sequence fields produce connections between three nodes: the
    parent, the helper node, and the child.
    """
        root = gast.parse(
            textwrap.dedent("""\
        print(1)
        print(2)
        print(3)
        print(4)
        print(5)
        print(6)
        """))

        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        # Child edges from the parent node
        node = graph["root__Module"]
        self.assertLen(node.out_edges["body_out_all"], 6)
        self.assertEqual(node.out_edges["body_out_first"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])
        self.assertEqual(node.out_edges["body_out_last"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_5__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])

        # Edges from the sequence helper
        node = graph["root_body_0__Module_body-seq-helper"]
        self.assertEqual(node.out_edges["parent_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId("root__Module"),
                in_edge=graph_types.InEdgeType("body_in"))
        ])
        self.assertEqual(node.out_edges["item_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId("root_body_0_item__Expr"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])
        self.assertEqual(node.out_edges["prev_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("prev_missing"))
        ])
        self.assertEqual(node.out_edges["next_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_1__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("prev_in"))
        ])

        # Parent edge of the item
        node = graph["root_body_0_item__Expr"]
        self.assertEqual(node.out_edges["parent_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("item_in"))
        ])
示例#10
0
    def test_convert_example(self):
        tree = gast.parse(
            textwrap.dedent("""\
          def foo():
            x = 5
            return x
          """))

        py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree))
        ast_edges = [
            (tree.body[0].body[1], tree.body[0], 1),
            (tree.body[0].body[1].value, tree.body[0].body[0].targets[0], 2),
        ]
        converted_edges = [(ast_to_node_id[id(source)],
                            ast_to_node_id[id(dest)], edge_type)
                           for (source, dest, edge_type) in ast_edges]
        example = graph_bundle.convert_graph_with_edges(
            py_graph, converted_edges, builder=py_ast_graphs.BUILDER)

        self.assertEqual(list(py_graph), [
            "root__Module",
            "root_body_0__Module_body-seq-helper",
            "root_body_0_item__FunctionDef",
            "root_body_0_item_args__arguments",
            "root_body_0_item_body_0__FunctionDef_body-seq-helper",
            "root_body_0_item_body_0_item__Assign",
            "root_body_0_item_body_0_item_targets__Name",
            "root_body_0_item_body_0_item_value__Constant",
            "root_body_0_item_body_1__FunctionDef_body-seq-helper",
            "root_body_0_item_body_1_item__Return",
            "root_body_0_item_body_1_item_value__Name",
        ])

        self.assertEqual(
            example.graph_metadata,
            automaton_builder.EncodedGraphMetadata(num_nodes=11,
                                                   num_input_tagged_nodes=27))

        self.assertEqual(example.node_types.shape, (11, ))

        np.testing.assert_array_equal(example.edges.input_indices, [[1], [2]])
        np.testing.assert_array_equal(example.edges.output_indices,
                                      [[9, 2], [10, 6]])
        np.testing.assert_array_equal(example.edges.values, [1, 1])

        self.assertEqual(
            example.automaton_graph.initial_to_in_tagged.values.shape, (34, ))
        self.assertEqual(example.automaton_graph.initial_to_special.shape,
                         (11, ))
        self.assertEqual(
            example.automaton_graph.in_tagged_to_in_tagged.values.shape,
            (103, ))
        self.assertEqual(example.automaton_graph.in_tagged_to_special.shape,
                         (27, ))

        # Verify that the transition matrix can be built with the right size.
        routing_params = py_ast_graphs.BUILDER.initialize_routing_params(
            None, 1, 1, noise_factor=0)
        transition_matrix = py_ast_graphs.BUILDER.build_transition_matrix(
            routing_params, example.automaton_graph, example.graph_metadata)

        self.assertEqual(transition_matrix.initial_to_in_tagged.shape,
                         (1, 11, 1, 27, 1))
        self.assertEqual(transition_matrix.initial_to_special.shape,
                         (1, 11, 1, 3))
        self.assertEqual(transition_matrix.in_tagged_to_in_tagged.shape,
                         (1, 27, 1, 27, 1))
        self.assertEqual(transition_matrix.in_tagged_to_special.shape,
                         (1, 27, 1, 3))
        self.assertEqual(transition_matrix.in_tagged_node_indices.shape,
                         (27, ))
示例#11
0
    def test_pad_example(self):
        tree = gast.parse(
            textwrap.dedent("""\
          def foo():
            x = 5
            return x
          """))

        py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree))
        ast_edges = [
            (tree.body[0].body[1], tree.body[0], 1),
            (tree.body[0].body[1].value, tree.body[0].body[0].targets[0], 2),
        ]
        converted_edges = [(ast_to_node_id[id(source)],
                            ast_to_node_id[id(dest)], edge_type)
                           for (source, dest, edge_type) in ast_edges]
        example = graph_bundle.convert_graph_with_edges(
            py_graph, converted_edges, builder=py_ast_graphs.BUILDER)

        padding_config = graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=16, num_input_tagged_nodes=34),
            max_initial_transitions=64,
            max_in_tagged_transitions=128,
            max_edges=4)

        padded_example = graph_bundle.pad_example(example, padding_config)

        # Metadata is not affected by padding.
        self.assertEqual(
            padded_example.graph_metadata,
            automaton_builder.EncodedGraphMetadata(num_nodes=11,
                                                   num_input_tagged_nodes=27))

        # Everything else is padded.
        self.assertEqual(padded_example.node_types.shape, (16, ))

        np.testing.assert_array_equal(padded_example.edges.input_indices,
                                      [[1], [2], [0], [0]])
        np.testing.assert_array_equal(padded_example.edges.output_indices,
                                      [[9, 2], [10, 6], [0, 0], [0, 0]])
        np.testing.assert_array_equal(padded_example.edges.values,
                                      [1, 1, 0, 0])

        self.assertEqual(
            padded_example.automaton_graph.initial_to_in_tagged.values.shape,
            (64, ))
        self.assertEqual(
            padded_example.automaton_graph.initial_to_special.shape, (16, ))
        self.assertEqual(
            padded_example.automaton_graph.in_tagged_to_in_tagged.values.shape,
            (128, ))
        self.assertEqual(
            padded_example.automaton_graph.in_tagged_to_special.shape, (34, ))

        # Transition matrix also becomes padded once it is built.
        # (Note that we pass the padded static metadata to the transition matrix
        # builder, since the encoded graph has been padded.)
        routing_params = py_ast_graphs.BUILDER.initialize_routing_params(
            None, 1, 1, noise_factor=0)
        transition_matrix = py_ast_graphs.BUILDER.build_transition_matrix(
            routing_params, padded_example.automaton_graph,
            padding_config.static_max_metadata)

        self.assertEqual(transition_matrix.initial_to_in_tagged.shape,
                         (1, 16, 1, 34, 1))
        self.assertEqual(transition_matrix.initial_to_special.shape,
                         (1, 16, 1, 3))
        self.assertEqual(transition_matrix.in_tagged_to_in_tagged.shape,
                         (1, 34, 1, 34, 1))
        self.assertEqual(transition_matrix.in_tagged_to_special.shape,
                         (1, 34, 1, 3))
        self.assertEqual(transition_matrix.in_tagged_node_indices.shape,
                         (34, ))