Ejemplo n.º 1
0
def test_advance_root():
    children = {i: mcts.Node() for i in range(5)}
    root_with_children = mcts.Node(children=children)
    m = mcts.MCTS(root_with_children, '', '', '')
    m.advance_root(1)
    assert m.root.parent is None
    assert m.root == children[1]

    m = mcts.MCTS(mcts.Node(), '', '', '')
    move = chess.Move.from_uci('f2f3')
    m.advance_root(move)
    assert m.root.parent is None
    assert m.root.board.move_stack == [move]
Ejemplo n.º 2
0
def test_get_move():
    children = {
        'm1': mcts.Node(visit=1),
        'm2': mcts.Node(visit=0),
        'm3': mcts.Node(visit=3),
        'm4': mcts.Node(visit=2),
    }
    root = mcts.Node(children=children)
    m = mcts.MCTS(root, '', '', '')
    assert m.get_move() == 'm3'

    m = mcts.MCTS(mcts.Node(), '', '', '')
    with pytest.raises(mcts.MCTSError):
        m.get_move()
Ejemplo n.º 3
0
def test_expand():
    mock_policy = mock.MagicMock()
    probs = torch.randn(1, 4672)
    mock_policy.get_probs.return_value = probs
    m = mcts.MCTS('', '', mock_policy, '')
    # no children at this point
    n = mcts.Node()

    random_child = m.expand(n)
    # should have children now. 20 to be exact since we just expanded the root
    assert len(n.children) == 20
    for move, c in n.children.items():
        engine_move = translate_to_engine_move(move, chess.WHITE)
        index = get_engine_move_index(engine_move)
        assert c.prior == probs[0, index]
    assert random_child in n.children.values()

    # can't expand if it already has been expanded
    with pytest.raises(mcts.MCTSError):
        m.expand(n)

    # if expanding a terminal state, just return the node
    n = mcts.Node(board=chess.Board(
        fen='3b1q1q/1N2PRQ1/rR3KBr/B4PP1/2Pk1r1b/1P2P1N1/2P2P2/8 '
        'b - - 0 1'))
    assert n == m.expand(n)
Ejemplo n.º 4
0
def test_simulate_game_over():
    # if the game is over, return the reward
    # use the fool's mate
    n = mcts.Node()
    n.board = chess.Board(
        fen='rnb1kbnr/pppp1ppp/8/4p3/6Pq/5P2/PPPPP2P/RNBQKBNR w KQkq - 1 3')
    m = mcts.MCTS(n, '', '', '')
    value = m.simulate(n)
    assert value == -1
Ejemplo n.º 5
0
def test_simulate():
    mock_value = mock.MagicMock()
    mock_value.get_value.return_value = -0.9
    n = mcts.Node()
    m = mcts.MCTS(n, mock_value, '', '')
    value = m.simulate(n)
    assert value == -0.9

    with pytest.raises(mcts.MCTSError):
        n.children[1] = mcts.Node()
        m.simulate(n)
Ejemplo n.º 6
0
def test_backup():
    # white turn
    node = mcts.Node()
    # black turn
    node.parent = mcts.Node(board=chess.Board(
        fen='rnbqkbnr/pppppppp/8/8/8/5P2/PPPPP1PP/RNBQKBNR b KQkq - 0 1'))
    # white turn
    node.parent.parent = mcts.Node()
    m = mcts.MCTS('', '', '', '')
    m.backup(node, 0.9)
    walker = node
    while walker:
        assert walker.value == 0.9
        assert walker.visit == 1
        walker = walker.parent
Ejemplo n.º 7
0
def test_select():
    children = []
    for i in range(3):
        n = mcts.Node()
        n.ucb = mock.MagicMock(return_value=i)
        children.append(n)
    root = mcts.Node()
    m = mcts.MCTS(root, '', '', '')

    # if root is already a leaf, return that
    selected = m.select()
    assert root == selected

    # traverse the tree, picking the node with the biggest ucb
    root.children = {i: children[i] for i in range(3)}
    selected = m.select()
    assert selected == children[-1]