def test_convergence_to_kl_using_sample_form_on_3dim_normal(self): # Test that the sample mean KL is the same as analytic when we use samples # to estimate every part of the KL divergence ratio. vector_shape = (2, 3) n_samples = 5000 with self.test_session(): q = mvn_diag_lib.MultivariateNormalDiag( loc=self._rng.rand(*vector_shape), scale_diag=self._rng.rand(*vector_shape)) p = mvn_diag_lib.MultivariateNormalDiag( loc=self._rng.rand(*vector_shape), scale_diag=self._rng.rand(*vector_shape)) # In this case, the log_ratio is the KL. sample_kl = -1 * entropy.elbo_ratio( log_p=p.log_prob, q=q, n=n_samples, form=entropy.ELBOForms.sample, seed=42) actual_kl = kullback_leibler_lib.kl_divergence(q, p) # Relative tolerance (rtol) chosen 2 times as large as minimim needed to # pass. self.assertEqual((2,), sample_kl.get_shape()) self.assertAllClose(actual_kl.eval(), sample_kl.eval(), rtol=0.05)
def test_multivariate_normal_prob_positive_product_of_components(self): # Test that importance sampling can correctly estimate the probability that # the product of components in a MultivariateNormal are > 0. n = 1000 with self.cached_session(): p = mvn_diag_lib.MultivariateNormalDiag(loc=[0.], scale_diag=[1.0, 1.0]) q = mvn_diag_lib.MultivariateNormalDiag(loc=[0.5], scale_diag=[3., 3.]) # Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x). # Should equal 1/2 because p is a spherical Gaussian centered at (0, 0). def indicator(x): x1_times_x2 = math_ops.reduce_prod(x, axis=[-1]) return 0.5 * (math_ops.sign(x1_times_x2) + 1.0) prob = mc.expectation_importance_sampler(f=indicator, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42) # Relative tolerance (rtol) chosen 2 times as large as minimim needed to # pass. # Convergence is +- 0.004 if n = 100k. self.assertEqual(p.batch_shape, prob.get_shape()) self.assertAllClose(0.5, prob.eval(), rtol=0.05)
def test_non_vector_shape(self): dims = 2 new_batch_shape = 2 old_batch_shape = [2] new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: with self.assertRaisesRegexp(ValueError, r".*must be a vector.*"): batch_reshape_lib.BatchReshape(distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) else: with self.cached_session(): with self.assertRaisesOpError(r".*must be a vector.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True).sample().eval()
def testSampleConsistentMeanCovariance(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) self.run_test_sample_consistent_mean_covariance(sess.run, gm)
def test_bad_reshape_size(self): dims = 2 new_batch_shape = [2, 3] old_batch_shape = [2] # 2 != 2*3 new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: with self.assertRaisesRegexp( ValueError, (r"`batch_shape` size \(6\) must match " r"`distribution\.batch_shape` size \(2\)")): batch_reshape_lib.BatchReshape(distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) else: with self.cached_session(): with self.assertRaisesOpError(r"Shape sizes do not match."): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True).sample().eval()
def test_non_positive_shape(self): dims = 2 old_batch_shape = [4] if self.is_static_shape: # Unknown first dimension does not trigger size check. Note that # any dimension < 0 is treated statically as unknown. new_batch_shape = [-1, 0] else: new_batch_shape = [-2, -2] # -2 * -2 = 4, same size as the old shape. new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"): batch_reshape_lib.BatchReshape(distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) else: with self.cached_session(): with self.assertRaisesOpError(r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True).sample().eval()
def test_kl_reverse_multidim(self): with self.test_session() as sess: d = 5 # Dimension p = mvn_full_lib.MultivariateNormalFullCovariance( covariance_matrix=self._tridiag( d, diag_value=1, offdiag_value=0.5)) q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5] * d) approx_kl = cd.monte_carlo_csiszar_f_divergence(f=cd.kl_reverse, p=p, q=q, num_draws=int(1e5), seed=1) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_reverse(logu, self_normalized=True), p=p, q=q, num_draws=int(1e5), seed=1) exact_kl = kullback_leibler.kl_divergence(q, p) [approx_kl_, approx_kl_self_normalized_, exact_kl_ ] = sess.run([approx_kl, approx_kl_self_normalized, exact_kl]) self.assertAllClose(approx_kl_, exact_kl_, rtol=0.02, atol=0.) self.assertAllClose(approx_kl_self_normalized_, exact_kl_, rtol=0.08, atol=0.)
def testVarianceConsistentCovariance(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) cov_, var_ = sess.run([gm.covariance(), gm.variance()]) self.assertAllClose(cov_.diagonal(), var_, atol=0.)
def testSampleAndLogProbMultivariateShapes(self): with self.test_session(): gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) x = gm.sample([4, 5], seed=42) log_prob_x = gm.log_prob(x) self.assertEqual([4, 5, 2], x.shape) self.assertEqual([4, 5], log_prob_x.shape)
def testSampleConsistentStats(self): loc = np.float32([[-1., 1], [1, -1]]) scale = np.float32([1., 0.5]) n_samp = 1e4 with self.cached_session() as sess: ind = independent_lib.Independent( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale), reinterpreted_batch_ndims=1) x = ind.sample(int(n_samp), seed=42) sample_mean = math_ops.reduce_mean(x, axis=0) sample_var = math_ops.reduce_mean(math_ops.squared_difference( x, sample_mean), axis=0) sample_std = math_ops.sqrt(sample_var) sample_entropy = -math_ops.reduce_mean(ind.log_prob(x), axis=0) [ sample_mean_, sample_var_, sample_std_, sample_entropy_, actual_mean_, actual_var_, actual_std_, actual_entropy_, actual_mode_, ] = sess.run([ sample_mean, sample_var, sample_std, sample_entropy, ind.mean(), ind.variance(), ind.stddev(), ind.entropy(), ind.mode(), ]) self.assertAllCloseAccordingToType(sample_mean_, actual_mean_, rtol=0.02) self.assertAllCloseAccordingToType(sample_var_, actual_var_, rtol=0.04) self.assertAllCloseAccordingToType(sample_std_, actual_std_, rtol=0.02) self.assertAllCloseAccordingToType(sample_entropy_, actual_entropy_, rtol=0.01) self.assertAllCloseAccordingToType(loc, actual_mode_, rtol=1e-6)
def testKLMultivariateToMultivariate(self): # (1, 1, 2) batch of MVNDiag mvn1 = mvn_diag_lib.MultivariateNormalDiag( loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]), scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]])) ind1 = independent_lib.Independent(distribution=mvn1, reinterpreted_batch_ndims=2) # (1, 1, 2) batch of MVNDiag mvn2 = mvn_diag_lib.MultivariateNormalDiag( loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]), scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]])) ind2 = independent_lib.Independent(distribution=mvn2, reinterpreted_batch_ndims=2) mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2) ind_kl = kullback_leibler.kl_divergence(ind1, ind2) self.assertAllClose( self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])), self.evaluate(ind_kl))
def testSampleConsistentLogProb(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( sess.run, gm, radius=1., center=[-1., 1], rtol=0.02) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( sess.run, gm, radius=1., center=[1., -1], rtol=0.02)
def test_divergence_between_identical_distributions_is_zero(self): n = 1000 vector_shape = (2, 3) with self.test_session(): q = mvn_diag_lib.MultivariateNormalDiag( loc=self._rng.rand(*vector_shape), scale_diag=self._rng.rand(*vector_shape)) for alpha in [0.25, 0.75]: negative_renyi_divergence = entropy.renyi_ratio( log_p=q.log_prob, q=q, n=n, alpha=alpha, seed=0) self.assertEqual((2,), negative_renyi_divergence.get_shape()) self.assertAllClose(np.zeros(2), negative_renyi_divergence.eval())
def test_pad_mixture_dimensions_mixture_same_family(self): with self.test_session() as sess: gm = mixture_same_family.MixtureSameFamily( mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5])) x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) x_pad = distribution_util.pad_mixture_dimensions( x, gm, gm.mixture_distribution, gm.event_shape.ndims) x_out, x_pad_out = sess.run([x, x_pad]) self.assertAllEqual(x_pad_out.shape, [2, 2, 1]) self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
def make_mvn(self, dims, new_batch_shape, old_batch_shape): new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) reshape_mvn = batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) return mvn, reshape_mvn
def test_sample_kl_zero_when_p_and_q_are_the_same_distribution(self): n_samples = 50 vector_shape = (2, 3) with self.test_session(): q = mvn_diag_lib.MultivariateNormalDiag( loc=self._rng.rand(*vector_shape), scale_diag=self._rng.rand(*vector_shape)) # In this case, the log_ratio is the KL. sample_kl = -1 * entropy.elbo_ratio( log_p=q.log_prob, q=q, n=n_samples, form=entropy.ELBOForms.sample, seed=42) self.assertEqual((2,), sample_kl.get_shape()) self.assertAllClose(np.zeros(2), sample_kl.eval())
def testKLRaises(self): ind1 = independent_lib.Independent(distribution=normal_lib.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=1) ind2 = independent_lib.Independent(distribution=normal_lib.Normal( loc=np.float32(-1), scale=np.float32(0.5)), reinterpreted_batch_ndims=0) with self.assertRaisesRegexp(ValueError, "Event shapes do not match"): kullback_leibler.kl_divergence(ind1, ind2) ind1 = independent_lib.Independent(distribution=normal_lib.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=1) ind2 = independent_lib.Independent( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=np.float32([-1., 1]), scale_diag=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=0) with self.assertRaisesRegexp(NotImplementedError, "different event shapes"): kullback_leibler.kl_divergence(ind1, ind2)
def testSampleAndLogProbMultivariate(self): loc = np.float32([[-1., 1], [1, -1]]) scale = np.float32([1., 0.5]) with self.cached_session() as sess: ind = independent_lib.Independent( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale), reinterpreted_batch_ndims=1) x = ind.sample([4, 5], seed=42) log_prob_x = ind.log_prob(x) x_, actual_log_prob_x = sess.run([x, log_prob_x]) self.assertEqual([], ind.batch_shape) self.assertEqual([2, 2], ind.event_shape) self.assertEqual([4, 5, 2, 2], x.shape) self.assertEqual([4, 5], log_prob_x.shape) expected_log_prob_x = stats.norm( loc, scale[:, None]).logpdf(x_).sum(-1).sum(-1) self.assertAllCloseAccordingToType(expected_log_prob_x, actual_log_prob_x)
def test_kl_forward_multidim(self): with self.test_session() as sess: d = 5 # Dimension p = mvn_full_lib.MultivariateNormalFullCovariance( covariance_matrix=self._tridiag( d, diag_value=1, offdiag_value=0.5)) # Variance is very high when approximating Forward KL, so we make # scale_diag larger than in test_kl_reverse_multidim. This ensures q # "covers" p and thus Var_q[p/q] is smaller. q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[1.] * d) approx_kl = cd.monte_carlo_csiszar_f_divergence(f=cd.kl_forward, p=p, q=q, num_draws=int(1e5), seed=1) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_forward(logu, self_normalized=True), p=p, q=q, num_draws=int(1e5), seed=1) exact_kl = kullback_leibler.kl_divergence(p, q) [approx_kl_, approx_kl_self_normalized_, exact_kl_ ] = sess.run([approx_kl, approx_kl_self_normalized, exact_kl]) self.assertAllClose(approx_kl_, exact_kl_, rtol=0.06, atol=0.) self.assertAllClose(approx_kl_self_normalized_, exact_kl_, rtol=0.05, atol=0.)
def test_vimco_and_gradient(self): with self.test_session() as sess: dims = 5 # Dimension num_draws = int(20) num_batch_draws = int(3) seed = 1 f = lambda logu: cd.kl_reverse(logu, self_normalized=False) np_f = lambda logu: -logu p = mvn_full_lib.MultivariateNormalFullCovariance( covariance_matrix=tridiag( dims, diag_value=1, offdiag_value=0.5)) # Variance is very high when approximating Forward KL, so we make # scale_diag larger than in test_kl_reverse_multidim. This ensures q # "covers" p and thus Var_q[p/q] is smaller. s = array_ops.constant(1.) q = mvn_diag_lib.MultivariateNormalDiag( scale_diag=array_ops.tile([s], [dims])) vimco = cd.csiszar_vimco(f=f, p_log_prob=p.log_prob, q=q, num_draws=num_draws, num_batch_draws=num_batch_draws, seed=seed) x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) x = array_ops.stop_gradient(x) logu = p.log_prob(x) - q.log_prob(x) f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0]) grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] def jacobian(x): # Warning: this function is slow and may not even finish if prod(shape) # is larger than, say, 100. shape = x.shape.as_list() assert all(s is not None for s in shape) x = array_ops.reshape(x, shape=[-1]) r = [grad_sum(x[i]) for i in range(np.prod(shape))] return array_ops.reshape(array_ops.stack(r), shape=shape) [ logu_, jacobian_logqx_, vimco_, grad_vimco_, f_log_sum_u_, grad_mean_f_log_sum_u_, ] = sess.run([ logu, jacobian(q.log_prob(x)), vimco, grad_sum(vimco), f_log_sum_u, grad_sum(f_log_sum_u) / num_batch_draws, ]) np_log_avg_u, np_log_sooavg_u = self._numpy_csiszar_vimco_helper( logu_) # Test VIMCO loss is correct. self.assertAllClose(np_f(np_log_avg_u).mean(axis=0), vimco_, rtol=1e-5, atol=0.) # Test gradient of VIMCO loss is correct. # # To make this computation we'll inject two gradients from TF: # - grad[mean(f(log(sum(p(x)/q(x)))))] # - jacobian[log(q(x))]. # # We now justify why using these (and only these) TF values for # ground-truth does not undermine the completeness of this test. # # Regarding `grad_mean_f_log_sum_u_`, note that we validate the # correctness of the zero-th order derivative (for each batch member). # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient # information, we can safely rely on TF. self.assertAllClose(np_f(np_log_avg_u), f_log_sum_u_, rtol=1e-4, atol=0.) # # Regarding `jacobian_logqx_`, note that testing the gradient of # `q.log_prob` is outside the scope of this unit-test thus we may safely # use TF to find it. # The `mean` is across batches and the `sum` is across iid samples. np_grad_vimco = (grad_mean_f_log_sum_u_ + np.mean(np.sum( jacobian_logqx_ * (np_f(np_log_avg_u) - np_f(np_log_sooavg_u)), axis=0), axis=0)) self.assertAllClose(np_grad_vimco, grad_vimco_, rtol=1e-5, atol=0.)
def test_score_trick(self): with self.test_session() as sess: d = 5 # Dimension num_draws = int(1e5) seed = 1 p = mvn_full_lib.MultivariateNormalFullCovariance( covariance_matrix=self._tridiag( d, diag_value=1, offdiag_value=0.5)) # Variance is very high when approximating Forward KL, so we make # scale_diag larger than in test_kl_reverse_multidim. This ensures q # "covers" p and thus Var_q[p/q] is smaller. s = array_ops.constant(1.) q = mvn_diag_lib.MultivariateNormalDiag( scale_diag=array_ops.tile([s], [d])) approx_kl = cd.monte_carlo_csiszar_f_divergence( f=cd.kl_reverse, p=p, q=q, num_draws=num_draws, seed=seed) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_reverse(logu, self_normalized=True), p=p, q=q, num_draws=num_draws, seed=seed) approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence( f=cd.kl_reverse, p=p, q=q, num_draws=num_draws, use_reparametrization=False, seed=seed) approx_kl_self_normalized_score_trick = ( cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_reverse(logu, self_normalized=True), p=p, q=q, num_draws=num_draws, use_reparametrization=False, seed=seed)) exact_kl = kullback_leibler.kl_divergence(q, p) grad = lambda fs: gradients_impl.gradients(fs, s)[0] [ approx_kl_, approx_kl_self_normalized_, approx_kl_score_trick_, approx_kl_self_normalized_score_trick_, exact_kl_, ] = sess.run([ grad(approx_kl), grad(approx_kl_self_normalized), grad(approx_kl_score_trick), grad(approx_kl_self_normalized_score_trick), grad(exact_kl), ]) self.assertAllClose(approx_kl_, exact_kl_, rtol=0.06, atol=0.) self.assertAllClose(approx_kl_self_normalized_, exact_kl_, rtol=0.05, atol=0.) self.assertAllClose(approx_kl_score_trick_, exact_kl_, rtol=0.06, atol=0.) self.assertAllClose(approx_kl_self_normalized_score_trick_, exact_kl_, rtol=0.05, atol=0.)