def testMVNConjugateLinearUpdateSupportsBatchShape(self): strm = test_util.test_seed_stream() num_latents = 2 num_outputs = 4 batch_shape = [3, 1] prior_mean = tf.ones([num_latents]) prior_scale = tf.eye(num_latents) * 5. likelihood_scale = tf.linalg.LinearOperatorLowerTriangular( tfb.ScaleTriL().forward( tf.random.normal(shape=batch_shape + [int(num_outputs * (num_outputs + 1) / 2)], seed=strm()))) linear_transformation = tf.random.normal( batch_shape + [num_outputs, num_latents], seed=strm()) * 5. true_latent = tf.random.normal(batch_shape + [num_latents], seed=strm()) observation = tf.linalg.matvec(linear_transformation, true_latent) posterior_mean, posterior_prec = (tfd.mvn_conjugate_linear_update( prior_mean=prior_mean, prior_scale=prior_scale, linear_transformation=linear_transformation, likelihood_scale=likelihood_scale, observation=observation)) self._mvn_linear_update_test_helper( prior_mean=prior_mean, prior_scale=prior_scale, linear_transformation=linear_transformation, likelihood_scale=likelihood_scale.to_dense(), observation=observation, candidate_posterior_mean=posterior_mean, candidate_posterior_prec=posterior_prec.to_dense())
def testMVNConjugateLinearUpdatePreservesStructuredLinops(self): strm = test_util.test_seed_stream() num_outputs = 4 prior_scale = tf.linalg.LinearOperatorScaledIdentity(num_outputs, 4.) likelihood_scale = tf.linalg.LinearOperatorScaledIdentity( num_outputs, 0.2) linear_transformation = tf.linalg.LinearOperatorIdentity(num_outputs) observation = tf.random.normal([num_outputs], seed=strm()) posterior_mean, posterior_prec = (tfd.mvn_conjugate_linear_update( prior_scale=prior_scale, linear_transformation=linear_transformation, likelihood_scale=likelihood_scale, observation=observation)) # TODO(davmre): enable next line once internal CI is updated to recent TF. # self.assertIsInstance(posterior_prec, # tf.linalg.LinearOperatorScaledIdentity) self._mvn_linear_update_test_helper( prior_mean=tf.zeros([num_outputs]), prior_scale=prior_scale.to_dense(), linear_transformation=linear_transformation.to_dense(), likelihood_scale=likelihood_scale.to_dense(), observation=observation, candidate_posterior_mean=posterior_mean, candidate_posterior_prec=posterior_prec.to_dense()) # Also check the result against the scalar calculation. scalar_posterior_dist = tfd.normal_conjugates_known_scale_posterior( prior=tfd.Normal(loc=0., scale=prior_scale.diag_part()), scale=likelihood_scale.diag_part(), s=observation, n=1) (posterior_mean_, posterior_prec_, scalar_posterior_mean_, scalar_posterior_prec_) = self.evaluate( (posterior_mean, posterior_prec.to_dense(), scalar_posterior_dist.mean(), tf.linalg.diag(1. / scalar_posterior_dist.variance()))) self.assertAllClose(posterior_mean_, scalar_posterior_mean_) self.assertAllClose(posterior_prec_, scalar_posterior_prec_)