Ejemplo n.º 1
0
def test_mcts_value_at_children_of_root():
    mock_game = MockGame()
    root = MCTSNode(0, player=1)
    assert root.N == 0

    mcts(root, mock_game, mock_game.mock_estimator, 100, 1.0)

    terminal_nodes = get_terminal_nodes(root)
    N_terminal_nodes = sum(node.N for node in terminal_nodes)

    # Each iteration of MCTS we should add 1 to W of one of the
    # children of the root.
    assert sum(child.W for child in root.children.values()) == N_terminal_nodes
Ejemplo n.º 2
0
def test_mcts_action_count_at_root_children():
    mock_game = MockGame()
    root = MCTSNode(0, player=1)

    action_probs = mcts(root, mock_game, mock_game.mock_estimator, 100, 1.0)

    # Each iteration of MCTS we should add 1 to N at the root.
    assert sum(child.N for child in root.children.values()) == 99
Ejemplo n.º 3
0
def test_can_run_mcts_on_fake_game():
    """ This test shows that we can run MCTS using a 'next_states'
    function and 'evaluator' function.
    """
    mock_game = MockGame()
    root = MCTSNode(0, player=1)
    action_probs = mcts(root, mock_game, mock_game.mock_estimator, 100, 1.0)

    assert action_probs is not None
Ejemplo n.º 4
0
def test_mcts_action_count_at_root():
    mock_game = MockGame()
    root = MCTSNode(0, player=1)
    assert root.N == 0

    action_probs = mcts(root, mock_game, mock_game.mock_estimator, 100, 1.0)

    # Each iteration of MCTS we should add 1 to N at the root.
    assert root.N == 100
Ejemplo n.º 5
0
def test_print_mcts():
    mock_game = MockGame()
    # TODO: This doesn't currently test anything.
    # TODO: This next line is probably a bad idea
    mock_game.TERMINAL_STATE_VALUES = [0.01 * i for i in range(12)]

    root = MCTSNode(0, player=1)
    action_probs = mcts(root, mock_game, mock_game.mock_estimator, 10, 1.0)
    print_tree(root)
    mock_game.TERMINAL_STATE_VALUES = (1, ) * 12
Ejemplo n.º 6
0
def test_mcts_can_play_fake_game(evaluator, expected):
    mock_game = MockGame()

    root = MCTSNode(0, player=1)
    node = root
    nodes = [node]

    while not node.is_terminal:
        action_probs = mcts(root, mock_game, evaluator, 100, 1)

        action = max(action_probs, key=action_probs.get)
        node = node.children[action]
        nodes.append(node)
    assert [node.game_state for node in nodes] == expected
Ejemplo n.º 7
0
def test_neural_net_estimator():
    mock_game = MockGame()
    nnet = MockNetEstimator(learning_rate=0.01)

    root = MCTSNode(0, player=1)
    action_probs = mcts(root, mock_game, nnet, 100, 1.0)