예제 #1
0
    def test_mask_invalid_moves(self):
        # https://lichess.org/editor/8/3k4/8/8/8/4Q3/2K5/8_w_-_-_0_1
        board = chess.Board('8/3k4/8/8/8/4Q3/2K5/8 w - - 0 1')

        # Mask the white King and Queen's moves
        legal_move_dict = build_legal_move_dict(board)

        queen_moves = legal_move_dict[chess.E3]
        queen_indices = {
            square_move_to_index(chess.E3, to_square)
            for to_square in queen_moves
        }
        queen_expected = torch.tensor(
            [1 if i in queen_indices else 0 for i in range(73)])

        king_moves = legal_move_dict[chess.C2]
        king_indices = {
            square_move_to_index(chess.C2, to_square)
            for to_square in king_moves
        }
        king_expected = torch.tensor(
            [1 if i in king_indices else 0 for i in range(73)])

        expected = torch.zeros(8, 8, 73)
        expected[utils.square_to_n_n(chess.E3)] = queen_expected
        expected[utils.square_to_n_n(chess.C2)] = king_expected
        expected /= expected.sum()

        policy = torch.ones(8, 8, 73)
        mask_invalid_moves(policy, board)

        self.assertEqual(policy.sum(), 1)
        self.assertTrue((expected == policy).all())
예제 #2
0
    def test_mask_position(self):
        policy = torch.ones(8, 8, 73)
        # Alekhine's defense https://lichess.org/editor/rnbqkb1r/ppp1pppp/3p4/3nP3/3P4/5N2/PPP2PPP/RNBQKB1R_b_KQkq_-_1_4
        board = chess.Board(
            'rnbqkb1r/ppp1pppp/3p4/3nP3/3P4/5N2/PPP2PPP/RNBQKB1R b KQkq - 1 4')

        # Mask the black knight's moves
        legal_move_dict = build_legal_move_dict(board)
        knight_moves = legal_move_dict[chess.D5]
        legal_indices = {
            square_move_to_index(chess.D5, to_square)
            for to_square in knight_moves
        }
        expected = torch.tensor(
            [1 if i in legal_indices else 0 for i in range(73)])

        row, col = utils.square_to_n_n(chess.D5)
        mask_position(row, col, policy, legal_move_dict)

        self.assertTrue((policy[row, col] == expected).all())
예제 #3
0
    def encode_state(self, board, prev_state):
        '''
            Returns the input tensor to be processed by AlphaZero.
        '''
        T = self.T
        state = torch.zeros(M * T + L, N, N)

        # Encode history, dropping the oldest state
        state[:M * (T - 1), ...] = prev_state[M:M * T, ...]

        # Encode pieces
        t_offset = M * (T - 1)
        for square, piece in board.piece_map().items():
            piece_ind = self.piece_to_index(piece, board)
            row, col = square_to_n_n(square)
            state[t_offset + piece_ind, row, col] = 1

        # TODO Potentially encode repeat count for both players

        # Encode current player color
        state[M * T, ...] += board.turn

        # Encode total move count
        state[M * T + 1, ...] += len(board.move_stack)

        # Encode player castling
        state[M * T + 2,
              ...] += board.has_queenside_castling_rights(board.turn)
        state[M * T + 3, ...] += board.has_kingside_castling_rights(board.turn)

        # Encode opponent castling
        state[M * T + 4,
              ...] += board.has_queenside_castling_rights(not board.turn)
        state[M * T + 5,
              ...] += board.has_kingside_castling_rights(not board.turn)

        # TODO Potentially encode progress count

        return state
예제 #4
0
    def get_mcts_policy(self, mcts_dist):
        '''
            Takes the MCTSDist and converts it to a gold
            distribution in the shape of the network's policy.
        '''
        policy = torch.zeros(8, 8, 73)

        for move in mcts_dist.move_data:
            row, col = square_to_n_n(move.from_square)
            index = square_move_to_index(move.from_square, move.to_square,
                                         move.promotion)
            # Since MCTS is very expensive on a single core machine, use softmax instead
            # of hard probabilities which are frequently zero as n_visits are sparse
            # score = move.n_visits
            score = move.n_visits**(1 / mcts_dist.temp
                                    )  # Original AlphaZero score

            policy[row, col, index] = score

        # policy = F.softmax(policy.flatten() / mcts_dist.temp, dim=0) # Softer distribution than original AlphaZero
        # policy = policy.reshape(8,8,73)
        policy /= policy.sum()  # Original AlphaZero normalization

        return policy
예제 #5
0
 def prior_func(move):
     row, col = square_to_n_n(move.from_square)
     index = square_move_to_index(move.from_square, move.to_square, move.promotion)
     return net_policy[row, col, index]