예제 #1
0
    def get_action(self, root_id, board, turn, enemy_turn):

        if turn != enemy_turn:
            pi = self.player.get_pi(root_id, board, turn, tau=0.01)
            self.player_pi = pi
            self.player_visit = self.player.get_visit()
            action, action_index = utils.argmax_onehot(pi)
        else:
            pi = self.enemy.get_pi(root_id, board, turn, tau=0.01)
            self.enemy_pi = pi
            self.enemy_visit = self.enemy.get_visit()
            action, action_index = utils.argmax_onehot(pi)

        return action, action_index
예제 #2
0
    def get_pi(self, root_id, tau):
        self._init_mcts(root_id)
        self._mcts(self.root_id)

        visit = np.zeros(self.board_size**2, 'float')
        for action_index in self.tree[self.root_id]['child']:
            child_id = self.root_id + (action_index, )
            visit[action_index] = self.tree[child_id]['n']

        pi = visit / visit.sum()

        if tau == 0:
            pi, _ = utils.argmax_onehot(pi)

        return pi
예제 #3
0
    def get_action(self, root_id, board, turn, enemy_turn):
        if turn != enemy_turn:
            if isinstance(self.player, agents.ZeroAgent):
                pi = self.player.get_pi(root_id, tau=0)
            else:
                pi = self.player.get_pi(root_id, board, turn, tau=0)
        else:
            if isinstance(self.enemy, agents.ZeroAgent):
                pi = self.enemy.get_pi(root_id, tau=0)
            else:
                pi = self.enemy.get_pi(root_id, board, turn, tau=0)

        action, action_index = utils.argmax_onehot(pi)

        return action, action_index
예제 #4
0
    def get_action(self, root_id, board, turn, enemy_turn):
        if turn != enemy_turn:  # 플레이어 턴일때
            if isinstance(
                    self.player, agents.ZeroAgent
            ):  # isinstance() : 내장함수, 첫번째 파라미터의 객체가 두번째 파라미터의 클래스에 해당하는지 확인한다
                pi = self.player.get_pi(root_id, tau=0)
            else:
                pi = self.player.get_pi(root_id, board, turn,
                                        tau=0)  # 플레이어는 human이므로 이 부분이 실행됨

                # for monitor
                self.monitor.get_pi(root_id, tau=0)
        else:  # 적 턴일때
            if isinstance(self.enemy, agents.ZeroAgent):
                pi = self.enemy.get_pi(root_id, tau=0)  # 적 플레이어는 이 부분이 실행됨
            else:
                pi = self.enemy.get_pi(root_id, board, turn, tau=0)

        # action : boardsize**2크기의 array에서 착수위치만 onehot 인코딩됨
        # action_index : onehot인코딩된 위치
        action, action_index = utils.argmax_onehot(pi)

        return action, action_index
예제 #5
0
    def get_pi(self, root_id, tau):
        self._init_mcts(root_id)
        self._mcts(self.root_id)

        visit = np.zeros(self.board_size**2, 'float')
        policy = np.zeros(self.board_size**2, 'float')

        for action_index in self.tree[self.root_id]['child']:
            child_id = self.root_id + (action_index, )
            visit[action_index] = self.tree[child_id]['n']
            policy[action_index] = self.tree[child_id]['p']

        self.visit = visit
        self.policy = policy

        pi = visit / (
            visit.sum() + 1e-8
        )  # normally visit.sum() is not zero because of expansion

        if tau == 0:
            pi, _ = utils.argmax_onehot(pi)

        return pi
예제 #6
0
    def get_pi(self, root_id, tau):
        # MCTS
        self._init_mcts(root_id)
        self._mcts(self.root_id)

        # 초기화
        visit = np.zeros(self.board_size**2, 'float')
        policy = np.zeros(self.board_size**2, 'float')

        for action_index in self.tree[self.root_id]['child']:
            child_id = self.root_id + (action_index, )
            visit[action_index] = self.tree[child_id][
                'n']  # MCTS에서 각 자식 노드의 방문 횟수
            policy[action_index] = self.tree[child_id][
                'p']  # 정책, 승리 가능성이 높은 위치가 높은 값을 가짐

        self.visit = visit
        self.policy = policy

        pi = visit / visit.sum()
        if tau == 0:
            pi, _ = utils.argmax_onehot(pi)  # 최대값을 가진 위치 중 하나만 onehot인코딩된다

        return pi