def expand(parent: Node) -> Node: child_expansion_move = choice(tuple(parent.unexplored_moves)) child_state = MancalaEnv.clone(parent.state) child_state.perform_move(child_expansion_move) child_node = Node(state=child_state, move=child_expansion_move, parent=parent) parent.put_child(child_node) MonteCarloTreePolicy._rave_expand(child_node) # go down the tree return child_node
def _rave_expand(parent: Node): moves = [-1e80 for _ in range(parent.state.board.holes + 1)] for unexplored_move in parent.unexplored_moves.copy(): child_state = MancalaEnv.clone(parent.state) child_state.perform_move(unexplored_move) moves[unexplored_move.index] = evaluation.get_score( state=child_state, parent_side=parent.state.side_to_move) moves_dist = np.asarray(moves, dtype=np.float64).flatten() exp = np.exp(moves_dist - np.max(moves_dist)) dist = exp / np.sum(exp) parent.value = max(dist)
def expand(self, node: AlphaNode): # Tactical workaround the pie move if Move(node.state.side_to_move, 0) in node.unexplored_moves: node.unexplored_moves.remove(Move(node.state.side_to_move, 0)) dist, value = self.network.evaluate_state(node.state) for index, prior in enumerate(dist): expansion_move = Move(node.state.side_to_move, index + 1) if node.state.is_legal(expansion_move): child_state = MancalaEnv.clone(node.state) child_state.perform_move(expansion_move) child_node = AlphaNode(state=child_state, prior=prior, move=expansion_move, parent=node) node.put_child(child_node) # go down the tree return node_utils.select_child_with_maximum_action_value(node)
def search(self, state: MancalaEnv) -> Move: # short circuit last move if len(state.get_legal_moves()) == 1: return state.get_legal_moves()[0] game_state_root = Node(state=MancalaEnv.clone(state)) start_time = datetime.datetime.utcnow() games_played = 0 while datetime.datetime.utcnow() - start_time < self.calculation_time: node = self.tree_policy.select(game_state_root) final_state = self.default_policy.simulate(node) self.rollout_policy.backpropagate(node, final_state) # Debugging information games_played += 1 logging.debug("%s; Game played %i" % (node, games_played)) logging.debug("%s" % game_state_root) chosen_child = node_utils.select_robust_child(game_state_root) logging.info("Choosing: %s" % chosen_child) return chosen_child.move
def _make_temp_child(parent: Node, move: Move) -> MancalaEnv: child_state = MancalaEnv.clone(parent.state) child_state.perform_move(move) return child_state
def __init__(self, state: MancalaEnv, action_taken: Move): self.state = MancalaEnv.clone(state) self.action_taken = Move.clone(action_taken)
def test_cloning_immutability(self): clone = MancalaEnv.clone(self.game) self.game.perform_move(Move(Side.SOUTH, 3)) self.assertEqual(clone.board.get_seeds(Side.SOUTH, 3), 7) self.assertEqual(clone.side_to_move, Side.SOUTH)