def test_reset(self):
     env = AtariEnvironment('Breakout')
     state = env.reset()
     self.assertEqual(state.observation.shape, (1, 84, 84))
     self.assertEqual(state.reward, 0)
     self.assertFalse(state.done)
     self.assertEqual(state.mask, 1)
 def test_step(self):
     env = AtariEnvironment('Breakout')
     env.reset()
     state = env.step(1)
     self.assertEqual(state.observation.shape, (1, 84, 84))
     self.assertEqual(state.reward, 0)
     self.assertFalse(state.done)
     self.assertEqual(state.mask, 1)
     self.assertEqual(state['life_lost'], False)
 def test_step_until_done(self):
     env = AtariEnvironment('Breakout')
     env.reset()
     for _ in range(1000):
         state = env.step(1)
         if state.done:
             break
     self.assertEqual(state.observation.shape, (1, 84, 84))
     self.assertEqual(state.reward, 0)
     self.assertTrue(state.done)
     self.assertEqual(state.mask, 0)
     self.assertEqual(state['life_lost'], False)
Exemplo n.º 4
0
class TestAtariPresets(unittest.TestCase):
    def setUp(self):
        self.env = AtariEnvironment('Breakout')
        self.env.reset()

    def tearDown(self):
        if os.path.exists('test_preset.pt'):
            os.remove('test_preset.pt')

    def test_a2c(self):
        self.validate_preset(a2c)

    def test_c51(self):
        self.validate_preset(c51)

    def test_ddqn(self):
        self.validate_preset(ddqn)

    def test_dqn(self):
        self.validate_preset(dqn)

    def test_ppo(self):
        self.validate_preset(ppo)

    def test_rainbow(self):
        self.validate_preset(rainbow)

    def test_vac(self):
        self.validate_preset(vac)

    def test_vpq(self):
        self.validate_preset(vpg)

    def test_vsarsa(self):
        self.validate_preset(vsarsa)

    def test_vqn(self):
        self.validate_preset(vqn)

    def validate_preset(self, builder):
        preset = builder.device('cpu').env(self.env).build()
        # normal agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.env.state)
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
def main():
    parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
    parser.add_argument("env", help="Name of the Atari game (e.g. Pong).")
    parser.add_argument(
        "agent",
        help="Name of the agent (e.g. dqn). See presets for available agents.")
    parser.add_argument(
        "--device",
        default="cuda",
        help=
        "The name of the device to run the agent on (e.g. cpu, cuda, cuda:0).",
    )
    parser.add_argument("--frames",
                        type=int,
                        default=40e6,
                        help="The number of training frames.")
    parser.add_argument("--render",
                        type=bool,
                        default=False,
                        help="Render the environment.")
    parser.add_argument("--logdir",
                        default='runs',
                        help="The base logging directory.")
    args = parser.parse_args()

    env = AtariEnvironment(args.env, device=args.device)
    agent_name = args.agent
    agent = getattr(atari, agent_name)

    run_experiment(agent(device=args.device, last_frame=args.frames),
                   env,
                   args.frames,
                   render=args.render,
                   logdir=args.logdir)
Exemplo n.º 6
0
def main():
    # run on gpu
    device = 'cuda'

    def get_agents(preset):
        agents = [
            getattr(preset, agent_name)
            for agent_name in classic_control.__all__
        ]
        return [agent(device=device) for agent in agents]

    SlurmExperiment(get_agents(atari),
                    AtariEnvironment('Breakout', device=device),
                    2e7,
                    sbatch_args={'partition': '1080ti-long'})

    SlurmExperiment(get_agents(classic_control),
                    GymEnvironment('CartPole-v0', device=device),
                    100000,
                    sbatch_args={'partition': '1080ti-short'})

    SlurmExperiment(get_agents(continuous),
                    GymEnvironment('LunarLanderContinuous-v2', device=device),
                    500000,
                    sbatch_args={'partition': '1080ti-short'})
Exemplo n.º 7
0
def run():
    # parse arguments
    parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
    parser.add_argument("env", help="Name of the Atari game (e.g. Pong)")
    parser.add_argument(
        "--device",
        default="cuda",
        help=
        "The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)",
    )
    parser.add_argument("--frames",
                        type=int,
                        default=2e6,
                        help="The number of training frames")
    args = parser.parse_args()

    # create atari environment
    env = AtariEnvironment(args.env, device=args.device)

    # run the experiment
    Experiment(model_predictive_dqn(device=args.device),
               env,
               frames=args.frames)

    # run the baseline agent for comparison
    Experiment(dqn(device=args.device,
                   replay_buffer_size=1e5,
                   last_frame=(args.frames * 4)),
               env,
               frames=args.frames)
