class Arena(): """ An Arena class where any 2 agents can be pit against each other. """ def __init__(self, player1, player2, game, display=None): """ Input: player 1,2: two functions that takes board as input, return action game: Game object display: a function that takes board as input and prints it (e.g. display in othello/OthelloGame). Is necessary for verbose mode. see othello/OthelloPlayers.py for an example. See pit.py for pitting human players/other baselines with each other. """ self.player1 = player1 self.player2 = player2 self.game = game self.display = display self.debug = Debug() dt = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") self.folder = os.path.join("Debug", "Arena", dt + "_" + str(os.getpid())) os.makedirs(self.folder) self.print_boards = False def playGame(self, verbose=False): """ Executes one episode of a game. Returns: either winner: player who won the game (1 if player1, -1 if player2) or draw result returned from the game that is neither 1, -1, nor 0; currently this is 0.0001 """ # reset player's state self.player1.reset() self.player2.reset() players = [self.player2, None, self.player1] curPlayer = 1 board = self.game.getInitBoard() it = 0 lastMove = [] while self.game.getGameEnded(board, curPlayer) == 0: it += 1 #if verbose: # assert(self.display) # print("Turn ", str(it), "Player ", str(curPlayer)) # self.display(board) canonicalBoard = self.game.getCanonicalForm(board, curPlayer) if lastMove and not board.last_long_capture: players[curPlayer + 1].makeOpponentMove(lastMove) del lastMove[:] action = players[curPlayer + 1].play(canonicalBoard) valids = self.game.getValidMoves(canonicalBoard, 1) if valids[action] == 0: print(action) canonicalBoard.display() assert valids[action] > 0, "Illegal move: " + str( Move.parse_action(action)) # board.legal_moves = canonicalBoard.legal_moves #self.debug.print_legal_moves(board) previousBoard = board action = canonicalBoard.transform_action_for_board(action, board) board, curPlayer = self.game.getNextState(board, curPlayer, action) lastMove.append(board.executed_moves[-1]) # debug next state board.id = previousBoard.id + "-?" if board.halfMoves > 100: print("info: board.id:", previousBoard.id, "=>", board.id, ", pieces:", board.count_pieces(), ", halfMoves:", board.halfMoves, ", no-progress:", board.noProgressCount) if self.print_boards: s = self.game.stringRepresentation(previousBoard) s1 = self.game.stringRepresentation(board) self.debug.print_to_file( board, s1, "Newly generated position, previous one: " + s) #board.display() #input("press enter to continue: "); result = curPlayer * self.game.getGameEnded(board, curPlayer) if verbose: assert (self.display) print("Game over: Turn ", str(it), "Result ", str(result)) self.display(board) # print game record to a file self.game.printGameRecord(board, curPlayer, self.folder) return result def playGames(self, num, verbose=False): """ Plays num games in which player1 starts num/2 games and player2 starts num/2 games. Returns: oneWon: games won by player1 twoWon: games won by player2 draws: games won by nobody """ eps_time = AverageMeter() bar = Bar('Arena.playGames', max=num) end = time.time() eps = 0 maxeps = int(num) num = int(num / 2) oneWon = 0 twoWon = 0 draws = 0 for _ in range(num): eps += 1 print("") print("Episode ", eps) gameResult = self.playGame(verbose=verbose) if gameResult == 1: oneWon += 1 elif gameResult == -1: twoWon += 1 else: draws += 1 # bookkeeping + plot progress eps_time.update(time.time() - end) end = time.time() bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format( eps=eps, maxeps=maxeps, et=eps_time.avg, total=bar.elapsed_td, eta=bar.eta_td) bar.next() self.player1, self.player2 = self.player2, self.player1 for _ in range(num): eps += 1 print("") print("Episode ", eps) gameResult = self.playGame(verbose=verbose) if gameResult == -1: oneWon += 1 elif gameResult == 1: twoWon += 1 else: draws += 1 # bookkeeping + plot progress eps_time.update(time.time() - end) end = time.time() bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format( eps=eps, maxeps=maxeps, et=eps_time.avg, total=bar.elapsed_td, eta=bar.eta_td) bar.next() bar.finish() return oneWon, twoWon, draws
class MCTS(): """ This class handles the MCTS tree. """ # (debug) print every board to a file print_boards = False def __init__(self, game, nnet, args): self.game = game self.nnet = nnet self.args = args self.Qsa = {} # stores Q values for s,a (as defined in the paper) self.Nsa = {} # stores #times edge s,a was visited self.Ns = {} # stores #times board s was visited self.Ps = {} # stores initial policy (returned by neural net) self.Es = {} # stores game.getGameEnded ended for board s self.Vs = { } # stores game.getValidMoves for board s - full list of size game.getActionSize self.Ls = {} # stores LegalMoves for board s - short list self.Nbnodes = { } # stores number of nodes outgoing from node (board) s self.numActionProbs = 0 self.searchIndex = 0 self.predictionMeter = AverageMeter() self.debug = Debug() self.maxHalfMovesForDebug = 150 self.id = np.random.randint(1000) warnings.filterwarnings('error', category=RuntimeWarning) def getActionProb(self, canonicalBoard, temp=1): """ This function performs numMCTSSims simulations of MCTS starting from canonicalBoard. Returns: probs: a policy vector where the probability of the ith action is proportional to Nsa[(s,a)]**(1./temp) """ self.numActionProbs += 1 if not canonicalBoard.id: canonicalBoard.id = "ActionProb_" + str(self.numActionProbs) + "_" for i in range(self.args.numMCTSSims): #print("getActionProb, simulation:", i) self.search(canonicalBoard, True) s = self.game.stringRepresentation(canonicalBoard) counts = [ self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize()) ] if temp == 0: maxi = max(counts) #print("maxi=",maxi, " of counts=", counts) allBest = np.where(np.array(counts) == maxi)[0] #print("[",self.id,"] getActionProb:", self.numActionProbs, "allBest=", allBest) bestA = np.random.choice(allBest) #print("[",self.id,"] getActionProb:", self.numActionProbs, "bestA=", bestA) # bestA = np.argmax(counts) probs = [0] * len(counts) probs[bestA] = 1 return probs counts = [x**(1. / temp) for x in counts] if sum(counts) == 0: canonicalBoard.display() probs = [x / float(sum(counts)) for x in counts] return probs def search(self, canonicalBoard, isRootNode): """ This function performs one iteration of MCTS. It is recursively called till a leaf node is found. The action chosen at each node is one that has the maximum upper confidence bound as in the paper. Once a leaf node is found, the neural network is called to return an initial policy P and a value v for the state. This value is propogated up the search path. In case the leaf node is a terminal state, the outcome is propogated up the search path. The values of Ns, Nsa, Qsa are updated. NOTE: the return values are the negative of the value of the current state. This is done since v is in [-1,1] and if v is the value of a state for the current player, then its value is -v for the other player. Returns: v: the negative of the value of the current canonicalBoard """ self.searchIndex += 1 myIndex = self.searchIndex #print("MCTS.search: ", myIndex, " size of map: ", len(self.Ps)) s = self.game.stringRepresentation(canonicalBoard) # check rotation and active player #rot = s[3] #assert rot=="r" or rot=="n", "Illegal rotation flag:"+str(rot)+" in "+s #if rot=="r": # assert rot=="r" and canonicalBoard.halfMoves % 2 == 1, canonicalBoard.display() #else: # assert rot=="n" and canonicalBoard.halfMoves % 2 == 0, canonicalBoard.display() if s not in self.Es: self.Es[s] = self.game.getGameEnded(canonicalBoard, 1) if self.Es[s] != 0 and self.print_boards: # print("game over detected, r:", self.Es[s], ", halfMoves:", canonicalBoard.halfMoves, ", no-progress:", canonicalBoard.noProgressCount) #print("all_moves: ", canonicalBoard.executed_moves) #canonicalBoard.display() self.debug.print_to_file( canonicalBoard, s, "Game over, result:" + str(self.Es[s])) if self.Es[s] != 0: # terminal node return -self.Es[s] if s not in self.Ps: # leaf node start = millis() self.Ps[s], v = self.nnet.predict(canonicalBoard) time = millis() - start self.predictionMeter.update(time) valids = self.game.getValidMoves(canonicalBoard, 1) # for debugging # orig_Ps_s = copy.deepcopy(self.Ps[s]) self.Ps[s] = self.Ps[s] * valids # masking invalid moves # for debugging # masked_Ps_s = copy.deepcopy(self.Ps[s]) #try: sum_Ps_s = np.sum(self.Ps[s]) if sum_Ps_s > 0: self.Ps[s] /= sum_Ps_s # renormalize else: # if all valid moves were masked # make all valid moves equally probable print("All valid moves were masked, do workaround. board.id=", canonicalBoard.id) #print("valid_moves:", canonicalBoard.filter_legal_moves()) #print("executed_moves:", canonicalBoard.executed_moves) #canonicalBoard.display() predicted = "?" # np.where(orig_Ps_s > 0) self.debug.print_to_file( canonicalBoard, "workarounds", "All valid moves were masked, do workaround\npredicted_actions:" + str(predicted)) self.Ps[s] = self.Ps[s] + valids self.Ps[s] /= np.sum(self.Ps[s]) #except Warning: # print("Error detected") # canonicalBoard.display() # self.debug.print_legal_moves(canonicalBoard) # print("all_previous_moves: ", canonicalBoard.executed_moves) # # print to file # self.debug.print_to_file(canonicalBoard, s, "Error detected") # raise self.Vs[s] = valids self.Ls[s] = canonicalBoard.legal_moves self.Ns[s] = 0 #if self.print_boards: # self.debug.print_to_file(canonicalBoard, s, "Valid moves calculated") return -v canonicalBoard.legal_moves = self.Ls[s] valids = self.Vs[s] cur_best = -float('inf') #best_act = -1 allBest = [] # add Dirichlet noise for root node. set epsilon=0 for Arena competitions of trained models e = self.args.epsilon if isRootNode and e > 0: noise = np.random.dirichlet( [self.args.dirAlpha] * len(canonicalBoard.filter_legal_moves())) # pick the action with the highest upper confidence bound i = -1 for a in range(self.game.getActionSize()): if valids[a]: i += 1 if (s, a) in self.Qsa: q = self.Qsa[(s, a)] n_s_a = self.Nsa[(s, a)] #u = self.Qsa[(s,a)] + self.args.cpuct*self.Ps[s][a]*math.sqrt(self.Ns[s])/(1+self.Nsa[(s,a)]) else: q = 0 n_s_a = 0 #u = self.args.cpuct*self.Ps[s][a]*math.sqrt(self.Ns[s]) # Q = 0 ? p = self.Ps[s][a] if isRootNode and e > 0: p = (1 - e) * p + e * noise[i] u = q + self.args.cpuct * p * math.sqrt( self.Ns[s]) / (1 + n_s_a) if u > cur_best: cur_best = u #best_act = a del allBest[:] allBest.append(a) elif u == cur_best: allBest.append(a) #a = best_act a = np.random.choice(allBest) try: assert a >= 0, "Illegal action=" + str(a) next_s, next_player = self.game.getNextState(canonicalBoard, 1, a) next_s = self.game.getCanonicalForm(next_s, next_player) # s1 = self.game.stringRepresentation(next_s) next_s.id = self.next_board_id(canonicalBoard, s) if next_s.halfMoves > self.maxHalfMovesForDebug: self.maxHalfMovesForDebug = next_s.halfMoves print("info: board.id:", canonicalBoard.id, "=>", next_s.id, ", pieces:", next_s.count_pieces(), ", halfMoves:", next_s.halfMoves, ", no-progress:", next_s.noProgressCount) #if self.print_boards: # self.debug.print_to_file(next_s, s1, "Newly generated position, previous one: "+s) except: print("Error detected") print(s) canonicalBoard.display() self.debug.print_legal_moves(canonicalBoard) # a < 0 occurs sometimes in MCTS.search() move = None if a < 0 else canonicalBoard.parse_action(a) print("execute_move:", move) print("executed_moves: ", canonicalBoard.executed_moves) # print to file file = self.debug.print_to_file(canonicalBoard, s, "Error detected") with (open(file, 'a+')) as f: print("execute_move:", move, file=f) f.closed raise #if next_player==1: # print("long capture detected") # print("search next_state") v = self.search(next_s, False) # trick for long_captures v *= -next_player if (s, a) in self.Qsa: self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1) self.Nsa[(s, a)] += 1 else: self.Qsa[(s, a)] = v self.Nsa[(s, a)] = 1 self.Ns[s] += 1 #print("END OF MCTS.search: ", myIndex, " size of map: ", len(self.Ps)) return -v def next_board_id(self, previousBoard, s_of_previousBoard): # define board.id if s_of_previousBoard not in self.Nbnodes: self.Nbnodes[s_of_previousBoard] = 0 else: self.Nbnodes[s_of_previousBoard] += 1 next_id = previousBoard.id + "-" + str( self.Nbnodes[s_of_previousBoard]) return next_id def print_stats(self): # print collected stats of the instance print("") # empty line for the case if the carriage was not returned print("MCTS: Ps.size = ", len(self.Ps), ", Es.size = ", len(self.Es), ", Qsa.size = ", len(self.Qsa)) print("MCTS: actionProbs = ", self.numActionProbs, ", searchCount = ", self.searchIndex) pm = self.predictionMeter print("MCTS: nnet.predictions: count=", pm.count, ", times(min/avg/max/total) = ", pm.min, pm.avg, pm.max, pm.sum)