def describe(self): sort_order = list(range(self.board_size * self.board_size + 1)) sort_order.sort(key=lambda i: ( self.child_N[i], self.child_action_score[i]), reverse=True) soft_n = self.child_N / sum(self.child_N) p_delta = soft_n - self.child_prior p_rel = p_delta / self.child_prior # Dump out some statistics output = [] output.append('{q:.4f}\n'.format(q=self.Q)) output.append(self.most_visited_path()) output.append( '''move: action Q U P P-Dir N soft-N p-delta p-rel\n''') output.append( '\n'.join([ '''{!s:6}: {: .3f}, {: .3f}, {:.3f}, {:.3f}, {:.3f}, {:4d} {:.4f} {: .5f} {: .2f}'''.format( coords.to_kgs(self.board_size, coords.from_flat( self.board_size, key)), self.child_action_score[key], self.child_Q[key], self.child_U[key], self.child_prior[key], self.original_prior[key], int(self.child_N[key]), soft_n[key], p_delta[key], p_rel[key]) for key in sort_order][:15])) return ''.join(output)
def maybe_add_child(self, fcoord): """Add child node for fcoord if it doesn't already exist, and returns it.""" if fcoord not in self.children: new_position = self.position.play_move( coords.from_flat(self.board_size, fcoord)) self.children[fcoord] = MCTSNode( self.board_size, new_position, fmove=fcoord, parent=self) return self.children[fcoord]
def test_legal_moves(self): board = utils_test.load_board(''' .O.O.XOX. O..OOOOOX ......O.O OO.....OX XO.....X. .O....... OX.....OO XX...OOOX .....O.X. ''') position = Position(utils_test.BOARD_SIZE, board=board, to_play=BLACK) illegal_moves = coords_from_kgs_set('A9 E9 J9') legal_moves = coords_from_kgs_set('A4 G1 J1 H7') | {None} for move in illegal_moves: with self.subTest(type='illegal', move=move): self.assertFalse(position.is_move_legal(move)) for move in legal_moves: with self.subTest(type='legal', move=move): self.assertTrue(position.is_move_legal(move)) # check that the bulk legal test agrees with move-by-move illegal test. bulk_legality = position.all_legal_moves() for i, bulk_legal in enumerate(bulk_legality): with self.subTest(type='bulk', move=coords.from_flat( utils_test.BOARD_SIZE, i)): self.assertEqual( bulk_legal, position.is_move_legal( coords.from_flat(utils_test.BOARD_SIZE, i))) # flip the colors and check that everything is still (il)legal position = Position(utils_test.BOARD_SIZE, board=-board, to_play=WHITE) for move in illegal_moves: with self.subTest(type='illegal', move=move): self.assertFalse(position.is_move_legal(move)) for move in legal_moves: with self.subTest(type='legal', move=move): self.assertTrue(position.is_move_legal(move)) bulk_legality = position.all_legal_moves() for i, bulk_legal in enumerate(bulk_legality): with self.subTest(type='bulk', move=coords.from_flat( utils_test.BOARD_SIZE, i)): self.assertEqual( bulk_legal, position.is_move_legal(coords.from_flat( utils_test.BOARD_SIZE, i)))
def mvp_gg(self): """ Returns most visited path in go-gui VAR format e.g. 'b r3 w c17...""" node = self output = [] while node.children and max(node.child_N) > 1: next_kid = np.argmax(node.child_N) node = node.children[next_kid] output.append('{}'.format(coords.to_kgs( self.board_size, coords.from_flat(self.board_size, node.fmove)))) return ' '.join(output)
def most_visited_path(self): node = self output = [] while node.children: next_kid = np.argmax(node.child_N) node = node.children.get(next_kid) if node is None: output.append('GAME END') break output.append('{} ({}) ==> '.format( coords.to_kgs( self.board_size, coords.from_flat(self.board_size, node.fmove)), node.N)) output.append('Q: {:.5f}\n'.format(node.Q)) return ''.join(output)
def pick_move(self): """Picks a move to play, based on MCTS readout statistics. Highest N is most robust indicator. In the early stage of the game, pick a move weighted by visit count; later on, pick the absolute max. """ if self.root.position.n > self.temp_threshold: fcoord = np.argmax(self.root.child_N) else: cdf = self.root.child_N.cumsum() cdf /= cdf[-1] selection = random.random() fcoord = cdf.searchsorted(selection) assert self.root.child_N[fcoord] != 0 return coords.from_flat(self.board_size, fcoord)
def play(board_size, network, readouts, resign_threshold, simultaneous_leaves, verbosity=0): """Plays out a self-play match. Args: board_size: the go board size network: the DualNet model readouts: the number of readouts in MCTS resign_threshold: the threshold to resign at in the match simultaneous_leaves: the number of simultaneous leaves in MCTS verbosity: the verbosity of the self-play match Returns: the final position the n x 362 tensor of floats representing the mcts search probabilities the n-ary tensor of floats representing the original value-net estimate where n is the number of moves in the game. """ player = MCTSPlayer(board_size, network, resign_threshold=resign_threshold, verbosity=verbosity, num_parallel=simultaneous_leaves) # Disable resign in 5% of games if random.random() < 0.05: player.resign_threshold = -1.0 player.initialize_game() # Must run this once at the start, so that noise injection actually # affects the first move of the game. first_node = player.root.select_leaf() prob, val = network.run(first_node.position) first_node.incorporate_results(prob, val, first_node) while True: start = time.time() player.root.inject_noise() current_readouts = player.root.N # we want to do "X additional readouts", rather than "up to X readouts". while player.root.N < current_readouts + readouts: player.tree_search() if verbosity >= 3: print(player.root.position) print(player.root.describe()) if player.should_resign(): player.set_result(-1 * player.root.position.to_play, was_resign=True) break move = player.pick_move() player.play_move(move) if player.root.is_done(): player.set_result(player.root.position.result(), was_resign=False) break if (verbosity >= 2) or (verbosity >= 1 and player.root.position.n % 10 == 9): print("Q: {:.5f}".format(player.root.Q)) dur = time.time() - start print("%d: %d readouts, %.3f s/100. (%.2f sec)" % (player.root.position.n, readouts, dur / readouts * 100.0, dur)) if verbosity >= 3: print("Played >>", coords.to_kgs(coords.from_flat(player.root.fmove))) if verbosity >= 2: print("%s: %.3f" % (player.result_string, player.root.Q), file=sys.stderr) print(player.root.position, player.root.position.score(), file=sys.stderr) return player
def heatmap(self, sort_order, node, prop): return '\n'.join([ '{!s:6} {}'.format(coords.to_kgs(coords.from_flat(key)), node.__dict__.get(prop)[key]) for key in sort_order if node.child_N[key] > 0 ][:20])