from agents.common.networks import CNNDeepmind_Multihead
from agents.eqrdqn.eqrdqn import EQRDQN
from agents.common.atari_wrappers import make_atari, wrap_deepmind

import pickle
import numpy as np
import matplotlib.pyplot as plt

notes = "This is a test run."

env = make_atari("BreakoutNoFrameskip-v0", noop=False)
env = wrap_deepmind(env, episode_life=False)

nb_steps = 12500000

agent = EQRDQN(env,
               CNNDeepmind_Multihead,
               n_quantiles=200,
               kappa=0,
               prior=0.0001,
               replay_start_size=50000,
               replay_buffer_size=1000000,
               gamma=0.99,
               update_target_frequency=10000,
               minibatch_size=32,
               learning_rate=5e-5,
               adam_epsilon=0.01 / 32,
               update_frequency=4,
               log_folder_details="Breakout-EQRDQN",
               logging=True,
               notes=notes)
Ejemplo n.º 2
0
from agents.common.networks.cnn_deepmind import CNNDeepmind
from agents.qrdqn.qrdqn import QRDQN
from agents.common.atari_wrappers import make_atari, wrap_deepmind

import pickle
import numpy as np
import matplotlib.pyplot as plt

notes = "This is a test run."

env_name = "BreakoutNoFrameskip-v4"
env = make_atari(env_name,noop=True)
env = wrap_deepmind(env, episode_life=True)

nb_steps = 12500000

agent = QRDQN( env,
                 CNNDeepmind,
                 n_quantiles=200,
                 kappa=0,
                 replay_start_size=50000,
                 replay_buffer_size=1000000,
                 gamma=0.99,
                 update_target_frequency=10000,
                 minibatch_size=32,
                 learning_rate=5e-5,
                 initial_exploration_rate=1,
                 final_exploration_rate=0.01,
                 final_exploration_step=1000000,
                 adam_epsilon=0.01/32,
Ejemplo n.º 3
0
import gym
import matplotlib.pyplot as plt
import numpy as np

from agents.ide.ide import IDE
from agents.common.networks.cnn_deepmind import CNNDeepmind_Multihead

from agents.common.atari_wrappers import make_atari, wrap_deepmind

FOLDER = "???"

game_scores = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = make_atari("AsterixNoFrameskip", noop=True)
env = wrap_deepmind(env, clip_rewards=False, episode_life=False)

agent = IDE(env, CNNDeepmind_Multihead, n_quantiles=200)

for i in range(50):

    filename = "network_" + str((i + 1) * 250000) + ".pth"

    agent.load(FOLDER + filename)

    score = 0
    scores = []
    total_timesteps = 0
    while total_timesteps < 125000:
        done = False
        obs = env.reset()