def get_policy_output(self, model_out: TensorType) -> TensorType: """Returns policy outputs, given the output of self.__call__(). For continuous action spaces, these will be the mean/stddev distribution inputs for the (SquashedGaussian) action distribution. For discrete action spaces, these will be the logits for a categorical distribution. Args: model_out (TensorType): Feature outputs from the model layers (result of doing `self.__call__(obs)`). Returns: TensorType: Distribution inputs for sampling actions. """ # Model outs may come as original Tuple/Dict observations, concat them # here if this is the case. if isinstance(self.action_model.obs_space, Box): if isinstance(model_out, (list, tuple)): model_out = tf.concat(model_out, axis=-1) elif isinstance(model_out, dict): model_out = tf.concat([ tf.expand_dims(val, 1) if len(val.shape) == 1 else val for val in tree.flatten(model_out.values()) ], axis=-1) out, _ = self.action_model({"obs": model_out}, [], None) return out
def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None: if not isinstance(observation, OrderedDict): observation = OrderedDict(sorted(observation.items())) assert len(observation) == len(self.preprocessors), \ (len(observation), len(self.preprocessors)) for o, p in zip(observation.values(), self.preprocessors): p.write(o, array, offset) offset += p.size