def suggest_move(self): """Suggest a single action after performing a number of MCTS simulations. Args: None Returns: An action to be taken in the environment. """ start = time.time() if self.timed_match: while time.time() - start < self.seconds_per_move: self.tree_search() else: current_readouts = self.root.n with performance.timer('Searched %d' % self.num_mcts_sim, True): if (self._num_searches % self._num_moves_per_search == 0) or (not bool(self.root.children)): while self.root.n < current_readouts + self.num_mcts_sim: # perform MCTS searches in the following two scenarios: # (1) we've already taken the designated number of moves per search # (2) or, the root does not have any children self.tree_search() # Increments the number of searches performed so far. self._num_searches += 1 # Retrieve the suggested move based on the collected statistics in the tree. suggested_move = self.pick_move() return suggested_move
def select_leaf(self): """Select leaves from MCTS. Args: None Returns: A leaf node. """ current = self while True: # if a node has never been evaluated (leaf node), # we have no basis to select a child. if (not current.is_expanded) or current.is_done(): break with performance.timer('calculating action score', self._debug): action_scores = current.child_action_score with performance.timer('calculating argmax', self._debug): best_move = np.argmax(action_scores) with performance.timer('add child', self._debug): current = current.add_child_if_absent(best_move) return current
def sample_actions(self, mcts_dist): """Sample a set of actions per node. The sampled actions become the children of the given node. Args: mcts_dist: This distribution is built using the RL Policy (PPO) by passing the state of the current node to the policy network. Returns: A set of actions sampled from the policy distribution. """ with performance.timer('sample multivariate normal distribution', self._debug): sampled_actions = mcts_dist.rvs(size=self.max_num_actions) return sampled_actions
def tree_search(self, parallel_readouts=None): """Main tree search (MCTS simulations). Args: parallel_readouts: Number of parallel searches (threads) that are performed in the tree. Returns: Selected leaf node. """ if parallel_readouts is None: parallel_readouts = self._parallel_readouts # The array that holds the selected leaf nodes for expand and evaluate # If parallel_readouts is one, len(leaves) = 1 leaves = [] failsafe = 0 while len(leaves ) < parallel_readouts and failsafe < parallel_readouts * 2: failsafe += 1 with performance.timer('select a leaf', self._debug): # Select a leaf for expansion in the tree leaf = self.root.select_leaf() # If this is a terminal node, directly call backup_value # We pass `zero` as the backup value on this node. The meaning of this # backup value is that `the expected total reward for an agent starting # from this terminal node is ZERO.` You may need to change this based # on the target environment. if leaf.is_done(): # Everything should be a tensor, as we are working in a batched mode. policy_out = self._call_policy(np.asarray([leaf.observ]), only_normalized=True) leaf.network_value = policy_out['value'][0] leaf.backup_value(value=0, up_to=self.root) continue # Append the leaf to the list for further evaluation. leaves.append(leaf) if parallel_readouts == 1: assert (len(leaves) <= 1), 'Only select one leaf or less!' # Calculate the state-value and probabilities of the children if leaves: # Calls the policy on the leaf`s observation (NOT STATE) and retrieves # the `value` for these observations from the value network. with performance.timer('call policy', self._debug): policy_out = self._call_policy(np.asarray([leaves[0].observ]), only_normalized=True) mean = policy_out['mean'][0] logstd = policy_out['logstd'][0] value = policy_out['value'][0] leaves[0].network_value = value with performance.timer('create multivariate normal distribution', self._debug): # Based on the returned policy (mean and logstd), we create # a multivariate normal distribution. This distribution is later used # to sample actions from the environment. mcts_dist = multivariate_normal( mean=mean, cov=np.diag(np.power(np.exp(logstd), 2))) sampled_actions = self.sample_actions(mcts_dist=mcts_dist) child_probs = mcts_dist.pdf(sampled_actions) # Update `move_to_action` for the selected leaf for i, a in enumerate(sampled_actions): leaves[0].move_to_action[i] = a # In case the number of parallel environments are not sufficient, # we process the actions in chunks. Since the last chunk may have fewer # elements, we pad it so as to all chunks have the same size. # This restrication is necessary bc of BatchEnv environment. first_iteration = True child_reward = np.zeros(0) child_observ = np.zeros(0) child_state_qpos = np.zeros(0) child_state_qvel = np.zeros(0) child_done = np.zeros(0) for mcts_env, mcts_action in zip(self.tree_env, sampled_actions): mcts_env.reset() mcts_env.set_state(leaves[0].state[0], leaves[0].state[1]) observ, reward, done, _ = mcts_env.step(mcts_action) state = mcts_env.sim.get_state() if first_iteration: child_reward = np.array([reward]) child_observ = np.array([observ]) child_state_qpos = np.array([state.qpos]) child_state_qvel = np.array([state.qvel]) child_done = np.array([done]) first_iteration = False else: child_reward = np.concatenate( (child_reward, np.array([reward]))) child_observ = np.concatenate((child_observ, [observ])) child_state_qpos = np.concatenate( (child_state_qpos, [state.qpos])) child_state_qvel = np.concatenate( (child_state_qvel, [state.qvel])) child_done = np.concatenate((child_done, np.array([done]))) # Updates the rewards/observs/states for the selected leaf's children and # performs backup step. leaves[0].child_reward = child_reward[:self.max_num_actions] leaves[0].move_to_observ = child_observ[:self.max_num_actions] leaves[0].move_to_state = [ (qpos, qvel) for qpos, qvel in zip(child_state_qpos[:self.max_num_actions], child_state_qvel[:self.max_num_actions]) ] leaves[0].move_to_done = child_done[:self.max_num_actions] # Update the values for all the children by calling the value network. # We set `only_normalized` to True as we only need to normalize the # observations whenever we call the policy/value network without updating # the running mean and standard deviation. We only update running mean # and standard deviation for observations in the trajectory. # Note that, this is a design decision and is based on the intuition # that the MCTS search may visit some states that may not be `good`. # Using this approach, we avoid updating the running mean/std for # the observations that may not be good. network_children = self._call_policy(leaves[0].move_to_observ, only_normalized=True) leaves[0].child_w = network_children['value'] with performance.timer('incorporating rewards', self._debug): leaves[0].incorporate_results(child_probs=child_probs, node_value=value, up_to=self.root) return leaves