def test_mcts_value_at_children_of_root(): mock_game = MockGame() root = MCTSNode(0, player=1) assert root.N == 0 mcts(root, mock_game, mock_game.mock_estimator, 100, 1.0) terminal_nodes = get_terminal_nodes(root) N_terminal_nodes = sum(node.N for node in terminal_nodes) # Each iteration of MCTS we should add 1 to W of one of the # children of the root. assert sum(child.W for child in root.children.values()) == N_terminal_nodes
def test_mcts_action_count_at_root_children(): mock_game = MockGame() root = MCTSNode(0, player=1) action_probs = mcts(root, mock_game, mock_game.mock_estimator, 100, 1.0) # Each iteration of MCTS we should add 1 to N at the root. assert sum(child.N for child in root.children.values()) == 99
def test_can_run_mcts_on_fake_game(): """ This test shows that we can run MCTS using a 'next_states' function and 'evaluator' function. """ mock_game = MockGame() root = MCTSNode(0, player=1) action_probs = mcts(root, mock_game, mock_game.mock_estimator, 100, 1.0) assert action_probs is not None
def test_mcts_action_count_at_root(): mock_game = MockGame() root = MCTSNode(0, player=1) assert root.N == 0 action_probs = mcts(root, mock_game, mock_game.mock_estimator, 100, 1.0) # Each iteration of MCTS we should add 1 to N at the root. assert root.N == 100
def test_print_mcts(): mock_game = MockGame() # TODO: This doesn't currently test anything. # TODO: This next line is probably a bad idea mock_game.TERMINAL_STATE_VALUES = [0.01 * i for i in range(12)] root = MCTSNode(0, player=1) action_probs = mcts(root, mock_game, mock_game.mock_estimator, 10, 1.0) print_tree(root) mock_game.TERMINAL_STATE_VALUES = (1, ) * 12
def test_mcts_can_play_fake_game(evaluator, expected): mock_game = MockGame() root = MCTSNode(0, player=1) node = root nodes = [node] while not node.is_terminal: action_probs = mcts(root, mock_game, evaluator, 100, 1) action = max(action_probs, key=action_probs.get) node = node.children[action] nodes.append(node) assert [node.game_state for node in nodes] == expected
def test_neural_net_estimator(): mock_game = MockGame() nnet = MockNetEstimator(learning_rate=0.01) root = MCTSNode(0, player=1) action_probs = mcts(root, mock_game, nnet, 100, 1.0)