def main():
    device = 'cuda'
    envs = [
        AtariEnvironment(env, device)
        for env in ['Pong', 'Breakout', 'SpaceInvaders']
    ]
    SlurmExperiment(a2c(device=device),
                    envs,
                    1e6,
                    sbatch_args={'partition': '1080ti-short'})
Exemplo n.º 9
0
 def test_rainbow_model_cpu(self):
     env = AtariEnvironment('Breakout')
     model = nature_rainbow(env)
     env.reset()
     x = torch.cat([env.state.raw] * 4, dim=1).float()
     out = model(x)
     tt.assert_almost_equal(
         out,
         torch.tensor([[
             0.0676, -0.0235, 0.0690, -0.0713, -0.0287, 0.0053, -0.0463,
             0.0495, -0.0222, -0.0504, 0.0064, -0.0204, 0.0168, 0.0127,
             -0.0113, -0.0586, -0.0544, 0.0114, -0.0077, 0.0666, -0.0663,
             -0.0420, -0.0698, -0.0314, 0.0272, 0.0361, -0.0537, 0.0301,
             0.0036, -0.0472, -0.0499, 0.0114, 0.0182, 0.0008, -0.0132,
             -0.0803, -0.0087, -0.0017, 0.0598, -0.0627, 0.0859, 0.0117,
             0.0105, 0.0309, -0.0370, -0.0111, -0.0262, 0.0338, 0.0141,
             -0.0385, 0.0547, 0.0648, -0.0370, 0.0107, -0.0629, -0.0163,
             0.0282, -0.0670, 0.0161, -0.0244, -0.0030, 0.0038, -0.0208,
             0.0005, 0.0125, 0.0608, -0.0089, 0.0026, 0.0562, -0.0678,
             0.0841, -0.0265, -0.0461, -0.0124, 0.0276, 0.0364, 0.0195,
             -0.0309, -0.0337, -0.0603, -0.0252, -0.0356, 0.0221, 0.0184,
             -0.0154, -0.0136, -0.0277, 0.0283, 0.0495, 0.0185, -0.0357,
             0.0305, -0.0052, -0.0432, -0.0135, -0.0554, -0.0094, 0.0272,
             0.1030, 0.0049, 0.0012, -0.0140, 0.0146, -0.0979, 0.0487,
             0.0122, -0.0204, 0.0496, -0.0055, -0.0015, -0.0170, 0.0053,
             0.0104, -0.0742, 0.0742, -0.0381, 0.0104, -0.0065, -0.0564,
             0.0453, -0.0057, -0.0029, -0.0722, 0.0094, -0.0561, 0.0284,
             0.0402, 0.0233, -0.0716, -0.0424, 0.0165, -0.0505, 0.0006,
             0.0219, -0.0601, 0.0656, -0.0175, -0.0524, 0.0355, 0.0007,
             -0.0042, -0.0443, 0.0871, -0.0403, -0.0031, 0.0171, -0.0359,
             -0.0520, -0.0344, 0.0239, 0.0099, 0.0004, 0.0235, 0.0238,
             -0.0153, 0.0501, -0.0052, 0.0162, 0.0313, -0.0121, 0.0009,
             -0.0366, -0.0628, 0.0386, -0.0671, 0.0480, -0.0595, 0.0568,
             -0.0604, -0.0540, 0.0403, -0.0187, 0.0649, 0.0029, -0.0003,
             0.0020, -0.0056, 0.0471, -0.0145, -0.0126, -0.0395, -0.0455,
             -0.0437, 0.0056, 0.0331, 0.0004, 0.0127, -0.0022, -0.0502,
             0.0362, 0.0624, -0.0012, -0.0515, 0.0303, -0.0357, -0.0420,
             0.0321, -0.0162, 0.0007, -0.0272, 0.0227, 0.0187, -0.0459,
             0.0496
         ]]),
         decimal=3)
Exemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
    parser.add_argument("env", help="Name of the Atari game (e.g. Pong).")
    parser.add_argument(
        "agent",
        help="Name of the agent (e.g. dqn). See presets for available agents.")
    parser.add_argument(
        "--device",
        default="cuda",
        help=
        "The name of the device to run the agent on (e.g. cpu, cuda, cuda:0).",
    )
    parser.add_argument("--frames",
                        type=int,
                        default=40e6,
                        help="The number of training frames.")
    parser.add_argument("--render",
                        action="store_true",
                        default=False,
                        help="Render the environment.")
    parser.add_argument("--logdir",
                        default='runs',
                        help="The base logging directory.")
    parser.add_argument(
        "--writer",
        default='tensorboard',
        help="The backend used for tracking experiment metrics.")
    parser.add_argument('--hyperparameters', default=[], nargs='*')
    args = parser.parse_args()

    env = AtariEnvironment(args.env, device=args.device)

    agent_name = args.agent
    agent = getattr(atari, agent_name)
    agent = agent.device(args.device)

    # parse hyperparameters
    hyperparameters = {}
    for hp in args.hyperparameters:
        key, value = hp.split('=')
        hyperparameters[key] = type(agent.default_hyperparameters[key])(value)
    agent = agent.hyperparameters(**hyperparameters)

    run_experiment(
        agent,
        env,
        args.frames,
        render=args.render,
        logdir=args.logdir,
        writer=args.writer,
    )
