def test_ast_graph_optional_field_edges(self): """Test that edges for optional fields are correct.""" root = gast.parse("return 1\nreturn") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual( graph["root_body_0_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0_item_value__Constant"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Constant"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Return"), in_edge=graph_types.InEdgeType("value_in")) ]) self.assertEqual( graph["root_body_1_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_1_item__Return"), in_edge=graph_types.InEdgeType("value_missing")) ])
def test_ast_graph_conforms_to_schema(self): # Some example code using a few different syntactic constructs, to cover # a large set of nodes in the schema root = gast.parse( textwrap.dedent("""\ def foo(n): if n <= 1: return 1 else: return foo(n-1) + foo(n-2) def bar(m, n) -> int: x = n for i in range(m): if False: continue x = x + i while True: break return x x0 = 1 + 2 - 3 * 4 / 5 x1 = (1 == 2) and (3 < 4) and (5 > 6) x2 = (7 <= 8) and (9 >= 10) or (11 != 12) x2 = bar(13, 14 + 15) """)) graph, _ = py_ast_graphs.py_ast_to_graph(root) # Graph should match the schema schema_util.assert_conforms_to_schema(graph, py_ast_graphs.SCHEMA)
def test_ast_graph_nodes(self): """Check node IDs, node types, and forward mapping.""" root = gast.parse( textwrap.dedent("""\ pass def foo(n): if n <= 1: return 1 """)) graph, forward_map = py_ast_graphs.py_ast_to_graph(root) # pytype: disable=attribute-error self.assertIn("root__Module", graph) self.assertEqual(graph["root__Module"].node_type, "Module") self.assertEqual(forward_map[id(root)], "root__Module") self.assertIn("root_body_1__Module_body-seq-helper", graph) self.assertEqual( graph["root_body_1__Module_body-seq-helper"].node_type, "Module_body-seq-helper") self.assertIn("root_body_1_item_body_0_item__If", graph) self.assertEqual(graph["root_body_1_item_body_0_item__If"].node_type, "If") self.assertEqual(forward_map[id(root.body[1].body[0])], "root_body_1_item_body_0_item__If") self.assertIn("root_body_1_item_body_0_item_test_left__Name", graph) self.assertEqual( graph["root_body_1_item_body_0_item_test_left__Name"].node_type, "Name") self.assertEqual(forward_map[id(root.body[1].body[0].test.left)], "root_body_1_item_body_0_item_test_left__Name")
def build_example(size): tree = gast.Module( body=[gast.Constant(value=i, kind=None) for i in range(size)], type_ignores=[]) py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree)) edges = [] for i in range(1, size, 2): edges.append((ast_to_node_id[id(tree.body[i])], ast_to_node_id[id(tree.body[i - 1])], 1)) return graph_bundle.convert_graph_with_edges( py_graph, edges, py_ast_graphs.BUILDER)
def test_convert_no_targets(self): tree = gast.parse( textwrap.dedent("""\ def foo(): x = 5 return x """)) py_graph, _ = py_ast_graphs.py_ast_to_graph(tree) example = graph_bundle.convert_graph_with_edges( py_graph, [], builder=py_ast_graphs.BUILDER) # Target indices should still be a valid operator, but with no nonzero # entries. self.assertEqual(example.edges.input_indices.shape, (0, 1)) self.assertEqual(example.edges.output_indices.shape, (0, 2)) self.assertEqual(example.edges.values.shape, (0, ))
def test_ast_graph_unique_field_edges(self): """Test that edges for unique fields are correct.""" root = gast.parse("print(1)") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual(graph["root_body_0_item__Expr"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item_value__Call"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Call"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("value_in")) ])
def test_zeros_like_padded_example(self): tree = gast.parse("pass") py_graph, _ = py_ast_graphs.py_ast_to_graph(tree) example = graph_bundle.convert_graph_with_edges( py_graph, [], builder=py_ast_graphs.BUILDER) padding_config = graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=16, num_input_tagged_nodes=34), max_initial_transitions=64, max_in_tagged_transitions=128, max_edges=4) padded_example = graph_bundle.pad_example(example, padding_config) generated = graph_bundle.zeros_like_padded_example(padding_config) def _check(x, y): x = np.asarray(x) y = np.asarray(y) self.assertEqual(x.shape, y.shape) self.assertEqual(x.dtype, y.dtype) jax.tree_multimap(_check, generated, padded_example)
def test_invalid_graphs(self, ast, expected_error): with self.assertRaisesRegex(ValueError, expected_error): py_ast_graphs.py_ast_to_graph(ast)
def test_ast_graph_sequence_field_edges(self): """Test that edges for sequence fields are correct. Note that sequence fields produce connections between three nodes: the parent, the helper node, and the child. """ root = gast.parse( textwrap.dedent("""\ print(1) print(2) print(3) print(4) print(5) print(6) """)) graph, _ = py_ast_graphs.py_ast_to_graph(root) # Child edges from the parent node node = graph["root__Module"] self.assertLen(node.out_edges["body_out_all"], 6) self.assertEqual(node.out_edges["body_out_first"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["body_out_last"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_5__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) # Edges from the sequence helper node = graph["root_body_0__Module_body-seq-helper"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root__Module"), in_edge=graph_types.InEdgeType("body_in")) ]) self.assertEqual(node.out_edges["item_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["prev_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_missing")) ]) self.assertEqual(node.out_edges["next_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_1__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_in")) ]) # Parent edge of the item node = graph["root_body_0_item__Expr"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("item_in")) ])
def test_convert_example(self): tree = gast.parse( textwrap.dedent("""\ def foo(): x = 5 return x """)) py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree)) ast_edges = [ (tree.body[0].body[1], tree.body[0], 1), (tree.body[0].body[1].value, tree.body[0].body[0].targets[0], 2), ] converted_edges = [(ast_to_node_id[id(source)], ast_to_node_id[id(dest)], edge_type) for (source, dest, edge_type) in ast_edges] example = graph_bundle.convert_graph_with_edges( py_graph, converted_edges, builder=py_ast_graphs.BUILDER) self.assertEqual(list(py_graph), [ "root__Module", "root_body_0__Module_body-seq-helper", "root_body_0_item__FunctionDef", "root_body_0_item_args__arguments", "root_body_0_item_body_0__FunctionDef_body-seq-helper", "root_body_0_item_body_0_item__Assign", "root_body_0_item_body_0_item_targets__Name", "root_body_0_item_body_0_item_value__Constant", "root_body_0_item_body_1__FunctionDef_body-seq-helper", "root_body_0_item_body_1_item__Return", "root_body_0_item_body_1_item_value__Name", ]) self.assertEqual( example.graph_metadata, automaton_builder.EncodedGraphMetadata(num_nodes=11, num_input_tagged_nodes=27)) self.assertEqual(example.node_types.shape, (11, )) np.testing.assert_array_equal(example.edges.input_indices, [[1], [2]]) np.testing.assert_array_equal(example.edges.output_indices, [[9, 2], [10, 6]]) np.testing.assert_array_equal(example.edges.values, [1, 1]) self.assertEqual( example.automaton_graph.initial_to_in_tagged.values.shape, (34, )) self.assertEqual(example.automaton_graph.initial_to_special.shape, (11, )) self.assertEqual( example.automaton_graph.in_tagged_to_in_tagged.values.shape, (103, )) self.assertEqual(example.automaton_graph.in_tagged_to_special.shape, (27, )) # Verify that the transition matrix can be built with the right size. routing_params = py_ast_graphs.BUILDER.initialize_routing_params( None, 1, 1, noise_factor=0) transition_matrix = py_ast_graphs.BUILDER.build_transition_matrix( routing_params, example.automaton_graph, example.graph_metadata) self.assertEqual(transition_matrix.initial_to_in_tagged.shape, (1, 11, 1, 27, 1)) self.assertEqual(transition_matrix.initial_to_special.shape, (1, 11, 1, 3)) self.assertEqual(transition_matrix.in_tagged_to_in_tagged.shape, (1, 27, 1, 27, 1)) self.assertEqual(transition_matrix.in_tagged_to_special.shape, (1, 27, 1, 3)) self.assertEqual(transition_matrix.in_tagged_node_indices.shape, (27, ))
def test_pad_example(self): tree = gast.parse( textwrap.dedent("""\ def foo(): x = 5 return x """)) py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree)) ast_edges = [ (tree.body[0].body[1], tree.body[0], 1), (tree.body[0].body[1].value, tree.body[0].body[0].targets[0], 2), ] converted_edges = [(ast_to_node_id[id(source)], ast_to_node_id[id(dest)], edge_type) for (source, dest, edge_type) in ast_edges] example = graph_bundle.convert_graph_with_edges( py_graph, converted_edges, builder=py_ast_graphs.BUILDER) padding_config = graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=16, num_input_tagged_nodes=34), max_initial_transitions=64, max_in_tagged_transitions=128, max_edges=4) padded_example = graph_bundle.pad_example(example, padding_config) # Metadata is not affected by padding. self.assertEqual( padded_example.graph_metadata, automaton_builder.EncodedGraphMetadata(num_nodes=11, num_input_tagged_nodes=27)) # Everything else is padded. self.assertEqual(padded_example.node_types.shape, (16, )) np.testing.assert_array_equal(padded_example.edges.input_indices, [[1], [2], [0], [0]]) np.testing.assert_array_equal(padded_example.edges.output_indices, [[9, 2], [10, 6], [0, 0], [0, 0]]) np.testing.assert_array_equal(padded_example.edges.values, [1, 1, 0, 0]) self.assertEqual( padded_example.automaton_graph.initial_to_in_tagged.values.shape, (64, )) self.assertEqual( padded_example.automaton_graph.initial_to_special.shape, (16, )) self.assertEqual( padded_example.automaton_graph.in_tagged_to_in_tagged.values.shape, (128, )) self.assertEqual( padded_example.automaton_graph.in_tagged_to_special.shape, (34, )) # Transition matrix also becomes padded once it is built. # (Note that we pass the padded static metadata to the transition matrix # builder, since the encoded graph has been padded.) routing_params = py_ast_graphs.BUILDER.initialize_routing_params( None, 1, 1, noise_factor=0) transition_matrix = py_ast_graphs.BUILDER.build_transition_matrix( routing_params, padded_example.automaton_graph, padding_config.static_max_metadata) self.assertEqual(transition_matrix.initial_to_in_tagged.shape, (1, 16, 1, 34, 1)) self.assertEqual(transition_matrix.initial_to_special.shape, (1, 16, 1, 3)) self.assertEqual(transition_matrix.in_tagged_to_in_tagged.shape, (1, 34, 1, 34, 1)) self.assertEqual(transition_matrix.in_tagged_to_special.shape, (1, 34, 1, 3)) self.assertEqual(transition_matrix.in_tagged_node_indices.shape, (34, ))