def test_static_env(): from cchess_alphazero.environment.env import CChessEnv import cchess_alphazero.environment.static_env as senv from cchess_alphazero.environment.static_env import INIT_STATE from cchess_alphazero.environment.lookup_tables import flip_move env = CChessEnv() env.reset() print("env: " + env.observation) print("senv: " + INIT_STATE) state = INIT_STATE env.step('0001') state = senv.step(state, '0001') print(senv.evaluate(state)) print("env: " + env.observation) print("senv: " + state) env.step('7770') state = senv.step(state, flip_move('7770')) print(senv.evaluate(state)) print("env: " + env.observation) print("senv: " + state) env.render() board = senv.state_to_board(state) for i in range(9, -1, -1): print(board[i]) print("env: ") print(env.input_planes()[0+7:3+7]) print("senv: ") print(senv.state_to_planes(state)[0+7:3+7]) print(f"env: {env.board.legal_moves()}" ) print(f"senv: {senv.get_legal_moves(state)}") print(set(env.board.legal_moves()) == set(senv.get_legal_moves(state)))
def MCTS_search(self, state, history=[], is_root_node=False) -> float: """ Monte Carlo Tree Search """ while True: # logger.debug(f"start MCTS, state = {state}, history = {history}") game_over, v, _ = senv.done(state) if game_over: self.executor.submit(self.update_tree, None, v, history) break with self.node_lock[state]: if state not in self.tree: # Expand and Evaluate self.tree[state].sum_n = 1 self.tree[state].legal_moves = senv.get_legal_moves(state) self.tree[state].waiting = True # logger.debug(f"expand_and_evaluate {state}, sum_n = {self.tree[state].sum_n}, history = {history}") self.expand_and_evaluate(state, history) break if state in history[:-1]: # loop -> loss # logger.debug(f"loop -> loss, state = {state}, history = {history[:-1]}") self.executor.submit(self.update_tree, None, 0, history) break # Select node = self.tree[state] if node.waiting: node.visit.append(history) # logger.debug(f"wait for prediction state = {state}") break sel_action = self.select_action_q_and_u(state, is_root_node) virtual_loss = self.config.play.virtual_loss self.tree[state].sum_n += 1 # logger.debug(f"node = {state}, sum_n = {node.sum_n}") action_state = self.tree[state].a[sel_action] action_state.n += virtual_loss action_state.w -= virtual_loss action_state.q = action_state.w / action_state.n # logger.debug(f"apply virtual_loss = {virtual_loss}, as.n = {action_state.n}, w = {action_state.w}, q = {action_state.q}") if action_state.next is None: action_state.next = senv.step(state, sel_action) # logger.debug(f"step action {sel_action}, next = {action_state.next}") history.append(sel_action) state = action_state.next history.append(state)
def MCTS_search(self, state, history=[], is_root_node=False, real_hist=None) -> float: """ Monte Carlo Tree Search """ while True: # logger.debug(f"start MCTS, state = {state}, history = {history}") game_over, v, _ = senv.done(state) if game_over: v = v * 2 self.executor.submit(self.update_tree, None, v, history) break with self.node_lock[state]: if state not in self.tree: # Expand and Evaluate self.tree[state].sum_n = 1 self.tree[state].legal_moves = senv.get_legal_moves(state) self.tree[state].waiting = True # logger.debug(f"expand_and_evaluate {state}, sum_n = {self.tree[state].sum_n}, history = {history}") if is_root_node and real_hist: self.expand_and_evaluate(state, history, real_hist) else: self.expand_and_evaluate(state, history) break if state in history[:-1]: # loop for i in range(len(history) - 1): if history[i] == state: if senv.will_check_or_catch(state, history[i + 1]): self.executor.submit(self.update_tree, None, -1, history) elif senv.be_catched(state, history[i + 1]): self.executor.submit(self.update_tree, None, 1, history) else: # logger.debug(f"loop -> loss, state = {state}, history = {history[:-1]}") self.executor.submit(self.update_tree, None, 0, history) break break # Select node = self.tree[state] if node.waiting: node.visit.append(history) # logger.debug(f"wait for prediction state = {state}") break sel_action = self.select_action_q_and_u(state, is_root_node) virtual_loss = self.config.play.virtual_loss self.tree[state].sum_n += 1 # logger.debug(f"node = {state}, sum_n = {node.sum_n}") action_state = self.tree[state].a[sel_action] action_state.n += virtual_loss action_state.w -= virtual_loss action_state.q = action_state.w / action_state.n # logger.debug(f"apply virtual_loss = {virtual_loss}, as.n = {action_state.n}, w = {action_state.w}, q = {action_state.q}") # if action_state.next is None: history.append(sel_action) state = senv.step(state, sel_action) history.append(state)