コード例 #1
0
    def get_action_shape(action_space):
        """Returns action tensor dtype and shape for the action space.

        Args:
            action_space (Space): Action space of the target gym env.
        Returns:
            (dtype, shape): Dtype and shape of the actions tensor.
        """

        if isinstance(action_space, gym.spaces.Discrete):
            return (tf.int64, (None, ))
        elif isinstance(action_space, (gym.spaces.Box, Simplex)):
            return (tf.float32, (None, ) + action_space.shape)
        elif isinstance(action_space, gym.spaces.MultiDiscrete):
            return (tf.as_dtype(action_space.dtype),
                    (None, ) + action_space.shape)
        elif isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
            flat_action_space = flatten_space(action_space)
            size = 0
            all_discrete = True
            for i in range(len(flat_action_space)):
                if isinstance(flat_action_space[i], gym.spaces.Discrete):
                    size += 1
                else:
                    all_discrete = False
                    size += np.product(flat_action_space[i].shape)
            size = int(size)
            return (tf.int64 if all_discrete else tf.float32, (None, size))
        else:
            raise NotImplementedError(
                "Action space {} not supported".format(action_space))
コード例 #2
0
 def __init__(self, config):
     self.observation_space = config.get(
         "space", Tuple([Discrete(2),
                         Dict({"a": Box(-1.0, 1.0, (2, ))})]))
     self.action_space = self.observation_space
     self.flattened_action_space = flatten_space(self.action_space)
     self.episode_len = config.get("episode_len", 100)
コード例 #3
0
    def get_action_dist(action_space,
                        config,
                        dist_type=None,
                        framework="tf",
                        **kwargs):
        """Returns a distribution class and size for the given action space.

        Args:
            action_space (Space): Action space of the target gym env.
            config (Optional[dict]): Optional model config.
            dist_type (Optional[str]): Identifier of the action distribution.
            framework (str): One of "tf" or "torch".
            kwargs (dict): Optional kwargs to pass on to the Distribution's
                constructor.

        Returns:
            dist_class (ActionDistribution): Python class of the distribution.
            dist_dim (int): The size of the input vector to the distribution.
        """
        dist = None
        config = config or MODEL_DEFAULTS
        # Custom distribution given.
        if config.get("custom_action_dist"):
            action_dist_name = config["custom_action_dist"]
            logger.debug(
                "Using custom action distribution {}".format(action_dist_name))
            dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
        # Dist_type is given directly as a class.
        elif type(dist_type) is type and \
                issubclass(dist_type, ActionDistribution) and \
                dist_type not in (
                MultiActionDistribution, TorchMultiActionDistribution):
            dist = dist_type
        # Box space -> DiagGaussian OR Deterministic.
        elif isinstance(action_space, gym.spaces.Box):
            if len(action_space.shape) > 1:
                raise UnsupportedSpaceException(
                    "Action space has multiple dimensions "
                    "{}. ".format(action_space.shape) +
                    "Consider reshaping this into a single dimension, "
                    "using a custom action distribution, "
                    "using a Tuple action space, or the multi-agent API.")
            # TODO(sven): Check for bounds and return SquashedNormal, etc..
            if dist_type is None:
                dist = DiagGaussian if framework == "tf" else TorchDiagGaussian
            elif dist_type == "deterministic":
                dist = Deterministic if framework == "tf" else \
                    TorchDeterministic
        # Discrete Space -> Categorical.
        elif isinstance(action_space, gym.spaces.Discrete):
            dist = Categorical if framework == "tf" else TorchCategorical
        # Tuple/Dict Spaces -> MultiAction.
        elif dist_type in (MultiActionDistribution,
                           TorchMultiActionDistribution) or \
                isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
            flat_action_space = flatten_space(action_space)
            child_dists_and_in_lens = tree.map_structure(
                lambda s: ModelCatalog.get_action_dist(
                    s, config, framework=framework), flat_action_space)
            child_dists = [e[0] for e in child_dists_and_in_lens]
            input_lens = [int(e[1]) for e in child_dists_and_in_lens]
            return partial((TorchMultiActionDistribution if framework
                            == "torch" else MultiActionDistribution),
                           action_space=action_space,
                           child_distributions=child_dists,
                           input_lens=input_lens), int(sum(input_lens))
        # Simplex -> Dirichlet.
        elif isinstance(action_space, Simplex):
            if framework == "torch":
                # TODO(sven): implement
                raise NotImplementedError(
                    "Simplex action spaces not supported for torch.")
            dist = Dirichlet
        # MultiDiscrete -> MultiCategorical.
        elif isinstance(action_space, gym.spaces.MultiDiscrete):
            dist = MultiCategorical if framework == "tf" else \
                TorchMultiCategorical
            return partial(dist, input_lens=action_space.nvec), \
                int(sum(action_space.nvec))
        # Unknown type -> Error.
        else:
            raise NotImplementedError("Unsupported args: {} {}".format(
                action_space, dist_type))

        return dist, dist.required_model_output_shape(action_space, config)