예제 #1
0
    def get_policy_output(self, model_out: TensorType) -> TensorType:
        """Returns policy outputs, given the output of self.__call__().

        For continuous action spaces, these will be the mean/stddev
        distribution inputs for the (SquashedGaussian) action distribution.
        For discrete action spaces, these will be the logits for a categorical
        distribution.

        Args:
            model_out (TensorType): Feature outputs from the model layers
                (result of doing `self.__call__(obs)`).

        Returns:
            TensorType: Distribution inputs for sampling actions.
        """
        # Model outs may come as original Tuple/Dict observations, concat them
        # here if this is the case.
        if isinstance(self.action_model.obs_space, Box):
            if isinstance(model_out, (list, tuple)):
                model_out = tf.concat(model_out, axis=-1)
            elif isinstance(model_out, dict):
                model_out = tf.concat([
                    tf.expand_dims(val, 1) if len(val.shape) == 1 else val
                    for val in tree.flatten(model_out.values())
                ],
                                      axis=-1)
        out, _ = self.action_model({"obs": model_out}, [], None)
        return out
예제 #2
0
 def write(self, observation: TensorType, array: np.ndarray,
           offset: int) -> None:
     if not isinstance(observation, OrderedDict):
         observation = OrderedDict(sorted(observation.items()))
     assert len(observation) == len(self.preprocessors), \
         (len(observation), len(self.preprocessors))
     for o, p in zip(observation.values(), self.preprocessors):
         p.write(o, array, offset)
         offset += p.size