def worker(env_name, policy_cls, worker_kwargs): return RolloutWorker( env_creator=get_env_creator(env_name), policy_spec=policy_cls, policy_config={"env": env_name, "framework": "torch"}, **worker_kwargs, )
def worker(env_name, policy_cls, worker_kwargs): return RolloutWorker( env_creator=get_env_creator(env_name), policy=policy_cls, policy_config={"env": env_name}, **worker_kwargs, )
def worker(envs, env_name, policy_cls): return RolloutWorker( env_creator=envs[env_name], policy=policy_cls, policy_config={"env": env_name}, rollout_fragment_length=1, batch_mode="complete_episodes", )
def make_worker(*, env_config: dict, **kwargs) -> RolloutWorker: """Build rollout worker for a linear feedback policy on LQG. This function calls `initialize_from_lqg` on the created policy with the created LQG instance as an argument. The worker will sample complete trajectories on each call to .sample(). Args: env_config: the configuration for the LQG environment Returns: A RolloutWorker instance """ _validate_env = kwargs.pop("validate_env", None) def validate_env(env: EnvType, env_context: EnvContext): assert isinstance(env, RandomVectorLQG) assert env.num_envs == env_context["num_envs"] if _validate_env: _validate_env(env, env_context) policy_config = kwargs.pop("policy_config", {}) policy_config["framework"] = "torch" # Create and initialize worker = RolloutWorker( env_creator=get_env_creator("RandomVectorLQG"), validate_env=validate_env, env_config=env_config, num_envs=env_config["num_envs"], policy_spec=LQGPolicy, policy_config=policy_config, rollout_fragment_length=env_config["horizon"], batch_mode="truncate_episodes", _use_trajectory_view_api=False, **kwargs ) worker.foreach_trainable_policy(lambda p, _: p.setup(worker.env)) return worker