예제 #1
0
 def testStateParts(self):
   with self.test_session(graph=ops.Graph()) as sess:
     dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1))
     dist_y = independent_lib.Independent(
         gamma_lib.Gamma(concentration=self.dtype([1, 2]),
                         rate=self.dtype([0.5, 0.75])),
         reinterpreted_batch_ndims=1)
     def target_log_prob(x, y):
       return dist_x.log_prob(x) + dist_y.log_prob(y)
     x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)]
     samples, _ = hmc.sample_chain(
         num_results=int(2e3),
         target_log_prob_fn=target_log_prob,
         current_state=x0,
         step_size=0.85,
         num_leapfrog_steps=3,
         num_burnin_steps=int(250),
         seed=49)
     actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples]
     actual_vars = [_reduce_variance(s, axis=0) for s in samples]
     expected_means = [dist_x.mean(), dist_y.mean()]
     expected_vars = [dist_x.variance(), dist_y.variance()]
     [
         actual_means_,
         actual_vars_,
         expected_means_,
         expected_vars_,
     ] = sess.run([
         actual_means,
         actual_vars,
         expected_means,
         expected_vars,
     ])
     self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16)
     self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25)
예제 #2
0
    def testKLScalarToMultivariate(self):
        normal1 = normal_lib.Normal(loc=np.float32([-1., 1]),
                                    scale=np.float32([0.1, 0.5]))
        ind1 = independent_lib.Independent(distribution=normal1,
                                           reinterpreted_batch_ndims=1)

        normal2 = normal_lib.Normal(loc=np.float32([-3., 3]),
                                    scale=np.float32([0.3, 0.3]))
        ind2 = independent_lib.Independent(distribution=normal2,
                                           reinterpreted_batch_ndims=1)

        normal_kl = kullback_leibler.kl_divergence(normal1, normal2)
        ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
        self.assertAllClose(
            self.evaluate(math_ops.reduce_sum(normal_kl, axis=-1)),
            self.evaluate(ind_kl))
예제 #3
0
    def testKLIdentity(self):
        normal1 = normal_lib.Normal(loc=np.float32([-1., 1]),
                                    scale=np.float32([0.1, 0.5]))
        # This is functionally just a wrapper around normal1,
        # and doesn't change any outputs.
        ind1 = independent_lib.Independent(distribution=normal1,
                                           reinterpreted_batch_ndims=0)

        normal2 = normal_lib.Normal(loc=np.float32([-3., 3]),
                                    scale=np.float32([0.3, 0.3]))
        # This is functionally just a wrapper around normal2,
        # and doesn't change any outputs.
        ind2 = independent_lib.Independent(distribution=normal2,
                                           reinterpreted_batch_ndims=0)

        normal_kl = kullback_leibler.kl_divergence(normal1, normal2)
        ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
        self.assertAllClose(self.evaluate(normal_kl), self.evaluate(ind_kl))
예제 #4
0
 def _fn(dtype, shape, name, trainable, add_variable_fn):
   """Creates multivariate `Deterministic` or `Normal` distribution."""
   loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn)
   if scale is None:
     dist = deterministic_lib.Deterministic(loc=loc)
   else:
     dist = normal_lib.Normal(loc=loc, scale=scale)
   reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0]
   return independent_lib.Independent(
       dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
예제 #5
0
    def testKLMultivariateToMultivariate(self):
        # (1, 1, 2) batch of MVNDiag
        mvn1 = mvn_diag_lib.MultivariateNormalDiag(
            loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]),
            scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]]))
        ind1 = independent_lib.Independent(distribution=mvn1,
                                           reinterpreted_batch_ndims=2)

        # (1, 1, 2) batch of MVNDiag
        mvn2 = mvn_diag_lib.MultivariateNormalDiag(
            loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]),
            scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]]))

        ind2 = independent_lib.Independent(distribution=mvn2,
                                           reinterpreted_batch_ndims=2)

        mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2)
        ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
        self.assertAllClose(
            self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])),
            self.evaluate(ind_kl))
예제 #6
0
    def testSampleConsistentStats(self):
        loc = np.float32([[-1., 1], [1, -1]])
        scale = np.float32([1., 0.5])
        n_samp = 1e4
        with self.cached_session() as sess:
            ind = independent_lib.Independent(
                distribution=mvn_diag_lib.MultivariateNormalDiag(
                    loc=loc, scale_identity_multiplier=scale),
                reinterpreted_batch_ndims=1)

            x = ind.sample(int(n_samp), seed=42)
            sample_mean = math_ops.reduce_mean(x, axis=0)
            sample_var = math_ops.reduce_mean(math_ops.squared_difference(
                x, sample_mean),
                                              axis=0)
            sample_std = math_ops.sqrt(sample_var)
            sample_entropy = -math_ops.reduce_mean(ind.log_prob(x), axis=0)

            [
                sample_mean_,
                sample_var_,
                sample_std_,
                sample_entropy_,
                actual_mean_,
                actual_var_,
                actual_std_,
                actual_entropy_,
                actual_mode_,
            ] = sess.run([
                sample_mean,
                sample_var,
                sample_std,
                sample_entropy,
                ind.mean(),
                ind.variance(),
                ind.stddev(),
                ind.entropy(),
                ind.mode(),
            ])

            self.assertAllCloseAccordingToType(sample_mean_,
                                               actual_mean_,
                                               rtol=0.02)
            self.assertAllCloseAccordingToType(sample_var_,
                                               actual_var_,
                                               rtol=0.04)
            self.assertAllCloseAccordingToType(sample_std_,
                                               actual_std_,
                                               rtol=0.02)
            self.assertAllCloseAccordingToType(sample_entropy_,
                                               actual_entropy_,
                                               rtol=0.01)
            self.assertAllCloseAccordingToType(loc, actual_mode_, rtol=1e-6)
