def _single_search(self, game:Quoridor, c_puct, verbose=False) -> float: """Recursively run a single MCTS thread out from the given state using exploration parameter 'c_puct'. """ node = self._node_lookup[game.hash_key()] action = sample_action(node.upper_conf(c_puct), node._player, temperature=0.0) if verbose: print("\tsingle_search starting @", node, "\n\t\ttaking", action, end="") with game.temp_move(action): winner = game.get_winner() if winner is not None: if verbose: print("--> winner is", winner) # Case 1: 'action' ended the game. Return +1 if a win from the perspective of whoever played the move backup_val = +1 if winner == node._player else -1 elif game.hash_key() not in self._node_lookup: # Case 2: 'action' resulted in a state we've never seen before. Create a new node and return pol, val = self.pol_val_fun(game) new_node = TreeNode(game, pol, val) self._node_lookup[game.hash_key()] = new_node node.add_child(action, new_node) if verbose: print("--> leaf <{}> with value".format(str(new_node)), val) # "val" is from the perspective of "new_node" but we're evaluating "node". Flip sign for minmax. backup_val = -val else: # Case 3: we've seen this state before. But it's possible we're reaching it from a different history. # Ensure the parent/child relationship exists then recurse, flipping the sign of the child node's value. if verbose: print("--> recursing to node", self._node_lookup[game.hash_key()]) node.add_child(action, self._node_lookup[game.hash_key()]) backup_val = -self._single_search(game, c_puct, verbose=verbose) # Apply backup node.backup(action, backup_val) return backup_val
def __init__(self, game_state:Quoridor, policy_output, value_output): # _counts is the number of times we've taken some action *from this state*. Initialized to all zeros. Stored # as a torch tensor over all possible actions, to be later masked with the set of legal actions self._counts = torch.zeros(3, 9, 9) self._total_reward = torch.zeros(3, 9, 9) self._policy = policy_output self._value = value_output self._legal_mask = encode_actions_to_planes(game_state.all_legal_moves(), game_state.current_player) self._player = game_state.current_player self._key = game_state.hash_key() self._children = {} self.__flagged = False
def __init__(self, init_state:Quoridor, pol_val_fun): self.pol_val_fun = pol_val_fun self._root = TreeNode(init_state, *pol_val_fun(init_state)) self._node_lookup = {init_state.hash_key(): self._root} self._state = init_state