예제 #1
0
 def __init__(self, board_size: int) -> None:
     self.commands = {
         "name": self.name,
         "version": self.version,
         "protocol_version": self.protocol_version,
         "known_command": self.known_command,
         "list_commands": self.list_commands,
         "quit": self.quit,
         "boardsize": self.boardsize,
         "clear_board": self.clear_board,
         "play": self.play,
         "genmove": self.genmove,
         "showboard": self.showboard,
         # "set_time": self.set_time,
         "result": self.result,
     }
     self.board_size = board_size
     self.net = self.get_game_net(board_size).to(torch.device("cuda:1"))
     self.game_manager = self.net.manager
     self.mcts = MCTS(
         self.game_manager,
         num_simulations=100,
         rollout_policy=None,
         state_evaluator=cached_state_evaluator(self.net),
     )
예제 #2
0
 def clear_board(self, args: List[str]) -> None:
     self.net = self.get_game_net(self.board_size)
     self.game_manager = self.net.manager
     self.mcts = MCTS(
         self.game_manager,
         num_simulations=100,
         rollout_policy=None,
         state_evaluator=cached_state_evaluator(self.net),
     )
예제 #3
0
파일: train.py 프로젝트: henribru/deep-mcts
def evaluate(
    game_net: GameNet[_S],
    previous_game_net: GameNet[_S],
    game_manager: GameManager[_S],
    config: TrainingConfiguration[_S],
) -> Tuple[AgentComparison, AgentComparison]:
    state_evaluator = cached_state_evaluator(game_net)
    previous_state_evaluator = cached_state_evaluator(previous_game_net)
    random_mcts_evaluation = compare_agents(
        (
            MCTSAgent(
                MCTS(
                    game_manager,
                    config.num_simulations,
                    config.rollout_policy,
                    state_evaluator=state_evaluator,
                )),
            MCTSAgent(
                MCTS(
                    game_manager,
                    config.num_simulations * 2,
                    lambda s: random.choice(game_manager.legal_actions(s)),
                    state_evaluator=None,
                )),
        ),
        config.evaluation_games,
        game_manager,
    )
    previous_evaluation = compare_agents(
        (
            MCTSAgent(
                MCTS(
                    game_manager,
                    config.num_simulations,
                    config.rollout_policy,
                    state_evaluator,
                ),
                epsilon=config.epsilon,
            ),
            MCTSAgent(
                MCTS(
                    game_manager,
                    config.num_simulations,
                    config.rollout_policy,
                    previous_state_evaluator,
                ),
                epsilon=config.epsilon,
            ),
        ),
        config.evaluation_games,
        game_manager,
    )
    return random_mcts_evaluation, previous_evaluation
예제 #4
0
파일: train.py 프로젝트: henribru/deep-mcts
def create_self_play_examples(
    process_number: int,
    game_net: GameNet[_S],
    last_trained_iteration: torch.Tensor,
    config: TrainingConfiguration[_S],
    games_queue: "multiprocessing.Queue[SelfPlayGame[_S]]",
) -> None:
    game_manager = game_net.manager
    last_cached_iteration = 0

    # Use a uniform evaluator as the starting point
    def uniform_state_evaluator(state: _S) -> Tuple[float, List[float]]:
        legal_actions = set(game_manager.legal_actions(state))
        return (
            0.5,
            [
                1 / len(legal_actions) if action in legal_actions else 0.0
                for action in range(game_manager.num_actions)
            ],
        )

    state_evaluator: StateEvaluator[_S] = uniform_state_evaluator
    for i in range(config.num_games):
        # Recreate the cache if the network has been trained since
        # we last created the cache
        last_trained_iteration_value = cast(int, last_trained_iteration.item())
        if last_trained_iteration_value > last_cached_iteration:
            state_evaluator = cached_state_evaluator(game_net)
            last_cached_iteration = last_trained_iteration_value
        mcts = MCTS(
            game_manager,
            config.num_simulations,
            config.rollout_policy,
            state_evaluator,
            config.sample_move_cutoff,
            config.dirichlet_alpha,
            config.dirichlet_factor,
        )
        examples = []
        for state, next_state, action, visit_distribution in mcts.self_play():
            examples.append((state, visit_distribution))
        # The network uses a range of [-1, 1]
        outcome = (
            cast(float,
                 game_manager.evaluate_final_state(next_state).value) * 2 - 1)
        games_queue.put([(
            state,
            visit_distribution,
            outcome if state.player == Player.max_player() else -outcome,
        ) for state, visit_distribution in examples])
        if i % 100 == 0 and process_number == 0:
            print(f"{time.strftime('%H:%M:%S')} {i}")
