예제 #1
0
파일: test_io.py 프로젝트: novahe/ray
    def test_custom_input_registry(self):
        config = {"input_config": {}}
        ioctx = IOContext(self.test_dir, config, 0, None)

        class CustomInputReader(InputReader):
            def __init__(self, ioctx: IOContext):
                self.ioctx = ioctx

            def next(self):
                return 0

        def input_creator(ioctx: IOContext):
            return ShuffledInput(CustomInputReader(ioctx))

        register_input("custom_input", input_creator)
        self.assertTrue(registry_contains_input("custom_input"))
        creator = registry_get_input("custom_input")
        self.assertIsNotNone(creator)
        reader = creator(ioctx)
        self.assertIsInstance(reader, ShuffledInput)
        self.assertEqual(reader.next(), 0)
        self.assertEqual(ioctx.log_dir, self.test_dir)
        self.assertEqual(ioctx.config, config)
        self.assertEqual(ioctx.worker_index, 0)
        self.assertIsNone(ioctx.worker)
        self.assertEqual(ioctx.input_config, {})
예제 #2
0
    def __init__(self, dist: Dict[JsonReader, float], ioctx: IOContext):
        """Initialize a MixedInput.

        Args:
            dist (dict): dict mapping JSONReader paths or "sampler" to
                probabilities. The probabilities must sum to 1.0.
            ioctx (IOContext): current IO context object.
        """
        if sum(dist.values()) != 1.0:
            raise ValueError("Values must sum to 1.0: {}".format(dist))
        self.choices = []
        self.p = []
        for k, v in dist.items():
            if k == "sampler":
                self.choices.append(ioctx.default_sampler_input())
            elif isinstance(k, FunctionType):
                self.choices.append(k(ioctx))
            elif isinstance(k, str) and registry_contains_input(k):
                input_creator = registry_get_input(k)
                self.choices.append(input_creator(ioctx))
            else:
                self.choices.append(JsonReader(k, ioctx))
            self.p.append(v)
