Пример #1
0
    def test_get_score(self):
        class Capture(object):
            def __init__(self, vae):
                self._vae = vae
                self._vae_variational = vae.variational
                vae.variational = self._variational
                self.q_net = None

            def _variational(self, x, z=None, n_z=None):
                self.q_net = self._vae_variational(x, z=z, n_z=n_z)
                return self.q_net

        tf.set_random_seed(1234)
        donut = Donut(
            h_for_p_x=lambda x: x,
            h_for_q_z=lambda x: x,
            x_dims=5,
            z_dims=3,
        )
        capture = Capture(donut.vae)
        donut.vae.reconstruct = Mock(
            wraps=lambda x: x + tf.reduce_sum(x)  # only called by MCMC
        )
        x = tf.reshape(tf.range(20, dtype=tf.float32), [4, 5])
        y = tf.constant([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0],
                         [0, 0, 0, 1, 0]],
                        dtype=tf.int32)
        _ = donut.get_score(x, y)  # ensure variables created

        def r_prob(x, z, n_z=None, x_in=None):
            if x_in is None:
                x_in = x
            q_net = donut.vae.variational(x_in, z=z, n_z=n_z)
            p_net = donut.vae.model(z=q_net['z'], x=x, n_z=n_z)
            p = p_net['x'].log_prob(group_ndims=0)
            if n_z is not None:
                p = tf.reduce_mean(p, axis=0)
            return p

        with self.test_session() as sess:
            ensure_variables_initialized()

            # test y is None
            donut.vae.reconstruct.reset_mock()
            np.testing.assert_allclose(*sess.run([
                donut.get_score(
                    x, y=None, mcmc_iteration=1, last_point_only=False),
                r_prob(x, z=capture.q_net['z'])
            ]))
            self.assertEqual(donut.vae.reconstruct.call_count, 0)

            # test mcmc_iteration is None
            donut.vae.reconstruct.reset_mock()
            np.testing.assert_allclose(*sess.run([
                donut.get_score(
                    x, y=y, mcmc_iteration=None, last_point_only=False),
                r_prob(x, z=capture.q_net['z'])
            ]))
            self.assertEqual(donut.vae.reconstruct.call_count, 0)

            # test mcmc once
            x2 = tf.where(
                tf.cast(y, dtype=tf.bool),
                x,
                x + tf.reduce_sum(x),
            )
            donut.vae.reconstruct.reset_mock()
            np.testing.assert_allclose(*sess.run([
                donut.get_score(
                    x, y=y, mcmc_iteration=1, last_point_only=False),
                r_prob(x, z=capture.q_net['z'], x_in=x2)
            ]))
            self.assertEqual(donut.vae.reconstruct.call_count, 1)

            # test mcmc with n_z > 1
            donut.vae.reconstruct.reset_mock()
            np.testing.assert_allclose(*sess.run([
                donut.get_score(
                    x, y=y, n_z=7, mcmc_iteration=1, last_point_only=False),
                r_prob(x, z=capture.q_net['z'], x_in=x2, n_z=7)
            ]))
            self.assertEqual(capture.q_net['z'].get_shape(),
                             tf.TensorShape([7, 4, 3]))
            self.assertEqual(donut.vae.reconstruct.call_count, 1)
Пример #2
0
    def test_prediction(self):
        np.random.seed(1234)
        tf.set_random_seed(1234)

        # test last_point_only == True
        donut = Donut(h_for_p_x=lambda x: x,
                      h_for_q_z=lambda x: x,
                      x_dims=5,
                      z_dims=3)
        _ = donut.get_score(tf.zeros([4, 5], dtype=tf.float32),
                            tf.zeros(
                                [4, 5],
                                dtype=tf.int32))  # ensure variables created
        pred = DonutPredictor(donut, n_z=2, batch_size=4)
        self.assertIs(pred.model, donut)

        with self.test_session():
            ensure_variables_initialized()

            # test without missing
            res = pred.get_score(values=np.arange(5, dtype=np.float32))
            self.assertEqual(res.shape, (1, ))

            res = pred.get_score(values=np.arange(8, dtype=np.float32))
            self.assertEqual(res.shape, (4, ))

            res = pred.get_score(values=np.arange(10, dtype=np.float32))
            self.assertEqual(res.shape, (6, ))

            # test with missing
            res = pred.get_score(values=np.arange(10, dtype=np.float32),
                                 missing=np.random.binomial(1, .5, size=10))
            self.assertEqual(res.shape, (6, ))

        # test errors
        with self.test_session():
            with pytest.raises(ValueError,
                               match='`values` must be a 1-D array'):
                _ = pred.get_score(
                    np.arange(10, dtype=np.float32).reshape([-1, 1]))
            with pytest.raises(ValueError,
                               match='The shape of `missing` does not agree '
                               'with the shape of `values`'):
                _ = pred.get_score(np.arange(10, dtype=np.float32),
                                   np.arange(9, dtype=np.int32))

        # test last_point_only == False
        pred = DonutPredictor(donut,
                              n_z=2,
                              batch_size=4,
                              last_point_only=False)

        with self.test_session():
            ensure_variables_initialized()

            # test without missing
            res = pred.get_score(values=np.arange(10, dtype=np.float32))
            self.assertEqual(res.shape, (6, 5))

            # test with missing
            res = pred.get_score(values=np.arange(10, dtype=np.float32),
                                 missing=np.random.binomial(1, .5, size=10))
            self.assertEqual(res.shape, (6, 5))

        # test set feed_dict
        is_training = tf.placeholder(tf.bool, shape=())
        pred = DonutPredictor(donut,
                              n_z=2,
                              batch_size=4,
                              feed_dict={is_training: False})

        with self.test_session():
            ensure_variables_initialized()
            _ = pred.get_score(values=np.arange(10, dtype=np.float32))