def test_kl_reverse_multidim(self):

        with self.test_session() as sess:
            d = 5  # Dimension

            p = mvn_full_lib.MultivariateNormalFullCovariance(
                covariance_matrix=self._tridiag(
                    d, diag_value=1, offdiag_value=0.5))

            q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5] * d)

            approx_kl = cd.monte_carlo_csiszar_f_divergence(f=cd.kl_reverse,
                                                            p=p,
                                                            q=q,
                                                            num_draws=int(1e5),
                                                            seed=1)

            approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
                f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
                p=p,
                q=q,
                num_draws=int(1e5),
                seed=1)

            exact_kl = kullback_leibler.kl_divergence(q, p)

            [approx_kl_, approx_kl_self_normalized_, exact_kl_
             ] = sess.run([approx_kl, approx_kl_self_normalized, exact_kl])

            self.assertAllClose(approx_kl_, exact_kl_, rtol=0.02, atol=0.)

            self.assertAllClose(approx_kl_self_normalized_,
                                exact_kl_,
                                rtol=0.08,
                                atol=0.)
    def test_kl_forward_multidim(self):

        with self.test_session() as sess:
            d = 5  # Dimension

            p = mvn_full_lib.MultivariateNormalFullCovariance(
                covariance_matrix=self._tridiag(
                    d, diag_value=1, offdiag_value=0.5))

            # Variance is very high when approximating Forward KL, so we make
            # scale_diag larger than in test_kl_reverse_multidim. This ensures q
            # "covers" p and thus Var_q[p/q] is smaller.
            q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[1.] * d)

            approx_kl = cd.monte_carlo_csiszar_f_divergence(f=cd.kl_forward,
                                                            p=p,
                                                            q=q,
                                                            num_draws=int(1e5),
                                                            seed=1)

            approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
                f=lambda logu: cd.kl_forward(logu, self_normalized=True),
                p=p,
                q=q,
                num_draws=int(1e5),
                seed=1)

            exact_kl = kullback_leibler.kl_divergence(p, q)

            [approx_kl_, approx_kl_self_normalized_, exact_kl_
             ] = sess.run([approx_kl, approx_kl_self_normalized, exact_kl])

            self.assertAllClose(approx_kl_, exact_kl_, rtol=0.06, atol=0.)

            self.assertAllClose(approx_kl_self_normalized_,
                                exact_kl_,
                                rtol=0.05,
                                atol=0.)
    def test_score_trick(self):

        with self.test_session() as sess:
            d = 5  # Dimension
            num_draws = int(1e5)
            seed = 1

            p = mvn_full_lib.MultivariateNormalFullCovariance(
                covariance_matrix=self._tridiag(
                    d, diag_value=1, offdiag_value=0.5))

            # Variance is very high when approximating Forward KL, so we make
            # scale_diag larger than in test_kl_reverse_multidim. This ensures q
            # "covers" p and thus Var_q[p/q] is smaller.
            s = array_ops.constant(1.)
            q = mvn_diag_lib.MultivariateNormalDiag(
                scale_diag=array_ops.tile([s], [d]))

            approx_kl = cd.monte_carlo_csiszar_f_divergence(
                f=cd.kl_reverse, p=p, q=q, num_draws=num_draws, seed=seed)

            approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
                f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
                p=p,
                q=q,
                num_draws=num_draws,
                seed=seed)

            approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence(
                f=cd.kl_reverse,
                p=p,
                q=q,
                num_draws=num_draws,
                use_reparametrization=False,
                seed=seed)

            approx_kl_self_normalized_score_trick = (
                cd.monte_carlo_csiszar_f_divergence(
                    f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
                    p=p,
                    q=q,
                    num_draws=num_draws,
                    use_reparametrization=False,
                    seed=seed))

            exact_kl = kullback_leibler.kl_divergence(q, p)

            grad = lambda fs: gradients_impl.gradients(fs, s)[0]

            [
                approx_kl_,
                approx_kl_self_normalized_,
                approx_kl_score_trick_,
                approx_kl_self_normalized_score_trick_,
                exact_kl_,
            ] = sess.run([
                grad(approx_kl),
                grad(approx_kl_self_normalized),
                grad(approx_kl_score_trick),
                grad(approx_kl_self_normalized_score_trick),
                grad(exact_kl),
            ])

            self.assertAllClose(approx_kl_, exact_kl_, rtol=0.06, atol=0.)

            self.assertAllClose(approx_kl_self_normalized_,
                                exact_kl_,
                                rtol=0.05,
                                atol=0.)

            self.assertAllClose(approx_kl_score_trick_,
                                exact_kl_,
                                rtol=0.06,
                                atol=0.)

            self.assertAllClose(approx_kl_self_normalized_score_trick_,
                                exact_kl_,
                                rtol=0.05,
                                atol=0.)
    def test_vimco_and_gradient(self):

        with self.test_session() as sess:
            dims = 5  # Dimension
            num_draws = int(20)
            num_batch_draws = int(3)
            seed = 1

            f = lambda logu: cd.kl_reverse(logu, self_normalized=False)
            np_f = lambda logu: -logu

            p = mvn_full_lib.MultivariateNormalFullCovariance(
                covariance_matrix=tridiag(
                    dims, diag_value=1, offdiag_value=0.5))

            # Variance is very high when approximating Forward KL, so we make
            # scale_diag larger than in test_kl_reverse_multidim. This ensures q
            # "covers" p and thus Var_q[p/q] is smaller.
            s = array_ops.constant(1.)
            q = mvn_diag_lib.MultivariateNormalDiag(
                scale_diag=array_ops.tile([s], [dims]))

            vimco = cd.csiszar_vimco(f=f,
                                     p_log_prob=p.log_prob,
                                     q=q,
                                     num_draws=num_draws,
                                     num_batch_draws=num_batch_draws,
                                     seed=seed)

            x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed)
            x = array_ops.stop_gradient(x)
            logu = p.log_prob(x) - q.log_prob(x)
            f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0])

            grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0]

            def jacobian(x):
                # Warning: this function is slow and may not even finish if prod(shape)
                # is larger than, say, 100.
                shape = x.shape.as_list()
                assert all(s is not None for s in shape)
                x = array_ops.reshape(x, shape=[-1])
                r = [grad_sum(x[i]) for i in range(np.prod(shape))]
                return array_ops.reshape(array_ops.stack(r), shape=shape)

            [
                logu_,
                jacobian_logqx_,
                vimco_,
                grad_vimco_,
                f_log_sum_u_,
                grad_mean_f_log_sum_u_,
            ] = sess.run([
                logu,
                jacobian(q.log_prob(x)),
                vimco,
                grad_sum(vimco),
                f_log_sum_u,
                grad_sum(f_log_sum_u) / num_batch_draws,
            ])

            np_log_avg_u, np_log_sooavg_u = self._numpy_csiszar_vimco_helper(
                logu_)

            # Test VIMCO loss is correct.
            self.assertAllClose(np_f(np_log_avg_u).mean(axis=0),
                                vimco_,
                                rtol=1e-5,
                                atol=0.)

            # Test gradient of VIMCO loss is correct.
            #
            # To make this computation we'll inject two gradients from TF:
            # - grad[mean(f(log(sum(p(x)/q(x)))))]
            # - jacobian[log(q(x))].
            #
            # We now justify why using these (and only these) TF values for
            # ground-truth does not undermine the completeness of this test.
            #
            # Regarding `grad_mean_f_log_sum_u_`, note that we validate the
            # correctness of the zero-th order derivative (for each batch member).
            # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient
            # information, we can safely rely on TF.
            self.assertAllClose(np_f(np_log_avg_u),
                                f_log_sum_u_,
                                rtol=1e-4,
                                atol=0.)
            #
            # Regarding `jacobian_logqx_`, note that testing the gradient of
            # `q.log_prob` is outside the scope of this unit-test thus we may safely
            # use TF to find it.

            # The `mean` is across batches and the `sum` is across iid samples.
            np_grad_vimco = (grad_mean_f_log_sum_u_ + np.mean(np.sum(
                jacobian_logqx_ * (np_f(np_log_avg_u) - np_f(np_log_sooavg_u)),
                axis=0),
                                                              axis=0))

            self.assertAllClose(np_grad_vimco, grad_vimco_, rtol=1e-5, atol=0.)