Exemplo n.º 1
0
    def test_multivariate_normal_prob_positive_product_of_components(self):
        # Test that importance sampling can correctly estimate the probability that
        # the product of components in a MultivariateNormal are > 0.
        n = 1000
        with self.cached_session():
            p = mvn_diag_lib.MultivariateNormalDiag(loc=[0.],
                                                    scale_diag=[1.0, 1.0])
            q = mvn_diag_lib.MultivariateNormalDiag(loc=[0.5],
                                                    scale_diag=[3., 3.])

            # Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x).
            # Should equal 1/2 because p is a spherical Gaussian centered at (0, 0).
            def indicator(x):
                x1_times_x2 = math_ops.reduce_prod(x, axis=[-1])
                return 0.5 * (math_ops.sign(x1_times_x2) + 1.0)

            prob = mc.expectation_importance_sampler(f=indicator,
                                                     log_p=p.log_prob,
                                                     sampling_dist_q=q,
                                                     n=n,
                                                     seed=42)

            # Relative tolerance (rtol) chosen 2 times as large as minimim needed to
            # pass.
            # Convergence is +- 0.004 if n = 100k.
            self.assertEqual(p.batch_shape, prob.get_shape())
            self.assertAllClose(0.5, prob.eval(), rtol=0.05)
Exemplo n.º 2
0
  def testSampleConsistentStats(self):
    loc = np.float32([[-1., 1], [1, -1]])
    scale = np.float32([1., 0.5])
    n_samp = 1e4
    with self.cached_session() as sess:
      ind = independent_lib.Independent(
          distribution=mvn_diag_lib.MultivariateNormalDiag(
              loc=loc,
              scale_identity_multiplier=scale),
          reinterpreted_batch_ndims=1)

      x = ind.sample(int(n_samp), seed=42)
      sample_mean = math_ops.reduce_mean(x, axis=0)
      sample_var = math_ops.reduce_mean(
          math_ops.squared_difference(x, sample_mean), axis=0)
      sample_std = math_ops.sqrt(sample_var)
      sample_entropy = -math_ops.reduce_mean(ind.log_prob(x), axis=0)

      [
          sample_mean_, sample_var_, sample_std_, sample_entropy_,
          actual_mean_, actual_var_, actual_std_, actual_entropy_,
          actual_mode_,
      ] = sess.run([
          sample_mean, sample_var, sample_std, sample_entropy,
          ind.mean(), ind.variance(), ind.stddev(), ind.entropy(), ind.mode(),
      ])

      self.assertAllClose(sample_mean_, actual_mean_, rtol=0.02, atol=0.)
      self.assertAllClose(sample_var_, actual_var_, rtol=0.04, atol=0.)
      self.assertAllClose(sample_std_, actual_std_, rtol=0.02, atol=0.)
      self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.)
      self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.)
