Esempio n. 1
0
def sample(env, batch_size, prob):
    # Define the player
    player_id = 1
    # Set up the player here. We used the SimpleAI that does not take actions for now
    player = wimblepong.SimpleAi(env, player_id)

    samples = []
    i = 0

    # run until the data set has been sampled
    while True:
        done = False
        while not done:
            action1 = player.get_action()
            ob1, rew1, done, info = env.step(action1)
            if args.housekeeping and np.random.uniform() > prob:
                samples.append([preprocess(ob1), env.ball.x, env.player2.y])
            if not args.headless:
                env.render()
            if done:
                plt.close()  # Hides game window

                print("episode {} over.".format(i))
                if i % 5 == 4:
                    # env.switch_sides() do not switch sides
                    print("Current samples: ", len(samples))

                env.reset()
        if len(samples) >= batch_size:
            break
        i += 1
    print("Sampling done")

    random.shuffle(samples)
    return samples[:batch_size]
Esempio n. 2
0
    def __init__(self):
        self.eps = np.finfo(np.float32).eps.item()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.env = gym.make("WimblepongVisualMultiplayer-v0")
        self.env.unwrapped.scale, self.env.unwrapped.fps = 1, 30
        self.opponent = wimblepong.SimpleAi(self.env, 2)
        self.env.set_names('player', self.opponent.get_name())

        self.desired_frame_size = (1, 50, 50)
Esempio n. 3
0
def test(render=False, checkpoint='model.mdl'):
    env = gym.make("WimblepongVisualMultiplayer-v0")
    env.unwrapped.scale, env.unwrapped.fps = 1, 30

    player = Agent()
    player.load_model(checkpoint)
    opponent = wimblepong.SimpleAi(env, 2)
    env.set_names(player.get_name(), opponent.get_name())

    steps, episode = 0, 0
    data, run_avg_plot = [], []
    nice_plot = []
    while episode <= 12:
        (observation, observation2) = env.reset()  # 200 x 200 x 3 each obs
        cur_state = prepro(observation)
        prev_state = None
        while True:
            if render:
                env.render()
            x = cur_state - prev_state if prev_state is not None else cur_state
            cv2.imshow('x ', np.reshape(x, (100, 100, 1)))
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()

            x = torch.tensor(x).to(player.train_device)
            action = player.get_action(x)
            action2 = opponent.get_action()
            (observation, observation2), (r1, r2), done, info = env.step(
                (action, action2))

            reward = np.sign(r1)
            prev_state = cur_state
            cur_state = prepro(observation)
            player.store_info(s=x, r=reward)

            if done:
                break

        R_sum = np.sum(player.rewards)
        print("Total reward for episode {}: {}".format(episode, R_sum))
        data.append(1 if R_sum > 0 else 0)
        nice_plot.append(1 if R_sum > 0 else 0)
        run_avg_plot.append(np.mean(data))
        player.reset()
        episode += 1

    plt.figure(figsize=(12, 10))
    plt.plot(data, label='Win rate')
    plt.plot(run_avg_plot, label='Win rate avg', linewidth=5.0)
    plt.legend()
    plt.show()
    plt.figure(figsize=(12, 10))
    print('Avg win rate:{}'.format(np.average(data)))
    print('nice_plot ', np.shape(nice_plot))
Esempio n. 4
0
# check if gpu is available
print("Cuda:", torch.cuda.is_available())
print("Start Training")

# saving folders
model_path = "./train_weights/"
plot_data_path = "./plot_data/"

# Define players and opponents
player_id = 1
opponent1_id = 2
opponent2_id = 3

player = Agent()
simple_opponent = wimblepong.SimpleAi(env, opponent1_id)
complex_opponent = Agent()

# load existing model for player
if args.load_model_path:
    player.load_model(path=args.load_model_path)

# load existing model for complex opponent
if args.load_model_path_opponent:
    complex_opponent.load_model(path=args.load_model_path_opponent)

# initialize variables
episodes = 2000000
wins = 0
frames_seen = 0
scores = [0 for _ in range(100)]
Esempio n. 5
0
                    type=int,
                    help="Scale of the rendered game",
                    default=4)
args = parser.parse_args()

