def test_kl_reverse_multidim(self): with self.test_session() as sess: d = 5 # Dimension p = mvn_full_lib.MultivariateNormalFullCovariance( covariance_matrix=self._tridiag( d, diag_value=1, offdiag_value=0.5)) q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5] * d) approx_kl = cd.monte_carlo_csiszar_f_divergence(f=cd.kl_reverse, p=p, q=q, num_draws=int(1e5), seed=1) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_reverse(logu, self_normalized=True), p=p, q=q, num_draws=int(1e5), seed=1) exact_kl = kullback_leibler.kl_divergence(q, p) [approx_kl_, approx_kl_self_normalized_, exact_kl_ ] = sess.run([approx_kl, approx_kl_self_normalized, exact_kl]) self.assertAllClose(approx_kl_, exact_kl_, rtol=0.02, atol=0.) self.assertAllClose(approx_kl_self_normalized_, exact_kl_, rtol=0.08, atol=0.)
def test_kl_forward_multidim(self): with self.test_session() as sess: d = 5 # Dimension p = mvn_full_lib.MultivariateNormalFullCovariance( covariance_matrix=self._tridiag( d, diag_value=1, offdiag_value=0.5)) # Variance is very high when approximating Forward KL, so we make # scale_diag larger than in test_kl_reverse_multidim. This ensures q # "covers" p and thus Var_q[p/q] is smaller. q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[1.] * d) approx_kl = cd.monte_carlo_csiszar_f_divergence(f=cd.kl_forward, p=p, q=q, num_draws=int(1e5), seed=1) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_forward(logu, self_normalized=True), p=p, q=q, num_draws=int(1e5), seed=1) exact_kl = kullback_leibler.kl_divergence(p, q) [approx_kl_, approx_kl_self_normalized_, exact_kl_ ] = sess.run([approx_kl, approx_kl_self_normalized, exact_kl]) self.assertAllClose(approx_kl_, exact_kl_, rtol=0.06, atol=0.) self.assertAllClose(approx_kl_self_normalized_, exact_kl_, rtol=0.05, atol=0.)
def test_score_trick(self): with self.test_session() as sess: d = 5 # Dimension num_draws = int(1e5) seed = 1 p = mvn_full_lib.MultivariateNormalFullCovariance( covariance_matrix=self._tridiag( d, diag_value=1, offdiag_value=0.5)) # Variance is very high when approximating Forward KL, so we make # scale_diag larger than in test_kl_reverse_multidim. This ensures q # "covers" p and thus Var_q[p/q] is smaller. s = array_ops.constant(1.) q = mvn_diag_lib.MultivariateNormalDiag( scale_diag=array_ops.tile([s], [d])) approx_kl = cd.monte_carlo_csiszar_f_divergence( f=cd.kl_reverse, p=p, q=q, num_draws=num_draws, seed=seed) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_reverse(logu, self_normalized=True), p=p, q=q, num_draws=num_draws, seed=seed) approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence( f=cd.kl_reverse, p=p, q=q, num_draws=num_draws, use_reparametrization=False, seed=seed) approx_kl_self_normalized_score_trick = ( cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_reverse(logu, self_normalized=True), p=p, q=q, num_draws=num_draws, use_reparametrization=False, seed=seed)) exact_kl = kullback_leibler.kl_divergence(q, p) grad = lambda fs: gradients_impl.gradients(fs, s)[0] [ approx_kl_, approx_kl_self_normalized_, approx_kl_score_trick_, approx_kl_self_normalized_score_trick_, exact_kl_, ] = sess.run([ grad(approx_kl), grad(approx_kl_self_normalized), grad(approx_kl_score_trick), grad(approx_kl_self_normalized_score_trick), grad(exact_kl), ]) self.assertAllClose(approx_kl_, exact_kl_, rtol=0.06, atol=0.) self.assertAllClose(approx_kl_self_normalized_, exact_kl_, rtol=0.05, atol=0.) self.assertAllClose(approx_kl_score_trick_, exact_kl_, rtol=0.06, atol=0.) self.assertAllClose(approx_kl_self_normalized_score_trick_, exact_kl_, rtol=0.05, atol=0.)
def test_vimco_and_gradient(self): with self.test_session() as sess: dims = 5 # Dimension num_draws = int(20) num_batch_draws = int(3) seed = 1 f = lambda logu: cd.kl_reverse(logu, self_normalized=False) np_f = lambda logu: -logu p = mvn_full_lib.MultivariateNormalFullCovariance( covariance_matrix=tridiag( dims, diag_value=1, offdiag_value=0.5)) # Variance is very high when approximating Forward KL, so we make # scale_diag larger than in test_kl_reverse_multidim. This ensures q # "covers" p and thus Var_q[p/q] is smaller. s = array_ops.constant(1.) q = mvn_diag_lib.MultivariateNormalDiag( scale_diag=array_ops.tile([s], [dims])) vimco = cd.csiszar_vimco(f=f, p_log_prob=p.log_prob, q=q, num_draws=num_draws, num_batch_draws=num_batch_draws, seed=seed) x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) x = array_ops.stop_gradient(x) logu = p.log_prob(x) - q.log_prob(x) f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0]) grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] def jacobian(x): # Warning: this function is slow and may not even finish if prod(shape) # is larger than, say, 100. shape = x.shape.as_list() assert all(s is not None for s in shape) x = array_ops.reshape(x, shape=[-1]) r = [grad_sum(x[i]) for i in range(np.prod(shape))] return array_ops.reshape(array_ops.stack(r), shape=shape) [ logu_, jacobian_logqx_, vimco_, grad_vimco_, f_log_sum_u_, grad_mean_f_log_sum_u_, ] = sess.run([ logu, jacobian(q.log_prob(x)), vimco, grad_sum(vimco), f_log_sum_u, grad_sum(f_log_sum_u) / num_batch_draws, ]) np_log_avg_u, np_log_sooavg_u = self._numpy_csiszar_vimco_helper( logu_) # Test VIMCO loss is correct. self.assertAllClose(np_f(np_log_avg_u).mean(axis=0), vimco_, rtol=1e-5, atol=0.) # Test gradient of VIMCO loss is correct. # # To make this computation we'll inject two gradients from TF: # - grad[mean(f(log(sum(p(x)/q(x)))))] # - jacobian[log(q(x))]. # # We now justify why using these (and only these) TF values for # ground-truth does not undermine the completeness of this test. # # Regarding `grad_mean_f_log_sum_u_`, note that we validate the # correctness of the zero-th order derivative (for each batch member). # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient # information, we can safely rely on TF. self.assertAllClose(np_f(np_log_avg_u), f_log_sum_u_, rtol=1e-4, atol=0.) # # Regarding `jacobian_logqx_`, note that testing the gradient of # `q.log_prob` is outside the scope of this unit-test thus we may safely # use TF to find it. # The `mean` is across batches and the `sum` is across iid samples. np_grad_vimco = (grad_mean_f_log_sum_u_ + np.mean(np.sum( jacobian_logqx_ * (np_f(np_log_avg_u) - np_f(np_log_sooavg_u)), axis=0), axis=0)) self.assertAllClose(np_grad_vimco, grad_vimco_, rtol=1e-5, atol=0.)