Esempio n. 1
0
  def test_convergence_to_kl_using_sample_form_on_3dim_normal(self):
    # Test that the sample mean KL is the same as analytic when we use samples
    # to estimate every part of the KL divergence ratio.
    vector_shape = (2, 3)
    n_samples = 5000

    with self.test_session():
      q = mvn_diag_lib.MultivariateNormalDiag(
          loc=self._rng.rand(*vector_shape),
          scale_diag=self._rng.rand(*vector_shape))
      p = mvn_diag_lib.MultivariateNormalDiag(
          loc=self._rng.rand(*vector_shape),
          scale_diag=self._rng.rand(*vector_shape))

      # In this case, the log_ratio is the KL.
      sample_kl = -1 * entropy.elbo_ratio(
          log_p=p.log_prob,
          q=q,
          n=n_samples,
          form=entropy.ELBOForms.sample,
          seed=42)
      actual_kl = kullback_leibler_lib.kl_divergence(q, p)

      # Relative tolerance (rtol) chosen 2 times as large as minimim needed to
      # pass.
      self.assertEqual((2,), sample_kl.get_shape())
      self.assertAllClose(actual_kl.eval(), sample_kl.eval(), rtol=0.05)
    def test_multivariate_normal_prob_positive_product_of_components(self):
        # Test that importance sampling can correctly estimate the probability that
        # the product of components in a MultivariateNormal are > 0.
        n = 1000
        with self.cached_session():
            p = mvn_diag_lib.MultivariateNormalDiag(loc=[0.],
                                                    scale_diag=[1.0, 1.0])
            q = mvn_diag_lib.MultivariateNormalDiag(loc=[0.5],
                                                    scale_diag=[3., 3.])

            # Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x).
            # Should equal 1/2 because p is a spherical Gaussian centered at (0, 0).
            def indicator(x):
                x1_times_x2 = math_ops.reduce_prod(x, axis=[-1])
                return 0.5 * (math_ops.sign(x1_times_x2) + 1.0)

            prob = mc.expectation_importance_sampler(f=indicator,
                                                     log_p=p.log_prob,
                                                     sampling_dist_q=q,
                                                     n=n,
                                                     seed=42)

            # Relative tolerance (rtol) chosen 2 times as large as minimim needed to
            # pass.
            # Convergence is +- 0.004 if n = 100k.
            self.assertEqual(p.batch_shape, prob.get_shape())
            self.assertAllClose(0.5, prob.eval(), rtol=0.05)
Esempio n. 3
0
    def test_non_vector_shape(self):
        dims = 2
        new_batch_shape = 2
        old_batch_shape = [2]

        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)

        if self.is_static_shape:
            with self.assertRaisesRegexp(ValueError, r".*must be a vector.*"):
                batch_reshape_lib.BatchReshape(distribution=mvn,
                                               batch_shape=new_batch_shape_ph,
                                               validate_args=True)

        else:
            with self.cached_session():
                with self.assertRaisesOpError(r".*must be a vector.*"):
                    batch_reshape_lib.BatchReshape(
                        distribution=mvn,
                        batch_shape=new_batch_shape_ph,
                        validate_args=True).sample().eval()
