def run_rollout(args, parser):

    config = args.config
    if not args.env:
        if not config.get("env"):
            parser.error("the following arguments are required: --env")
        args.env = config.get("env")

    # Create the Trainer from config.
    cls = get_trainable_cls(args.run)
    agent = cls(env=args.env, config=config)

    # Load state from checkpoint.
    agent.restore(args.checkpoint)
    num_steps = int(args.steps)
    num_episodes = int(args.episodes)

    # Determine the video output directory.
    use_arg_monitor = False
    try:
        args.video_dir
    except AttributeError:
        print("There is no such attribute: args.video_dir")
        use_arg_monitor = True

    video_dir = None
    if not use_arg_monitor:
        if args.monitor:
            video_dir = os.path.join("./logs", "video")
        elif args.video_dir:
            video_dir = os.path.expanduser(args.video_dir)

    # Do the actual rollout.
    with rollout.RolloutSaver(
            args.out,
            args.use_shelve,
            write_update_file=args.track_progress,
            target_steps=num_steps,
            target_episodes=num_episodes,
            save_info=args.save_info) as saver:
        if use_arg_monitor:
            rollout.rollout(
                agent,
                args.env,
                num_steps,
                num_episodes,
                saver,
                args.no_render,
                args.monitor)
        else:
            rollout.rollout(
                agent, args.env,
                num_steps,
                num_episodes,
                saver,
                args.no_render, video_dir)
示例#2
0
    def test_rollout_dict_space(self):
        register_env("nested", lambda _: NestedDictEnv())
        agent = PGTrainer(env="nested", config={"framework": "tf"})
        agent.train()
        path = agent.save()
        agent.stop()

        # Test train works on restore
        agent2 = PGTrainer(env="nested", config={"framework": "tf"})
        agent2.restore(path)
        agent2.train()

        # Test rollout works on restore
        rollout(agent2, "nested", 100)
示例#3
0
    def testRolloutDictSpace(self):
        register_env("nested", lambda _: NestedDictEnv())
        agent = PGTrainer(env="nested")
        agent.train()
        path = agent.save()
        agent.stop()

        # Test train works on restore
        agent2 = PGTrainer(env="nested")
        agent2.restore(path)
        agent2.train()

        # Test rollout works on restore
        rollout(agent2, "nested", 100)
示例#4
0
    def testRolloutDictSpace(self):
        register_env("nested", lambda _: NestedDictEnv())
        agent = PGAgent(env="nested")
        agent.train()
        path = agent.save()
        agent.stop()

        # Test train works on restore
        agent2 = PGAgent(env="nested")
        agent2.restore(path)
        agent2.train()

        # Test rollout works on restore
        rollout(agent2, "nested", 100)
示例#5
0
 def test_ray_rollout_reader(self):
     for trainer_fn in ALL_TRAINERS:
         for env_fn in ALL_ENVS:
             with self.subTest(trainer=trainer_fn, env=env_fn):
                 trainer = trainer_fn(
                     config={
                         "env": env_fn,
                         "framework": "torch",
                         "train_batch_size": 128,
                         "rollout_fragment_length": 128,
                         "create_env_on_driver": True,
                     })
                 with tempfile.TemporaryDirectory() as temp_dir:
                     rollout_file = Path(temp_dir) / "rollout"
                     rollout(
                         trainer,
                         env_fn,
                         num_steps=100,
                         num_episodes=10,
                         saver=RolloutSaver(outfile=str(rollout_file)),
                     )
                     rollout_reader = RayFileRolloutReader(rollout_file)
                     for trajectory in rollout_reader:
                         self.assertIsInstance(trajectory.timesteps, list)
示例#6
0
from ray.rllib import rollout

rollout.rollout()
示例#7
0
                        help="Max step per episode (default is 400)")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Configuration filename for MARl algorithms.")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="Checkpoint for model loading")
    parser.add_argument("--out", type=str, default="out.pkl")
    args = parser.parse_args()

    with open(args.config, 'rb') as f:
        configs = pickle.load(f)

    ray.init()
    env_config = configs['ray_config']['env_config']  # ['env_name']
    env_config['render'] = True  # turn on render mode
    env_name = env_config['env_setup']['env_name']

    register_env(
        env_name,
        lambda kwargs: multiprocess_wrapper.MultiAgentClient(**kwargs))

    cls = get_agent_class(configs["run"])
    agents = cls(env=env_name, config=configs['ray_config'])
    agents.restore(args.checkpoint)

    rollout(agents, env_name, args.max_step, args.out)
    agents.stop()
示例#8
0
        'env_config': env_config,
        'num_workers': 2,
        'log_level': 'ERROR',
        'framework': 'tf',
        'model': model_config,
    }

    register_env('DirectCnnEnv-v0',
                 lambda env_config: DirectCnnEnv(env_config))

    a2c_trainer = a3c.A2CTrainer(config=config, env='DirectCnnEnv-v0')

    policy = a2c_trainer.get_policy()
    cur_model = policy.model.base_model
    cur_model.summary()

    for i in tqdm(range(1000)):
        result = a2c_trainer.train()
        print(f"{result['episode_reward_max']:.4f}  |  "
              f"{result['episode_reward_mean']:.4f}  |  "
              f"{result['episode_reward_min']:.4f}")

        if i % 10 == 0:
            checkpoint = a2c_trainer.save()

            print("checkpoint saved at", checkpoint)
            rollout.rollout(a2c_trainer,
                            env_name='DirectCnnEnv-v0',
                            num_steps=1,
                            no_render=False)