def test_neg_log_likelihood_mv_gaussian_conv_filters_chol(self): x, mu, covariance, weights, filters, log_diag = self._random_normal_params( cov_rep.PrecisionConvCholFilters) tf_mvnd = tfd.MultivariateNormalFullCovariance( loc=mu, covariance_matrix=covariance) tf_nll = -tf_mvnd.log_prob(x) img_size = int(np.sqrt(self.features_size)) img_shape = (self.batch_size, img_size, img_size, 1) covar = cov_rep.PrecisionConvCholFilters( weights_precision=tf.convert_to_tensor(weights), filters_precision=tf.convert_to_tensor(filters), sample_shape=img_shape) covar.log_diag_chol_precision = log_diag r_tf = tf.convert_to_tensor(x - mu) r_tf_img = tf.reshape(r_tf, img_shape) nll = neg_log_likelihood_mv_gaussian( r_tf, x_precision_x=covar.x_precision_x(r_tf_img), log_det_cov=covar.log_det_covariance(), mean_batch=False) self._asset_allclose_tf_feed(nll, tf_nll)
def test_kl_divergence_mv_gaussian_conv_filters_chol(self): _, mu1, covar1, weights1, filters1, log_diag1 = self._random_normal_params( cov_rep.PrecisionConvCholFilters) _, mu2, covar2, weights2, filters2, log_diag2 = self._random_normal_params( cov_rep.PrecisionConvCholFilters) tf_mvnd1 = dist.MultivariateNormalFullCovariance( loc=mu1, covariance_matrix=covar1) tf_mvnd2 = dist.MultivariateNormalFullCovariance( loc=mu2, covariance_matrix=covar2) tf_kldiv = dist.kl_divergence(tf_mvnd1, tf_mvnd2) mu1_tf, weights1, filters1 = self._convert_to_tensor( mu1, weights1, filters1) mu2_tf, weights2, filters2 = self._convert_to_tensor( mu2, weights2, filters2) img_size = int(np.sqrt(self.features_size)) img_shape = (self.batch_size, img_size, img_size, 1) covar1 = cov_rep.PrecisionConvCholFilters(weights_precision=weights1, filters_precision=filters1, sample_shape=img_shape) covar1.log_diag_chol_precision = log_diag1 covar2 = cov_rep.PrecisionConvCholFilters(weights_precision=weights2, filters_precision=filters2, sample_shape=img_shape) covar2.log_diag_chol_precision = log_diag2 covar_kldiv = kl_divergence_mv_gaussian_v2(sigma1=covar1, sigma2=covar2, mu1=mu1_tf, mu2=mu2_tf, mean_batch=False) self._asset_allclose_tf_feed(tf_kldiv, covar_kldiv) tf_kldiv = tf.reduce_mean(tf_kldiv) covar_kldiv = kl_divergence_mv_gaussian_v2(sigma1=covar1, sigma2=covar2, mu1=mu1_tf, mu2=mu2_tf) self._asset_allclose_tf_feed(tf_kldiv, covar_kldiv)
def _sample_n_sparse(self, n, kw, seed=None): assert n == 1 batch_shape = self.batch_shape_tensor() event_shape = self.event_shape iw = int(np.sqrt(event_shape[0].value)) # Image width nb = kw**2 # Number of basis nb_half = nb // 2 + 1 nch = 1 # Number of channels in the image stream = seed_stream.SeedStream(seed=seed, salt="Wishart") shape = tf.concat([batch_shape, [iw, iw, nb_half - 1]], 0) # Random sample for the off diagonal values as a dense tensor x_right = tf.random_normal(shape=shape, dtype=self.dtype, seed=stream()) # The upper triangular values needed to get a square kernel per pixel x_left = tf.zeros(shape) # Random sample for the diagonal of the matrix x_diag = tf.random_gamma(shape=[n], alpha=self._multi_gamma_sequence( 0.5 * self.df, self.p), beta=0.5, dtype=self.dtype, seed=stream()) # Concatenate the diagonal and off-diagonal elements x_diag = tf.reshape(x_diag, (-1, iw, iw, nch)) x = tf.concat([tf.sqrt(x_diag), x_right], axis=3) # Scale the sampled matrix using the distribution Scale matrix diag_scale = tf.exp(self.log_diag_scale) diag_scale = tf.reshape(diag_scale, (-1, iw, iw, nch)) x *= tf.sqrt(diag_scale) # Square root is equivalent to Cholesky # Concatenate with the zeros x = tf.concat([x_left, x], axis=3) # Create identity basis so that the sampled matrix is only defined by x, if this were not the case # we would have to do some optimization to find the basis and weights that reconstruct x identity_basis = tf.eye(num_rows=nb) identity_basis = tf.reshape(identity_basis, (nb, kw, kw, nch, nch)) sample_shape = tf.concat([batch_shape, [iw, iw, nch]], axis=0) x_sparse = cov_rep.PrecisionConvCholFilters( weights_precision=x, filters_precision=identity_basis, sample_shape=sample_shape) return x_sparse
def _create_covariance_instance(self): super()._create_covariance_instance() # Test without giving the filters, let the model build it img_shape = (self.batch_size, self.img_size, self.img_size, self.num_ch) self.cov_object = cov_rep.PrecisionConvCholFilters( weights_precision=self.tf_weights, filters_precision=None, sample_shape=img_shape, inversion_method=self.inversion_method)
def _create_single_sqrt_wishart_pair(self, add_sparse_gamma=False): # Create a random scale matrix for the Wishart distribution diag_precision_prior = np.abs( np.random.normal(size=(self.batch_size, self.features_size))) diag_precision_prior = diag_precision_prior.astype( self.dtype.as_numpy_dtype) precision_prior = np.zeros(shape=(self.batch_size, self.features_size, self.features_size), dtype=self.dtype.as_numpy_dtype) for i in range(self.batch_size): precision_prior[i][np.diag_indices_from( precision_prior[i])] = diag_precision_prior[i] log_diag_precision_prior = np.log(diag_precision_prior) # Create a random vector of degrees of freedom, whose values must be larger than features_size df = np.random.uniform(low=self.features_size, high=self.features_size * 10, size=self.batch_size) df = df.astype(self.dtype.as_numpy_dtype) # Create a square root Wishart distribution using bijectors wishart = tfd.Wishart(scale=precision_prior, df=df) cholesky_bijector = tfb.Invert(tfb.CholeskyOuterProduct()) sqrt_wishart_tfd = tfd.TransformedDistribution( distribution=wishart, bijector=cholesky_bijector) # Create our custom square root Wishart distribution with the same parameters sqrt_gamma_gaussian = SqrtGammaGaussian( df=df, log_diag_scale=log_diag_precision_prior) if add_sparse_gamma: sparse_sqrt_gamma_gaussian = SparseSqrtGammaGaussian( df=df, log_diag_scale=log_diag_precision_prior) # Create a random Cholesky matrix to test the probability density functions _, __, x_covariance, x_weights, x_basis, log_diag = self._random_normal_params( cov_rep.PrecisionConvCholFilters) x = np.linalg.cholesky(np.linalg.inv(x_covariance)) # Our custom square root Wishart is optimized to work with PrecisionConvCholFilters, it will measure # the pdf of the Cholesky of the Precision img_w = int(np.sqrt(self.features_size)) sample_shape = tf.TensorShape((self.batch_size, img_w, img_w, 1)) x_cov_obj = cov_rep.PrecisionConvCholFilters( weights_precision=tf.constant(x_weights), filters_precision=tf.constant(x_basis), sample_shape=sample_shape) x_cov_obj.log_diag_chol_precision = log_diag if add_sparse_gamma: return x, x_cov_obj, sqrt_wishart_tfd, sqrt_gamma_gaussian, sparse_sqrt_gamma_gaussian else: return x, x_cov_obj, sqrt_wishart_tfd, sqrt_gamma_gaussian
def _create_covariance_instance(self): img_shape = (self.batch_size, self.img_size, self.img_size, self.num_ch) self.np_weights = self._create_random_weights() self.np_basis = self._create_random_basis() self.equivalent_sample_method = cov_rep.SampleMethod.CHOLESKY self.cov_object = cov_rep.PrecisionConvCholFilters( weights_precision=self.tf_weights, filters_precision=self.tf_basis, sample_shape=img_shape, inversion_method=self.inversion_method) self.np_filters = self._filters_from_weights_basis( self.np_weights, self.np_basis, self.filter_size) center_c = (self.filter_size**2) // 2 self.np_log_diag_chol_precision = np.log(self.np_filters[:, :, center_c]) self.np_basis = np.reshape(self.np_basis, newshape=self.basis_shape + (self.num_ch, self.num_ch)) self.np_chol_precision = self._matrix_from_filters( self.np_filters, self.filter_size) self.np_precision = np.matmul( self.np_chol_precision, self.np_chol_precision.transpose([0, 2, 1])) self._create_np_precision_cholesky() # For sampling with the network, for the covariance it should be equivalent to sampling with the # sqrt(covariance), and for the precision it will default to sampling with the cholesky self.np_precision_net_sample_matrix = self.np_chol_precision self.np_covariance_net_sample_matrix = self.np_covariance_chol_sample_matrix
def __init__(self, loc, weights_precision, filters_precision, log_diag_chol_precision, sample_shape, validate_args=False, allow_nan_stats=True, name="MultivariateNormalCholFilters"): """ Multivariate normal distribution for gray-scale images. Assumes an batch of images with shape [batch, img_w, img_h, 1] It models the distribution as N(mu, inv(L L.T)), where L is the Cholesky decomposition of the inverse of the covariance matrix. :param loc: The mean of the distribution [batch, img_w * img_h] :param weights_precision: Weight factors [batch, img_w, img_h, nb] :param filters_precision: Basis matrix (optionally it can be None) [nb, fs, fs, 1, 1] :param log_diag_chol_precision: The log values of the diagonal of L [batch, img_w * img_h] :param sample_shape: A list or tensor indicating the shape [batch, img_w, img_h, 1] :param validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. :param allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. :param name: Python `str` name prefixed to Ops created by this class. There are two modes of operation 1) Without basis functions, which is set by filters_precision = None. This internally creates a filters_precision of identity matrix nb must be a squared number and weights precision must follow weights_precision[..., 0:nb2] = 0 weights_precision[..., nb2] must be positive where nb2 = nb // 2 Example of sparsity pattern for nb = 9, and looking at a slice [0, 0, :, :] | 0 0 0 0 d x x x x| | 0 0 0 0 d x x x x| ... | 0 0 0 0 d x x x x| where 'd' must be positive. Use example batch = 10 img_w, img_h = 5, 5 fs = 3 nb2 = (fs**2) // 2 loc = tf.zeros((batch, img_w * img_h)) zeros = tf.zeros((batch, img_w, img_h, nb2)) weights_precision_right = tf.random_normal((batch, img_w, img_h, nb2)) log_diag_chol_precision = tf.random_normal((batch, img_w, img_h, 1)) diag_chol_precision = tf.exp(log_diag_chol_precision) weights_precision = tf.concat([zeros, diag_chol_precision, weights_precision_right], axis=3) mvg_dist = MultivariateNormalPrecCholFilters(loc, weights_precision, None, log_diag_chol_precision, (batch, img_w, img_h, 1)) 2) With a basis matrix, where weights_precision and filters_precision are given weights_precision must be positive filters_precision top half and left half of the center row must be zero and the center values must be positive. Example for fs = 3, and looking at a slice [0, :, :, 0, 0] | 0 0 0 | | 0 d x | | x x x | Use example batch = 10 img_w, img_h = 5, 5 fs = 3 nb = 4 fs2 = (fs ** 2) // 2 loc = tf.zeros((batch, img_w * img_h)) log_weights_precision = tf.random_normal((batch, img_w, img_h, nb)) weights_precision = tf.exp(log_weights_precision) left_filters = tf.zeros((nb, fs2, 1, 1)) log_center_filters = tf.random_normal((nb, 1, 1, 1)) right_filters = tf.random_normal((nb, fs2, 1, 1)) center_filters = tf.exp(log_center_filters) filters_precision = tf.concat([left_filters, center_filters, right_filters], axis=1) filters_precision = tf.reshape(filters_precision, (nb, fs, fs, 1, 1)) log_center_filters = tf.reshape(log_center_filters, (1, 1, 1, -1)) log_diag_chol_precision = tf.reduce_logsumexp(log_center_filters + log_weights_precision, axis=3) log_diag_chol_precision = tf.reshape(log_diag_chol_precision, (batch, img_w * img_h)) mvg_dist = MultivariateNormalPrecCholFilters(loc, weights_precision, filters_precision, log_diag_chol_precision, (batch, img_w, img_h, 1)) Enforcing positiveness could be done in all cases by employing the exp operation. TODO: Add operations to validate args """ parameters = locals() cov_obj = None with tf.name_scope(name=name): weights_precision = tf.convert_to_tensor(weights_precision) log_diag_chol_precision = tf.convert_to_tensor( log_diag_chol_precision) if filters_precision is not None: filters_precision = tf.convert_to_tensor(filters_precision) cov_obj = cov_rep.PrecisionConvCholFilters( weights_precision=weights_precision, filters_precision=filters_precision, sample_shape=sample_shape) cov_obj.log_diag_chol_precision = log_diag_chol_precision super().__init__(loc=loc, cov_obj=cov_obj, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters