Пример #1
0
    def get_action_shape(action_space):
        """Returns action tensor dtype and shape for the action space.

        Args:
            action_space (Space): Action space of the target gym env.
        Returns:
            (dtype, shape): Dtype and shape of the actions tensor.
        """

        if isinstance(action_space, gym.spaces.Discrete):
            return (tf.int64, (None, ))
        elif isinstance(action_space, (gym.spaces.Box, Simplex)):
            return (tf.float32, (None, ) + action_space.shape)
        elif isinstance(action_space, gym.spaces.MultiDiscrete):
            return (tf.as_dtype(action_space.dtype),
                    (None, ) + action_space.shape)
        elif isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
            flat_action_space = flatten_space(action_space)
            size = 0
            all_discrete = True
            for i in range(len(flat_action_space)):
                if isinstance(flat_action_space[i], gym.spaces.Discrete):
                    size += 1
                else:
                    all_discrete = False
                    size += np.product(flat_action_space[i].shape)
            size = int(size)
            return (tf.int64 if all_discrete else tf.float32, (None, size))
        else:
            raise NotImplementedError(
                "Action space {} not supported".format(action_space))
Пример #2
0
    def get_action_shape(action_space: gym.Space,
                         framework: str = "tf") -> (np.dtype, List[int]):
        """Returns action tensor dtype and shape for the action space.

        Args:
            action_space (Space): Action space of the target gym env.
            framework (str): The framework identifier. One of "tf" or "torch".

        Returns:
            (dtype, shape): Dtype and shape of the actions tensor.
        """
        dl_lib = torch if framework == "torch" else tf

        if isinstance(action_space, Discrete):
            return action_space.dtype, (None, )
        elif isinstance(action_space, (Box, Simplex)):
            return dl_lib.float32, (None, ) + action_space.shape
        elif isinstance(action_space, MultiDiscrete):
            return action_space.dtype, (None, ) + action_space.shape
        elif isinstance(action_space, (Tuple, Dict)):
            flat_action_space = flatten_space(action_space)
            size = 0
            all_discrete = True
            for i in range(len(flat_action_space)):
                if isinstance(flat_action_space[i], Discrete):
                    size += 1
                else:
                    all_discrete = False
                    size += np.product(flat_action_space[i].shape)
            size = int(size)
            return dl_lib.int64 if all_discrete else dl_lib.float32, \
                (None, size)
        else:
            raise NotImplementedError(
                "Action space {} not supported".format(action_space))
Пример #3
0
 def _get_multi_action_distribution(dist_class, action_space, config,
                                    framework):
     # In case the custom distribution is a child of MultiActionDistr.
     # If users want to completely ignore the suggested child
     # distributions, they should simply do so in their custom class'
     # constructor.
     if issubclass(dist_class,
                   (MultiActionDistribution, TorchMultiActionDistribution)):
         flat_action_space = flatten_space(action_space)
         child_dists_and_in_lens = tree.map_structure(
             lambda s: ModelCatalog.get_action_dist(
                 s, config, framework=framework),
             flat_action_space,
         )
         child_dists = [e[0] for e in child_dists_and_in_lens]
         input_lens = [int(e[1]) for e in child_dists_and_in_lens]
         return (
             partial(
                 dist_class,
                 action_space=action_space,
                 child_distributions=child_dists,
                 input_lens=input_lens,
             ),
             int(sum(input_lens)),
         )
     return dist_class, dist_class.required_model_output_shape(
         action_space, config)
Пример #4
0
 def __init__(self, config):
     self.observation_space = config.get(
         "space", Tuple([Discrete(2),
                         Dict({"a": Box(-1.0, 1.0, (2, ))})]))
     self.action_space = self.observation_space
     self.flattened_action_space = flatten_space(self.action_space)
     self.episode_len = config.get("episode_len", 100)
Пример #5
0
def get_action_dim(action_space: gym.Space):
    """Returns action dim,

    Args:
        action_space (Space): Action space of the target gym env.
    Returns:
        
    """

    if isinstance(action_space, gym.spaces.Discrete):
        return action_space.n
    elif isinstance(action_space, (gym.spaces.Box)):
        return np.product(action_space.shape) * 2  # 对角高斯
    elif isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
        flat_action_space = flatten_space(action_space)
        size = 0
        all_discrete = True
        for i in range(len(flat_action_space)):
            size += get_action_dim(flat_action_space[i])
        return size
    else:
        raise NotImplementedError(
            "Action space {} not supported".format(action_space))
