예제 #1
0
 def test_tree_search_failsafe(self):
     # Test that the failsafe works correctly. It can trigger if the MCTS
     # repeatedly visits a finished game state.
     probs = np.array([.001] * (go.N * go.N + 1))
     probs[-1] = 1  # Make the dummy net always want to pass
     player = MCTSPlayerMixin(DummyNet(fake_priors=probs))
     pass_position = go.Position().pass_move()
     player.initialize_game(pass_position)
     player.tree_search(num_parallel=1)
     self.assertNoPendingVirtualLosses(player.root)
예제 #2
0
 def test_tree_search_failsafe(self):
     # Test that the failsafe works correctly. It can trigger if the MCTS
     # repeatedly visits a finished game state.
     probs = np.array([.001] * (go.N * go.N + 1))
     probs[-1] = 1  # Make the dummy net always want to pass
     player = MCTSPlayerMixin(DummyNet(fake_priors=probs))
     pass_position = go.Position().pass_move()
     player.initialize_game(pass_position)
     player.tree_search(num_parallel=1)
     self.assertNoPendingVirtualLosses(player.root)
예제 #3
0
 def test_cold_start_parallel_tree_search(self):
     # Test that parallel tree search doesn't trip on an empty tree
     player = MCTSPlayerMixin(DummyNet(fake_value=0.17))
     player.initialize_game()
     self.assertEqual(player.root.N, 0)
     self.assertFalse(player.root.is_expanded)
     player.tree_search(num_parallel=4)
     self.assertNoPendingVirtualLosses(player.root)
     # Even though the root gets selected 4 times by tree search, its
     # final visit count should just be 1.
     self.assertEqual(player.root.N, 1)
     # 0.085 = average(0, 0.17), since 0 is the prior on the root.
     self.assertAlmostEqual(player.root.Q, 0.085)
예제 #4
0
 def test_cold_start_parallel_tree_search(self):
   # Test that parallel tree search doesn't trip on an empty tree
   player = MCTSPlayerMixin(utils_test.BOARD_SIZE, DummyNet(fake_value=0.17))
   player.initialize_game()
   self.assertEqual(player.root.N, 0)
   self.assertFalse(player.root.is_expanded)
   player.tree_search(num_parallel=4)
   self.assertNoPendingVirtualLosses(player.root)
   # Even though the root gets selected 4 times by tree search, its
   # final visit count should just be 1.
   self.assertEqual(player.root.N, 1)
   # 0.085 = average(0, 0.17), since 0 is the prior on the root.
   self.assertAlmostEqual(player.root.Q, 0.085)
예제 #5
0
    def test_long_game_tree_search(self):
        player = MCTSPlayerMixin(DummyNet())
        endgame = go.Position(board=TT_FTW_BOARD,
                              n=MAX_DEPTH - 2,
                              komi=2.5,
                              ko=None,
                              recent=(go.PlayerMove(go.BLACK, (0, 1)),
                                      go.PlayerMove(go.WHITE, (0, 8))),
                              to_play=go.BLACK)
        player.initialize_game(endgame)

        # Test that an almost complete game
        for i in range(10):
            player.tree_search(num_parallel=8)
        self.assertNoPendingVirtualLosses(player.root)
        self.assertGreater(player.root.Q, 0)
예제 #6
0
    def test_extract_data_normal_end(self):
        player = MCTSPlayerMixin(DummyNet())
        player.initialize_game()
        player.tree_search()
        player.play_move(None)
        player.tree_search()
        player.play_move(None)
        self.assertTrue(player.root.is_done())
        player.set_result(player.root.position.result(), was_resign=False)

        data = list(player.extract_data())
        self.assertEqual(len(data), 2)
        position, pi, result = data[0]
        # White wins by komi
        self.assertEqual(result, go.WHITE)
        self.assertEqual(player.result_string, "W+{}".format(player.root.position.komi))
예제 #7
0
    def test_extract_data_normal_end(self):
        player = MCTSPlayerMixin(DummyNet())
        player.initialize_game()
        player.tree_search()
        player.play_move(None)
        player.tree_search()
        player.play_move(None)
        self.assertTrue(player.root.is_done())
        player.set_result(player.root.position.result(), was_resign=False)

        data = list(player.extract_data())
        self.assertEqual(len(data), 2)
        position, pi, result = data[0]
        # White wins by komi
        self.assertEqual(result, go.WHITE)
        self.assertEqual(player.result_string,
                         "W+{}".format(player.root.position.komi))
    def test_long_game_tree_search(self):
        player = MCTSPlayerMixin(DummyNet())
        endgame = go.Position(board=TT_FTW_BOARD,
                              n=flags.FLAGS.max_game_length - 2,
                              komi=2.5,
                              ko=None,
                              recent=(go.PlayerMove(go.BLACK, (0, 1)),
                                      go.PlayerMove(go.WHITE, (0, 8))),
                              to_play=go.BLACK)
        player.initialize_game(endgame)

        # Test that MCTS can deduce that B wins because of TT-scoring
        # triggered by move limit.
        for i in range(10):
            player.tree_search(parallel_readouts=8)
        self.assertNoPendingVirtualLosses(player.root)
        self.assertGreater(player.root.Q, 0)
