예제 #1
0
    def test_log_prob_extreme(self):
        assert_allclose = functools.partial(
            np.testing.assert_allclose, rtol=1e-5, atol=1e-6)

        x = np.asarray(0., dtype=np.float64)
        mean = np.asarray(0., dtype=np.float64)
        bin_size = 1 / 256.

        # to ensure bin_size / (2*scale) < cdf_delta
        cdf_delta = 1e-8
        s = np.linspace(0, 20, 101)
        t = bin_size / (2 * np.exp(s))
        idx = np.where(safe_sigmoid(t) - safe_sigmoid(-t) < cdf_delta)[0]
        self.assertGreater(np.size(idx), 2)
        log_scale = s[idx]

        with self.test_session() as sess:
            # now compute the log-probability of this extreme case
            d = DiscretizedLogistic(
                mean=mean, log_scale=log_scale, bin_size=bin_size,
                epsilon=1e-7
            )
            assert_allclose(
                sess.run(d.log_prob(x, group_ndims=0)),
                naive_discretized_logistic_pdf(
                    x, mean, log_scale, bin_size, None, None,
                    biased_edges=True, group_ndims=0)
            )
예제 #2
0
    def test_log_prob(self):
        assert_allclose = functools.partial(
            np.testing.assert_allclose, rtol=1e-5, atol=1e-6)

        np.random.seed(1234)
        x = np.random.normal(size=[7, 3, 2, 5, 4]).astype(np.float64)
        self.assertLess(np.min(x), -1.1)
        self.assertGreater(np.max(x), 2.1)
        self.assertGreater(np.sum(np.logical_and(x > -0.5, x < 1.5)), 0)

        mean = 3 * np.random.uniform(size=[2, 1, 4]).astype(np.float64) - 1
        log_scale = np.random.normal(size=[3, 1, 5, 1]).astype(np.float64)
        bin_size = 1 / 256.
        min_val = -1.
        max_val = 2.

        with self.test_session() as sess:
            # biased_edges = False, discretize_given = True
            d = DiscretizedLogistic(
                mean=mean, log_scale=log_scale, bin_size=bin_size,
                min_val=None, max_val=None, biased_edges=False,
                dtype=tf.float64
            )

            assert_allclose(
                sess.run(d.log_prob(x, group_ndims=0)),
                naive_discretized_logistic_pdf(
                    x, mean, log_scale, bin_size, None, None,
                    biased_edges=False, group_ndims=0)
            )

            # biased_edges = False, discretize_given = True
            d = DiscretizedLogistic(
                mean=mean, log_scale=log_scale, bin_size=bin_size,
                min_val=min_val, max_val=max_val, biased_edges=False,
                dtype=tf.float64
            )

            assert_allclose(
                sess.run(d.log_prob(x, group_ndims=0)),
                naive_discretized_logistic_pdf(
                    x, mean, log_scale, bin_size, min_val, max_val,
                    biased_edges=False, group_ndims=0)
            )

            # biased_edges = False, discretize_given = False
            d = DiscretizedLogistic(
                mean=mean, log_scale=log_scale, bin_size=bin_size,
                min_val=min_val, max_val=max_val, biased_edges=False,
                dtype=tf.float64, discretize_given=False
            )

            assert_allclose(
                sess.run(d.log_prob(x, group_ndims=0)),
                naive_discretized_logistic_pdf(
                    x, mean, log_scale, bin_size, min_val, max_val,
                    biased_edges=False, discretize_given=False, group_ndims=0)
            )

            # biased_edges = True, discretize_given = True
            d = DiscretizedLogistic(
                mean=mean, log_scale=log_scale, bin_size=bin_size,
                min_val=min_val, max_val=max_val, biased_edges=True,
                dtype=tf.float64
            )

            assert_allclose(
                sess.run(d.log_prob(x, group_ndims=0)),
                naive_discretized_logistic_pdf(
                    x, mean, log_scale, bin_size, min_val, max_val,
                    biased_edges=True, group_ndims=0)
            )

            assert_allclose(
                sess.run(d.log_prob(x, group_ndims=2)),
                naive_discretized_logistic_pdf(
                    x, mean, log_scale, bin_size, min_val, max_val,
                    biased_edges=True, group_ndims=2)
            )