Пример #6
0
    def get_action_dist(action_space,
                        config,
                        dist_type=None,
                        framework="tf",
                        **kwargs):
        """Returns a distribution class and size for the given action space.

        Args:
            action_space (Space): Action space of the target gym env.
            config (Optional[dict]): Optional model config.
            dist_type (Optional[str]): Identifier of the action distribution
                interpreted as a hint.
            framework (str): One of "tf", "tfe", or "torch".
            kwargs (dict): Optional kwargs to pass on to the Distribution's
                constructor.

        Returns:
            Tuple:
                - dist_class (ActionDistribution): Python class of the
                    distribution.
                - dist_dim (int): The size of the input vector to the
                    distribution.
        """

        dist = None
        config = config or MODEL_DEFAULTS
        # Custom distribution given.
        if config.get("custom_action_dist"):
            action_dist_name = config["custom_action_dist"]
            logger.debug(
                "Using custom action distribution {}".format(action_dist_name))
            dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
        # Dist_type is given directly as a class.
        elif type(dist_type) is type and \
                issubclass(dist_type, ActionDistribution) and \
                dist_type not in (
                MultiActionDistribution, TorchMultiActionDistribution):
            dist = dist_type
        # Box space -> DiagGaussian OR Deterministic.
        elif isinstance(action_space, gym.spaces.Box):
            if len(action_space.shape) > 1:
                raise UnsupportedSpaceException(
                    "Action space has multiple dimensions "
                    "{}. ".format(action_space.shape) +
                    "Consider reshaping this into a single dimension, "
                    "using a custom action distribution, "
                    "using a Tuple action space, or the multi-agent API.")
            # TODO(sven): Check for bounds and return SquashedNormal, etc..
            if dist_type is None:
                dist = TorchDiagGaussian if framework == "torch" \
                    else DiagGaussian
            elif dist_type == "deterministic":
                dist = TorchDeterministic if framework == "torch" \
                    else Deterministic
        # Discrete Space -> Categorical.
        elif isinstance(action_space, gym.spaces.Discrete):
            dist = TorchCategorical if framework == "torch" else Categorical
        # Tuple/Dict Spaces -> MultiAction.
        elif dist_type in (MultiActionDistribution,
                           TorchMultiActionDistribution) or \
                isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
            flat_action_space = flatten_space(action_space)
            child_dists_and_in_lens = tree.map_structure(
                lambda s: ModelCatalog.get_action_dist(
                    s, config, framework=framework), flat_action_space)
            child_dists = [e[0] for e in child_dists_and_in_lens]
            input_lens = [int(e[1]) for e in child_dists_and_in_lens]
            return partial(
                (TorchMultiActionDistribution
                 if framework == "torch" else MultiActionDistribution),
                action_space=action_space,
                child_distributions=child_dists,
                input_lens=input_lens), int(sum(input_lens))
        # Simplex -> Dirichlet.
        elif isinstance(action_space, Simplex):
            if framework == "torch":
                # TODO(sven): implement
                raise NotImplementedError(
                    "Simplex action spaces not supported for torch.")
            dist = Dirichlet
        # MultiDiscrete -> MultiCategorical.
        elif isinstance(action_space, gym.spaces.MultiDiscrete):
            dist = TorchMultiCategorical if framework == "torch" else \
                MultiCategorical
            return partial(dist, input_lens=action_space.nvec), \
                int(sum(action_space.nvec))
        # Unknown type -> Error.
        else:
            raise NotImplementedError("Unsupported args: {} {}".format(
                action_space, dist_type))

        return dist, dist.required_model_output_shape(action_space, config)