# Make the environment
env = gym.make("WimblepongVisualMultiplayer-v0")
env.unwrapped.scale = args.scale
env.unwrapped.fps = args.fps
# Number of episodes/games to play
episodes = 100000

# Define the player IDs for both SimpleAI agents
player_id = 1
opponent_id = 3 - player_id
opponent = wimblepong.SimpleAi(env, opponent_id)
player = Agent(input_shape=(1, 84, 84),
               num_actions=3,
               network_fn=DuelingDQN,
               network_fn_kwargs=None,
               minibatch_size=128,
               replay_memory_size=500000,
               stack_size=4,
               gamma=0.98,
               beta0=0.9,
               beta1=0.999,
               learning_rate=1e-4,
               device='cuda',
               normalize=False,
               noisy=False,
               prioritized=True)
Esempio n. 6
0
                    help="Scale of the rendered game",
                    default=1)
args = parser.parse_args()

# Make the environment
env = gym.make("WimblepongSimpleAI-v0")
env.unwrapped.scale = args.scale
env.unwrapped.fps = args.fps

# Number of episodes/games to play
episodes = 100000

# Define the player
player_id = 1
# Set up the player here. We used the SimpleAI that does not take actions for now
player = wimblepong.SimpleAi(env, player_id)

# Housekeeping
states = []
win1 = 0

for i in range(0, episodes):
    done = False
    while not done:
        # action1 is zero because in this example no agent is playing as player 0
        action1 = 0  #player.get_action()
        ob1, rew1, done, info = env.step(action1)
        if args.housekeeping:
            states.append(ob1)
        # Count the wins
        if rew1 == 10:
Esempio n. 7
0
import wimblepong
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import cv2

eps = np.finfo(np.float32).eps.item()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

env = gym.make("WimblepongVisualMultiplayer-v0")
env.unwrapped.scale, env.unwrapped.fps = 1, 30
opponent = wimblepong.SimpleAi(env, 2)
env.set_names('player', opponent.get_name())

desired_frame_size = (1, 100, 100)


def prepro(image):
    s = image / 255.0  # normalize
    s = s[::2, ::2].mean(axis=-1)  # 100x100
    s = np.expand_dims(s, axis=-1)  # 100x100x1
    state = np.reshape(s, desired_frame_size)  #(1, 100, 100)
    return state


def discount_rewards(r, gamma=0.99):
    discounted_r = np.zeros_like(r)
parser.add_argument("--headless", action="store_true", help="Run in headless mode")
parser.add_argument("--fps", type=int, help="FPS for rendering", default=30)
parser.add_argument("--scale", type=int, help="Scale of the rendered game", default=1)
args = parser.parse_args()

# Make the environment
env = gym.make("WimblepongVisualMultiplayer-v0")
env.unwrapped.scale = args.scale
env.unwrapped.fps = args.fps
# Number of episodes/games to play
episodes = 100000

# Define the player IDs for both SimpleAI agents
player_id = 1
opponent_id = 3 - player_id
opponent = wimblepong.SimpleAi(env, opponent_id)
player = wimblepong.SimpleAi(env, player_id)

# Set the names for both SimpleAIs
env.set_names(player.get_name(), opponent.get_name())

win1 = 0
for i in range(0,episodes):
    done = False
    while not done:
        # Get the actions from both SimpleAIs
        action1 = player.get_action()
        action2 = opponent.get_action()
        # Step the environment and get the rewards and new observations
        (ob1, ob2), (rew1, rew2), done, info = env.step((action1, action2))
        #print(ob1, ob2, ob1.shape, ob2.shape)