예제 #3
0
    def _make_worker(
        self,
        *,
        cls: Callable,
        env_creator: EnvCreator,
        validate_env: Optional[Callable[[EnvType], None]],
        policy_cls: Type[Policy],
        worker_index: int,
        num_workers: int,
        recreated_worker: bool = False,
        config: AlgorithmConfigDict,
        spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
                                              gym.spaces.Space]]] = None,
    ) -> Union[RolloutWorker, ActorHandle]:
        def session_creator():
            logger.debug("Creating TF session {}".format(
                config["tf_session_args"]))
            return tf1.Session(config=tf1.ConfigProto(
                **config["tf_session_args"]))

        def valid_module(class_path):
            if (isinstance(class_path, str) and not os.path.isfile(class_path)
                    and "." in class_path):
                module_path, class_name = class_path.rsplit(".", 1)
                try:
                    spec = importlib.util.find_spec(module_path)
                    if spec is not None:
                        return True
                except (ModuleNotFoundError, ValueError):
                    print(
                        f"module {module_path} not found while trying to get "
                        f"input {class_path}")
            return False

        # A callable returning an InputReader object to use.
        if isinstance(config["input"], FunctionType):
            input_creator = config["input"]
        # Use RLlib's Sampler classes (SyncSampler or AsynchSampler, depending
        # on `config.sample_async` setting).
        elif config["input"] == "sampler":
            input_creator = lambda ioctx: ioctx.default_sampler_input()
        # Ray Dataset input -> Use `config.input_config` to construct DatasetReader.
        elif config["input"] == "dataset":
            # Input dataset shards should have already been prepared.
            # We just need to take the proper shard here.
            input_creator = lambda ioctx: DatasetReader(
                ioctx, self._ds_shards[worker_index])
        # Dict: Mix of different input methods with different ratios.
        elif isinstance(config["input"], dict):
            input_creator = lambda ioctx: ShuffledInput(
                MixedInput(config["input"], ioctx), config[
                    "shuffle_buffer_size"])
        # A pre-registered input descriptor (str).
        elif isinstance(config["input"], str) and registry_contains_input(
                config["input"]):
            input_creator = registry_get_input(config["input"])
        # D4RL input.
        elif "d4rl" in config["input"]:
            env_name = config["input"].split(".")[-1]
            input_creator = lambda ioctx: D4RLReader(env_name, ioctx)
        # Valid python module (class path) -> Create using `from_config`.
        elif valid_module(config["input"]):
            input_creator = lambda ioctx: ShuffledInput(
                from_config(config["input"], ioctx=ioctx))
        # JSON file or list of JSON files -> Use JsonReader (shuffled).
        else:
            input_creator = lambda ioctx: ShuffledInput(
                JsonReader(config["input"], ioctx), config[
                    "shuffle_buffer_size"])

        if isinstance(config["output"], FunctionType):
            output_creator = config["output"]
        elif config["output"] is None:
            output_creator = lambda ioctx: NoopOutput()
        elif config["output"] == "dataset":
            output_creator = lambda ioctx: DatasetWriter(
                ioctx, compress_columns=config["output_compress_columns"])
        elif config["output"] == "logdir":
            output_creator = lambda ioctx: JsonWriter(
                ioctx.log_dir,
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"],
            )
        else:
            output_creator = lambda ioctx: JsonWriter(
                config["output"],
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"],
            )

        # Assert everything is correct in "multiagent" config dict (if given).
        ma_policies = config["multiagent"]["policies"]
        if ma_policies:
            for pid, policy_spec in ma_policies.copy().items():
                assert isinstance(policy_spec, PolicySpec)
                # Class is None -> Use `policy_cls`.
                if policy_spec.policy_class is None:
                    ma_policies[pid].policy_class = policy_cls
            policies = ma_policies

        # Create a policy_spec (MultiAgentPolicyConfigDict),
        # even if no "multiagent" setup given by user.
        else:
            policies = policy_cls

        if worker_index == 0:
            extra_python_environs = config.get(
                "extra_python_environs_for_driver", None)
        else:
            extra_python_environs = config.get(
                "extra_python_environs_for_worker", None)

        worker = cls(
            env_creator=env_creator,
            validate_env=validate_env,
            policy_spec=policies,
            policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
            policies_to_train=config["multiagent"]["policies_to_train"],
            tf_session_creator=(session_creator
                                if config["tf_session_args"] else None),
            rollout_fragment_length=config["rollout_fragment_length"],
            count_steps_by=config["multiagent"]["count_steps_by"],
            batch_mode=config["batch_mode"],
            episode_horizon=config["horizon"],
            preprocessor_pref=config["preprocessor_pref"],
            sample_async=config["sample_async"],
            compress_observations=config["compress_observations"],
            num_envs=config["num_envs_per_worker"],
            observation_fn=config["multiagent"]["observation_fn"],
            observation_filter=config["observation_filter"],
            clip_rewards=config["clip_rewards"],
            normalize_actions=config["normalize_actions"],
            clip_actions=config["clip_actions"],
            env_config=config["env_config"],
            policy_config=config,
            worker_index=worker_index,
            num_workers=num_workers,
            recreated_worker=recreated_worker,
            log_dir=self._logdir,
            log_level=config["log_level"],
            callbacks=config["callbacks"],
            input_creator=input_creator,
            output_creator=output_creator,
            remote_worker_envs=config["remote_worker_envs"],
            remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
            soft_horizon=config["soft_horizon"],
            no_done_at_end=config["no_done_at_end"],
            seed=(config["seed"] +
                  worker_index) if config["seed"] is not None else None,
            fake_sampler=config["fake_sampler"],
            extra_python_environs=extra_python_environs,
            spaces=spaces,
            disable_env_checking=config["disable_env_checking"],
        )

        return worker
