def test_soft_maze_values(self): maze = np.array([ [1, 1, 1, 1, 1], [1, 0, 1, 1, 1], [1, 1, 1, 1, 1], ]).astype(bool) # Convert the maze into an adjacency matrix. maze_graph, coords = maze_schema.encode_maze(maze) primitive_edges = maze_task.maze_primitive_edges(maze_graph) example = graph_bundle.convert_graph_with_edges( maze_graph, primitive_edges, maze_task.BUILDER) edge_primitives = example.edges.apply_add( in_array=jnp.eye(4), out_array=jnp.zeros([len(maze_graph), len(maze_graph), 4]), in_dims=(0, ), out_dims=(0, 1)) # Compute values for getting to a particular square, under a low temperature # (so that it's approximately shortest paths) values, q_vals, policy = train_maze_lib.soft_maze_values( edge_primitives, target_state_index=coords.index((0, 1)), temperature=1e-7) expected_values_at_coords = np.array([ [-1, 0, -1, -2, -3], [-2, np.nan, -2, -3, -4], [-3, -4, -3, -4, -5], ]) expected_values = [expected_values_at_coords[c] for c in coords] np.testing.assert_allclose(values, expected_values, atol=1e-6) # Check Q vals and policy at top left corner. np.testing.assert_allclose(q_vals[0], [-2, -1, -2, -3]) np.testing.assert_allclose(policy[0], [0, 1, 0, 0]) # Check reverse-mode gradient under a higher temperature. fun = functools.partial(train_maze_lib.soft_maze_values, target_state_index=coords.index((0, 1)), temperature=1) jax.test_util.check_grads(fun, (edge_primitives, ), 1, "rev") # Check gradient under batching. jax.test_util.check_grads(jax.vmap(fun), (edge_primitives[None, Ellipsis], ), 1, "rev")
def test_primitive_edges(self): maze = np.array([ [1, 1, 1, 1, 1], [1, 0, 1, 1, 1], [1, 1, 1, 1, 1], ]).T.astype(bool) maze_graph, _ = maze_schema.encode_maze(maze) primitive_edges = maze_task.maze_primitive_edges(maze_graph) subset_of_expected_edges = [ ("cell_0_0", "cell_0_0", 0), ("cell_0_0", "cell_0_1", 1), ("cell_0_0", "cell_0_0", 2), ("cell_0_0", "cell_1_0", 3), ] for expected in subset_of_expected_edges: self.assertIn(expected, primitive_edges)
def test_loss_fn(self): maze = np.array([ [1, 1, 1, 1, 1], [1, 0, 1, 1, 1], [1, 1, 1, 1, 1], ]).astype(bool) # Convert the maze into an adjacency matrix. maze_graph, _ = maze_schema.encode_maze(maze) primitive_edges = maze_task.maze_primitive_edges(maze_graph) example = graph_bundle.convert_graph_with_edges( maze_graph, primitive_edges, maze_task.BUILDER) def mock_model(example, dynamic_metadata): del example, dynamic_metadata return jnp.full([3, 14, 14], 1 / 14) loss, metrics = train_maze_lib.loss_fn(mock_model, (example, 1), num_goals=4) self.assertGreater(loss, 0) for metric in metrics.values(): self.assertEqual(metric.shape, ())
def test_encode_maze(self): """Tests that an encoded maze is correct and matches the schema.""" maze = np.array([ [1, 1, 1, 1, 1], [1, 0, 1, 1, 1], [1, 1, 1, 1, 1], ]).astype(bool) encoded_graph, coordinates = maze_schema.encode_maze(maze) # Check coordinates. expected_coords = [] for r in range(3): for c in range(5): if (r, c) != (1, 1): expected_coords.append((r, c)) self.assertEqual(coordinates, expected_coords) # Check a few nodes. self.assertEqual( encoded_graph[graph_types.NodeId("cell_0_0")], graph_types.GraphNode( graph_types.NodeType("cell_xRxD"), { graph_types.OutEdgeType("R_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_0_1"), graph_types.InEdgeType("L_in")) ], graph_types.OutEdgeType("D_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_1_0"), graph_types.InEdgeType("U_in")) ], })) self.assertEqual( encoded_graph[graph_types.NodeId("cell_1_4")], graph_types.GraphNode( graph_types.NodeType("cell_LxUD"), { graph_types.OutEdgeType("L_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_1_3"), graph_types.InEdgeType("R_in")) ], graph_types.OutEdgeType("U_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_0_4"), graph_types.InEdgeType("D_in")) ], graph_types.OutEdgeType("D_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_2_4"), graph_types.InEdgeType("U_in")) ], })) # Check schema validity. schema_util.assert_conforms_to_schema(encoded_graph, maze_schema.build_maze_schema(2))