Esempio n. 4
0
 def testSampleConsistentMeanCovariance(self):
   with self.test_session() as sess:
     gm = mixture_same_family_lib.MixtureSameFamily(
         mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
         components_distribution=mvn_diag_lib.MultivariateNormalDiag(
             loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
     self.run_test_sample_consistent_mean_covariance(sess.run, gm)
Esempio n. 5
0
    def test_bad_reshape_size(self):
        dims = 2
        new_batch_shape = [2, 3]
        old_batch_shape = [2]  # 2 != 2*3

        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)

        if self.is_static_shape:
            with self.assertRaisesRegexp(
                    ValueError, (r"`batch_shape` size \(6\) must match "
                                 r"`distribution\.batch_shape` size \(2\)")):
                batch_reshape_lib.BatchReshape(distribution=mvn,
                                               batch_shape=new_batch_shape_ph,
                                               validate_args=True)

        else:
            with self.cached_session():
                with self.assertRaisesOpError(r"Shape sizes do not match."):
                    batch_reshape_lib.BatchReshape(
                        distribution=mvn,
                        batch_shape=new_batch_shape_ph,
                        validate_args=True).sample().eval()
Esempio n. 6
0
    def test_non_positive_shape(self):
        dims = 2
        old_batch_shape = [4]
        if self.is_static_shape:
            # Unknown first dimension does not trigger size check. Note that
            # any dimension < 0 is treated statically as unknown.
            new_batch_shape = [-1, 0]
        else:
            new_batch_shape = [-2,
                               -2]  # -2 * -2 = 4, same size as the old shape.

        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)

        if self.is_static_shape:
            with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"):
                batch_reshape_lib.BatchReshape(distribution=mvn,
                                               batch_shape=new_batch_shape_ph,
                                               validate_args=True)

        else:
            with self.cached_session():
                with self.assertRaisesOpError(r".*must be >=-1.*"):
                    batch_reshape_lib.BatchReshape(
                        distribution=mvn,
                        batch_shape=new_batch_shape_ph,
                        validate_args=True).sample().eval()
    def test_kl_reverse_multidim(self):

        with self.test_session() as sess:
            d = 5  # Dimension

            p = mvn_full_lib.MultivariateNormalFullCovariance(
                covariance_matrix=self._tridiag(
                    d, diag_value=1, offdiag_value=0.5))

            q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5] * d)

            approx_kl = cd.monte_carlo_csiszar_f_divergence(f=cd.kl_reverse,
                                                            p=p,
                                                            q=q,
                                                            num_draws=int(1e5),
                                                            seed=1)

            approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
                f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
                p=p,
                q=q,
                num_draws=int(1e5),
                seed=1)

            exact_kl = kullback_leibler.kl_divergence(q, p)

            [approx_kl_, approx_kl_self_normalized_, exact_kl_
             ] = sess.run([approx_kl, approx_kl_self_normalized, exact_kl])

            self.assertAllClose(approx_kl_, exact_kl_, rtol=0.02, atol=0.)

            self.assertAllClose(approx_kl_self_normalized_,
                                exact_kl_,
                                rtol=0.08,
                                atol=0.)
Esempio n. 8
0
 def testVarianceConsistentCovariance(self):
   with self.test_session() as sess:
     gm = mixture_same_family_lib.MixtureSameFamily(
         mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
         components_distribution=mvn_diag_lib.MultivariateNormalDiag(
             loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
     cov_, var_ = sess.run([gm.covariance(), gm.variance()])
     self.assertAllClose(cov_.diagonal(), var_, atol=0.)
Esempio n. 9
0
 def testSampleAndLogProbMultivariateShapes(self):
   with self.test_session():
     gm = mixture_same_family_lib.MixtureSameFamily(
         mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
         components_distribution=mvn_diag_lib.MultivariateNormalDiag(
             loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
     x = gm.sample([4, 5], seed=42)
     log_prob_x = gm.log_prob(x)
     self.assertEqual([4, 5, 2], x.shape)
     self.assertEqual([4, 5], log_prob_x.shape)
Esempio n. 10
0
    def testSampleConsistentStats(self):
        loc = np.float32([[-1., 1], [1, -1]])
        scale = np.float32([1., 0.5])
        n_samp = 1e4
        with self.cached_session() as sess:
            ind = independent_lib.Independent(
                distribution=mvn_diag_lib.MultivariateNormalDiag(
                    loc=loc, scale_identity_multiplier=scale),
                reinterpreted_batch_ndims=1)

            x = ind.sample(int(n_samp), seed=42)
            sample_mean = math_ops.reduce_mean(x, axis=0)
            sample_var = math_ops.reduce_mean(math_ops.squared_difference(
                x, sample_mean),
                                              axis=0)
            sample_std = math_ops.sqrt(sample_var)
            sample_entropy = -math_ops.reduce_mean(ind.log_prob(x), axis=0)

            [
                sample_mean_,
                sample_var_,
                sample_std_,
                sample_entropy_,
                actual_mean_,
                actual_var_,
                actual_std_,
                actual_entropy_,
                actual_mode_,
            ] = sess.run([
                sample_mean,
                sample_var,
                sample_std,
                sample_entropy,
                ind.mean(),
                ind.variance(),
                ind.stddev(),
                ind.entropy(),
                ind.mode(),
            ])

            self.assertAllCloseAccordingToType(sample_mean_,
                                               actual_mean_,
                                               rtol=0.02)
            self.assertAllCloseAccordingToType(sample_var_,
                                               actual_var_,
                                               rtol=0.04)
            self.assertAllCloseAccordingToType(sample_std_,
                                               actual_std_,
                                               rtol=0.02)
            self.assertAllCloseAccordingToType(sample_entropy_,
                                               actual_entropy_,
                                               rtol=0.01)
            self.assertAllCloseAccordingToType(loc, actual_mode_, rtol=1e-6)
Esempio n. 11
0
    def testKLMultivariateToMultivariate(self):
        # (1, 1, 2) batch of MVNDiag
        mvn1 = mvn_diag_lib.MultivariateNormalDiag(
            loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]),
            scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]]))
        ind1 = independent_lib.Independent(distribution=mvn1,
                                           reinterpreted_batch_ndims=2)

        # (1, 1, 2) batch of MVNDiag
        mvn2 = mvn_diag_lib.MultivariateNormalDiag(
            loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]),
            scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]]))

        ind2 = independent_lib.Independent(distribution=mvn2,
                                           reinterpreted_batch_ndims=2)

        mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2)
        ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
        self.assertAllClose(
            self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])),
            self.evaluate(ind_kl))