예제 #4
0
    def _make_worker(
            self,
            *,
            cls: Callable,
            env_creator: Callable[[EnvContext], EnvType],
            validate_env: Optional[Callable[[EnvType], None]],
            policy_cls: Type[Policy],
            worker_index: int,
            num_workers: int,
            config: TrainerConfigDict,
            spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
                                                  gym.spaces.Space]]] = None,
    ) -> Union[RolloutWorker, ActorHandle]:
        def session_creator():
            logger.debug("Creating TF session {}".format(
                config["tf_session_args"]))
            return tf1.Session(
                config=tf1.ConfigProto(**config["tf_session_args"]))

        def valid_module(class_path):
            if isinstance(class_path, str) and "." in class_path:
                module_path, class_name = class_path.rsplit(".", 1)
                try:
                    spec = importlib.util.find_spec(module_path)
                    if spec is not None:
                        return True
                except (ModuleNotFoundError, ValueError):
                    print(
                        f"module {module_path} not found while trying to get "
                        f"input {class_path}")
            return False

        if isinstance(config["input"], FunctionType):
            input_creator = config["input"]
        elif config["input"] == "sampler":
            input_creator = (lambda ioctx: ioctx.default_sampler_input())
        elif isinstance(config["input"], dict):
            input_creator = (
                lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))
        elif isinstance(config["input"], str) and \
                registry_contains_input(config["input"]):
            input_creator = registry_get_input(config["input"])
        elif "d4rl" in config["input"]:
            env_name = config["input"].split(".")[-1]
            input_creator = (lambda ioctx: D4RLReader(env_name, ioctx))
        elif valid_module(config["input"]):
            input_creator = (lambda ioctx: ShuffledInput(from_config(
                config["input"], ioctx=ioctx)))
        else:
            input_creator = (
                lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))

        if isinstance(config["output"], FunctionType):
            output_creator = config["output"]
        elif config["output"] is None:
            output_creator = (lambda ioctx: NoopOutput())
        elif config["output"] == "logdir":
            output_creator = (lambda ioctx: JsonWriter(
                ioctx.log_dir,
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))
        else:
            output_creator = (lambda ioctx: JsonWriter(
                config["output"],
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))

        if config["input"] == "sampler":
            input_evaluation = []
        else:
            input_evaluation = config["input_evaluation"]

        # Assert everything is correct in "multiagent" config dict (if given).
        ma_policies = config["multiagent"]["policies"]
        if ma_policies:
            for pid, policy_spec in ma_policies.copy().items():
                assert isinstance(policy_spec, (PolicySpec, list, tuple))
                # Class is None -> Use `policy_cls`.
                if policy_spec.policy_class is None:
                    ma_policies[pid] = ma_policies[pid]._replace(
                        policy_class=policy_cls)
            policies = ma_policies

        # Create a policy_spec (MultiAgentPolicyConfigDict),
        # even if no "multiagent" setup given by user.
        else:
            policies = policy_cls

        if worker_index == 0:
            extra_python_environs = config.get(
                "extra_python_environs_for_driver", None)
        else:
            extra_python_environs = config.get(
                "extra_python_environs_for_worker", None)

        worker = cls(
            env_creator=env_creator,
            validate_env=validate_env,
            policy_spec=policies,
            policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
            policies_to_train=config["multiagent"]["policies_to_train"],
            tf_session_creator=(session_creator
                                if config["tf_session_args"] else None),
            rollout_fragment_length=config["rollout_fragment_length"],
            count_steps_by=config["multiagent"]["count_steps_by"],
            batch_mode=config["batch_mode"],
            episode_horizon=config["horizon"],
            preprocessor_pref=config["preprocessor_pref"],
            sample_async=config["sample_async"],
            compress_observations=config["compress_observations"],
            num_envs=config["num_envs_per_worker"],
            observation_fn=config["multiagent"]["observation_fn"],
            observation_filter=config["observation_filter"],
            clip_rewards=config["clip_rewards"],
            normalize_actions=config["normalize_actions"],
            clip_actions=config["clip_actions"],
            env_config=config["env_config"],
            policy_config=config,
            worker_index=worker_index,
            num_workers=num_workers,
            record_env=config["record_env"],
            log_dir=self._logdir,
            log_level=config["log_level"],
            callbacks=config["callbacks"],
            input_creator=input_creator,
            input_evaluation=input_evaluation,
            output_creator=output_creator,
            remote_worker_envs=config["remote_worker_envs"],
            remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
            soft_horizon=config["soft_horizon"],
            no_done_at_end=config["no_done_at_end"],
            seed=(config["seed"] + worker_index)
            if config["seed"] is not None else None,
            fake_sampler=config["fake_sampler"],
            extra_python_environs=extra_python_environs,
            spaces=spaces,
        )

        return worker
