def test_sgvb(self):
        eps_samples = tf.convert_to_tensor(self._n1_samples)
        mu = tf.constant(2.)
        sigma = tf.constant(3.)
        qx_samples = eps_samples * sigma + mu
        norm = Normal(mean=mu, std=sigma)
        log_qx = norm.log_prob(qx_samples)

        def _check_sgvb(x_mean, x_std, threshold):
            def log_joint(observed):
                norm = Normal(mean=x_mean, std=x_std)
                return norm.log_prob(observed['x'])

            lower_bound = importance_weighted_objective(
                log_joint,
                observed={},
                latent={'x': [qx_samples, log_qx]},
                axis=0)
            sgvb_cost = lower_bound.sgvb()
            sgvb_cost = tf.reduce_mean(sgvb_cost)
            sgvb_grads = tf.gradients(sgvb_cost, [mu, sigma])
            true_cost = _kl_normal_normal(mu, sigma, x_mean, x_std)
            true_grads = tf.gradients(true_cost, [mu, sigma])

            with self.session(use_gpu=True) as sess:
                g1 = sess.run(sgvb_grads)
                g2 = sess.run(true_grads)
                # print('sgvb_grads:', g1)
                # print('true_grads:', g2)
                self.assertAllClose(g1, g2, threshold, threshold)

        _check_sgvb(0., 1., 0.04)
        _check_sgvb(2., 3., 0.02)
    def test_reinforce(self):
        eps_samples = tf.convert_to_tensor(self._n01_samples)
        mu = tf.constant(2.)
        sigma = tf.constant(3.)
        qx_samples = tf.stop_gradient(eps_samples * sigma + mu)
        norm = Normal(mean=mu, std=sigma)
        log_qx = norm.log_prob(qx_samples)

        def _check_reinforce(x_mean, x_std, threshold):
            def log_joint(observed):
                norm = Normal(mean=x_mean, std=x_std)
                return norm.log_prob(observed['x'])

            lower_bound = elbo(log_joint,
                               observed={},
                               latent={'x': [qx_samples, log_qx]},
                               axis=0)
            # TODO: Check grads when use variance reduction and baseline
            reinforce_cost = lower_bound.reinforce(variance_reduction=False)
            reinforce_grads = tf.gradients(reinforce_cost, [mu, sigma])
            true_cost = _kl_normal_normal(mu, sigma, x_mean, x_std)
            true_grads = tf.gradients(true_cost, [mu, sigma])

            with self.test_session(use_gpu=True) as sess:
                sess.run(tf.global_variables_initializer())
                g1 = sess.run(reinforce_grads)
                g2 = sess.run(true_grads)
                # print('reinforce_grads:', g1)
                # print('true_grads:', g2)
                self.assertAllClose(g1, g2, threshold, threshold)

        _check_reinforce(0., 1., 0.03)
        # asymptotically no variance (p=q)
        _check_reinforce(2., 3., 1e-6)
    def test_sgvb(self):
        eps_samples = tf.convert_to_tensor(self._n01_samples)
        mu = tf.constant(2.)
        sigma = tf.constant(3.)
        qx_samples = eps_samples * sigma + mu
        norm = Normal(mean=mu, std=sigma)
        log_qx = norm.log_prob(qx_samples)

        def _check_sgvb(x_mean, x_std, threshold):
            def log_joint(observed):
                norm = Normal(mean=x_mean, std=x_std)
                return norm.log_prob(observed['x'])

            lower_bound = elbo(log_joint,
                               observed={},
                               latent={'x': [qx_samples, log_qx]},
                               axis=0)
            sgvb_cost = lower_bound.sgvb()
            sgvb_grads = tf.gradients(sgvb_cost, [mu, sigma])
            true_cost = _kl_normal_normal(mu, sigma, x_mean, x_std)
            true_grads = tf.gradients(true_cost, [mu, sigma])

            with self.test_session(use_gpu=True) as sess:
                g1 = sess.run(sgvb_grads)
                g2 = sess.run(true_grads)
                # print('sgvb_grads:', g1)
                # print('true_grads:', g2)
                self.assertAllClose(g1, g2, threshold, threshold)

        _check_sgvb(0., 1., 0.04)
        # 1e-6 would be good for sgvb if sticking the landing is used. (p=q)
        _check_sgvb(2., 3., 0.02)
    def test_importance(self):
        eps_samples = tf.convert_to_tensor(self._n01_samples)
        mu = tf.constant(2.)
        sigma = tf.constant(3.)
        qx_samples = tf.stop_gradient(eps_samples * sigma + mu)
        q = Normal(mean=mu, std=sigma)
        log_qx = q.log_prob(qx_samples)

        def _check_importance(x_mean, x_std, threshold):
            def log_joint(observed):
                p = Normal(mean=x_mean, std=x_std)
                return p.log_prob(observed['x'])

            klpq_obj = klpq(log_joint,
                            observed={},
                            latent={'x': [qx_samples, log_qx]},
                            axis=0)
            cost = klpq_obj.importance()
            importance_grads = tf.gradients(cost, [mu, sigma])
            true_cost = _kl_normal_normal(x_mean, x_std, mu, sigma)
            true_grads = tf.gradients(true_cost, [mu, sigma])

            with self.session(use_gpu=True) as sess:
                g1 = sess.run(importance_grads)
                g2 = sess.run(true_grads)
                # print('importance_grads:', g1)
                # print('true_grads:', g2)
                self.assertAllClose(g1, g2, threshold, threshold)

        _check_importance(0., 1., 0.01)
        _check_importance(2., 3., 0.02)

        single_sample = tf.stop_gradient(tf.random_normal([]) * sigma + mu)
        single_log_q = q.log_prob(single_sample)

        def log_joint(observed):
            p = Normal(std=1.)
            return p.log_prob(observed['x'])

        single_sample_obj = klpq(log_joint,
                                 observed={},
                                 latent={'x': [single_sample, single_log_q]})

        with warnings.catch_warnings(record=True) as w:
            # Cause all warnings to always be triggered.
            warnings.simplefilter("always")
            # Trigger a warning.
            single_sample_obj.importance()
            self.assertTrue(issubclass(w[-1].category, UserWarning))
            self.assertTrue("biased and inaccurate when you're using only "
                            "a single sample" in str(w[-1].message))
    def test_vimco(self):
        eps_samples = tf.convert_to_tensor(self._n3_samples)
        mu = tf.constant(2.)
        sigma = tf.constant(3.)
        qx_samples = eps_samples * sigma + mu
        norm = Normal(mean=mu, std=sigma)
        log_qx = norm.log_prob(qx_samples)

        v_qx_samples = eps_samples * tf.stop_gradient(sigma) + \
            tf.stop_gradient(mu)
        v_log_qx = norm.log_prob(v_qx_samples)

        def _check_vimco(x_mean, x_std, threshold):
            def log_joint(observed):
                norm = Normal(mean=x_mean, std=x_std)
                return norm.log_prob(observed['x'])

            lower_bound = importance_weighted_objective(
                log_joint,
                observed={},
                latent={'x': [qx_samples, log_qx]},
                axis=0)
            v_lower_bound = importance_weighted_objective(
                log_joint,
                observed={},
                latent={'x': [v_qx_samples, v_log_qx]},
                axis=0)

            vimco_cost = v_lower_bound.vimco()
            vimco_cost = tf.reduce_mean(vimco_cost)
            vimco_grads = tf.gradients(vimco_cost, [mu, sigma])
            sgvb_cost = tf.reduce_mean(lower_bound.sgvb())
            sgvb_grads = tf.gradients(sgvb_cost, [mu, sigma])

            with self.session(use_gpu=True) as sess:
                g1 = sess.run(vimco_grads)
                g2 = sess.run(sgvb_grads)
                # print('vimco_grads:', g1)
                # print('sgvb_grads:', g2)
                self.assertAllClose(g1, g2, threshold, threshold)

        _check_vimco(0., 1., 1e-2)
        _check_vimco(2., 3., 1e-6)
 def log_joint(observed):
     norm = Normal(mean=x_mean, std=x_std)
     return norm.log_prob(observed['x'])
    def forward(self, bn, observed, initial_position):

        if initial_position:
            observed_ = {**initial_position, **observed}
        else:
            observed_ = observed
        bn.forward(observed_)

        q0 = [[k, v.tensor] for k, v in bn.nodes.items()
              if k not in observed.keys()]
        normals = [[k, Normal(mean=fluid.layers.zeros(shape=v.shape, dtype='float32'), std=1)]\
                    for k,v in q0]

        for e in range(self.iters):
            q1 = [[k, paddle.assign(v)] for k, v in q0]
            p0 = [[k, v.sample()] for k, v in normals]
            p1 = [[k, paddle.assign(v)] for k, v in p0]

            ###### leapfrog integrator
            for s in range(self.n_leapfrogs):
                observed_ = {**dict(q1), **observed}
                bn.forward(observed_)
                log_joint_ = bn.log_joint()
                q_v = [v for _, v in q1]
                q_grad = paddle.grad(log_joint_, q_v)

                for i, _ in enumerate(q_grad):
                    p1[i][1] = p1[i][1] + self.step_size * q_grad[i] / 2.0
                    q1[i][1] = q1[i][1] + self.step_size * p1[i][1]
                    p1[i][1] = p1[i][1].detach()
                    p1[i][1].stop_gradient = False
                    q1[i][1] = q1[i][1].detach()
                    q1[i][1].stop_gradient = False

                observed_ = {**dict(q1), **observed}
                q_v = [v for _, v in q1]
                bn.forward(observed_)
                #print(dir(bn))
                log_joint_ = bn.log_joint()
                q_grad = paddle.grad(log_joint_, q_v)

                for i, _ in enumerate(q_grad):
                    p1[i][1] = p1[i][1] + self.step_size * q_grad[i] / 2.0
                    p1[i][1] = p1[i][1].detach()
                    p1[i][1].stop_gradient = False

            ###### reverse p1
            for i, _ in enumerate(p1):
                p1[i][1] = -1 * p1[i][1]

            ###### M-H step
            observed_ = {**dict(q0), **observed}
            bn.forward(observed_)
            log_prob_q0 = bn.log_joint()
            log_prob_p0 = None
            for i, _ in enumerate(p0):
                len_q = len(log_prob_q0.shape)
                len_p = len(p0[i][1].shape)
                assert (len_p >= len_q)
                if len_p > len_q:
                    dims = [i for i in range(len_q - len_p, 0)]
                    try:
                        log_prob_p0 = log_prob_p0 + fluid.layers.reduce_sum(
                            p0[i][1], dims)
                    except:
                        log_prob_p0 = fluid.layers.reduce_sum(p0[i][1], dims)
                else:
                    try:
                        log_prob_p0 = log_prob_p0 + p0[i][1]
                    except:
                        log_prob_p0 = p0[i][1]

            observed_ = {**dict(q1), **observed}
            bn.forward(observed_)
            log_prob_q1 = bn.log_joint()
            log_prob_p1 = None
            for i, _ in enumerate(p1):
                len_q = len(log_prob_q0.shape)
                len_p = len(p1[i][1].shape)
                assert (len_p >= len_q)
                if len_p > len_q:
                    dims = [i for i in range(len_q - len_p, 0)]
                    try:
                        log_prob_p1 = log_prob_p1 + fluid.layers.reduce_sum(
                            p1[i][1], dims)
                    except:
                        log_prob_p1 = fluid.layers.reduce_sum(p1[i][1], dims)
                else:
                    try:
                        log_prob_p1 = log_prob_p1 + p1[i][1]
                    except:
                        log_prob_p1 = p1[i][1]

            assert (log_prob_q0.shape == log_prob_p1.shape)

            acceptance = log_prob_q1 + log_prob_p1 - log_prob_q0 - log_prob_p0
            #acceptance = log_prob_q0 + log_prob_p0 - log_prob_q1 - log_prob_p1

            for i, _ in enumerate(q1):
                event = paddle.to_tensor(np.log(
                    np.random.rand(*q1[i][1].shape)),
                                         dtype='float32')
                #q0[i][1] = paddle.where(acceptance>=event, q1[i][1], q0[i][1])
                a = paddle.cast(acceptance > event, dtype='float32')
                q0[i][1] = paddle.assign(a * q1[i][1] + (1.0 - a) * q0[i][1])

            #print(q0[0][1])
            #print(dir(bn))
            #print(bn.clear_gradients())

        sample_ = dict(q0)
        return sample_
 def log_joint(observed):
     p = Normal(std=1.)
     return p.log_prob(observed['x'])
 def log_joint(observed):
     norm = Normal(std=1.)
     return norm.log_prob(observed['x'])