def __init__(self, board_size: int) -> None: self.commands = { "name": self.name, "version": self.version, "protocol_version": self.protocol_version, "known_command": self.known_command, "list_commands": self.list_commands, "quit": self.quit, "boardsize": self.boardsize, "clear_board": self.clear_board, "play": self.play, "genmove": self.genmove, "showboard": self.showboard, # "set_time": self.set_time, "result": self.result, } self.board_size = board_size self.net = self.get_game_net(board_size).to(torch.device("cuda:1")) self.game_manager = self.net.manager self.mcts = MCTS( self.game_manager, num_simulations=100, rollout_policy=None, state_evaluator=cached_state_evaluator(self.net), )
def clear_board(self, args: List[str]) -> None: self.net = self.get_game_net(self.board_size) self.game_manager = self.net.manager self.mcts = MCTS( self.game_manager, num_simulations=100, rollout_policy=None, state_evaluator=cached_state_evaluator(self.net), )
def evaluate( game_net: GameNet[_S], previous_game_net: GameNet[_S], game_manager: GameManager[_S], config: TrainingConfiguration[_S], ) -> Tuple[AgentComparison, AgentComparison]: state_evaluator = cached_state_evaluator(game_net) previous_state_evaluator = cached_state_evaluator(previous_game_net) random_mcts_evaluation = compare_agents( ( MCTSAgent( MCTS( game_manager, config.num_simulations, config.rollout_policy, state_evaluator=state_evaluator, )), MCTSAgent( MCTS( game_manager, config.num_simulations * 2, lambda s: random.choice(game_manager.legal_actions(s)), state_evaluator=None, )), ), config.evaluation_games, game_manager, ) previous_evaluation = compare_agents( ( MCTSAgent( MCTS( game_manager, config.num_simulations, config.rollout_policy, state_evaluator, ), epsilon=config.epsilon, ), MCTSAgent( MCTS( game_manager, config.num_simulations, config.rollout_policy, previous_state_evaluator, ), epsilon=config.epsilon, ), ), config.evaluation_games, game_manager, ) return random_mcts_evaluation, previous_evaluation
def create_self_play_examples( process_number: int, game_net: GameNet[_S], last_trained_iteration: torch.Tensor, config: TrainingConfiguration[_S], games_queue: "multiprocessing.Queue[SelfPlayGame[_S]]", ) -> None: game_manager = game_net.manager last_cached_iteration = 0 # Use a uniform evaluator as the starting point def uniform_state_evaluator(state: _S) -> Tuple[float, List[float]]: legal_actions = set(game_manager.legal_actions(state)) return ( 0.5, [ 1 / len(legal_actions) if action in legal_actions else 0.0 for action in range(game_manager.num_actions) ], ) state_evaluator: StateEvaluator[_S] = uniform_state_evaluator for i in range(config.num_games): # Recreate the cache if the network has been trained since # we last created the cache last_trained_iteration_value = cast(int, last_trained_iteration.item()) if last_trained_iteration_value > last_cached_iteration: state_evaluator = cached_state_evaluator(game_net) last_cached_iteration = last_trained_iteration_value mcts = MCTS( game_manager, config.num_simulations, config.rollout_policy, state_evaluator, config.sample_move_cutoff, config.dirichlet_alpha, config.dirichlet_factor, ) examples = [] for state, next_state, action, visit_distribution in mcts.self_play(): examples.append((state, visit_distribution)) # The network uses a range of [-1, 1] outcome = ( cast(float, game_manager.evaluate_final_state(next_state).value) * 2 - 1) games_queue.put([( state, visit_distribution, outcome if state.player == Player.max_player() else -outcome, ) for state, visit_distribution in examples]) if i % 100 == 0 and process_number == 0: print(f"{time.strftime('%H:%M:%S')} {i}")
def evaluate_models( model_dir: Path, net_class: Type[GameNet[_S]], manager: GameManager[_S], device: torch.device, ) -> None: model_file = max( (f for f in model_dir.iterdir() if f.name.endswith(".tar")), key=lambda f: int(f.name[5:-4]), ) model = net_class.from_path_full(str(model_file), manager).to(device) state_evaluator = cached_state_evaluator(model) agents = [ MCTSAgent( MCTS( manager, num_simulations=50, rollout_policy=( lambda state: random.choice(manager.legal_actions(state)) ) if i > 0 else None, state_evaluator=state_evaluator, rollout_share=i / 100, ) ) for i in range(0, 101, 20) ] result_dir = model_dir.parent.parent / "simple_rollouts" result_dir.mkdir(exist_ok=True) print(model_dir.name) results = tournament(agents, num_games=40, game_manager=manager) with open(result_dir / f"{model_dir.name}.json", "w") as f: json.dump(results, f)
def evaluate_models( model_dir: Path, net_class: Type[GameNet[_S]], manager: GameManager[_S], device: torch.device, ) -> None: print(model_dir.name) model_file = max( (f for f in model_dir.iterdir() if f.name.endswith(".tar")), key=lambda f: int(f.name[5:-4]), ) time_per_move = 1.0 for with_state_evaluator, dir_name in [ (True, "with_state_evaluator"), (False, "without_state_evaluator"), ]: manager_copy = manager.copy() model = net_class.from_path_full(str(model_file), manager_copy).to(device) cached_model = cached_state_evaluator(model) complex_agent = MCTSAgent( MCTS( manager_copy, num_simulations=float("inf"), # type: ignore[arg-type] rollout_policy=lambda state: np.argmax( # type: ignore[no-any-return] cached_model(state)[1] ), state_evaluator=cached_model if with_state_evaluator else None, rollout_share=1.0, time_per_move=time_per_move, ), reset_fn=reset_caches, ) manager_copy = manager.copy() state_evaluator: Optional[Callable[[_S], Tuple[float, Sequence[float]]]] if with_state_evaluator: model = net_class.from_path_full(str(model_file), manager_copy).to(device) state_evaluator = cached_state_evaluator(model) else: state_evaluator = None simple_agent = MCTSAgent( MCTS( manager_copy, num_simulations=float("inf"), # type: ignore[arg-type] rollout_policy=lambda state: random.choice( manager_copy.legal_actions(state) ), state_evaluator=state_evaluator, rollout_share=1.0, time_per_move=time_per_move, ), reset_fn=reset_caches, ) agents = (complex_agent, simple_agent) result_dir = model_dir.parent.parent / "complex_rollouts" / dir_name result_dir.mkdir(exist_ok=True) results = compare_agents(agents, num_games=240, game_manager=manager) with open(result_dir / f"{model_dir.name}.json", "w") as f: json.dump( { "results": results, "complex_simulations": complex_agent.simulation_stats, "simple_simulations": simple_agent.simulation_stats, }, f, )
class GTPInterface(ABC, Generic[_S]): commands: Dict[str, Callable[[List[str]], Optional[str]]] game_manager: GameManager[_S] mcts: MCTS[_S] board_size: int def __init__(self, board_size: int) -> None: self.commands = { "name": self.name, "version": self.version, "protocol_version": self.protocol_version, "known_command": self.known_command, "list_commands": self.list_commands, "quit": self.quit, "boardsize": self.boardsize, "clear_board": self.clear_board, "play": self.play, "genmove": self.genmove, "showboard": self.showboard, # "set_time": self.set_time, "result": self.result, } self.board_size = board_size self.net = self.get_game_net(board_size).to(torch.device("cuda:1")) self.game_manager = self.net.manager self.mcts = MCTS( self.game_manager, num_simulations=100, rollout_policy=None, state_evaluator=cached_state_evaluator(self.net), ) def run_command(self, command: str) -> Optional[str]: command, *args = command.split() if command in self.commands: return self.commands[command](args) else: raise ValueError("invalid command") def name(self, args: List[str]) -> str: return "Deep MCTS" def version(self, args: List[str]) -> str: return "0.0.1" def protocol_version(self, args: List[str]) -> str: return "2" def known_command(self, args: List[str]) -> str: if len(args) != 1: raise ValueError("known_command takes 1 argument") command = args[0] known = command in self.commands return str(known).lower() def list_commands(self, args: List[str]) -> str: return "\n" + "\n".join(list(self.commands)) def quit(self, args: List[str]) -> NoReturn: sys.exit(0) def boardsize(self, args: List[str]) -> None: # allow 2 arguments because HexGui passes the board size twice if len(args) not in [1, 2]: raise ValueError("boardsize takes 1 argument") try: board_size = int(args[0]) except ValueError: raise ValueError("invalid board size") if board_size < 1: raise ValueError("invalid board size") self.board_size = board_size self.clear_board([]) def clear_board(self, args: List[str]) -> None: self.net = self.get_game_net(self.board_size) self.game_manager = self.net.manager self.mcts = MCTS( self.game_manager, num_simulations=100, rollout_policy=None, state_evaluator=cached_state_evaluator(self.net), ) def play(self, args: List[str]) -> None: if len(args) != 2: raise ValueError("play takes 2 arguments") player = self.parse_player(args[0]) action = self.parse_move(args[1], self.board_size) if action not in self.game_manager.legal_actions(self.mcts.state): raise ValueError("illegal move") actual_player = self.mcts.state.player if actual_player != player: self.mcts.state = dataclasses.replace(self.mcts.state, player=player) if __debug__ and action not in self.mcts.root.children: assert len(self.mcts.root.children) == 0 self.mcts.root = self.mcts.root.children.get(action, Node()) self.mcts.state = self.game_manager.generate_child_state( self.mcts.state, action) def genmove(self, args: List[str]) -> str: if len(args) != 1: raise ValueError("play takes 1 argument") player = self.parse_player(args[0]) actual_player = self.mcts.state.player if actual_player != player: self.mcts.state = dataclasses.replace(self.mcts.state, player=player) action_probabilities, _ = self.mcts.step() value, net_action_probabilities = self.net.evaluate(self.mcts.state) print(value, file=sys.stderr) print( self.game_manager.probabilities_grid( net_action_probabilities), # type: ignore file=sys.stderr, ) print(file=sys.stderr) print(self.game_manager.probabilities_grid(action_probabilities), file=sys.stderr) action = np.argmax(action_probabilities) self.mcts.root = self.mcts.root.children[action] self.mcts.state = self.mcts.game_manager.generate_child_state( self.mcts.state, action) return self.format_move(action, self.board_size) def showboard(self, args: List[str]) -> str: return f"\n{self.mcts.state}" def result(self, args: List[str]) -> str: return str(self.game_manager.evaluate_final_state(self.mcts.state)) @staticmethod def parse_player(player: str) -> int: player = player.lower() if player == "w": player = "white" elif player == "b": player = "black" if player not in ["white", "black"]: raise ValueError("invalid player") return Player.FIRST if player == "black" else Player.SECOND @staticmethod @abstractmethod def parse_move(move: str, board_size: int) -> Action: ... @staticmethod def format_player(player: int) -> str: return "black" if player == Player.FIRST else "white" @staticmethod @abstractmethod def format_move(move: Action, board_size: int) -> str: ... @staticmethod @abstractmethod def get_game_net(board_size: int) -> GameNet[_S]: ... def start(self) -> None: while True: command = input("") if not command: continue try: result = self.run_command(command) except ValueError as e: print(f"? {e}\n") else: if result is None: print("= \n") else: print(f"= {result}\n")