def get_action_shape(action_space): """Returns action tensor dtype and shape for the action space. Args: action_space (Space): Action space of the target gym env. Returns: (dtype, shape): Dtype and shape of the actions tensor. """ if isinstance(action_space, gym.spaces.Discrete): return (tf.int64, (None, )) elif isinstance(action_space, (gym.spaces.Box, Simplex)): return (tf.float32, (None, ) + action_space.shape) elif isinstance(action_space, gym.spaces.MultiDiscrete): return (tf.as_dtype(action_space.dtype), (None, ) + action_space.shape) elif isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)): flat_action_space = flatten_space(action_space) size = 0 all_discrete = True for i in range(len(flat_action_space)): if isinstance(flat_action_space[i], gym.spaces.Discrete): size += 1 else: all_discrete = False size += np.product(flat_action_space[i].shape) size = int(size) return (tf.int64 if all_discrete else tf.float32, (None, size)) else: raise NotImplementedError( "Action space {} not supported".format(action_space))
def get_action_shape(action_space: gym.Space, framework: str = "tf") -> (np.dtype, List[int]): """Returns action tensor dtype and shape for the action space. Args: action_space (Space): Action space of the target gym env. framework (str): The framework identifier. One of "tf" or "torch". Returns: (dtype, shape): Dtype and shape of the actions tensor. """ dl_lib = torch if framework == "torch" else tf if isinstance(action_space, Discrete): return action_space.dtype, (None, ) elif isinstance(action_space, (Box, Simplex)): return dl_lib.float32, (None, ) + action_space.shape elif isinstance(action_space, MultiDiscrete): return action_space.dtype, (None, ) + action_space.shape elif isinstance(action_space, (Tuple, Dict)): flat_action_space = flatten_space(action_space) size = 0 all_discrete = True for i in range(len(flat_action_space)): if isinstance(flat_action_space[i], Discrete): size += 1 else: all_discrete = False size += np.product(flat_action_space[i].shape) size = int(size) return dl_lib.int64 if all_discrete else dl_lib.float32, \ (None, size) else: raise NotImplementedError( "Action space {} not supported".format(action_space))
def _get_multi_action_distribution(dist_class, action_space, config, framework): # In case the custom distribution is a child of MultiActionDistr. # If users want to completely ignore the suggested child # distributions, they should simply do so in their custom class' # constructor. if issubclass(dist_class, (MultiActionDistribution, TorchMultiActionDistribution)): flat_action_space = flatten_space(action_space) child_dists_and_in_lens = tree.map_structure( lambda s: ModelCatalog.get_action_dist( s, config, framework=framework), flat_action_space, ) child_dists = [e[0] for e in child_dists_and_in_lens] input_lens = [int(e[1]) for e in child_dists_and_in_lens] return ( partial( dist_class, action_space=action_space, child_distributions=child_dists, input_lens=input_lens, ), int(sum(input_lens)), ) return dist_class, dist_class.required_model_output_shape( action_space, config)
def __init__(self, config): self.observation_space = config.get( "space", Tuple([Discrete(2), Dict({"a": Box(-1.0, 1.0, (2, ))})])) self.action_space = self.observation_space self.flattened_action_space = flatten_space(self.action_space) self.episode_len = config.get("episode_len", 100)
def get_action_dim(action_space: gym.Space): """Returns action dim, Args: action_space (Space): Action space of the target gym env. Returns: """ if isinstance(action_space, gym.spaces.Discrete): return action_space.n elif isinstance(action_space, (gym.spaces.Box)): return np.product(action_space.shape) * 2 # 对角高斯 elif isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)): flat_action_space = flatten_space(action_space) size = 0 all_discrete = True for i in range(len(flat_action_space)): size += get_action_dim(flat_action_space[i]) return size else: raise NotImplementedError( "Action space {} not supported".format(action_space))
def get_action_dist(action_space, config, dist_type=None, framework="tf", **kwargs): """Returns a distribution class and size for the given action space. Args: action_space (Space): Action space of the target gym env. config (Optional[dict]): Optional model config. dist_type (Optional[str]): Identifier of the action distribution interpreted as a hint. framework (str): One of "tf", "tfe", or "torch". kwargs (dict): Optional kwargs to pass on to the Distribution's constructor. Returns: Tuple: - dist_class (ActionDistribution): Python class of the distribution. - dist_dim (int): The size of the input vector to the distribution. """ dist = None config = config or MODEL_DEFAULTS # Custom distribution given. if config.get("custom_action_dist"): action_dist_name = config["custom_action_dist"] logger.debug( "Using custom action distribution {}".format(action_dist_name)) dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name) # Dist_type is given directly as a class. elif type(dist_type) is type and \ issubclass(dist_type, ActionDistribution) and \ dist_type not in ( MultiActionDistribution, TorchMultiActionDistribution): dist = dist_type # Box space -> DiagGaussian OR Deterministic. elif isinstance(action_space, gym.spaces.Box): if len(action_space.shape) > 1: raise UnsupportedSpaceException( "Action space has multiple dimensions " "{}. ".format(action_space.shape) + "Consider reshaping this into a single dimension, " "using a custom action distribution, " "using a Tuple action space, or the multi-agent API.") # TODO(sven): Check for bounds and return SquashedNormal, etc.. if dist_type is None: dist = TorchDiagGaussian if framework == "torch" \ else DiagGaussian elif dist_type == "deterministic": dist = TorchDeterministic if framework == "torch" \ else Deterministic # Discrete Space -> Categorical. elif isinstance(action_space, gym.spaces.Discrete): dist = TorchCategorical if framework == "torch" else Categorical # Tuple/Dict Spaces -> MultiAction. elif dist_type in (MultiActionDistribution, TorchMultiActionDistribution) or \ isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)): flat_action_space = flatten_space(action_space) child_dists_and_in_lens = tree.map_structure( lambda s: ModelCatalog.get_action_dist( s, config, framework=framework), flat_action_space) child_dists = [e[0] for e in child_dists_and_in_lens] input_lens = [int(e[1]) for e in child_dists_and_in_lens] return partial( (TorchMultiActionDistribution if framework == "torch" else MultiActionDistribution), action_space=action_space, child_distributions=child_dists, input_lens=input_lens), int(sum(input_lens)) # Simplex -> Dirichlet. elif isinstance(action_space, Simplex): if framework == "torch": # TODO(sven): implement raise NotImplementedError( "Simplex action spaces not supported for torch.") dist = Dirichlet # MultiDiscrete -> MultiCategorical. elif isinstance(action_space, gym.spaces.MultiDiscrete): dist = TorchMultiCategorical if framework == "torch" else \ MultiCategorical return partial(dist, input_lens=action_space.nvec), \ int(sum(action_space.nvec)) # Unknown type -> Error. else: raise NotImplementedError("Unsupported args: {} {}".format( action_space, dist_type)) return dist, dist.required_model_output_shape(action_space, config)
def __init__(self, obs_space, action_space, num_outputs, model_config, name): self.original_space = obs_space.original_space if \ hasattr(obs_space, "original_space") else obs_space assert isinstance(self.original_space, (Dict, Tuple)), \ "`obs_space.original_space` must be [Dict|Tuple]!" self.processed_obs_space = self.original_space if \ model_config.get("_disable_preprocessor_api") else obs_space nn.Module.__init__(self) TorchModelV2.__init__(self, self.original_space, action_space, num_outputs, model_config, name) self.flattened_input_space = flatten_space(self.original_space) # Atari type CNNs or IMPALA type CNNs (with residual layers)? # self.cnn_type = self.model_config["custom_model_config"].get( # "conv_type", "atari") # Build the CNN(s) given obs_space's image components. self.cnns = {} self.one_hot = {} self.flatten = {} concat_size = 0 for i, component in enumerate(self.flattened_input_space): # Image space. if len(component.shape) == 3: config = { "conv_filters": model_config["conv_filters"] if "conv_filters" in model_config else get_filter_config(obs_space.shape), "conv_activation": model_config.get("conv_activation"), "post_fcnet_hiddens": [], } # if self.cnn_type == "atari": cnn = ModelCatalog.get_model_v2(component, action_space, num_outputs=None, model_config=config, framework="torch", name="cnn_{}".format(i)) # TODO (sven): add IMPALA-style option. # else: # cnn = TorchImpalaVisionNet( # component, # action_space, # num_outputs=None, # model_config=config, # name="cnn_{}".format(i)) concat_size += cnn.num_outputs self.cnns[i] = cnn self.add_module("cnn_{}".format(i), cnn) # Discrete|MultiDiscrete inputs -> One-hot encode. elif isinstance(component, Discrete): self.one_hot[i] = True concat_size += component.n elif isinstance(component, MultiDiscrete): self.one_hot[i] = True concat_size += sum(component.nvec) # Everything else (1D Box). else: self.flatten[i] = int(np.product(component.shape)) concat_size += self.flatten[i] # Optional post-concat FC-stack. post_fc_stack_config = { "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []), "fcnet_activation": model_config.get("post_fcnet_activation", "relu") } self.post_fc_stack = ModelCatalog.get_model_v2(Box( float("-inf"), float("inf"), shape=(concat_size, ), dtype=np.float32), self.action_space, None, post_fc_stack_config, framework="torch", name="post_fc_stack") # Actions and value heads. self.logits_layer = None self.value_layer = None self._value_out = None if num_outputs: # Action-distribution head. self.logits_layer = SlimFC( in_size=self.post_fc_stack.num_outputs, out_size=num_outputs, activation_fn=None, ) # Create the value branch model. self.value_layer = SlimFC( in_size=self.post_fc_stack.num_outputs, out_size=1, activation_fn=None, initializer=torch_normc_initializer(0.01)) else: self.num_outputs = concat_size
def __init__(self, obs_space, action_space, num_outputs, model_config, name): self.original_space = obs_space.original_space if \ hasattr(obs_space, "original_space") else obs_space assert isinstance(self.original_space, (Dict, Tuple)), \ "`obs_space.original_space` must be [Dict|Tuple]!" self.processed_obs_space = self.original_space if \ model_config.get("_disable_preprocessor_api") else obs_space super().__init__(self.original_space, action_space, num_outputs, model_config, name) self.flattened_input_space = flatten_space(self.original_space) # Build the CNN(s) given obs_space's image components. self.cnns = {} self.one_hot = {} self.flatten = {} concat_size = 0 for i, component in enumerate(self.flattened_input_space): # Image space. if len(component.shape) == 3: config = { "conv_filters": model_config["conv_filters"] if "conv_filters" in model_config else get_filter_config(obs_space.shape), "conv_activation": model_config.get("conv_activation"), "post_fcnet_hiddens": [], } cnn = ModelCatalog.get_model_v2(component, action_space, num_outputs=None, model_config=config, framework="tf", name="cnn_{}".format(i)) concat_size += cnn.num_outputs self.cnns[i] = cnn # Discrete|MultiDiscrete inputs -> One-hot encode. elif isinstance(component, Discrete): self.one_hot[i] = True concat_size += component.n elif isinstance(component, MultiDiscrete): self.one_hot[i] = True concat_size += sum(component.nvec) # Everything else (1D Box). else: self.flatten[i] = int(np.product(component.shape)) concat_size += self.flatten[i] # Optional post-concat FC-stack. post_fc_stack_config = { "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []), "fcnet_activation": model_config.get("post_fcnet_activation", "relu") } self.post_fc_stack = ModelCatalog.get_model_v2(Box( float("-inf"), float("inf"), shape=(concat_size, ), dtype=np.float32), self.action_space, None, post_fc_stack_config, framework="tf", name="post_fc_stack") # Actions and value heads. self.logits_and_value_model = None self._value_out = None if num_outputs: # Action-distribution head. concat_layer = tf.keras.layers.Input( (self.post_fc_stack.num_outputs, )) logits_layer = tf.keras.layers.Dense( num_outputs, activation=tf.keras.activations.linear, name="logits")(concat_layer) # Create the value branch model. value_layer = tf.keras.layers.Dense( 1, name="value_out", activation=None, kernel_initializer=normc_initializer(0.01))(concat_layer) self.logits_and_value_model = tf.keras.models.Model( concat_layer, [logits_layer, value_layer]) else: self.num_outputs = self.post_fc_stack.num_outputs