示例#1
0
def worker(env_name, policy_cls, worker_kwargs):
    return RolloutWorker(
        env_creator=get_env_creator(env_name),
        policy_spec=policy_cls,
        policy_config={"env": env_name, "framework": "torch"},
        **worker_kwargs,
    )
示例#2
0
def worker(env_name, policy_cls, worker_kwargs):
    return RolloutWorker(
        env_creator=get_env_creator(env_name),
        policy=policy_cls,
        policy_config={"env": env_name},
        **worker_kwargs,
    )
示例#3
0
def worker(envs, env_name, policy_cls):
    return RolloutWorker(
        env_creator=envs[env_name],
        policy=policy_cls,
        policy_config={"env": env_name},
        rollout_fragment_length=1,
        batch_mode="complete_episodes",
    )
示例#4
0
def make_worker(*, env_config: dict, **kwargs) -> RolloutWorker:
    """Build rollout worker for a linear feedback policy on LQG.

    This function calls `initialize_from_lqg` on the created policy with the
    created LQG instance as an argument.
    The worker will sample complete trajectories on each call to .sample().

    Args:
        env_config: the configuration for the LQG environment

    Returns:
        A RolloutWorker instance
    """
    _validate_env = kwargs.pop("validate_env", None)

    def validate_env(env: EnvType, env_context: EnvContext):
        assert isinstance(env, RandomVectorLQG)
        assert env.num_envs == env_context["num_envs"]

        if _validate_env:
            _validate_env(env, env_context)

    policy_config = kwargs.pop("policy_config", {})
    policy_config["framework"] = "torch"

    # Create and initialize
    worker = RolloutWorker(
        env_creator=get_env_creator("RandomVectorLQG"),
        validate_env=validate_env,
        env_config=env_config,
        num_envs=env_config["num_envs"],
        policy_spec=LQGPolicy,
        policy_config=policy_config,
        rollout_fragment_length=env_config["horizon"],
        batch_mode="truncate_episodes",
        _use_trajectory_view_api=False,
        **kwargs
    )
    worker.foreach_trainable_policy(lambda p, _: p.setup(worker.env))
    return worker
示例#5
0
文件: apex.py 项目: alipay/ray
 def remote_worker_sample_and_store(worker: RolloutWorker,
                                    replay_actors: List[ReplayActor]):
     # This function is run as a remote function on sampling workers,
     # and should only be used with the RolloutWorker's apply function ever.
     # It is used to gather samples, and trigger the operation to store them to
     # replay actors from the rollout worker instead of returning the obj
     # refs for the samples to the driver process and doing the sampling
     # operation on there.
     _batch = worker.sample()
     _actor = random.choice(replay_actors)
     _actor.add_batch.remote(_batch)
     _batch_statistics = {
         "agent_steps": _batch.agent_steps(),
         "env_steps": _batch.env_steps(),
     }
     return _batch_statistics