Exemple #1
0
  def testKLRaises(self):
    ind1 = independent_lib.Independent(
        distribution=normal_lib.Normal(
            loc=np.float32([-1., 1]),
            scale=np.float32([0.1, 0.5])),
        reinterpreted_batch_ndims=1)
    ind2 = independent_lib.Independent(
        distribution=normal_lib.Normal(
            loc=np.float32(-1),
            scale=np.float32(0.5)),
        reinterpreted_batch_ndims=0)

    with self.assertRaisesRegexp(
        ValueError, "Event shapes do not match"):
      kullback_leibler.kl_divergence(ind1, ind2)

    ind1 = independent_lib.Independent(
        distribution=normal_lib.Normal(
            loc=np.float32([-1., 1]),
            scale=np.float32([0.1, 0.5])),
        reinterpreted_batch_ndims=1)
    ind2 = independent_lib.Independent(
        distribution=mvn_diag_lib.MultivariateNormalDiag(
            loc=np.float32([-1., 1]),
            scale_diag=np.float32([0.1, 0.5])),
        reinterpreted_batch_ndims=0)

    with self.assertRaisesRegexp(
        NotImplementedError, "different event shapes"):
      kullback_leibler.kl_divergence(ind1, ind2)
Exemple #2
0
  def testSampleConsistentStats(self):
    loc = np.float32([[-1., 1], [1, -1]])
    scale = np.float32([1., 0.5])
    n_samp = 1e4
    with self.cached_session() as sess:
      ind = independent_lib.Independent(
          distribution=mvn_diag_lib.MultivariateNormalDiag(
              loc=loc,
              scale_identity_multiplier=scale),
          reinterpreted_batch_ndims=1)

      x = ind.sample(int(n_samp), seed=42)
      sample_mean = math_ops.reduce_mean(x, axis=0)
      sample_var = math_ops.reduce_mean(
          math_ops.squared_difference(x, sample_mean), axis=0)
      sample_std = math_ops.sqrt(sample_var)
      sample_entropy = -math_ops.reduce_mean(ind.log_prob(x), axis=0)

      [
          sample_mean_, sample_var_, sample_std_, sample_entropy_,
          actual_mean_, actual_var_, actual_std_, actual_entropy_,
          actual_mode_,
      ] = sess.run([
          sample_mean, sample_var, sample_std, sample_entropy,
          ind.mean(), ind.variance(), ind.stddev(), ind.entropy(), ind.mode(),
      ])

      self.assertAllClose(sample_mean_, actual_mean_, rtol=0.02, atol=0.)
      self.assertAllClose(sample_var_, actual_var_, rtol=0.04, atol=0.)
      self.assertAllClose(sample_std_, actual_std_, rtol=0.02, atol=0.)
      self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.)
      self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.)
Exemple #3
0
  def testKLScalarToMultivariate(self):
    normal1 = normal_lib.Normal(
        loc=np.float32([-1., 1]),
        scale=np.float32([0.1, 0.5]))
    ind1 = independent_lib.Independent(
        distribution=normal1, reinterpreted_batch_ndims=1)

    normal2 = normal_lib.Normal(
        loc=np.float32([-3., 3]),
        scale=np.float32([0.3, 0.3]))
    ind2 = independent_lib.Independent(
        distribution=normal2, reinterpreted_batch_ndims=1)

    normal_kl = kullback_leibler.kl_divergence(normal1, normal2)
    ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
    self.assertAllClose(
        self.evaluate(math_ops.reduce_sum(normal_kl, axis=-1)),
        self.evaluate(ind_kl))
Exemple #4
0
  def testKLMultivariateToMultivariate(self):
    # (1, 1, 2) batch of MVNDiag
    mvn1 = mvn_diag_lib.MultivariateNormalDiag(
        loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]),
        scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]]))
    ind1 = independent_lib.Independent(
        distribution=mvn1, reinterpreted_batch_ndims=2)

    # (1, 1, 2) batch of MVNDiag
    mvn2 = mvn_diag_lib.MultivariateNormalDiag(
        loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]),
        scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]]))

    ind2 = independent_lib.Independent(
        distribution=mvn2, reinterpreted_batch_ndims=2)

    mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2)
    ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
    self.assertAllClose(
        self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])),
        self.evaluate(ind_kl))
