示例#1
0
def run():
    # parse arguments
    parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
    parser.add_argument("env", help="Name of the Atari game (e.g. Pong)")
    parser.add_argument(
        "--device",
        default="cuda",
        help=
        "The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)",
    )
    parser.add_argument("--frames",
                        type=int,
                        default=2e6,
                        help="The number of training frames")
    args = parser.parse_args()

    # create atari environment
    env = AtariEnvironment(args.env, device=args.device)

    # run the experiment
    Experiment(model_predictive_dqn(device=args.device),
               env,
               frames=args.frames)

    # run the baseline agent for comparison
    Experiment(dqn(device=args.device,
                   replay_buffer_size=1e5,
                   last_frame=(args.frames * 4)),
               env,
               frames=args.frames)
def main():
    agents = [
        atari.a2c(),
        atari.c51(),
        atari.dqn(),
        atari.ddqn(),
        atari.ppo(),
        atari.rainbow(),
    ]
    envs = [
        AtariEnvironment(env, device='cuda')
        for env in ['BeamRider', 'Breakout', 'Pong', 'Qbert', 'SpaceInvaders']
    ]
    SlurmExperiment(agents,
                    envs,
                    10e6,
                    sbatch_args={'partition': '1080ti-long'})
示例#3
0
 def test_dqn_cuda(self):
     validate_agent(
         dqn(replay_start_size=64, device=CUDA),
         AtariEnvironment("Breakout", device=CUDA),
     )