def test_playout_with_pass(self): # Test that playout handles the end of the game (i.e. passing/no moves). Mock this by # creating a policy that returns nothing after 4 moves. def stop_early_policy(state): if len(state.history) <= 4: return dummy_policy(state) else: return [] self.mcts = MCTS(dummy_value, stop_early_policy, stop_early_policy, n_playout=2) self.mcts._playout(self.gs.copy(), 8) # Assert that (18, 18) and (18, 17) are still only visited once. self.assertEqual(1, self.mcts._root._children[(18, 18)]._n_visits) # Assert that no expansions happened after reaching the "end" in 4 moves. self.assertEqual(5, self._count_expansions())
class TestMCTS(unittest.TestCase): def setUp(self): self.gs = GameState() self.mcts = MCTS(self.gs, value_network, policy_network, rollout_policy, n_search=2) def test_treenode_selection(self): treenode = TreeNode(None, 1.0) treenode.expansion(policy_network(self.gs)) action, node = treenode.selection() self.assertEqual(action, (18, 18)) # according to the policy below self.assertIsNotNone(node) def test_mcts_DFS(self): treenode = TreeNode(None, 1.0) self.mcts._DFS(8, treenode, self.gs.copy()) self.assertEqual(1, treenode.children[(18, 18)].nVisits, 'DFS visits incorrect') def test_mcts_getMove(self): move = self.mcts.get_move(self.gs) self.mcts.update_with_move(move)
def setUp(self): self.gs = GameState() self.mcts = MCTS(dummy_value, dummy_policy, dummy_rollout, n_playout=2)
class TestMCTS(unittest.TestCase): def setUp(self): self.gs = GameState() self.mcts = MCTS(dummy_value, dummy_policy, dummy_rollout, n_playout=2) def _count_expansions(self): """Helper function to count the number of expansions past the root using the dummy policy """ node = self.mcts._root expansions = 0 # Loop over actions in decreasing probability. for action, _ in sorted(dummy_policy(self.gs), key=itemgetter(1), reverse=True): if action in node._children: expansions += 1 node = node._children[action] else: break return expansions def test_playout(self): self.mcts._playout(self.gs.copy(), 8) # Assert that the most likely child was visited (according to the dummy policy below). self.assertEqual(1, self.mcts._root._children[(18, 18)]._n_visits) # Assert that the search depth expanded nodes 8 times. self.assertEqual(8, self._count_expansions()) def test_playout_with_pass(self): # Test that playout handles the end of the game (i.e. passing/no moves). Mock this by # creating a policy that returns nothing after 4 moves. def stop_early_policy(state): if len(state.history) <= 4: return dummy_policy(state) else: return [] self.mcts = MCTS(dummy_value, stop_early_policy, stop_early_policy, n_playout=2) self.mcts._playout(self.gs.copy(), 8) # Assert that (18, 18) and (18, 17) are still only visited once. self.assertEqual(1, self.mcts._root._children[(18, 18)]._n_visits) # Assert that no expansions happened after reaching the "end" in 4 moves. self.assertEqual(5, self._count_expansions()) def test_get_move(self): move = self.mcts.get_move(self.gs) self.mcts.update_with_move(move) # success if no errors def test_update_with_move(self): move = self.mcts.get_move(self.gs) self.gs.do_move(move) self.mcts.update_with_move(move) # Assert that the new root still has children. self.assertTrue(len(self.mcts._root._children) > 0) # Assert that the new root has no parent (the rest of the tree will be garbage collected). self.assertIsNone(self.mcts._root._parent) # Assert that the next best move according to the root is (18, 17), according to the # dummy policy below. self.assertEqual((18, 17), self.mcts._root.select()[0])
def setUp(self): self.gs = GameState() self.mcts = MCTS(self.gs, value_network, policy_network, rollout_policy, n_search=2)