예제 #9
0
    def test_long_game_tree_search(self):
        player = MCTSPlayerMixin(DummyNet())
        endgame = go.Position(
            board=TT_FTW_BOARD,
            n=MAX_DEPTH-2,
            komi=2.5,
            ko=None,
            recent=(go.PlayerMove(go.BLACK, (0, 1)),
                    go.PlayerMove(go.WHITE, (0, 8))),
            to_play=go.BLACK
        )
        player.initialize_game(endgame)

        # Test that an almost complete game
        for i in range(10):
            player.tree_search(num_parallel=8)
        self.assertNoPendingVirtualLosses(player.root)
        self.assertGreater(player.root.Q, 0)
예제 #10
0
    def test_only_check_game_end_once(self):
        # When presented with a situation where the last move was a pass,
        # and we have to decide whether to pass, it should be the first thing
        # we check, but not more than that.

        white_passed_pos = go.Position().play_move(
            (3, 3)  # b plays
        ).play_move(
            (3, 4)  # w plays
        ).play_move(
            (4, 3)  # b plays
        ).pass_move()  # w passes - if B passes too, B would lose by komi.

        player = MCTSPlayerMixin(DummyNet())
        player.initialize_game(white_passed_pos)
        # initialize the root
        player.tree_search()
        # explore a child - should be a pass move.
        player.tree_search()
        pass_move = go.N * go.N
        self.assertEqual(player.root.children[pass_move].N, 1)
        self.assertEqual(player.root.child_N[pass_move], 1)
        player.tree_search()
        # check that we didn't visit the pass node any more times.
        self.assertEqual(player.root.child_N[pass_move], 1)
예제 #11
0
  def test_only_check_game_end_once(self):
    # When presented with a situation where the last move was a pass,
    # and we have to decide whether to pass, it should be the first thing
    # we check, but not more than that.

    white_passed_pos = go.Position(
        utils_test.BOARD_SIZE,).play_move(
            (3, 3)  # b plays
            ).play_move(
                (3, 4)  # w plays
            ).play_move(
                (4, 3)  # b plays
            ).pass_move()  # w passes - if B passes too, B would lose by komi.

    player = MCTSPlayerMixin(utils_test.BOARD_SIZE, DummyNet())
    player.initialize_game(white_passed_pos)
    # initialize the root
    player.tree_search()
    # explore a child - should be a pass move.
    player.tree_search()
    pass_move = utils_test.BOARD_SIZE * utils_test.BOARD_SIZE
    self.assertEqual(player.root.children[pass_move].N, 1)
    self.assertEqual(player.root.child_N[pass_move], 1)
    player.tree_search()
    # check that we didn't visit the pass node any more times.
    self.assertEqual(player.root.child_N[pass_move], 1)
예제 #12
0
    def test_extract_data_resign_end(self):
        player = MCTSPlayerMixin(DummyNet())
        player.initialize_game()
        player.tree_search()
        player.play_move((0, 0))
        player.tree_search()
        player.play_move(None)
        player.tree_search()
        # Black is winning on the board
        self.assertEqual(player.root.position.result(), go.BLACK)
        # But if Black resigns
        player.set_result(go.WHITE, was_resign=True)

        data = list(player.extract_data())
        position, pi, result = data[0]
        # Result should say White is the winner
        self.assertEqual(result, go.WHITE)
        self.assertEqual(player.result_string, "W+R")
예제 #13
0
  def test_extract_data_resign_end(self):
    player = MCTSPlayerMixin(utils_test.BOARD_SIZE, DummyNet())
    player.initialize_game()
    player.tree_search()
    player.play_move((0, 0))
    player.tree_search()
    player.play_move(None)
    player.tree_search()
    # Black is winning on the board
    self.assertEqual(player.root.position.result(), go.BLACK)
    # But if Black resigns
    player.set_result(go.WHITE, was_resign=True)

    data = list(player.extract_data())
    position, pi, result = data[0]
    # Result should say White is the winner
    self.assertEqual(result, go.WHITE)
    self.assertEqual(player.result_string, 'W+R')