def test_advance_root(): children = {i: mcts.Node() for i in range(5)} root_with_children = mcts.Node(children=children) m = mcts.MCTS(root_with_children, '', '', '') m.advance_root(1) assert m.root.parent is None assert m.root == children[1] m = mcts.MCTS(mcts.Node(), '', '', '') move = chess.Move.from_uci('f2f3') m.advance_root(move) assert m.root.parent is None assert m.root.board.move_stack == [move]
def test_get_move(): children = { 'm1': mcts.Node(visit=1), 'm2': mcts.Node(visit=0), 'm3': mcts.Node(visit=3), 'm4': mcts.Node(visit=2), } root = mcts.Node(children=children) m = mcts.MCTS(root, '', '', '') assert m.get_move() == 'm3' m = mcts.MCTS(mcts.Node(), '', '', '') with pytest.raises(mcts.MCTSError): m.get_move()
def test_expand(): mock_policy = mock.MagicMock() probs = torch.randn(1, 4672) mock_policy.get_probs.return_value = probs m = mcts.MCTS('', '', mock_policy, '') # no children at this point n = mcts.Node() random_child = m.expand(n) # should have children now. 20 to be exact since we just expanded the root assert len(n.children) == 20 for move, c in n.children.items(): engine_move = translate_to_engine_move(move, chess.WHITE) index = get_engine_move_index(engine_move) assert c.prior == probs[0, index] assert random_child in n.children.values() # can't expand if it already has been expanded with pytest.raises(mcts.MCTSError): m.expand(n) # if expanding a terminal state, just return the node n = mcts.Node(board=chess.Board( fen='3b1q1q/1N2PRQ1/rR3KBr/B4PP1/2Pk1r1b/1P2P1N1/2P2P2/8 ' 'b - - 0 1')) assert n == m.expand(n)
def test_simulate_game_over(): # if the game is over, return the reward # use the fool's mate n = mcts.Node() n.board = chess.Board( fen='rnb1kbnr/pppp1ppp/8/4p3/6Pq/5P2/PPPPP2P/RNBQKBNR w KQkq - 1 3') m = mcts.MCTS(n, '', '', '') value = m.simulate(n) assert value == -1
def test_simulate(): mock_value = mock.MagicMock() mock_value.get_value.return_value = -0.9 n = mcts.Node() m = mcts.MCTS(n, mock_value, '', '') value = m.simulate(n) assert value == -0.9 with pytest.raises(mcts.MCTSError): n.children[1] = mcts.Node() m.simulate(n)
def test_backup(): # white turn node = mcts.Node() # black turn node.parent = mcts.Node(board=chess.Board( fen='rnbqkbnr/pppppppp/8/8/8/5P2/PPPPP1PP/RNBQKBNR b KQkq - 0 1')) # white turn node.parent.parent = mcts.Node() m = mcts.MCTS('', '', '', '') m.backup(node, 0.9) walker = node while walker: assert walker.value == 0.9 assert walker.visit == 1 walker = walker.parent
def test_select(): children = [] for i in range(3): n = mcts.Node() n.ucb = mock.MagicMock(return_value=i) children.append(n) root = mcts.Node() m = mcts.MCTS(root, '', '', '') # if root is already a leaf, return that selected = m.select() assert root == selected # traverse the tree, picking the node with the biggest ucb root.children = {i: children[i] for i in range(3)} selected = m.select() assert selected == children[-1]