def suggest_move(self):
        """Suggest a single action after performing a number of MCTS simulations.

    Args: None

    Returns:
      An action to be taken in the environment.
    """
        start = time.time()

        if self.timed_match:
            while time.time() - start < self.seconds_per_move:
                self.tree_search()
        else:
            current_readouts = self.root.n
            with performance.timer('Searched %d' % self.num_mcts_sim, True):
                if (self._num_searches % self._num_moves_per_search
                        == 0) or (not bool(self.root.children)):
                    while self.root.n < current_readouts + self.num_mcts_sim:
                        # perform MCTS searches in the following two scenarios:
                        # (1) we've already taken the designated number of moves per search
                        # (2) or, the root does not have any children
                        self.tree_search()
        # Increments the number of searches performed so far.
        self._num_searches += 1
        # Retrieve the suggested move based on the collected statistics in the tree.
        suggested_move = self.pick_move()
        return suggested_move
예제 #2
0
    def select_leaf(self):
        """Select leaves from MCTS.

    Args: None

    Returns:
      A leaf node.
    """
        current = self
        while True:
            # if a node has never been evaluated (leaf node),
            # we have no basis to select a child.
            if (not current.is_expanded) or current.is_done():
                break

            with performance.timer('calculating action score', self._debug):
                action_scores = current.child_action_score
            with performance.timer('calculating argmax', self._debug):
                best_move = np.argmax(action_scores)
            with performance.timer('add child', self._debug):
                current = current.add_child_if_absent(best_move)

        return current
    def sample_actions(self, mcts_dist):
        """Sample a set of actions per node.

    The sampled actions become the children of the given node.

    Args:
      mcts_dist: This distribution is built using the RL Policy (PPO) by passing
        the state of the current node to the policy network.

    Returns:
      A set of actions sampled from the policy distribution.
    """
        with performance.timer('sample multivariate normal distribution',
                               self._debug):
            sampled_actions = mcts_dist.rvs(size=self.max_num_actions)
        return sampled_actions
    def tree_search(self, parallel_readouts=None):
        """Main tree search (MCTS simulations).

    Args:
      parallel_readouts: Number of parallel searches (threads) that are
        performed in the tree.

    Returns:
      Selected leaf node.
    """
        if parallel_readouts is None:
            parallel_readouts = self._parallel_readouts

        # The array that holds the selected leaf nodes for expand and evaluate
        #   If parallel_readouts is one, len(leaves) = 1
        leaves = []
        failsafe = 0
        while len(leaves
                  ) < parallel_readouts and failsafe < parallel_readouts * 2:

            failsafe += 1

            with performance.timer('select a leaf', self._debug):
                # Select a leaf for expansion in the tree
                leaf = self.root.select_leaf()

            # If this is a terminal node, directly call backup_value
            # We pass `zero` as the backup value on this node. The meaning of this
            # backup value is that `the expected total reward for an agent starting
            # from this terminal node is ZERO.` You may need to change this based
            # on the target environment.
            if leaf.is_done():
                # Everything should be a tensor, as we are working in a batched mode.
                policy_out = self._call_policy(np.asarray([leaf.observ]),
                                               only_normalized=True)
                leaf.network_value = policy_out['value'][0]
                leaf.backup_value(value=0, up_to=self.root)
                continue

            # Append the leaf to the list for further evaluation.
            leaves.append(leaf)

        if parallel_readouts == 1:
            assert (len(leaves) <= 1), 'Only select one leaf or less!'

        # Calculate the state-value and probabilities of the children
        if leaves:
            # Calls the policy on the leaf`s observation (NOT STATE) and retrieves
            # the `value` for these observations from the value network.
            with performance.timer('call policy', self._debug):
                policy_out = self._call_policy(np.asarray([leaves[0].observ]),
                                               only_normalized=True)
            mean = policy_out['mean'][0]
            logstd = policy_out['logstd'][0]
            value = policy_out['value'][0]
            leaves[0].network_value = value
            with performance.timer('create multivariate normal distribution',
                                   self._debug):
                # Based on the returned policy (mean and logstd), we create
                # a multivariate normal distribution. This distribution is later used
                # to sample actions from the environment.
                mcts_dist = multivariate_normal(
                    mean=mean, cov=np.diag(np.power(np.exp(logstd), 2)))
            sampled_actions = self.sample_actions(mcts_dist=mcts_dist)
            child_probs = mcts_dist.pdf(sampled_actions)
            # Update `move_to_action` for the selected leaf
            for i, a in enumerate(sampled_actions):
                leaves[0].move_to_action[i] = a

            # In case the number of parallel environments are not sufficient,
            # we process the actions in chunks. Since the last chunk may have fewer
            # elements, we pad it so as to all chunks have the same size.
            # This restrication is necessary bc of BatchEnv environment.
            first_iteration = True
            child_reward = np.zeros(0)
            child_observ = np.zeros(0)
            child_state_qpos = np.zeros(0)
            child_state_qvel = np.zeros(0)
            child_done = np.zeros(0)

            for mcts_env, mcts_action in zip(self.tree_env, sampled_actions):
                mcts_env.reset()
                mcts_env.set_state(leaves[0].state[0], leaves[0].state[1])
                observ, reward, done, _ = mcts_env.step(mcts_action)
                state = mcts_env.sim.get_state()

                if first_iteration:
                    child_reward = np.array([reward])
                    child_observ = np.array([observ])
                    child_state_qpos = np.array([state.qpos])
                    child_state_qvel = np.array([state.qvel])
                    child_done = np.array([done])
                    first_iteration = False
                else:
                    child_reward = np.concatenate(
                        (child_reward, np.array([reward])))
                    child_observ = np.concatenate((child_observ, [observ]))
                    child_state_qpos = np.concatenate(
                        (child_state_qpos, [state.qpos]))
                    child_state_qvel = np.concatenate(
                        (child_state_qvel, [state.qvel]))
                    child_done = np.concatenate((child_done, np.array([done])))

            # Updates the rewards/observs/states for the selected leaf's children and
            # performs backup step.
            leaves[0].child_reward = child_reward[:self.max_num_actions]
            leaves[0].move_to_observ = child_observ[:self.max_num_actions]
            leaves[0].move_to_state = [
                (qpos, qvel)
                for qpos, qvel in zip(child_state_qpos[:self.max_num_actions],
                                      child_state_qvel[:self.max_num_actions])
            ]
            leaves[0].move_to_done = child_done[:self.max_num_actions]

            # Update the values for all the children by calling the value network.
            # We set `only_normalized` to True as we only need to normalize the
            # observations whenever we call the policy/value network without updating
            # the running mean and standard deviation. We only update running mean
            # and standard deviation for observations in the trajectory.
            # Note that, this is a design decision and is based on the intuition
            # that the MCTS search may visit some states that may not be `good`.
            # Using this approach, we avoid updating the running mean/std for
            # the observations that may not be good.
            network_children = self._call_policy(leaves[0].move_to_observ,
                                                 only_normalized=True)
            leaves[0].child_w = network_children['value']

            with performance.timer('incorporating rewards', self._debug):
                leaves[0].incorporate_results(child_probs=child_probs,
                                              node_value=value,
                                              up_to=self.root)
        return leaves