Esempio n. 9
0
def training_loop(submit_config,
                  num_episodes,
                  target_epsilon,
                  beta_0,
                  reach_target_at_frame,
                  player_id,
                  start_training_at_frame,
                  target_update_freq,
                  model_update_freq,
                  save_every_n_ep,
                  log_freq,
                  agent_config,
                  network_fn_kwargs,
                  clip_reward=False,
                  run_description='',
                  render=False):
    """Training loop for Pong agents."""
    run_dir = submit_config.run_dir

    # Make the environment
    env = make_pong_environment()

    # Set up the agent.
    agent = Agent(**agent_config, network_fn_kwargs=network_fn_kwargs)

    # Setup the opponent.
    opponent_id = 2
    opponent = wimblepong.SimpleAi(env, opponent_id)

    # Set the names for both SimpleAIs
    env.set_names(agent.get_name(), opponent.get_name())

    # Setup directories for models and logging.
    model_dir = os.path.join(run_dir, 'models')
    log_dir = os.path.join(run_dir, 'logs')

    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)

    # Initialize summary writer.
    writer = SummaryWriter(log_dir=log_dir, comment=run_description)

    # Initialize file for logging KPIs.
    model_perf_file = os.path.join(run_dir, 'model_perf.txt')
    with open(model_perf_file, 'w') as f:
        f.write('ep,mean_rewards,mean_ep_length,mean_wr\n')

    # Housekeeping
    max_reward = 10.0 if not clip_reward else 1.0
    wins = 0
    frames_seen = 0
    game_results = []
    reward_sums = []
    ep_lengths = []

    for ep in range(0, num_episodes):
        # Reset the Pong environment
        (agent_state, opp_state) = env.reset()
        done = False
        step = 0
        losses = []
        reward_sum = 0.0

        # Compute new epsilon and beta.
        epsilon = epsilon_schedule(frames_seen, target_epsilon,
                                   reach_target_at_frame)
        beta = beta_schedule(frames_seen, beta_0, reach_target_at_frame)

        start = time.time()
        while not done:
            # Get actions from agent and opponent.
            agent_action = agent.get_action(agent_state, epsilon=epsilon)
            opp_action = opponent.get_action(opp_state)

            # Step the environment and get the rewards and new observations
            (agent_next_state,
             opp_next_state), (agent_reward, _), done, info = env.step(
                 (agent_action, opp_action))

            # Clip reward.
            if clip_reward:
                agent_reward = max(-1., min(1., agent_reward))

            # Store transitions.
            agent.store_transition(agent_state, agent_action, agent_next_state,
                                   agent_reward, done)

            # See if theres enough frames to start training.
            if frames_seen > start_training_at_frame:
                if frames_seen % model_update_freq == model_update_freq - 1:
                    # Update policy network.
                    loss = agent.compute_loss(beta=beta)

                    # Update EMA network.
                    agent.update_ema_policy()

                if frames_seen % target_update_freq == target_update_freq - 1:  # Update target network.
                    agent.update_target_network()

            # Count the wins. Won't work with discounting.
            if agent_reward == max_reward:
                wins += 1
                game_results.append(1)

            if agent_reward == -max_reward:
                game_results.append(0)

            if render:
                env.render()

            if frames_seen > start_training_at_frame:
                if frames_seen % model_update_freq == model_update_freq - 1:
                    losses.append(loss)
            else:
                losses.append(0)

            agent_state = agent_next_state
            opp_state = opp_next_state
            reward_sum += agent_reward
            step += 1
            frames_seen += 1

        reward_sums.append(reward_sum)
        ep_lengths.append(step)
        elapsed_time = time.time() - start
        print(
            'buf_count %i, episode %i, end frame %i, tot. frames %i, eps %0.2f, wins %i, losses %i, %gs'
            % (agent.memory.count, ep, step, frames_seen, epsilon, wins,
               ep + 1 - wins, elapsed_time))

        # Log progress.
        if ep % log_freq == 0:
            # Write scalars.
            writer.add_scalar('Progress/Epsilon', epsilon, frames_seen)
            writer.add_scalar('Progress/Frames', frames_seen, frames_seen)

            if ep < 100:  # Log results and rewards from last n games.
                last_n_results = game_results
                last_n_reward_sums = reward_sums
                last_n_ep_lengths = ep_lengths
            else:
                last_n_results = game_results[-100:]
                last_n_reward_sums = reward_sums[-100:]
                last_n_ep_lengths = ep_lengths[-100:]

            cur_win_rate = np.mean(last_n_results)
            mean_rewards = np.mean(last_n_reward_sums)
            mean_ep_length = np.mean(last_n_ep_lengths)
            writer.add_scalar('Progress/Cumulative-reward', mean_rewards, ep)
            writer.add_scalar('Progress/Win-rate', cur_win_rate, ep)
            writer.add_scalar('Episode/Average-episode-length', mean_ep_length,
                              ep)
            writer.add_scalar('Episode/Loss', np.mean(losses), ep)

            # Show random batch of states.
            (state_batch, _, _, _, _), _, _ = agent.memory.sample_batch(5)
            n, c, h, w = state_batch.shape
            state_batch = state_batch.reshape(n * c, h, w)[:, None, :, :]
            writer.add_images('ReplayBuffer/Sample states', state_batch, ep)

        # Reset agent's internal state.
        agent.reset()

        if ep % save_every_n_ep == 0:
            torch.save(
                agent.policy_net.state_dict(),
                os.path.join(
                    model_dir,
                    'agent_%s_ep%i.mdl' % (agent_config.network_name, ep)))
            torch.save(
                agent.policy_net_ema.state_dict(),
                os.path.join(
                    model_dir,
                    'ema_agent_%s_ep%i.mdl' % (agent_config.network_name, ep)))

            perf_str = '%i,%g,%g,%g\n' % (ep, mean_rewards, mean_ep_length,
                                          cur_win_rate)
            with open(model_perf_file, 'a') as f:
                f.write(perf_str)