Exemplo n.º 11
0
def watch_atari():
    parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
    parser.add_argument("env", help="Name of the Atari game (e.g. Pong)")
    parser.add_argument("dir",
                        help="Directory where the agent's model was saved.")
    parser.add_argument(
        "--device",
        default="cpu",
        help=
        "The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)",
    )
    args = parser.parse_args()
    env = AtariEnvironment(args.env, device=args.device)
    load_and_watch(args.dir, env)
Exemplo n.º 12
0
def main():
    agents = [
        atari.a2c(),
        atari.c51(),
        atari.dqn(),
        atari.ddqn(),
        atari.ppo(),
        atari.rainbow(),
    ]
    envs = [
        AtariEnvironment(env, device='cuda')
        for env in ['BeamRider', 'Breakout', 'Pong', 'Qbert', 'SpaceInvaders']
    ]
    SlurmExperiment(agents,
                    envs,
                    10e6,
                    sbatch_args={'partition': '1080ti-long'})
Exemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
    parser.add_argument("env", help="Name of the Atari game (e.g. Pong)")
    parser.add_argument("filename", help="File where the model was saved.")
    parser.add_argument(
        "--device",
        default="cuda",
        help=
        "The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)",
    )
    parser.add_argument(
        "--fps",
        default=60,
        help="Playback speed",
    )
    args = parser.parse_args()
    env = AtariEnvironment(args.env, device=args.device)
    load_and_watch(args.filename, env, fps=args.fps)
Exemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
    parser.add_argument("env", help="Name of the Atari game (e.g. Pong)")
    parser.add_argument("dir", help="Directory where the agent's model was saved.")
    parser.add_argument(
        "--device",
        default="cpu",
        help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)",
    )
    parser.add_argument(
        "--fps",
        default=60,
        help="Playback speed",
    )
    args = parser.parse_args()
    env = AtariEnvironment(args.env, device=args.device)
    agent = DeepmindAtariBody(GreedyAgent.load(args.dir, env))
    watch(agent, env, fps=args.fps)
Exemplo n.º 15
0
def run_atari():
    parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
    parser.add_argument("env", help="Name of the Atari game (e.g. Pong)")
    parser.add_argument(
        "agent", help="Name of the agent (e.g. dqn). See presets for available agents."
    )
    parser.add_argument(
        "--device",
        default="cuda",
        help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)",
    )
    parser.add_argument(
        "--frames", type=int, default=40e6, help="The number of training frames"
    )
    args = parser.parse_args()

    env = AtariEnvironment(args.env, device=args.device)
    agent_name = args.agent
    agent = getattr(atari, agent_name)

    Experiment(agent(device=args.device, last_frame=args.frames), env, frames=args.frames)
def run_atari():
    parser = argparse.ArgumentParser(description='Run an Atari benchmark.')
    parser.add_argument('env', help='Name of the Atari game (e.g. Pong)')
    parser.add_argument(
        'agent',
        help="Name of the agent (e.g. dqn). See presets for available agents.")
    parser.add_argument(
        '--device',
        default='cuda',
        help=
        'The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)')
    parser.add_argument('--frames',
                        type=int,
                        default=100e6,
                        help='The number of training frames')
    args = parser.parse_args()

    env = AtariEnvironment(args.env, device=args.device)
    agent_name = args.agent
    agent = getattr(atari, agent_name)

    experiment = Experiment(env, frames=args.frames)
    experiment.run(agent(device=args.device), label=agent_name)
    def test_runs(self):
        np.random.seed(0)
        torch.random.manual_seed(0)
        n = 4
        envs = []
        for i in range(n):
            env = AtariEnvironment('Breakout')
            env.reset()
            envs.append(env)
        agent = MockAgent(n, max_action=4)
        body = ParallelAtariBody(agent, envs, noop_max=30)

        for _ in range(200):
            states = [env.state for env in envs]
            rewards = torch.tensor([env.reward for env in envs]).float()
            actions = body.act(states, rewards)
            for i, env in enumerate(envs):
                if actions[i] is not None:
                    env.step(actions[i])
