Ejemplo n.º 1
0
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config

        if self.config["normalize_actions"]:
            inner = self.env_creator
            self.env_creator = (
                lambda env_config: NormalizeActionWrapper(inner(env_config)))

        Trainer._validate_config(self.config)
        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info("Current log_level is {}. For more information, "
                        "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                        "-vv flags.".format(log_level))
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        def get_scope():
            if tf and not tf.executing_eagerly():
                return tf.Graph().as_default()
            else:
                return open("/dev/null")  # fake a no-op scope

        with get_scope():
            self._init(self.config, self.env_creator)

            # Evaluation setup.
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "batch_steps": 1,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))

                self.evaluation_workers = self._make_workers(
                    self.env_creator,
                    self._policy,
                    merge_dicts(self.config, extra_config),
                    num_workers=self.config["evaluation_num_workers"])
                self.evaluation_metrics = {}
Ejemplo n.º 2
0
 def normalize(env):
     import gym  # soft dependency
     if not isinstance(env, gym.Env):
         raise ValueError(
             "Cannot apply NormalizeActionActionWrapper to env of "
             "type {}, which does not subclass gym.Env.", type(env))
     return NormalizeActionWrapper(env)
Ejemplo n.º 3
0
    def __init__(self, env_config):
        assert "num_agents" in env_config
        assert "env_name" in env_config
        num_agents = env_config['num_agents']
        agent_ids = ["agent{}".format(i) for i in range(num_agents)]
        self._render_policy = env_config.get('render_policy')
        self.num_agents = num_agents
        self.agent_ids = agent_ids
        self.env_name = env_config['env_name']
        self.env_maker = get_env_maker(env_config['env_name'],
                                       require_render=bool(
                                           self._render_policy))

        if env_config.get("normalize_action", False):
            self.env_maker = lambda: NormalizeActionWrapper(self.env_maker())

        self.envs = {}
        if not isinstance(agent_ids, list):
            agent_ids = [agent_ids]
        for aid in agent_ids:
            if aid not in self.envs:
                self.envs[aid] = self.env_maker()
        self.dones = set()
        tmp_env = next(iter(self.envs.values()))
        self.observation_space = tmp_env.observation_space
        self.action_space = tmp_env.action_space
        self.reward_range = tmp_env.reward_range
        self.metadata = tmp_env.metadata
        self.spec = tmp_env.spec