Example #1
0
    def _make_worker(self, cls, env_creator, policy, worker_index, config):
        def session_creator():
            logger.debug("Creating TF session {}".format(
                config["tf_session_args"]))
            return tf.Session(config=tf.ConfigProto(
                **config["tf_session_args"]))

        if isinstance(config["input"], FunctionType):
            input_creator = config["input"]
        elif config["input"] == "sampler":
            input_creator = (lambda ioctx: ioctx.default_sampler_input())
        elif isinstance(config["input"], dict):
            input_creator = (
                lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))
        else:
            input_creator = (
                lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))

        if isinstance(config["output"], FunctionType):
            output_creator = config["output"]
        elif config["output"] is None:
            output_creator = (lambda ioctx: NoopOutput())
        elif config["output"] == "logdir":
            output_creator = (lambda ioctx: JsonWriter(
                ioctx.log_dir,
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))
        else:
            output_creator = (lambda ioctx: JsonWriter(
                config["output"],
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))

        if config["input"] == "sampler":
            input_evaluation = []
        else:
            input_evaluation = config["input_evaluation"]

        # Fill in the default policy if 'None' is specified in multiagent
        if config["multiagent"]["policies"]:
            tmp = config["multiagent"]["policies"]
            _validate_multiagent_config(tmp, allow_none_graph=True)
            for k, v in tmp.items():
                if v[0] is None:
                    tmp[k] = (policy, v[1], v[2], v[3])
            policy = tmp

        return cls(env_creator,
                   policy,
                   policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
                   policies_to_train=config["multiagent"]["policies_to_train"],
                   tf_session_creator=(session_creator
                                       if config["tf_session_args"] else None),
                   batch_steps=config["sample_batch_size"],
                   batch_mode=config["batch_mode"],
                   episode_horizon=config["horizon"],
                   preprocessor_pref=config["preprocessor_pref"],
                   sample_async=config["sample_async"],
                   compress_observations=config["compress_observations"],
                   num_envs=config["num_envs_per_worker"],
                   observation_filter=config["observation_filter"],
                   clip_rewards=config["clip_rewards"],
                   clip_actions=config["clip_actions"],
                   env_config=config["env_config"],
                   model_config=config["model"],
                   policy_config=config,
                   worker_index=worker_index,
                   monitor_path=self._logdir if config["monitor"] else None,
                   log_dir=self._logdir,
                   log_level=config["log_level"],
                   callbacks=config["callbacks"],
                   input_creator=input_creator,
                   input_evaluation=input_evaluation,
                   output_creator=output_creator,
                   remote_worker_envs=config["remote_worker_envs"],
                   remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
                   soft_horizon=config["soft_horizon"],
                   no_done_at_end=config["no_done_at_end"],
                   seed=(config["seed"] +
                         worker_index) if config["seed"] is not None else None,
                   _fake_sampler=config.get("_fake_sampler", False))
