def test_multivariate_normal_prob_positive_product_of_components(self): # Test that importance sampling can correctly estimate the probability that # the product of components in a MultivariateNormal are > 0. n = 1000 with self.cached_session(): p = mvn_diag_lib.MultivariateNormalDiag(loc=[0.], scale_diag=[1.0, 1.0]) q = mvn_diag_lib.MultivariateNormalDiag(loc=[0.5], scale_diag=[3., 3.]) # Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x). # Should equal 1/2 because p is a spherical Gaussian centered at (0, 0). def indicator(x): x1_times_x2 = math_ops.reduce_prod(x, axis=[-1]) return 0.5 * (math_ops.sign(x1_times_x2) + 1.0) prob = mc.expectation_importance_sampler(f=indicator, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42) # Relative tolerance (rtol) chosen 2 times as large as minimim needed to # pass. # Convergence is +- 0.004 if n = 100k. self.assertEqual(p.batch_shape, prob.get_shape()) self.assertAllClose(0.5, prob.eval(), rtol=0.05)
def testSampleConsistentStats(self): loc = np.float32([[-1., 1], [1, -1]]) scale = np.float32([1., 0.5]) n_samp = 1e4 with self.cached_session() as sess: ind = independent_lib.Independent( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale), reinterpreted_batch_ndims=1) x = ind.sample(int(n_samp), seed=42) sample_mean = math_ops.reduce_mean(x, axis=0) sample_var = math_ops.reduce_mean( math_ops.squared_difference(x, sample_mean), axis=0) sample_std = math_ops.sqrt(sample_var) sample_entropy = -math_ops.reduce_mean(ind.log_prob(x), axis=0) [ sample_mean_, sample_var_, sample_std_, sample_entropy_, actual_mean_, actual_var_, actual_std_, actual_entropy_, actual_mode_, ] = sess.run([ sample_mean, sample_var, sample_std, sample_entropy, ind.mean(), ind.variance(), ind.stddev(), ind.entropy(), ind.mode(), ]) self.assertAllClose(sample_mean_, actual_mean_, rtol=0.02, atol=0.) self.assertAllClose(sample_var_, actual_var_, rtol=0.04, atol=0.) self.assertAllClose(sample_std_, actual_std_, rtol=0.02, atol=0.) self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.) self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.)
def testKLRaises(self): ind1 = independent_lib.Independent( distribution=normal_lib.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=1) ind2 = independent_lib.Independent( distribution=normal_lib.Normal( loc=np.float32(-1), scale=np.float32(0.5)), reinterpreted_batch_ndims=0) with self.assertRaisesRegexp( ValueError, "Event shapes do not match"): kullback_leibler.kl_divergence(ind1, ind2) ind1 = independent_lib.Independent( distribution=normal_lib.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=1) ind2 = independent_lib.Independent( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=np.float32([-1., 1]), scale_diag=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=0) with self.assertRaisesRegexp( NotImplementedError, "different event shapes"): kullback_leibler.kl_divergence(ind1, ind2)
def test_non_vector_shape(self): dims = 2 new_batch_shape = 2 old_batch_shape = [2] new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: with self.assertRaisesRegexp(ValueError, r".*must be a vector.*"): batch_reshape_lib.BatchReshape(distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) else: with self.cached_session(): with self.assertRaisesOpError(r".*must be a vector.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True).sample().eval()
def test_non_positive_shape(self): dims = 2 old_batch_shape = [4] if self.is_static_shape: # Unknown first dimension does not trigger size check. Note that # any dimension < 0 is treated statically as unknown. new_batch_shape = [-1, 0] else: new_batch_shape = [-2, -2] # -2 * -2 = 4, same size as the old shape. new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"): batch_reshape_lib.BatchReshape(distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) else: with self.cached_session(): with self.assertRaisesOpError(r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True).sample().eval()
def test_bad_reshape_size(self): dims = 2 new_batch_shape = [2, 3] old_batch_shape = [2] # 2 != 2*3 new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: with self.assertRaisesRegexp( ValueError, (r"`batch_shape` size \(6\) must match " r"`distribution\.batch_shape` size \(2\)")): batch_reshape_lib.BatchReshape(distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) else: with self.cached_session(): with self.assertRaisesOpError(r"Shape sizes do not match."): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True).sample().eval()
def testSampleConsistentMeanCovariance(self): with self.cached_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) self.run_test_sample_consistent_mean_covariance(sess.run, gm)
def testVarianceConsistentCovariance(self): with self.cached_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) cov_, var_ = sess.run([gm.covariance(), gm.variance()]) self.assertAllClose(cov_.diagonal(), var_, atol=0.)
def testSampleAndLogProbMultivariateShapes(self): with self.cached_session(): gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) x = gm.sample([4, 5], seed=42) log_prob_x = gm.log_prob(x) self.assertEqual([4, 5, 2], x.shape) self.assertEqual([4, 5], log_prob_x.shape)
def testKLMultivariateToMultivariate(self): # (1, 1, 2) batch of MVNDiag mvn1 = mvn_diag_lib.MultivariateNormalDiag( loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]), scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]])) ind1 = independent_lib.Independent( distribution=mvn1, reinterpreted_batch_ndims=2) # (1, 1, 2) batch of MVNDiag mvn2 = mvn_diag_lib.MultivariateNormalDiag( loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]), scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]])) ind2 = independent_lib.Independent( distribution=mvn2, reinterpreted_batch_ndims=2) mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2) ind_kl = kullback_leibler.kl_divergence(ind1, ind2) self.assertAllClose( self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])), self.evaluate(ind_kl))
def testSampleConsistentLogProb(self): with self.cached_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( sess.run, gm, radius=1., center=[-1., 1], rtol=0.02) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( sess.run, gm, radius=1., center=[1., -1], rtol=0.02)
def test_pad_mixture_dimensions_mixture_same_family(self): with self.cached_session() as sess: gm = mixture_same_family.MixtureSameFamily( mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5])) x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) x_pad = distribution_util.pad_mixture_dimensions( x, gm, gm.mixture_distribution, gm.event_shape.ndims) x_out, x_pad_out = sess.run([x, x_pad]) self.assertAllEqual(x_pad_out.shape, [2, 2, 1]) self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
def make_mvn(self, dims, new_batch_shape, old_batch_shape): new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) reshape_mvn = batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) return mvn, reshape_mvn
def testSampleAndLogProbMultivariate(self): loc = np.float32([[-1., 1], [1, -1]]) scale = np.float32([1., 0.5]) with self.cached_session() as sess: ind = independent_lib.Independent( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale), reinterpreted_batch_ndims=1) x = ind.sample([4, 5], seed=42) log_prob_x = ind.log_prob(x) x_, actual_log_prob_x = sess.run([x, log_prob_x]) self.assertEqual([], ind.batch_shape) self.assertEqual([2, 2], ind.event_shape) self.assertEqual([4, 5, 2, 2], x.shape) self.assertEqual([4, 5], log_prob_x.shape) expected_log_prob_x = stats.norm(loc, scale[:, None]).logpdf( x_).sum(-1).sum(-1) self.assertAllClose(expected_log_prob_x, actual_log_prob_x, rtol=1e-6, atol=0.)