Пример #1
0
    def test_kl_divergence_mv_gaussian_v2_full(self):
        _, mu1, covar1 = self._random_normal_params(cov_rep.CovarianceFull)
        _, mu2, covar2 = self._random_normal_params(cov_rep.CovarianceFull)

        tf_mvnd1 = dist.MultivariateNormalFullCovariance(
            loc=mu1, covariance_matrix=covar1)
        tf_mvnd2 = dist.MultivariateNormalFullCovariance(
            loc=mu2, covariance_matrix=covar2)

        tf_kldiv = dist.kl_divergence(tf_mvnd1, tf_mvnd2)

        mu1_tf, mu2_tf = self._convert_to_tensor(mu1, mu2)
        covar1 = cov_rep.CovarianceFull(
            covariance=tf.convert_to_tensor(covar1))
        covar2 = cov_rep.CovarianceFull(
            covariance=tf.convert_to_tensor(covar2))
        covar_kldiv = kl_divergence_mv_gaussian_v2(sigma1=covar1,
                                                   sigma2=covar2,
                                                   mu1=mu1_tf,
                                                   mu2=mu2_tf,
                                                   mean_batch=False)

        self._asset_allclose_tf_feed(tf_kldiv, covar_kldiv)

        tf_kldiv = tf.reduce_mean(tf_kldiv)
        covar_kldiv = kl_divergence_mv_gaussian_v2(sigma1=covar1,
                                                   sigma2=covar2,
                                                   mu1=mu1_tf,
                                                   mu2=mu2_tf)

        self._asset_allclose_tf_feed(tf_kldiv, covar_kldiv)
Пример #2
0
    def test_kl_divergence_mv_gaussian_v2_diag(self):
        _, mu1, sigma_sq1 = self._random_normal_params(cov_rep.CovarianceDiag)
        _, mu2, sigma_sq2 = self._random_normal_params(cov_rep.CovarianceDiag)

        tf_mvnd1 = dist.MultivariateNormalDiag(loc=mu1,
                                               scale_diag=np.sqrt(sigma_sq1))
        tf_mvnd2 = dist.MultivariateNormalDiag(loc=mu2,
                                               scale_diag=np.sqrt(sigma_sq2))

        tf_kldiv = dist.kl_divergence(tf_mvnd1, tf_mvnd2)

        mu1_tf, mu2_tf = self._convert_to_tensor(mu1, mu2)
        sigma_sq1 = cov_rep.CovarianceDiag(
            log_diag_covariance=tf.log(sigma_sq1))
        covar2 = cov_rep.CovarianceDiag(log_diag_covariance=tf.log(sigma_sq2))
        covar_kldiv = kl_divergence_mv_gaussian_v2(sigma1=sigma_sq1,
                                                   sigma2=covar2,
                                                   mu1=mu1_tf,
                                                   mu2=mu2_tf,
                                                   mean_batch=False)

        self._asset_allclose_tf_feed(tf_kldiv, covar_kldiv)

        tf_kldiv = tf.reduce_mean(tf_kldiv)
        covar_kldiv = kl_divergence_mv_gaussian_v2(sigma1=sigma_sq1,
                                                   sigma2=covar2,
                                                   mu1=mu1_tf,
                                                   mu2=mu2_tf)

        self._asset_allclose_tf_feed(tf_kldiv, covar_kldiv)
Пример #3
0
def _kl_mvnd_mvnd(a, b, name=None):
    """Batched KL divergence `KL(a || b)` for multivariate Normals."""
    return kl_divergence_mv_gaussian_v2(mu1=a.loc,
                                        mu2=b.loc,
                                        sigma1=a.cov_obj,
                                        sigma2=b.cov_obj,
                                        mean_batch=False,
                                        name=name)
Пример #4
0
    def test_kl_divergence_mv_gaussian_conv_filters_chol(self):
        _, mu1, covar1, weights1, filters1, log_diag1 = self._random_normal_params(
            cov_rep.PrecisionConvCholFilters)
        _, mu2, covar2, weights2, filters2, log_diag2 = self._random_normal_params(
            cov_rep.PrecisionConvCholFilters)

        tf_mvnd1 = dist.MultivariateNormalFullCovariance(
            loc=mu1, covariance_matrix=covar1)
        tf_mvnd2 = dist.MultivariateNormalFullCovariance(
            loc=mu2, covariance_matrix=covar2)

        tf_kldiv = dist.kl_divergence(tf_mvnd1, tf_mvnd2)

        mu1_tf, weights1, filters1 = self._convert_to_tensor(
            mu1, weights1, filters1)
        mu2_tf, weights2, filters2 = self._convert_to_tensor(
            mu2, weights2, filters2)

        img_size = int(np.sqrt(self.features_size))
        img_shape = (self.batch_size, img_size, img_size, 1)
        covar1 = cov_rep.PrecisionConvCholFilters(weights_precision=weights1,
                                                  filters_precision=filters1,
                                                  sample_shape=img_shape)
        covar1.log_diag_chol_precision = log_diag1

        covar2 = cov_rep.PrecisionConvCholFilters(weights_precision=weights2,
                                                  filters_precision=filters2,
                                                  sample_shape=img_shape)
        covar2.log_diag_chol_precision = log_diag2

        covar_kldiv = kl_divergence_mv_gaussian_v2(sigma1=covar1,
                                                   sigma2=covar2,
                                                   mu1=mu1_tf,
                                                   mu2=mu2_tf,
                                                   mean_batch=False)

        self._asset_allclose_tf_feed(tf_kldiv, covar_kldiv)

        tf_kldiv = tf.reduce_mean(tf_kldiv)
        covar_kldiv = kl_divergence_mv_gaussian_v2(sigma1=covar1,
                                                   sigma2=covar2,
                                                   mu1=mu1_tf,
                                                   mu2=mu2_tf)

        self._asset_allclose_tf_feed(tf_kldiv, covar_kldiv)
Пример #5
0
def _kl_mvnd_tfmvnd(a, b, name=None):
    """Batched KL divergence `KL(a || b)` for multivariate Normals, when "a" is a
    tf.contrib.distributions.MultivariateNormal* distribution"""
    a_cov_obj = cov_rep.CovarianceCholesky(chol_covariance=a.scale.to_dense())
    return kl_divergence_mv_gaussian_v2(mu1=a.loc,
                                        mu2=b.loc,
                                        sigma1=a_cov_obj,
                                        sigma2=b.cov_obj,
                                        mean_batch=False,
                                        name=name)