예제 #5
0
def evaluate_models(
    model_dir: Path,
    net_class: Type[GameNet[_S]],
    manager: GameManager[_S],
    device: torch.device,
) -> None:
    model_file = max(
        (f for f in model_dir.iterdir() if f.name.endswith(".tar")),
        key=lambda f: int(f.name[5:-4]),
    )

    model = net_class.from_path_full(str(model_file), manager).to(device)
    state_evaluator = cached_state_evaluator(model)
    agents = [
        MCTSAgent(
            MCTS(
                manager,
                num_simulations=50,
                rollout_policy=(
                    lambda state: random.choice(manager.legal_actions(state))
                )
                if i > 0
                else None,
                state_evaluator=state_evaluator,
                rollout_share=i / 100,
            )
        )
        for i in range(0, 101, 20)
    ]

    result_dir = model_dir.parent.parent / "simple_rollouts"
    result_dir.mkdir(exist_ok=True)
    print(model_dir.name)
    results = tournament(agents, num_games=40, game_manager=manager)
    with open(result_dir / f"{model_dir.name}.json", "w") as f:
        json.dump(results, f)
예제 #6
0
def evaluate_models(
    model_dir: Path,
    net_class: Type[GameNet[_S]],
    manager: GameManager[_S],
    device: torch.device,
) -> None:
    print(model_dir.name)
    model_file = max(
        (f for f in model_dir.iterdir() if f.name.endswith(".tar")),
        key=lambda f: int(f.name[5:-4]),
    )
    time_per_move = 1.0
    for with_state_evaluator, dir_name in [
        (True, "with_state_evaluator"),
        (False, "without_state_evaluator"),
    ]:
        manager_copy = manager.copy()
        model = net_class.from_path_full(str(model_file), manager_copy).to(device)
        cached_model = cached_state_evaluator(model)
        complex_agent = MCTSAgent(
            MCTS(
                manager_copy,
                num_simulations=float("inf"),  # type: ignore[arg-type]
                rollout_policy=lambda state: np.argmax(  # type: ignore[no-any-return]
                    cached_model(state)[1]
                ),
                state_evaluator=cached_model if with_state_evaluator else None,
                rollout_share=1.0,
                time_per_move=time_per_move,
            ),
            reset_fn=reset_caches,
        )
        manager_copy = manager.copy()
        state_evaluator: Optional[Callable[[_S], Tuple[float, Sequence[float]]]]
        if with_state_evaluator:
            model = net_class.from_path_full(str(model_file), manager_copy).to(device)
            state_evaluator = cached_state_evaluator(model)
        else:
            state_evaluator = None

        simple_agent = MCTSAgent(
            MCTS(
                manager_copy,
                num_simulations=float("inf"),  # type: ignore[arg-type]
                rollout_policy=lambda state: random.choice(
                    manager_copy.legal_actions(state)
                ),
                state_evaluator=state_evaluator,
                rollout_share=1.0,
                time_per_move=time_per_move,
            ),
            reset_fn=reset_caches,
        )
        agents = (complex_agent, simple_agent)
        result_dir = model_dir.parent.parent / "complex_rollouts" / dir_name
        result_dir.mkdir(exist_ok=True)
        results = compare_agents(agents, num_games=240, game_manager=manager)

        with open(result_dir / f"{model_dir.name}.json", "w") as f:
            json.dump(
                {
                    "results": results,
                    "complex_simulations": complex_agent.simulation_stats,
                    "simple_simulations": simple_agent.simulation_stats,
                },
                f,
            )
