def _playout(self, game_state: go.GameState): """Run a single playout from the root to the leaf, getting a value at the leaf and propagating it back through its parents. State is modified in-place, so a copy must be provided. """ node = self._root while True: if node.is_leaf(): break action, node = node.select(self._c_puct) if action == PASS_MOVE: move = PASS_MOVE else: move = (action // game_state.size, action % game_state.size) game_state.do_move(move) if not game_state.is_end_of_game and len(game_state.get_legal_moves(False)) > 0: action_probs, leaf_value = self._policy_value_net.policy_value_fn( game_state ) node.expand(action_probs) else: winner = game_state.get_winner() # for end state,return the "true" leaf_value if winner is None: # tie leaf_value = 0.0 else: leaf_value = 1.0 if winner == game_state.current_player else -1.0 # Update value and visit count of nodes in this traversal. node.update_recursive(-leaf_value)
def _evaluate_rollout(self, game_state: go.GameState, limit=1000): """Use the rollout policy to play until the end of the game, returning +1 if the current player wins, -1 if the opponent wins, and 0 if it is a tie. """ player = game_state.current_player for _ in range(limit): if game_state.is_end_of_game or len( game_state.get_legal_moves(False)) == 0: break action_probs = _rollout_policy_fn(game_state) max_action = max(action_probs, key=itemgetter(1))[0] if max_action == PASS_MOVE: move = PASS_MOVE else: move = (max_action // game_state.size, max_action % game_state.size) game_state.do_move(move) else: # If no break from the loop, issue a warning. print("WARNING: rollout reached move limit") winner = game_state.get_winner() if winner is None: # tie return 0.0 else: return 1.0 if winner == player else -1.0
def _rollout_policy_fn(game_state: go.GameState): """a coarse, fast version of policy_fn used in the rollout phase.""" # rollout randomly legal_moves = (np.array(game_state.get_legal_moves(False)) @ np.array( [game_state.size, 1])).tolist() legal_moves = legal_moves + [PASS_MOVE] action_probs = np.random.rand(len(legal_moves)) return zip(legal_moves, action_probs)
def _policy_value_fn(game_state: go.GameState): """a function that takes in a state and outputs a list of (action, probability) tuples and a score for the state""" # return uniform probabilities and 0 score for pure MCTS legal_moves = (np.array(game_state.get_legal_moves(False)) @ np.array( [game_state.size, 1])).tolist() legal_moves = legal_moves + [PASS_MOVE] action_probs = np.ones(len(legal_moves)) / len(legal_moves) return zip(legal_moves, action_probs), 0
def policy_value_fn(self, game_state: go.GameState): """ input: board output: a list of (action, probability) tuples for each available action and the score of the board state """ legal_moves = (np.array(game_state.get_legal_moves(False)) @ np.array( [game_state.size, 1])).tolist() legal_moves_idx = legal_moves + [game_state.size**2] legal_moves = legal_moves + [PASS_MOVE] current_input = np.expand_dims(get_current_input(game_state), axis=0) if self.use_gpu: log_act_probs, value = self.policy_value_net( Variable(torch.from_numpy(current_input)).cuda().float()) act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten()) else: log_act_probs, value = self.policy_value_net( Variable(torch.from_numpy(current_input)).float()) act_probs = np.exp(log_act_probs.data.numpy()).flatten() _act_probs = zip(legal_moves, act_probs[legal_moves_idx]) value = value.data[0][0] return _act_probs, value