예제 #1
0
 def test_degenerate_scalar_mask(self):
   d0 = tfd.Masked(tfd.Normal(0., 1.), validity_mask=False, validate_args=True)
   d1 = tfd.Masked(tfd.Normal(0., 1.), validity_mask=True, validate_args=True)
   stream = test_util.test_seed_stream()
   self.assertAllEqual(
       d0.sample(seed=stream()), d0.sample(seed=stream()))
   self.assertNotAllEqual(
       d1.sample(seed=stream()), d1.sample(seed=stream()))
   self.assertAllEqual(0., d0.log_prob(123.))
예제 #2
0
  def test_log_prob(self):
    d = tfd.Masked(tfd.Normal(tf.zeros([20]), 1), tf.sequence_mask(15, 20))
    x = np.linspace(-.1, .1, 20).astype(np.float32)
    self.assertAllClose(
        tf.pad(tfd.Normal(0, 1).log_prob(x[:15]), [[0, 5]]),
        d.log_prob(x))

    d = tfd.Masked(tfd.Normal(tf.zeros([20]), 1), (tf.range(20) % 2) > 0)
    x = np.linspace(-.1, .1, 20).astype(np.float32)
    self.assertAllClose(tf.zeros(10), d.log_prob(x)[::2])
    self.assertAllClose(tfd.Normal(0, 1).log_prob(x)[1::2], d.log_prob(x)[1::2])
예제 #3
0
 def test_kl(self):
   a = tfd.MultivariateNormalDiag([tf.range(3.)] * 4, tf.ones(3))
   b = tfd.MultivariateNormalDiag([tf.range(3.) + .5] * 4, tf.ones(3))
   kl = tfd.kl_divergence(tfd.Masked(a, tf.sequence_mask(3, 4)),
                          tfd.Masked(b, tf.sequence_mask(2, 4)))
   kl2 = tfd.kl_divergence(tfd.Masked(a, tf.sequence_mask(2, 4)),
                           tfd.Masked(b, tf.sequence_mask(3, 4)))
   self.assertAllClose(a.kl_divergence(b)[:2], kl[:2])
   self.assertAllEqual(float('nan'), kl[2])
   self.assertAllEqual(0., kl[3])
   self.assertAllEqual(float('nan'), kl2[2])
예제 #4
0
 def test_log_prob_ratio(self):
   p = tfd.Masked(tfd.MultivariateNormalDiag(tf.zeros([4, 3000]),
                                             tf.ones([3000])),
                  tf.sequence_mask(2, 4))
   q = tfd.Masked(tfd.MultivariateNormalDiag(tf.zeros([4, 3000]),
                                             tf.ones([3000])),
                  [True, True, False, False])
   stream = test_util.test_seed_stream()
   x = p.sample(seed=stream())
   y = x + 1e-5 * q.sample(seed=stream())
   x, y = self.evaluate((x, y))  # Avoids different samples across evals.
   normal64 = tfd.Normal(0, 1)
   self.assertAllClose(
       tf.reduce_sum(normal64.log_prob(x) - normal64.log_prob(y),
                     axis=-1) * tf.sequence_mask(2, 4, dtype=tf.float32),
       tf.cast(tfde.log_prob_ratio(p, x, q, y), tf.float64))
예제 #5
0
  def test_sampling(self):
    d = tfd.Masked(tfd.Normal(tf.zeros([20]), 1), tf.sequence_mask(15, 20),
                   safe_sample_fn=tfd.Distribution.mode)
    self.assertAllEqual(tf.zeros([5]),  # distribution mode
                        d.sample(seed=test_util.test_seed())[15:])

    # Gamma doesn't have mode for concentration<1.
    # Let's verify we don't trigger an assert.
    d = tfd.Masked(tfd.Gamma(.9, tf.ones([20]), allow_nan_stats=False),
                   tf.sequence_mask(13, 20),
                   safe_sample_fn=tfd.Distribution.mean)
    self.assertAllEqual(d.distribution.mean()[15:],
                        d.sample(seed=test_util.test_seed())[15:])

    d = tfd.Masked(tfd.Gamma(.9, tf.ones([20]), allow_nan_stats=False),
                   tf.sequence_mask(13, 20))
    self.evaluate(d.sample(seed=test_util.test_seed()))
예제 #6
0
 def test_event_space_bijector_fldj(self):
   # Also test forward log det jacobian for the default event space
   # bijector in the same setting, for completeness.
   sub_d = tfd.ExpRelaxedOneHotCategorical(
       logits=[0., 0., 0.],
       temperature=[0.01, 0.01, 0.01, 0.01])
   d = tfd.Masked(sub_d, validity_mask=False, validate_args=True)
   bij = d.experimental_default_event_space_bijector()
   fldj = bij.forward_log_det_jacobian(tf.zeros(shape=[4, 2]))
   self.assertAllEqual(fldj, tf.zeros_like(fldj))
예제 #7
0
 def test_batch_mask(self):
   d = tfd.Masked(tfd.Normal(tf.zeros([20]), 1),
                  tf.sequence_mask([15, 17], 20))
   self.assertEqual((2, 20), d.batch_shape)
   x = np.linspace(-.1, .1, 20).astype(np.float32)
   self.assertAllClose(
       tf.stack([
           tf.pad(tfd.Normal(0, 1).log_prob(x[:15]), [[0, 5]]),
           tf.pad(tfd.Normal(0, 1).log_prob(x[:17]), [[0, 3]])]),
       d.log_prob(x))
예제 #8
0
 def test_event_space_bijector(self):
   # Test that the default event space bijector executes.  This is
   # non-trivial, because the event space bijector of this particular
   # component distribution cannot be relied upon to produce finite
   # values in the unconstrained space from samples of `sub_d`.
   sub_d = tfd.ExpRelaxedOneHotCategorical(
       logits=[0., 0., 0.],
       temperature=[0.01, 0.01, 0.01, 0.01],
       validate_args=True)
   d = tfd.Masked(sub_d, validity_mask=False, validate_args=True)
   bij = d.experimental_default_event_space_bijector()
   x = bij(tf.zeros(shape=[4, 2]))
   # The error tested for manifests as failed validations due to
   # invalid values.
   self.assertAllNotNan(self.evaluate(x))
예제 #9
0
 def test_shape(self):
   d = tfd.Masked(tfd.Normal(tf.zeros([7, 20]), 1), tf.sequence_mask(15, 20))
   self.assertEqual((), d.event_shape)
   self.assertAllEqual([], d.event_shape_tensor())
   self.assertEqual((7, 20), d.batch_shape)
   self.assertAllEqual([7, 20], d.batch_shape_tensor())
예제 #10
0
 def f(loc):
   d = tfd.Independent(
       tfd.Masked(tfd.LogNormal(loc, 1), [True, False, False]),
       reinterpreted_batch_ndims=1)
   return d.log_prob([1.1, 0., float('nan')])