Ejemplo n.º 1
0
 def __init__(self, config, name=None):
     super(LNormPolicy, self).__init__(name=name or self.__class__.__name__)
     self._config = config
     hidden_widths = config.hidden_widths
     self.p = config.p
     self.q = config.q
     if config.embed:
         transformation_layers = [layers_lib.soft_hot_layer(**config.embed)]
     else:
         transformation_layers = []
     self._body = tf.keras.Sequential(transformation_layers + [
         tf.keras.layers.Dense(w, activation='relu') for w in hidden_widths
     ] + [tf.keras.layers.Dense(1, activation=None)])
 def __init__(self, config, name=None):
     super(ExponentialFamilyPolicy,
           self).__init__(name=name or self.__class__.__name__)
     self._config = config
     hidden_widths = config.hidden_widths
     self._dist = DISTS[config.dist_name]
     params = self._dist.params
     self._log_pdf = self._dist.log_pdf
     self._mode = self._dist.mode
     if config.embed:
         transformation_layers = [layers_lib.soft_hot_layer(**config.embed)]
     else:
         transformation_layers = []
     self._body = tf.keras.Sequential(transformation_layers + [
         tf.keras.layers.Dense(w, activation='relu') for w in hidden_widths
     ] + [
         tf.keras.layers.Dense(len(params), activation=None),
         tf.keras.layers.Lambda(lambda x: tf.stack(  # pylint: disable=g-long-lambda
             [
                 tf.math.softplus(x[Ellipsis, i]) if is_non_neg else x[
                     Ellipsis, i] for i, is_non_neg in enumerate(params)
             ], -1))
     ])