예제 #5
0
파일: worker_set.py 프로젝트: ckw017/ray
    def _make_worker(
        self,
        *,
        cls: Callable,
        env_creator: Callable[[EnvContext], EnvType],
        validate_env: Optional[Callable[[EnvType], None]],
        policy_cls: Type[Policy],
        worker_index: int,
        num_workers: int,
        config: TrainerConfigDict,
        spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
                                              gym.spaces.Space]]] = None,
    ) -> Union[RolloutWorker, ActorHandle]:
        def session_creator():
            logger.debug("Creating TF session {}".format(
                config["tf_session_args"]))
            return tf1.Session(config=tf1.ConfigProto(
                **config["tf_session_args"]))

        def valid_module(class_path):
            if isinstance(class_path, str) and "." in class_path:
                module_path, class_name = class_path.rsplit(".", 1)
                try:
                    spec = importlib.util.find_spec(module_path)
                    if spec is not None:
                        return True
                except (ModuleNotFoundError, ValueError):
                    print(
                        f"module {module_path} not found while trying to get "
                        f"input {class_path}")
            return False

        if isinstance(config["input"], FunctionType):
            input_creator = config["input"]
        elif config["input"] == "sampler":
            input_creator = (lambda ioctx: ioctx.default_sampler_input())
        elif isinstance(config["input"], dict):
            input_creator = (
                lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))
        elif isinstance(config["input"], str) and \
                registry_contains_input(config["input"]):
            input_creator = registry_get_input(config["input"])
        elif "d4rl" in config["input"]:
            env_name = config["input"].split(".")[-1]
            input_creator = (lambda ioctx: D4RLReader(env_name, ioctx))
        elif valid_module(config["input"]):
            input_creator = (lambda ioctx: ShuffledInput(
                from_config(config["input"], ioctx=ioctx)))
        else:
            input_creator = (
                lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))

        if isinstance(config["output"], FunctionType):
            output_creator = config["output"]
        elif config["output"] is None:
            output_creator = (lambda ioctx: NoopOutput())
        elif config["output"] == "logdir":
            output_creator = (lambda ioctx: JsonWriter(
                ioctx.log_dir,
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))
        else:
            output_creator = (lambda ioctx: JsonWriter(
                config["output"],
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))

        if config["input"] == "sampler":
            input_evaluation = []
        else:
            input_evaluation = config["input_evaluation"]

        # Fill in the default policy_cls if 'None' is specified in multiagent.
        if config["multiagent"]["policies"]:
            tmp = config["multiagent"]["policies"]
            _validate_multiagent_config(tmp, allow_none_graph=True)
            # TODO: (sven) Allow for setting observation and action spaces to
            #  None as well, in which case, spaces are taken from env.
            #  It's tedious to have to provide these in a multi-agent config.
            for k, v in tmp.items():
                if v[0] is None:
                    tmp[k] = (policy_cls, v[1], v[2], v[3])
            policy_spec = tmp
        # Otherwise, policy spec is simply the policy class itself.
        else:
            policy_spec = policy_cls

        if worker_index == 0:
            extra_python_environs = config.get(
                "extra_python_environs_for_driver", None)
        else:
            extra_python_environs = config.get(
                "extra_python_environs_for_worker", None)

        worker = cls(
            env_creator=env_creator,
            validate_env=validate_env,
            policy_spec=policy_spec,
            policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
            policies_to_train=config["multiagent"]["policies_to_train"],
            tf_session_creator=(session_creator
                                if config["tf_session_args"] else None),
            rollout_fragment_length=config["rollout_fragment_length"],
            count_steps_by=config["multiagent"]["count_steps_by"],
            batch_mode=config["batch_mode"],
            episode_horizon=config["horizon"],
            preprocessor_pref=config["preprocessor_pref"],
            sample_async=config["sample_async"],
            compress_observations=config["compress_observations"],
            num_envs=config["num_envs_per_worker"],
            observation_fn=config["multiagent"]["observation_fn"],
            observation_filter=config["observation_filter"],
            clip_rewards=config["clip_rewards"],
            normalize_actions=config["normalize_actions"],
            clip_actions=config["clip_actions"],
            env_config=config["env_config"],
            model_config=config["model"],
            policy_config=config,
            worker_index=worker_index,
            num_workers=num_workers,
            record_env=config["record_env"],
            log_dir=self._logdir,
            log_level=config["log_level"],
            callbacks=config["callbacks"],
            input_creator=input_creator,
            input_evaluation=input_evaluation,
            output_creator=output_creator,
            remote_worker_envs=config["remote_worker_envs"],
            remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
            soft_horizon=config["soft_horizon"],
            no_done_at_end=config["no_done_at_end"],
            seed=(config["seed"] +
                  worker_index) if config["seed"] is not None else None,
            fake_sampler=config["fake_sampler"],
            extra_python_environs=extra_python_environs,
            spaces=spaces,
        )

        return worker