def __call__(self, shape, dtype=None): if not self.built: self.build(shape, dtype) return generated_random_variables.Independent( generated_random_variables.Normal( loc=self.mean, scale=self.stddev).distribution, reinterpreted_batch_ndims=len(shape))
def __call__(self, shape, dtype=None): if not self.built: self.build(shape, dtype) return generated_random_variables.Independent( generated_random_variables.HalfCauchy( loc=self.loc, scale=self.scale).distribution, reinterpreted_batch_ndims=len(shape))
def call(self, inputs): if self.conditional_inputs is None and self.conditional_outputs is None: covariance_matrix = self.covariance_fn(inputs, inputs) # Tile locations so output has shape [units, batch_size]. Covariance will # broadcast to [units, batch_size, batch_size], and we perform # shape manipulations to get a random variable over [batch_size, units]. loc = self.mean_fn(inputs) loc = tf.tile(loc[tf.newaxis], [self.units] + [1] * len(loc.shape)) else: knn = self.covariance_fn(inputs, inputs) knm = self.covariance_fn(inputs, self.conditional_inputs) kmm = self.covariance_fn(self.conditional_inputs, self.conditional_inputs) kmm = tf.linalg.set_diag( kmm, tf.linalg.diag_part(kmm) + tf.keras.backend.epsilon()) kmm_tril = tf.linalg.cholesky(kmm) kmm_tril_operator = tf.linalg.LinearOperatorLowerTriangular(kmm_tril) knm_operator = tf.linalg.LinearOperatorFullMatrix(knm) # TODO(trandustin): Vectorize linear algebra for multiple outputs. For # now, we do each separately and stack to obtain a locations Tensor of # shape [units, batch_size]. loc = [] for conditional_outputs_unit in tf.unstack(self.conditional_outputs, axis=-1): center = conditional_outputs_unit - self.mean_fn( self.conditional_inputs) loc_unit = knm_operator.matvec( kmm_tril_operator.solvevec(kmm_tril_operator.solvevec(center), adjoint=True)) loc.append(loc_unit) loc = tf.stack(loc) + self.mean_fn(inputs)[tf.newaxis] covariance_matrix = knn covariance_matrix -= knm_operator.matmul( kmm_tril_operator.solve( kmm_tril_operator.solve(knm, adjoint_arg=True), adjoint=True)) covariance_matrix = tf.linalg.set_diag( covariance_matrix, tf.linalg.diag_part(covariance_matrix) + tf.keras.backend.epsilon()) # Form a multivariate normal random variable with batch_shape units and # event_shape batch_size. Then make it be independent across the units # dimension. Then transpose its dimensions so it is [batch_size, units]. random_variable = ( generated_random_variables.MultivariateNormalFullCovariance( loc=loc, covariance_matrix=covariance_matrix)) random_variable = generated_random_variables.Independent( random_variable.distribution, reinterpreted_batch_ndims=1) bijector = tfp.bijectors.Inline( forward_fn=lambda x: tf.transpose(x, perm=[1, 0]), inverse_fn=lambda y: tf.transpose(y, perm=[1, 0]), forward_event_shape_fn=lambda input_shape: input_shape[::-1], forward_event_shape_tensor_fn=lambda input_shape: input_shape[::-1], inverse_log_det_jacobian_fn=lambda y: tf.cast(0, y.dtype), forward_min_event_ndims=2) random_variable = generated_random_variables.TransformedDistribution( random_variable.distribution, bijector=bijector) return random_variable
def __call__(self, shape, dtype=None): if not self.built: self.build(shape, dtype) loc = self.loc if self.loc_constraint: loc = self.loc_constraint(loc) return generated_random_variables.Independent( generated_random_variables.Deterministic(loc=loc).distribution, reinterpreted_batch_ndims=len(shape))
def __call__(self, x): """Computes regularization given an ed.Normal random variable as input.""" if not isinstance(x, random_variable.RandomVariable): raise ValueError('Input must be an ed.RandomVariable.') prior = generated_random_variables.Independent( generated_random_variables.Normal(loc=x.distribution.mean(), scale=self.stddev).distribution, reinterpreted_batch_ndims=len(x.distribution.event_shape)) regularization = x.distribution.kl_divergence(prior.distribution) return self.scale_factor * regularization
def __call__(self, x): """Computes regularization using an unbiased Monte Carlo estimate.""" prior = generated_random_variables.Independent( generated_random_variables.HalfCauchy( loc=tf.broadcast_to(self.loc, x.distribution.event_shape), scale=tf.broadcast_to(self.scale, x.distribution.event_shape) ).distribution, reinterpreted_batch_ndims=len(x.distribution.event_shape)) negative_entropy = x.distribution.log_prob(x) cross_entropy = -prior.distribution.log_prob(x) return self.scale_factor * (negative_entropy + cross_entropy)
def __call__(self, shape, dtype=None): if not self.built: self.build(shape, dtype) loc = self.loc if self.loc_constraint: loc = self.loc_constraint(loc) scale = self.scale if self.scale_constraint: scale = self.scale_constraint(scale) return generated_random_variables.Independent( generated_random_variables.LogNormal(loc=loc, scale=scale).distribution, reinterpreted_batch_ndims=len(shape))
def call(self, inputs): """Computes regularization given an input ed.RandomVariable.""" if not isinstance(inputs, random_variable.RandomVariable): raise ValueError('Input must be an ed.RandomVariable.') stddev = self.stddev if self.stddev_constraint: stddev = self.stddev_constraint(stddev) prior = generated_random_variables.Independent( generated_random_variables.Normal( loc=self.mean, scale=stddev).distribution, reinterpreted_batch_ndims=len(inputs.distribution.event_shape)) regularization = inputs.distribution.kl_divergence(prior.distribution) return self.scale_factor * regularization
def __call__(self, shape, dtype=None): if not self.built: self.build(shape, dtype) loc = self.loc if self.loc_constraint: loc = self.loc_constraint(loc) return generated_random_variables.Independent( generated_random_variables.MixtureSameFamily( mixture_distribution=generated_random_variables.Categorical( probs=tf.broadcast_to([[1 / self.num_components] * self.num_components], list(shape) + [self.num_components])).distribution, components_distribution=generated_random_variables. Deterministic(loc=loc).distribution).distribution, reinterpreted_batch_ndims=len(shape))