예제 #1
0
    def testMVNConjugateLinearUpdateSupportsBatchShape(self):
        strm = test_util.test_seed_stream()
        num_latents = 2
        num_outputs = 4
        batch_shape = [3, 1]

        prior_mean = tf.ones([num_latents])
        prior_scale = tf.eye(num_latents) * 5.
        likelihood_scale = tf.linalg.LinearOperatorLowerTriangular(
            tfb.ScaleTriL().forward(
                tf.random.normal(shape=batch_shape +
                                 [int(num_outputs * (num_outputs + 1) / 2)],
                                 seed=strm())))
        linear_transformation = tf.random.normal(
            batch_shape + [num_outputs, num_latents], seed=strm()) * 5.
        true_latent = tf.random.normal(batch_shape + [num_latents],
                                       seed=strm())
        observation = tf.linalg.matvec(linear_transformation, true_latent)
        posterior_mean, posterior_prec = (tfd.mvn_conjugate_linear_update(
            prior_mean=prior_mean,
            prior_scale=prior_scale,
            linear_transformation=linear_transformation,
            likelihood_scale=likelihood_scale,
            observation=observation))

        self._mvn_linear_update_test_helper(
            prior_mean=prior_mean,
            prior_scale=prior_scale,
            linear_transformation=linear_transformation,
            likelihood_scale=likelihood_scale.to_dense(),
            observation=observation,
            candidate_posterior_mean=posterior_mean,
            candidate_posterior_prec=posterior_prec.to_dense())
예제 #2
0
    def testMVNConjugateLinearUpdatePreservesStructuredLinops(self):
        strm = test_util.test_seed_stream()
        num_outputs = 4

        prior_scale = tf.linalg.LinearOperatorScaledIdentity(num_outputs, 4.)
        likelihood_scale = tf.linalg.LinearOperatorScaledIdentity(
            num_outputs, 0.2)
        linear_transformation = tf.linalg.LinearOperatorIdentity(num_outputs)
        observation = tf.random.normal([num_outputs], seed=strm())
        posterior_mean, posterior_prec = (tfd.mvn_conjugate_linear_update(
            prior_scale=prior_scale,
            linear_transformation=linear_transformation,
            likelihood_scale=likelihood_scale,
            observation=observation))
        # TODO(davmre): enable next line once internal CI is updated to recent TF.
        # self.assertIsInstance(posterior_prec,
        #                       tf.linalg.LinearOperatorScaledIdentity)

        self._mvn_linear_update_test_helper(
            prior_mean=tf.zeros([num_outputs]),
            prior_scale=prior_scale.to_dense(),
            linear_transformation=linear_transformation.to_dense(),
            likelihood_scale=likelihood_scale.to_dense(),
            observation=observation,
            candidate_posterior_mean=posterior_mean,
            candidate_posterior_prec=posterior_prec.to_dense())

        # Also check the result against the scalar calculation.
        scalar_posterior_dist = tfd.normal_conjugates_known_scale_posterior(
            prior=tfd.Normal(loc=0., scale=prior_scale.diag_part()),
            scale=likelihood_scale.diag_part(),
            s=observation,
            n=1)
        (posterior_mean_, posterior_prec_, scalar_posterior_mean_,
         scalar_posterior_prec_) = self.evaluate(
             (posterior_mean, posterior_prec.to_dense(),
              scalar_posterior_dist.mean(),
              tf.linalg.diag(1. / scalar_posterior_dist.variance())))
        self.assertAllClose(posterior_mean_, scalar_posterior_mean_)
        self.assertAllClose(posterior_prec_, scalar_posterior_prec_)