예제 #1
0
def test_static_env():
    from cchess_alphazero.environment.env import CChessEnv
    import cchess_alphazero.environment.static_env as senv
    from cchess_alphazero.environment.static_env import INIT_STATE
    from cchess_alphazero.environment.lookup_tables import flip_move
    env = CChessEnv()
    env.reset()
    print("env:  " + env.observation)
    print("senv: " + INIT_STATE)
    state = INIT_STATE
    env.step('0001')
    state = senv.step(state, '0001')
    print(senv.evaluate(state))
    print("env:  " + env.observation)
    print("senv: " + state)
    env.step('7770')
    state = senv.step(state, flip_move('7770'))
    print(senv.evaluate(state))
    print("env:  " + env.observation)
    print("senv: " + state)
    env.render()
    board = senv.state_to_board(state)
    for i in range(9, -1, -1):
        print(board[i])
    print("env: ")
    print(env.input_planes()[0+7:3+7])
    print("senv: ")
    print(senv.state_to_planes(state)[0+7:3+7])
    print(f"env:  {env.board.legal_moves()}" )
    print(f"senv: {senv.get_legal_moves(state)}")
    print(set(env.board.legal_moves()) == set(senv.get_legal_moves(state)))
예제 #2
0
    def MCTS_search(self, state, history=[], is_root_node=False) -> float:
        """
        Monte Carlo Tree Search
        """
        while True:
            # logger.debug(f"start MCTS, state = {state}, history = {history}")
            game_over, v, _ = senv.done(state)
            if game_over:
                self.executor.submit(self.update_tree, None, v, history)
                break

            with self.node_lock[state]:
                if state not in self.tree:
                    # Expand and Evaluate
                    self.tree[state].sum_n = 1
                    self.tree[state].legal_moves = senv.get_legal_moves(state)
                    self.tree[state].waiting = True
                    # logger.debug(f"expand_and_evaluate {state}, sum_n = {self.tree[state].sum_n}, history = {history}")
                    self.expand_and_evaluate(state, history)
                    break

                if state in history[:-1]:  # loop -> loss
                    # logger.debug(f"loop -> loss, state = {state}, history = {history[:-1]}")
                    self.executor.submit(self.update_tree, None, 0, history)
                    break

                # Select
                node = self.tree[state]
                if node.waiting:
                    node.visit.append(history)
                    # logger.debug(f"wait for prediction state = {state}")
                    break

                sel_action = self.select_action_q_and_u(state, is_root_node)

                virtual_loss = self.config.play.virtual_loss
                self.tree[state].sum_n += 1
                # logger.debug(f"node = {state}, sum_n = {node.sum_n}")

                action_state = self.tree[state].a[sel_action]
                action_state.n += virtual_loss
                action_state.w -= virtual_loss
                action_state.q = action_state.w / action_state.n

                # logger.debug(f"apply virtual_loss = {virtual_loss}, as.n = {action_state.n}, w = {action_state.w}, q = {action_state.q}")

                if action_state.next is None:
                    action_state.next = senv.step(state, sel_action)
                # logger.debug(f"step action {sel_action}, next = {action_state.next}")

            history.append(sel_action)
            state = action_state.next
            history.append(state)
예제 #3
0
    def MCTS_search(self, state, history=[], is_root_node=False) -> float:
        """
        Monte Carlo Tree Search
        """
        while True:
            # logger.debug(f"start MCTS, state = {state}, history = {history}")
            game_over, v, _ = senv.done(state)
            if game_over:
                self.executor.submit(self.update_tree, None, v, history)
                break

            with self.node_lock[state]:
                if state not in self.tree:
                    # Expand and Evaluate
                    self.tree[state].sum_n = 1
                    self.tree[state].legal_moves = senv.get_legal_moves(state)
                    self.tree[state].waiting = True
                    # logger.debug(f"expand_and_evaluate {state}, sum_n = {self.tree[state].sum_n}, history = {history}")
                    self.expand_and_evaluate(state, history)
                    break

                if state in history[:-1]: # loop -> loss
                    # logger.debug(f"loop -> loss, state = {state}, history = {history[:-1]}")
                    self.executor.submit(self.update_tree, None, 0, history)
                    break

                # Select
                node = self.tree[state]
                if node.waiting:
                    node.visit.append(history)
                    # logger.debug(f"wait for prediction state = {state}")
                    break

                sel_action = self.select_action_q_and_u(state, is_root_node)

                virtual_loss = self.config.play.virtual_loss
                self.tree[state].sum_n += 1
                # logger.debug(f"node = {state}, sum_n = {node.sum_n}")
                
                action_state = self.tree[state].a[sel_action]
                action_state.n += virtual_loss
                action_state.w -= virtual_loss
                action_state.q = action_state.w / action_state.n

                # logger.debug(f"apply virtual_loss = {virtual_loss}, as.n = {action_state.n}, w = {action_state.w}, q = {action_state.q}")
                
                if action_state.next is None:
                    action_state.next = senv.step(state, sel_action)
                # logger.debug(f"step action {sel_action}, next = {action_state.next}")

            history.append(sel_action)
            state = action_state.next
            history.append(state)
예제 #4
0
    def MCTS_search(self,
                    state,
                    history=[],
                    is_root_node=False,
                    real_hist=None) -> float:
        """
        Monte Carlo Tree Search
        """
        while True:
            # logger.debug(f"start MCTS, state = {state}, history = {history}")
            game_over, v, _ = senv.done(state)
            if game_over:
                v = v * 2
                self.executor.submit(self.update_tree, None, v, history)
                break

            with self.node_lock[state]:
                if state not in self.tree:
                    # Expand and Evaluate
                    self.tree[state].sum_n = 1
                    self.tree[state].legal_moves = senv.get_legal_moves(state)
                    self.tree[state].waiting = True
                    # logger.debug(f"expand_and_evaluate {state}, sum_n = {self.tree[state].sum_n}, history = {history}")
                    if is_root_node and real_hist:
                        self.expand_and_evaluate(state, history, real_hist)
                    else:
                        self.expand_and_evaluate(state, history)
                    break

                if state in history[:-1]:  # loop
                    for i in range(len(history) - 1):
                        if history[i] == state:
                            if senv.will_check_or_catch(state, history[i + 1]):
                                self.executor.submit(self.update_tree, None,
                                                     -1, history)
                            elif senv.be_catched(state, history[i + 1]):
                                self.executor.submit(self.update_tree, None, 1,
                                                     history)
                            else:
                                # logger.debug(f"loop -> loss, state = {state}, history = {history[:-1]}")
                                self.executor.submit(self.update_tree, None, 0,
                                                     history)
                            break
                    break

                # Select
                node = self.tree[state]
                if node.waiting:
                    node.visit.append(history)
                    # logger.debug(f"wait for prediction state = {state}")
                    break

                sel_action = self.select_action_q_and_u(state, is_root_node)

                virtual_loss = self.config.play.virtual_loss
                self.tree[state].sum_n += 1
                # logger.debug(f"node = {state}, sum_n = {node.sum_n}")

                action_state = self.tree[state].a[sel_action]
                action_state.n += virtual_loss
                action_state.w -= virtual_loss
                action_state.q = action_state.w / action_state.n

                # logger.debug(f"apply virtual_loss = {virtual_loss}, as.n = {action_state.n}, w = {action_state.w}, q = {action_state.q}")

                # if action_state.next is None:
                history.append(sel_action)
                state = senv.step(state, sel_action)
                history.append(state)