def __init__(self, env, policies, policy_mapping_fn, preprocessors, obs_filters, clip_rewards, unroll_length, callbacks, horizon=None, pack=False, tf_sess=None, clip_actions=True): self.base_env = BaseEnv.to_base_env(env) self.unroll_length = unroll_length self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn self.preprocessors = preprocessors self.obs_filters = obs_filters self.extra_batches = queue.Queue() self.rollout_provider = _env_runner( self.base_env, self.extra_batches.put, self.policies, self.policy_mapping_fn, self.unroll_length, self.horizon, self.preprocessors, self.obs_filters, clip_rewards, clip_actions, pack, callbacks, tf_sess) self.metrics_queue = queue.Queue()
def model_vector_env(env: EnvType) -> BaseEnv: """Returns a VectorizedEnv wrapper around the given environment. To obtain worker configs, one can call get_global_worker(). Args: env (EnvType): The input environment (of any supported environment type) to be convert to a _VectorizedModelGymEnv (wrapped as an RLlib BaseEnv). Returns: BaseEnv: The BaseEnv converted input `env`. """ worker = get_global_worker() worker_index = worker.worker_index if worker_index: env = _VectorizedModelGymEnv( make_env=worker.make_env_fn, existing_envs=[env], num_envs=worker.num_envs, observation_space=env.observation_space, action_space=env.action_space, ) return BaseEnv.to_base_env(env, make_env=worker.make_env_fn, num_envs=worker.num_envs, remote_envs=False, remote_env_batch_wait_ms=0)
def __init__(self, env, policies, policy_mapping_fn, preprocessors, obs_filters, clip_rewards, rollout_fragment_length, callbacks, horizon=None, pack=False, tf_sess=None, clip_actions=True, soft_horizon=False, no_done_at_end=False): self.base_env = BaseEnv.to_base_env(env) self.rollout_fragment_length = rollout_fragment_length self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn self.preprocessors = preprocessors self.obs_filters = obs_filters self.extra_batches = queue.Queue() self.perf_stats = PerfStats() self.rollout_provider = _env_runner( self.base_env, self.extra_batches.put, self.policies, self.policy_mapping_fn, self.rollout_fragment_length, self.horizon, self.preprocessors, self.obs_filters, clip_rewards, clip_actions, pack, callbacks, tf_sess, self.perf_stats, soft_horizon, no_done_at_end) self.metrics_queue = queue.Queue()
def __init__(self, env, policies, policies_to_train, policy_config, policy_mapping_fn, preprocessors, obs_filters, observation_filter, clip_rewards, unroll_length, callbacks, horizon=None, pack=False, tf_sess=None, clip_actions=True, blackhole_outputs=False, soft_horizon=False, no_done_at_end=False): #===MOD=== for _, f in obs_filters.items(): assert getattr(f, "is_concurrent", False), \ "Observation Filter must support concurrent updates." self.base_env = BaseEnv.to_base_env(env) threading.Thread.__init__(self) self.queue = queue.Queue(5) self.extra_batches = queue.Queue() self.metrics_queue = queue.Queue() self.unroll_length = unroll_length self.horizon = horizon self.policies = policies #===MOD=== # Added: self.policies_to_train = policies_to_train self.policy_config = policy_config self.observation_filter = observation_filter #===MOD=== self.policy_mapping_fn = policy_mapping_fn self.preprocessors = preprocessors self.obs_filters = obs_filters self.clip_rewards = clip_rewards self.daemon = True self.pack = pack self.tf_sess = tf_sess self.callbacks = callbacks self.clip_actions = clip_actions self.blackhole_outputs = blackhole_outputs self.soft_horizon = soft_horizon self.no_done_at_end = no_done_at_end self.perf_stats = PerfStats() self.shutdown = False
def __init__(self, env, policies, policies_to_train, policy_config, policy_mapping_fn, preprocessors, obs_filters, observation_filter, clip_rewards, unroll_length, callbacks, horizon=None, pack=False, tf_sess=None, clip_actions=True, soft_horizon=False, no_done_at_end=False): #===MOD=== self.base_env = BaseEnv.to_base_env(env) self.unroll_length = unroll_length self.horizon = horizon self.policies = policies #===MOD=== # Added: self.policies_to_train = policies_to_train self.policy_config = policy_config self.observation_filter = observation_filter #===MOD=== self.policy_mapping_fn = policy_mapping_fn self.preprocessors = preprocessors self.obs_filters = obs_filters self.extra_batches = queue.Queue() self.perf_stats = PerfStats() #===MOD=== """ Added arguments: self.policies_to_train, self.policy_config, self.observation_filter, """ self.rollout_provider = _env_runner( self.base_env, self.extra_batches.put, self.policies, self.policies_to_train, self.policy_config, self.observation_filter, self.policy_mapping_fn, self.unroll_length, self.horizon, self.preprocessors, self.obs_filters, clip_rewards, clip_actions, pack, callbacks, tf_sess, self.perf_stats, soft_horizon, no_done_at_end) #===MOD=== self.metrics_queue = queue.Queue()
def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]: """Atari games have multiple logical episodes, one per life. However, for metrics reporting we count full episodes, all lives included. """ unwrapped = base_env.get_unwrapped() if not unwrapped: return None atari_out = [] for u in unwrapped: monitor = get_wrapper_by_cls(u, MonitorEnv) if not monitor: return None for eps_rew, eps_len in monitor.next_episode_results(): atari_out.append(RolloutMetrics(eps_len, eps_rew)) return atari_out
def custom_model_vector_env(env): """Returns a VectorizedEnv wrapper around the current envioronment To obtain worker configs, one can call get_global_worker(). """ worker = get_global_worker() worker_index = worker.worker_index if worker_index: env = _VectorizedModelGymEnv( make_env=worker.make_env_fn, existing_envs=[env], num_envs=worker.num_envs, observation_space=env.observation_space, action_space=env.action_space, ) return BaseEnv.to_base_env( env, make_env=worker.make_env_fn, num_envs=worker.num_envs, remote_envs=False, remote_env_batch_wait_ms=0)
def __init__(self, env, policies, policy_mapping_fn, preprocessors, obs_filters, clip_rewards, unroll_length, callbacks, horizon=None, pack=False, tf_sess=None, clip_actions=True, blackhole_outputs=False): for _, f in obs_filters.items(): assert getattr(f, "is_concurrent", False), \ "Observation Filter must support concurrent updates." self.base_env = BaseEnv.to_base_env(env) threading.Thread.__init__(self) self.queue = queue.Queue(5) self.extra_batches = queue.Queue() self.metrics_queue = queue.Queue() self.unroll_length = unroll_length self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn self.preprocessors = preprocessors self.obs_filters = obs_filters self.clip_rewards = clip_rewards self.daemon = True self.pack = pack self.tf_sess = tf_sess self.callbacks = callbacks self.clip_actions = clip_actions self.blackhole_outputs = blackhole_outputs self.shutdown = False
def __init__( self, *, worker: "RolloutWorker", env: BaseEnv, clip_rewards: bool, rollout_fragment_length: int, count_steps_by: str = "env_steps", callbacks: "DefaultCallbacks", horizon: int = None, multiple_episodes_in_batch: bool = False, normalize_actions: bool = True, clip_actions: bool = False, blackhole_outputs: bool = False, soft_horizon: bool = False, no_done_at_end: bool = False, observation_fn: "ObservationFunction" = None, sample_collector_class: Optional[Type[SampleCollector]] = None, render: bool = False, # Obsolete. policies=None, policy_mapping_fn=None, preprocessors=None, obs_filters=None, tf_sess=None, ): """Initializes a AsyncSampler object. Args: worker (RolloutWorker): The RolloutWorker that will use this Sampler for sampling. env (Env): Any Env object. Will be converted into an RLlib BaseEnv. clip_rewards (Union[bool, float]): True for +/-1.0 clipping, actual float value for +/- value clipping. False for no clipping. rollout_fragment_length (int): The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object. count_steps_by (str): Either "env_steps" or "agent_steps". Refers to the unit of `rollout_fragment_length`. callbacks (Callbacks): The Callbacks object to use when episode events happen during rollout. horizon (Optional[int]): Hard-reset the Env multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. normalize_actions (bool): Whether to normalize actions to the action space's bounds. clip_actions (bool): Whether to clip actions according to the given action_space's bounds. blackhole_outputs (bool): Whether to collect samples, but then not further process or store them (throw away all samples). soft_horizon (bool): If True, calculate bootstrapped values as if episode had ended, but don't physically reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. observation_fn (Optional[ObservationFunction]): Optional multi-agent observation func to use for preprocessing observations. sample_collector_class (Optional[Type[SampleCollector]]): An optional Samplecollector sub-class to use to collect, store, and retrieve environment-, model-, and sampler data. render (bool): Whether to try to render the environment after each step. """ # All of the following arguments are deprecated. They will instead be # provided via the passed in `worker` arg, e.g. `worker.policy_map`. if log_once("deprecated_async_sampler_args"): if policies is not None: deprecation_warning(old="policies") if policy_mapping_fn is not None: deprecation_warning(old="policy_mapping_fn") if preprocessors is not None: deprecation_warning(old="preprocessors") if obs_filters is not None: deprecation_warning(old="obs_filters") if tf_sess is not None: deprecation_warning(old="tf_sess") self.worker = worker for _, f in worker.filters.items(): assert getattr(f, "is_concurrent", False), \ "Observation Filter must support concurrent updates." self.base_env = BaseEnv.to_base_env(env) threading.Thread.__init__(self) self.queue = queue.Queue(5) self.extra_batches = queue.Queue() self.metrics_queue = queue.Queue() self.rollout_fragment_length = rollout_fragment_length self.horizon = horizon self.clip_rewards = clip_rewards self.daemon = True self.multiple_episodes_in_batch = multiple_episodes_in_batch self.callbacks = callbacks self.normalize_actions = normalize_actions self.clip_actions = clip_actions self.blackhole_outputs = blackhole_outputs self.soft_horizon = soft_horizon self.no_done_at_end = no_done_at_end self.perf_stats = _PerfStats() self.shutdown = False self.observation_fn = observation_fn self.render = render if not sample_collector_class: sample_collector_class = SimpleListCollector self.sample_collector = sample_collector_class( worker.policy_map, clip_rewards, callbacks, multiple_episodes_in_batch, rollout_fragment_length, count_steps_by=count_steps_by)
def __init__( self, *, env_creator: Callable[[EnvContext], EnvType], validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None, policy_spec: Union[type, Dict[ str, Tuple[Optional[type], gym.Space, gym.Space, PartialTrainerConfigDict]]] = None, policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None, policies_to_train: Optional[List[PolicyID]] = None, tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None, rollout_fragment_length: int = 100, batch_mode: str = "truncate_episodes", episode_horizon: int = None, preprocessor_pref: str = "deepmind", sample_async: bool = False, compress_observations: bool = False, num_envs: int = 1, observation_fn: "ObservationFunction" = None, observation_filter: str = "NoFilter", clip_rewards: bool = None, clip_actions: bool = True, env_config: EnvConfigDict = None, model_config: ModelConfigDict = None, policy_config: TrainerConfigDict = None, worker_index: int = 0, num_workers: int = 0, monitor_path: str = None, log_dir: str = None, log_level: str = None, callbacks: Type["DefaultCallbacks"] = None, input_creator: Callable[[ IOContext ], InputReader] = lambda ioctx: ioctx.default_sampler_input(), input_evaluation: List[str] = frozenset([]), output_creator: Callable[ [IOContext], OutputWriter] = lambda ioctx: NoopOutput(), remote_worker_envs: bool = False, remote_env_batch_wait_ms: int = 0, soft_horizon: bool = False, no_done_at_end: bool = False, seed: int = None, extra_python_environs: dict = None, fake_sampler: bool = False, spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]]] = None, policy: Union[type, Dict[ str, Tuple[Optional[type], gym.Space, gym.Space, PartialTrainerConfigDict]]] = None, ): """Initialize a rollout worker. Args: env_creator (Callable[[EnvContext], EnvType]): Function that returns a gym.Env given an EnvContext wrapped configuration. validate_env (Optional[Callable[[EnvType, EnvContext], None]]): Optional callable to validate the generated environment (only on worker=0). policy_spec (Union[type, Dict[str, Tuple[Type[Policy], gym.Space, gym.Space, PartialTrainerConfigDict]]]): Either a Policy class or a dict of policy id strings to (Policy class, obs_space, action_space, config)-tuples. If a dict is specified, then we are in multi-agent mode and a policy_mapping_fn can also be set (if not, will map all agents to DEFAULT_POLICY_ID). policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): A callable that maps agent ids to policy ids in multi-agent mode. This function will be called each time a new agent appears in an episode, to bind that agent to a policy for the duration of the episode. If not provided, will map all agents to DEFAULT_POLICY_ID. policies_to_train (Optional[List[PolicyID]]): Optional list of policies to train, or None for all policies. tf_session_creator (Optional[Callable[[], tf1.Session]]): A function that returns a TF session. This is optional and only useful with TFPolicy. rollout_fragment_length (int): The target number of env transitions to include in each sample batch returned from this worker. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch of at most `rollout_fragment_length * num_envs` in size. The batch will be exactly `rollout_fragment_length * num_envs` in size if postprocessing does not change batch sizes. Episodes may be truncated in order to meet this size requirement. "complete_episodes": Each call to sample() will return a batch of at least `rollout_fragment_length * num_envs` in size. Episodes will not be truncated, but multiple episodes may be packed within one batch to meet the batch size. Note that when `num_envs > 1`, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data. episode_horizon (int): Whether to stop episodes at this horizon. preprocessor_pref (str): Whether to prefer RLlib preprocessors ("rllib") or deepmind ("deepmind") when applicable. sample_async (bool): Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy. compress_observations (bool): If true, compress the observations. They can be decompressed with rllib/utils/compression. num_envs (int): If more than one, will create multiple envs and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv. observation_fn (ObservationFunction): Optional multi-agent observation function. observation_filter (str): Name of observation filter to use. clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to experience postprocessing. Setting to None means clip for Atari only. clip_actions (bool): Whether to clip action values to the range specified by the policy action space. env_config (EnvConfigDict): Config to pass to the env creator. model_config (ModelConfigDict): Config to use when creating the policy model. policy_config (TrainerConfigDict): Config to pass to the policy. In the multi-agent case, this config will be merged with the per-policy configs specified by `policy_spec`. worker_index (int): For remote workers, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker. num_workers (int): For remote workers, how many workers altogether have been created? monitor_path (str): Write out episode stats and videos to this directory if specified. log_dir (str): Directory where logs can be placed. log_level (str): Set the root log level on creation. callbacks (DefaultCallbacks): Custom training callbacks. input_creator (Callable[[IOContext], InputReader]): Function that returns an InputReader object for loading previous generated experiences. input_evaluation (List[str]): How to evaluate the policy performance. This only makes sense to set when the input is reading offline data. The possible values include: - "is": the step-wise importance sampling estimator. - "wis": the weighted step-wise is estimator. - "simulation": run the environment in the background, but use this data for evaluation only and never for learning. output_creator (Callable[[IOContext], OutputWriter]): Function that returns an OutputWriter object for saving generated experiences. remote_worker_envs (bool): If using num_envs > 1, whether to create those new envs in remote processes instead of in the current process. This adds overheads, but can make sense if your envs remote_env_batch_wait_ms (float): Timeout that remote workers are waiting when polling environments. 0 (continue when at least one env is ready) is a reasonable default, but optimal value could be obtained by measuring your environment step / reset and model inference perf. soft_horizon (bool): Calculate rewards but don't reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. seed (int): Set the seed of both np and tf to this value to to ensure each remote worker has unique exploration behavior. extra_python_environs (dict): Extra python environments need to be set. fake_sampler (bool): Use a fake (inf speed) sampler for testing. spaces (Optional[Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]]]): An optional space dict mapping policy IDs to (obs_space, action_space)-tuples. This is used in case no Env is created on this RolloutWorker. policy: Obsoleted arg. Use `policy_spec` instead. """ # Deprecated arg. if policy is not None: deprecation_warning("policy", "policy_spec", error=False) policy_spec = policy assert policy_spec is not None, "Must provide `policy_spec` when " \ "creating RolloutWorker!" self._original_kwargs: dict = locals().copy() del self._original_kwargs["self"] global _global_worker _global_worker = self # set extra environs first if extra_python_environs: for key, value in extra_python_environs.items(): os.environ[key] = str(value) def gen_rollouts(): while True: yield self.sample() ParallelIteratorWorker.__init__(self, gen_rollouts, False) policy_config: TrainerConfigDict = policy_config or {} if (tf1 and policy_config.get("framework") in ["tf2", "tfe"] # This eager check is necessary for certain all-framework tests # that use tf's eager_mode() context generator. and not tf1.executing_eagerly()): tf1.enable_eager_execution() if log_level: logging.getLogger("ray.rllib").setLevel(log_level) if worker_index > 1: disable_log_once_globally() # only need 1 worker to log elif log_level == "DEBUG": enable_periodic_logging() env_context = EnvContext(env_config or {}, worker_index) self.env_context = env_context self.policy_config: TrainerConfigDict = policy_config if callbacks: self.callbacks: "DefaultCallbacks" = callbacks() else: from ray.rllib.agents.callbacks import DefaultCallbacks self.callbacks: "DefaultCallbacks" = DefaultCallbacks() self.worker_index: int = worker_index self.num_workers: int = num_workers model_config: ModelConfigDict = model_config or {} policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) if not callable(policy_mapping_fn): raise ValueError("Policy mapping function not callable?") self.env_creator: Callable[[EnvContext], EnvType] = env_creator self.rollout_fragment_length: int = rollout_fragment_length * num_envs self.batch_mode: str = batch_mode self.compress_observations: bool = compress_observations self.preprocessing_enabled: bool = True self.last_batch: SampleBatchType = None self.global_vars: dict = None self.fake_sampler: bool = fake_sampler # No Env will be used in this particular worker (not needed). if worker_index == 0 and num_workers > 0 and \ policy_config["create_env_on_driver"] is False: self.env = None # Create an env for this worker. else: self.env = _validate_env(env_creator(env_context)) if validate_env is not None: validate_env(self.env, self.env_context) if isinstance(self.env, (BaseEnv, MultiAgentEnv)): def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ not model_config.get("custom_preprocessor") and \ preprocessor_pref == "deepmind": # Deepmind wrappers already handle all preprocessing. self.preprocessing_enabled = False # If clip_rewards not explicitly set to False, switch it # on here (clip between -1.0 and 1.0). if clip_rewards is None: clip_rewards = True def wrap(env): env = wrap_deepmind( env, dim=model_config.get("dim"), framestack=model_config.get("framestack")) if monitor_path: from gym import wrappers env = wrappers.Monitor(env, monitor_path, resume=True) return env else: def wrap(env): if monitor_path: from gym import wrappers env = wrappers.Monitor(env, monitor_path, resume=True) return env self.env: EnvType = wrap(self.env) def make_env(vector_index): return wrap( env_creator( env_context.copy_with_overrides( worker_index=worker_index, vector_index=vector_index, remote=remote_worker_envs))) self.make_env_fn = make_env self.tf_sess = None policy_dict = _validate_and_canonicalize( policy_spec, self.env, spaces=spaces) self.policies_to_train: List[PolicyID] = policies_to_train or list( policy_dict.keys()) self.policy_map: Dict[PolicyID, Policy] = None self.preprocessors: Dict[PolicyID, Preprocessor] = None # set numpy and python seed if seed is not None: np.random.seed(seed) random.seed(seed) if not hasattr(self.env, "seed"): logger.info("Env doesn't support env.seed(): {}".format( self.env)) else: self.env.seed(seed) try: assert torch is not None torch.manual_seed(seed) except AssertionError: logger.info("Could not seed torch") if _has_tensorflow_graph(policy_dict) and not ( tf1 and tf1.executing_eagerly()): if not tf1: raise ImportError("Could not import tensorflow") with tf1.Graph().as_default(): if tf_session_creator: self.tf_sess = tf_session_creator() else: self.tf_sess = tf1.Session( config=tf1.ConfigProto( gpu_options=tf1.GPUOptions(allow_growth=True))) with self.tf_sess.as_default(): # set graph-level seed if seed is not None: tf1.set_random_seed(seed) self.policy_map, self.preprocessors = \ self._build_policy_map(policy_dict, policy_config) else: self.policy_map, self.preprocessors = self._build_policy_map( policy_dict, policy_config) if (ray.is_initialized() and ray.worker._mode() != ray.worker.LOCAL_MODE): # Check available number of GPUs if not ray.get_gpu_ids(): logger.debug("Creating policy evaluation worker {}".format( worker_index) + " on CPU (please ignore any CUDA init errors)") elif (policy_config["framework"] in ["tf2", "tf", "tfe"] and not tf.config.experimental.list_physical_devices("GPU")) or \ (policy_config["framework"] == "torch" and not torch.cuda.is_available()): raise RuntimeError( "GPUs were assigned to this worker by Ray, but " "your DL framework ({}) reports GPU acceleration is " "disabled. This could be due to a bad CUDA- or {} " "installation.".format(policy_config["framework"], policy_config["framework"])) self.multiagent: bool = set( self.policy_map.keys()) != {DEFAULT_POLICY_ID} if self.multiagent and self.env is not None: if not ((isinstance(self.env, MultiAgentEnv) or isinstance(self.env, ExternalMultiAgentEnv)) or isinstance(self.env, BaseEnv)): raise ValueError( "Have multiple policies {}, but the env ".format( self.policy_map) + "{} is not a subclass of BaseEnv, MultiAgentEnv or " "ExternalMultiAgentEnv?".format(self.env)) self.filters: Dict[PolicyID, Filter] = { policy_id: get_filter(observation_filter, policy.observation_space.shape) for (policy_id, policy) in self.policy_map.items() } if self.worker_index == 0: logger.info("Built filter map: {}".format(self.filters)) self.num_envs: int = num_envs if self.env is None: self.async_env = None elif "custom_vector_env" in policy_config: custom_vec_wrapper = policy_config["custom_vector_env"] self.async_env = custom_vec_wrapper(self.env) else: # Always use vector env for consistency even if num_envs = 1. self.async_env: BaseEnv = BaseEnv.to_base_env( self.env, make_env=make_env, num_envs=num_envs, remote_envs=remote_worker_envs, remote_env_batch_wait_ms=remote_env_batch_wait_ms) # `truncate_episodes`: Allow a batch to contain more than one episode # (fragments) and always make the batch `rollout_fragment_length` # long. if self.batch_mode == "truncate_episodes": pack = True # `complete_episodes`: Never cut episodes and sampler will return # exactly one (complete) episode per poll. elif self.batch_mode == "complete_episodes": rollout_fragment_length = float("inf") pack = False else: raise ValueError("Unsupported batch mode: {}".format( self.batch_mode)) self.io_context: IOContext = IOContext(log_dir, policy_config, worker_index, self) self.reward_estimators: List[OffPolicyEstimator] = [] for method in input_evaluation: if method == "simulation": logger.warning( "Requested 'simulation' input evaluation method: " "will discard all sampler outputs and keep only metrics.") sample_async = True elif method == "is": ise = ImportanceSamplingEstimator.create(self.io_context) self.reward_estimators.append(ise) elif method == "wis": wise = WeightedImportanceSamplingEstimator.create( self.io_context) self.reward_estimators.append(wise) else: raise ValueError( "Unknown evaluation method: {}".format(method)) if self.env is None: self.sampler = None elif sample_async: self.sampler = AsyncSampler( worker=self, env=self.async_env, policies=self.policy_map, policy_mapping_fn=policy_mapping_fn, preprocessors=self.preprocessors, obs_filters=self.filters, clip_rewards=clip_rewards, rollout_fragment_length=rollout_fragment_length, callbacks=self.callbacks, horizon=episode_horizon, multiple_episodes_in_batch=pack, tf_sess=self.tf_sess, clip_actions=clip_actions, blackhole_outputs="simulation" in input_evaluation, soft_horizon=soft_horizon, no_done_at_end=no_done_at_end, observation_fn=observation_fn, _use_trajectory_view_api=policy_config.get( "_use_trajectory_view_api", False)) # Start the Sampler thread. self.sampler.start() else: self.sampler = SyncSampler( worker=self, env=self.async_env, policies=self.policy_map, policy_mapping_fn=policy_mapping_fn, preprocessors=self.preprocessors, obs_filters=self.filters, clip_rewards=clip_rewards, rollout_fragment_length=rollout_fragment_length, callbacks=self.callbacks, horizon=episode_horizon, multiple_episodes_in_batch=pack, tf_sess=self.tf_sess, clip_actions=clip_actions, soft_horizon=soft_horizon, no_done_at_end=no_done_at_end, observation_fn=observation_fn, _use_trajectory_view_api=policy_config.get( "_use_trajectory_view_api", False)) self.input_reader: InputReader = input_creator(self.io_context) self.output_writer: OutputWriter = output_creator(self.io_context) logger.debug( "Created rollout worker with env {} ({}), policies {}".format( self.async_env, self.env, self.policy_map))
return env.action_space.sample() def attach_preprocessors(self, preprocessors): for agent_name, agent in self.agents.items(): agent.preprocessor = preprocessors[agent.type] if __name__ == '__main__': from utils import * import pdb # from gazebo_tb_mpc_env import * # rospy.init_node('GazeboAgentNode') config = get_configuration('../tb_simple_world') # rosrate = rospy.Rate(1) # 10hz # config['rr'] = rosrate from gazebo_tb_agent import GazeboTurtleBotAgent config['agents']['waffle']['agent_class'] = GazeboTurtleBotAgent print(config) env = GazeboEnv(config) o = env.reset(env.get_random_state()) benv = BaseEnv.to_base_env(env) benv.poll() # t = time.time() # for i in range(100): # env.step(env.get_random_action()) # print(time.time() - t) # print([x['gt_pose'][:2] for x in o.values()]) # print([env.get_random_action()]) # print([a.get_random_action() for a in env.agents.values()]) pdb.set_trace()
def test_nested_tuple_async(self): self.do_test_nested_tuple( lambda _: BaseEnv.to_base_env(NestedTupleEnv()))
def test_nested_dict_async(self): self.do_test_nested_dict( lambda _: BaseEnv.to_base_env(NestedDictEnv()))
def testNestedDictAsync(self): self.doTestNestedDict(lambda _: BaseEnv.to_base_env(NestedDictEnv()))
def env_maker(config, return_agents=False): env = env_cls(config) env = BaseEnv.to_base_env(env) if return_agents: return env, env.envs[0].agents return env
def __init__(self, *, worker: "RolloutWorker", env: BaseEnv, policies: Dict[PolicyID, Policy], policy_mapping_fn: Callable[[AgentID], PolicyID], preprocessors: Dict[PolicyID, Preprocessor], obs_filters: Dict[PolicyID, Filter], clip_rewards: bool, rollout_fragment_length: int, callbacks: "DefaultCallbacks", horizon: int = None, pack_multiple_episodes_in_batch: bool = False, tf_sess=None, clip_actions: bool = True, soft_horizon: bool = False, no_done_at_end: bool = False, observation_fn: "ObservationFunction" = None, _use_trajectory_view_api: bool = False): """Initializes a SyncSampler object. Args: worker (RolloutWorker): The RolloutWorker that will use this Sampler for sampling. env (Env): Any Env object. Will be converted into an RLlib BaseEnv. policies (Dict[str,Policy]): Mapping from policy ID to Policy obj. policy_mapping_fn (callable): Callable that takes an agent ID and returns a Policy object. preprocessors (Dict[str,Preprocessor]): Mapping from policy ID to Preprocessor object for the observations prior to filtering. obs_filters (Dict[str,Filter]): Mapping from policy ID to env Filter object. clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual float value for +/- value clipping. False for no clipping. rollout_fragment_length (int): The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object. callbacks (Callbacks): The Callbacks object to use when episode events happen during rollout. horizon (Optional[int]): Hard-reset the Env pack_multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. tf_sess (Optional[tf.Session]): A tf.Session object to use (only if framework=tf). clip_actions (bool): Whether to clip actions according to the given action_space's bounds. soft_horizon (bool): If True, calculate bootstrapped values as if episode had ended, but don't physically reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. observation_fn (Optional[ObservationFunction]): Optional multi-agent observation func to use for preprocessing observations. _use_trajectory_view_api (bool): Whether to use the (experimental) `_use_trajectory_view_api` to make generic trajectory views available to Models. Default: False. """ self.base_env = BaseEnv.to_base_env(env) self.rollout_fragment_length = rollout_fragment_length self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn self.preprocessors = preprocessors self.obs_filters = obs_filters self.extra_batches = queue.Queue() self.perf_stats = _PerfStats() # Create the rollout generator to use for calls to `get_data()`. self.rollout_provider = _env_runner( worker, self.base_env, self.extra_batches.put, self.policies, self.policy_mapping_fn, self.rollout_fragment_length, self.horizon, self.preprocessors, self.obs_filters, clip_rewards, clip_actions, pack_multiple_episodes_in_batch, callbacks, tf_sess, self.perf_stats, soft_horizon, no_done_at_end, observation_fn, _use_trajectory_view_api) self.metrics_queue = queue.Queue()
def _env_runner( worker: "RolloutWorker", base_env: BaseEnv, extra_batch_callback: Callable[[SampleBatchType], None], policies: Dict[PolicyID, Policy], policy_mapping_fn: Callable[[AgentID], PolicyID], rollout_fragment_length: int, horizon: int, preprocessors: Dict[PolicyID, Preprocessor], obs_filters: Dict[PolicyID, Filter], clip_rewards: bool, clip_actions: bool, multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", tf_sess: Optional["tf.Session"], perf_stats: _PerfStats, soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", sample_collector: Optional[SampleCollector] = None, render: bool = None, ) -> Iterable[SampleBatchType]: """This implements the common experience collection logic. Args: worker (RolloutWorker): Reference to the current rollout worker. base_env (BaseEnv): Env implementing BaseEnv. extra_batch_callback (fn): function to send extra batch data to. policies (Dict[PolicyID, Policy]): Map of policy ids to Policy instances. policy_mapping_fn (func): Function that maps agent ids to policy ids. This is called when an agent first enters the environment. The agent is then "bound" to the returned policy for the episode. rollout_fragment_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. horizon (int): Horizon of the episode. preprocessors (dict): Map of policy id to preprocessor for the observations prior to filtering. obs_filters (dict): Map of policy id to filter used to process observations for the policy. clip_rewards (bool): Whether to clip rewards before postprocessing. multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. clip_actions (bool): Whether to clip actions to the space range. callbacks (DefaultCallbacks): User callbacks to run on episode events. tf_sess (Session|None): Optional tensorflow session to use for batching TF policy evaluations. perf_stats (_PerfStats): Record perf stats into this object. soft_horizon (bool): Calculate rewards but don't reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. observation_fn (ObservationFunction): Optional multi-agent observation func to use for preprocessing observations. sample_collector (Optional[SampleCollector]): An optional SampleCollector object to use. render (bool): Whether to try to render the environment after each step. Yields: rollout (SampleBatch): Object containing state, action, reward, terminal condition, and other fields as dictated by `policy`. """ # May be populated with used for image rendering simple_image_viewer: Optional["SimpleImageViewer"] = None # Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore # error and continue with max_episode_steps=None. max_episode_steps = None try: max_episode_steps = base_env.get_unwrapped()[0].spec.max_episode_steps except Exception: pass # Trainer has a given `horizon` setting. if horizon: # `horizon` is larger than env's limit. if max_episode_steps and horizon > max_episode_steps: # Try to override the env's own max-step setting with our horizon. # If this won't work, throw an error. try: base_env.get_unwrapped()[0].spec.max_episode_steps = horizon base_env.get_unwrapped()[0]._max_episode_steps = horizon except Exception: raise ValueError( "Your `horizon` setting ({}) is larger than the Env's own " "timestep limit ({}), which seems to be unsettable! Try " "to increase the Env's built-in limit to be at least as " "large as your wanted `horizon`.".format( horizon, max_episode_steps)) # Otherwise, set Trainer's horizon to env's max-steps. elif max_episode_steps: horizon = max_episode_steps logger.debug( "No episode horizon specified, setting it to Env's limit ({}).". format(max_episode_steps)) # No horizon/max_episode_steps -> Episodes may be infinitely long. else: horizon = float("inf") logger.debug("No episode horizon specified, assuming inf.") # Pool of batch builders, which can be shared across episodes to pack # trajectory data. batch_builder_pool: List[MultiAgentSampleBatchBuilder] = [] def get_batch_builder(): if batch_builder_pool: return batch_builder_pool.pop() else: return None def new_episode(env_id): episode = MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback, env_id=env_id) # Call each policy's Exploration.on_episode_start method. # type: Policy for p in policies.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_start(policy=p, environment=base_env, episode=episode, tf_sess=getattr( p, "_sess", None)) callbacks.on_episode_start( worker=worker, base_env=base_env, policies=policies, episode=episode, env_index=env_id, ) return episode active_episodes: Dict[str, MultiAgentEpisode] = \ NewEpisodeDefaultDict(new_episode) while True: perf_stats.iters += 1 t0 = time.time() # Get observations from all ready agents. # type: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ... unfiltered_obs, rewards, dones, infos, off_policy_actions = \ base_env.poll() perf_stats.env_wait_time += time.time() - t0 if log_once("env_returns"): logger.info("Raw obs from env: {}".format( summarize(unfiltered_obs))) logger.info("Info return from env: {}".format(summarize(infos))) # Process observations and prepare for policy evaluation. t1 = time.time() # type: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], # List[Union[RolloutMetrics, SampleBatchType]] active_envs, to_eval, outputs = \ _process_observations( worker=worker, base_env=base_env, policies=policies, active_episodes=active_episodes, unfiltered_obs=unfiltered_obs, rewards=rewards, dones=dones, infos=infos, horizon=horizon, preprocessors=preprocessors, obs_filters=obs_filters, multiple_episodes_in_batch=multiple_episodes_in_batch, callbacks=callbacks, soft_horizon=soft_horizon, no_done_at_end=no_done_at_end, observation_fn=observation_fn, sample_collector=sample_collector, ) perf_stats.raw_obs_processing_time += time.time() - t1 for o in outputs: yield o # Do batched policy eval (accross vectorized envs). t2 = time.time() # type: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]] eval_results = _do_policy_eval( to_eval=to_eval, policies=policies, sample_collector=sample_collector, active_episodes=active_episodes, tf_sess=tf_sess, ) perf_stats.inference_time += time.time() - t2 # Process results and update episode state. t3 = time.time() actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \ _process_policy_eval_results( to_eval=to_eval, eval_results=eval_results, active_episodes=active_episodes, active_envs=active_envs, off_policy_actions=off_policy_actions, policies=policies, clip_actions=clip_actions, ) perf_stats.action_processing_time += time.time() - t3 # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. t4 = time.time() base_env.send_actions(actions_to_send) perf_stats.env_wait_time += time.time() - t4 # Try to render the env, if required. if render: t5 = time.time() # Render can either return an RGB image (uint8 [w x h x 3] numpy # array) or take care of rendering itself (returning True). rendered = base_env.try_render() # Rendering returned an image -> Display it in a SimpleImageViewer. if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3: # ImageViewer not defined yet, try to create one. if simple_image_viewer is None: try: from gym.envs.classic_control.rendering import \ SimpleImageViewer simple_image_viewer = SimpleImageViewer() except (ImportError, ModuleNotFoundError): render = False # disable rendering logger.warning( "Could not import gym.envs.classic_control." "rendering! Try `pip install gym[all]`.") if simple_image_viewer: simple_image_viewer.imshow(rendered) perf_stats.env_render_time += time.time() - t5
def _process_observations( worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], batch_builder_pool: List[MultiAgentSampleBatchBuilder], active_episodes: Dict[str, MultiAgentEpisode], unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]], rewards: Dict[EnvID, Dict[AgentID, float]], dones: Dict[EnvID, Dict[AgentID, bool]], infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], horizon: int, preprocessors: Dict[PolicyID, Preprocessor], obs_filters: Dict[PolicyID, Filter], rollout_fragment_length: int, pack_multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", _use_trajectory_view_api: bool = False ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[ RolloutMetrics, SampleBatchType]]]: """Record new data from the environment and prepare for policy evaluation. Args: worker (RolloutWorker): Reference to the current rollout worker. base_env (BaseEnv): Env implementing BaseEnv. policies (dict): Map of policy ids to Policy instances. batch_builder_pool (List[SampleBatchBuilder]): List of pooled SampleBatchBuilder object for recycling. active_episodes (Dict[str, MultiAgentEpisode]): Mapping from episode ID to currently ongoing MultiAgentEpisode object. unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids -> unfiltered observation tensor, returned by a `BaseEnv.poll()` call. rewards (dict): Doubly keyed dict of env-ids -> agent ids -> rewards tensor, returned by a `BaseEnv.poll()` call. dones (dict): Doubly keyed dict of env-ids -> agent ids -> boolean done flags, returned by a `BaseEnv.poll()` call. infos (dict): Doubly keyed dict of env-ids -> agent ids -> info dicts, returned by a `BaseEnv.poll()` call. horizon (int): Horizon of the episode. preprocessors (dict): Map of policy id to preprocessor for the observations prior to filtering. obs_filters (dict): Map of policy id to filter used to process observations for the policy. rollout_fragment_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. pack_multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. callbacks (DefaultCallbacks): User callbacks to run on episode events. soft_horizon (bool): Calculate rewards but don't reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. observation_fn (ObservationFunction): Optional multi-agent observation func to use for preprocessing observations. _use_trajectory_view_api (bool): Whether to use the (experimental) `_use_trajectory_view_api` to make generic trajectory views available to Models. Default: False. Returns: Tuple: - active_envs: Set of non-terminated env ids. - to_eval: Map of policy_id to list of agent PolicyEvalData. - outputs: List of metrics and samples to return from the sampler. """ # Output objects. active_envs: Set[EnvID] = set() to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list) outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \ rollout_fragment_length != float("inf") else 5000 # For each environment. # type: EnvID, Dict[AgentID, EnvObsType] for env_id, agent_obs in unfiltered_obs.items(): is_new_episode: bool = env_id not in active_episodes episode: MultiAgentEpisode = active_episodes[env_id] if not is_new_episode: episode.length += 1 episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) if (episode.batch_builder.total() > large_batch_threshold and log_once("large_batch_warning")): logger.warning( "More than {} observations for {} env steps ".format( episode.batch_builder.total(), episode.batch_builder.count) + "are buffered in " "the sampler. If this is more than you expected, check " "that you set a horizon on your environment correctly and " "that it terminates at some point. " "Note: In multi-agent environments, `rollout_fragment_length` " "sets the batch size based on environment steps, not the " "steps of " "individual agents, which can result in unexpectedly large " "batches. Also, you may be in evaluation waiting for your Env " "to terminate (batch_mode=`complete_episodes`). Make sure it " "does at some point.") # Check episode termination conditions. if dones[env_id]["__all__"] or episode.length >= horizon: hit_horizon = (episode.length >= horizon and not dones[env_id]["__all__"]) all_agents_done = True atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics( base_env) if atari_metrics is not None: for m in atari_metrics: outputs.append( m._replace(custom_metrics=episode.custom_metrics)) else: outputs.append( RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics, {}, episode.hist_data)) else: hit_horizon = False all_agents_done = False active_envs.add(env_id) # Custom observation function is applied before preprocessing. if observation_fn: agent_obs: Dict[AgentID, EnvObsType] = observation_fn( agent_obs=agent_obs, worker=worker, base_env=base_env, policies=policies, episode=episode) if not isinstance(agent_obs, dict): raise ValueError( "observe() must return a dict of agent observations") # For each agent in the environment. # type: AgentID, EnvObsType for agent_id, raw_obs in agent_obs.items(): assert agent_id != "__all__" policy_id: PolicyID = episode.policy_for(agent_id) prep_obs: EnvObsType = _get_or_raise(preprocessors, policy_id).transform(raw_obs) if log_once("prep_obs"): logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs: EnvObsType = _get_or_raise(obs_filters, policy_id)(prep_obs) if log_once("filtered_obs"): logger.info("Filtered obs: {}".format(summarize(filtered_obs))) agent_done = bool(all_agents_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, infos[env_id].get(agent_id, {}), episode.rnn_state_for(agent_id), episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0)) last_observation: EnvObsType = episode.last_observation_for( agent_id) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) # Record transition info if applicable. if (last_observation is not None and infos[env_id].get( agent_id, {}).get("training_enabled", True)): episode.batch_builder.add_values( agent_id, policy_id, t=episode.length - 1, eps_id=episode.episode_id, agent_index=episode._agent_index(agent_id), obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], prev_actions=episode.prev_action_for(agent_id), prev_rewards=episode.prev_reward_for(agent_id), dones=(False if (no_done_at_end or (hit_horizon and soft_horizon)) else agent_done), infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) # Invoke the step callback after the step is logged to the episode callbacks.on_episode_step( worker=worker, base_env=base_env, episode=episode) # Cut the batch if ... # - all-agents-done and not packing multiple episodes into one # (batch_mode="complete_episodes") # - or if we've exceeded the rollout_fragment_length. if episode.batch_builder.has_pending_agent_data(): # Sanity check, whether all agents have done=True, if done[__all__] # is True. if dones[env_id]["__all__"] and not no_done_at_end: episode.batch_builder.check_missing_dones() # Reached end of episode and we are not allowed to pack the # next episode into the same SampleBatch -> Build the SampleBatch # and add it to "outputs". if (all_agents_done and not pack_multiple_episodes_in_batch) or \ episode.batch_builder.count >= rollout_fragment_length: outputs.append(episode.batch_builder.build_and_reset(episode)) # Make sure postprocessor stays within one episode. elif all_agents_done: episode.batch_builder.postprocess_batch_so_far(episode) # Episode is done. if all_agents_done: # Handle episode termination. batch_builder_pool.append(episode.batch_builder) # Call each policy's Exploration.on_episode_end method. for p in policies.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_end( policy=p, environment=base_env, episode=episode, tf_sess=getattr(p, "_sess", None)) # Call custom on_episode_end callback. callbacks.on_episode_end( worker=worker, base_env=base_env, policies=policies, episode=episode) if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs: Dict[AgentID, EnvObsType] = agent_obs else: del active_episodes[env_id] resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset( env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list. if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") elif resetted_obs != ASYNC_RESET_RETURN: # Creates a new episode if this is not async return. # If reset is async, we will get its result in some future poll episode: MultiAgentEpisode = active_episodes[env_id] if observation_fn: resetted_obs: Dict[AgentID, EnvObsType] = observation_fn( agent_obs=resetted_obs, worker=worker, base_env=base_env, policies=policies, episode=episode) # type: AgentID, EnvObsType for agent_id, raw_obs in resetted_obs.items(): policy_id: PolicyID = episode.policy_for(agent_id) policy: Policy = _get_or_raise(policies, policy_id) prep_obs: EnvObsType = _get_or_raise( preprocessors, policy_id).transform(raw_obs) filtered_obs: EnvObsType = _get_or_raise( obs_filters, policy_id)(prep_obs) episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData( env_id, agent_id, filtered_obs, episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), np.zeros_like( flatten_to_single_ndarray( policy.action_space.sample())), 0.0)) return active_envs, to_eval, outputs
def _env_runner(worker: "RolloutWorker", base_env: BaseEnv, extra_batch_callback: Callable[[SampleBatchType], None], policies: Dict[PolicyID, Policy], policy_mapping_fn: Callable[[AgentID], PolicyID], rollout_fragment_length: int, horizon: int, preprocessors: Dict[PolicyID, Preprocessor], obs_filters: Dict[PolicyID, Filter], clip_rewards: bool, clip_actions: bool, pack_multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", tf_sess: Optional["tf.Session"], perf_stats: _PerfStats, soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", _use_trajectory_view_api: bool = False ) -> Iterable[SampleBatchType]: """This implements the common experience collection logic. Args: worker (RolloutWorker): Reference to the current rollout worker. base_env (BaseEnv): Env implementing BaseEnv. extra_batch_callback (fn): function to send extra batch data to. policies (Dict[PolicyID, Policy]): Map of policy ids to Policy instances. policy_mapping_fn (func): Function that maps agent ids to policy ids. This is called when an agent first enters the environment. The agent is then "bound" to the returned policy for the episode. rollout_fragment_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. horizon (int): Horizon of the episode. preprocessors (dict): Map of policy id to preprocessor for the observations prior to filtering. obs_filters (dict): Map of policy id to filter used to process observations for the policy. clip_rewards (bool): Whether to clip rewards before postprocessing. pack_multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. clip_actions (bool): Whether to clip actions to the space range. callbacks (DefaultCallbacks): User callbacks to run on episode events. tf_sess (Session|None): Optional tensorflow session to use for batching TF policy evaluations. perf_stats (_PerfStats): Record perf stats into this object. soft_horizon (bool): Calculate rewards but don't reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. observation_fn (ObservationFunction): Optional multi-agent observation func to use for preprocessing observations. _use_trajectory_view_api (bool): Whether to use the (experimental) `_use_trajectory_view_api` to make generic trajectory views available to Models. Default: False. Yields: rollout (SampleBatch): Object containing state, action, reward, terminal condition, and other fields as dictated by `policy`. """ # Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore # error and continue with max_episode_steps=None. max_episode_steps = None try: max_episode_steps = base_env.get_unwrapped()[0].spec.max_episode_steps except Exception: pass # Trainer has a given `horizon` setting. if horizon: # `horizon` is larger than env's limit -> Error and explain how # to increase Env's own episode limit. if max_episode_steps and horizon > max_episode_steps: raise ValueError( "Your `horizon` setting ({}) is larger than the Env's own " "timestep limit ({})! Try to increase the Env's limit via " "setting its `spec.max_episode_steps` property.".format( horizon, max_episode_steps)) # Otherwise, set Trainer's horizon to env's max-steps. elif max_episode_steps: horizon = max_episode_steps logger.debug( "No episode horizon specified, setting it to Env's limit ({}).". format(max_episode_steps)) else: horizon = float("inf") logger.debug("No episode horizon specified, assuming inf.") # Pool of batch builders, which can be shared across episodes to pack # trajectory data. batch_builder_pool: List[MultiAgentSampleBatchBuilder] = [] def get_batch_builder(): if batch_builder_pool: return batch_builder_pool.pop() else: return MultiAgentSampleBatchBuilder(policies, clip_rewards, callbacks) def new_episode(): episode = MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback) # Call each policy's Exploration.on_episode_start method. # type: Policy for p in policies.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_start( policy=p, environment=base_env, episode=episode, tf_sess=getattr(p, "_sess", None)) callbacks.on_episode_start( worker=worker, base_env=base_env, policies=policies, episode=episode) return episode active_episodes: Dict[str, MultiAgentEpisode] = defaultdict(new_episode) while True: perf_stats.iters += 1 t0 = time.time() # Get observations from all ready agents. # type: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ... unfiltered_obs, rewards, dones, infos, off_policy_actions = \ base_env.poll() perf_stats.env_wait_time += time.time() - t0 if log_once("env_returns"): logger.info("Raw obs from env: {}".format( summarize(unfiltered_obs))) logger.info("Info return from env: {}".format(summarize(infos))) # Process observations and prepare for policy evaluation. t1 = time.time() # type: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], # List[Union[RolloutMetrics, SampleBatchType]] active_envs, to_eval, outputs = _process_observations( worker=worker, base_env=base_env, policies=policies, batch_builder_pool=batch_builder_pool, active_episodes=active_episodes, unfiltered_obs=unfiltered_obs, rewards=rewards, dones=dones, infos=infos, horizon=horizon, preprocessors=preprocessors, obs_filters=obs_filters, rollout_fragment_length=rollout_fragment_length, pack_multiple_episodes_in_batch=pack_multiple_episodes_in_batch, callbacks=callbacks, soft_horizon=soft_horizon, no_done_at_end=no_done_at_end, observation_fn=observation_fn, _use_trajectory_view_api=_use_trajectory_view_api) perf_stats.processing_time += time.time() - t1 for o in outputs: yield o # Do batched policy eval (accross vectorized envs). t2 = time.time() # type: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]] eval_results = _do_policy_eval( to_eval=to_eval, policies=policies, active_episodes=active_episodes, tf_sess=tf_sess, _use_trajectory_view_api=_use_trajectory_view_api) perf_stats.inference_time += time.time() - t2 # Process results and update episode state. t3 = time.time() actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \ _process_policy_eval_results( to_eval=to_eval, eval_results=eval_results, active_episodes=active_episodes, active_envs=active_envs, off_policy_actions=off_policy_actions, policies=policies, clip_actions=clip_actions, _use_trajectory_view_api=_use_trajectory_view_api) perf_stats.processing_time += time.time() - t3 # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. t4 = time.time() base_env.send_actions(actions_to_send) perf_stats.env_wait_time += time.time() - t4
def __init__( self, *, worker: "RolloutWorker", env: BaseEnv, clip_rewards: bool, rollout_fragment_length: int, count_steps_by: str = "env_steps", callbacks: "DefaultCallbacks", horizon: int = None, multiple_episodes_in_batch: bool = False, tf_sess=None, clip_actions: bool = True, soft_horizon: bool = False, no_done_at_end: bool = False, observation_fn: "ObservationFunction" = None, sample_collector_class: Optional[Type[SampleCollector]] = None, render: bool = False, # Obsolete. policies=None, policy_mapping_fn=None, preprocessors=None, obs_filters=None, ): """Initializes a SyncSampler object. Args: worker (RolloutWorker): The RolloutWorker that will use this Sampler for sampling. env (Env): Any Env object. Will be converted into an RLlib BaseEnv. clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual float value for +/- value clipping. False for no clipping. rollout_fragment_length (int): The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object. callbacks (Callbacks): The Callbacks object to use when episode events happen during rollout. horizon (Optional[int]): Hard-reset the Env multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. tf_sess (Optional[tf.Session]): A tf.Session object to use (only if framework=tf). clip_actions (bool): Whether to clip actions according to the given action_space's bounds. soft_horizon (bool): If True, calculate bootstrapped values as if episode had ended, but don't physically reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. observation_fn (Optional[ObservationFunction]): Optional multi-agent observation func to use for preprocessing observations. sample_collector_class (Optional[Type[SampleCollector]]): An optional Samplecollector sub-class to use to collect, store, and retrieve environment-, model-, and sampler data. render (bool): Whether to try to render the environment after each step. """ # All of the following arguments are deprecated. They will instead be # provided via the passed in `worker` arg, e.g. `worker.policy_map`. if log_once("deprecated_sync_sampler_args"): if policies is not None: deprecation_warning(old="policies") if policy_mapping_fn is not None: deprecation_warning(old="policy_mapping_fn") if preprocessors is not None: deprecation_warning(old="preprocessors") if obs_filters is not None: deprecation_warning(old="obs_filters") self.base_env = BaseEnv.to_base_env(env) self.rollout_fragment_length = rollout_fragment_length self.horizon = horizon self.extra_batches = queue.Queue() self.perf_stats = _PerfStats() if not sample_collector_class: sample_collector_class = SimpleListCollector self.sample_collector = sample_collector_class( worker.policy_map, clip_rewards, callbacks, multiple_episodes_in_batch, rollout_fragment_length, count_steps_by=count_steps_by) self.render = render # Create the rollout generator to use for calls to `get_data()`. self.rollout_provider = _env_runner( worker, self.base_env, self.extra_batches.put, self.rollout_fragment_length, self.horizon, clip_rewards, clip_actions, multiple_episodes_in_batch, callbacks, tf_sess, self.perf_stats, soft_horizon, no_done_at_end, observation_fn, self.sample_collector, self.render) self.metrics_queue = queue.Queue()
def __init__(self, env_creator, policy, policy_mapping_fn=None, policies_to_train=None, tf_session_creator=None, rollout_fragment_length=100, batch_mode="truncate_episodes", episode_horizon=None, preprocessor_pref="deepmind", sample_async=False, compress_observations=False, num_envs=1, observation_fn=None, observation_filter="NoFilter", clip_rewards=None, clip_actions=True, env_config=None, model_config=None, policy_config=None, worker_index=0, num_workers=0, monitor_path=None, log_dir=None, log_level=None, callbacks=None, input_creator=lambda ioctx: ioctx.default_sampler_input(), input_evaluation=frozenset([]), output_creator=lambda ioctx: NoopOutput(), remote_worker_envs=False, remote_env_batch_wait_ms=0, soft_horizon=False, no_done_at_end=False, seed=None, extra_python_environs=None, fake_sampler=False): """Initialize a rollout worker. Arguments: env_creator (func): Function that returns a gym.Env given an EnvContext wrapped configuration. policy (class|dict): Either a class implementing Policy, or a dictionary of policy id strings to (Policy, obs_space, action_space, config) tuples. If a dict is specified, then we are in multi-agent mode and a policy_mapping_fn should also be set. policy_mapping_fn (func): A function that maps agent ids to policy ids in multi-agent mode. This function will be called each time a new agent appears in an episode, to bind that agent to a policy for the duration of the episode. policies_to_train (list): Optional whitelist of policies to train, or None for all policies. tf_session_creator (func): A function that returns a TF session. This is optional and only useful with TFPolicy. rollout_fragment_length (int): The target number of env transitions to include in each sample batch returned from this worker. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch of at most `rollout_fragment_length * num_envs` in size. The batch will be exactly `rollout_fragment_length * num_envs` in size if postprocessing does not change batch sizes. Episodes may be truncated in order to meet this size requirement. "complete_episodes": Each call to sample() will return a batch of at least `rollout_fragment_length * num_envs` in size. Episodes will not be truncated, but multiple episodes may be packed within one batch to meet the batch size. Note that when `num_envs > 1`, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data. episode_horizon (int): Whether to stop episodes at this horizon. preprocessor_pref (str): Whether to prefer RLlib preprocessors ("rllib") or deepmind ("deepmind") when applicable. sample_async (bool): Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy. compress_observations (bool): If true, compress the observations. They can be decompressed with rllib/utils/compression. num_envs (int): If more than one, will create multiple envs and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv. observation_fn (ObservationFunction): Optional multi-agent observation function. observation_filter (str): Name of observation filter to use. clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to experience postprocessing. Setting to None means clip for Atari only. clip_actions (bool): Whether to clip action values to the range specified by the policy action space. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. In the multi-agent case, this config will be merged with the per-policy configs specified by `policy`. worker_index (int): For remote workers, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker. num_workers (int): For remote workers, how many workers altogether have been created? monitor_path (str): Write out episode stats and videos to this directory if specified. log_dir (str): Directory where logs can be placed. log_level (str): Set the root log level on creation. callbacks (DefaultCallbacks): Custom training callbacks. input_creator (func): Function that returns an InputReader object for loading previous generated experiences. input_evaluation (list): How to evaluate the policy performance. This only makes sense to set when the input is reading offline data. The possible values include: - "is": the step-wise importance sampling estimator. - "wis": the weighted step-wise is estimator. - "simulation": run the environment in the background, but use this data for evaluation only and never for learning. output_creator (func): Function that returns an OutputWriter object for saving generated experiences. remote_worker_envs (bool): If using num_envs > 1, whether to create those new envs in remote processes instead of in the current process. This adds overheads, but can make sense if your envs remote_env_batch_wait_ms (float): Timeout that remote workers are waiting when polling environments. 0 (continue when at least one env is ready) is a reasonable default, but optimal value could be obtained by measuring your environment step / reset and model inference perf. soft_horizon (bool): Calculate rewards but don't reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. seed (int): Set the seed of both np and tf to this value to to ensure each remote worker has unique exploration behavior. extra_python_environs (dict): Extra python environments need to be set. fake_sampler (bool): Use a fake (inf speed) sampler for testing. """ self._original_kwargs = locals().copy() del self._original_kwargs["self"] global _global_worker _global_worker = self # set extra environs first if extra_python_environs: for key, value in extra_python_environs.items(): os.environ[key] = str(value) def gen_rollouts(): while True: yield self.sample() ParallelIteratorWorker.__init__(self, gen_rollouts, False) policy_config = policy_config or {} if (tf and policy_config.get("eager") and not policy_config.get("no_eager_on_workers") # This eager check is necessary for certain all-framework tests # that use tf's eager_mode() context generator. and not tf.executing_eagerly()): tf.enable_eager_execution() if log_level: logging.getLogger("ray.rllib").setLevel(log_level) if worker_index > 1: disable_log_once_globally() # only need 1 worker to log elif log_level == "DEBUG": enable_periodic_logging() env_context = EnvContext(env_config or {}, worker_index) self.policy_config = policy_config if callbacks: self.callbacks = callbacks() else: from ray.rllib.agents.callbacks import DefaultCallbacks self.callbacks = DefaultCallbacks() self.worker_index = worker_index self.num_workers = num_workers model_config = model_config or {} policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) if not callable(policy_mapping_fn): raise ValueError("Policy mapping function not callable?") self.env_creator = env_creator self.rollout_fragment_length = rollout_fragment_length * num_envs self.batch_mode = batch_mode self.compress_observations = compress_observations self.preprocessing_enabled = True self.last_batch = None self.global_vars = None self.fake_sampler = fake_sampler self.env = _validate_env(env_creator(env_context)) if isinstance(self.env, MultiAgentEnv) or \ isinstance(self.env, BaseEnv): def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ not model_config.get("custom_preprocessor") and \ preprocessor_pref == "deepmind": # Deepmind wrappers already handle all preprocessing self.preprocessing_enabled = False if clip_rewards is None: clip_rewards = True def wrap(env): env = wrap_deepmind( env, dim=model_config.get("dim"), framestack=model_config.get("framestack")) if monitor_path: from gym import wrappers env = wrappers.Monitor(env, monitor_path, resume=True) return env else: def wrap(env): if monitor_path: from gym import wrappers env = wrappers.Monitor(env, monitor_path, resume=True) return env self.env = wrap(self.env) def make_env(vector_index): return wrap( env_creator( env_context.copy_with_overrides( vector_index=vector_index, remote=remote_worker_envs))) self.tf_sess = None policy_dict = _validate_and_canonicalize(policy, self.env) self.policies_to_train = policies_to_train or list(policy_dict.keys()) # set numpy and python seed if seed is not None: np.random.seed(seed) random.seed(seed) if not hasattr(self.env, "seed"): raise ValueError("Env doesn't support env.seed(): {}".format( self.env)) self.env.seed(seed) try: assert torch is not None torch.manual_seed(seed) except AssertionError: logger.info("Could not seed torch") if _has_tensorflow_graph(policy_dict) and not (tf and tf.executing_eagerly()): if not tf: raise ImportError("Could not import tensorflow") with tf.Graph().as_default(): if tf_session_creator: self.tf_sess = tf_session_creator() else: self.tf_sess = tf.Session( config=tf.ConfigProto( gpu_options=tf.GPUOptions(allow_growth=True))) with self.tf_sess.as_default(): # set graph-level seed if seed is not None: tf.set_random_seed(seed) self.policy_map, self.preprocessors = \ self._build_policy_map(policy_dict, policy_config) if (ray.is_initialized() and ray.worker._mode() != ray.worker.LOCAL_MODE): if not ray.get_gpu_ids(): logger.debug( "Creating policy evaluation worker {}".format( worker_index) + " on CPU (please ignore any CUDA init errors)") elif not tf.test.is_gpu_available(): raise RuntimeError( "GPUs were assigned to this worker by Ray, but " "TensorFlow reports GPU acceleration is disabled. " "This could be due to a bad CUDA or TF installation.") else: self.policy_map, self.preprocessors = self._build_policy_map( policy_dict, policy_config) self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID} if self.multiagent: if not ((isinstance(self.env, MultiAgentEnv) or isinstance(self.env, ExternalMultiAgentEnv)) or isinstance(self.env, BaseEnv)): raise ValueError( "Have multiple policies {}, but the env ".format( self.policy_map) + "{} is not a subclass of BaseEnv, MultiAgentEnv or " "ExternalMultiAgentEnv?".format(self.env)) self.filters = { policy_id: get_filter(observation_filter, policy.observation_space.shape) for (policy_id, policy) in self.policy_map.items() } if self.worker_index == 0: logger.info("Built filter map: {}".format(self.filters)) # Always use vector env for consistency even if num_envs = 1 self.async_env = BaseEnv.to_base_env( self.env, make_env=make_env, num_envs=num_envs, remote_envs=remote_worker_envs, remote_env_batch_wait_ms=remote_env_batch_wait_ms) self.num_envs = num_envs if self.batch_mode == "truncate_episodes": pack_episodes = True elif self.batch_mode == "complete_episodes": rollout_fragment_length = float("inf") # never cut episodes pack_episodes = False # sampler will return 1 episode per poll else: raise ValueError("Unsupported batch mode: {}".format( self.batch_mode)) self.io_context = IOContext(log_dir, policy_config, worker_index, self) self.reward_estimators = [] for method in input_evaluation: if method == "simulation": logger.warning( "Requested 'simulation' input evaluation method: " "will discard all sampler outputs and keep only metrics.") sample_async = True elif method == "is": ise = ImportanceSamplingEstimator.create(self.io_context) self.reward_estimators.append(ise) elif method == "wis": wise = WeightedImportanceSamplingEstimator.create( self.io_context) self.reward_estimators.append(wise) else: raise ValueError( "Unknown evaluation method: {}".format(method)) if sample_async: self.sampler = AsyncSampler( self, self.async_env, self.policy_map, policy_mapping_fn, self.preprocessors, self.filters, clip_rewards, rollout_fragment_length, self.callbacks, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess, clip_actions=clip_actions, blackhole_outputs="simulation" in input_evaluation, soft_horizon=soft_horizon, no_done_at_end=no_done_at_end, observation_fn=observation_fn) self.sampler.start() else: self.sampler = SyncSampler( self, self.async_env, self.policy_map, policy_mapping_fn, self.preprocessors, self.filters, clip_rewards, rollout_fragment_length, self.callbacks, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess, clip_actions=clip_actions, soft_horizon=soft_horizon, no_done_at_end=no_done_at_end, observation_fn=observation_fn) self.input_reader = input_creator(self.io_context) assert isinstance(self.input_reader, InputReader), self.input_reader self.output_writer = output_creator(self.io_context) assert isinstance(self.output_writer, OutputWriter), self.output_writer logger.debug( "Created rollout worker with env {} ({}), policies {}".format( self.async_env, self.env, self.policy_map))
def testNestedTupleAsync(self): self.doTestNestedTuple(lambda _: BaseEnv.to_base_env(NestedTupleEnv()))
def __init__(self, env_creator, policy_graph, policy_mapping_fn=None, policies_to_train=None, tf_session_creator=None, batch_steps=100, batch_mode="truncate_episodes", episode_horizon=None, preprocessor_pref="deepmind", sample_async=False, compress_observations=False, num_envs=1, observation_filter="NoFilter", clip_rewards=None, clip_actions=True, env_config=None, model_config=None, policy_config=None, worker_index=0, monitor_path=None, log_dir=None, log_level=None, callbacks=None, input_creator=lambda ioctx: ioctx.default_sampler_input(), input_evaluation_method=None, output_creator=lambda ioctx: NoopOutput(), remote_worker_envs=False): """Initialize a policy evaluator. Arguments: env_creator (func): Function that returns a gym.Env given an EnvContext wrapped configuration. policy_graph (class|dict): Either a class implementing PolicyGraph, or a dictionary of policy id strings to (PolicyGraph, obs_space, action_space, config) tuples. If a dict is specified, then we are in multi-agent mode and a policy_mapping_fn should also be set. policy_mapping_fn (func): A function that maps agent ids to policy ids in multi-agent mode. This function will be called each time a new agent appears in an episode, to bind that agent to a policy for the duration of the episode. policies_to_train (list): Optional whitelist of policies to train, or None for all policies. tf_session_creator (func): A function that returns a TF session. This is optional and only useful with TFPolicyGraph. batch_steps (int): The target number of env transitions to include in each sample batch returned from this evaluator. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch of at most `batch_steps * num_envs` in size. The batch will be exactly `batch_steps * num_envs` in size if postprocessing does not change batch sizes. Episodes may be truncated in order to meet this size requirement. "complete_episodes": Each call to sample() will return a batch of at least `batch_steps * num_envs` in size. Episodes will not be truncated, but multiple episodes may be packed within one batch to meet the batch size. Note that when `num_envs > 1`, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data. episode_horizon (int): Whether to stop episodes at this horizon. preprocessor_pref (str): Whether to prefer RLlib preprocessors ("rllib") or deepmind ("deepmind") when applicable. sample_async (bool): Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy. compress_observations (bool): If true, compress the observations. They can be decompressed with rllib/utils/compression. num_envs (int): If more than one, will create multiple envs and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv. observation_filter (str): Name of observation filter to use. clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to experience postprocessing. Setting to None means clip for Atari only. clip_actions (bool): Whether to clip action values to the range specified by the policy action space. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. In the multi-agent case, this config will be merged with the per-policy configs specified by `policy_graph`. worker_index (int): For remote evaluators, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker. monitor_path (str): Write out episode stats and videos to this directory if specified. log_dir (str): Directory where logs can be placed. log_level (str): Set the root log level on creation. callbacks (dict): Dict of custom debug callbacks. input_creator (func): Function that returns an InputReader object for loading previous generated experiences. input_evaluation_method (str): How to evaluate the current policy. This only applies when the input is reading offline data. Options are: - None: don't evaluate the policy. The episode reward and other metrics will be NaN. - "simulation": run the environment in the background, but use this data for evaluation only and never for learning. output_creator (func): Function that returns an OutputWriter object for saving generated experiences. remote_worker_envs (bool): If using num_envs > 1, whether to create those new envs in remote processes instead of in the current process. This adds overheads, but can make sense if your envs are very CPU intensive (e.g., for StarCraft). """ if log_level: logging.getLogger("ray.rllib").setLevel(log_level) env_context = EnvContext(env_config or {}, worker_index) policy_config = policy_config or {} self.policy_config = policy_config self.callbacks = callbacks or {} model_config = model_config or {} policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) if not callable(policy_mapping_fn): raise ValueError( "Policy mapping function not callable. If you're using Tune, " "make sure to escape the function with tune.function() " "to prevent it from being evaluated as an expression.") self.env_creator = env_creator self.sample_batch_size = batch_steps * num_envs self.batch_mode = batch_mode self.compress_observations = compress_observations self.preprocessing_enabled = True self.env = env_creator(env_context) if isinstance(self.env, MultiAgentEnv) or \ isinstance(self.env, BaseEnv): def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ not model_config.get("custom_preprocessor") and \ preprocessor_pref == "deepmind": # Deepmind wrappers already handle all preprocessing self.preprocessing_enabled = False if clip_rewards is None: clip_rewards = True def wrap(env): env = wrap_deepmind( env, dim=model_config.get("dim"), framestack=model_config.get("framestack")) if monitor_path: env = _monitor(env, monitor_path) return env else: def wrap(env): if monitor_path: env = _monitor(env, monitor_path) return env self.env = wrap(self.env) def make_env(vector_index): return wrap( env_creator( env_context.copy_with_overrides( vector_index=vector_index, remote=remote_worker_envs))) self.tf_sess = None policy_dict = _validate_and_canonicalize(policy_graph, self.env) self.policies_to_train = policies_to_train or list(policy_dict.keys()) if _has_tensorflow_graph(policy_dict): if (ray.is_initialized() and ray.worker._mode() != ray.worker.LOCAL_MODE and not ray.get_gpu_ids()): logger.info("Creating policy evaluation worker {}".format( worker_index) + " on CPU (please ignore any CUDA init errors)") with tf.Graph().as_default(): if tf_session_creator: self.tf_sess = tf_session_creator() else: self.tf_sess = tf.Session( config=tf.ConfigProto( gpu_options=tf.GPUOptions(allow_growth=True))) with self.tf_sess.as_default(): self.policy_map, self.preprocessors = \ self._build_policy_map(policy_dict, policy_config) else: self.policy_map, self.preprocessors = self._build_policy_map( policy_dict, policy_config) self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID} if self.multiagent: if not (isinstance(self.env, MultiAgentEnv) or isinstance(self.env, BaseEnv)): raise ValueError( "Have multiple policy graphs {}, but the env ".format( self.policy_map) + "{} is not a subclass of MultiAgentEnv?".format(self.env)) self.filters = { policy_id: get_filter(observation_filter, policy.observation_space.shape) for (policy_id, policy) in self.policy_map.items() } # Always use vector env for consistency even if num_envs = 1 self.async_env = BaseEnv.to_base_env( self.env, make_env=make_env, num_envs=num_envs, remote_envs=remote_worker_envs) self.num_envs = num_envs if self.batch_mode == "truncate_episodes": unroll_length = batch_steps pack_episodes = True elif self.batch_mode == "complete_episodes": unroll_length = float("inf") # never cut episodes pack_episodes = False # sampler will return 1 episode per poll else: raise ValueError("Unsupported batch mode: {}".format( self.batch_mode)) if input_evaluation_method == "simulation": logger.warning( "Requested 'simulation' input evaluation method: " "will discard all sampler outputs and keep only metrics.") sample_async = True elif input_evaluation_method is None: pass else: raise ValueError("Unknown evaluation method: {}".format( input_evaluation_method)) if sample_async: self.sampler = AsyncSampler( self.async_env, self.policy_map, policy_mapping_fn, self.preprocessors, self.filters, clip_rewards, unroll_length, self.callbacks, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess, clip_actions=clip_actions, blackhole_outputs=input_evaluation_method == "simulation") self.sampler.start() else: self.sampler = SyncSampler( self.async_env, self.policy_map, policy_mapping_fn, self.preprocessors, self.filters, clip_rewards, unroll_length, self.callbacks, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess, clip_actions=clip_actions) self.io_context = IOContext(log_dir, policy_config, worker_index, self) self.input_reader = input_creator(self.io_context) assert isinstance(self.input_reader, InputReader), self.input_reader self.output_writer = output_creator(self.io_context) assert isinstance(self.output_writer, OutputWriter), self.output_writer logger.debug("Created evaluator with env {} ({}), policies {}".format( self.async_env, self.env, self.policy_map))
def __init__( self, *, worker: "RolloutWorker", env: BaseEnv, policies: Dict[PolicyID, Policy], policy_mapping_fn: Callable[[AgentID], PolicyID], preprocessors: Dict[PolicyID, Preprocessor], obs_filters: Dict[PolicyID, Filter], clip_rewards: bool, rollout_fragment_length: int, count_steps_by: str = "env_steps", callbacks: "DefaultCallbacks", horizon: int = None, multiple_episodes_in_batch: bool = False, tf_sess=None, clip_actions: bool = True, blackhole_outputs: bool = False, soft_horizon: bool = False, no_done_at_end: bool = False, observation_fn: "ObservationFunction" = None, sample_collector_class: Optional[Type[SampleCollector]] = None, render: bool = False, ): """Initializes a AsyncSampler object. Args: worker (RolloutWorker): The RolloutWorker that will use this Sampler for sampling. env (Env): Any Env object. Will be converted into an RLlib BaseEnv. policies (Dict[str, Policy]): Mapping from policy ID to Policy obj. policy_mapping_fn (callable): Callable that takes an agent ID and returns a Policy object. preprocessors (Dict[str, Preprocessor]): Mapping from policy ID to Preprocessor object for the observations prior to filtering. obs_filters (Dict[str, Filter]): Mapping from policy ID to env Filter object. clip_rewards (Union[bool, float]): True for +/-1.0 clipping, actual float value for +/- value clipping. False for no clipping. rollout_fragment_length (int): The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object. count_steps_by (str): Either "env_steps" or "agent_steps". Refers to the unit of `rollout_fragment_length`. callbacks (Callbacks): The Callbacks object to use when episode events happen during rollout. horizon (Optional[int]): Hard-reset the Env multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. tf_sess (Optional[tf.Session]): A tf.Session object to use (only if framework=tf). clip_actions (bool): Whether to clip actions according to the given action_space's bounds. blackhole_outputs (bool): Whether to collect samples, but then not further process or store them (throw away all samples). soft_horizon (bool): If True, calculate bootstrapped values as if episode had ended, but don't physically reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. observation_fn (Optional[ObservationFunction]): Optional multi-agent observation func to use for preprocessing observations. sample_collector_class (Optional[Type[SampleCollector]]): An optional Samplecollector sub-class to use to collect, store, and retrieve environment-, model-, and sampler data. render (bool): Whether to try to render the environment after each step. """ for _, f in obs_filters.items(): assert getattr(f, "is_concurrent", False), \ "Observation Filter must support concurrent updates." self.worker = worker self.base_env = BaseEnv.to_base_env(env) threading.Thread.__init__(self) self.queue = queue.Queue(5) self.extra_batches = queue.Queue() self.metrics_queue = queue.Queue() self.rollout_fragment_length = rollout_fragment_length self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn self.preprocessors = preprocessors self.obs_filters = obs_filters self.clip_rewards = clip_rewards self.daemon = True self.multiple_episodes_in_batch = multiple_episodes_in_batch self.tf_sess = tf_sess self.callbacks = callbacks self.clip_actions = clip_actions self.blackhole_outputs = blackhole_outputs self.soft_horizon = soft_horizon self.no_done_at_end = no_done_at_end self.perf_stats = _PerfStats() self.shutdown = False self.observation_fn = observation_fn self.render = render if not sample_collector_class: sample_collector_class = SimpleListCollector self.sample_collector = sample_collector_class( policies, clip_rewards, callbacks, multiple_episodes_in_batch, rollout_fragment_length, count_steps_by=count_steps_by)
def __init__(self, env_creator, policy_graph, policy_mapping_fn=None, policies_to_train=None, tf_session_creator=None, batch_steps=100, batch_mode="truncate_episodes", episode_horizon=None, preprocessor_pref="deepmind", sample_async=False, compress_observations=False, num_envs=1, observation_filter="NoFilter", clip_rewards=None, clip_actions=True, env_config=None, model_config=None, policy_config=None, worker_index=0, monitor_path=None, log_dir=None, log_level=None, callbacks=None, input_creator=lambda ioctx: ioctx.default_sampler_input(), input_evaluation=frozenset([]), output_creator=lambda ioctx: NoopOutput(), remote_worker_envs=False, async_remote_worker_envs=False): """Initialize a policy evaluator. Arguments: env_creator (func): Function that returns a gym.Env given an EnvContext wrapped configuration. policy_graph (class|dict): Either a class implementing PolicyGraph, or a dictionary of policy id strings to (PolicyGraph, obs_space, action_space, config) tuples. If a dict is specified, then we are in multi-agent mode and a policy_mapping_fn should also be set. policy_mapping_fn (func): A function that maps agent ids to policy ids in multi-agent mode. This function will be called each time a new agent appears in an episode, to bind that agent to a policy for the duration of the episode. policies_to_train (list): Optional whitelist of policies to train, or None for all policies. tf_session_creator (func): A function that returns a TF session. This is optional and only useful with TFPolicyGraph. batch_steps (int): The target number of env transitions to include in each sample batch returned from this evaluator. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch of at most `batch_steps * num_envs` in size. The batch will be exactly `batch_steps * num_envs` in size if postprocessing does not change batch sizes. Episodes may be truncated in order to meet this size requirement. "complete_episodes": Each call to sample() will return a batch of at least `batch_steps * num_envs` in size. Episodes will not be truncated, but multiple episodes may be packed within one batch to meet the batch size. Note that when `num_envs > 1`, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data. episode_horizon (int): Whether to stop episodes at this horizon. preprocessor_pref (str): Whether to prefer RLlib preprocessors ("rllib") or deepmind ("deepmind") when applicable. sample_async (bool): Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy. compress_observations (bool): If true, compress the observations. They can be decompressed with rllib/utils/compression. num_envs (int): If more than one, will create multiple envs and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv. observation_filter (str): Name of observation filter to use. clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to experience postprocessing. Setting to None means clip for Atari only. clip_actions (bool): Whether to clip action values to the range specified by the policy action space. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. In the multi-agent case, this config will be merged with the per-policy configs specified by `policy_graph`. worker_index (int): For remote evaluators, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker. monitor_path (str): Write out episode stats and videos to this directory if specified. log_dir (str): Directory where logs can be placed. log_level (str): Set the root log level on creation. callbacks (dict): Dict of custom debug callbacks. input_creator (func): Function that returns an InputReader object for loading previous generated experiences. input_evaluation (list): How to evaluate the policy performance. This only makes sense to set when the input is reading offline data. The possible values include: - "is": the step-wise importance sampling estimator. - "wis": the weighted step-wise is estimator. - "simulation": run the environment in the background, but use this data for evaluation only and never for learning. output_creator (func): Function that returns an OutputWriter object for saving generated experiences. remote_worker_envs (bool): If using num_envs > 1, whether to create those new envs in remote processes instead of in the current process. This adds overheads, but can make sense if your envs are very CPU intensive (e.g., for StarCraft). async_remote_worker_envs (bool): Similar to remote_worker_envs, but runs the envs asynchronously in the background. """ if log_level: logging.getLogger("ray.rllib").setLevel(log_level) env_context = EnvContext(env_config or {}, worker_index) policy_config = policy_config or {} self.policy_config = policy_config self.callbacks = callbacks or {} model_config = model_config or {} policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) if not callable(policy_mapping_fn): raise ValueError( "Policy mapping function not callable. If you're using Tune, " "make sure to escape the function with tune.function() " "to prevent it from being evaluated as an expression.") self.env_creator = env_creator self.sample_batch_size = batch_steps * num_envs self.batch_mode = batch_mode self.compress_observations = compress_observations self.preprocessing_enabled = True self.env = _validate_env(env_creator(env_context)) if isinstance(self.env, MultiAgentEnv) or \ isinstance(self.env, BaseEnv): def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ not model_config.get("custom_preprocessor") and \ preprocessor_pref == "deepmind": # Deepmind wrappers already handle all preprocessing self.preprocessing_enabled = False if clip_rewards is None: clip_rewards = True def wrap(env): env = wrap_deepmind( env, dim=model_config.get("dim"), framestack=model_config.get("framestack")) if monitor_path: env = _monitor(env, monitor_path) return env else: def wrap(env): if monitor_path: env = _monitor(env, monitor_path) return env self.env = wrap(self.env) def make_env(vector_index): return wrap( env_creator( env_context.copy_with_overrides( vector_index=vector_index, remote=remote_worker_envs))) self.tf_sess = None policy_dict = _validate_and_canonicalize(policy_graph, self.env) self.policies_to_train = policies_to_train or list(policy_dict.keys()) if _has_tensorflow_graph(policy_dict): if (ray.is_initialized() and ray.worker._mode() != ray.worker.LOCAL_MODE and not ray.get_gpu_ids()): logger.info("Creating policy evaluation worker {}".format( worker_index) + " on CPU (please ignore any CUDA init errors)") with tf.Graph().as_default(): if tf_session_creator: self.tf_sess = tf_session_creator() else: self.tf_sess = tf.Session( config=tf.ConfigProto( gpu_options=tf.GPUOptions(allow_growth=True))) with self.tf_sess.as_default(): self.policy_map, self.preprocessors = \ self._build_policy_map(policy_dict, policy_config) else: self.policy_map, self.preprocessors = self._build_policy_map( policy_dict, policy_config) self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID} if self.multiagent: if not (isinstance(self.env, MultiAgentEnv) or isinstance(self.env, BaseEnv)): raise ValueError( "Have multiple policy graphs {}, but the env ".format( self.policy_map) + "{} is not a subclass of MultiAgentEnv?".format(self.env)) self.filters = { policy_id: get_filter(observation_filter, policy.observation_space.shape) for (policy_id, policy) in self.policy_map.items() } # Always use vector env for consistency even if num_envs = 1 self.async_env = BaseEnv.to_base_env( self.env, make_env=make_env, num_envs=num_envs, remote_envs=remote_worker_envs, async_remote_envs=async_remote_worker_envs) self.num_envs = num_envs if self.batch_mode == "truncate_episodes": unroll_length = batch_steps pack_episodes = True elif self.batch_mode == "complete_episodes": unroll_length = float("inf") # never cut episodes pack_episodes = False # sampler will return 1 episode per poll else: raise ValueError("Unsupported batch mode: {}".format( self.batch_mode)) self.io_context = IOContext(log_dir, policy_config, worker_index, self) self.reward_estimators = [] for method in input_evaluation: if method == "simulation": logger.warning( "Requested 'simulation' input evaluation method: " "will discard all sampler outputs and keep only metrics.") sample_async = True elif method == "is": ise = ImportanceSamplingEstimator.create(self.io_context) self.reward_estimators.append(ise) elif method == "wis": wise = WeightedImportanceSamplingEstimator.create( self.io_context) self.reward_estimators.append(wise) else: raise ValueError( "Unknown evaluation method: {}".format(method)) if sample_async: self.sampler = AsyncSampler( self.async_env, self.policy_map, policy_mapping_fn, self.preprocessors, self.filters, clip_rewards, unroll_length, self.callbacks, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess, clip_actions=clip_actions, blackhole_outputs="simulation" in input_evaluation) self.sampler.start() else: self.sampler = SyncSampler( self.async_env, self.policy_map, policy_mapping_fn, self.preprocessors, self.filters, clip_rewards, unroll_length, self.callbacks, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess, clip_actions=clip_actions) self.input_reader = input_creator(self.io_context) assert isinstance(self.input_reader, InputReader), self.input_reader self.output_writer = output_creator(self.io_context) assert isinstance(self.output_writer, OutputWriter), self.output_writer logger.debug("Created evaluator with env {} ({}), policies {}".format( self.async_env, self.env, self.policy_map))
def _process_observations( *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], active_episodes: Dict[str, MultiAgentEpisode], unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]], rewards: Dict[EnvID, Dict[AgentID, float]], dones: Dict[EnvID, Dict[AgentID, bool]], infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], horizon: int, preprocessors: Dict[PolicyID, Preprocessor], obs_filters: Dict[PolicyID, Filter], multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", sample_collector: SampleCollector, ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[ RolloutMetrics, SampleBatchType]]]: """Record new data from the environment and prepare for policy evaluation. Args: worker (RolloutWorker): Reference to the current rollout worker. base_env (BaseEnv): Env implementing BaseEnv. policies (dict): Map of policy ids to Policy instances. batch_builder_pool (List[SampleBatchBuilder]): List of pooled SampleBatchBuilder object for recycling. active_episodes (Dict[str, MultiAgentEpisode]): Mapping from episode ID to currently ongoing MultiAgentEpisode object. unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids -> unfiltered observation tensor, returned by a `BaseEnv.poll()` call. rewards (dict): Doubly keyed dict of env-ids -> agent ids -> rewards tensor, returned by a `BaseEnv.poll()` call. dones (dict): Doubly keyed dict of env-ids -> agent ids -> boolean done flags, returned by a `BaseEnv.poll()` call. infos (dict): Doubly keyed dict of env-ids -> agent ids -> info dicts, returned by a `BaseEnv.poll()` call. horizon (int): Horizon of the episode. preprocessors (dict): Map of policy id to preprocessor for the observations prior to filtering. obs_filters (dict): Map of policy id to filter used to process observations for the policy. rollout_fragment_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. callbacks (DefaultCallbacks): User callbacks to run on episode events. soft_horizon (bool): Calculate rewards but don't reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. observation_fn (ObservationFunction): Optional multi-agent observation func to use for preprocessing observations. sample_collector (SampleCollector): The SampleCollector object used to store and retrieve environment samples. Returns: Tuple: - active_envs: Set of non-terminated env ids. - to_eval: Map of policy_id to list of agent PolicyEvalData. - outputs: List of metrics and samples to return from the sampler. """ # Output objects. active_envs: Set[EnvID] = set() to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list) outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] # For each (vectorized) sub-environment. # type: EnvID, Dict[AgentID, EnvObsType] for env_id, all_agents_obs in unfiltered_obs.items(): is_new_episode: bool = env_id not in active_episodes episode: MultiAgentEpisode = active_episodes[env_id] if not is_new_episode: sample_collector.episode_step(episode) episode._add_agent_rewards(rewards[env_id]) # Check episode termination conditions. if dones[env_id]["__all__"] or episode.length >= horizon: hit_horizon = (episode.length >= horizon and not dones[env_id]["__all__"]) all_agents_done = True atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics( base_env) if atari_metrics is not None: for m in atari_metrics: outputs.append( m._replace(custom_metrics=episode.custom_metrics)) else: outputs.append( RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics, {}, episode.hist_data, episode.media)) else: hit_horizon = False all_agents_done = False active_envs.add(env_id) # Custom observation function is applied before preprocessing. if observation_fn: all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn( agent_obs=all_agents_obs, worker=worker, base_env=base_env, policies=policies, episode=episode) if not isinstance(all_agents_obs, dict): raise ValueError( "observe() must return a dict of agent observations") # For each agent in the environment. # type: AgentID, EnvObsType for agent_id, raw_obs in all_agents_obs.items(): assert agent_id != "__all__" last_observation: EnvObsType = episode.last_observation_for( agent_id) agent_done = bool(all_agents_done or dones[env_id].get(agent_id)) # A new agent (initial obs) is already done -> Skip entirely. if last_observation is None and agent_done: continue policy_id: PolicyID = episode.policy_for(agent_id) prep_obs: EnvObsType = _get_or_raise(preprocessors, policy_id).transform(raw_obs) if log_once("prep_obs"): logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs: EnvObsType = _get_or_raise(obs_filters, policy_id)(prep_obs) if log_once("filtered_obs"): logger.info("Filtered obs: {}".format(summarize(filtered_obs))) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) # Infos from the environment. agent_infos = infos[env_id].get(agent_id, {}) episode._set_last_info(agent_id, agent_infos) # Record transition info if applicable. if last_observation is None: sample_collector.add_init_obs(episode, agent_id, env_id, policy_id, episode.length - 1, filtered_obs) else: # Add actions, rewards, next-obs to collectors. values_dict = { "t": episode.length - 1, "env_id": env_id, "agent_index": episode._agent_index(agent_id), # Action (slot 0) taken at timestep t. "actions": episode.last_action_for(agent_id), # Reward received after taking a at timestep t. "rewards": rewards[env_id][agent_id], # After taking action=a, did we reach terminal? "dones": (False if (no_done_at_end or (hit_horizon and soft_horizon)) else agent_done), # Next observation. "new_obs": filtered_obs, } # Add extra-action-fetches to collectors. pol = policies[policy_id] for key, value in episode.last_pi_info_for(agent_id).items(): if key in pol.view_requirements: values_dict[key] = value # Env infos for this agent. if "infos" in pol.view_requirements: values_dict["infos"] = agent_infos sample_collector.add_action_reward_next_obs( episode.episode_id, agent_id, env_id, policy_id, agent_done, values_dict) if not agent_done: item = PolicyEvalData( env_id, agent_id, filtered_obs, agent_infos, None if last_observation is None else episode.rnn_state_for(agent_id), None if last_observation is None else episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0) to_eval[policy_id].append(item) # Invoke the `on_episode_step` callback after the step is logged # to the episode. # Exception: The very first env.poll() call causes the env to get reset # (no step taken yet, just a single starting observation logged). # We need to skip this callback in this case. if episode.length > 0: callbacks.on_episode_step(worker=worker, base_env=base_env, episode=episode, env_index=env_id) # Episode is done for all agents (dones[__all__] == True) # or we hit the horizon. if all_agents_done: is_done = dones[env_id]["__all__"] check_dones = is_done and not no_done_at_end # If, we are not allowed to pack the next episode into the same # SampleBatch (batch_mode=complete_episodes) -> Build the # MultiAgentBatch from a single episode and add it to "outputs". # Otherwise, just postprocess and continue collecting across # episodes. ma_sample_batch = sample_collector.postprocess_episode( episode, is_done=is_done or (hit_horizon and not soft_horizon), check_dones=check_dones, build=not multiple_episodes_in_batch) if ma_sample_batch: outputs.append(ma_sample_batch) # Call each policy's Exploration.on_episode_end method. for p in policies.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_end(policy=p, environment=base_env, episode=episode, tf_sess=getattr( p, "_sess", None)) # Call custom on_episode_end callback. callbacks.on_episode_end( worker=worker, base_env=base_env, policies=policies, episode=episode, env_index=env_id, ) # Horizon hit and we have a soft horizon (no hard env reset). if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs: Dict[AgentID, EnvObsType] = all_agents_obs else: del active_episodes[env_id] resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(env_id) # Reset not supported, drop this env from the ready list. if resetted_obs is None: if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") # Creates a new episode if this is not async return. # If reset is async, we will get its result in some future poll. elif resetted_obs != ASYNC_RESET_RETURN: new_episode: MultiAgentEpisode = active_episodes[env_id] if observation_fn: resetted_obs: Dict[AgentID, EnvObsType] = observation_fn( agent_obs=resetted_obs, worker=worker, base_env=base_env, policies=policies, episode=new_episode) # type: AgentID, EnvObsType for agent_id, raw_obs in resetted_obs.items(): policy_id: PolicyID = new_episode.policy_for(agent_id) prep_obs: EnvObsType = _get_or_raise( preprocessors, policy_id).transform(raw_obs) filtered_obs: EnvObsType = _get_or_raise( obs_filters, policy_id)(prep_obs) new_episode._set_last_observation(agent_id, filtered_obs) # Add initial obs to buffer. sample_collector.add_init_obs(new_episode, agent_id, env_id, policy_id, new_episode.length - 1, filtered_obs) item = PolicyEvalData( env_id, agent_id, filtered_obs, episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), None, 0.0) to_eval[policy_id].append(item) # Try to build something. if multiple_episodes_in_batch: sample_batches = \ sample_collector.try_build_truncated_episode_multi_agent_batch() if sample_batches: outputs.extend(sample_batches) return active_envs, to_eval, outputs
dones_are_same, msg = objects_are_the_same(new=new_dones, old=old_dones) if not dones_are_same: assert False, f"In the dones,\n{msg}" info_are_same, msg = objects_are_the_same(new=new_info, old=old_info) if not info_are_same: assert False, f"In the infos,\n{msg}" print("\n\nstarted") num_envs = 5 env = StrategoParallelEnv(num_envs=num_envs) old_env = BaseEnv.to_base_env(StrategoMultiAgentEnv(), lambda x: StrategoMultiAgentEnv(), num_envs=num_envs) old_env.poll() env.poll() print("envs have been made") for env_idx in range(num_envs): print(f"ENV INDEX:{env_idx}") old_env.try_reset(env_idx) print("a") old_state = old_env.envs[env_idx].state old_player = old_env.envs[env_idx].player print("b") env.try_reset(env_idx, force_state=old_state, force_player=old_player) print("c") print("envs have been reset and synced")