def test_select_leaf(self): probs = np.array([.02] * (go.N * go.N + 1)) probs[kgs_to_flat('D9')] = 0.4 root = MCTSNode(SEND_TWO_RETURN_ONE) root.select_leaf().incorporate_results(probs, 0, root) self.assertEqual(root.position.to_play, go.WHITE) self.assertEqual(root.select_leaf(), root.children[kgs_to_flat('D9')])
def test_dont_pass_if_losing(self): player = initialize_almost_done_player() # check -- white is losing. self.assertEqual(player.root.position.score(), -0.5) for i in range(20): player.tree_search() # uncomment to debug this test # print(player.root.describe()) # Search should converge on D9 as only winning move. best_move = np.argmax(player.root.child_N) self.assertEqual(best_move, kgs_to_flat('D9')) # D9 should have a positive value self.assertGreater(player.root.children[kgs_to_flat('D9')].Q, 0) self.assertGreaterEqual(player.root.N, 20) # passing should be ineffective. self.assertLess(player.root.child_Q[-1], 0) # no virtual losses should be pending self.assertNoPendingVirtualLosses(player.root)
def test_parallel_tree_search(self): player = initialize_almost_done_player() # check -- white is losing. self.assertEqual(player.root.position.score(), -0.5) # initialize the tree so that the root node has populated children. player.tree_search(num_parallel=1) # virtual losses should enable multiple searches to happen simultaneously # without throwing an error... for i in range(5): player.tree_search(num_parallel=4) # uncomment to debug this test # print(player.root.describe()) # Search should converge on D9 as only winning move. best_move = np.argmax(player.root.child_N) self.assertEqual(best_move, kgs_to_flat('D9')) # D9 should have a positive value self.assertGreater(player.root.children[kgs_to_flat('D9')].Q, 0) self.assertGreaterEqual(player.root.N, 20) # passing should be ineffective. self.assertLess(player.root.child_Q[-1], 0) # no virtual losses should be pending self.assertNoPendingVirtualLosses(player.root)