예제 #7
0
class GTPInterface(ABC, Generic[_S]):
    commands: Dict[str, Callable[[List[str]], Optional[str]]]
    game_manager: GameManager[_S]
    mcts: MCTS[_S]
    board_size: int

    def __init__(self, board_size: int) -> None:
        self.commands = {
            "name": self.name,
            "version": self.version,
            "protocol_version": self.protocol_version,
            "known_command": self.known_command,
            "list_commands": self.list_commands,
            "quit": self.quit,
            "boardsize": self.boardsize,
            "clear_board": self.clear_board,
            "play": self.play,
            "genmove": self.genmove,
            "showboard": self.showboard,
            # "set_time": self.set_time,
            "result": self.result,
        }
        self.board_size = board_size
        self.net = self.get_game_net(board_size).to(torch.device("cuda:1"))
        self.game_manager = self.net.manager
        self.mcts = MCTS(
            self.game_manager,
            num_simulations=100,
            rollout_policy=None,
            state_evaluator=cached_state_evaluator(self.net),
        )

    def run_command(self, command: str) -> Optional[str]:
        command, *args = command.split()
        if command in self.commands:
            return self.commands[command](args)
        else:
            raise ValueError("invalid command")

    def name(self, args: List[str]) -> str:
        return "Deep MCTS"

    def version(self, args: List[str]) -> str:
        return "0.0.1"

    def protocol_version(self, args: List[str]) -> str:
        return "2"

    def known_command(self, args: List[str]) -> str:
        if len(args) != 1:
            raise ValueError("known_command takes 1 argument")
        command = args[0]
        known = command in self.commands
        return str(known).lower()

    def list_commands(self, args: List[str]) -> str:
        return "\n" + "\n".join(list(self.commands))

    def quit(self, args: List[str]) -> NoReturn:
        sys.exit(0)

    def boardsize(self, args: List[str]) -> None:
        # allow 2 arguments because HexGui passes the board size twice
        if len(args) not in [1, 2]:
            raise ValueError("boardsize takes 1 argument")
        try:
            board_size = int(args[0])
        except ValueError:
            raise ValueError("invalid board size")
        if board_size < 1:
            raise ValueError("invalid board size")
        self.board_size = board_size
        self.clear_board([])

    def clear_board(self, args: List[str]) -> None:
        self.net = self.get_game_net(self.board_size)
        self.game_manager = self.net.manager
        self.mcts = MCTS(
            self.game_manager,
            num_simulations=100,
            rollout_policy=None,
            state_evaluator=cached_state_evaluator(self.net),
        )

    def play(self, args: List[str]) -> None:
        if len(args) != 2:
            raise ValueError("play takes 2 arguments")
        player = self.parse_player(args[0])
        action = self.parse_move(args[1], self.board_size)
        if action not in self.game_manager.legal_actions(self.mcts.state):
            raise ValueError("illegal move")
        actual_player = self.mcts.state.player
        if actual_player != player:
            self.mcts.state = dataclasses.replace(self.mcts.state,
                                                  player=player)
        if __debug__ and action not in self.mcts.root.children:
            assert len(self.mcts.root.children) == 0
        self.mcts.root = self.mcts.root.children.get(action, Node())
        self.mcts.state = self.game_manager.generate_child_state(
            self.mcts.state, action)

    def genmove(self, args: List[str]) -> str:
        if len(args) != 1:
            raise ValueError("play takes 1 argument")
        player = self.parse_player(args[0])
        actual_player = self.mcts.state.player
        if actual_player != player:
            self.mcts.state = dataclasses.replace(self.mcts.state,
                                                  player=player)

        action_probabilities, _ = self.mcts.step()
        value, net_action_probabilities = self.net.evaluate(self.mcts.state)
        print(value, file=sys.stderr)
        print(
            self.game_manager.probabilities_grid(
                net_action_probabilities),  # type: ignore
            file=sys.stderr,
        )
        print(file=sys.stderr)
        print(self.game_manager.probabilities_grid(action_probabilities),
              file=sys.stderr)
        action = np.argmax(action_probabilities)
        self.mcts.root = self.mcts.root.children[action]
        self.mcts.state = self.mcts.game_manager.generate_child_state(
            self.mcts.state, action)
        return self.format_move(action, self.board_size)

    def showboard(self, args: List[str]) -> str:
        return f"\n{self.mcts.state}"

    def result(self, args: List[str]) -> str:
        return str(self.game_manager.evaluate_final_state(self.mcts.state))

    @staticmethod
    def parse_player(player: str) -> int:
        player = player.lower()
        if player == "w":
            player = "white"
        elif player == "b":
            player = "black"
        if player not in ["white", "black"]:
            raise ValueError("invalid player")
        return Player.FIRST if player == "black" else Player.SECOND

    @staticmethod
    @abstractmethod
    def parse_move(move: str, board_size: int) -> Action:
        ...

    @staticmethod
    def format_player(player: int) -> str:
        return "black" if player == Player.FIRST else "white"

    @staticmethod
    @abstractmethod
    def format_move(move: Action, board_size: int) -> str:
        ...

    @staticmethod
    @abstractmethod
    def get_game_net(board_size: int) -> GameNet[_S]:
        ...

    def start(self) -> None:
        while True:
            command = input("")
            if not command:
                continue
            try:
                result = self.run_command(command)
            except ValueError as e:
                print(f"? {e}\n")
            else:
                if result is None:
                    print("= \n")
                else:
                    print(f"= {result}\n")