def _validate_multiagent_config(policy: MultiAgentPolicyConfigDict, allow_none_graph: bool = False): for k, v in policy.items(): if not isinstance(k, str): raise ValueError("policy keys must be strs, got {}".format( type(k))) if not isinstance(v, (tuple, list)) or len(v) != 4: raise ValueError( "policy values must be tuples/lists of " "(cls or None, obs_space, action_space, config), got {}". format(v)) if allow_none_graph and v[0] is None: pass elif not issubclass(v[0], Policy): raise ValueError("policy tuple value 0 must be a rllib.Policy " "class or None, got {}".format(v[0])) if not isinstance(v[1], gym.Space): raise ValueError( "policy tuple value 1 (observation_space) must be a " "gym.Space, got {}".format(type(v[1]))) if not isinstance(v[2], gym.Space): raise ValueError("policy tuple value 2 (action_space) must be a " "gym.Space, got {}".format(type(v[2]))) if not isinstance(v[3], dict): raise ValueError("policy tuple value 3 (config) must be a dict, " "got {}".format(type(v[3])))
def _build_policy_map( self, policy_dict: MultiAgentPolicyConfigDict, policy_config: TrainerConfigDict ) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]: policy_map = {} preprocessors = {} for name, (cls, obs_space, act_space, conf) in sorted(policy_dict.items()): logger.debug("Creating policy for {}".format(name)) merged_conf = merge_dicts(policy_config, conf) merged_conf["num_workers"] = self.num_workers merged_conf["worker_index"] = self.worker_index if self.preprocessing_enabled: preprocessor = ModelCatalog.get_preprocessor_for_space( obs_space, merged_conf.get("model")) preprocessors[name] = preprocessor obs_space = preprocessor.observation_space else: preprocessors[name] = NoPreprocessor(obs_space) if isinstance(obs_space, gym.spaces.Dict) or \ isinstance(obs_space, gym.spaces.Tuple): raise ValueError( "Found raw Tuple|Dict space as input to policy. " "Please preprocess these observations with a " "Tuple|DictFlatteningPreprocessor.") if tf1 and tf1.executing_eagerly(): if hasattr(cls, "as_eager"): cls = cls.as_eager() if policy_config.get("eager_tracing"): cls = cls.with_tracing() elif not issubclass(cls, TFPolicy): pass # could be some other type of policy else: raise ValueError("This policy does not support eager " "execution: {}".format(cls)) if tf1: with tf1.variable_scope(name): policy_map[name] = cls(obs_space, act_space, merged_conf) else: policy_map[name] = cls(obs_space, act_space, merged_conf) if self.worker_index == 0: logger.info("Built policy map: {}".format(policy_map)) logger.info("Built preprocessor map: {}".format(preprocessors)) return policy_map, preprocessors