def test_write_dataset(self): ioctx = IOContext( self.test_dir, { "output": "dataset", "output_config": { "format": "json", "path": self.test_dir, "max_num_samples_per_file": 2, }, }, 0, None) writer = DatasetWriter(ioctx, compress_columns=["obs"]) self.assertEqual(len(os.listdir(self.test_dir)), 0) writer.write(SAMPLES) writer.write(SAMPLES) self.assertEqual(len(os.listdir(self.test_dir)), 1)
def _make_worker( self, *, cls: Callable, env_creator: EnvCreator, validate_env: Optional[Callable[[EnvType], None]], policy_cls: Type[Policy], worker_index: int, num_workers: int, recreated_worker: bool = False, config: AlgorithmConfigDict, spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]]] = None, ) -> Union[RolloutWorker, ActorHandle]: def session_creator(): logger.debug("Creating TF session {}".format( config["tf_session_args"])) return tf1.Session(config=tf1.ConfigProto( **config["tf_session_args"])) def valid_module(class_path): if (isinstance(class_path, str) and not os.path.isfile(class_path) and "." in class_path): module_path, class_name = class_path.rsplit(".", 1) try: spec = importlib.util.find_spec(module_path) if spec is not None: return True except (ModuleNotFoundError, ValueError): print( f"module {module_path} not found while trying to get " f"input {class_path}") return False # A callable returning an InputReader object to use. if isinstance(config["input"], FunctionType): input_creator = config["input"] # Use RLlib's Sampler classes (SyncSampler or AsynchSampler, depending # on `config.sample_async` setting). elif config["input"] == "sampler": input_creator = lambda ioctx: ioctx.default_sampler_input() # Ray Dataset input -> Use `config.input_config` to construct DatasetReader. elif config["input"] == "dataset": # Input dataset shards should have already been prepared. # We just need to take the proper shard here. input_creator = lambda ioctx: DatasetReader( ioctx, self._ds_shards[worker_index]) # Dict: Mix of different input methods with different ratios. elif isinstance(config["input"], dict): input_creator = lambda ioctx: ShuffledInput( MixedInput(config["input"], ioctx), config[ "shuffle_buffer_size"]) # A pre-registered input descriptor (str). elif isinstance(config["input"], str) and registry_contains_input( config["input"]): input_creator = registry_get_input(config["input"]) # D4RL input. elif "d4rl" in config["input"]: env_name = config["input"].split(".")[-1] input_creator = lambda ioctx: D4RLReader(env_name, ioctx) # Valid python module (class path) -> Create using `from_config`. elif valid_module(config["input"]): input_creator = lambda ioctx: ShuffledInput( from_config(config["input"], ioctx=ioctx)) # JSON file or list of JSON files -> Use JsonReader (shuffled). else: input_creator = lambda ioctx: ShuffledInput( JsonReader(config["input"], ioctx), config[ "shuffle_buffer_size"]) if isinstance(config["output"], FunctionType): output_creator = config["output"] elif config["output"] is None: output_creator = lambda ioctx: NoopOutput() elif config["output"] == "dataset": output_creator = lambda ioctx: DatasetWriter( ioctx, compress_columns=config["output_compress_columns"]) elif config["output"] == "logdir": output_creator = lambda ioctx: JsonWriter( ioctx.log_dir, ioctx, max_file_size=config["output_max_file_size"], compress_columns=config["output_compress_columns"], ) else: output_creator = lambda ioctx: JsonWriter( config["output"], ioctx, max_file_size=config["output_max_file_size"], compress_columns=config["output_compress_columns"], ) # Assert everything is correct in "multiagent" config dict (if given). ma_policies = config["multiagent"]["policies"] if ma_policies: for pid, policy_spec in ma_policies.copy().items(): assert isinstance(policy_spec, PolicySpec) # Class is None -> Use `policy_cls`. if policy_spec.policy_class is None: ma_policies[pid].policy_class = policy_cls policies = ma_policies # Create a policy_spec (MultiAgentPolicyConfigDict), # even if no "multiagent" setup given by user. else: policies = policy_cls if worker_index == 0: extra_python_environs = config.get( "extra_python_environs_for_driver", None) else: extra_python_environs = config.get( "extra_python_environs_for_worker", None) worker = cls( env_creator=env_creator, validate_env=validate_env, policy_spec=policies, policy_mapping_fn=config["multiagent"]["policy_mapping_fn"], policies_to_train=config["multiagent"]["policies_to_train"], tf_session_creator=(session_creator if config["tf_session_args"] else None), rollout_fragment_length=config["rollout_fragment_length"], count_steps_by=config["multiagent"]["count_steps_by"], batch_mode=config["batch_mode"], episode_horizon=config["horizon"], preprocessor_pref=config["preprocessor_pref"], sample_async=config["sample_async"], compress_observations=config["compress_observations"], num_envs=config["num_envs_per_worker"], observation_fn=config["multiagent"]["observation_fn"], observation_filter=config["observation_filter"], clip_rewards=config["clip_rewards"], normalize_actions=config["normalize_actions"], clip_actions=config["clip_actions"], env_config=config["env_config"], policy_config=config, worker_index=worker_index, num_workers=num_workers, recreated_worker=recreated_worker, log_dir=self._logdir, log_level=config["log_level"], callbacks=config["callbacks"], input_creator=input_creator, output_creator=output_creator, remote_worker_envs=config["remote_worker_envs"], remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"], soft_horizon=config["soft_horizon"], no_done_at_end=config["no_done_at_end"], seed=(config["seed"] + worker_index) if config["seed"] is not None else None, fake_sampler=config["fake_sampler"], extra_python_environs=extra_python_environs, spaces=spaces, disable_env_checking=config["disable_env_checking"], ) return worker