예제 #1
0
 def _build_policy_map(self, policy_dict, policy_config):
     policy_map = {}
     preprocessors = {}
     for name, (cls, obs_space, act_space,
                conf) in sorted(policy_dict.items()):
         logger.debug("Creating policy graph for {}".format(name))
         merged_conf = merge_dicts(policy_config, conf)
         if self.preprocessing_enabled:
             preprocessor = ModelCatalog.get_preprocessor_for_space(
                 obs_space, merged_conf.get("model"))
             preprocessors[name] = preprocessor
             obs_space = preprocessor.observation_space
         else:
             preprocessors[name] = NoPreprocessor(obs_space)
         if isinstance(obs_space, gym.spaces.Dict) or \
                 isinstance(obs_space, gym.spaces.Tuple):
             raise ValueError(
                 "Found raw Tuple|Dict space as input to policy graph. "
                 "Please preprocess these observations with a "
                 "Tuple|DictFlatteningPreprocessor.")
         if tf:
             with tf.variable_scope(name):
                 policy_map[name] = cls(obs_space, act_space, merged_conf)
         else:
             policy_map[name] = cls(obs_space, act_space, merged_conf)
     if self.worker_index == 0:
         logger.info("Built policy map: {}".format(policy_map))
         logger.info("Built preprocessor map: {}".format(preprocessors))
     return policy_map, preprocessors
예제 #2
0
파일: catalog.py 프로젝트: cybermaster/ray
    def get_preprocessor(cls, env_name, obs_shape, options=dict()):
        """Returns a suitable processor for the given environment.

        Args:
            env_name (str): The name of the environment.
            obs_shape (tuple): The shape of the env observation space.

        Returns:
            preprocessor (Preprocessor): Preprocessor for the env observations.
        """

        ATARI_OBS_SHAPE = (210, 160, 3)
        ATARI_RAM_OBS_SHAPE = (128, )

        for k in options.keys():
            if k not in MODEL_CONFIGS:
                raise Exception("Unknown config key `{}`, all keys: {}".format(
                    k, MODEL_CONFIGS))

        if env_name in cls._registered_preprocessor:
            return cls._registered_preprocessor[env_name](options)

        if obs_shape == ATARI_OBS_SHAPE:
            print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
            return AtariPixelPreprocessor(options)
        elif obs_shape == ATARI_RAM_OBS_SHAPE:
            print("Assuming Atari ram env, using AtariRamPreprocessor.")
            return AtariRamPreprocessor(options)

        print("Non-atari env, not using any observation preprocessor.")
        return NoPreprocessor(options)
예제 #3
0
    def _init_shape(self, obs_space, options):
        assert isinstance(self._obs_space, gym.spaces.Dict)
        size = 0
        self.preprocessors = []
        for space_id, space in self._obs_space.spaces.items():

            logger.debug("Creating sub-preprocessor for {}".format(space))

            if space_id == 'internal_state':
                self.dummy_internal_state = np.zeros(shape=space.shape,
                                                     dtype=np.float32)
            preprocessor = NoPreprocessor(space, self._options)
            self.preprocessors.append(preprocessor)
            size += preprocessor.size
        return (size, )
예제 #4
0
 def _build_policy_map(
         self, policy_dict: MultiAgentPolicyConfigDict,
         policy_config: TrainerConfigDict
 ) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]:
     policy_map = {}
     preprocessors = {}
     for name, (cls, obs_space, act_space,
                conf) in sorted(policy_dict.items()):
         logger.debug("Creating policy for {}".format(name))
         merged_conf = merge_dicts(policy_config, conf)
         merged_conf["num_workers"] = self.num_workers
         merged_conf["worker_index"] = self.worker_index
         if self.preprocessing_enabled:
             preprocessor = ModelCatalog.get_preprocessor_for_space(
                 obs_space, merged_conf.get("model"))
             preprocessors[name] = preprocessor
             obs_space = preprocessor.observation_space
         else:
             preprocessors[name] = NoPreprocessor(obs_space)
         if isinstance(obs_space, gym.spaces.Dict) or \
                 isinstance(obs_space, gym.spaces.Tuple):
             raise ValueError(
                 "Found raw Tuple|Dict space as input to policy. "
                 "Please preprocess these observations with a "
                 "Tuple|DictFlatteningPreprocessor.")
         if tf1 and tf1.executing_eagerly():
             if hasattr(cls, "as_eager"):
                 cls = cls.as_eager()
                 if policy_config.get("eager_tracing"):
                     cls = cls.with_tracing()
             elif not issubclass(cls, TFPolicy):
                 pass  # could be some other type of policy
             else:
                 raise ValueError("This policy does not support eager "
                                  "execution: {}".format(cls))
         if tf1:
             with tf1.variable_scope(name):
                 policy_map[name] = cls(obs_space, act_space, merged_conf)
         else:
             policy_map[name] = cls(obs_space, act_space, merged_conf)
     if self.worker_index == 0:
         logger.info("Built policy map: {}".format(policy_map))
         logger.info("Built preprocessor map: {}".format(preprocessors))
     return policy_map, preprocessors
예제 #5
0
파일: catalog.py 프로젝트: the-sea/ray
    def get_preprocessor(env_name, obs_shape):
        """Returns a suitable processor for the given environment.

        Args:
            env_name (str): The name of the environment.
            obs_shape (tuple): The shape of the env observation space.

        Returns:
            preprocessor (Preprocessor): Preprocessor for the env observations.
        """

        ATARI_OBS_SHAPE = (210, 160, 3)
        ATARI_RAM_OBS_SHAPE = (128,)

        if obs_shape == ATARI_OBS_SHAPE:
            print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
            return AtariPixelPreprocessor()
        elif obs_shape == ATARI_RAM_OBS_SHAPE:
            print("Assuming Atari ram env, using AtariRamPreprocessor.")
            return AtariRamPreprocessor()

        print("Non-atari env, not using any observation preprocessor.")
        return NoPreprocessor()