def run_model(args):
    env = TankEnv(args.game_path,
                  opp_fp_and_elo=[(args.opp, 1000)],
                  game_port=args.base_port,
                  my_port=args.my_port,
                  image_based=args.image_based,
                  level_path=args.level_path,
                  rand_opp=args.rand_opp,
                  p=args.env_p,
                  opp_p=args.opp_env_p)
    model = None
    if args.p1:
        model = PPO.load(args.p1)
    elif args.p1same:
        model = PPO.load(args.opp)

    score = [0, 0, 0]
    print("Score: [Player1 Wins, Player2 Wins, Ties]")

    obs = env.reset()
    if args.image_based and (args.ai_view or args.rev_ai_view):
        fig = plt.gcf()
        fig.show()
        fig.canvas.draw()
    while True:
        if args.image_based and (args.ai_view or args.rev_ai_view):
            if not args.rev_ai_view:
                plt.imshow(obs, origin="lower")
            else:
                plt.imshow(env.opp_state, origin="lower")
            fig.canvas.draw()
        if model:
            action, _ = model.predict(obs)
        elif args.rand_p1:
            action = np.random.rand(5) * 2 - 1
        else:
            action = np.zeros(5, dtype=np.float32)
        obs, reward, done, info = env.step(action)
        if done:
            score[info["winner"]] += 1
            print("Score:", score)
            obs = env.reset()
class AIMatchmaker(gym.Env):
    metadata = {'render.modes': None}

    def __init__(self,
                 all_stats,
                 all_opps,
                 all_elos,
                 game_path,
                 model_dir,
                 base_port=50000,
                 my_port=50001,
                 image_based=False,
                 level_path=None,
                 env_p=3,
                 starting_elo=None,
                 K=16,
                 D=5.,
                 time_reward=-0.003,
                 matchmaking_mode=0,
                 elo_log_interval=10000,
                 win_loss_ratio=[0, 0]):
        super(AIMatchmaker, self).__init__()

        self.all_stats = combine_winrates(all_stats)
        self.all_opps = all_opps
        self.all_elos = all_elos
        self.model_dir = model_dir

        self.agent_elo = starting_elo if starting_elo != None else self.all_elos[
            0]
        self.env = TankEnv(game_path,
                           opp_fp_and_elo=[],
                           game_port=base_port,
                           my_port=my_port,
                           image_based=image_based,
                           level_path=level_path,
                           p=env_p,
                           time_reward=time_reward)
        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space

        self.K = K
        self.D = D
        self.my_port = my_port
        self.mm = matchmaking_mode

        self.uncounted_games = np.array([0, 0], dtype=np.uint32)
        self.counted_game_sets = 0
        self.win_loss_ratio = np.array(win_loss_ratio, dtype=np.uint32)

        self.started = False
        self.next_opp()

        self.elo_log_interval = elo_log_interval
        self.num_steps = 0
        self.elo_log = []

    def next_opp(self):
        weights = np.zeros((len(self.all_elos)), dtype=np.float32)
        if self.mm == 1:
            # ELO based matchmaking, where ELOs closer to agent ELo is prefered (but not guarenteed)
            weights += np.array([
                weight_func(elo - self.agent_elo, self.D)
                for elo in self.all_elos
            ],
                                dtype=np.float32)

        if any(self.win_loss_ratio):
            while all(self.uncounted_games >= self.win_loss_ratio):
                self.uncounted_games -= self.win_loss_ratio
                self.counted_game_sets += 1

            tmp = self.uncounted_games >= self.win_loss_ratio
            if tmp[0] and not tmp[1]:
                # Need more losses
                if self.mm == 1:
                    # Zero weights for opponents that have <= ELOs than agent
                    for i, elo in enumerate(self.all_elos):
                        if elo <= self.agent_elo:
                            weights[i] = 0
                    # Choose agent with highest ELO if agent ELO is higher than all opponent ELOs
                    if sum(weights) == 0:
                        weights[self.all_elos.index(max(self.all_elos))] = 1
                else:
                    # Equal probability for opponents that have > ELOs than agent
                    for i, elo in enumerate(self.all_elos):
                        if elo > self.agent_elo:
                            weights[i] = 1
                    # Choose agent with highest ELO if agent ELO is higher than all opponent ELOs
                    if sum(weights) == 0:
                        weights[self.all_elos.index(max(self.all_elos))] = 1
            elif not tmp[0] and tmp[1]:
                # Need more wins
                if self.mm == 1:
                    # Zero weights for opponents that have >= ELOs than agent
                    for i, elo in enumerate(self.all_elos):
                        if elo >= self.agent_elo:
                            weights[i] = 0
                    # Choose agent with lowest ELO if agent ELO is higher than all opponent ELOs
                    if sum(weights) == 0:
                        weights[self.all_elos.index(min(self.all_elos))] = 1
                else:
                    # Equal probability for opponents that have < ELOs than agent
                    for i, elo in enumerate(self.all_elos):
                        if elo < self.agent_elo:
                            weights[i] = 1
                    # Choose agent with highest ELO if agent ELO is higher than all opponent ELOs
                    if sum(weights) == 0:
                        weights[self.all_elos.index(min(self.all_elos))] = 1

        self.current_opp_idx = choice_with_normalization(
            [i for i in range(len(self.all_elos))], weights)
        self.current_opp = self.all_opps[self.current_opp_idx]
        self.current_opp_elo = self.all_elos[self.current_opp_idx]
        #print("thread", self.my_port, "current opp elo:", self.current_opp_elo, "agent elo:", self.agent_elo, flush=True)
        self.env.load_new_opp(0, opp_fp(self.model_dir, self.current_opp),
                              self.current_opp_elo)

    def get_agent_elo(self):
        return self.agent_elo

    def reset(self):
        if self.started:
            last_winner = self.env.last_winner
            if last_winner == 0:
                win_rate = 1.
                self.uncounted_games[0] += 1
            elif last_winner == 1:
                win_rate = 0.
                self.uncounted_games[1] += 1
            else:
                win_rate = .5

            agent_elo_change, _ = elo_change(self.agent_elo,
                                             self.current_opp_elo, self.K,
                                             win_rate)
            self.agent_elo += int(agent_elo_change)
            #print("THREAD", self.my_port, "CURRENT AGENT ELO:", self.agent_elo, flush=True)
        else:
            self.started = True

        self.next_opp()
        return self.env.reset()

    def step(self, action):
        if self.num_steps % self.elo_log_interval == 0:
            self.elo_log.append(self.agent_elo)
        self.num_steps += 1
        return self.env.step(action)

    def render(self, mode='console'):
        raise NotImplementedError()

    def close(self):
        self.env.close()