Esempio n. 12
0
 def testSampleConsistentLogProb(self):
   with self.test_session() as sess:
     gm = mixture_same_family_lib.MixtureSameFamily(
         mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
         components_distribution=mvn_diag_lib.MultivariateNormalDiag(
             loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
     # Ball centered at component0's mean.
     self.run_test_sample_consistent_log_prob(
         sess.run, gm, radius=1., center=[-1., 1], rtol=0.02)
     # Larger ball centered at component1's mean.
     self.run_test_sample_consistent_log_prob(
         sess.run, gm, radius=1., center=[1., -1], rtol=0.02)
Esempio n. 13
0
  def test_divergence_between_identical_distributions_is_zero(self):
    n = 1000
    vector_shape = (2, 3)
    with self.test_session():
      q = mvn_diag_lib.MultivariateNormalDiag(
          loc=self._rng.rand(*vector_shape),
          scale_diag=self._rng.rand(*vector_shape))
      for alpha in [0.25, 0.75]:

        negative_renyi_divergence = entropy.renyi_ratio(
            log_p=q.log_prob, q=q, n=n, alpha=alpha, seed=0)

        self.assertEqual((2,), negative_renyi_divergence.get_shape())
        self.assertAllClose(np.zeros(2), negative_renyi_divergence.eval())
Esempio n. 14
0
    def test_pad_mixture_dimensions_mixture_same_family(self):
        with self.test_session() as sess:
            gm = mixture_same_family.MixtureSameFamily(
                mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]),
                components_distribution=mvn_diag.MultivariateNormalDiag(
                    loc=[[-1., 1], [1, -1]],
                    scale_identity_multiplier=[1.0, 0.5]))

            x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]])
            x_pad = distribution_util.pad_mixture_dimensions(
                x, gm, gm.mixture_distribution, gm.event_shape.ndims)
            x_out, x_pad_out = sess.run([x, x_pad])

        self.assertAllEqual(x_pad_out.shape, [2, 2, 1])
        self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
Esempio n. 15
0
    def make_mvn(self, dims, new_batch_shape, old_batch_shape):
        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)
        reshape_mvn = batch_reshape_lib.BatchReshape(
            distribution=mvn,
            batch_shape=new_batch_shape_ph,
            validate_args=True)
        return mvn, reshape_mvn
Esempio n. 16
0
  def test_sample_kl_zero_when_p_and_q_are_the_same_distribution(self):
    n_samples = 50

    vector_shape = (2, 3)
    with self.test_session():
      q = mvn_diag_lib.MultivariateNormalDiag(
          loc=self._rng.rand(*vector_shape),
          scale_diag=self._rng.rand(*vector_shape))

      # In this case, the log_ratio is the KL.
      sample_kl = -1 * entropy.elbo_ratio(
          log_p=q.log_prob,
          q=q,
          n=n_samples,
          form=entropy.ELBOForms.sample,
          seed=42)

      self.assertEqual((2,), sample_kl.get_shape())
      self.assertAllClose(np.zeros(2), sample_kl.eval())
Esempio n. 17
0
    def testKLRaises(self):
        ind1 = independent_lib.Independent(distribution=normal_lib.Normal(
            loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])),
                                           reinterpreted_batch_ndims=1)
        ind2 = independent_lib.Independent(distribution=normal_lib.Normal(
            loc=np.float32(-1), scale=np.float32(0.5)),
                                           reinterpreted_batch_ndims=0)

        with self.assertRaisesRegexp(ValueError, "Event shapes do not match"):
            kullback_leibler.kl_divergence(ind1, ind2)

        ind1 = independent_lib.Independent(distribution=normal_lib.Normal(
            loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])),
                                           reinterpreted_batch_ndims=1)
        ind2 = independent_lib.Independent(
            distribution=mvn_diag_lib.MultivariateNormalDiag(
                loc=np.float32([-1., 1]), scale_diag=np.float32([0.1, 0.5])),
            reinterpreted_batch_ndims=0)

        with self.assertRaisesRegexp(NotImplementedError,
                                     "different event shapes"):
            kullback_leibler.kl_divergence(ind1, ind2)