Esempio n. 10
0
def self_play_training_loop(submit_config,
                            player_run_id,
                            player_model_id,
                            opponent_run_ids,
                            opponent_model_ids,
                            total_rounds,
                            num_episodes,
                            start_training_at_frame,
                            target_update_freq,
                            model_update_freq,
                            save_every_n_ep,
                            log_freq,
                            epsilon,
                            learning_rate,
                            run_description='',
                            render=False):
    """Self-play training loop for Pong agents."""
    submission_run_dir = submit_config.run_dir
    run_dir_root = submit_config.run_dir_root

    # Make the environment
    env = make_pong_environment()

    # Locate opponent run dirs.
    run_dirs = []
    for run_id in opponent_run_ids:
        run_dir = [
            os.path.join(run_dir_root, d) for d in os.listdir(run_dir_root)
            if str(run_id).zfill(5) in d
        ]
        if run_dir:
            run_dirs.append(run_dir[0])

    # Load agent configs, network configs and network weights.
    opposing_agents = []
    for run_dir, model_id in zip(run_dirs, opponent_model_ids):
        # Load run config.
        with open(os.path.join(run_dir, 'run_func_args.pkl'), 'rb') as f:
            run_config = pickle.load(f)

        # Initialize agent.
        agent_config = run_config.agent_config
        network_fn_kwargs = run_config.network_fn_kwargs
        agent = Agent(network_fn_kwargs=network_fn_kwargs, **agent_config)

        # Load model weights.
        model_dir = os.path.join(run_dir, 'models')
        model_path = os.path.join(model_dir, [
            f for f in os.listdir(model_dir) if str(model_id) in f
        ][0])
        agent.load_model(model_path)

        # Append to player list.
        opposing_agents.append(
            (agent.get_name() + '_%s' % (agent_config.network_name), agent))

    # Add SimpleAI to opponents.
    opposing_agents.append(('SimpleAI', wimblepong.SimpleAi(env, 2)))
    opponent_run_ids.append(-1)

    # Load agent that is trained.
    target_run_dir = [
        os.path.join(run_dir_root, d) for d in os.listdir(run_dir_root)
        if str(player_run_id).zfill(5) in d
    ][0]
    print('Loading traget from: %s' % target_run_dir)

    # Load run config.
    with open(os.path.join(target_run_dir, 'run_func_args.pkl'), 'rb') as f:
        run_config = pickle.load(f)

    # Initialize agent.
    agent_config = run_config.agent_config
    agent_config.learning_rate = learning_rate
    network_fn_kwargs = run_config.network_fn_kwargs
    p1 = Agent(network_fn_kwargs=network_fn_kwargs, **agent_config)

    # Load model weights.
    model_dir = os.path.join(target_run_dir, 'models')
    model_path = os.path.join(model_dir, [
        f for f in os.listdir(model_dir) if str(player_model_id) in f
    ][0])
    p1.load_model(model_path)
    p1_name = p1.get_name() + '_%s' % (agent_config.network_name)

    # Setup directories for models and logging.
    model_dir = os.path.join(submission_run_dir, 'models')
    log_dir = os.path.join(submission_run_dir, 'logs')
    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)

    # Initialize summary writer.
    writer = SummaryWriter(log_dir=log_dir, comment=run_description)
    perf_file = os.path.join(submission_run_dir, 'win_rates.txt')

    # Housekeeping.
    num_opponents = len(opposing_agents)
    max_reward = 10.0
    wr_against = {}
    total_frames_seen = 0

    for total_ep in range(total_rounds):
        # Pick two agents uniform random.
        p2_idx = np.random.randint(num_opponents)
        p2_name, p2 = opposing_agents[p2_idx]

        # Setup players and housekeeping.
        env.set_names(p1_name, p2_name)
        frames_seen = 0
        p1_wins = 0
        p1_reward_sums = []
        p1_losses = []
        ep_lengths = []

        print('Training %s vs. %s...' % (p1_name, p2_name))
        for ep in range(num_episodes):
            # Reset the Pong environment.
            (p1_state, p2_state) = env.reset()
            p1_reward_sum = 0
            done = False
            step = 0

            while not done:
                # Get actions from agent and opponent.
                p1_action = p1.get_action(p1_state, epsilon=epsilon)
                p2_action = p2.get_action(
                    p2_state, epsilon=epsilon
                ) if p2_name != 'SimpleAI' else p2.get_action()

                # Step the environment and get the rewards and new observations
                (p1_next_state, p2_next_state), (p1_reward,
                                                 _), done, info = env.step(
                                                     (p1_action, p2_action))

                # Store transitions.
                p1.store_transition(p1_state, p1_action, p1_next_state,
                                    p1_reward, done)

                # See if theres enough frames to start training.
                if frames_seen >= start_training_at_frame:
                    if frames_seen % model_update_freq == model_update_freq - 1:
                        # Update policy networks.
                        loss_p1 = p1.compute_loss()

                        # Update EMA networks.
                        p1.update_ema_policy()
                else:
                    loss_p1 = 0

                if total_frames_seen % target_update_freq == target_update_freq - 1:  # Update target networks.
                    p1.update_target_network()

                # Count the wins. Won't work with discounting.
                if p1_reward == max_reward:
                    p1_wins += 1

                if render:
                    env.render()

                if frames_seen % model_update_freq == model_update_freq - 1:
                    p1_losses.append(loss_p1)

                p1_state = p1_next_state
                p2_state = p2_next_state
                p1_reward_sum += p1_reward
                step += 1
                frames_seen += 1
                total_frames_seen += 1

            p1_reward_sums.append(p1_reward_sum)
            ep_lengths.append(step)
            print(
                '%s vs. %s, episode %i/%i, end frame %i, frames %i, eps %0.2f, wins %i, losses %i'
                % (p1_name, p2_name, ep, num_episodes, step, frames_seen,
                   epsilon, p1_wins, ep + 1 - p1_wins))

            if ep % save_every_n_ep == 0:
                torch.save(
                    p1.policy_net.state_dict(),
                    os.path.join(model_dir,
                                 'agent_%s.mdl' % (agent_config.network_name)))
                torch.save(
                    p1.policy_net_ema.state_dict(),
                    os.path.join(
                        model_dir,
                        'ema_agent_%s.mdl' % (agent_config.network_name)))

                # Update WR against current opponent.
                key = 'wr_vs_run_id_%i' % opponent_run_ids[p2_idx]
                wr_against[key] = p1_wins / (ep + 1)

                perf_dict = {**{'round': [total_ep], 'ep': [ep]}, **wr_against}
                df = pd.DataFrame(data=perf_dict)
                df.to_csv(perf_file, header=True, index=False)

        print('WR against %s: %0.2f' % (p2_name, p1_wins / num_episodes))
        with open(os.path.join(submission_run_dir, 'wr.pkl'), 'wb') as f:
            pickle.dump(wr_against, f)
        print()

    # Dump final results.
    with open(os.path.join(submission_run_dir, 'wr.pkl'), 'wb') as f:
        pickle.dump(wr_against, f)
