def test_importance(self):
        eps_samples = tf.convert_to_tensor(self._n01_samples)
        mu = tf.constant(2.)
        sigma = tf.constant(3.)
        qx_samples = tf.stop_gradient(eps_samples * sigma + mu)
        q = Normal(mean=mu, std=sigma)
        log_qx = q.log_prob(qx_samples)

        def _check_importance(x_mean, x_std, threshold):
            def log_joint(observed):
                p = Normal(mean=x_mean, std=x_std)
                return p.log_prob(observed['x'])

            klpq_obj = klpq(log_joint,
                            observed={},
                            latent={'x': [qx_samples, log_qx]},
                            axis=0)
            cost = klpq_obj.importance()
            importance_grads = tf.gradients(cost, [mu, sigma])
            true_cost = _kl_normal_normal(x_mean, x_std, mu, sigma)
            true_grads = tf.gradients(true_cost, [mu, sigma])

            with self.session(use_gpu=True) as sess:
                g1 = sess.run(importance_grads)
                g2 = sess.run(true_grads)
                # print('importance_grads:', g1)
                # print('true_grads:', g2)
                self.assertAllClose(g1, g2, threshold, threshold)

        _check_importance(0., 1., 0.01)
        _check_importance(2., 3., 0.02)

        single_sample = tf.stop_gradient(tf.random_normal([]) * sigma + mu)
        single_log_q = q.log_prob(single_sample)

        def log_joint(observed):
            p = Normal(std=1.)
            return p.log_prob(observed['x'])

        single_sample_obj = klpq(log_joint,
                                 observed={},
                                 latent={'x': [single_sample, single_log_q]})

        with warnings.catch_warnings(record=True) as w:
            # Cause all warnings to always be triggered.
            warnings.simplefilter("always")
            # Trigger a warning.
            single_sample_obj.importance()
            self.assertTrue(issubclass(w[-1].category, UserWarning))
            self.assertTrue("biased and inaccurate when you're using only "
                            "a single sample" in str(w[-1].message))
    def test_sgvb(self):
        eps_samples = tf.convert_to_tensor(self._n1_samples)
        mu = tf.constant(2.)
        sigma = tf.constant(3.)
        qx_samples = eps_samples * sigma + mu
        norm = Normal(mean=mu, std=sigma)
        log_qx = norm.log_prob(qx_samples)

        def _check_sgvb(x_mean, x_std, threshold):
            def log_joint(observed):
                norm = Normal(mean=x_mean, std=x_std)
                return norm.log_prob(observed['x'])

            lower_bound = importance_weighted_objective(
                log_joint,
                observed={},
                latent={'x': [qx_samples, log_qx]},
                axis=0)
            sgvb_cost = lower_bound.sgvb()
            sgvb_cost = tf.reduce_mean(sgvb_cost)
            sgvb_grads = tf.gradients(sgvb_cost, [mu, sigma])
            true_cost = _kl_normal_normal(mu, sigma, x_mean, x_std)
            true_grads = tf.gradients(true_cost, [mu, sigma])

            with self.session(use_gpu=True) as sess:
                g1 = sess.run(sgvb_grads)
                g2 = sess.run(true_grads)
                # print('sgvb_grads:', g1)
                # print('true_grads:', g2)
                self.assertAllClose(g1, g2, threshold, threshold)

        _check_sgvb(0., 1., 0.04)
        _check_sgvb(2., 3., 0.02)
    def test_reinforce(self):
        eps_samples = tf.convert_to_tensor(self._n01_samples)
        mu = tf.constant(2.)
        sigma = tf.constant(3.)
        qx_samples = tf.stop_gradient(eps_samples * sigma + mu)
        norm = Normal(mean=mu, std=sigma)
        log_qx = norm.log_prob(qx_samples)

        def _check_reinforce(x_mean, x_std, threshold):
            def log_joint(observed):
                norm = Normal(mean=x_mean, std=x_std)
                return norm.log_prob(observed['x'])

            lower_bound = elbo(log_joint,
                               observed={},
                               latent={'x': [qx_samples, log_qx]},
                               axis=0)
            # TODO: Check grads when use variance reduction and baseline
            reinforce_cost = lower_bound.reinforce(variance_reduction=False)
            reinforce_grads = tf.gradients(reinforce_cost, [mu, sigma])
            true_cost = _kl_normal_normal(mu, sigma, x_mean, x_std)
            true_grads = tf.gradients(true_cost, [mu, sigma])

            with self.test_session(use_gpu=True) as sess:
                sess.run(tf.global_variables_initializer())
                g1 = sess.run(reinforce_grads)
                g2 = sess.run(true_grads)
                # print('reinforce_grads:', g1)
                # print('true_grads:', g2)
                self.assertAllClose(g1, g2, threshold, threshold)

        _check_reinforce(0., 1., 0.03)
        # asymptotically no variance (p=q)
        _check_reinforce(2., 3., 1e-6)
    def test_sgvb(self):
        eps_samples = tf.convert_to_tensor(self._n01_samples)
        mu = tf.constant(2.)
        sigma = tf.constant(3.)
        qx_samples = eps_samples * sigma + mu
        norm = Normal(mean=mu, std=sigma)
        log_qx = norm.log_prob(qx_samples)

        def _check_sgvb(x_mean, x_std, threshold):
            def log_joint(observed):
                norm = Normal(mean=x_mean, std=x_std)
                return norm.log_prob(observed['x'])

            lower_bound = elbo(log_joint,
                               observed={},
                               latent={'x': [qx_samples, log_qx]},
                               axis=0)
            sgvb_cost = lower_bound.sgvb()
            sgvb_grads = tf.gradients(sgvb_cost, [mu, sigma])
            true_cost = _kl_normal_normal(mu, sigma, x_mean, x_std)
            true_grads = tf.gradients(true_cost, [mu, sigma])

            with self.test_session(use_gpu=True) as sess:
                g1 = sess.run(sgvb_grads)
                g2 = sess.run(true_grads)
                # print('sgvb_grads:', g1)
                # print('true_grads:', g2)
                self.assertAllClose(g1, g2, threshold, threshold)

        _check_sgvb(0., 1., 0.04)
        # 1e-6 would be good for sgvb if sticking the landing is used. (p=q)
        _check_sgvb(2., 3., 0.02)
    def test_vimco(self):
        eps_samples = tf.convert_to_tensor(self._n3_samples)
        mu = tf.constant(2.)
        sigma = tf.constant(3.)
        qx_samples = eps_samples * sigma + mu
        norm = Normal(mean=mu, std=sigma)
        log_qx = norm.log_prob(qx_samples)

        v_qx_samples = eps_samples * tf.stop_gradient(sigma) + \
            tf.stop_gradient(mu)
        v_log_qx = norm.log_prob(v_qx_samples)

        def _check_vimco(x_mean, x_std, threshold):
            def log_joint(observed):
                norm = Normal(mean=x_mean, std=x_std)
                return norm.log_prob(observed['x'])

            lower_bound = importance_weighted_objective(
                log_joint,
                observed={},
                latent={'x': [qx_samples, log_qx]},
                axis=0)
            v_lower_bound = importance_weighted_objective(
                log_joint,
                observed={},
                latent={'x': [v_qx_samples, v_log_qx]},
                axis=0)

            vimco_cost = v_lower_bound.vimco()
            vimco_cost = tf.reduce_mean(vimco_cost)
            vimco_grads = tf.gradients(vimco_cost, [mu, sigma])
            sgvb_cost = tf.reduce_mean(lower_bound.sgvb())
            sgvb_grads = tf.gradients(sgvb_cost, [mu, sigma])

            with self.session(use_gpu=True) as sess:
                g1 = sess.run(vimco_grads)
                g2 = sess.run(sgvb_grads)
                # print('vimco_grads:', g1)
                # print('sgvb_grads:', g2)
                self.assertAllClose(g1, g2, threshold, threshold)

        _check_vimco(0., 1., 1e-2)
        _check_vimco(2., 3., 1e-6)
 def log_joint(observed):
     norm = Normal(mean=x_mean, std=x_std)
     return norm.log_prob(observed['x'])
 def log_joint(observed):
     p = Normal(std=1.)
     return p.log_prob(observed['x'])
 def log_joint(observed):
     norm = Normal(std=1.)
     return norm.log_prob(observed['x'])