def test_copy_maintains_shared_sets(self): gs = GameState(7) gs.do_move((4, 4), go.BLACK) gs.do_move((4, 5), go.BLACK) # assert that gs has *the same object* referenced by group/liberty sets self.assertTrue(gs.group_sets[4][5] is gs.group_sets[4][4]) self.assertTrue(gs.liberty_sets[4][5] is gs.liberty_sets[4][4]) gs_copy = gs.copy() self.assertTrue(gs_copy.group_sets[4][5] is gs_copy.group_sets[4][4]) self.assertTrue( gs_copy.liberty_sets[4][5] is gs_copy.liberty_sets[4][4])
class TestMCTS(unittest.TestCase): def setUp(self): self.gs = GameState() self.mcts = MCTS(dummy_value, dummy_policy, dummy_rollout, n_playout=2) def _count_expansions(self): """Helper function to count the number of expansions past the root using the dummy policy """ node = self.mcts._root expansions = 0 # Loop over actions in decreasing probability. for action, _ in sorted(dummy_policy(self.gs), key=lambda (a, p): p, reverse=True): if action in node._children: expansions += 1 node = node._children[action] else: break return expansions def test_playout(self): self.mcts._playout(self.gs.copy(), 8) # Assert that the most likely child was visited (according to the dummy policy below). self.assertEqual(1, self.mcts._root._children[(18, 18)]._n_visits) # Assert that the search depth expanded nodes 8 times. self.assertEqual(8, self._count_expansions()) def test_playout_with_pass(self): # Test that playout handles the end of the game (i.e. passing/no moves). Mock this by # creating a policy that returns nothing after 4 moves. def stop_early_policy(state): if len(state.history) <= 4: return dummy_policy(state) else: return [] self.mcts = MCTS(dummy_value, stop_early_policy, stop_early_policy, n_playout=2) self.mcts._playout(self.gs.copy(), 8) # Assert that (18, 18) and (18, 17) are still only visited once. self.assertEqual(1, self.mcts._root._children[(18, 18)]._n_visits) # Assert that no expansions happened after reaching the "end" in 4 moves. self.assertEqual(5, self._count_expansions()) def test_get_move(self): move = self.mcts.get_move(self.gs) self.mcts.update_with_move(move) # success if no errors def test_update_with_move(self): move = self.mcts.get_move(self.gs) self.gs.do_move(move) self.mcts.update_with_move(move) # Assert that the new root still has children. self.assertTrue(len(self.mcts._root._children) > 0) # Assert that the new root has no parent (the rest of the tree will be garbage collected). self.assertIsNone(self.mcts._root._parent) # Assert that the next best move according to the root is (18, 17), according to the # dummy policy below. self.assertEqual((18, 17), self.mcts._root.select()[0])