예제 #1
0
    def test_catch(self, policy_type: Text):
        env = catch.Catch(rows=2, seed=1)
        num_actions = env.action_spec().num_values
        model = simulator.Simulator(env)
        eval_fn = lambda _: (np.ones(num_actions) / num_actions, 0.)

        timestep = env.reset()
        model.reset()

        search_policy = search.bfs if policy_type == 'bfs' else search.puct

        root = search.mcts(observation=timestep.observation,
                           model=model,
                           search_policy=search_policy,
                           evaluation=eval_fn,
                           num_simulations=100,
                           num_actions=num_actions)

        values = np.array([c.value for c in root.children.values()])
        best_action = search.argmax(values)

        if env._paddle_x > env._ball_x:
            self.assertEqual(best_action, 0)
        if env._paddle_x == env._ball_x:
            self.assertEqual(best_action, 1)
        if env._paddle_x < env._ball_x:
            self.assertEqual(best_action, 2)
예제 #2
0
    def select_action(self, observation: types.Observation) -> types.Action:
        """Computes the agent's policy via MCTS."""
        if self._model.needs_reset:
            self._model.reset(observation)

        # Compute a fresh MCTS plan.
        root = search.mcts(
            observation,
            model=self._model,
            search_policy=search.puct,
            evaluation=self._forward,
            num_simulations=self._num_simulations,
            num_actions=self._num_actions,
            discount=self._discount,
        )

        # The agent's policy is softmax w.r.t. the *visit counts* as in AlphaZero.
        probs = search.visit_count_policy(root)
        action = np.int32(np.random.choice(self._actions, p=probs))

        # Save the policy probs so that we can add them to replay in `observe()`.
        self._probs = probs.astype(np.float32)

        return action