Ejemplo n.º 1
0
def receive_gameSet():
    receive_data = request.get_json()
    print(receive_data)

    board = Board(width=9, height=9, n_in_row=5)
    board.init_board(1)

    hard_idx = receive_data['hard_idx']
    hards = [2500, 5000, 7500, 10000, 12500, 15000, 17500, 20000]
    model_file = f'./model/policy_9_{hards[hard_idx]}.model'
    policy_param = pickle.load(open(model_file, 'rb'), encoding='bytes')
    best_policy = PolicyValueNetNumpy(9, 9, policy_param)
    mcts_player = MCTSPlayer(best_policy.policy_value_fn,
                             c_puct=5,
                             n_playout=400)

    # AI가 先인 경우, 1번 먼저 돌을 둔다.
    ai_move = mcts_player.get_action(board)
    ai_loc = board.move_to_location(ai_move)

    states_loc = [[0] * 9 for _ in range(9)]
    states_loc[ai_loc[0]][ai_loc[1]] = 2

    data = {
        'ai_moved': list(map(int, ai_loc)),
        'states_loc': states_loc,
        'message': None
    }
    return jsonify(data)
Ejemplo n.º 2
0
def save_heatmap_for_board_and_model(model_name, width, height,
                                     input_plains_num, c_puct, n_playout,
                                     board, board_name, save_to_local,
                                     save_to_tensorboard, writer, i,
                                     heatmap_save_path, use_gpu, **kwargs):

    model_file = f'/home/lirontyomkin/AlphaZero_Gomoku/models/{model_name}/current_policy_{i}.model'

    policy = PolicyValueNet(width,
                            height,
                            model_file=model_file,
                            input_plains_num=input_plains_num,
                            use_gpu=use_gpu)
    player = MCTSPlayer(policy.policy_value_fn,
                        c_puct=c_puct,
                        n_playout=n_playout,
                        name=f"{model_name}_{i}")

    _, heatmap_buf = player.get_action(board, return_prob=0, return_fig=True)
    image = PIL.Image.open(heatmap_buf)

    if save_to_local:
        plt.savefig(heatmap_save_path + f"{board_name}.png",
                    bbox_inches='tight')

    if save_to_tensorboard:
        image_tensor = ToTensor()(image)
        writer.add_image(tag=f'Heatmap on {board_name}',
                         img_tensor=image_tensor,
                         global_step=i)
    plt.close('all')

    now_time = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
    print(f"Done saving: {model_name}_{i} on board {board_name}, {now_time}")
Ejemplo n.º 3
0
 def get_net_player_next_action(
     player,
     i,
     shared_board_states,
     shared_board_availables,
     shared_board_last_move,
     shared_board_current_player,
     game_continue,
     winner,
     play_lock,
     net_lock,
 ):
     from policy_value_net_tensorflow import PolicyValueNet
     from mcts_alphaZero import MCTSPlayer
     local_board = Board(width=self.board.width,
                         height=self.board.height,
                         n_in_row=self.board.n_in_row)
     local_board.init_board(start_player)
     with net_lock:
         policy = PolicyValueNet(local_board.width,
                                 local_board.height,
                                 model_file=player)
     mcts_player = MCTSPlayer(policy.policy_value_fn,
                              c_puct=5,
                              n_playout=400,
                              is_selfplay=0)
     while game_continue.value == 1:
         if shared_board_current_player.value == i:
             with play_lock:
                 # 必须进行同步,好麻烦
                 for k, v in shared_board_states.items():
                     local_board.states[k] = v
                 local_board.availables = []
                 for availables in shared_board_availables:
                     local_board.availables.append(availables)
                 local_board.last_move = shared_board_last_move.value
                 local_board.current_player = shared_board_current_player.value
                 # 同步结束
                 move = mcts_player.get_action(local_board)
                 local_board.do_move(move)
                 #print('player {} do move {}'.format(i, move))
                 if is_shown:
                     self.graphic(local_board)
                 end, win = local_board.game_end()
                 if end:
                     if win != -1:
                         print("Game end. Winner is", win)
                     else:
                         print("Game end. Tie")
                     game_continue.value = 0
                     winner.value = win
                 # 继续同步
                 shared_board_states[
                     move] = shared_board_current_player.value
                 shared_board_availables.remove(move)
                 shared_board_last_move.value = move
                 shared_board_current_player.value = 1 - shared_board_current_player.value
         time.sleep(0.2)