Example #2
0
    def _make_worker(
        self,
        *,
        cls: Callable,
        env_creator: Callable[[EnvContext], EnvType],
        validate_env: Optional[Callable[[EnvType], None]],
        policy_cls: Type[Policy],
        worker_index: int,
        num_workers: int,
        config: TrainerConfigDict,
        spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
                                              gym.spaces.Space]]] = None,
    ) -> Union[RolloutWorker, "ActorHandle"]:
        def session_creator():
            logger.debug("Creating TF session {}".format(
                config["tf_session_args"]))
            return tf1.Session(config=tf1.ConfigProto(
                **config["tf_session_args"]))

        if isinstance(config["input"], FunctionType):
            input_creator = config["input"]
        elif config["input"] == "sampler":
            input_creator = (lambda ioctx: ioctx.default_sampler_input())
        elif isinstance(config["input"], dict):
            input_creator = (
                lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))
        else:
            input_creator = (
                lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))

        if isinstance(config["output"], FunctionType):
            output_creator = config["output"]
        elif config["output"] is None:
            output_creator = (lambda ioctx: NoopOutput())
        elif config["output"] == "logdir":
            output_creator = (lambda ioctx: JsonWriter(
                ioctx.log_dir,
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))
        else:
            output_creator = (lambda ioctx: JsonWriter(
                config["output"],
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))

        if config["input"] == "sampler":
            input_evaluation = []
        else:
            input_evaluation = config["input_evaluation"]

        # Fill in the default policy_cls if 'None' is specified in multiagent.
        if config["multiagent"]["policies"]:
            tmp = config["multiagent"]["policies"]
            _validate_multiagent_config(tmp, allow_none_graph=True)
            # TODO: (sven) Allow for setting observation and action spaces to
            #  None as well, in which case, spaces are taken from env.
            #  It's tedious to have to provide these in a multi-agent config.
            for k, v in tmp.items():
                if v[0] is None:
                    tmp[k] = (policy_cls, v[1], v[2], v[3])
            policy_spec = tmp
        # Otherwise, policy spec is simply the policy class itself.
        else:
            policy_spec = policy_cls

        if worker_index == 0:
            extra_python_environs = config.get(
                "extra_python_environs_for_driver", None)
        else:
            extra_python_environs = config.get(
                "extra_python_environs_for_worker", None)

        worker = cls(
            env_creator=env_creator,
            validate_env=validate_env,
            policy_spec=policy_spec,
            policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
            policies_to_train=config["multiagent"]["policies_to_train"],
            tf_session_creator=(session_creator
                                if config["tf_session_args"] else None),
            rollout_fragment_length=config["rollout_fragment_length"],
            batch_mode=config["batch_mode"],
            episode_horizon=config["horizon"],
            preprocessor_pref=config["preprocessor_pref"],
            sample_async=config["sample_async"],
            compress_observations=config["compress_observations"],
            num_envs=config["num_envs_per_worker"],
            observation_fn=config["multiagent"]["observation_fn"],
            observation_filter=config["observation_filter"],
            clip_rewards=config["clip_rewards"],
            clip_actions=config["clip_actions"],
            env_config=config["env_config"],
            model_config=config["model"],
            policy_config=config,
            worker_index=worker_index,
            num_workers=num_workers,
            monitor_path=self._logdir if config["monitor"] else None,
            log_dir=self._logdir,
            log_level=config["log_level"],
            callbacks=config["callbacks"],
            input_creator=input_creator,
            input_evaluation=input_evaluation,
            output_creator=output_creator,
            remote_worker_envs=config["remote_worker_envs"],
            remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
            soft_horizon=config["soft_horizon"],
            no_done_at_end=config["no_done_at_end"],
            seed=(config["seed"] +
                  worker_index) if config["seed"] is not None else None,
            fake_sampler=config["fake_sampler"],
            extra_python_environs=extra_python_environs,
            spaces=spaces,
        )

        return worker
    def _make_worker(self, cls, env_creator, policy, worker_index, config):
        def session_creator():
            logger.debug("Creating TF session {}".format(
                config["tf_session_args"]))
            return tf.Session(config=tf.ConfigProto(
                **config["tf_session_args"]))

        if isinstance(config["input"], FunctionType):
            input_creator = config["input"]
        elif config["input"] == "sampler":
            input_creator = (lambda ioctx: ioctx.default_sampler_input())
        elif isinstance(config["input"], dict):
            input_creator = (
                lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))
        else:
            input_creator = (
                lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))

        if isinstance(config["output"], FunctionType):
            output_creator = config["output"]
        elif config["output"] is None:
            output_creator = (lambda ioctx: NoopOutput())
        elif config["output"] == "logdir":
            output_creator = (lambda ioctx: JsonWriter(
                ioctx.log_dir,
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))
        else:
            output_creator = (lambda ioctx: JsonWriter(
                config["output"],
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))

        if config["input"] == "sampler":
            input_evaluation = []
        else:
            input_evaluation = config["input_evaluation"]

        # Fill in the default policy if 'None' is specified in multiagent
        if config["multiagent"]["policies"]:
            tmp = config["multiagent"]["policies"]
            _validate_multiagent_config(tmp, allow_none_graph=True)
            for k, v in tmp.items():
                if v[0] is None:
                    tmp[k] = (policy, v[1], v[2], v[3])
            policy = tmp

        worker = cls(
            env_creator,
            policy,
            policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
            policies_to_train=config["multiagent"]["policies_to_train"],
            tf_session_creator=(session_creator
                                if config["tf_session_args"] else None),
            rollout_fragment_length=config["rollout_fragment_length"],
            batch_mode=config["batch_mode"],
            episode_horizon=config["horizon"],
            preprocessor_pref=config["preprocessor_pref"],
            sample_async=config["sample_async"],
            compress_observations=config["compress_observations"],
            num_envs=config["num_envs_per_worker"],
            observation_filter=config["observation_filter"],
            clip_rewards=config["clip_rewards"],
            clip_actions=config["clip_actions"],
            env_config=config["env_config"],
            model_config=config["model"],
            policy_config=config,
            worker_index=worker_index,
            num_workers=self._num_workers,
            monitor_path=self._logdir if config["monitor"] else None,
            log_dir=self._logdir,
            log_level=config["log_level"],
            callbacks=config["callbacks"],
            input_creator=input_creator,
            input_evaluation=input_evaluation,
            output_creator=output_creator,
            remote_worker_envs=config["remote_worker_envs"],
            remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
            soft_horizon=config["soft_horizon"],
            no_done_at_end=config["no_done_at_end"],
            seed=(config["seed"] +
                  worker_index) if config["seed"] is not None else None,
            _fake_sampler=config.get("_fake_sampler", False))

        # Check for correct policy class (only locally, remote Workers should
        # create the exact same Policy types).
        if type(worker) is RolloutWorker:
            actual_class = type(worker.get_policy())

            # Pytorch case: Policy must be a TorchPolicy.
            if config["use_pytorch"]:
                assert issubclass(actual_class, TorchPolicy), \
                    "Worker policy must be subclass of `TorchPolicy`, " \
                    "but is {}!".format(actual_class.__name__)
            # non-Pytorch case:
            # Policy may be None AND must not be a TorchPolicy.
            else:
                assert issubclass(actual_class, type(None)) or \
                       (issubclass(actual_class, Policy) and
                        not issubclass(actual_class, TorchPolicy)), "Worker " \
                       "policy must be subclass of `Policy`, but NOT " \
                       "`TorchPolicy` (your class={})! If you have a torch " \
                       "Trainer, make sure to set `use_pytorch=True` in " \
                       "your Trainer's config)!".format(actual_class.__name__)

        return worker