示例#1
0
def worker(env_name, policy_cls, worker_kwargs):
    return RolloutWorker(
        env_creator=get_env_creator(env_name),
        policy_spec=policy_cls,
        policy_config={"env": env_name, "framework": "torch"},
        **worker_kwargs,
    )
示例#2
0
def worker(env_name, policy_cls, worker_kwargs):
    return RolloutWorker(
        env_creator=get_env_creator(env_name),
        policy=policy_cls,
        policy_config={"env": env_name},
        **worker_kwargs,
    )
示例#3
0
def worker(envs, env_name, policy_cls):
    return RolloutWorker(
        env_creator=envs[env_name],
        policy=policy_cls,
        policy_config={"env": env_name},
        rollout_fragment_length=1,
        batch_mode="complete_episodes",
    )
示例#4
0
def make_worker(*, env_config: dict, **kwargs) -> RolloutWorker:
    """Build rollout worker for a linear feedback policy on LQG.

    This function calls `initialize_from_lqg` on the created policy with the
    created LQG instance as an argument.
    The worker will sample complete trajectories on each call to .sample().

    Args:
        env_config: the configuration for the LQG environment

    Returns:
        A RolloutWorker instance
    """
    _validate_env = kwargs.pop("validate_env", None)

    def validate_env(env: EnvType, env_context: EnvContext):
        assert isinstance(env, RandomVectorLQG)
        assert env.num_envs == env_context["num_envs"]

        if _validate_env:
            _validate_env(env, env_context)

    policy_config = kwargs.pop("policy_config", {})
    policy_config["framework"] = "torch"

    # Create and initialize
    worker = RolloutWorker(
        env_creator=get_env_creator("RandomVectorLQG"),
        validate_env=validate_env,
        env_config=env_config,
        num_envs=env_config["num_envs"],
        policy_spec=LQGPolicy,
        policy_config=policy_config,
        rollout_fragment_length=env_config["horizon"],
        batch_mode="truncate_episodes",
        _use_trajectory_view_api=False,
        **kwargs
    )
    worker.foreach_trainable_policy(lambda p, _: p.setup(worker.env))
    return worker