Exemplo n.º 3
0
  def testKLRaises(self):
    ind1 = independent_lib.Independent(
        distribution=normal_lib.Normal(
            loc=np.float32([-1., 1]),
            scale=np.float32([0.1, 0.5])),
        reinterpreted_batch_ndims=1)
    ind2 = independent_lib.Independent(
        distribution=normal_lib.Normal(
            loc=np.float32(-1),
            scale=np.float32(0.5)),
        reinterpreted_batch_ndims=0)

    with self.assertRaisesRegexp(
        ValueError, "Event shapes do not match"):
      kullback_leibler.kl_divergence(ind1, ind2)

    ind1 = independent_lib.Independent(
        distribution=normal_lib.Normal(
            loc=np.float32([-1., 1]),
            scale=np.float32([0.1, 0.5])),
        reinterpreted_batch_ndims=1)
    ind2 = independent_lib.Independent(
        distribution=mvn_diag_lib.MultivariateNormalDiag(
            loc=np.float32([-1., 1]),
            scale_diag=np.float32([0.1, 0.5])),
        reinterpreted_batch_ndims=0)

    with self.assertRaisesRegexp(
        NotImplementedError, "different event shapes"):
      kullback_leibler.kl_divergence(ind1, ind2)
    def test_non_vector_shape(self):
        dims = 2
        new_batch_shape = 2
        old_batch_shape = [2]

        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)

        if self.is_static_shape:
            with self.assertRaisesRegexp(ValueError, r".*must be a vector.*"):
                batch_reshape_lib.BatchReshape(distribution=mvn,
                                               batch_shape=new_batch_shape_ph,
                                               validate_args=True)

        else:
            with self.cached_session():
                with self.assertRaisesOpError(r".*must be a vector.*"):
                    batch_reshape_lib.BatchReshape(
                        distribution=mvn,
                        batch_shape=new_batch_shape_ph,
                        validate_args=True).sample().eval()
    def test_non_positive_shape(self):
        dims = 2
        old_batch_shape = [4]
        if self.is_static_shape:
            # Unknown first dimension does not trigger size check. Note that
            # any dimension < 0 is treated statically as unknown.
            new_batch_shape = [-1, 0]
        else:
            new_batch_shape = [-2,
                               -2]  # -2 * -2 = 4, same size as the old shape.

        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)

        if self.is_static_shape:
            with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"):
                batch_reshape_lib.BatchReshape(distribution=mvn,
                                               batch_shape=new_batch_shape_ph,
                                               validate_args=True)

        else:
            with self.cached_session():
                with self.assertRaisesOpError(r".*must be >=-1.*"):
                    batch_reshape_lib.BatchReshape(
                        distribution=mvn,
                        batch_shape=new_batch_shape_ph,
                        validate_args=True).sample().eval()
    def test_bad_reshape_size(self):
        dims = 2
        new_batch_shape = [2, 3]
        old_batch_shape = [2]  # 2 != 2*3

        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)

        if self.is_static_shape:
            with self.assertRaisesRegexp(
                    ValueError, (r"`batch_shape` size \(6\) must match "
                                 r"`distribution\.batch_shape` size \(2\)")):
                batch_reshape_lib.BatchReshape(distribution=mvn,
                                               batch_shape=new_batch_shape_ph,
                                               validate_args=True)

        else:
            with self.cached_session():
                with self.assertRaisesOpError(r"Shape sizes do not match."):
                    batch_reshape_lib.BatchReshape(
                        distribution=mvn,
                        batch_shape=new_batch_shape_ph,
                        validate_args=True).sample().eval()
 def testSampleConsistentMeanCovariance(self):
   with self.cached_session() as sess:
     gm = mixture_same_family_lib.MixtureSameFamily(
         mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
         components_distribution=mvn_diag_lib.MultivariateNormalDiag(
             loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
     self.run_test_sample_consistent_mean_covariance(sess.run, gm)
 def testVarianceConsistentCovariance(self):
   with self.cached_session() as sess:
     gm = mixture_same_family_lib.MixtureSameFamily(
         mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
         components_distribution=mvn_diag_lib.MultivariateNormalDiag(
             loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
     cov_, var_ = sess.run([gm.covariance(), gm.variance()])
     self.assertAllClose(cov_.diagonal(), var_, atol=0.)
 def testSampleAndLogProbMultivariateShapes(self):
   with self.cached_session():
     gm = mixture_same_family_lib.MixtureSameFamily(
         mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
         components_distribution=mvn_diag_lib.MultivariateNormalDiag(
             loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
     x = gm.sample([4, 5], seed=42)
     log_prob_x = gm.log_prob(x)
     self.assertEqual([4, 5, 2], x.shape)
     self.assertEqual([4, 5], log_prob_x.shape)
Exemplo n.º 10
0
  def testKLMultivariateToMultivariate(self):
    # (1, 1, 2) batch of MVNDiag
    mvn1 = mvn_diag_lib.MultivariateNormalDiag(
        loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]),
        scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]]))
    ind1 = independent_lib.Independent(
        distribution=mvn1, reinterpreted_batch_ndims=2)

    # (1, 1, 2) batch of MVNDiag
    mvn2 = mvn_diag_lib.MultivariateNormalDiag(
        loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]),
        scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]]))

    ind2 = independent_lib.Independent(
        distribution=mvn2, reinterpreted_batch_ndims=2)

    mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2)
    ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
    self.assertAllClose(
        self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])),
        self.evaluate(ind_kl))
 def testSampleConsistentLogProb(self):
   with self.cached_session() as sess:
     gm = mixture_same_family_lib.MixtureSameFamily(
         mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
         components_distribution=mvn_diag_lib.MultivariateNormalDiag(
             loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
     # Ball centered at component0's mean.
     self.run_test_sample_consistent_log_prob(
         sess.run, gm, radius=1., center=[-1., 1], rtol=0.02)
     # Larger ball centered at component1's mean.
     self.run_test_sample_consistent_log_prob(
         sess.run, gm, radius=1., center=[1., -1], rtol=0.02)
Exemplo n.º 12
0
    def test_pad_mixture_dimensions_mixture_same_family(self):
        with self.cached_session() as sess:
            gm = mixture_same_family.MixtureSameFamily(
                mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]),
                components_distribution=mvn_diag.MultivariateNormalDiag(
                    loc=[[-1., 1], [1, -1]],
                    scale_identity_multiplier=[1.0, 0.5]))

            x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]])
            x_pad = distribution_util.pad_mixture_dimensions(
                x, gm, gm.mixture_distribution, gm.event_shape.ndims)
            x_out, x_pad_out = sess.run([x, x_pad])

        self.assertAllEqual(x_pad_out.shape, [2, 2, 1])
        self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
    def make_mvn(self, dims, new_batch_shape, old_batch_shape):
        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)
        reshape_mvn = batch_reshape_lib.BatchReshape(
            distribution=mvn,
            batch_shape=new_batch_shape_ph,
            validate_args=True)
        return mvn, reshape_mvn
Exemplo n.º 14
0
  def testSampleAndLogProbMultivariate(self):
    loc = np.float32([[-1., 1], [1, -1]])
    scale = np.float32([1., 0.5])
    with self.cached_session() as sess:
      ind = independent_lib.Independent(
          distribution=mvn_diag_lib.MultivariateNormalDiag(
              loc=loc,
              scale_identity_multiplier=scale),
          reinterpreted_batch_ndims=1)

      x = ind.sample([4, 5], seed=42)
      log_prob_x = ind.log_prob(x)
      x_, actual_log_prob_x = sess.run([x, log_prob_x])

      self.assertEqual([], ind.batch_shape)
      self.assertEqual([2, 2], ind.event_shape)
      self.assertEqual([4, 5, 2, 2], x.shape)
      self.assertEqual([4, 5], log_prob_x.shape)

      expected_log_prob_x = stats.norm(loc, scale[:, None]).logpdf(
          x_).sum(-1).sum(-1)
      self.assertAllClose(expected_log_prob_x, actual_log_prob_x,
                          rtol=1e-6, atol=0.)