Пример #1
0
    def test_soft_maze_values(self):
        maze = np.array([
            [1, 1, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 1, 1, 1, 1],
        ]).astype(bool)

        # Convert the maze into an adjacency matrix.
        maze_graph, coords = maze_schema.encode_maze(maze)
        primitive_edges = maze_task.maze_primitive_edges(maze_graph)
        example = graph_bundle.convert_graph_with_edges(
            maze_graph, primitive_edges, maze_task.BUILDER)
        edge_primitives = example.edges.apply_add(
            in_array=jnp.eye(4),
            out_array=jnp.zeros([len(maze_graph),
                                 len(maze_graph), 4]),
            in_dims=(0, ),
            out_dims=(0, 1))

        # Compute values for getting to a particular square, under a low temperature
        # (so that it's approximately shortest paths)
        values, q_vals, policy = train_maze_lib.soft_maze_values(
            edge_primitives,
            target_state_index=coords.index((0, 1)),
            temperature=1e-7)

        expected_values_at_coords = np.array([
            [-1, 0, -1, -2, -3],
            [-2, np.nan, -2, -3, -4],
            [-3, -4, -3, -4, -5],
        ])
        expected_values = [expected_values_at_coords[c] for c in coords]
        np.testing.assert_allclose(values, expected_values, atol=1e-6)

        # Check Q vals and policy at top left corner.
        np.testing.assert_allclose(q_vals[0], [-2, -1, -2, -3])
        np.testing.assert_allclose(policy[0], [0, 1, 0, 0])

        # Check reverse-mode gradient under a higher temperature.
        fun = functools.partial(train_maze_lib.soft_maze_values,
                                target_state_index=coords.index((0, 1)),
                                temperature=1)
        jax.test_util.check_grads(fun, (edge_primitives, ), 1, "rev")

        # Check gradient under batching.
        jax.test_util.check_grads(jax.vmap(fun),
                                  (edge_primitives[None, Ellipsis], ), 1,
                                  "rev")
Пример #2
0
    def test_primitive_edges(self):
        maze = np.array([
            [1, 1, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 1, 1, 1, 1],
        ]).T.astype(bool)

        maze_graph, _ = maze_schema.encode_maze(maze)
        primitive_edges = maze_task.maze_primitive_edges(maze_graph)

        subset_of_expected_edges = [
            ("cell_0_0", "cell_0_0", 0),
            ("cell_0_0", "cell_0_1", 1),
            ("cell_0_0", "cell_0_0", 2),
            ("cell_0_0", "cell_1_0", 3),
        ]
        for expected in subset_of_expected_edges:
            self.assertIn(expected, primitive_edges)
Пример #3
0
    def test_loss_fn(self):
        maze = np.array([
            [1, 1, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 1, 1, 1, 1],
        ]).astype(bool)

        # Convert the maze into an adjacency matrix.
        maze_graph, _ = maze_schema.encode_maze(maze)
        primitive_edges = maze_task.maze_primitive_edges(maze_graph)
        example = graph_bundle.convert_graph_with_edges(
            maze_graph, primitive_edges, maze_task.BUILDER)

        def mock_model(example, dynamic_metadata):
            del example, dynamic_metadata
            return jnp.full([3, 14, 14], 1 / 14)

        loss, metrics = train_maze_lib.loss_fn(mock_model, (example, 1),
                                               num_goals=4)
        self.assertGreater(loss, 0)
        for metric in metrics.values():
            self.assertEqual(metric.shape, ())