def __call__(self, x): """Computes regularization given an input ed.RandomVariable.""" if not isinstance(x, random_variable.RandomVariable): raise ValueError('Input must be an ed.RandomVariable.') prior = generated_random_variables.Independent( generated_random_variables.LogNormal( loc=tf.broadcast_to(self.loc, x.distribution.event_shape), scale=tf.broadcast_to(self.scale, x.distribution.event_shape) ).distribution, reinterpreted_batch_ndims=len(x.distribution.event_shape)) regularization = x.distribution.kl_divergence(prior.distribution) return self.scale_factor * regularization
def __call__(self, shape, dtype=None): if not self.built: self.build(shape, dtype) loc = self.loc if self.loc_constraint: loc = self.loc_constraint(loc) scale = self.scale if self.scale_constraint: scale = self.scale_constraint(scale) return generated_random_variables.Independent( generated_random_variables.LogNormal(loc=loc, scale=scale).distribution, reinterpreted_batch_ndims=len(shape))