Esempio n. 11
0
def train(render=False, checkpoint='model.mdl'):
    env = gym.make("WimblepongVisualMultiplayer-v0")
    env.unwrapped.scale, env.unwrapped.fps = 1, 30

    player = Agent()
    opponent = wimblepong.SimpleAi(env, 2)
    env.set_names('Undisputed', opponent.get_name())

    episode, max_games, highest_running_winrate = 0, 10000, 0
    scores, game_lengths, game_lengths_Avg, run_avg_plot = [], [], [], []
    steps = 1500
    while True and episode <= max_games:
        (observation, observation2) = env.reset()  # 200 x 200 x 3 each obs
        cur_state = prepro(observation)
        prev_state = None
        for _ in range(steps):
            if render:
                env.render()
            x = cur_state - prev_state if prev_state is not None else cur_state
            cv2.imshow('x ', np.reshape(x, (100, 100, 1)))
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()

            x = torch.tensor(x).to(player.train_device)
            action = player.get_action(x)
            action2 = opponent.get_action()
            (observation, observation2), (r1, r2), done, info = env.step(
                (action, action2))

            reward = float(np.sign(r1))
            prev_state = cur_state
            cur_state = prepro(observation)
            player.store_info(s=x, r=reward)

            if done:
                break

        R_sum = np.sum(player.rewards)
        print("Total reward for episode {}: {}".format(episode, R_sum))
        game_lengths.append(len(player.rewards))
        scores.append(1) if R_sum > 0 else scores.append(0)

        # Update policy network
        player.update_policy()
        episode += 1

        if episode > 100:
            run_avg = np.mean(np.array(scores)[-100:])
            game_length_avg = np.mean(np.array(game_lengths)[-100:])
        else:
            run_avg = np.mean(np.array(scores))
            game_length_avg = np.mean(np.array(game_lengths))

        run_avg_plot.append(run_avg)
        game_lengths_Avg.append(game_length_avg)
        if episode % 100 == 0:  # run_avg  > highest_running_winrate:
            highest_running_winrate = run_avg
            print('highest_running_winrate ', highest_running_winrate)
            print("model_" + str(highest_running_winrate) + '.mdl')
            torch.save(player.policy.state_dict(),
                       "model_" + str(highest_running_winrate) + '.mdl')
            print(
                'Saved policy----------------------------------------------------------------'
            )

        if episode % 100 == 0:
            plt.figure(figsize=(12, 10))
            plt.plot(run_avg_plot, label='avg win rate')
            plt.legend()
            plt.show()
            plt.figure(figsize=(12, 10))
            plt.plot(game_lengths_Avg, label='avg timesteps')
            plt.legend()
            plt.show()
