def __init__(self, game: Game, state_encoder: GameStateEncoder, nn: torch.nn.Module, config: Dict[str, Any]) -> None: super().__init__() self.game = game self.mcts = MonteCarloTreeSearch(game=game, state_encoder=state_encoder, nn=nn, config=config)
class AlphaZeroArgMaxAgent(Agent): def __init__(self, game: Game, state_encoder: GameStateEncoder, nn: torch.nn.Module, config: Dict[str, Any]) -> None: super().__init__() self.game = game self.mcts = MonteCarloTreeSearch(game=game, state_encoder=state_encoder, nn=nn, config=config) def select_move(self, state: GameState) -> Move: policy = self.mcts.get_policy(state, temperature=0) move_index = np.random.choice(self.game.action_space_size, p=policy) return self.game.index_to_move(move_index) def reset(self) -> None: self.mcts.reset()
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' def read_move(player: TicTacToePlayer) -> TicTacToeMove: x, y = input(f"{player.name} move: ").split() x, y = int(x), int(y) return TicTacToeMove(x, y) if __name__ == '__main__': game = TicTacToeGame(config['game_size']) state_encoder = TicTacToeStateEncoder(config['device']) net = dual_resnet(game, config) mcts = MonteCarloTreeSearch(game=game, state_encoder=state_encoder, nn=net, config=config) net.load_state_dict( torch.load(os.path.join('pretrained', 'ttt_dualres_comp.pth'))) # net.load_state_dict(torch.load(os.path.join(config['log_dir'], 'best.pth'))) net.eval() agent = AlphaZeroArgMaxAgent(game, state_encoder, net, config) agent_role = random.choice([TicTacToePlayer.X, TicTacToePlayer.O]) while not game.is_over: game.show_board() # print(f"current state score by eval func: {agent.eval_fn(game.state, agent.player)}") if game.current_player == agent_role: move = read_move(game.current_player)