예제 #7
0
    def testKLRaises(self):
        ind1 = independent_lib.Independent(distribution=normal_lib.Normal(
            loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])),
                                           reinterpreted_batch_ndims=1)
        ind2 = independent_lib.Independent(distribution=normal_lib.Normal(
            loc=np.float32(-1), scale=np.float32(0.5)),
                                           reinterpreted_batch_ndims=0)

        with self.assertRaisesRegexp(ValueError, "Event shapes do not match"):
            kullback_leibler.kl_divergence(ind1, ind2)

        ind1 = independent_lib.Independent(distribution=normal_lib.Normal(
            loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])),
                                           reinterpreted_batch_ndims=1)
        ind2 = independent_lib.Independent(
            distribution=mvn_diag_lib.MultivariateNormalDiag(
                loc=np.float32([-1., 1]), scale_diag=np.float32([0.1, 0.5])),
            reinterpreted_batch_ndims=0)

        with self.assertRaisesRegexp(NotImplementedError,
                                     "different event shapes"):
            kullback_leibler.kl_divergence(ind1, ind2)
예제 #8
0
    def _testMnistLike(self, static_shape):
        sample_shape = [4, 5]
        batch_shape = [10]
        image_shape = [28, 28, 1]
        logits = 3 * self._rng.random_sample(batch_shape + image_shape).astype(
            np.float32) - 1

        def expected_log_prob(x, logits):
            return (x * logits -
                    np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1)

        with self.test_session() as sess:
            logits_ph = array_ops.placeholder(
                dtypes.float32, shape=logits.shape if static_shape else None)
            ind = independent_lib.Independent(
                distribution=bernoulli_lib.Bernoulli(logits=logits_ph))
            x = ind.sample(sample_shape)
            log_prob_x = ind.log_prob(x)
            [
                x_,
                actual_log_prob_x,
                ind_batch_shape,
                ind_event_shape,
                x_shape,
                log_prob_x_shape,
            ] = sess.run([
                x,
                log_prob_x,
                ind.batch_shape_tensor(),
                ind.event_shape_tensor(),
                array_ops.shape(x),
                array_ops.shape(log_prob_x),
            ],
                         feed_dict={logits_ph: logits})

            if static_shape:
                ind_batch_shape = ind.batch_shape
                ind_event_shape = ind.event_shape
                x_shape = x.shape
                log_prob_x_shape = log_prob_x.shape

            self.assertAllEqual(batch_shape, ind_batch_shape)
            self.assertAllEqual(image_shape, ind_event_shape)
            self.assertAllEqual(sample_shape + batch_shape + image_shape,
                                x_shape)
            self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape)
            self.assertAllClose(expected_log_prob(x_, logits),
                                actual_log_prob_x,
                                rtol=1e-6,
                                atol=0.)
예제 #9
0
    def testSampleAndLogProbUnivariate(self):
        loc = np.float32([-1., 1])
        scale = np.float32([0.1, 0.5])
        with self.cached_session() as sess:
            ind = independent_lib.Independent(distribution=normal_lib.Normal(
                loc=loc, scale=scale),
                                              reinterpreted_batch_ndims=1)

            x = ind.sample([4, 5], seed=42)
            log_prob_x = ind.log_prob(x)
            x_, actual_log_prob_x = sess.run([x, log_prob_x])

            self.assertEqual([], ind.batch_shape)
            self.assertEqual([2], ind.event_shape)
            self.assertEqual([4, 5, 2], x.shape)
            self.assertEqual([4, 5], log_prob_x.shape)

            expected_log_prob_x = stats.norm(loc, scale).logpdf(x_).sum(-1)
            self.assertAllCloseAccordingToType(expected_log_prob_x,
                                               actual_log_prob_x)
예제 #10
0
    def testSampleAndLogProbMultivariate(self):
        loc = np.float32([[-1., 1], [1, -1]])
        scale = np.float32([1., 0.5])
        with self.cached_session() as sess:
            ind = independent_lib.Independent(
                distribution=mvn_diag_lib.MultivariateNormalDiag(
                    loc=loc, scale_identity_multiplier=scale),
                reinterpreted_batch_ndims=1)

            x = ind.sample([4, 5], seed=42)
            log_prob_x = ind.log_prob(x)
            x_, actual_log_prob_x = sess.run([x, log_prob_x])

            self.assertEqual([], ind.batch_shape)
            self.assertEqual([2, 2], ind.event_shape)
            self.assertEqual([4, 5, 2, 2], x.shape)
            self.assertEqual([4, 5], log_prob_x.shape)

            expected_log_prob_x = stats.norm(
                loc, scale[:, None]).logpdf(x_).sum(-1).sum(-1)
            self.assertAllCloseAccordingToType(expected_log_prob_x,
                                               actual_log_prob_x)
예제 #11
0
 def _fn(samples):
     scale = math_ops.exp(affine_bijector.forward(samples))
     return independent_lib.Independent(normal_lib.Normal(
         loc=0., scale=scale, validate_args=True),
                                        reinterpreted_batch_ndims=1)