if args.image_based:
    model = PPO("CnnPolicy", env, n_steps=64)
else:
    model = PPO("MlpPolicy", env, n_steps=64)
    
print(model.policy)
  
try:
    if args.train:
        model.learn(total_timesteps=args.num_steps)
    else:
        obs = env.reset()
        if args.image_based and args.ai_view:
            fig = plt.gcf()
            fig.show()
            fig.canvas.draw()
        for _ in tqdm(range(args.num_steps)):
            if args.image_based and args.ai_view:
                plt.imshow(obs, origin="lower", interpolation='none')
                fig.canvas.draw()
            if model:
                action, _ = model.predict(obs)
            elif args.rand_p1:
                action = np.random.rand(5) * 2 - 1
            else:
                action = np.zeros(5, dtype=np.float32)
            obs, reward, done, info = env.step(action)
            if done:
                obs = env.reset()
finally:
    env.close()
                game_port=args.base_port+port, 
                my_port=args.base_port+port+1,
                level_path=args.level_path,
                image_based=pop_stats[p_idx]["image_based"],
                p=pop_stats[p_idx]["env_p"],
                verbose=True
                )
                
            print("Worker", args.worker_idx, "got here", 4, flush=True)
                
            for i,opp in enumerate(tqdm(pop, file=sys.stdout)):
                env.load_new_opp(0, curr_model_path(args.local_pop_dir, opp, pop_stats[pop.index(opp)]), 0)
                for j in range(args.N):
                    obs = env.reset()
                    side = -1 if args.from_right else 1
                    while env.raw_state[0] * side > 0:
                        obs = env.reset()
                        
                    done=False
                    for k in range(args.max_len+1):
                        traj_set[i,j,k,:,:,:] = obs
                        if done or k==args.max_len:
                            info_set[i,j] = k
                            break
                        else:
                            action, _ = p_model.predict(obs)
                            obs,_,done,_ = env.step(action)

            np.savez_compressed(args.model_dir + p + "/" + args.save_name, traj=traj_set, info=info_set)
        finally:
            env.close()