Exemplo n.º 1
0
    def test_soft_maze_values(self):
        maze = np.array([
            [1, 1, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 1, 1, 1, 1],
        ]).astype(bool)

        # Convert the maze into an adjacency matrix.
        maze_graph, coords = maze_schema.encode_maze(maze)
        primitive_edges = maze_task.maze_primitive_edges(maze_graph)
        example = graph_bundle.convert_graph_with_edges(
            maze_graph, primitive_edges, maze_task.BUILDER)
        edge_primitives = example.edges.apply_add(
            in_array=jnp.eye(4),
            out_array=jnp.zeros([len(maze_graph),
                                 len(maze_graph), 4]),
            in_dims=(0, ),
            out_dims=(0, 1))

        # Compute values for getting to a particular square, under a low temperature
        # (so that it's approximately shortest paths)
        values, q_vals, policy = train_maze_lib.soft_maze_values(
            edge_primitives,
            target_state_index=coords.index((0, 1)),
            temperature=1e-7)

        expected_values_at_coords = np.array([
            [-1, 0, -1, -2, -3],
            [-2, np.nan, -2, -3, -4],
            [-3, -4, -3, -4, -5],
        ])
        expected_values = [expected_values_at_coords[c] for c in coords]
        np.testing.assert_allclose(values, expected_values, atol=1e-6)

        # Check Q vals and policy at top left corner.
        np.testing.assert_allclose(q_vals[0], [-2, -1, -2, -3])
        np.testing.assert_allclose(policy[0], [0, 1, 0, 0])

        # Check reverse-mode gradient under a higher temperature.
        fun = functools.partial(train_maze_lib.soft_maze_values,
                                target_state_index=coords.index((0, 1)),
                                temperature=1)
        jax.test_util.check_grads(fun, (edge_primitives, ), 1, "rev")

        # Check gradient under batching.
        jax.test_util.check_grads(jax.vmap(fun),
                                  (edge_primitives[None, Ellipsis], ), 1,
                                  "rev")
Exemplo n.º 2
0
    def test_primitive_edges(self):
        maze = np.array([
            [1, 1, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 1, 1, 1, 1],
        ]).T.astype(bool)

        maze_graph, _ = maze_schema.encode_maze(maze)
        primitive_edges = maze_task.maze_primitive_edges(maze_graph)

        subset_of_expected_edges = [
            ("cell_0_0", "cell_0_0", 0),
            ("cell_0_0", "cell_0_1", 1),
            ("cell_0_0", "cell_0_0", 2),
            ("cell_0_0", "cell_1_0", 3),
        ]
        for expected in subset_of_expected_edges:
            self.assertIn(expected, primitive_edges)
Exemplo n.º 3
0
    def test_loss_fn(self):
        maze = np.array([
            [1, 1, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 1, 1, 1, 1],
        ]).astype(bool)

        # Convert the maze into an adjacency matrix.
        maze_graph, _ = maze_schema.encode_maze(maze)
        primitive_edges = maze_task.maze_primitive_edges(maze_graph)
        example = graph_bundle.convert_graph_with_edges(
            maze_graph, primitive_edges, maze_task.BUILDER)

        def mock_model(example, dynamic_metadata):
            del example, dynamic_metadata
            return jnp.full([3, 14, 14], 1 / 14)

        loss, metrics = train_maze_lib.loss_fn(mock_model, (example, 1),
                                               num_goals=4)
        self.assertGreater(loss, 0)
        for metric in metrics.values():
            self.assertEqual(metric.shape, ())
Exemplo n.º 4
0
    def test_encode_maze(self):
        """Tests that an encoded maze is correct and matches the schema."""

        maze = np.array([
            [1, 1, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 1, 1, 1, 1],
        ]).astype(bool)

        encoded_graph, coordinates = maze_schema.encode_maze(maze)

        # Check coordinates.
        expected_coords = []
        for r in range(3):
            for c in range(5):
                if (r, c) != (1, 1):
                    expected_coords.append((r, c))

        self.assertEqual(coordinates, expected_coords)

        # Check a few nodes.
        self.assertEqual(
            encoded_graph[graph_types.NodeId("cell_0_0")],
            graph_types.GraphNode(
                graph_types.NodeType("cell_xRxD"), {
                    graph_types.OutEdgeType("R_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_0_1"),
                            graph_types.InEdgeType("L_in"))
                    ],
                    graph_types.OutEdgeType("D_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_1_0"),
                            graph_types.InEdgeType("U_in"))
                    ],
                }))

        self.assertEqual(
            encoded_graph[graph_types.NodeId("cell_1_4")],
            graph_types.GraphNode(
                graph_types.NodeType("cell_LxUD"), {
                    graph_types.OutEdgeType("L_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_1_3"),
                            graph_types.InEdgeType("R_in"))
                    ],
                    graph_types.OutEdgeType("U_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_0_4"),
                            graph_types.InEdgeType("D_in"))
                    ],
                    graph_types.OutEdgeType("D_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_2_4"),
                            graph_types.InEdgeType("U_in"))
                    ],
                }))

        # Check schema validity.
        schema_util.assert_conforms_to_schema(encoded_graph,
                                              maze_schema.build_maze_schema(2))