Пример #1
0
    def get_action(self, state, player_id):
        assert (self.root_state.compact_str() == state.compact_str())
        assert (self.root_node.player_id != player_id)

        for sim_id in range(self.sim_num):
            node = self.root_node
            state = copy.deepcopy(self.root_state)

            while not state.is_end() and node.is_fully_expanded():
                node = node.uct_select_children()
                state.do_action(node.player_id, node.action)

            if not state.is_end():
                action = random.choice(node.unexpanded_actions)
                next_player_id = get_next_player_id(node.player_id)
                state.do_action(next_player_id, action)
                child = Node(state=state,
                             player_id=next_player_id,
                             action=action,
                             parent=node)
                node.add_child(child)
                node = child

            player_id = node.player_id
            while not state.is_end():
                player_id = get_next_player_id(player_id)
                action = random.choice(state.get_legal_actions(player_id))
                state.do_action(player_id, action)

            result = state.get_result()
            if result > 0:
                V = 1 if result == node.player_id else -1
            else:
                V = 0
            while node:
                node.update(V)
                V = -V
                node = node.parent

        #visualize_tree(self.root_node, self.root_state, depth = 3)

        child = self.root_node.get_most_visited_child()
        child.parent = None

        self.root_node = child
        self.root_state.do_action(self.root_node.player_id,
                                  self.root_node.action)

        return self.root_node.action
Пример #2
0
 def expand(self, actions, Ps):
     for action, P in zip(actions, Ps):
         child = Node(player_id=get_next_player_id(self.player_id),
                      action=action,
                      P=P,
                      parent=self)
         self.children.append(child)
Пример #3
0
 def __init__(self, root_state, player_id, sim_num):
     self.sim_num = sim_num
     self.root_node = Node(state=root_state,
                           player_id=get_next_player_id(player_id),
                           action=None,
                           parent=None)
     self.root_state = copy.deepcopy(root_state)
Пример #4
0
    def __init__(self, state, player_id, action, parent):
        self.player_id = player_id  # player_id that takes the action
        self.action = action  # the action that player takes
        self.N = 0
        self.W = 0
        self.Q = 0

        self.unexpanded_actions = state.get_legal_actions(
            get_next_player_id(player_id))
        self.children = []
        self.parent = parent
 def do_action(self, player_id, action):
     super().do_action(player_id, action)
     opponent_player_id = get_next_player_id(player_id)
     i, j = action
     self.change_color(i, j, -1, 0, player_id, opponent_player_id)
     self.change_color(i, j, 1, 0, player_id, opponent_player_id)
     self.change_color(i, j, 0, -1, player_id, opponent_player_id)
     self.change_color(i, j, 0, 1, player_id, opponent_player_id)
     self.change_color(i, j, -1, -1, player_id, opponent_player_id)
     self.change_color(i, j, -1, 1, player_id, opponent_player_id)
     self.change_color(i, j, 1, -1, player_id, opponent_player_id)
     self.change_color(i, j, 1, 1, player_id, opponent_player_id)
 def get_result(self):
     if self.last_action == None:
         return -1
     if self.get_legal_actions(get_next_player_id(self.last_player_id)):
         return -1
     black_num = sum(1 for e in self.board for v in e if v == 1)
     white_num = sum(1 for e in self.board for v in e if v == 2)
     if black_num > white_num:
         return 1
     if black_num < white_num:
         return 2
     else:
         return 0
Пример #7
0
def search(root_node, root_state, sim_num):
    for sim_id in range(sim_num):
        node = root_node
        state = copy.deepcopy(root_state)

        while not state.is_end() and node.is_fully_expanded():
            node = node.uct_select_children()
            state.do_action(node.player_id, node.action)

        if not state.is_end():
            action = random.choice(node.unexpanded_actions)
            next_player_id = get_next_player_id(node.player_id)
            state.do_action(next_player_id, action)
            child = Node(state=state,
                         player_id=next_player_id,
                         action=action,
                         parent=node)
            node.add_child(child)
            node = child

        player_id = node.player_id
        while not state.is_end():
            player_id = get_next_player_id(player_id)
            action = random.choice(state.get_legal_actions(player_id))
            state.do_action(player_id, action)

        result = state.get_result()
        if result > 0:
            V = 1 if result == node.player_id else -1
        else:
            V = 0

        while node:
            node.update(V)
            V = -V
            node = node.parent

    action_to_N = {c.action: c.N for c in root_node.children}
    return action_to_N
