def test_catch(self, policy_type: Text): env = catch.Catch(rows=2, seed=1) num_actions = env.action_spec().num_values model = simulator.Simulator(env) eval_fn = lambda _: (np.ones(num_actions) / num_actions, 0.) timestep = env.reset() model.reset() search_policy = search.bfs if policy_type == 'bfs' else search.puct root = search.mcts(observation=timestep.observation, model=model, search_policy=search_policy, evaluation=eval_fn, num_simulations=100, num_actions=num_actions) values = np.array([c.value for c in root.children.values()]) best_action = search.argmax(values) if env._paddle_x > env._ball_x: self.assertEqual(best_action, 0) if env._paddle_x == env._ball_x: self.assertEqual(best_action, 1) if env._paddle_x < env._ball_x: self.assertEqual(best_action, 2)
def select_action(self, observation: types.Observation) -> types.Action: """Computes the agent's policy via MCTS.""" if self._model.needs_reset: self._model.reset(observation) # Compute a fresh MCTS plan. root = search.mcts( observation, model=self._model, search_policy=search.puct, evaluation=self._forward, num_simulations=self._num_simulations, num_actions=self._num_actions, discount=self._discount, ) # The agent's policy is softmax w.r.t. the *visit counts* as in AlphaZero. probs = search.visit_count_policy(root) action = np.int32(np.random.choice(self._actions, p=probs)) # Save the policy probs so that we can add them to replay in `observe()`. self._probs = probs.astype(np.float32) return action