def play(network): """Plays out a self-play match, returning a MCTSPlayer object containing: - the final position - the n x 362 tensor of floats representing the mcts search probabilities - the n-ary tensor of floats representing the original value-net estimate where n is the number of moves in the game """ readouts = FLAGS.num_readouts # defined in strategies.py # Disable resign in 5% of games if random.random() < FLAGS.resign_disable_pct: resign_threshold = -1.0 else: resign_threshold = None player = MCTSPlayer(network, resign_threshold=resign_threshold) player.initialize_game() # Must run this once at the start to expand the root node. first_node = player.root.select_leaf() prob, val = network.run(first_node.position) first_node.incorporate_results(prob, val, first_node) while True: start = time.time() player.root.inject_noise() current_readouts = player.root.N # we want to do "X additional readouts", rather than "up to X readouts". while player.root.N < current_readouts + readouts: player.tree_search() if FLAGS.verbose >= 3: print(player.root.position) print(player.root.describe()) if player.should_resign(): player.set_result(-1 * player.root.position.to_play, was_resign=True) break move = player.pick_move() player.play_move(move) if player.root.is_done(): player.set_result(player.root.position.result(), was_resign=False) break if (FLAGS.verbose >= 2) or (FLAGS.verbose >= 1 and player.root.position.n % 10 == 9): print("Q: {:.5f}".format(player.root.Q)) dur = time.time() - start print("%d: %d readouts, %.3f s/100. (%.2f sec)" % ( player.root.position.n, readouts, dur / readouts * 100.0, dur), flush=True) if FLAGS.verbose >= 3: print("Played >>", coords.to_gtp(coords.from_flat(player.root.fmove))) if FLAGS.verbose >= 2: utils.dbg("%s: %.3f" % (player.result_string, player.root.Q)) utils.dbg(player.root.position, player.root.position.score()) return player
def test_tree_search_failsafe(self): # Test that the failsafe works correctly. It can trigger if the MCTS # repeatedly visits a finished game state. probs = np.array([.001] * (go.N * go.N + 1)) probs[-1] = 1 # Make the dummy net always want to pass player = MCTSPlayer(DummyNet(fake_priors=probs)) pass_position = go.Position().pass_move() player.initialize_game(pass_position) player.tree_search(parallel_readouts=8) self.assertNoPendingVirtualLosses(player.root)
def test_cold_start_parallel_tree_search(self): # Test that parallel tree search doesn't trip on an empty tree player = MCTSPlayer(DummyNet(fake_value=0.17)) player.initialize_game() self.assertEqual(0, player.root.N) self.assertFalse(player.root.is_expanded) player.tree_search(parallel_readouts=4) self.assertNoPendingVirtualLosses(player.root) # Even though the root gets selected 4 times by tree search, its # final visit count should just be 1. self.assertEqual(1, player.root.N) # 0.085 = average(0, 0.17), since 0 is the prior on the root. self.assertAlmostEqual(0.085, player.root.Q)
def play(network): search_n = 100 player = MCTSPlayer(network=network,seconds_per_move=seconds_per_move,timed_match=timed_match,search_n=search_n, player_mode=0) player.initialize_game() while True: start = time.time() current_n = player.root.N while player.root.N < current_n + search_n: player.tree_search() move = player.pick_move() #print(move, player.root.status.to_play) player.play_move(move) if player.root.is_done(): #print('[!] finish') break #X, p, v = player.generate_data() return player
def test_long_game_tree_search(self): player = MCTSPlayer(DummyNet()) endgame = go.Position(board=TT_FTW_BOARD, n=flags.FLAGS.max_game_length - 2, komi=2.5, ko=None, recent=(go.PlayerMove(go.BLACK, (0, 1)), go.PlayerMove(go.WHITE, (0, 8))), to_play=go.BLACK) player.initialize_game(endgame) # Test that MCTS can deduce that B wins because of TT-scoring # triggered by move limit. for _ in range(10): player.tree_search(parallel_readouts=8) self.assertNoPendingVirtualLosses(player.root) self.assertGreater(player.root.Q, 0)
def test_extract_data_normal_end(self): player = MCTSPlayer(DummyNet()) player.initialize_game() player.tree_search() player.play_move(None) player.tree_search() player.play_move(None) self.assertTrue(player.root.is_done()) player.set_result(player.root.position.result(), was_resign=False) data = list(player.extract_data()) self.assertEqual(2, len(data)) position, _, result = data[0] # White wins by komi self.assertEqual(go.WHITE, result) self.assertEqual("W+{}".format(player.root.position.komi), player.result_string)
def get_mcts_player(network, pos): if random.random() < FLAGS.resign_disable_pct: resign_threshold = -1.0 else: resign_threshold = None player = MCTSPlayer(network, resign_threshold=resign_threshold) player.initialize_game(position=pos) # Must run this once at the start to expand the root node. first_node = player.root.select_leaf() prob, val = network.run(first_node.position) first_node.incorporate_results(prob, val, first_node) # while True: start = time.time() player.root.inject_noise() current_readouts = player.root.N # we want to do "X additional readouts", rather than "up to X readouts". while player.root.N < current_readouts + readouts: player.tree_search() return player
def play(network): readouts = FLAGS.num_readouts player = MCTSPlayer(network) player.initialize_game() first_node = player.root.select_leaf() prob, val = network.predict(first_node.position.state) first_node.incorporate_results(prob, val, first_node) while True: # player.root.inject_noise() current_readouts = player.root.N while player.root.N < current_readouts + readouts: player.tree_search() move = player.pick_move() player.play_move(move) tf.logging.info('playing move: %d hamming distance: %d' % (move, state_diff(player.root.position.state))) if player.root.is_done(): tf.logging.info('done') break
def test_only_check_game_end_once(self): # When presented with a situation where the last move was a pass, # and we have to decide whether to pass, it should be the first thing # we check, but not more than that. white_passed_pos = go.Position().play_move( (3, 3) # b plays ).play_move( (3, 4) # w plays ).play_move( (4, 3) # b plays ).pass_move() # w passes - if B passes too, B would lose by komi. player = MCTSPlayer(DummyNet()) player.initialize_game(white_passed_pos) # initialize the root player.tree_search() # explore a child - should be a pass move. player.tree_search() pass_move = go.N * go.N self.assertEqual(1, player.root.children[pass_move].N) self.assertEqual(1, player.root.child_N[pass_move]) player.tree_search() # check that we didn't visit the pass node any more times. self.assertEqual(player.root.child_N[pass_move], 1)
def play(network): readouts = FLAGS.num_readouts player = MCTSPlayer(network) player.initialize_game() first_node = player.root.select_leaf() prob, val = network.predict(first_node.position.state) first_node.incorporate_results(prob, val, first_node) lastmove = -1 hamm_dist = state_diff(player.root.position.state) for lo in range(0, hamm_dist): # player.root.inject_noise() current_readouts = player.root.N start = time.time() while player.root.N < current_readouts + readouts and time.time( ) - start < FLAGS.time_per_move: player.tree_search() move = player.pick_move() if move == lastmove: tf.logging.info('lastmove == move') return state_diff(player.root.position.state) before = state_diff(player.root.position.state) player.play_move(move) after = state_diff(player.root.position.state) if after > before: tf.logging.info('move increasing distance') return after tf.logging.info('playing move: %d hamming distance: %d' % (move, state_diff(player.root.position.state))) if player.root.is_done(): tf.logging.info('done') return 0 lastmove = move return state_diff(player.root.position.state)
def test_extract_data_resign_end(self): player = MCTSPlayer(DummyNet()) player.initialize_game() player.tree_search() player.play_move((0, 0)) player.tree_search() player.play_move(None) player.tree_search() # Black is winning on the board self.assertEqual(go.BLACK, player.root.position.result()) # But if Black resigns player.set_result(go.WHITE, was_resign=True) data = list(player.extract_data()) position, _, result = data[0] # Result should say White is the winner self.assertEqual(go.WHITE, result) self.assertEqual("W+R", player.result_string)