Пример #8
0
 def __init__(self,
              model,
              root_state,
              player_id,
              sim_num,
              is_training,
              dirichlet_factor=None,
              dirichlet_alpha=None):
     self.model = model
     self.sim_num = sim_num
     self.is_training = is_training
     self.dirichlet_factor = dirichlet_factor
     self.dirichlet_alpha = dirichlet_alpha
     self.root_node = Node(player_id=get_next_player_id(player_id))
     self.root_state = copy.deepcopy(root_state)
 def get_legal_actions(self, player_id):
     super().get_legal_actions(player_id)
     opponent_player_id = get_next_player_id(player_id)
     legal_actions = []
     for i in range(self.board_shape[0]):
         for j in range(self.board_shape[1]):
             if self.board[i][j] == 0:
                 if self.get_target_pos(i, j, -1, 0, player_id, opponent_player_id)[0] < i-1 or \
                    self.get_target_pos(i, j, 1, 0, player_id, opponent_player_id)[0] > i+1 or \
                    self.get_target_pos(i, j, 0, -1, player_id, opponent_player_id)[1] < j-1 or \
                    self.get_target_pos(i, j, 0, 1, player_id, opponent_player_id)[1] > j+1 or \
                    self.get_target_pos(i, j, -1, -1, player_id, opponent_player_id)[0] < i-1 or \
                    self.get_target_pos(i, j, -1, 1, player_id, opponent_player_id)[0] < i-1 or \
                    self.get_target_pos(i, j, 1, -1, player_id, opponent_player_id)[0] > i+1 or \
                    self.get_target_pos(i, j, 1, 1, player_id, opponent_player_id)[0] > i+1:
                     legal_actions.append((i, j))
     return legal_actions
 def do_action(self, player_id, action):
     assert (self.cur_player_id == player_id)
     assert (player_id == 1 and self.board[action[0]][action[1]] > NULL
             or player_id == 2 and self.board[action[0]][action[1]] < NULL)
     assert (player_id == 1 and self.board[action[2]][action[3]] <= NULL
             or player_id == 2 and self.board[action[2]][action[3]] >= NULL)
     self.board_history.append(copy.deepcopy(self.board))
     if len(self.board_history) > 8:
         del self.board_history[0]
     if self.board[action[2]][action[3]] != NULL:
         self.left_no_kill_action_num = MAX_NO_KILL_ACTION_NUM
     else:
         self.left_no_kill_action_num -= 1
     self.board[action[2]][action[3]] = self.board[action[0]][action[1]]
     self.board[action[0]][action[1]] = NULL
     self.cur_player_id = get_next_player_id(player_id)
     self.left_action_num -= 1
Пример #11
0
    def get_legal_Ps(self):
        for sim_id in range(self.sim_num):
            node = self.root_node
            state = copy.deepcopy(self.root_state)

            while not node.is_leaf():
                node = node.uct_select_children()
                state.do_action(node.player_id, node.action)

            result = state.get_result()
            if result >= 0:
                V = 1 if result > 0 else 0
            else:
                player_id = get_next_player_id(node.player_id)
                state_m = state.to_state_m()
                _, Ps_m, _, V = self.model.evaluate(state_m)
                V = -V
                actions = state.get_legal_actions(player_id)
                action_indexes = [
                    state.action_to_action_index(action) for action in actions
                ]
                legal_Ps_m = Ps_m[action_indexes]
                if self.is_training and not node.parent:
                    legal_Ps_m = legal_Ps_m * (
                        1 - self.dirichlet_factor) + np.random.dirichlet(
                            np.ones_like(legal_Ps_m) *
                            self.dirichlet_alpha) * self.dirichlet_factor
                node.expand(actions, legal_Ps_m)

            while node:
                node.update(V)
                V = -V
                node = node.parent

        #visualize_tree(self.root_node, self.root_state, depth = 3)

        t = 1 if self.is_training else 3
        Ns_m = np.array([c.N for c in self.root_node.children], dtype=np.int64)
        Ns_m **= t
        legal_Ps_m = Ns_m / np.sum(Ns_m)

        return legal_Ps_m
Пример #12
0
 def get_cur_player_id(self):
     return get_next_player_id(self.last_player_id)