def _make_worker(self, cls, env_creator, policy, worker_index, config): def session_creator(): logger.debug("Creating TF session {}".format( config["tf_session_args"])) return tf.Session(config=tf.ConfigProto( **config["tf_session_args"])) if isinstance(config["input"], FunctionType): input_creator = config["input"] elif config["input"] == "sampler": input_creator = (lambda ioctx: ioctx.default_sampler_input()) elif isinstance(config["input"], dict): input_creator = ( lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx), config["shuffle_buffer_size"])) else: input_creator = ( lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx), config["shuffle_buffer_size"])) if isinstance(config["output"], FunctionType): output_creator = config["output"] elif config["output"] is None: output_creator = (lambda ioctx: NoopOutput()) elif config["output"] == "logdir": output_creator = (lambda ioctx: JsonWriter( ioctx.log_dir, ioctx, max_file_size=config["output_max_file_size"], compress_columns=config["output_compress_columns"])) else: output_creator = (lambda ioctx: JsonWriter( config["output"], ioctx, max_file_size=config["output_max_file_size"], compress_columns=config["output_compress_columns"])) if config["input"] == "sampler": input_evaluation = [] else: input_evaluation = config["input_evaluation"] # Fill in the default policy if 'None' is specified in multiagent if config["multiagent"]["policies"]: tmp = config["multiagent"]["policies"] _validate_multiagent_config(tmp, allow_none_graph=True) for k, v in tmp.items(): if v[0] is None: tmp[k] = (policy, v[1], v[2], v[3]) policy = tmp return cls(env_creator, policy, policy_mapping_fn=config["multiagent"]["policy_mapping_fn"], policies_to_train=config["multiagent"]["policies_to_train"], tf_session_creator=(session_creator if config["tf_session_args"] else None), batch_steps=config["sample_batch_size"], batch_mode=config["batch_mode"], episode_horizon=config["horizon"], preprocessor_pref=config["preprocessor_pref"], sample_async=config["sample_async"], compress_observations=config["compress_observations"], num_envs=config["num_envs_per_worker"], observation_filter=config["observation_filter"], clip_rewards=config["clip_rewards"], clip_actions=config["clip_actions"], env_config=config["env_config"], model_config=config["model"], policy_config=config, worker_index=worker_index, monitor_path=self._logdir if config["monitor"] else None, log_dir=self._logdir, log_level=config["log_level"], callbacks=config["callbacks"], input_creator=input_creator, input_evaluation=input_evaluation, output_creator=output_creator, remote_worker_envs=config["remote_worker_envs"], remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"], soft_horizon=config["soft_horizon"], no_done_at_end=config["no_done_at_end"], seed=(config["seed"] + worker_index) if config["seed"] is not None else None, _fake_sampler=config.get("_fake_sampler", False))
def _make_worker( self, *, cls: Callable, env_creator: Callable[[EnvContext], EnvType], validate_env: Optional[Callable[[EnvType], None]], policy_cls: Type[Policy], worker_index: int, num_workers: int, config: TrainerConfigDict, spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]]] = None, ) -> Union[RolloutWorker, "ActorHandle"]: def session_creator(): logger.debug("Creating TF session {}".format( config["tf_session_args"])) return tf1.Session(config=tf1.ConfigProto( **config["tf_session_args"])) if isinstance(config["input"], FunctionType): input_creator = config["input"] elif config["input"] == "sampler": input_creator = (lambda ioctx: ioctx.default_sampler_input()) elif isinstance(config["input"], dict): input_creator = ( lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx), config["shuffle_buffer_size"])) else: input_creator = ( lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx), config["shuffle_buffer_size"])) if isinstance(config["output"], FunctionType): output_creator = config["output"] elif config["output"] is None: output_creator = (lambda ioctx: NoopOutput()) elif config["output"] == "logdir": output_creator = (lambda ioctx: JsonWriter( ioctx.log_dir, ioctx, max_file_size=config["output_max_file_size"], compress_columns=config["output_compress_columns"])) else: output_creator = (lambda ioctx: JsonWriter( config["output"], ioctx, max_file_size=config["output_max_file_size"], compress_columns=config["output_compress_columns"])) if config["input"] == "sampler": input_evaluation = [] else: input_evaluation = config["input_evaluation"] # Fill in the default policy_cls if 'None' is specified in multiagent. if config["multiagent"]["policies"]: tmp = config["multiagent"]["policies"] _validate_multiagent_config(tmp, allow_none_graph=True) # TODO: (sven) Allow for setting observation and action spaces to # None as well, in which case, spaces are taken from env. # It's tedious to have to provide these in a multi-agent config. for k, v in tmp.items(): if v[0] is None: tmp[k] = (policy_cls, v[1], v[2], v[3]) policy_spec = tmp # Otherwise, policy spec is simply the policy class itself. else: policy_spec = policy_cls if worker_index == 0: extra_python_environs = config.get( "extra_python_environs_for_driver", None) else: extra_python_environs = config.get( "extra_python_environs_for_worker", None) worker = cls( env_creator=env_creator, validate_env=validate_env, policy_spec=policy_spec, policy_mapping_fn=config["multiagent"]["policy_mapping_fn"], policies_to_train=config["multiagent"]["policies_to_train"], tf_session_creator=(session_creator if config["tf_session_args"] else None), rollout_fragment_length=config["rollout_fragment_length"], batch_mode=config["batch_mode"], episode_horizon=config["horizon"], preprocessor_pref=config["preprocessor_pref"], sample_async=config["sample_async"], compress_observations=config["compress_observations"], num_envs=config["num_envs_per_worker"], observation_fn=config["multiagent"]["observation_fn"], observation_filter=config["observation_filter"], clip_rewards=config["clip_rewards"], clip_actions=config["clip_actions"], env_config=config["env_config"], model_config=config["model"], policy_config=config, worker_index=worker_index, num_workers=num_workers, monitor_path=self._logdir if config["monitor"] else None, log_dir=self._logdir, log_level=config["log_level"], callbacks=config["callbacks"], input_creator=input_creator, input_evaluation=input_evaluation, output_creator=output_creator, remote_worker_envs=config["remote_worker_envs"], remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"], soft_horizon=config["soft_horizon"], no_done_at_end=config["no_done_at_end"], seed=(config["seed"] + worker_index) if config["seed"] is not None else None, fake_sampler=config["fake_sampler"], extra_python_environs=extra_python_environs, spaces=spaces, ) return worker
def _make_worker(self, cls, env_creator, policy, worker_index, config): def session_creator(): logger.debug("Creating TF session {}".format( config["tf_session_args"])) return tf.Session(config=tf.ConfigProto( **config["tf_session_args"])) if isinstance(config["input"], FunctionType): input_creator = config["input"] elif config["input"] == "sampler": input_creator = (lambda ioctx: ioctx.default_sampler_input()) elif isinstance(config["input"], dict): input_creator = ( lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx), config["shuffle_buffer_size"])) else: input_creator = ( lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx), config["shuffle_buffer_size"])) if isinstance(config["output"], FunctionType): output_creator = config["output"] elif config["output"] is None: output_creator = (lambda ioctx: NoopOutput()) elif config["output"] == "logdir": output_creator = (lambda ioctx: JsonWriter( ioctx.log_dir, ioctx, max_file_size=config["output_max_file_size"], compress_columns=config["output_compress_columns"])) else: output_creator = (lambda ioctx: JsonWriter( config["output"], ioctx, max_file_size=config["output_max_file_size"], compress_columns=config["output_compress_columns"])) if config["input"] == "sampler": input_evaluation = [] else: input_evaluation = config["input_evaluation"] # Fill in the default policy if 'None' is specified in multiagent if config["multiagent"]["policies"]: tmp = config["multiagent"]["policies"] _validate_multiagent_config(tmp, allow_none_graph=True) for k, v in tmp.items(): if v[0] is None: tmp[k] = (policy, v[1], v[2], v[3]) policy = tmp worker = cls( env_creator, policy, policy_mapping_fn=config["multiagent"]["policy_mapping_fn"], policies_to_train=config["multiagent"]["policies_to_train"], tf_session_creator=(session_creator if config["tf_session_args"] else None), rollout_fragment_length=config["rollout_fragment_length"], batch_mode=config["batch_mode"], episode_horizon=config["horizon"], preprocessor_pref=config["preprocessor_pref"], sample_async=config["sample_async"], compress_observations=config["compress_observations"], num_envs=config["num_envs_per_worker"], observation_filter=config["observation_filter"], clip_rewards=config["clip_rewards"], clip_actions=config["clip_actions"], env_config=config["env_config"], model_config=config["model"], policy_config=config, worker_index=worker_index, num_workers=self._num_workers, monitor_path=self._logdir if config["monitor"] else None, log_dir=self._logdir, log_level=config["log_level"], callbacks=config["callbacks"], input_creator=input_creator, input_evaluation=input_evaluation, output_creator=output_creator, remote_worker_envs=config["remote_worker_envs"], remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"], soft_horizon=config["soft_horizon"], no_done_at_end=config["no_done_at_end"], seed=(config["seed"] + worker_index) if config["seed"] is not None else None, _fake_sampler=config.get("_fake_sampler", False)) # Check for correct policy class (only locally, remote Workers should # create the exact same Policy types). if type(worker) is RolloutWorker: actual_class = type(worker.get_policy()) # Pytorch case: Policy must be a TorchPolicy. if config["use_pytorch"]: assert issubclass(actual_class, TorchPolicy), \ "Worker policy must be subclass of `TorchPolicy`, " \ "but is {}!".format(actual_class.__name__) # non-Pytorch case: # Policy may be None AND must not be a TorchPolicy. else: assert issubclass(actual_class, type(None)) or \ (issubclass(actual_class, Policy) and not issubclass(actual_class, TorchPolicy)), "Worker " \ "policy must be subclass of `Policy`, but NOT " \ "`TorchPolicy` (your class={})! If you have a torch " \ "Trainer, make sure to set `use_pytorch=True` in " \ "your Trainer's config)!".format(actual_class.__name__) return worker