예제 #1
0
    def test_neg_log_likelihood_mv_gaussian_conv_filters_chol(self):
        x, mu, covariance, weights, filters, log_diag = self._random_normal_params(
            cov_rep.PrecisionConvCholFilters)

        tf_mvnd = tfd.MultivariateNormalFullCovariance(
            loc=mu, covariance_matrix=covariance)
        tf_nll = -tf_mvnd.log_prob(x)

        img_size = int(np.sqrt(self.features_size))
        img_shape = (self.batch_size, img_size, img_size, 1)
        covar = cov_rep.PrecisionConvCholFilters(
            weights_precision=tf.convert_to_tensor(weights),
            filters_precision=tf.convert_to_tensor(filters),
            sample_shape=img_shape)
        covar.log_diag_chol_precision = log_diag

        r_tf = tf.convert_to_tensor(x - mu)
        r_tf_img = tf.reshape(r_tf, img_shape)
        nll = neg_log_likelihood_mv_gaussian(
            r_tf,
            x_precision_x=covar.x_precision_x(r_tf_img),
            log_det_cov=covar.log_det_covariance(),
            mean_batch=False)

        self._asset_allclose_tf_feed(nll, tf_nll)
예제 #2
0
    def test_kl_divergence_mv_gaussian_conv_filters_chol(self):
        _, mu1, covar1, weights1, filters1, log_diag1 = self._random_normal_params(
            cov_rep.PrecisionConvCholFilters)
        _, mu2, covar2, weights2, filters2, log_diag2 = self._random_normal_params(
            cov_rep.PrecisionConvCholFilters)

        tf_mvnd1 = dist.MultivariateNormalFullCovariance(
            loc=mu1, covariance_matrix=covar1)
        tf_mvnd2 = dist.MultivariateNormalFullCovariance(
            loc=mu2, covariance_matrix=covar2)

        tf_kldiv = dist.kl_divergence(tf_mvnd1, tf_mvnd2)

        mu1_tf, weights1, filters1 = self._convert_to_tensor(
            mu1, weights1, filters1)
        mu2_tf, weights2, filters2 = self._convert_to_tensor(
            mu2, weights2, filters2)

        img_size = int(np.sqrt(self.features_size))
        img_shape = (self.batch_size, img_size, img_size, 1)
        covar1 = cov_rep.PrecisionConvCholFilters(weights_precision=weights1,
                                                  filters_precision=filters1,
                                                  sample_shape=img_shape)
        covar1.log_diag_chol_precision = log_diag1

        covar2 = cov_rep.PrecisionConvCholFilters(weights_precision=weights2,
                                                  filters_precision=filters2,
                                                  sample_shape=img_shape)
        covar2.log_diag_chol_precision = log_diag2

        covar_kldiv = kl_divergence_mv_gaussian_v2(sigma1=covar1,
                                                   sigma2=covar2,
                                                   mu1=mu1_tf,
                                                   mu2=mu2_tf,
                                                   mean_batch=False)

        self._asset_allclose_tf_feed(tf_kldiv, covar_kldiv)

        tf_kldiv = tf.reduce_mean(tf_kldiv)
        covar_kldiv = kl_divergence_mv_gaussian_v2(sigma1=covar1,
                                                   sigma2=covar2,
                                                   mu1=mu1_tf,
                                                   mu2=mu2_tf)

        self._asset_allclose_tf_feed(tf_kldiv, covar_kldiv)
예제 #3
0
    def _sample_n_sparse(self, n, kw, seed=None):
        assert n == 1

        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape
        iw = int(np.sqrt(event_shape[0].value))  # Image width
        nb = kw**2  # Number of basis
        nb_half = nb // 2 + 1
        nch = 1  # Number of channels in the image

        stream = seed_stream.SeedStream(seed=seed, salt="Wishart")

        shape = tf.concat([batch_shape, [iw, iw, nb_half - 1]], 0)

        # Random sample for the off diagonal values as a dense tensor
        x_right = tf.random_normal(shape=shape,
                                   dtype=self.dtype,
                                   seed=stream())

        # The upper triangular values needed to get a square kernel per pixel
        x_left = tf.zeros(shape)

        # Random sample for the diagonal of the matrix
        x_diag = tf.random_gamma(shape=[n],
                                 alpha=self._multi_gamma_sequence(
                                     0.5 * self.df, self.p),
                                 beta=0.5,
                                 dtype=self.dtype,
                                 seed=stream())

        # Concatenate the diagonal and off-diagonal elements
        x_diag = tf.reshape(x_diag, (-1, iw, iw, nch))
        x = tf.concat([tf.sqrt(x_diag), x_right], axis=3)

        # Scale the sampled matrix using the distribution Scale matrix
        diag_scale = tf.exp(self.log_diag_scale)
        diag_scale = tf.reshape(diag_scale, (-1, iw, iw, nch))
        x *= tf.sqrt(diag_scale)  # Square root is equivalent to Cholesky

        # Concatenate with the zeros
        x = tf.concat([x_left, x], axis=3)

        # Create identity basis so that the sampled matrix is only defined by x, if this were not the case
        # we would have to do some optimization to find the basis and weights that reconstruct x
        identity_basis = tf.eye(num_rows=nb)
        identity_basis = tf.reshape(identity_basis, (nb, kw, kw, nch, nch))

        sample_shape = tf.concat([batch_shape, [iw, iw, nch]], axis=0)
        x_sparse = cov_rep.PrecisionConvCholFilters(
            weights_precision=x,
            filters_precision=identity_basis,
            sample_shape=sample_shape)

        return x_sparse