Exemplo n.º 18
0
 def setUp(self):
     self.env = AtariEnvironment('Breakout')
     self.env.reset()
     self.parallel_env = DuplicateEnvironment([AtariEnvironment('Breakout'), AtariEnvironment('Breakout')])
     self.parallel_env.reset()
 def test_ppo(self):
     validate_agent(
         ppo.device(CPU).hyperparameters(n_envs=4),
         AtariEnvironment("Breakout", device=CPU))
 def test_rainbow_cuda(self):
     validate_agent(
         rainbow.device(CUDA).hyperparameters(replay_start_size=64),
         AtariEnvironment("Breakout", device=CUDA),
     )
 def test_vac_cuda(self):
     validate_agent(
         vac.device(CUDA).hyperparameters(n_envs=4),
         AtariEnvironment("Breakout", device=CUDA))
Exemplo n.º 22
0
 def test_vsarsa(self):
     validate_agent(vsarsa(device=CPU, n_envs=4),
                    AtariEnvironment("Breakout", device=CPU))
Exemplo n.º 23
0
 def test_vpg_cuda(self):
     validate_agent(vpg(device=CUDA),
                    AtariEnvironment("Breakout", device=CUDA))
Exemplo n.º 24
0
 def test_vpg(self):
     validate_agent(vpg(device=CPU), AtariEnvironment("Breakout",
                                                      device=CPU))
Exemplo n.º 25
0
 def test_vac_cuda(self):
     validate_agent(vac(device=CUDA, n_envs=4),
                    AtariEnvironment("Breakout", device=CUDA))
Exemplo n.º 26
0
 def test_rainbow_cuda(self):
     validate_agent(
         rainbow(replay_start_size=64, device=CUDA),
         AtariEnvironment("Breakout", device=CUDA),
     )
Exemplo n.º 27
0
 def test_dqn(self):
     validate_agent(
         dqn(replay_start_size=64, device=CPU),
         AtariEnvironment("Breakout", device=CPU),
     )
Exemplo n.º 28
0
 def test_ddqn_cuda(self):
     validate_agent(
         ddqn(replay_start_size=64, device=CUDA),
         AtariEnvironment("Breakout", device=CUDA),
     )
'''
Quick demo of a2c running on slurm.
Note that it only runs for 1 million frames.
For real experiments, you will surely need a modified version of this script.
'''
from gym import envs
from all.experiments import SlurmExperiment
from all.presets.atari import a2c
from all.environments import AtariEnvironment

# Quick demo of a2c running on slurm.
# Note that it only runs for 1 million frames.
# For real experiments, you will surely need a modified version of this script.
device = 'cuda'
envs = [AtariEnvironment(env, device) for env in ['Pong', 'Breakout', 'SpaceInvaders']]
SlurmExperiment(a2c, envs, 1e6, hyperparameters={'device': device}, sbatch_args={
    'partition': '1080ti-short'
})
Exemplo n.º 30
0
class TestAtariPresets(unittest.TestCase):
    def setUp(self):
        self.env = AtariEnvironment('Breakout')
        self.env.reset()
        self.parallel_env = DuplicateEnvironment([AtariEnvironment('Breakout'), AtariEnvironment('Breakout')])
        self.parallel_env.reset()

    def tearDown(self):
        if os.path.exists('test_preset.pt'):
            os.remove('test_preset.pt')

    def test_a2c(self):
        self.validate_preset(a2c)

    def test_c51(self):
        self.validate_preset(c51)

    def test_ddqn(self):
        self.validate_preset(ddqn)

    def test_dqn(self):
        self.validate_preset(dqn)

    def test_ppo(self):
        self.validate_preset(ppo)

    def test_rainbow(self):
        self.validate_preset(rainbow)

    def test_vac(self):
        self.validate_preset(vac)

    def test_vpq(self):
        self.validate_preset(vpg)

    def test_vsarsa(self):
        self.validate_preset(vsarsa)

    def test_vqn(self):
        self.validate_preset(vqn)

    def validate_preset(self, builder):
        preset = builder.device('cpu').env(self.env).build()
        if isinstance(preset, ParallelPreset):
            return self.validate_parallel_preset(preset)
        return self.validate_standard_preset(preset)

    def validate_standard_preset(self, preset):
        # train agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.env.state)
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)

    def validate_parallel_preset(self, preset):
        # train agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.parallel_env.state_array)
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
        # parallel test_agent
        parallel_test_agent = preset.test_agent()
        parallel_test_agent.act(self.parallel_env.state_array)
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)