def test_select_node(): board = cm.initialize_game_state() child_board = cm.initialize_game_state() child_board[0, 0] = cm.PLAYER1 current_node = node.mcts_node(state=board) child_node = node.mcts_node(state=child_board, parent=current_node) current_node.open_moves = [0, 3, 4] current_node.children = [child_node] selected_node = agent.select_node(current_node) # since it has unexpanded node, hence it will return the current node assert selected_node == current_node
def test_explore_node(): board = cm.initialize_game_state() current_node = node.mcts_node(state=board) current_node.open_moves = [0, 3, 4] explored_node = agent.explore_node(current_node) # explore a new open node assert len(current_node.open_moves) == 2 assert explored_node != current_node
def test_child_selection(): board = cm.initialize_game_state() child_board01 = board.copy() child_board01[0, 0] = cm.PLAYER1 child_board02 = board.copy() child_board02[0, 3] = cm.PLAYER1 child_node01 = node.mcts_node(state=child_board01, player=cm.PLAYER1) child_node02 = node.mcts_node(state=child_board02, player=cm.PLAYER1) current_node = node.mcts_node(state=board, player=cm.PLAYER1) current_node.total_visits = 100 child_node01.total_visits = 50 child_node02.total_visits = 40 child_node01.num_wins = 35 child_node02.num_wins = 25 children_array = [child_node01, child_node02] current_node.children = children_array assert child_node01.__eq__(current_node.select_next_node())
def test_get_open_moves_connected4(): board = cm.initialize_game_state() board[0, 0] = cm.PLAYER1 board[0, 1] = cm.PLAYER1 board[0, 2] = cm.PLAYER1 board[0, 3] = cm.PLAYER1 current_node = node.mcts_node(state=board, player=cm.PLAYER1) assert current_node.get_open_moves().shape[0] == 0
def test_expansion(): board = cm.initialize_game_state() current_node = node.mcts_node(state=board, player=cm.PLAYER1) current_node.total_visits = 100 current_node.open_moves = [0, 5] assert len(current_node.open_moves) == 2 assert len(current_node.children) == 0 child_node = current_node.expand_node(5) assert len(current_node.open_moves) == 1 assert len(current_node.children) == 1 assert current_node.open_moves[0] == 0 assert child_node.__eq__(current_node.children[0])
def MCTS(board: np.ndarray) -> cn.PlayerAction: """ The monte carlo tree search algorithm :param board: The board against which to calculate the simulation :return: Column in which player wants to make his move (chosen using MCTS) """ # initialise root node to start the algorithm root = mcts_node.mcts_node(state=board, player=PLAYER) global GLOBAL_TIME # defining the limiting factor to get out of the loop end = time.time() + GLOBAL_TIME while time.time() < end: # Start at the root node at each iteration node = root # Step 1: select_next_node - this selects a terminal node or an unexplored node node = select_node(node) # Step 2: Explore the node that is selected node = explore_node(node) # Step 3 : Simulation win_game_flag, current_player = simulate_game(node) # Step 4: Back propagation back_propagation(node, win_game_flag, current_player) max_score = - 10000000 selected_column = - 1 for child in root.children: if cn.connected_four(child.state, child.player): return child.move else: score = child.num_wins / child.total_visits if score > max_score: selected_column = child.move max_score = score return selected_column
def test_simulate_game(): board = cm.initialize_game_state() current_node = node.mcts_node(state=board, player=cm.PLAYER1) win, player = agent.simulate_game(current_node) # simulated till a win is encountered assert win
def test_back_propagation(): board = cm.initialize_game_state() current_node = node.mcts_node(state=board) agent.back_propagation(current_node, cm.PLAYER1, False) assert current_node.total_visits == 1
def test_get_open_moves(): board = cm.initialize_game_state() current_node = node.mcts_node(state=board, player=cm.PLAYER1) assert current_node.get_open_moves().shape[0] == 7
def test_set_visit_win(): board = cm.initialize_game_state() current_node = node.mcts_node(state=board, player=cm.PLAYER1) current_node.set_visit_and_win(result=10) assert current_node.num_wins == 10 assert current_node.total_visits == 1