Exemple #5
0
  def testKLIdentity(self):
    normal1 = normal_lib.Normal(
        loc=np.float32([-1., 1]),
        scale=np.float32([0.1, 0.5]))
    # This is functionally just a wrapper around normal1,
    # and doesn't change any outputs.
    ind1 = independent_lib.Independent(
        distribution=normal1, reinterpreted_batch_ndims=0)

    normal2 = normal_lib.Normal(
        loc=np.float32([-3., 3]),
        scale=np.float32([0.3, 0.3]))
    # This is functionally just a wrapper around normal2,
    # and doesn't change any outputs.
    ind2 = independent_lib.Independent(
        distribution=normal2, reinterpreted_batch_ndims=0)

    normal_kl = kullback_leibler.kl_divergence(normal1, normal2)
    ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
    self.assertAllClose(
        self.evaluate(normal_kl), self.evaluate(ind_kl))
Exemple #6
0
  def _testMnistLike(self, static_shape):
    sample_shape = [4, 5]
    batch_shape = [10]
    image_shape = [28, 28, 1]
    logits = 3 * self._rng.random_sample(
        batch_shape + image_shape).astype(np.float32) - 1

    def expected_log_prob(x, logits):
      return (x * logits - np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1)

    with self.cached_session() as sess:
      logits_ph = array_ops.placeholder(
          dtypes.float32, shape=logits.shape if static_shape else None)
      ind = independent_lib.Independent(
          distribution=bernoulli_lib.Bernoulli(logits=logits_ph))
      x = ind.sample(sample_shape, seed=42)
      log_prob_x = ind.log_prob(x)
      [
          x_,
          actual_log_prob_x,
          ind_batch_shape,
          ind_event_shape,
          x_shape,
          log_prob_x_shape,
      ] = sess.run([
          x,
          log_prob_x,
          ind.batch_shape_tensor(),
          ind.event_shape_tensor(),
          array_ops.shape(x),
          array_ops.shape(log_prob_x),
      ], feed_dict={logits_ph: logits})

      if static_shape:
        ind_batch_shape = ind.batch_shape
        ind_event_shape = ind.event_shape
        x_shape = x.shape
        log_prob_x_shape = log_prob_x.shape

      self.assertAllEqual(batch_shape, ind_batch_shape)
      self.assertAllEqual(image_shape, ind_event_shape)
      self.assertAllEqual(sample_shape + batch_shape + image_shape, x_shape)
      self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape)
      self.assertAllClose(expected_log_prob(x_, logits),
                          actual_log_prob_x,
                          rtol=1e-6, atol=0.)
Exemple #7
0
  def testSampleAndLogProbUnivariate(self):
    loc = np.float32([-1., 1])
    scale = np.float32([0.1, 0.5])
    with self.cached_session() as sess:
      ind = independent_lib.Independent(
          distribution=normal_lib.Normal(loc=loc, scale=scale),
          reinterpreted_batch_ndims=1)

      x = ind.sample([4, 5], seed=42)
      log_prob_x = ind.log_prob(x)
      x_, actual_log_prob_x = sess.run([x, log_prob_x])

      self.assertEqual([], ind.batch_shape)
      self.assertEqual([2], ind.event_shape)
      self.assertEqual([4, 5, 2], x.shape)
      self.assertEqual([4, 5], log_prob_x.shape)

      expected_log_prob_x = stats.norm(loc, scale).logpdf(x_).sum(-1)
      self.assertAllClose(expected_log_prob_x, actual_log_prob_x,
                          rtol=1e-5, atol=0.)
Exemple #8
0
  def testSampleAndLogProbMultivariate(self):
    loc = np.float32([[-1., 1], [1, -1]])
    scale = np.float32([1., 0.5])
    with self.cached_session() as sess:
      ind = independent_lib.Independent(
          distribution=mvn_diag_lib.MultivariateNormalDiag(
              loc=loc,
              scale_identity_multiplier=scale),
          reinterpreted_batch_ndims=1)

      x = ind.sample([4, 5], seed=42)
      log_prob_x = ind.log_prob(x)
      x_, actual_log_prob_x = sess.run([x, log_prob_x])

      self.assertEqual([], ind.batch_shape)
      self.assertEqual([2, 2], ind.event_shape)
      self.assertEqual([4, 5, 2, 2], x.shape)
      self.assertEqual([4, 5], log_prob_x.shape)

      expected_log_prob_x = stats.norm(loc, scale[:, None]).logpdf(
          x_).sum(-1).sum(-1)
      self.assertAllClose(expected_log_prob_x, actual_log_prob_x,
                          rtol=1e-6, atol=0.)
 def _fn(samples):
     scale = math_ops.exp(affine_bijector.forward(samples))
     return independent_lib.Independent(normal_lib.Normal(
         loc=0., scale=scale, validate_args=True),
                                        reinterpreted_batch_ndims=1)