def test_mask_invalid_moves(self): # https://lichess.org/editor/8/3k4/8/8/8/4Q3/2K5/8_w_-_-_0_1 board = chess.Board('8/3k4/8/8/8/4Q3/2K5/8 w - - 0 1') # Mask the white King and Queen's moves legal_move_dict = build_legal_move_dict(board) queen_moves = legal_move_dict[chess.E3] queen_indices = { square_move_to_index(chess.E3, to_square) for to_square in queen_moves } queen_expected = torch.tensor( [1 if i in queen_indices else 0 for i in range(73)]) king_moves = legal_move_dict[chess.C2] king_indices = { square_move_to_index(chess.C2, to_square) for to_square in king_moves } king_expected = torch.tensor( [1 if i in king_indices else 0 for i in range(73)]) expected = torch.zeros(8, 8, 73) expected[utils.square_to_n_n(chess.E3)] = queen_expected expected[utils.square_to_n_n(chess.C2)] = king_expected expected /= expected.sum() policy = torch.ones(8, 8, 73) mask_invalid_moves(policy, board) self.assertEqual(policy.sum(), 1) self.assertTrue((expected == policy).all())
def test_mask_position(self): policy = torch.ones(8, 8, 73) # Alekhine's defense https://lichess.org/editor/rnbqkb1r/ppp1pppp/3p4/3nP3/3P4/5N2/PPP2PPP/RNBQKB1R_b_KQkq_-_1_4 board = chess.Board( 'rnbqkb1r/ppp1pppp/3p4/3nP3/3P4/5N2/PPP2PPP/RNBQKB1R b KQkq - 1 4') # Mask the black knight's moves legal_move_dict = build_legal_move_dict(board) knight_moves = legal_move_dict[chess.D5] legal_indices = { square_move_to_index(chess.D5, to_square) for to_square in knight_moves } expected = torch.tensor( [1 if i in legal_indices else 0 for i in range(73)]) row, col = utils.square_to_n_n(chess.D5) mask_position(row, col, policy, legal_move_dict) self.assertTrue((policy[row, col] == expected).all())
def encode_state(self, board, prev_state): ''' Returns the input tensor to be processed by AlphaZero. ''' T = self.T state = torch.zeros(M * T + L, N, N) # Encode history, dropping the oldest state state[:M * (T - 1), ...] = prev_state[M:M * T, ...] # Encode pieces t_offset = M * (T - 1) for square, piece in board.piece_map().items(): piece_ind = self.piece_to_index(piece, board) row, col = square_to_n_n(square) state[t_offset + piece_ind, row, col] = 1 # TODO Potentially encode repeat count for both players # Encode current player color state[M * T, ...] += board.turn # Encode total move count state[M * T + 1, ...] += len(board.move_stack) # Encode player castling state[M * T + 2, ...] += board.has_queenside_castling_rights(board.turn) state[M * T + 3, ...] += board.has_kingside_castling_rights(board.turn) # Encode opponent castling state[M * T + 4, ...] += board.has_queenside_castling_rights(not board.turn) state[M * T + 5, ...] += board.has_kingside_castling_rights(not board.turn) # TODO Potentially encode progress count return state
def get_mcts_policy(self, mcts_dist): ''' Takes the MCTSDist and converts it to a gold distribution in the shape of the network's policy. ''' policy = torch.zeros(8, 8, 73) for move in mcts_dist.move_data: row, col = square_to_n_n(move.from_square) index = square_move_to_index(move.from_square, move.to_square, move.promotion) # Since MCTS is very expensive on a single core machine, use softmax instead # of hard probabilities which are frequently zero as n_visits are sparse # score = move.n_visits score = move.n_visits**(1 / mcts_dist.temp ) # Original AlphaZero score policy[row, col, index] = score # policy = F.softmax(policy.flatten() / mcts_dist.temp, dim=0) # Softer distribution than original AlphaZero # policy = policy.reshape(8,8,73) policy /= policy.sum() # Original AlphaZero normalization return policy
def prior_func(move): row, col = square_to_n_n(move.from_square) index = square_move_to_index(move.from_square, move.to_square, move.promotion) return net_policy[row, col, index]