Ejemplo n.º 4
0
class GameStrategy_MZhang():
    def __init__(self, startplayer=0):
        model_file = 'models/resnet/output318/current_policy.model'
        policy_param = None
        self.height = 15
        self.width = 15
        '''if model_file is not None:
            print('loading...', model_file)
            try:
                policy_param = pickle.load(open(model_file, 'rb'))
            except:
                policy_param = pickle.load(open(model_file, 'rb'), encoding='bytes')'''
        policy_value_net = PolicyValueNet(self.height,
                                          self.width,
                                          model_file=model_file,
                                          output='output/')
        self.mcts_player = MCTSPlayer(policy_value_net.policy_value_fn,
                                      c_puct=1,
                                      n_playout=1000)
        self.board = Board(width=self.width, height=self.height, n_in_row=5)
        self.board.init_board(startplayer)
        self.game = Game(self.board)
        p1, p2 = self.board.players
        print('players:', p1, p2)
        self.mcts_player.set_player_ind(p1)
        pass

    def play_one_piece(self, user, gameboard):
        print('user:'******'gameboard:', gameboard.move_history)
        lastm = gameboard.get_lastmove()
        if lastm[0] != -1:
            usr, n, row, col = lastm
            mv = (self.height - row - 1) * self.height + col
            # if not self.board.states.has_key(mv):
            self.board.do_move(mv)

        print('board:', self.board.states.items())
        move = self.mcts_player.get_action(self.board)
        self.board.do_move(move)
        self.game.graphic(self.board, *self.board.players)
        outmv = (self.height - move // self.height - 1, move % self.width)

        return outmv
Ejemplo n.º 5
0
def run(states, sensible_moves, currentPlayer, lastMove):
    n = 5
    width, height = 8, 8
    board = Board(width=width, height=height, n_in_row=n)
    board.init_board()

    board.states = states
    board.availables = sensible_moves
    board.current_player = currentPlayer
    board.last_move = lastMove

    best_policy = PolicyValueNetNumpy(width, height, policy_param)
    mcts_player = MCTSPlayer(best_policy.policy_value_fn,
                             c_puct=5,
                             n_playout=400)

    nextmove = mcts_player.get_action(board)

    return nextmove
Ejemplo n.º 6
0
def run(states, sensible_moves, currentPlayer, lastMove):
    #胜利所需要连续的子
    n = 5
    #棋盘宽度,高度
    width, height = 8, 8
    board = Board(width=width, height=height, n_in_row=n)
    board.init_board()

    board.states = states
    board.availables = sensible_moves
    board.current_player = currentPlayer
    board.last_move = lastMove
    #策略价值网络
    best_policy = PolicyValueNetNumpy(width, height, policy_param)
    #纯蒙特卡洛搜索
    mcts_player = MCTSPlayer(best_policy.policy_value_fn,
                             c_puct=5,
                             n_playout=400)

    #从蒙特卡洛搜索中返回下一步要走的地方
    nextmove = mcts_player.get_action(board)

    return nextmove
    def start_play_with_UI(self,
                           AI: mcts_alphaZero.MCTSPlayer,
                           start_player=0):
        """
        a GUI for playing
        """
        AI.reset_player()
        self.board.init_board()
        current_player = SP = start_player
        UI = GUI(self.board.width)
        end = False
        while True:
            print("current_player", current_player)

            if current_player == 0:
                UI.show_messages("Your turn")
            else:
                UI.show_messages("AI's turn")

            if current_player == 1 and not end:
                move, move_probs = AI.get_action(self.board,
                                                 is_selfplay=False,
                                                 print_probs_value=1)
            else:
                inp = UI.get_input()
                if inp[0] == "move" and not end:
                    if type(inp[1]) != int:
                        move = UI.loc_2_move(inp[1])
                    else:
                        move = inp[1]
                elif inp[0] == "RestartGame":
                    end = False
                    current_player = SP
                    self.board.init_board()
                    UI.restart_game()
                    AI.reset_player()
                    continue
                elif inp[0] == "ResetScore":
                    UI.reset_score()
                    continue
                elif inp[0] == "quit":
                    exit()
                    continue
                elif inp[0] == "SwitchPlayer":
                    end = False
                    self.board.init_board()
                    UI.restart_game(False)
                    UI.reset_score()
                    AI.reset_player()
                    SP = (SP + 1) % 2
                    current_player = SP
                    continue
                else:
                    # print('ignored inp:', inp)
                    continue
            # print('player %r move : %r'%(current_player,[move//self.board.width,move%self.board.width]))
            if not end:
                # print(move, type(move), current_player)
                UI.render_step(move, self.board.current_player)
                self.board.do_move(move)
                # print('move', move)
                # print(2, self.board.get_current_player())
                current_player = (current_player + 1) % 2
                # UI.render_step(move, current_player)
                end, winner = self.board.game_end()
                if end:
                    if winner != -1:
                        print("Game end. Winner is player", winner)
                        UI.add_score(winner)
                    else:
                        print("Game end. Tie")
                    print(UI.score)
                    print()
Ejemplo n.º 8
0
def player_moved():
    receive_data = request.get_json()
    print(receive_data)

    board = Board(width=9, height=9, n_in_row=5)
    board.init_board(0)

    states_loc = receive_data['states_loc']
    if states_loc != None:
        board.states_loc = states_loc
        board.states_loc_to_states()

    # 플레이어가 둔 돌의 위치를 받고
    player_loc = receive_data['player_moved']
    player_move = board.location_to_move(player_loc)
    board.do_move(player_move)
    board.set_forbidden()  # 금수 자리 업데이트

    print(np.array(board.states_loc))
    print(board.states)

    # 승리 판정 (플레이어가 이겼는지)
    end, winner = board.game_end()
    if end:
        if winner == -1: message = "tie"
        else: message = winner

        data = {
            'ai_moved': None,
            'forbidden': board.forbidden_locations,
            'message': message
        }
        return jsonify(data)

    # AI가 둘 위치를 보낸다.
    # 난이도에 해당하는 player 불러옴.
    hard_idx = receive_data['hard_idx']
    hards = [2500, 5000, 7500, 10000, 12500, 15000, 17500, 20000]
    model_file = f'./model/policy_9_{hards[hard_idx]}.model'
    policy_param = pickle.load(open(model_file, 'rb'), encoding='bytes')
    best_policy = PolicyValueNetNumpy(9, 9, policy_param)
    mcts_player = MCTSPlayer(best_policy.policy_value_fn,
                             c_puct=5,
                             n_playout=400)

    ai_move = mcts_player.get_action(board)
    ai_loc = board.move_to_location(ai_move)
    board.do_move(ai_move)
    board.set_forbidden()  # 금수 자리 업데이트

    print(np.array(board.states_loc))

    # 승리 판정 (AI가 이겼는지)
    message = None
    end, winner = board.game_end()
    if end:
        if winner == -1: message = "tie"
        else: message = winner

    data = {
        'ai_moved': list(map(int, ai_loc)),
        'states_loc': board.states_loc,
        'forbidden': board.forbidden_locations,
        'message': message
    }
    return jsonify(data)
Ejemplo n.º 9
0
Archivo: app.py Proyecto: sewon0918/pj4
    result.init_board(start_player=0)
    if input1 != ' ':
        for i in range(len(input1)):
            result.states[input1[i]] = 1
            result.availables.remove(input1[i])
    if input2 != ' ':
        for j in range(len(input2)):
            result.states[input2[j]] = 2
            result.availables.remove(input2[j])
    result.current_player = 1
    return result


parsed_input1, parsed_input2, ai = parse(input_from_app)
width = height = length_list[ai]
board = makemap(parsed_input1, parsed_input2)
model_file = model_file_list[ai]

try:
    policy_param = pickle.load(open(model_file, 'rb'))

except:
    policy_param = pickle.load(open(model_file, 'rb'),
                               encoding='bytes')  # To support python3

best_policy = PolicyValueNetNumpy(width, height, policy_param)
mcts_player = MCTSPlayer(
    best_policy.policy_value_fn, c_puct=5,
    n_playout=400)  # set larger n_playout for better performance
print(mcts_player.get_action(board))
sys.stdout.flush()
Ejemplo n.º 10
0
class OthelloFrame(wx.Frame):
    current_move = 0
    has_set_ai_player = False
    is_banner_displayed = False

    block_length = int((WIN_HEIGHT - 90) / N)
    piece_radius = (block_length >> 1) - 3
    inner_circle_radius = piece_radius - 4
    half_button_width = (BUTTON_WIDTH - BUTTON_WIDTH_MARGIN) >> 1

    mcts_player = None

    line_list = []
    row_list = []
    column_list = []
    chess_record = []
    states = []
    current_players = []
    mcts_probabilities = []
    row_name_list = [
        '15', '14', '13', '12', '11', '10', ' 9', ' 8', ' 7', ' 6', ' 5', ' 4',
        ' 3', ' 2', ' 1'
    ]
    column_name_list = [
        'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'O',
        'P'
    ]

    def __init__(self):
        self.n = N
        self.board = Board(self.n)
        self.thread = threading.Thread()
        self.row_name_list = self.row_name_list[15 - self.n:15]
        self.column_name_list = self.column_name_list[0:self.n]
        self.grid_length = self.block_length * (self.n - 1)
        self.grid_position_x = ((WIN_WIDTH - self.grid_length) >> 1) + 15
        self.grid_position_y = (WIN_HEIGHT - self.grid_length -
                                HEIGHT_OFFSET) >> 1
        self.button_position_x = (self.grid_position_x + ROW_LIST_MARGIN -
                                  BUTTON_WIDTH) >> 1
        self.second_button_position_x = self.button_position_x + self.half_button_width + BUTTON_WIDTH_MARGIN

        for i in range(0, self.grid_length + 1, self.block_length):
            self.line_list.append(
                (i + self.grid_position_x, self.grid_position_y,
                 i + self.grid_position_x,
                 self.grid_position_y + self.grid_length - 1))
            self.line_list.append(
                (self.grid_position_x, i + self.grid_position_y,
                 self.grid_position_x + self.grid_length - 1,
                 i + self.grid_position_y))
            self.row_list.append((self.grid_position_x + ROW_LIST_MARGIN,
                                  i + self.grid_position_y - 8))
            self.column_list.append(
                (i + self.grid_position_x,
                 self.grid_position_y + self.grid_length + COLUMN_LIST_MARGIN))

        wx.Frame.__init__(self,
                          None,
                          title="Othello Zero",
                          pos=((wx.DisplaySize()[0] - WIN_WIDTH) >> 1,
                               (wx.DisplaySize()[1] - WIN_HEIGHT) / 2.5),
                          size=(WIN_WIDTH, WIN_HEIGHT),
                          style=wx.CLOSE_BOX)
        button_font = wx.Font(14, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL,
                              wx.FONTWEIGHT_NORMAL, False)
        image_font = wx.Font(25, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL,
                             wx.FONTWEIGHT_NORMAL, False)

        self.replay_button = wx.Button(
            self,
            label="Replay",
            pos=(self.button_position_x,
                 self.grid_position_y + BUTTON_HEIGHT_MARGIN),
            size=(BUTTON_WIDTH, BUTTON_HEIGHT))
        self.black_button = wx.Button(
            self,
            label="●",
            pos=(self.button_position_x,
                 self.grid_position_y + 2 * BUTTON_HEIGHT_MARGIN),
            size=(self.half_button_width, BUTTON_HEIGHT))
        self.white_button = wx.Button(
            self,
            label="○",
            pos=(self.second_button_position_x,
                 self.grid_position_y + 2 * BUTTON_HEIGHT_MARGIN),
            size=(self.half_button_width, BUTTON_HEIGHT))
        self.ai_hint_button = wx.Button(
            self,
            label="Hint",
            pos=(self.button_position_x,
                 self.grid_position_y + 3 * BUTTON_HEIGHT_MARGIN),
            size=(BUTTON_WIDTH, BUTTON_HEIGHT))
        self.black_text = wx.StaticText(
            self,
            label="●",
            pos=(self.button_position_x + 25,
                 self.grid_position_y + 4 * BUTTON_HEIGHT_MARGIN + 20),
            size=wx.Size(100, 30))
        self.black_number = wx.StaticText(
            self,
            label="",
            pos=(self.button_position_x + 50,
                 self.grid_position_y + 4 * BUTTON_HEIGHT_MARGIN + 20),
            size=wx.Size(100, 30))
        self.white_text = wx.StaticText(
            self,
            label="○",
            pos=(self.button_position_x + 85,
                 self.grid_position_y + 4 * BUTTON_HEIGHT_MARGIN + 20),
            size=wx.Size(100, 30))
        self.white_number = wx.StaticText(
            self,
            label="",
            pos=(self.button_position_x + 110,
                 self.grid_position_y + 4 * BUTTON_HEIGHT_MARGIN + 20),
            size=wx.Size(100, 30))
        self.replay_button.SetFont(button_font)
        self.ai_hint_button.SetFont(button_font)
        self.black_button.SetFont(image_font)
        self.white_button.SetFont(image_font)
        self.replay_button.Disable()
        try:
            self.policy_value_net = PolicyValueNet(
                self.n, model_file='./models/current_policy.model300')
            self.mcts_player = MCTSPlayer(
                self.policy_value_net.policy_value_func,
                c_puct=5,
                n_play_out=400)
            self.black_button.Enable()
            self.white_button.Enable()
            self.ai_hint_button.Enable()
        except IOError as _:
            self.black_button.Disable()
            self.white_button.Disable()
            self.ai_hint_button.Disable()
        self.initialize_user_interface()

    def on_replay_button_click(self, _):
        if not self.thread.is_alive():
            self.board.initialize()
            self.current_move = 0
            self.has_set_ai_player = False
            self.chess_record.clear()
            self.states = []
            self.current_players = []
            self.mcts_probabilities = []
            self.draw_board()
            self.draw_chess()
            self.replay_button.Disable()
            if self.mcts_player is not None:
                self.black_button.Enable()
                self.white_button.Enable()
                self.ai_hint_button.Enable()

    def on_black_button_click(self, _):
        self.black_button.Disable()
        self.white_button.Disable()
        self.has_set_ai_player = True

    def on_white_button_click(self, _):
        self.black_button.Disable()
        self.white_button.Disable()
        self.has_set_ai_player = True
        self.thread = threading.Thread(target=self.ai_next_move, args=())
        self.thread.start()

    def on_ai_hint_button_click(self, _):
        if not self.thread.is_alive():
            self.black_button.Disable()
            self.white_button.Disable()
            self.ai_next_move()

    def on_paint(self, _):
        dc = wx.PaintDC(self)
        dc.SetBackground(wx.Brush(wx.WHITE_BRUSH))
        dc.Clear()
        self.draw_board()
        self.draw_chess()
        self.update_number()

    def ai_next_move(self):
        move, move_probabilities = self.mcts_player.get_action(self.board)
        x, y = self.board.move_to_location(move)
        self.board.add_move(x, y)
        self.flatten(move_probabilities)
        self.draw_move(y, x)
        self.update_number()

    def flatten(self, move_probabilities):
        self.states.append(self.board.get_current_state())
        self.mcts_probabilities.append(move_probabilities)
        self.current_players.append(self.board.get_current_player())

    def disable_buttons(self):
        if self.board.has_winner() != -1:
            self.ai_hint_button.Disable()

    def initialize_user_interface(self):
        self.board = Board(self.n)
        self.Bind(wx.EVT_PAINT, self.on_paint)
        self.Bind(wx.EVT_LEFT_UP, self.on_click)
        self.Bind(wx.EVT_BUTTON, self.on_replay_button_click,
                  self.replay_button)
        self.Bind(wx.EVT_BUTTON, self.on_black_button_click, self.black_button)
        self.Bind(wx.EVT_BUTTON, self.on_white_button_click, self.white_button)
        self.Bind(wx.EVT_BUTTON, self.on_ai_hint_button_click,
                  self.ai_hint_button)
        self.Centre()
        self.Show(True)

    def repaint_board(self):
        self.draw_board()
        self.draw_chess()
        self.is_banner_displayed = False

    def draw_board(self):
        dc = wx.ClientDC(self)
        dc.SetPen(wx.Pen(wx.WHITE))
        dc.SetBrush(wx.Brush(wx.WHITE))
        dc.DrawRectangle(self.grid_position_x - self.block_length,
                         self.grid_position_y - self.block_length,
                         self.grid_length + self.block_length * 2,
                         self.grid_length + self.block_length * 2)
        dc.SetPen(wx.Pen(wx.BLACK, width=2))
        dc.DrawLineList(self.line_list)
        dc.SetFont(
            wx.Font(13, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL,
                    wx.FONTWEIGHT_NORMAL, False))
        dc.DrawTextList(self.row_name_list, self.row_list)
        dc.DrawTextList(self.column_name_list, self.column_list)
        dc.SetBrush(wx.Brush(wx.BLACK))
        if self.n % 2 == 1:
            dc.DrawCircle(
                self.grid_position_x + self.block_length * (self.n >> 1),
                self.grid_position_y + self.block_length * (self.n >> 1), 4)
        if self.n == 15:
            dc.DrawCircle(self.grid_position_x + self.block_length * 3,
                          self.grid_position_y + self.block_length * 3, 4)
            dc.DrawCircle(self.grid_position_x + self.block_length * 3,
                          self.grid_position_y + self.block_length * 11, 4)
            dc.DrawCircle(self.grid_position_x + self.block_length * 11,
                          self.grid_position_y + self.block_length * 3, 4)
            dc.DrawCircle(self.grid_position_x + self.block_length * 11,
                          self.grid_position_y + self.block_length * 11, 4)

    def draw_possible_moves(self, possible_move):
        dc = wx.ClientDC(self)
        for move, p in possible_move:
            y, x = self.board.move_to_location(move)
            dc.SetBrush(
                wx.Brush(
                    wx.Colour(28,
                              164,
                              252,
                              alpha=14 if int(p * 230) < 14 else int(p *
                                                                     230))))
            dc.SetPen(wx.Pen(wx.Colour(28, 164, 252, alpha=230)))
            dc.DrawCircle(self.grid_position_x + x * self.block_length,
                          self.grid_position_y + y * self.block_length,
                          self.piece_radius)

    def draw_chess(self):
        dc = wx.ClientDC(self)
        self.disable_buttons()
        for x, y in np.ndindex(self.board.chess[0:self.n, 0:self.n].shape):
            if self.board.chess[y, x] > 0:
                dc.SetBrush(
                    wx.Brush(wx.BLACK if self.board.chess[y, x] ==
                             1 else wx.WHITE))
                dc.DrawCircle(self.grid_position_x + x * self.block_length,
                              self.grid_position_y + y * self.block_length,
                              self.piece_radius)
        if self.current_move > 0:
            x, y = self.chess_record[self.current_move - 1]
            dc.SetBrush(
                wx.Brush(wx.BLACK if self.board.chess[y,
                                                      x] == 1 else wx.WHITE))
            dc.SetPen(
                wx.Pen(wx.WHITE if self.board.chess[y, x] == 1 else wx.BLACK))
            x = self.grid_position_x + x * self.block_length
            y = self.grid_position_y + y * self.block_length
            dc.DrawCircle(x, y, self.inner_circle_radius)

    def draw_move(self, x: int, y: int) -> bool:
        self.current_move += 1
        self.chess_record.append((x, y))
        self.draw_chess()
        winner = self.board.has_winner()
        if winner != -1 and len(self.states) > 0:
            # winners_z = np.zeros(len(self.current_players))
            # winners_z[np.array(self.current_players) == winner] = 1.0
            # winners_z[np.array(self.current_players) != winner] = -1.0
            # if os.path.exists('play.data'):
            #     with open('play.data', 'rb') as file:
            #         zip_list = pickle.load(file)
            # else:
            #     zip_list = []
            # with open('play.data', 'rb') as file:
            #     zip_list.append(zip(self.states, self.mcts_probabilities, winners_z))
            #     pickle.dump(zip_list, file, pickle.HIGHEST_PROTOCOL)
            #     subprocess.call(["python", "train_play.py"])
            self.disable_buttons()
            self.draw_banner(winner)
            return True
        return False

    def draw_banner(self, result: int):
        w = 216
        if result == 1:
            string = "BLACK WIN"
        elif result == 2:
            string = "WHITE WIN"
        else:
            string = "DRAW"
            w = 97
        x = (self.grid_position_x + ((self.grid_length - w) >> 1))
        dc = wx.ClientDC(self)
        dc.SetBrush(wx.Brush(wx.WHITE))
        dc.DrawRectangle(
            self.grid_position_x + ((self.grid_length - BANNER_WIDTH) >> 1),
            self.grid_position_y + ((self.grid_length - BANNER_HEIGHT) >> 1),
            BANNER_WIDTH, BANNER_HEIGHT)
        dc.SetPen(wx.Pen(wx.BLACK))
        dc.SetFont(
            wx.Font(40, wx.FONTFAMILY_MODERN, wx.FONTSTYLE_NORMAL,
                    wx.FONTWEIGHT_NORMAL, False))
        dc.DrawText(string, x,
                    (self.grid_position_y + ((self.grid_length - 40) >> 1)))
        self.is_banner_displayed = True

    def update_number(self):
        black, white = self.board.get_color_number()
        self.black_number.SetLabel(str(black))
        self.white_number.SetLabel(str(white))

    def on_click(self, e):
        if not self.thread.is_alive():
            if self.board.winner == -1:
                x, y = e.GetPosition()
                x = x - self.grid_position_x + (self.block_length >> 1)
                y = y - self.grid_position_y + (self.block_length >> 1)
                if x > 0 and y > 0:
                    x = int(x / self.block_length)
                    y = int(y / self.block_length)
                    if 0 <= x < self.n and 0 <= y < self.n:
                        if self.board.chess[y, x] == 0:
                            if self.board.location_to_move(
                                    y, x) in self.board.get_available_moves(
                                        self.board.get_current_player()):
                                if self.mcts_player is not None:
                                    self.black_button.Disable()
                                    self.white_button.Disable()
                                self.board.add_move(y, x)
                                has_end = self.draw_move(x, y)
                                self.replay_button.Enable()
                                self.update_number()
                                if self.has_set_ai_player and not has_end:
                                    self.thread = threading.Thread(
                                        target=self.ai_next_move, args=())
                                    self.thread.start()
            elif self.is_banner_displayed:
                self.repaint_board()
Ejemplo n.º 11
0
class Chess_Board_Canvas(Tkinter.Canvas):
    # 棋盤繪製
    def __init__(self, master=None, height=0, width=0):
        Tkinter.Canvas.__init__(self, master, height=height, width=width)
        self.step_record_chess_board = Record.Step_Record_Chess_Board()
        # 初始化記步器
        self.height = 15
        self.width = 15
        self.init_chess_board_points()  # 畫點
        self.init_chess_board_canvas()  # 畫棋盤
        self.board = MCTS.Board()
        self.n_in_row = 5
        self.n_playout = 400  # num of simulations for each move
        self.c_puct = 5
        """
        Important 1: Python is pass by reference
        So the self.board will be modified by other operations
        """
        self.AI = MCTS.MonteCarlo(self.board, 1)
        self.AI_1 = MCTS.MonteCarlo(self.board, 0)
        self.clicked = 1
        self.init = True  # first place is given by user (later need to be replaced as a random selection)
        self.train_or_play = True  # True - train, False - play
        self.step = 0
        self.text_id = None

    def init_chess_board_points(self):
        '''
        生成棋盤點,並且對應到像素座標
        保存到 chess_board_points 屬性
        '''
        self.chess_board_points = [[None for i in range(15)]
                                   for j in range(15)]

        for i in range(15):
            for j in range(15):
                self.chess_board_points[i][j] = Point.Point(i, j)
                # 轉換棋盤座標像素座標
        # self.label_step = Label(self, text="Step")

    def init_chess_board_canvas(self):
        '''
        初始化棋盤
        '''

        for i in range(15):  # 直線
            self.create_line(self.chess_board_points[i][0].pixel_x,
                             self.chess_board_points[i][0].pixel_y,
                             self.chess_board_points[i][14].pixel_x,
                             self.chess_board_points[i][14].pixel_y)

        for j in range(15):  # 橫線
            self.create_line(self.chess_board_points[0][j].pixel_x,
                             self.chess_board_points[0][j].pixel_y,
                             self.chess_board_points[14][j].pixel_x,
                             self.chess_board_points[14][j].pixel_y)
        # 邊界
        self.create_line(self.chess_board_points[2][2].pixel_x,
                         self.chess_board_points[2][2].pixel_y,
                         self.chess_board_points[2][12].pixel_x,
                         self.chess_board_points[2][12].pixel_y,
                         fill="red")
        self.create_line(self.chess_board_points[12][2].pixel_x,
                         self.chess_board_points[12][2].pixel_y,
                         self.chess_board_points[12][12].pixel_x,
                         self.chess_board_points[12][12].pixel_y,
                         fill="red")
        self.create_line(self.chess_board_points[2][12].pixel_x,
                         self.chess_board_points[2][12].pixel_y,
                         self.chess_board_points[12][12].pixel_x,
                         self.chess_board_points[12][12].pixel_y,
                         fill="red")
        self.create_line(self.chess_board_points[2][2].pixel_x,
                         self.chess_board_points[2][2].pixel_y,
                         self.chess_board_points[12][2].pixel_x,
                         self.chess_board_points[12][2].pixel_y,
                         fill="red")

        for i in range(15):  # 交點橢圓
            for j in range(15):
                r = 1
                self.create_oval(self.chess_board_points[i][j].pixel_x - r,
                                 self.chess_board_points[i][j].pixel_y - r,
                                 self.chess_board_points[i][j].pixel_x + r,
                                 self.chess_board_points[i][j].pixel_y + r)

    def click1(self, event):  # click關鍵字重複
        if self.train_or_play:
            print("In self training, mouse event is not available")
            return
        '''
        Mouse listener function, for the game played between human and AI
        '''
        if (self.clicked == 1):

            for i in range(15):
                for j in range(15):
                    square_distance = math.pow(
                        (event.x - self.chess_board_points[i][j].pixel_x),
                        2) + math.pow(
                            (event.y - self.chess_board_points[i][j].pixel_y),
                            2)
                    # 計算滑鼠的位置和點的距離
                    # 距離小於14的點

                    if (square_distance <=
                            200) and (self.step_record_chess_board.checkState(
                                i, j) == None):  # 合法落子位置
                        self.clicked = 0
                        if self.step_record_chess_board.who_to_play() == 1:
                            # 奇數次,黑落子
                            self.create_oval(
                                self.chess_board_points[i][j].pixel_x - 10,
                                self.chess_board_points[i][j].pixel_y - 10,
                                self.chess_board_points[i][j].pixel_x + 10,
                                self.chess_board_points[i][j].pixel_y + 10,
                                fill='black')
                            Tkinter.Canvas.update(self)
                        # 偶數次,白落子
                        elif self.step_record_chess_board.who_to_play() == 2:
                            self.create_oval(
                                self.chess_board_points[i][j].pixel_x - 10,
                                self.chess_board_points[i][j].pixel_y - 10,
                                self.chess_board_points[i][j].pixel_x + 10,
                                self.chess_board_points[i][j].pixel_y + 10,
                                fill='white')

                        result = 0
                        if (self.step_record_chess_board.value[1][i][j] >=
                                90000):
                            result = 1
                            self.clicked = 1
                        self.step_record_chess_board.insert_record(i, j)
                        # 落子,最多225

                        #######result = self.step_record_chess_board.check()
                        # 判斷是否有五子連珠

                        if result == 1:
                            self.create_text(240, 475, text='the black wins')
                            # 解除左键绑定
                            self.unbind('<Button-1>')
                            # """Unbind for this widget for event SEQUENCE  the
                            #     function identified with FUNCID."""

                        elif result == 2:
                            self.create_text(240, 475, text='the white wins')
                            # 解除左键绑定
                            self.unbind('<Button-1>')
            # 根據價值網路落子
            if (self.clicked != 1):
                x = 0
                y = 0
                max_value = 0
                for i in range(0, 15):
                    for j in range(0, 15):
                        if (self.step_record_chess_board.value[2][i][j] >=
                                90000):
                            x = i
                            y = j
                            max_value = 99999
                            break
                        elif (self.step_record_chess_board.value[0][i][j] >=
                              max_value):
                            x = i
                            y = j
                            max_value = self.step_record_chess_board.value[0][
                                i][j]
                if (self.step_record_chess_board.value[2][x][y] >= 90000):
                    result = 2

                self.board.state = np.copy(self.step_record_chess_board.state)
                self.AI.value = self.step_record_chess_board.value[0]
                self.AI.update(self.board.state)
                action = self.AI.bestAction()
                x, y = action

                self.step_record_chess_board.insert_record(x, y)
                self.create_oval(self.chess_board_points[x][y].pixel_x - 10,
                                 self.chess_board_points[x][y].pixel_y - 10,
                                 self.chess_board_points[x][y].pixel_x + 10,
                                 self.chess_board_points[x][y].pixel_y + 10,
                                 fill='white')
                #######result = self.step_record_chess_board.check()
                # 判斷是否有五子連珠

                if result == 1:
                    self.create_text(240, 475, text='the black wins')
                    # 解除左键绑定
                    self.unbind('<Button-1>')
                # """Unbind for this widget for event SEQUENCE  the
                #     function identified with FUNCID."""

                elif result == 2:
                    self.create_text(240, 475, text='the white wins')
                    # 解除左键绑定
                    self.unbind('<Button-1>')
                self.clicked = 1

    def click2(self):  # click關鍵字重複
        """
        #   Human vs AI
        #   Have to load trained NN
        """
        self.train_or_play = False  # this will lock the "ai vs human" button

        return

    def loadAI(self, init_model):
        """"
        # load AI
        """

        # return
        # self.buffer_size = 10000
        # self.batch_size = 512  # mini-batch size for training
        # self.data_buffer = deque(maxlen=self.buffer_size)
        init_model = "current_policy.model"
        init_model = "best_policy.model"
        init_model = 'best_policy_12000.pt'
        # init_model = 'best_policy200.pt'

        self.result = False

        # self.policy_value_net = PolicyValueNet(self.width,
        #                                        self.height,
        #                                        model_file=False)
        self.policy_value_net = PolicyValueNet(self.width,
                                               self.height,
                                               model_file=init_model,
                                               use_gpu=False)

        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=0)
        self.board2 = Board2(width=self.width,
                             height=self.height,
                             n_in_row=self.n_in_row)
        self.board2.init_board(1)
        self.game = Game(self.board2)

    def click3(self):
        """
        ##  Training
        ##  make it self play, AI vs AI, no need to click the mouse, so no need to listen the event
        ##  problem: Let two AIs to play, and learn the NN
        """

        self.train_or_play = True  # this will lock the "ai vs human" button
        self.loadAI(False)
        # self.policy_value_net = PolicyValueNet(self.width,
        #                                        self.height)

        # self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
        #                               c_puct=self.c_puct,
        #                               n_playout=self.n_playout,
        #                               is_selfplay=1)
        print(self.width, self.height)

        # self.step += 1
        # # self.train_agents()
        # print("agent1: load")
        self.train_nn_agents()

    def train_nn_agents(self):
        self.step += 1
        print("============agents at steps %d ============" % self.step)
        if (self.clicked == 1
            ):  # Black stone, Gomoku agent, using very good heuristic function
            print("Black begins at step %d, %d>>>>>>>>>>>>>>>>>>>>>" %
                  (self.step, self.step_record_chess_board.who_to_play()))
            x = 0
            y = 0
            max_value = 0
            result = 0
            for i in range(0, 15):
                for j in range(0, 15):
                    if (self.step_record_chess_board.value[1][i][j] >= 90000):
                        x = i
                        y = j
                        max_value = 99999
                        break
                    elif (self.step_record_chess_board.value[0][i][j] >=
                          max_value):
                        x = i
                        y = j
                        max_value = self.step_record_chess_board.value[0][i][j]

            if (self.step_record_chess_board.value[1][x][y] >= 90000):
                print("win black in black0!!!!!!!!!!!!!!!!")
                result = 1

            self.board.state = np.copy(self.step_record_chess_board.state)
            self.AI.value = self.step_record_chess_board.value[0]
            self.AI.update(self.board.state)
            # print("temp state, before move", self.AI.value)
            action = self.AI.bestAction()
            x, y = action

            self.create_oval(self.chess_board_points[x][y].pixel_x - 10,
                             self.chess_board_points[x][y].pixel_y - 10,
                             self.chess_board_points[x][y].pixel_x + 10,
                             self.chess_board_points[x][y].pixel_y + 10,
                             fill='black')
            move2 = (self.height - y - 1) * self.width + x

            self.board2.current_player = 1
            self.board2.do_move(move2)
            print("Black, Gomoku takes action: ", move2, x, y,
                  (self.step, self.step_record_chess_board.who_to_play()))
            self.step_record_chess_board.insert_record(
                x, y)  # this function will switch to another player
            # if (self.step_record_chess_board.value[1][x][y] >= 90000):
            #     print("win white in black")
            #     result = 2
            # if (self.step_record_chess_board.value[2][x][y] >= 90000):
            #     print("win black in black")
            #     result = 1
            m6 = self.board.isWin(self.step_record_chess_board.state, (x, y),
                                  1)
            if result == 1 or m6:
                self.create_text(240, 475, text='the black wins, b11')
                return
            elif result == 2:
                self.create_text(240, 475, text='the white wins, b22')
                return
            self.clicked = 0

        if (self.clicked != 1):  # white stone, AlphaZero Angent
            print("White begins at step %d, %d >>>>>>>>>>>>>>>>>>>>>" %
                  (self.step, self.step_record_chess_board.who_to_play()))
            self.clicked = 1
            if self.step_record_chess_board.who_to_play() == 1:
                cur_play = 1
            elif self.step_record_chess_board.who_to_play() == 2:
                cur_play = 2  # current is white, 2

            if (self.step_record_chess_board.value[2][x][y] >= 90000):
                print("win white in white!!!!!!!!!!!!!!!!")
                result = 2

            # NN AI do:
            # get board state information
            temp_board = np.copy(self.step_record_chess_board.state)
            self.board2.update_state(temp_board)

            self.board2.current_player = cur_play
            self.mcts_player.reset_player()
            self.mcts_player.set_player_ind(cur_play)
            test_moved = list(
                set(range(self.width * self.height)) -
                set(self.board2.availables))
            self.mcts_player.reset_player()
            action = self.mcts_player.get_action(self.board2)
            self.board2.do_move(action)
            x = action % self.width
            y = action // self.height
            y = self.height - y - 1
            # print("after action",self.board2.states, len(self.board2.availables))
            # print("White, NN agent want to place at: ", action, x, y, (self.step,self.step_record_chess_board.who_to_play()))
            # self.board2.do_move(action)
            # insert into the record, for the white player to use
            self.step_record_chess_board.insert_record(x, y)
            # print("----------------------------------------")

            self.create_oval(self.chess_board_points[x][y].pixel_x - 10,
                             self.chess_board_points[x][y].pixel_y - 10,
                             self.chess_board_points[x][y].pixel_x + 10,
                             self.chess_board_points[x][y].pixel_y + 10,
                             fill='white')
            if self.result:
                print("white wins")
                return
            # if (self.step_record_chess_board.value[1][x][y] >= 90000):
            #     print("win white in white")
            #     result = 2
            # if (self.step_record_chess_board.value[2][x][y] >= 90000):
            #     print("win black in white")
            #     result = 1
            m6 = self.board.isWin(self.step_record_chess_board.state, (x, y),
                                  2)
            if result == 1:
                self.create_text(240, 475, text='the black wins, w12')
                return

            elif result == 2 or m6:
                self.create_text(240, 475, text='the white wins, w22')
                return
            self.clicked = 1

        if self.text_id:
            print(self.text_id)
            self.delete(self.text_id)
        self.text_id = self.create_text(150, 475, text='Step: %d' % self.step)
        self.after(10, self.train_nn_agents)

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=1)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
                play_data: [(state, mcts_prob, winner_z), ..., ...]
                """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(
                    np.flipud(mcts_porb.reshape(self.height, self.width)), i)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    def train_agents(self):
        self.step += 1
        if (self.clicked != 1):
            x = 0
            y = 0
            max_value = 0
            result = 0
            for i in range(0, 15):
                for j in range(0, 15):
                    if (self.step_record_chess_board.value[1][i][j] >= 90000):
                        x = i
                        y = j
                        max_value = 99999
                        break
                    elif (self.step_record_chess_board.value[0][i][j] >=
                          max_value):
                        x = i
                        y = j
                        max_value = self.step_record_chess_board.value[0][i][j]
            if (self.step_record_chess_board.value[1][x][y] >= 90000):
                result = 2
            """
            Important 2:
            Only below 4 line are interact between the AI agent and the board
            self.board.state = np.copy(self.step_record_chess_board.state)  # make a deep copy of state
            self.AI_1.value = self.step_record_chess_board.value[1]         # assign board information, a 15*15 array
            self.AI_1.update(self.board.state)                              # AI function, do some calculations
            action = self.AI_1.bestAction()                                 # Best actions the AI will make
            """
            self.board.state = np.copy(self.step_record_chess_board.state)
            self.AI_1.value = self.step_record_chess_board.value[1]
            self.AI_1.update(self.board.state)
            action = self.AI_1.bestAction()

            x, y = action

            self.step_record_chess_board.insert_record(x, y)

            self.create_oval(self.chess_board_points[x][y].pixel_x - 10,
                             self.chess_board_points[x][y].pixel_y - 10,
                             self.chess_board_points[x][y].pixel_x + 10,
                             self.chess_board_points[x][y].pixel_y + 10,
                             fill='black')
            if result == 1:
                self.create_text(240, 475, text='the black wins')
                return

            elif result == 2:
                self.create_text(240, 475, text='the white wins')
                return
            self.clicked = 1

            # 根據價值網路落子
        if (self.clicked == 1):  # White stone
            x = 0
            y = 0
            max_value = 0
            for i in range(0, 15):
                for j in range(0, 15):
                    if (self.step_record_chess_board.value[2][i][j] >= 90000):
                        x = i
                        y = j
                        max_value = 99999
                        break
                    elif (self.step_record_chess_board.value[0][i][j] >=
                          max_value):
                        x = i
                        y = j
                        max_value = self.step_record_chess_board.value[0][i][j]

            if (self.step_record_chess_board.value[2][x][y] >= 90000):
                result = 2

            self.board.state = np.copy(self.step_record_chess_board.state)
            self.AI.value = self.step_record_chess_board.value[0]
            self.AI.update(self.board.state)
            action = self.AI.bestAction()
            x, y = action

            self.step_record_chess_board.insert_record(x, y)
            self.create_oval(self.chess_board_points[x][y].pixel_x - 10,
                             self.chess_board_points[x][y].pixel_y - 10,
                             self.chess_board_points[x][y].pixel_x + 10,
                             self.chess_board_points[x][y].pixel_y + 10,
                             fill='white')

            if result == 1:
                self.create_text(240, 475, text='the black wins')

                return
            elif result == 2:
                self.create_text(240, 475, text='the white wins')
                return
            self.clicked = 0

        if self.text_id:
            print(self.text_id)
            self.delete(self.text_id)
        self.text_id = self.create_text(150, 475, text='Step: %d' % self.step)
        self.after(10, self.train_agents)
Ejemplo n.º 12
0
class GomokuFrame(wx.Frame):
    moves = 0
    current_move = 0
    has_set_ai_player = False
    is_banner_displayed = False
    is_analysis_displayed = False

    block_length = int((WIN_HEIGHT - 90) / STANDARD_LENGTH)
    piece_radius = (block_length >> 1) - 3
    inner_circle_radius = piece_radius - 4
    half_button_width = (BUTTON_WIDTH - BUTTON_WIDTH_MARGIN) >> 1

    mcts_player = None

    line_list = []
    row_list = []
    column_list = []
    chess_record = []
    row_name_list = [
        '15', '14', '13', '12', '11', '10', ' 9', ' 8', ' 7', ' 6', ' 5', ' 4',
        ' 3', ' 2', ' 1'
    ]
    column_name_list = [
        'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'O',
        'P'
    ]

    def __init__(self, n: int):
        self.n = n
        self.board = Board(n)
        self.thread = threading.Thread()
        self.row_name_list = self.row_name_list[STANDARD_LENGTH -
                                                n:STANDARD_LENGTH]
        self.column_name_list = self.column_name_list[0:n]
        self.grid_length = self.block_length * (n - 1)
        self.grid_position_x = ((WIN_WIDTH - self.grid_length) >> 1) + 15
        self.grid_position_y = (WIN_HEIGHT - self.grid_length -
                                HEIGHT_OFFSET) >> 1
        self.button_position_x = (self.grid_position_x + ROW_LIST_MARGIN -
                                  BUTTON_WIDTH) >> 1
        self.second_button_position_x = self.button_position_x + self.half_button_width + BUTTON_WIDTH_MARGIN

        for i in range(0, self.grid_length + 1, self.block_length):
            self.line_list.append(
                (i + self.grid_position_x, self.grid_position_y,
                 i + self.grid_position_x,
                 self.grid_position_y + self.grid_length - 1))
            self.line_list.append(
                (self.grid_position_x, i + self.grid_position_y,
                 self.grid_position_x + self.grid_length - 1,
                 i + self.grid_position_y))
            self.row_list.append((self.grid_position_x + ROW_LIST_MARGIN,
                                  i + self.grid_position_y - 8))
            self.column_list.append(
                (i + self.grid_position_x,
                 self.grid_position_y + self.grid_length + COLUMN_LIST_MARGIN))

        wx.Frame.__init__(self,
                          None,
                          title="Gomoku Zero",
                          pos=((wx.DisplaySize()[0] - WIN_WIDTH) >> 1,
                               (wx.DisplaySize()[1] - WIN_HEIGHT) / 2.5),
                          size=(WIN_WIDTH, WIN_HEIGHT),
                          style=wx.CLOSE_BOX)
        button_font = wx.Font(14, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL,
                              False)
        image_font = wx.Font(25, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL,
                             False)
        self.undo_button = wx.Button(self,
                                     label="Undo",
                                     pos=(self.button_position_x,
                                          self.grid_position_y),
                                     size=(self.half_button_width,
                                           BUTTON_HEIGHT))
        self.redo_button = wx.Button(self,
                                     label="Redo",
                                     pos=(self.second_button_position_x,
                                          self.grid_position_y),
                                     size=(self.half_button_width,
                                           BUTTON_HEIGHT))
        self.replay_button = wx.Button(
            self,
            label="Replay",
            pos=(self.button_position_x,
                 self.grid_position_y + BUTTON_HEIGHT_MARGIN),
            size=(BUTTON_WIDTH, BUTTON_HEIGHT))
        self.black_button = wx.Button(
            self,
            label="●",
            pos=(self.button_position_x,
                 self.grid_position_y + 2 * BUTTON_HEIGHT_MARGIN),
            size=(self.half_button_width, BUTTON_HEIGHT))
        self.white_button = wx.Button(
            self,
            label="○",
            pos=(self.second_button_position_x,
                 self.grid_position_y + 2 * BUTTON_HEIGHT_MARGIN),
            size=(self.half_button_width, BUTTON_HEIGHT))
        self.ai_hint_button = wx.Button(
            self,
            label="Hint",
            pos=(self.button_position_x,
                 self.grid_position_y + 3 * BUTTON_HEIGHT_MARGIN),
            size=(BUTTON_WIDTH, BUTTON_HEIGHT))
        self.analysis_button = wx.Button(
            self,
            label="Analysis",
            pos=(self.button_position_x,
                 self.grid_position_y + 4 * BUTTON_HEIGHT_MARGIN),
            size=(BUTTON_WIDTH, BUTTON_HEIGHT))
        self.undo_button.SetFont(button_font)
        self.redo_button.SetFont(button_font)
        self.replay_button.SetFont(button_font)
        self.ai_hint_button.SetFont(button_font)
        self.analysis_button.SetFont(button_font)
        self.black_button.SetFont(image_font)
        self.white_button.SetFont(image_font)
        self.undo_button.Disable()
        self.redo_button.Disable()
        self.replay_button.Disable()
        try:
            policy_param = pickle.load(open('best_policy.model', 'rb'),
                                       encoding='bytes')
            self.mcts_player = MCTSPlayer(PolicyValueNet(
                self.n, net_params=policy_param).policy_value_func,
                                          c_puct=5,
                                          n_play_out=400)
            self.black_button.Enable()
            self.white_button.Enable()
            self.ai_hint_button.Enable()
            self.analysis_button.Enable()
        except IOError as _:
            self.black_button.Disable()
            self.white_button.Disable()
            self.ai_hint_button.Disable()
            self.analysis_button.Disable()
        self.initialize_user_interface()

    def on_undo_button_click(self, _):
        if not self.thread.is_alive():
            self.current_move -= 1
            self.board.winner = 0
            self.board.remove_move()
            if self.has_set_ai_player and (
                    self.board.winner == 0
                    or self.board.get_move_number() < self.n * self.n):
                self.current_move -= 1
                self.board.remove_move()
            self.redo_button.Enable()
            self.repaint_board()
            if self.current_move == 0:
                self.undo_button.Disable()
                self.replay_button.Disable()
            if self.mcts_player is not None:
                self.ai_hint_button.Enable()
                self.analysis_button.Enable()

    def on_redo_button_click(self, _):
        if not self.thread.is_alive():
            x, y = self.chess_record[self.current_move]
            self.current_move += 1
            self.board.add_move(y, x)
            if self.has_set_ai_player and (
                    self.board.winner == 0
                    or self.board.get_move_number() < self.n * self.n):
                x, y = self.chess_record[self.current_move]
                self.current_move += 1
                self.board.add_move(y, x)
            self.undo_button.Enable()
            self.replay_button.Enable()
            self.repaint_board()
            if self.current_move == self.moves:
                self.redo_button.Disable()
            if self.current_move == self.n * self.n:
                self.ai_hint_button.Disable()
            if self.mcts_player is not None:
                self.analysis_button.Enable()

    def on_replay_button_click(self, _):
        if not self.thread.is_alive():
            self.board.initialize()
            self.moves = 0
            self.current_move = 0
            self.has_set_ai_player = False
            self.chess_record.clear()
            self.draw_board()
            self.undo_button.Disable()
            self.redo_button.Disable()
            self.replay_button.Disable()
            if self.mcts_player is not None:
                self.black_button.Enable()
                self.white_button.Enable()
                self.ai_hint_button.Enable()
                self.analysis_button.Enable()

    def on_black_button_click(self, _):
        self.black_button.Disable()
        self.white_button.Disable()
        self.has_set_ai_player = True

    def on_white_button_click(self, _):
        self.black_button.Disable()
        self.white_button.Disable()
        self.has_set_ai_player = True
        self.thread = threading.Thread(target=self.ai_next_move, args=())
        self.thread.start()

    def on_ai_hint_button_click(self, _):
        if not self.thread.is_alive():
            self.black_button.Disable()
            self.white_button.Disable()
            self.ai_next_move()

    def on_analysis_button_click(self, _):
        if not self.thread.is_alive():
            moves, probability = copy.deepcopy(self.mcts_player).get_action(
                self.board, return_probability=2)
            move_list = [(moves[i], p) for i, p in enumerate(probability)
                         if p > 0]
            if len(move_list) > 0:
                self.draw_possible_moves(move_list)
                self.is_analysis_displayed = True
                self.analysis_button.Disable()

    def on_paint(self, _):
        dc = wx.PaintDC(self)
        dc.SetBackground(wx.Brush(wx.WHITE_BRUSH))
        dc.Clear()
        self.draw_board()

    def ai_next_move(self):
        move = self.mcts_player.get_action(self.board)
        x, y = self.board.move_to_location(move)
        self.board.add_move(x, y)
        if self.is_analysis_displayed:
            self.repaint_board()
        self.analysis_button.Enable()
        self.draw_move(y, x)

    def disable_buttons(self):
        if self.moves > 8:
            end, _ = self.board.has_ended()
            if end:
                self.ai_hint_button.Disable()
                self.analysis_button.Disable()

    def initialize_user_interface(self):
        self.board = Board(self.n)
        self.Bind(wx.EVT_PAINT, self.on_paint)
        self.Bind(wx.EVT_LEFT_UP, self.on_click)
        self.Bind(wx.EVT_BUTTON, self.on_undo_button_click, self.undo_button)
        self.Bind(wx.EVT_BUTTON, self.on_redo_button_click, self.redo_button)
        self.Bind(wx.EVT_BUTTON, self.on_replay_button_click,
                  self.replay_button)
        self.Bind(wx.EVT_BUTTON, self.on_black_button_click, self.black_button)
        self.Bind(wx.EVT_BUTTON, self.on_white_button_click, self.white_button)
        self.Bind(wx.EVT_BUTTON, self.on_ai_hint_button_click,
                  self.ai_hint_button)
        self.Bind(wx.EVT_BUTTON, self.on_analysis_button_click,
                  self.analysis_button)
        self.Centre()
        self.Show(True)

    def repaint_board(self):
        self.draw_board()
        self.draw_chess()
        self.is_banner_displayed = False
        self.is_analysis_displayed = False

    def draw_board(self):
        dc = wx.ClientDC(self)
        dc.SetPen(wx.Pen(wx.WHITE))
        dc.SetBrush(wx.Brush(wx.WHITE))
        dc.DrawRectangle(self.grid_position_x - self.block_length,
                         self.grid_position_y - self.block_length,
                         self.grid_length + self.block_length * 2,
                         self.grid_length + self.block_length * 2)
        dc.SetPen(wx.Pen(wx.BLACK, width=2))
        dc.DrawLineList(self.line_list)
        dc.SetFont(
            wx.Font(13, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, False))
        dc.DrawTextList(self.row_name_list, self.row_list)
        dc.DrawTextList(self.column_name_list, self.column_list)
        dc.SetBrush(wx.Brush(wx.BLACK))
        if self.n % 2 == 1:
            dc.DrawCircle(
                self.grid_position_x + self.block_length * (self.n >> 1),
                self.grid_position_y + self.block_length * (self.n >> 1), 4)
        if self.n == STANDARD_LENGTH:
            dc.DrawCircle(self.grid_position_x + self.block_length * 3,
                          self.grid_position_y + self.block_length * 3, 4)
            dc.DrawCircle(self.grid_position_x + self.block_length * 3,
                          self.grid_position_y + self.block_length * 11, 4)
            dc.DrawCircle(self.grid_position_x + self.block_length * 11,
                          self.grid_position_y + self.block_length * 3, 4)
            dc.DrawCircle(self.grid_position_x + self.block_length * 11,
                          self.grid_position_y + self.block_length * 11, 4)

    def draw_possible_moves(self, possible_move):
        dc = wx.ClientDC(self)
        for move, p in possible_move:
            y, x = self.board.move_to_location(move)
            dc.SetBrush(
                wx.Brush(
                    wx.Colour(28,
                              164,
                              252,
                              alpha=14 if int(p * 230) < 14 else int(p *
                                                                     230))))
            dc.SetPen(wx.Pen(wx.Colour(28, 164, 252, alpha=230)))
            dc.DrawCircle(self.grid_position_x + x * self.block_length,
                          self.grid_position_y + y * self.block_length,
                          self.piece_radius)

    def draw_chess(self):
        dc = wx.ClientDC(self)
        self.disable_buttons()
        for x, y in np.ndindex(self.board.chess[4:self.n + 4,
                                                4:self.n + 4].shape):
            if self.board.chess[y + 4, x + 4] > 0:
                dc.SetBrush(
                    wx.Brush(wx.BLACK if self.board.chess[y + 4, x + 4] ==
                             1 else wx.WHITE))
                dc.DrawCircle(self.grid_position_x + x * self.block_length,
                              self.grid_position_y + y * self.block_length,
                              self.piece_radius)
        if self.current_move > 0:
            x, y = self.chess_record[self.current_move - 1]
            x = self.grid_position_x + x * self.block_length
            y = self.grid_position_y + y * self.block_length
            dc.SetBrush(
                wx.Brush(wx.BLACK if self.current_move % 2 == 1 else wx.WHITE))
            dc.SetPen(
                wx.Pen(wx.WHITE if self.current_move % 2 == 1 else wx.BLACK))
            dc.DrawCircle(x, y, self.inner_circle_radius)

    def draw_move(self, x: int, y: int) -> bool:
        if self.current_move == 0:
            self.undo_button.Enable()
            self.replay_button.Enable()
        for _ in range(self.current_move, self.moves):
            self.chess_record.pop()
            self.redo_button.Disable()
        self.current_move += 1
        self.moves = self.current_move
        self.chess_record.append((x, y))
        self.draw_chess()
        if self.moves > 8:
            end, winner = self.board.has_ended()
            if end:
                self.disable_buttons()
                self.draw_banner(winner)
            return end
        return False

    def draw_banner(self, result: int):
        w = 216
        if result == 1:
            string = "BLACK WIN"
        elif result == 2:
            string = "WHITE WIN"
        else:
            string = "DRAW"
            w = 97
        x = (self.grid_position_x + ((self.grid_length - w) >> 1))
        dc = wx.ClientDC(self)
        dc.SetBrush(wx.Brush(wx.WHITE))
        dc.DrawRectangle(
            self.grid_position_x + ((self.grid_length - BANNER_WIDTH) >> 1),
            self.grid_position_y + ((self.grid_length - BANNER_HEIGHT) >> 1),
            BANNER_WIDTH, BANNER_HEIGHT)
        dc.SetPen(wx.Pen(wx.BLACK))
        dc.SetFont(
            wx.Font(40, wx.FONTFAMILY_MODERN, wx.FONTSTYLE_NORMAL, False))
        dc.DrawText(string, x,
                    (self.grid_position_y + ((self.grid_length - 40) >> 1)))
        self.is_banner_displayed = True

    def on_click(self, e):
        if not self.thread.is_alive():
            if self.board.winner == 0:
                if self.is_analysis_displayed:
                    self.repaint_board()
                x, y = e.GetPosition()
                x = x - self.grid_position_x + (self.block_length >> 1)
                y = y - self.grid_position_y + (self.block_length >> 1)
                if x > 0 and y > 0:
                    x = int(x / self.block_length)
                    y = int(y / self.block_length)
                    if 0 <= x < self.n and 0 <= y < self.n:
                        if self.board.chess[y + 4, x + 4] == 0:
                            if self.mcts_player is not None:
                                self.analysis_button.Enable()
                                self.black_button.Disable()
                                self.white_button.Disable()
                            self.board.add_move(y, x)
                            has_end = self.draw_move(x, y)
                            if self.has_set_ai_player and not has_end:
                                self.thread = threading.Thread(
                                    target=self.ai_next_move, args=())
                                self.thread.start()
            elif self.is_banner_displayed:
                self.repaint_board()