Пример #7
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        self.original_space = obs_space.original_space if \
            hasattr(obs_space, "original_space") else obs_space
        assert isinstance(self.original_space, (Dict, Tuple)), \
            "`obs_space.original_space` must be [Dict|Tuple]!"

        self.processed_obs_space = self.original_space if \
            model_config.get("_disable_preprocessor_api") else obs_space

        nn.Module.__init__(self)
        TorchModelV2.__init__(self, self.original_space, action_space,
                              num_outputs, model_config, name)

        self.flattened_input_space = flatten_space(self.original_space)

        # Atari type CNNs or IMPALA type CNNs (with residual layers)?
        # self.cnn_type = self.model_config["custom_model_config"].get(
        #     "conv_type", "atari")

        # Build the CNN(s) given obs_space's image components.
        self.cnns = {}
        self.one_hot = {}
        self.flatten = {}
        concat_size = 0
        for i, component in enumerate(self.flattened_input_space):
            # Image space.
            if len(component.shape) == 3:
                config = {
                    "conv_filters":
                    model_config["conv_filters"] if "conv_filters"
                    in model_config else get_filter_config(obs_space.shape),
                    "conv_activation":
                    model_config.get("conv_activation"),
                    "post_fcnet_hiddens": [],
                }
                # if self.cnn_type == "atari":
                cnn = ModelCatalog.get_model_v2(component,
                                                action_space,
                                                num_outputs=None,
                                                model_config=config,
                                                framework="torch",
                                                name="cnn_{}".format(i))
                # TODO (sven): add IMPALA-style option.
                # else:
                #    cnn = TorchImpalaVisionNet(
                #        component,
                #        action_space,
                #        num_outputs=None,
                #        model_config=config,
                #        name="cnn_{}".format(i))

                concat_size += cnn.num_outputs
                self.cnns[i] = cnn
                self.add_module("cnn_{}".format(i), cnn)
            # Discrete|MultiDiscrete inputs -> One-hot encode.
            elif isinstance(component, Discrete):
                self.one_hot[i] = True
                concat_size += component.n
            elif isinstance(component, MultiDiscrete):
                self.one_hot[i] = True
                concat_size += sum(component.nvec)
            # Everything else (1D Box).
            else:
                self.flatten[i] = int(np.product(component.shape))
                concat_size += self.flatten[i]

        # Optional post-concat FC-stack.
        post_fc_stack_config = {
            "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
            "fcnet_activation": model_config.get("post_fcnet_activation",
                                                 "relu")
        }
        self.post_fc_stack = ModelCatalog.get_model_v2(Box(
            float("-inf"),
            float("inf"),
            shape=(concat_size, ),
            dtype=np.float32),
                                                       self.action_space,
                                                       None,
                                                       post_fc_stack_config,
                                                       framework="torch",
                                                       name="post_fc_stack")

        # Actions and value heads.
        self.logits_layer = None
        self.value_layer = None
        self._value_out = None

        if num_outputs:
            # Action-distribution head.
            self.logits_layer = SlimFC(
                in_size=self.post_fc_stack.num_outputs,
                out_size=num_outputs,
                activation_fn=None,
            )
            # Create the value branch model.
            self.value_layer = SlimFC(
                in_size=self.post_fc_stack.num_outputs,
                out_size=1,
                activation_fn=None,
                initializer=torch_normc_initializer(0.01))
        else:
            self.num_outputs = concat_size
Пример #8
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        self.original_space = obs_space.original_space if \
            hasattr(obs_space, "original_space") else obs_space
        assert isinstance(self.original_space, (Dict, Tuple)), \
            "`obs_space.original_space` must be [Dict|Tuple]!"

        self.processed_obs_space = self.original_space if \
            model_config.get("_disable_preprocessor_api") else obs_space
        super().__init__(self.original_space, action_space, num_outputs,
                         model_config, name)

        self.flattened_input_space = flatten_space(self.original_space)

        # Build the CNN(s) given obs_space's image components.
        self.cnns = {}
        self.one_hot = {}
        self.flatten = {}
        concat_size = 0
        for i, component in enumerate(self.flattened_input_space):
            # Image space.
            if len(component.shape) == 3:
                config = {
                    "conv_filters":
                    model_config["conv_filters"] if "conv_filters"
                    in model_config else get_filter_config(obs_space.shape),
                    "conv_activation":
                    model_config.get("conv_activation"),
                    "post_fcnet_hiddens": [],
                }
                cnn = ModelCatalog.get_model_v2(component,
                                                action_space,
                                                num_outputs=None,
                                                model_config=config,
                                                framework="tf",
                                                name="cnn_{}".format(i))
                concat_size += cnn.num_outputs
                self.cnns[i] = cnn
            # Discrete|MultiDiscrete inputs -> One-hot encode.
            elif isinstance(component, Discrete):
                self.one_hot[i] = True
                concat_size += component.n
            elif isinstance(component, MultiDiscrete):
                self.one_hot[i] = True
                concat_size += sum(component.nvec)
            # Everything else (1D Box).
            else:
                self.flatten[i] = int(np.product(component.shape))
                concat_size += self.flatten[i]

        # Optional post-concat FC-stack.
        post_fc_stack_config = {
            "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
            "fcnet_activation": model_config.get("post_fcnet_activation",
                                                 "relu")
        }
        self.post_fc_stack = ModelCatalog.get_model_v2(Box(
            float("-inf"),
            float("inf"),
            shape=(concat_size, ),
            dtype=np.float32),
                                                       self.action_space,
                                                       None,
                                                       post_fc_stack_config,
                                                       framework="tf",
                                                       name="post_fc_stack")

        # Actions and value heads.
        self.logits_and_value_model = None
        self._value_out = None
        if num_outputs:
            # Action-distribution head.
            concat_layer = tf.keras.layers.Input(
                (self.post_fc_stack.num_outputs, ))
            logits_layer = tf.keras.layers.Dense(
                num_outputs,
                activation=tf.keras.activations.linear,
                name="logits")(concat_layer)

            # Create the value branch model.
            value_layer = tf.keras.layers.Dense(
                1,
                name="value_out",
                activation=None,
                kernel_initializer=normc_initializer(0.01))(concat_layer)
            self.logits_and_value_model = tf.keras.models.Model(
                concat_layer, [logits_layer, value_layer])
        else:
            self.num_outputs = self.post_fc_stack.num_outputs