def __init__(self, config, name=None): super(LNormPolicy, self).__init__(name=name or self.__class__.__name__) self._config = config hidden_widths = config.hidden_widths self.p = config.p self.q = config.q if config.embed: transformation_layers = [layers_lib.soft_hot_layer(**config.embed)] else: transformation_layers = [] self._body = tf.keras.Sequential(transformation_layers + [ tf.keras.layers.Dense(w, activation='relu') for w in hidden_widths ] + [tf.keras.layers.Dense(1, activation=None)])
def __init__(self, config, name=None): super(ExponentialFamilyPolicy, self).__init__(name=name or self.__class__.__name__) self._config = config hidden_widths = config.hidden_widths self._dist = DISTS[config.dist_name] params = self._dist.params self._log_pdf = self._dist.log_pdf self._mode = self._dist.mode if config.embed: transformation_layers = [layers_lib.soft_hot_layer(**config.embed)] else: transformation_layers = [] self._body = tf.keras.Sequential(transformation_layers + [ tf.keras.layers.Dense(w, activation='relu') for w in hidden_widths ] + [ tf.keras.layers.Dense(len(params), activation=None), tf.keras.layers.Lambda(lambda x: tf.stack( # pylint: disable=g-long-lambda [ tf.math.softplus(x[Ellipsis, i]) if is_non_neg else x[ Ellipsis, i] for i, is_non_neg in enumerate(params) ], -1)) ])