Exemplo n.º 1
0
 def __call__(self, x):
   """Computes regularization given an input ed.RandomVariable."""
   if not isinstance(x, random_variable.RandomVariable):
     raise ValueError('Input must be an ed.RandomVariable.')
   prior = generated_random_variables.Independent(
       generated_random_variables.LogNormal(
           loc=tf.broadcast_to(self.loc, x.distribution.event_shape),
           scale=tf.broadcast_to(self.scale, x.distribution.event_shape)
       ).distribution,
       reinterpreted_batch_ndims=len(x.distribution.event_shape))
   regularization = x.distribution.kl_divergence(prior.distribution)
   return self.scale_factor * regularization
Exemplo n.º 2
0
 def __call__(self, shape, dtype=None):
   if not self.built:
     self.build(shape, dtype)
   loc = self.loc
   if self.loc_constraint:
     loc = self.loc_constraint(loc)
   scale = self.scale
   if self.scale_constraint:
     scale = self.scale_constraint(scale)
   return generated_random_variables.Independent(
       generated_random_variables.LogNormal(loc=loc, scale=scale).distribution,
       reinterpreted_batch_ndims=len(shape))