def _build_policy_map(self, policy_dict, policy_config): policy_map = {} preprocessors = {} for name, (cls, obs_space, act_space, conf) in sorted(policy_dict.items()): logger.debug("Creating policy graph for {}".format(name)) merged_conf = merge_dicts(policy_config, conf) if self.preprocessing_enabled: preprocessor = ModelCatalog.get_preprocessor_for_space( obs_space, merged_conf.get("model")) preprocessors[name] = preprocessor obs_space = preprocessor.observation_space else: preprocessors[name] = NoPreprocessor(obs_space) if isinstance(obs_space, gym.spaces.Dict) or \ isinstance(obs_space, gym.spaces.Tuple): raise ValueError( "Found raw Tuple|Dict space as input to policy graph. " "Please preprocess these observations with a " "Tuple|DictFlatteningPreprocessor.") if tf: with tf.variable_scope(name): policy_map[name] = cls(obs_space, act_space, merged_conf) else: policy_map[name] = cls(obs_space, act_space, merged_conf) if self.worker_index == 0: logger.info("Built policy map: {}".format(policy_map)) logger.info("Built preprocessor map: {}".format(preprocessors)) return policy_map, preprocessors
def get_preprocessor(cls, env_name, obs_shape, options=dict()): """Returns a suitable processor for the given environment. Args: env_name (str): The name of the environment. obs_shape (tuple): The shape of the env observation space. Returns: preprocessor (Preprocessor): Preprocessor for the env observations. """ ATARI_OBS_SHAPE = (210, 160, 3) ATARI_RAM_OBS_SHAPE = (128, ) for k in options.keys(): if k not in MODEL_CONFIGS: raise Exception("Unknown config key `{}`, all keys: {}".format( k, MODEL_CONFIGS)) if env_name in cls._registered_preprocessor: return cls._registered_preprocessor[env_name](options) if obs_shape == ATARI_OBS_SHAPE: print("Assuming Atari pixel env, using AtariPixelPreprocessor.") return AtariPixelPreprocessor(options) elif obs_shape == ATARI_RAM_OBS_SHAPE: print("Assuming Atari ram env, using AtariRamPreprocessor.") return AtariRamPreprocessor(options) print("Non-atari env, not using any observation preprocessor.") return NoPreprocessor(options)
def _init_shape(self, obs_space, options): assert isinstance(self._obs_space, gym.spaces.Dict) size = 0 self.preprocessors = [] for space_id, space in self._obs_space.spaces.items(): logger.debug("Creating sub-preprocessor for {}".format(space)) if space_id == 'internal_state': self.dummy_internal_state = np.zeros(shape=space.shape, dtype=np.float32) preprocessor = NoPreprocessor(space, self._options) self.preprocessors.append(preprocessor) size += preprocessor.size return (size, )
def _build_policy_map( self, policy_dict: MultiAgentPolicyConfigDict, policy_config: TrainerConfigDict ) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]: policy_map = {} preprocessors = {} for name, (cls, obs_space, act_space, conf) in sorted(policy_dict.items()): logger.debug("Creating policy for {}".format(name)) merged_conf = merge_dicts(policy_config, conf) merged_conf["num_workers"] = self.num_workers merged_conf["worker_index"] = self.worker_index if self.preprocessing_enabled: preprocessor = ModelCatalog.get_preprocessor_for_space( obs_space, merged_conf.get("model")) preprocessors[name] = preprocessor obs_space = preprocessor.observation_space else: preprocessors[name] = NoPreprocessor(obs_space) if isinstance(obs_space, gym.spaces.Dict) or \ isinstance(obs_space, gym.spaces.Tuple): raise ValueError( "Found raw Tuple|Dict space as input to policy. " "Please preprocess these observations with a " "Tuple|DictFlatteningPreprocessor.") if tf1 and tf1.executing_eagerly(): if hasattr(cls, "as_eager"): cls = cls.as_eager() if policy_config.get("eager_tracing"): cls = cls.with_tracing() elif not issubclass(cls, TFPolicy): pass # could be some other type of policy else: raise ValueError("This policy does not support eager " "execution: {}".format(cls)) if tf1: with tf1.variable_scope(name): policy_map[name] = cls(obs_space, act_space, merged_conf) else: policy_map[name] = cls(obs_space, act_space, merged_conf) if self.worker_index == 0: logger.info("Built policy map: {}".format(policy_map)) logger.info("Built preprocessor map: {}".format(preprocessors)) return policy_map, preprocessors
def get_preprocessor(env_name, obs_shape): """Returns a suitable processor for the given environment. Args: env_name (str): The name of the environment. obs_shape (tuple): The shape of the env observation space. Returns: preprocessor (Preprocessor): Preprocessor for the env observations. """ ATARI_OBS_SHAPE = (210, 160, 3) ATARI_RAM_OBS_SHAPE = (128,) if obs_shape == ATARI_OBS_SHAPE: print("Assuming Atari pixel env, using AtariPixelPreprocessor.") return AtariPixelPreprocessor() elif obs_shape == ATARI_RAM_OBS_SHAPE: print("Assuming Atari ram env, using AtariRamPreprocessor.") return AtariRamPreprocessor() print("Non-atari env, not using any observation preprocessor.") return NoPreprocessor()