예제 #4
0
    def _create_covariance_instance(self):
        super()._create_covariance_instance()

        # Test without giving the filters, let the model build it
        img_shape = (self.batch_size, self.img_size, self.img_size,
                     self.num_ch)
        self.cov_object = cov_rep.PrecisionConvCholFilters(
            weights_precision=self.tf_weights,
            filters_precision=None,
            sample_shape=img_shape,
            inversion_method=self.inversion_method)
예제 #5
0
    def _create_single_sqrt_wishart_pair(self, add_sparse_gamma=False):
        # Create a random scale matrix for the Wishart distribution
        diag_precision_prior = np.abs(
            np.random.normal(size=(self.batch_size, self.features_size)))
        diag_precision_prior = diag_precision_prior.astype(
            self.dtype.as_numpy_dtype)
        precision_prior = np.zeros(shape=(self.batch_size, self.features_size,
                                          self.features_size),
                                   dtype=self.dtype.as_numpy_dtype)
        for i in range(self.batch_size):
            precision_prior[i][np.diag_indices_from(
                precision_prior[i])] = diag_precision_prior[i]
        log_diag_precision_prior = np.log(diag_precision_prior)

        # Create a random vector of degrees of freedom, whose values must be larger than features_size
        df = np.random.uniform(low=self.features_size,
                               high=self.features_size * 10,
                               size=self.batch_size)
        df = df.astype(self.dtype.as_numpy_dtype)

        # Create a square root Wishart distribution using bijectors
        wishart = tfd.Wishart(scale=precision_prior, df=df)
        cholesky_bijector = tfb.Invert(tfb.CholeskyOuterProduct())
        sqrt_wishart_tfd = tfd.TransformedDistribution(
            distribution=wishart, bijector=cholesky_bijector)

        # Create our custom square root Wishart distribution with the same parameters
        sqrt_gamma_gaussian = SqrtGammaGaussian(
            df=df, log_diag_scale=log_diag_precision_prior)
        if add_sparse_gamma:
            sparse_sqrt_gamma_gaussian = SparseSqrtGammaGaussian(
                df=df, log_diag_scale=log_diag_precision_prior)

        # Create a random Cholesky matrix to test the probability density functions
        _, __, x_covariance, x_weights, x_basis, log_diag = self._random_normal_params(
            cov_rep.PrecisionConvCholFilters)
        x = np.linalg.cholesky(np.linalg.inv(x_covariance))

        # Our custom square root Wishart is optimized to work with PrecisionConvCholFilters, it will measure
        # the pdf of the Cholesky of the Precision
        img_w = int(np.sqrt(self.features_size))
        sample_shape = tf.TensorShape((self.batch_size, img_w, img_w, 1))
        x_cov_obj = cov_rep.PrecisionConvCholFilters(
            weights_precision=tf.constant(x_weights),
            filters_precision=tf.constant(x_basis),
            sample_shape=sample_shape)
        x_cov_obj.log_diag_chol_precision = log_diag

        if add_sparse_gamma:
            return x, x_cov_obj, sqrt_wishart_tfd, sqrt_gamma_gaussian, sparse_sqrt_gamma_gaussian
        else:
            return x, x_cov_obj, sqrt_wishart_tfd, sqrt_gamma_gaussian
예제 #6
0
    def _create_covariance_instance(self):
        img_shape = (self.batch_size, self.img_size, self.img_size,
                     self.num_ch)
        self.np_weights = self._create_random_weights()
        self.np_basis = self._create_random_basis()

        self.equivalent_sample_method = cov_rep.SampleMethod.CHOLESKY

        self.cov_object = cov_rep.PrecisionConvCholFilters(
            weights_precision=self.tf_weights,
            filters_precision=self.tf_basis,
            sample_shape=img_shape,
            inversion_method=self.inversion_method)

        self.np_filters = self._filters_from_weights_basis(
            self.np_weights, self.np_basis, self.filter_size)

        center_c = (self.filter_size**2) // 2
        self.np_log_diag_chol_precision = np.log(self.np_filters[:, :,
                                                                 center_c])

        self.np_basis = np.reshape(self.np_basis,
                                   newshape=self.basis_shape +
                                   (self.num_ch, self.num_ch))

        self.np_chol_precision = self._matrix_from_filters(
            self.np_filters, self.filter_size)

        self.np_precision = np.matmul(
            self.np_chol_precision, self.np_chol_precision.transpose([0, 2,
                                                                      1]))

        self._create_np_precision_cholesky()

        # For sampling with the network, for the covariance it should be equivalent to sampling with the
        # sqrt(covariance), and for the precision it will default to sampling with the cholesky
        self.np_precision_net_sample_matrix = self.np_chol_precision
        self.np_covariance_net_sample_matrix = self.np_covariance_chol_sample_matrix
