Exemplo n.º 1
0
 def __call__(self, shape, dtype=None):
   if not self.built:
     self.build(shape, dtype)
   return generated_random_variables.Independent(
       generated_random_variables.Normal(
           loc=self.mean, scale=self.stddev).distribution,
       reinterpreted_batch_ndims=len(shape))
Exemplo n.º 2
0
 def __call__(self, shape, dtype=None):
   if not self.built:
     self.build(shape, dtype)
   return generated_random_variables.Independent(
       generated_random_variables.HalfCauchy(
           loc=self.loc, scale=self.scale).distribution,
       reinterpreted_batch_ndims=len(shape))
Exemplo n.º 3
0
  def call(self, inputs):
    if self.conditional_inputs is None and self.conditional_outputs is None:
      covariance_matrix = self.covariance_fn(inputs, inputs)
      # Tile locations so output has shape [units, batch_size]. Covariance will
      # broadcast to [units, batch_size, batch_size], and we perform
      # shape manipulations to get a random variable over [batch_size, units].
      loc = self.mean_fn(inputs)
      loc = tf.tile(loc[tf.newaxis], [self.units] + [1] * len(loc.shape))
    else:
      knn = self.covariance_fn(inputs, inputs)
      knm = self.covariance_fn(inputs, self.conditional_inputs)
      kmm = self.covariance_fn(self.conditional_inputs, self.conditional_inputs)
      kmm = tf.linalg.set_diag(
          kmm, tf.linalg.diag_part(kmm) + tf.keras.backend.epsilon())
      kmm_tril = tf.linalg.cholesky(kmm)
      kmm_tril_operator = tf.linalg.LinearOperatorLowerTriangular(kmm_tril)
      knm_operator = tf.linalg.LinearOperatorFullMatrix(knm)

      # TODO(trandustin): Vectorize linear algebra for multiple outputs. For
      # now, we do each separately and stack to obtain a locations Tensor of
      # shape [units, batch_size].
      loc = []
      for conditional_outputs_unit in tf.unstack(self.conditional_outputs,
                                                 axis=-1):
        center = conditional_outputs_unit - self.mean_fn(
            self.conditional_inputs)
        loc_unit = knm_operator.matvec(
            kmm_tril_operator.solvevec(kmm_tril_operator.solvevec(center),
                                       adjoint=True))
        loc.append(loc_unit)
      loc = tf.stack(loc) + self.mean_fn(inputs)[tf.newaxis]

      covariance_matrix = knn
      covariance_matrix -= knm_operator.matmul(
          kmm_tril_operator.solve(
              kmm_tril_operator.solve(knm, adjoint_arg=True), adjoint=True))

    covariance_matrix = tf.linalg.set_diag(
        covariance_matrix,
        tf.linalg.diag_part(covariance_matrix) + tf.keras.backend.epsilon())

    # Form a multivariate normal random variable with batch_shape units and
    # event_shape batch_size. Then make it be independent across the units
    # dimension. Then transpose its dimensions so it is [batch_size, units].
    random_variable = (
        generated_random_variables.MultivariateNormalFullCovariance(
            loc=loc, covariance_matrix=covariance_matrix))
    random_variable = generated_random_variables.Independent(
        random_variable.distribution, reinterpreted_batch_ndims=1)
    bijector = tfp.bijectors.Inline(
        forward_fn=lambda x: tf.transpose(x, perm=[1, 0]),
        inverse_fn=lambda y: tf.transpose(y, perm=[1, 0]),
        forward_event_shape_fn=lambda input_shape: input_shape[::-1],
        forward_event_shape_tensor_fn=lambda input_shape: input_shape[::-1],
        inverse_log_det_jacobian_fn=lambda y: tf.cast(0, y.dtype),
        forward_min_event_ndims=2)
    random_variable = generated_random_variables.TransformedDistribution(
        random_variable.distribution, bijector=bijector)
    return random_variable
Exemplo n.º 4
0
 def __call__(self, shape, dtype=None):
     if not self.built:
         self.build(shape, dtype)
     loc = self.loc
     if self.loc_constraint:
         loc = self.loc_constraint(loc)
     return generated_random_variables.Independent(
         generated_random_variables.Deterministic(loc=loc).distribution,
         reinterpreted_batch_ndims=len(shape))
Exemplo n.º 5
0
 def __call__(self, x):
     """Computes regularization given an ed.Normal random variable as input."""
     if not isinstance(x, random_variable.RandomVariable):
         raise ValueError('Input must be an ed.RandomVariable.')
     prior = generated_random_variables.Independent(
         generated_random_variables.Normal(loc=x.distribution.mean(),
                                           scale=self.stddev).distribution,
         reinterpreted_batch_ndims=len(x.distribution.event_shape))
     regularization = x.distribution.kl_divergence(prior.distribution)
     return self.scale_factor * regularization
Exemplo n.º 6
0
 def __call__(self, x):
   """Computes regularization using an unbiased Monte Carlo estimate."""
   prior = generated_random_variables.Independent(
       generated_random_variables.HalfCauchy(
           loc=tf.broadcast_to(self.loc, x.distribution.event_shape),
           scale=tf.broadcast_to(self.scale, x.distribution.event_shape)
       ).distribution,
       reinterpreted_batch_ndims=len(x.distribution.event_shape))
   negative_entropy = x.distribution.log_prob(x)
   cross_entropy = -prior.distribution.log_prob(x)
   return self.scale_factor * (negative_entropy + cross_entropy)
Exemplo n.º 7
0
 def __call__(self, shape, dtype=None):
   if not self.built:
     self.build(shape, dtype)
   loc = self.loc
   if self.loc_constraint:
     loc = self.loc_constraint(loc)
   scale = self.scale
   if self.scale_constraint:
     scale = self.scale_constraint(scale)
   return generated_random_variables.Independent(
       generated_random_variables.LogNormal(loc=loc, scale=scale).distribution,
       reinterpreted_batch_ndims=len(shape))
Exemplo n.º 8
0
 def call(self, inputs):
   """Computes regularization given an input ed.RandomVariable."""
   if not isinstance(inputs, random_variable.RandomVariable):
     raise ValueError('Input must be an ed.RandomVariable.')
   stddev = self.stddev
   if self.stddev_constraint:
     stddev = self.stddev_constraint(stddev)
   prior = generated_random_variables.Independent(
       generated_random_variables.Normal(
           loc=self.mean, scale=stddev).distribution,
       reinterpreted_batch_ndims=len(inputs.distribution.event_shape))
   regularization = inputs.distribution.kl_divergence(prior.distribution)
   return self.scale_factor * regularization
Exemplo n.º 9
0
 def __call__(self, shape, dtype=None):
     if not self.built:
         self.build(shape, dtype)
     loc = self.loc
     if self.loc_constraint:
         loc = self.loc_constraint(loc)
     return generated_random_variables.Independent(
         generated_random_variables.MixtureSameFamily(
             mixture_distribution=generated_random_variables.Categorical(
                 probs=tf.broadcast_to([[1 / self.num_components] *
                                        self.num_components],
                                       list(shape) +
                                       [self.num_components])).distribution,
             components_distribution=generated_random_variables.
             Deterministic(loc=loc).distribution).distribution,
         reinterpreted_batch_ndims=len(shape))