def build_example(size): tree = gast.Module( body=[gast.Constant(value=i, kind=None) for i in range(size)], type_ignores=[]) py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree)) edges = [] for i in range(1, size, 2): edges.append((ast_to_node_id[id(tree.body[i])], ast_to_node_id[id(tree.body[i - 1])], 1)) return graph_bundle.convert_graph_with_edges( py_graph, edges, py_ast_graphs.BUILDER)
def test_soft_maze_values(self): maze = np.array([ [1, 1, 1, 1, 1], [1, 0, 1, 1, 1], [1, 1, 1, 1, 1], ]).astype(bool) # Convert the maze into an adjacency matrix. maze_graph, coords = maze_schema.encode_maze(maze) primitive_edges = maze_task.maze_primitive_edges(maze_graph) example = graph_bundle.convert_graph_with_edges( maze_graph, primitive_edges, maze_task.BUILDER) edge_primitives = example.edges.apply_add( in_array=jnp.eye(4), out_array=jnp.zeros([len(maze_graph), len(maze_graph), 4]), in_dims=(0, ), out_dims=(0, 1)) # Compute values for getting to a particular square, under a low temperature # (so that it's approximately shortest paths) values, q_vals, policy = train_maze_lib.soft_maze_values( edge_primitives, target_state_index=coords.index((0, 1)), temperature=1e-7) expected_values_at_coords = np.array([ [-1, 0, -1, -2, -3], [-2, np.nan, -2, -3, -4], [-3, -4, -3, -4, -5], ]) expected_values = [expected_values_at_coords[c] for c in coords] np.testing.assert_allclose(values, expected_values, atol=1e-6) # Check Q vals and policy at top left corner. np.testing.assert_allclose(q_vals[0], [-2, -1, -2, -3]) np.testing.assert_allclose(policy[0], [0, 1, 0, 0]) # Check reverse-mode gradient under a higher temperature. fun = functools.partial(train_maze_lib.soft_maze_values, target_state_index=coords.index((0, 1)), temperature=1) jax.test_util.check_grads(fun, (edge_primitives, ), 1, "rev") # Check gradient under batching. jax.test_util.check_grads(jax.vmap(fun), (edge_primitives[None, Ellipsis], ), 1, "rev")
def test_convert_no_targets(self): tree = gast.parse( textwrap.dedent("""\ def foo(): x = 5 return x """)) py_graph, _ = py_ast_graphs.py_ast_to_graph(tree) example = graph_bundle.convert_graph_with_edges( py_graph, [], builder=py_ast_graphs.BUILDER) # Target indices should still be a valid operator, but with no nonzero # entries. self.assertEqual(example.edges.input_indices.shape, (0, 1)) self.assertEqual(example.edges.output_indices.shape, (0, 2)) self.assertEqual(example.edges.values.shape, (0, ))
def test_loss_fn(self): maze = np.array([ [1, 1, 1, 1, 1], [1, 0, 1, 1, 1], [1, 1, 1, 1, 1], ]).astype(bool) # Convert the maze into an adjacency matrix. maze_graph, _ = maze_schema.encode_maze(maze) primitive_edges = maze_task.maze_primitive_edges(maze_graph) example = graph_bundle.convert_graph_with_edges( maze_graph, primitive_edges, maze_task.BUILDER) def mock_model(example, dynamic_metadata): del example, dynamic_metadata return jnp.full([3, 14, 14], 1 / 14) loss, metrics = train_maze_lib.loss_fn(mock_model, (example, 1), num_goals=4) self.assertGreater(loss, 0) for metric in metrics.values(): self.assertEqual(metric.shape, ())
def test_zeros_like_padded_example(self): tree = gast.parse("pass") py_graph, _ = py_ast_graphs.py_ast_to_graph(tree) example = graph_bundle.convert_graph_with_edges( py_graph, [], builder=py_ast_graphs.BUILDER) padding_config = graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=16, num_input_tagged_nodes=34), max_initial_transitions=64, max_in_tagged_transitions=128, max_edges=4) padded_example = graph_bundle.pad_example(example, padding_config) generated = graph_bundle.zeros_like_padded_example(padding_config) def _check(x, y): x = np.asarray(x) y = np.asarray(y) self.assertEqual(x.shape, y.shape) self.assertEqual(x.dtype, y.dtype) jax.tree_multimap(_check, generated, padded_example)
def test_convert_example(self): tree = gast.parse( textwrap.dedent("""\ def foo(): x = 5 return x """)) py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree)) ast_edges = [ (tree.body[0].body[1], tree.body[0], 1), (tree.body[0].body[1].value, tree.body[0].body[0].targets[0], 2), ] converted_edges = [(ast_to_node_id[id(source)], ast_to_node_id[id(dest)], edge_type) for (source, dest, edge_type) in ast_edges] example = graph_bundle.convert_graph_with_edges( py_graph, converted_edges, builder=py_ast_graphs.BUILDER) self.assertEqual(list(py_graph), [ "root__Module", "root_body_0__Module_body-seq-helper", "root_body_0_item__FunctionDef", "root_body_0_item_args__arguments", "root_body_0_item_body_0__FunctionDef_body-seq-helper", "root_body_0_item_body_0_item__Assign", "root_body_0_item_body_0_item_targets__Name", "root_body_0_item_body_0_item_value__Constant", "root_body_0_item_body_1__FunctionDef_body-seq-helper", "root_body_0_item_body_1_item__Return", "root_body_0_item_body_1_item_value__Name", ]) self.assertEqual( example.graph_metadata, automaton_builder.EncodedGraphMetadata(num_nodes=11, num_input_tagged_nodes=27)) self.assertEqual(example.node_types.shape, (11, )) np.testing.assert_array_equal(example.edges.input_indices, [[1], [2]]) np.testing.assert_array_equal(example.edges.output_indices, [[9, 2], [10, 6]]) np.testing.assert_array_equal(example.edges.values, [1, 1]) self.assertEqual( example.automaton_graph.initial_to_in_tagged.values.shape, (34, )) self.assertEqual(example.automaton_graph.initial_to_special.shape, (11, )) self.assertEqual( example.automaton_graph.in_tagged_to_in_tagged.values.shape, (103, )) self.assertEqual(example.automaton_graph.in_tagged_to_special.shape, (27, )) # Verify that the transition matrix can be built with the right size. routing_params = py_ast_graphs.BUILDER.initialize_routing_params( None, 1, 1, noise_factor=0) transition_matrix = py_ast_graphs.BUILDER.build_transition_matrix( routing_params, example.automaton_graph, example.graph_metadata) self.assertEqual(transition_matrix.initial_to_in_tagged.shape, (1, 11, 1, 27, 1)) self.assertEqual(transition_matrix.initial_to_special.shape, (1, 11, 1, 3)) self.assertEqual(transition_matrix.in_tagged_to_in_tagged.shape, (1, 27, 1, 27, 1)) self.assertEqual(transition_matrix.in_tagged_to_special.shape, (1, 27, 1, 3)) self.assertEqual(transition_matrix.in_tagged_node_indices.shape, (27, ))
def test_pad_example(self): tree = gast.parse( textwrap.dedent("""\ def foo(): x = 5 return x """)) py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree)) ast_edges = [ (tree.body[0].body[1], tree.body[0], 1), (tree.body[0].body[1].value, tree.body[0].body[0].targets[0], 2), ] converted_edges = [(ast_to_node_id[id(source)], ast_to_node_id[id(dest)], edge_type) for (source, dest, edge_type) in ast_edges] example = graph_bundle.convert_graph_with_edges( py_graph, converted_edges, builder=py_ast_graphs.BUILDER) padding_config = graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=16, num_input_tagged_nodes=34), max_initial_transitions=64, max_in_tagged_transitions=128, max_edges=4) padded_example = graph_bundle.pad_example(example, padding_config) # Metadata is not affected by padding. self.assertEqual( padded_example.graph_metadata, automaton_builder.EncodedGraphMetadata(num_nodes=11, num_input_tagged_nodes=27)) # Everything else is padded. self.assertEqual(padded_example.node_types.shape, (16, )) np.testing.assert_array_equal(padded_example.edges.input_indices, [[1], [2], [0], [0]]) np.testing.assert_array_equal(padded_example.edges.output_indices, [[9, 2], [10, 6], [0, 0], [0, 0]]) np.testing.assert_array_equal(padded_example.edges.values, [1, 1, 0, 0]) self.assertEqual( padded_example.automaton_graph.initial_to_in_tagged.values.shape, (64, )) self.assertEqual( padded_example.automaton_graph.initial_to_special.shape, (16, )) self.assertEqual( padded_example.automaton_graph.in_tagged_to_in_tagged.values.shape, (128, )) self.assertEqual( padded_example.automaton_graph.in_tagged_to_special.shape, (34, )) # Transition matrix also becomes padded once it is built. # (Note that we pass the padded static metadata to the transition matrix # builder, since the encoded graph has been padded.) routing_params = py_ast_graphs.BUILDER.initialize_routing_params( None, 1, 1, noise_factor=0) transition_matrix = py_ast_graphs.BUILDER.build_transition_matrix( routing_params, padded_example.automaton_graph, padding_config.static_max_metadata) self.assertEqual(transition_matrix.initial_to_in_tagged.shape, (1, 16, 1, 34, 1)) self.assertEqual(transition_matrix.initial_to_special.shape, (1, 16, 1, 3)) self.assertEqual(transition_matrix.in_tagged_to_in_tagged.shape, (1, 34, 1, 34, 1)) self.assertEqual(transition_matrix.in_tagged_to_special.shape, (1, 34, 1, 3)) self.assertEqual(transition_matrix.in_tagged_node_indices.shape, (34, ))