예제 #7
0
파일: mvg.py 프로젝트: suyanzhou626/tf_mvg
    def __init__(self,
                 loc,
                 weights_precision,
                 filters_precision,
                 log_diag_chol_precision,
                 sample_shape,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MultivariateNormalCholFilters"):
        """
        Multivariate normal distribution for gray-scale images. Assumes an batch of images
            with shape [batch, img_w, img_h, 1]

            It models the distribution as N(mu, inv(L L.T)), where L is the Cholesky decomposition of the
            inverse of the covariance matrix.

        :param loc: The mean of the distribution [batch, img_w * img_h]
        :param weights_precision: Weight factors [batch, img_w, img_h, nb]
        :param filters_precision: Basis matrix (optionally it can be None) [nb, fs, fs, 1, 1]
        :param log_diag_chol_precision: The log values of the diagonal of L [batch, img_w * img_h]
        :param sample_shape:  A list or tensor indicating the shape [batch, img_w, img_h, 1]
        :param validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        :param allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        :param name: Python `str` name prefixed to Ops created by this class.

        There are two modes of operation
        1) Without basis functions, which is set by filters_precision = None. This internally creates a
        filters_precision of identity matrix

        nb must be a squared number and weights precision must follow
        weights_precision[..., 0:nb2] = 0
        weights_precision[..., nb2] must be positive
         where nb2 = nb // 2

        Example of sparsity pattern for nb = 9, and looking at a slice [0, 0, :, :]

        | 0 0 0 0 d x x x x|
        | 0 0 0 0 d x x x x|
                ...
        | 0 0 0 0 d x x x x|

        where 'd' must be positive.

        Use example

            batch = 10
            img_w, img_h = 5, 5
            fs = 3
            nb2 = (fs**2) // 2

            loc = tf.zeros((batch, img_w * img_h))

            zeros = tf.zeros((batch, img_w, img_h, nb2))
            weights_precision_right = tf.random_normal((batch, img_w, img_h, nb2))
            log_diag_chol_precision = tf.random_normal((batch, img_w, img_h, 1))

            diag_chol_precision = tf.exp(log_diag_chol_precision)

            weights_precision = tf.concat([zeros, diag_chol_precision, weights_precision_right], axis=3)

            mvg_dist = MultivariateNormalPrecCholFilters(loc, weights_precision, None, log_diag_chol_precision,
                                                        (batch, img_w, img_h, 1))

        2) With a basis matrix, where weights_precision and filters_precision are given

        weights_precision must be positive

        filters_precision top half and left half of the center row must be zero
        and the center values must be positive.

        Example for fs = 3, and looking at a slice [0, :, :, 0, 0]

        | 0 0 0 |
        | 0 d x |
        | x x x |


        Use example

            batch = 10
            img_w, img_h = 5, 5
            fs = 3
            nb = 4
            fs2 = (fs ** 2) // 2

            loc = tf.zeros((batch, img_w * img_h))

            log_weights_precision = tf.random_normal((batch, img_w, img_h, nb))
            weights_precision = tf.exp(log_weights_precision)

            left_filters = tf.zeros((nb, fs2, 1, 1))
            log_center_filters = tf.random_normal((nb, 1, 1, 1))
            right_filters = tf.random_normal((nb, fs2, 1, 1))

            center_filters = tf.exp(log_center_filters)

            filters_precision = tf.concat([left_filters, center_filters, right_filters], axis=1)
            filters_precision = tf.reshape(filters_precision, (nb, fs, fs, 1, 1))

            log_center_filters = tf.reshape(log_center_filters, (1, 1, 1, -1))

            log_diag_chol_precision = tf.reduce_logsumexp(log_center_filters + log_weights_precision, axis=3)
            log_diag_chol_precision = tf.reshape(log_diag_chol_precision, (batch, img_w * img_h))

            mvg_dist = MultivariateNormalPrecCholFilters(loc, weights_precision, filters_precision,
                                                         log_diag_chol_precision, (batch, img_w, img_h, 1))

        Enforcing positiveness could be done in all cases by employing the exp operation.

        TODO: Add operations to validate args
        """
        parameters = locals()

        cov_obj = None

        with tf.name_scope(name=name):
            weights_precision = tf.convert_to_tensor(weights_precision)
            log_diag_chol_precision = tf.convert_to_tensor(
                log_diag_chol_precision)
            if filters_precision is not None:
                filters_precision = tf.convert_to_tensor(filters_precision)

            cov_obj = cov_rep.PrecisionConvCholFilters(
                weights_precision=weights_precision,
                filters_precision=filters_precision,
                sample_shape=sample_shape)
            cov_obj.log_diag_chol_precision = log_diag_chol_precision

        super().__init__(loc=loc,
                         cov_obj=cov_obj,
                         validate_args=validate_args,
                         allow_nan_stats=allow_nan_stats,
                         name=name)
        self._parameters = parameters