Esempio n. 12
0
def training_loop(num_episodes,
                  target_epsilon,
                  reach_target_at_frame,
                  player_id,
                  start_training_at_frame,
                  update_target_freq,
                  save_every_n_ep,
                  log_freq,
                  agent_config,
                  clip_reward=False,
                  run_description='',
                  render=False):
    """Training loop for Pong agents."""
    run_dir = os.path.dirname(os.path.abspath(__file__))

    # Make the environment
    env = make_pong_environment()

    # Set up the agent.
    agent = Agent(**agent_config)

    # Setup the opponent.
    opponent_id = 2
    opponent = wimblepong.SimpleAi(env, opponent_id)

    # Set the names for both SimpleAIs
    env.set_names(agent.get_name(), opponent.get_name())

    # Setup directories for models and logging.
    model_dir = os.path.join(run_dir, 'models')
    log_dir = os.path.join(run_dir, 'logs')

    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)

    # Initialize summary writer.
    writer = SummaryWriter(log_dir=log_dir, comment=run_description)

    # Housekeeping
    max_reward = 10.0 if not clip_reward else 1.0
    wins = 0
    frames_seen = 0
    game_results = []
    reward_sums = []

    for ep in range(0, num_episodes):
        # Reset the Pong environment
        (agent_state, opp_state) = env.reset()
        done = False
        step = 0
        actions_taken = []
        losses = []
        reward_sum = 0.0

        # Compute new epsilon.
        epsilon = epsilon_schedule(frames_seen, target_epsilon,
                                   reach_target_at_frame)

        while not done:
            # Get actions from agent and opponent.
            agent_action = agent.get_action(agent_state, epsilon=epsilon)
            opp_action = opponent.get_action(opp_state)

            # Step the environment and get the rewards and new observations
            (agent_next_state,
             opp_next_state), (agent_reward, _), done, info = env.step(
                 (agent_action, opp_action))

            # Clip reward.
            if clip_reward:
                agent_reward = max(-1., min(1., agent_reward))

            #if agent_reward == 0.0:
            #    agent_reward = 0.2
            #elif agent_reward == 10.0:
            #    agent_reward = 15.0
            #elif agent_reward == -10:
            #    agent_reward = -15.0

            # Store transitions.
            agent.store_transition(agent_state, agent_action, agent_next_state,
                                   agent_reward, done)

            # See if theres enough frames to start training.
            if frames_seen > start_training_at_frame:
                loss = agent.compute_loss()

                if frames_seen % update_target_freq == update_target_freq - 1:  # Update target network.
                    agent.update_target_network()

            # Count the wins. Won't work with discounting.
            if agent_reward == max_reward:  # 15.0
                wins += 1
                game_results.append(1)
            else:
                game_results.append(0)

            if render:
                env.render()

            if frames_seen > start_training_at_frame:
                losses.append(loss)
            else:
                losses.append(0)

            agent_state = agent_next_state
            opp_state = opp_next_state
            actions_taken.append(agent_action)
            reward_sum += agent_reward
            step += 1
            frames_seen += 1

        reward_sums.append(reward_sum)
        act_counts, _ = np.histogram(actions_taken, bins=[0, 1, 2, 3])
        actions = 'stay %i, up %i, down %i' % (act_counts[0], act_counts[1],
                                               act_counts[2])
        print(
            'buf_count %i, episode %i, end frame %i, tot. frames %i, eps %0.2f, %s, wins %i, losses %i'
            % (agent.memory.count, ep, step, frames_seen, epsilon, actions,
               wins, ep + 1 - wins))

        # Log progress.
        if ep % log_freq == 0:
            # Write scalars.
            writer.add_scalar('Episode/Loss', np.mean(losses), ep)
            writer.add_scalar('Episode/Episode-length', step, ep)
            writer.add_scalar('Progress/Epsilon', epsilon, frames_seen)
            writer.add_scalar('Progress/Frames', frames_seen, frames_seen)

            if ep < 100:  # Log results and rewards from last n games.
                last_n_results = game_results
                last_n_reward_sums = reward_sums
            else:
                last_n_results = game_results[-100:]
                last_n_reward_sums = reward_sums[-100:]

            cur_win_rate = np.mean(last_n_results)
            mean_rewards = np.mean(last_n_reward_sums)
            writer.add_scalar('Progress/Cumulative-reward', mean_rewards, ep)
            writer.add_scalar('Progress/Win-rate', cur_win_rate, ep)

            # Show random batch of states.
            (state_batch, _, _, _, _), _ = agent.memory.sample_batch(5)
            n, c, h, w = state_batch.shape
            state_batch = state_batch.reshape(n * c, h, w)[:, None, :, :]
            writer.add_images('ReplayBuffer/Sample states', state_batch, ep)

        # Reset agent's internal state.
        agent.reset()

        if ep % save_every_n_ep == 0:
            torch.save(agent.policy_net.state_dict(),
                       os.path.join(model_dir, 'agent_ep.mdl'))
learner_validator_id = 1
validator_id = 3 - learner_validator_id

learner = FishAgent()
teacher = FishAgent()

filename = 'model.mdl'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = torch.load(filename, map_location=device)
learner.policy.load_state_dict(weights, strict=False) #Load weights for learner
teacher.policy.load_state_dict(weights, strict=False) #Load weights for teacher
teacher.policy.eval()
teacher.test = True

learner_validator = FishAgent()
validator_1 = wimblepong.SimpleAi(env, validator_id)
validator_2 = SomeAgent()
validator_2.load_model()
validator_3 = SomeOtherAgent()
validator_3.load_model()
validator_4 = KarpathyAgent()
validator_4.load_model()

validators = [validator_1, validator_2, validator_3, validator_4]
validator = validator_1


env.set_names(learner.get_name(), teacher.get_name())

def validation_run(env,n_games=100):
    learner_validator.policy.load_state_dict(learner.policy.state_dict())