Esempio n. 18
0
    def testSampleAndLogProbMultivariate(self):
        loc = np.float32([[-1., 1], [1, -1]])
        scale = np.float32([1., 0.5])
        with self.cached_session() as sess:
            ind = independent_lib.Independent(
                distribution=mvn_diag_lib.MultivariateNormalDiag(
                    loc=loc, scale_identity_multiplier=scale),
                reinterpreted_batch_ndims=1)

            x = ind.sample([4, 5], seed=42)
            log_prob_x = ind.log_prob(x)
            x_, actual_log_prob_x = sess.run([x, log_prob_x])

            self.assertEqual([], ind.batch_shape)
            self.assertEqual([2, 2], ind.event_shape)
            self.assertEqual([4, 5, 2, 2], x.shape)
            self.assertEqual([4, 5], log_prob_x.shape)

            expected_log_prob_x = stats.norm(
                loc, scale[:, None]).logpdf(x_).sum(-1).sum(-1)
            self.assertAllCloseAccordingToType(expected_log_prob_x,
                                               actual_log_prob_x)
    def test_kl_forward_multidim(self):

        with self.test_session() as sess:
            d = 5  # Dimension

            p = mvn_full_lib.MultivariateNormalFullCovariance(
                covariance_matrix=self._tridiag(
                    d, diag_value=1, offdiag_value=0.5))

            # Variance is very high when approximating Forward KL, so we make
            # scale_diag larger than in test_kl_reverse_multidim. This ensures q
            # "covers" p and thus Var_q[p/q] is smaller.
            q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[1.] * d)

            approx_kl = cd.monte_carlo_csiszar_f_divergence(f=cd.kl_forward,
                                                            p=p,
                                                            q=q,
                                                            num_draws=int(1e5),
                                                            seed=1)

            approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
                f=lambda logu: cd.kl_forward(logu, self_normalized=True),
                p=p,
                q=q,
                num_draws=int(1e5),
                seed=1)

            exact_kl = kullback_leibler.kl_divergence(p, q)

            [approx_kl_, approx_kl_self_normalized_, exact_kl_
             ] = sess.run([approx_kl, approx_kl_self_normalized, exact_kl])

            self.assertAllClose(approx_kl_, exact_kl_, rtol=0.06, atol=0.)

            self.assertAllClose(approx_kl_self_normalized_,
                                exact_kl_,
                                rtol=0.05,
                                atol=0.)
    def test_vimco_and_gradient(self):

        with self.test_session() as sess:
            dims = 5  # Dimension
            num_draws = int(20)
            num_batch_draws = int(3)
            seed = 1

            f = lambda logu: cd.kl_reverse(logu, self_normalized=False)
            np_f = lambda logu: -logu

            p = mvn_full_lib.MultivariateNormalFullCovariance(
                covariance_matrix=tridiag(
                    dims, diag_value=1, offdiag_value=0.5))

            # Variance is very high when approximating Forward KL, so we make
            # scale_diag larger than in test_kl_reverse_multidim. This ensures q
            # "covers" p and thus Var_q[p/q] is smaller.
            s = array_ops.constant(1.)
            q = mvn_diag_lib.MultivariateNormalDiag(
                scale_diag=array_ops.tile([s], [dims]))

            vimco = cd.csiszar_vimco(f=f,
                                     p_log_prob=p.log_prob,
                                     q=q,
                                     num_draws=num_draws,
                                     num_batch_draws=num_batch_draws,
                                     seed=seed)

            x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed)
            x = array_ops.stop_gradient(x)
            logu = p.log_prob(x) - q.log_prob(x)
            f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0])

            grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0]

            def jacobian(x):
                # Warning: this function is slow and may not even finish if prod(shape)
                # is larger than, say, 100.
                shape = x.shape.as_list()
                assert all(s is not None for s in shape)
                x = array_ops.reshape(x, shape=[-1])
                r = [grad_sum(x[i]) for i in range(np.prod(shape))]
                return array_ops.reshape(array_ops.stack(r), shape=shape)

            [
                logu_,
                jacobian_logqx_,
                vimco_,
                grad_vimco_,
                f_log_sum_u_,
                grad_mean_f_log_sum_u_,
            ] = sess.run([
                logu,
                jacobian(q.log_prob(x)),
                vimco,
                grad_sum(vimco),
                f_log_sum_u,
                grad_sum(f_log_sum_u) / num_batch_draws,
            ])

            np_log_avg_u, np_log_sooavg_u = self._numpy_csiszar_vimco_helper(
                logu_)

            # Test VIMCO loss is correct.
            self.assertAllClose(np_f(np_log_avg_u).mean(axis=0),
                                vimco_,
                                rtol=1e-5,
                                atol=0.)

            # Test gradient of VIMCO loss is correct.
            #
            # To make this computation we'll inject two gradients from TF:
            # - grad[mean(f(log(sum(p(x)/q(x)))))]
            # - jacobian[log(q(x))].
            #
            # We now justify why using these (and only these) TF values for
            # ground-truth does not undermine the completeness of this test.
            #
            # Regarding `grad_mean_f_log_sum_u_`, note that we validate the
            # correctness of the zero-th order derivative (for each batch member).
            # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient
            # information, we can safely rely on TF.
            self.assertAllClose(np_f(np_log_avg_u),
                                f_log_sum_u_,
                                rtol=1e-4,
                                atol=0.)
            #
            # Regarding `jacobian_logqx_`, note that testing the gradient of
            # `q.log_prob` is outside the scope of this unit-test thus we may safely
            # use TF to find it.

            # The `mean` is across batches and the `sum` is across iid samples.
            np_grad_vimco = (grad_mean_f_log_sum_u_ + np.mean(np.sum(
                jacobian_logqx_ * (np_f(np_log_avg_u) - np_f(np_log_sooavg_u)),
                axis=0),
                                                              axis=0))

            self.assertAllClose(np_grad_vimco, grad_vimco_, rtol=1e-5, atol=0.)
    def test_score_trick(self):

        with self.test_session() as sess:
            d = 5  # Dimension
            num_draws = int(1e5)
            seed = 1

            p = mvn_full_lib.MultivariateNormalFullCovariance(
                covariance_matrix=self._tridiag(
                    d, diag_value=1, offdiag_value=0.5))

            # Variance is very high when approximating Forward KL, so we make
            # scale_diag larger than in test_kl_reverse_multidim. This ensures q
            # "covers" p and thus Var_q[p/q] is smaller.
            s = array_ops.constant(1.)
            q = mvn_diag_lib.MultivariateNormalDiag(
                scale_diag=array_ops.tile([s], [d]))

            approx_kl = cd.monte_carlo_csiszar_f_divergence(
                f=cd.kl_reverse, p=p, q=q, num_draws=num_draws, seed=seed)

            approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
                f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
                p=p,
                q=q,
                num_draws=num_draws,
                seed=seed)

            approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence(
                f=cd.kl_reverse,
                p=p,
                q=q,
                num_draws=num_draws,
                use_reparametrization=False,
                seed=seed)

            approx_kl_self_normalized_score_trick = (
                cd.monte_carlo_csiszar_f_divergence(
                    f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
                    p=p,
                    q=q,
                    num_draws=num_draws,
                    use_reparametrization=False,
                    seed=seed))

            exact_kl = kullback_leibler.kl_divergence(q, p)

            grad = lambda fs: gradients_impl.gradients(fs, s)[0]

            [
                approx_kl_,
                approx_kl_self_normalized_,
                approx_kl_score_trick_,
                approx_kl_self_normalized_score_trick_,
                exact_kl_,
            ] = sess.run([
                grad(approx_kl),
                grad(approx_kl_self_normalized),
                grad(approx_kl_score_trick),
                grad(approx_kl_self_normalized_score_trick),
                grad(exact_kl),
            ])

            self.assertAllClose(approx_kl_, exact_kl_, rtol=0.06, atol=0.)

            self.assertAllClose(approx_kl_self_normalized_,
                                exact_kl_,
                                rtol=0.05,
                                atol=0.)

            self.assertAllClose(approx_kl_score_trick_,
                                exact_kl_,
                                rtol=0.06,
                                atol=0.)

            self.assertAllClose(approx_kl_self_normalized_score_trick_,
                                exact_kl_,
                                rtol=0.05,
                                atol=0.)