def __init__(self, registry, env_creator, config, logdir, start_sampler=True): env = ModelCatalog.get_preprocessor_as_wrapper( registry, env_creator(config["env_config"]), config["model"]) self.env = env policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space self.policy = policy_cls(registry, env.observation_space.shape, env.action_space, config) self.config = config # Technically not needed when not remote self.obs_filter = get_filter(config["observation_filter"], env.observation_space.shape) self.rew_filter = get_filter(config["reward_filter"], ()) self.filters = { "obs_filter": self.obs_filter, "rew_filter": self.rew_filter } self.sampler = AsyncSampler(env, self.policy, self.obs_filter, config["batch_size"]) if start_sampler and self.sampler. async: self.sampler.start() self.logdir = logdir
def __init__(self, env_creator, config, logdir): self.env = env = create_and_wrap(env_creator, config["model"]) policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space self.policy = policy_cls(env.observation_space.shape, env.action_space) obs_filter = get_filter(config["observation_filter"], env.observation_space.shape) self.rew_filter = get_filter(config["reward_filter"], ()) self.sampler = AsyncSampler(env, self.policy, obs_filter, config["batch_size"]) self.logdir = logdir
def __init__(self, env_creator, config, logdir, start_sampler=True): self.env = env = create_and_wrap(env_creator, config["preprocessing"]) policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space self.policy = policy_cls(env.observation_space.shape, env.action_space, config) self.config = config # Technically not needed when not remote self.obs_filter = get_filter(config["observation_filter"], env.observation_space.shape) self.rew_filter = get_filter(config["reward_filter"], ()) self.sampler = AsyncSampler(env, self.policy, self.obs_filter, config["batch_size"]) if start_sampler and self.sampler. async: self.sampler.start() self.logdir = logdir
def __init__( self, registry, env_creator, config, logdir, start_sampler=True): env = ModelCatalog.get_preprocessor_as_wrapper( registry, env_creator(config["env_config"]), config["model"]) self.env = env policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space self.policy = policy_cls( registry, env.observation_space.shape, env.action_space, config) self.config = config # Technically not needed when not remote self.obs_filter = get_filter( config["observation_filter"], env.observation_space.shape) self.rew_filter = get_filter(config["reward_filter"], ()) self.filters = {"obs_filter": self.obs_filter, "rew_filter": self.rew_filter} self.sampler = AsyncSampler(env, self.policy, self.obs_filter, config["batch_size"]) if start_sampler and self.sampler.async: self.sampler.start() self.logdir = logdir
def __init__(self, env_creator, policy_graph, tf_session_creator=None, batch_steps=100, batch_mode="truncate_episodes", episode_horizon=None, preprocessor_pref="rllib", sample_async=False, compress_observations=False, num_envs=1, observation_filter="NoFilter", env_config=None, model_config=None, policy_config=None): """Initialize a policy evaluator. Arguments: env_creator (func): Function that returns a gym.Env given an env config dict. policy_graph (class): A class implementing rllib.PolicyGraph or rllib.TFPolicyGraph. tf_session_creator (func): A function that returns a TF session. This is optional and only useful with TFPolicyGraph. batch_steps (int): The target number of env transitions to include in each sample batch returned from this evaluator. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch of exactly `batch_steps` in size. Episodes may be truncated in order to meet this size requirement. When `num_envs > 1`, episodes will be truncated to sequences of `batch_size / num_envs` in length. "complete_episodes": Each call to sample() will return a batch of at least `batch_steps in size. Episodes will not be truncated, but multiple episodes may be packed within one batch to meet the batch size. Note that when `num_envs > 1`, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data. episode_horizon (int): Whether to stop episodes at this horizon. preprocessor_pref (str): Whether to prefer RLlib preprocessors ("rllib") or deepmind ("deepmind") when applicable. sample_async (bool): Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy. compress_observations (bool): If true, compress the observations returned. num_envs (int): If more than one, will create multiple envs and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv. observation_filter (str): Name of observation filter to use. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. """ env_config = env_config or {} policy_config = policy_config or {} model_config = model_config or {} self.env_creator = env_creator self.policy_graph = policy_graph self.batch_steps = batch_steps self.batch_mode = batch_mode self.compress_observations = compress_observations self.env = env_creator(env_config) if isinstance(self.env, VectorEnv) or \ isinstance(self.env, ServingEnv) or \ isinstance(self.env, AsyncVectorEnv): def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ "custom_preprocessor" not in model_config and \ preprocessor_pref == "deepmind": def wrap(env): return wrap_deepmind(env, dim=model_config.get("dim", 80)) else: def wrap(env): return ModelCatalog.get_preprocessor_as_wrapper( env, model_config) self.env = wrap(self.env) def make_env(): return wrap(env_creator(env_config)) self.policy_map = {} if issubclass(policy_graph, TFPolicyGraph): with tf.Graph().as_default(): if tf_session_creator: self.sess = tf_session_creator() else: self.sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions(allow_growth=True))) with self.sess.as_default(): policy = policy_graph(self.env.observation_space, self.env.action_space, policy_config) else: policy = policy_graph(self.env.observation_space, self.env.action_space, policy_config) self.policy_map = {"default": policy} self.obs_filter = get_filter(observation_filter, self.env.observation_space.shape) self.filters = {"obs_filter": self.obs_filter} # Always use vector env for consistency even if num_envs = 1 if not isinstance(self.env, AsyncVectorEnv): if isinstance(self.env, ServingEnv): self.vector_env = _ServingEnvToAsync(self.env) else: if not isinstance(self.env, VectorEnv): self.env = VectorEnv.wrap(make_env, [self.env], num_envs=num_envs) self.vector_env = _VectorEnvToAsync(self.env) else: self.vector_env = self.env if self.batch_mode == "truncate_episodes": if batch_steps % num_envs != 0: raise ValueError( "In 'truncate_episodes' batch mode, `batch_steps` must be " "evenly divisible by `num_envs`. Got {} and {}.".format( batch_steps, num_envs)) batch_steps = batch_steps // num_envs pack_episodes = True elif self.batch_mode == "complete_episodes": batch_steps = float("inf") # never cut episodes pack_episodes = False # sampler will return 1 episode per poll else: raise ValueError("Unsupported batch mode: {}".format( self.batch_mode)) if sample_async: self.sampler = AsyncSampler(self.vector_env, self.policy_map["default"], self.obs_filter, batch_steps, horizon=episode_horizon, pack=pack_episodes) self.sampler.start() else: self.sampler = SyncSampler(self.vector_env, self.policy_map["default"], self.obs_filter, batch_steps, horizon=episode_horizon, pack=pack_episodes)
class A3CEvaluator(Evaluator): """Actor object to start running simulation on workers. The gradient computation is also executed from this object. Attributes: policy: Copy of graph used for policy. Used by sampler and gradients. obs_filter: Observation filter used in environment sampling rew_filter: Reward filter used in rollout post-processing. sampler: Component for interacting with environment and generating rollouts. logdir: Directory for logging. """ def __init__( self, registry, env_creator, config, logdir, start_sampler=True): env = ModelCatalog.get_preprocessor_as_wrapper( registry, env_creator(config["env_config"]), config["model"]) self.env = env policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space self.policy = policy_cls( registry, env.observation_space.shape, env.action_space, config) self.config = config # Technically not needed when not remote self.obs_filter = get_filter( config["observation_filter"], env.observation_space.shape) self.rew_filter = get_filter(config["reward_filter"], ()) self.filters = {"obs_filter": self.obs_filter, "rew_filter": self.rew_filter} self.sampler = AsyncSampler(env, self.policy, self.obs_filter, config["batch_size"]) if start_sampler and self.sampler.async: self.sampler.start() self.logdir = logdir def sample(self): rollout = self.sampler.get_data() samples = process_rollout( rollout, self.rew_filter, gamma=self.config["gamma"], lambda_=self.config["lambda"], use_gae=True) return samples def get_completed_rollout_metrics(self): """Returns metrics on previously completed rollouts. Calling this clears the queue of completed rollout metrics. """ return self.sampler.get_metrics() def compute_gradients(self, samples): gradient, info = self.policy.compute_gradients(samples) return gradient def apply_gradients(self, grads): self.policy.apply_gradients(grads) def get_weights(self): return self.policy.get_weights() def set_weights(self, params): self.policy.set_weights(params) def save(self): filters = self.get_filters(flush_after=True) weights = self.get_weights() return pickle.dumps({ "filters": filters, "weights": weights}) def restore(self, objs): objs = pickle.loads(objs) self.sync_filters(objs["filters"]) self.set_weights(objs["weights"]) def sync_filters(self, new_filters): """Changes self's filter to given and rebases any accumulated delta. Args: new_filters (dict): Filters with new state to update local copy. """ assert all(k in new_filters for k in self.filters) for k in self.filters: self.filters[k].sync(new_filters[k]) def get_filters(self, flush_after=False): """Returns a snapshot of filters. Args: flush_after (bool): Clears the filter buffer state. Returns: return_filters (dict): Dict for serializable filters """ return_filters = {} for k, f in self.filters.items(): return_filters[k] = f.as_serializable() if flush_after: f.clear_buffer() return return_filters
class A3CEvaluator(Evaluator): """Actor object to start running simulation on workers. The gradient computation is also executed from this object. Attributes: policy: Copy of graph used for policy. Used by sampler and gradients. rew_filter: Reward filter used in rollout post-processing. sampler: Component for interacting with environment and generating rollouts. logdir: Directory for logging. """ def __init__(self, env_creator, config, logdir, start_sampler=True): self.env = env = create_and_wrap(env_creator, config["preprocessing"]) policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space self.policy = policy_cls(env.observation_space.shape, env.action_space, config) self.config = config # Technically not needed when not remote self.obs_filter = get_filter(config["observation_filter"], env.observation_space.shape) self.rew_filter = get_filter(config["reward_filter"], ()) self.sampler = AsyncSampler(env, self.policy, self.obs_filter, config["batch_size"]) if start_sampler and self.sampler. async: self.sampler.start() self.logdir = logdir def sample(self): """ Returns: trajectory (PartialRollout): Experience Samples from evaluator""" rollout = self.sampler.get_data() samples = process_rollout(rollout, self.rew_filter, gamma=self.config["gamma"], lambda_=self.config["lambda"], use_gae=True) return samples def get_completed_rollout_metrics(self): """Returns metrics on previously completed rollouts. Calling this clears the queue of completed rollout metrics. """ return self.sampler.get_metrics() def compute_gradients(self, samples): gradient, info = self.policy.compute_gradients(samples) return gradient def apply_gradients(self, grads): self.policy.apply_gradients(grads) def get_weights(self): return self.policy.get_weights() def set_weights(self, params): self.policy.set_weights(params) def update_filters(self, obs_filter=None, rew_filter=None): if rew_filter: # No special handling required since outside of threaded code self.rew_filter = rew_filter.copy() if obs_filter: self.sampler.update_obs_filter(obs_filter) def save(self): weights = self.get_weights() return pickle.dumps({"weights": weights}) def restore(self, objs): objs = pickle.loads(objs) self.set_weights(objs["weights"])
class A3CEvaluator(Evaluator): """Actor object to start running simulation on workers. The gradient computation is also executed from this object. Attributes: policy: Copy of graph used for policy. Used by sampler and gradients. rew_filter: Reward filter used in rollout post-processing. sampler: Component for interacting with environment and generating rollouts. logdir: Directory for logging. """ def __init__(self, env_creator, config, logdir): self.env = env = create_and_wrap(env_creator, config["model"]) policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space self.policy = policy_cls(env.observation_space.shape, env.action_space) obs_filter = get_filter(config["observation_filter"], env.observation_space.shape) self.rew_filter = get_filter(config["reward_filter"], ()) self.sampler = AsyncSampler(env, self.policy, obs_filter, config["batch_size"]) self.logdir = logdir def sample(self): """ Returns: trajectory (PartialRollout): Experience Samples from evaluator""" rollout = self.sampler.get_data() return rollout def get_completed_rollout_metrics(self): """Returns metrics on previously completed rollouts. Calling this clears the queue of completed rollout metrics. """ return self.sampler.get_metrics() def compute_gradient(self): rollout = self.sampler.get_data() obs_filter = self.sampler.get_obs_filter(flush=True) traj = process_rollout(rollout, self.rew_filter, gamma=0.99, lambda_=1.0, use_gae=True) gradient, info = self.policy.compute_gradients(traj) info["obs_filter"] = obs_filter info["rew_filter"] = self.rew_filter return gradient, info def apply_gradient(self, grads): self.policy.apply_gradients(grads) def set_weights(self, params): self.policy.set_weights(params) def update_filters(self, obs_filter=None, rew_filter=None): if rew_filter: # No special handling required since outside of threaded code self.rew_filter = rew_filter.copy() if obs_filter: self.sampler.update_obs_filter(obs_filter)
class A3CEvaluator(Evaluator): """Actor object to start running simulation on workers. The gradient computation is also executed from this object. Attributes: policy: Copy of graph used for policy. Used by sampler and gradients. obs_filter: Observation filter used in environment sampling rew_filter: Reward filter used in rollout post-processing. sampler: Component for interacting with environment and generating rollouts. logdir: Directory for logging. """ def __init__(self, registry, env_creator, config, logdir, start_sampler=True): env = ModelCatalog.get_preprocessor_as_wrapper( registry, env_creator(config["env_config"]), config["model"]) self.env = env policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space self.policy = policy_cls(registry, env.observation_space.shape, env.action_space, config) self.config = config # Technically not needed when not remote self.obs_filter = get_filter(config["observation_filter"], env.observation_space.shape) self.rew_filter = get_filter(config["reward_filter"], ()) self.filters = { "obs_filter": self.obs_filter, "rew_filter": self.rew_filter } self.sampler = AsyncSampler(env, self.policy, self.obs_filter, config["batch_size"]) if start_sampler and self.sampler. async: self.sampler.start() self.logdir = logdir def sample(self): rollout = self.sampler.get_data() samples = process_rollout(rollout, self.rew_filter, gamma=self.config["gamma"], lambda_=self.config["lambda"], use_gae=True) return samples def get_completed_rollout_metrics(self): """Returns metrics on previously completed rollouts. Calling this clears the queue of completed rollout metrics. """ return self.sampler.get_metrics() def compute_gradients(self, samples): gradient, info = self.policy.compute_gradients(samples) return gradient def apply_gradients(self, grads): self.policy.apply_gradients(grads) def get_weights(self): return self.policy.get_weights() def set_weights(self, params): self.policy.set_weights(params) def save(self): filters = self.get_filters(flush_after=True) weights = self.get_weights() return pickle.dumps({"filters": filters, "weights": weights}) def restore(self, objs): objs = pickle.loads(objs) self.sync_filters(objs["filters"]) self.set_weights(objs["weights"]) def sync_filters(self, new_filters): """Changes self's filter to given and rebases any accumulated delta. Args: new_filters (dict): Filters with new state to update local copy. """ assert all(k in new_filters for k in self.filters) for k in self.filters: self.filters[k].sync(new_filters[k]) def get_filters(self, flush_after=False): """Returns a snapshot of filters. Args: flush_after (bool): Clears the filter buffer state. Returns: return_filters (dict): Dict for serializable filters """ return_filters = {} for k, f in self.filters.items(): return_filters[k] = f.as_serializable() if flush_after: f.clear_buffer() return return_filters
def __init__(self, env_creator, policy_graph, policy_mapping_fn=None, tf_session_creator=None, batch_steps=100, batch_mode="truncate_episodes", episode_horizon=None, preprocessor_pref="rllib", sample_async=False, compress_observations=False, num_envs=1, observation_filter="NoFilter", env_config=None, model_config=None, policy_config=None, worker_index=0): """Initialize a policy evaluator. Arguments: env_creator (func): Function that returns a gym.Env given an EnvContext wrapped configuration. policy_graph (class|dict): Either a class implementing PolicyGraph, or a dictionary of policy id strings to (PolicyGraph, obs_space, action_space, config) tuples. If a dict is specified, then we are in multi-agent mode and a policy_mapping_fn should also be set. policy_mapping_fn (func): A function that maps agent ids to policy ids in multi-agent mode. This function will be called each time a new agent appears in an episode, to bind that agent to a policy for the duration of the episode. tf_session_creator (func): A function that returns a TF session. This is optional and only useful with TFPolicyGraph. batch_steps (int): The target number of env transitions to include in each sample batch returned from this evaluator. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch of exactly `batch_steps` in size. Episodes may be truncated in order to meet this size requirement. When `num_envs > 1`, episodes will be truncated to sequences of `batch_size / num_envs` in length. "complete_episodes": Each call to sample() will return a batch of at least `batch_steps in size. Episodes will not be truncated, but multiple episodes may be packed within one batch to meet the batch size. Note that when `num_envs > 1`, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data. episode_horizon (int): Whether to stop episodes at this horizon. preprocessor_pref (str): Whether to prefer RLlib preprocessors ("rllib") or deepmind ("deepmind") when applicable. sample_async (bool): Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy. compress_observations (bool): If true, compress the observations returned. num_envs (int): If more than one, will create multiple envs and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv. observation_filter (str): Name of observation filter to use. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. In the multi-agent case, this config will be merged with the per-policy configs specified by `policy_graph`. worker_index (int): For remote evaluators, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker. """ env_context = EnvContext(env_config or {}, worker_index) policy_config = policy_config or {} self.policy_config = policy_config model_config = model_config or {} policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) self.env_creator = env_creator self.policy_graph = policy_graph self.batch_steps = batch_steps self.batch_mode = batch_mode self.compress_observations = compress_observations self.env = env_creator(env_context) if isinstance(self.env, VectorEnv) or \ isinstance(self.env, ServingEnv) or \ isinstance(self.env, MultiAgentEnv) or \ isinstance(self.env, AsyncVectorEnv): def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ "custom_preprocessor" not in model_config and \ preprocessor_pref == "deepmind": def wrap(env): return wrap_deepmind(env, dim=model_config.get("dim", 80)) else: def wrap(env): return ModelCatalog.get_preprocessor_as_wrapper( env, model_config) self.env = wrap(self.env) def make_env(): return wrap(env_creator(env_context)) self.tf_sess = None policy_dict = _validate_and_canonicalize(policy_graph, self.env) if _has_tensorflow_graph(policy_dict): with tf.Graph().as_default(): if tf_session_creator: self.tf_sess = tf_session_creator() else: self.tf_sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions(allow_growth=True))) with self.tf_sess.as_default(): self.policy_map = self._build_policy_map( policy_dict, policy_config) else: self.policy_map = self._build_policy_map(policy_dict, policy_config) self.multiagent = self.policy_map.keys() != set(DEFAULT_POLICY_ID) self.filters = { policy_id: get_filter(observation_filter, policy.observation_space.shape) for (policy_id, policy) in self.policy_map.items() } # Always use vector env for consistency even if num_envs = 1 self.async_env = AsyncVectorEnv.wrap_async(self.env, make_env=make_env, num_envs=num_envs) if self.batch_mode == "truncate_episodes": if batch_steps % num_envs != 0: raise ValueError( "In 'truncate_episodes' batch mode, `batch_steps` must be " "evenly divisible by `num_envs`. Got {} and {}.".format( batch_steps, num_envs)) batch_steps = batch_steps // num_envs pack_episodes = True elif self.batch_mode == "complete_episodes": batch_steps = float("inf") # never cut episodes pack_episodes = False # sampler will return 1 episode per poll else: raise ValueError("Unsupported batch mode: {}".format( self.batch_mode)) if sample_async: self.sampler = AsyncSampler(self.async_env, self.policy_map, policy_mapping_fn, self.filters, batch_steps, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess) self.sampler.start() else: self.sampler = SyncSampler(self.async_env, self.policy_map, policy_mapping_fn, self.filters, batch_steps, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess)
def __init__(self, env_creator, policy_graph, tf_session_creator=None, batch_steps=100, batch_mode="truncate_episodes", preprocessor_pref="rllib", sample_async=False, compress_observations=False, observation_filter="NoFilter", registry=None, env_config=None, model_config=None, policy_config=None): """Initialize a policy evaluator. Arguments: env_creator (func): Function that returns a gym.Env given an env config dict. policy_graph (class): A class implementing rllib.PolicyGraph or rllib.TFPolicyGraph. tf_session_creator (func): A function that returns a TF session. This is optional and only useful with TFPolicyGraph. batch_steps (int): The target number of env transitions to include in each sample batch returned from this evaluator. batch_mode (str): One of the following choices: complete_episodes: each batch will be at least batch_steps in size, and will include one or more complete episodes. truncate_episodes: each batch will be around batch_steps in size, and include transitions from one episode only. pack_episodes: each batch will be exactly batch_steps in size, and may include transitions from multiple episodes. preprocessor_pref (str): Whether to prefer RLlib preprocessors ("rllib") or deepmind ("deepmind") when applicable. sample_async (bool): Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy. compress_observations (bool): If true, compress the observations returned. observation_filter (str): Name of observation filter to use. registry (tune.Registry): User-registered objects. Pass in the value from tune.registry.get_registry() if you're having trouble resolving things like custom envs. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. """ registry = registry or get_registry() env_config = env_config or {} policy_config = policy_config or {} model_config = model_config or {} assert batch_mode in [ "complete_episodes", "truncate_episodes", "pack_episodes" ] self.env_creator = env_creator self.policy_graph = policy_graph self.batch_steps = batch_steps self.batch_mode = batch_mode self.compress_observations = compress_observations self.env = env_creator(env_config) is_atari = hasattr(self.env.unwrapped, "ale") if is_atari and "custom_preprocessor" not in model_config and \ preprocessor_pref == "deepmind": self.env = wrap_deepmind(self.env, dim=model_config.get("dim", 80)) else: self.env = ModelCatalog.get_preprocessor_as_wrapper( registry, self.env, model_config) self.vectorized = hasattr(self.env, "vector_reset") self.policy_map = {} if issubclass(policy_graph, TFPolicyGraph): with tf.Graph().as_default(): if tf_session_creator: self.sess = tf_session_creator() else: self.sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions(allow_growth=True))) with self.sess.as_default(): policy = policy_graph(self.env.observation_space, self.env.action_space, registry, policy_config) else: policy = policy_graph(self.env.observation_space, self.env.action_space, registry, policy_config) self.policy_map = {"default": policy} self.obs_filter = get_filter(observation_filter, self.env.observation_space.shape) self.filters = {"obs_filter": self.obs_filter} if self.vectorized: raise NotImplementedError("Vector envs not yet supported") else: if batch_mode not in [ "pack_episodes", "truncate_episodes", "complete_episodes" ]: raise NotImplementedError("Batch mode not yet supported") pack = batch_mode == "pack_episodes" if batch_mode == "complete_episodes": batch_steps = 999999 if sample_async: self.sampler = AsyncSampler(self.env, self.policy_map["default"], self.obs_filter, batch_steps, pack=pack) self.sampler.start() else: self.sampler = SyncSampler(self.env, self.policy_map["default"], self.obs_filter, batch_steps, pack=pack)