def testStddev(self):
        base_stddev = 2.
        shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
        scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32)
        expected_stddev = tf.abs(base_stddev * scale)
        normal = self._cls()(
            distribution=tfd.Normal(loc=tf.zeros_like(shift),
                                    scale=base_stddev * tf.ones_like(scale),
                                    validate_args=True),
            bijector=tfb.Chain(
                [tfb.Shift(shift=shift),
                 tfb.Scale(scale=scale)],
                validate_args=True),
            validate_args=True)
        self.assertAllClose(expected_stddev, normal.stddev())
        self.assertAllClose(expected_stddev**2, normal.variance())

        split_normal = self._cls()(distribution=tfd.Independent(
            normal, reinterpreted_batch_ndims=1),
                                   bijector=tfb.Split(3),
                                   validate_args=True)
        self.assertAllCloseNested(
            tf.split(expected_stddev, num_or_size_splits=3, axis=-1),
            split_normal.stddev())

        scaled_normal = self._cls()(distribution=tfd.Independent(
            normal, reinterpreted_batch_ndims=1),
                                    bijector=tfb.ScaleMatvecTriL([[1., 0.],
                                                                  [-1., 2.]]),
                                    validate_args=True)
        with self.assertRaisesRegex(NotImplementedError,
                                    'is a multivariate transformation'):
            scaled_normal.stddev()
示例#2
0
 def testCompareToBijector(self):
     """Demonstrates equivalence between TD, Bijector approach and AR dist."""
     sample_shape = np.int32([4, 5])
     batch_shape = np.int32([])
     event_size = np.int32(2)
     batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0)
     sample0 = tf.zeros(batch_event_shape)
     affine = tfb.ScaleMatvecTriL(
         scale_tril=self._random_scale_tril(event_size), validate_args=True)
     ar = tfd.Autoregressive(self._normal_fn(affine),
                             sample0,
                             validate_args=True)
     ar_flow = tfb.MaskedAutoregressiveFlow(
         is_constant_jacobian=True,
         shift_and_log_scale_fn=lambda x: [None, affine.forward(x)],
         validate_args=True)
     td = tfd.TransformedDistribution(
         # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
         # shape instead of tf.zeros.
         distribution=tfd.Sample(tfd.Normal(tf.zeros(batch_shape), 1.),
                                 [event_size]),
         bijector=ar_flow,
         validate_args=True)
     x_shape = np.concatenate([sample_shape, batch_shape, [event_size]],
                              axis=0)
     x = 2. * self._rng.random_sample(x_shape).astype(np.float32) - 1.
     td_log_prob_, ar_log_prob_ = self.evaluate(
         [td.log_prob(x), ar.log_prob(x)])
     self.assertAllClose(td_log_prob_, ar_log_prob_, atol=0., rtol=1e-6)
示例#3
0
 def testNoBatch(self, is_static):
     bijector = tfb.ScaleMatvecTriL(scale_tril=[[2., 0.], [2., 2.]])
     x = self.maybe_static([[1., 2.]], is_static)
     self.assertAllClose([[2., 6.]], bijector.forward(x))
     self.assertAllClose([[.5, .5]], bijector.inverse(x))
     self.assertAllClose(
         -np.abs(np.log(4.)),
         bijector.inverse_log_det_jacobian(x, event_ndims=1))
 def testRaisesWhenSingular(self):
   with self.assertRaisesRegexp(
       Exception,
       '.*Singular operator:  Diagonal contained zero values.*'):
     bijector = tfb.ScaleMatvecTriL(
         # Has zero on the diagonal.
         scale_tril=[[0., 0.], [1., 1.]],
         validate_args=True)
     self.evaluate(bijector.forward([1., 1.]))
 def __init__(self, loc, chol_precision_tril, name=None):
   super(MVNCholPrecisionTriL, self).__init__(
       distribution=tfd.Independent(tfd.Normal(tf.zeros_like(loc),
                                               scale=tf.ones_like(loc)),
                                    reinterpreted_batch_ndims=1),
       bijector=tfb.Chain([
           tfb.Shift(shift=loc),
           tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=chol_precision_tril,
                                          adjoint=True)),
       ]),
       name=name)
示例#6
0
  def testMVN(self, event_shape, shift, tril, dynamic_shape):
    if dynamic_shape and tf.executing_eagerly():
      self.skipTest('Eager execution does not support dynamic shape.')
    as_tensor = tf.convert_to_tensor
    if dynamic_shape:
      as_tensor = lambda v, name: tf1.placeholder_with_default(  # pylint: disable=g-long-lambda
          v, shape=None, name='dynamic_' + name)

    fake_mvn = tfd.TransformedDistribution(
        distribution=tfd.Sample(
            tfd.Normal(loc=as_tensor(0., name='loc'),
                       scale=as_tensor(1., name='scale'),
                       validate_args=True),
            sample_shape=as_tensor(np.int32(event_shape), name='event_shape')),
        bijector=tfb.Chain(
            [tfb.Shift(shift=as_tensor(shift, name='shift')),
             tfb.ScaleMatvecTriL(scale_tril=as_tensor(tril, name='scale_tril'))
             ]), validate_args=True)

    base_dist = fake_mvn.distribution
    expected_mean = tf.linalg.matvec(
        tril, tf.broadcast_to(base_dist.mean(), shift.shape)) + shift
    expected_cov = tf.linalg.matmul(
        tril,
        tf.matmul(
            tf.linalg.diag(tf.broadcast_to(base_dist.variance(), shift.shape)),
            tril,
            adjoint_b=True))
    expected_batch_shape = ps.shape(expected_mean)[:-1]

    if dynamic_shape:
      self.assertAllEqual(tf.TensorShape(None), fake_mvn.event_shape)
      self.assertAllEqual(tf.TensorShape(None), fake_mvn.batch_shape)
    else:
      self.assertAllEqual(event_shape, fake_mvn.event_shape)
      self.assertAllEqual(expected_batch_shape, fake_mvn.batch_shape)

    # Ensure sample works by checking first, second moments.
    num_samples = 7e3
    y = fake_mvn.sample(int(num_samples), seed=test_util.test_seed())
    x = y[0:5, ...]
    self.assertAllClose(expected_mean, tf.reduce_mean(y, axis=0),
                        atol=0.1, rtol=0.1)
    self.assertAllClose(expected_cov, tfp.stats.covariance(y, sample_axis=0),
                        atol=0., rtol=0.1)

    self.assertAllEqual(event_shape, fake_mvn.event_shape_tensor())
    self.assertAllEqual(expected_batch_shape, fake_mvn.batch_shape_tensor())
    self.assertAllEqual(
        ps.concat([[5], expected_batch_shape, event_shape], axis=0),
        ps.shape(x))
    self.assertAllClose(expected_mean, fake_mvn.mean())
示例#7
0
 def testSampleAndLogProbConsistency(self):
   batch_shape = np.int32([])
   event_size = 2
   batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0)
   sample0 = tf.zeros(batch_event_shape)
   affine = tfb.ScaleMatvecTriL(
       scale_tril=self._random_scale_tril(event_size), validate_args=True)
   ar = tfd.Autoregressive(
       self._normal_fn(affine), sample0, validate_args=True)
   self.run_test_sample_consistent_log_prob(
       self.evaluate,
       ar,
       num_samples=int(1e6),
       radius=1.,
       center=0.,
       rtol=0.01,
       seed=test_util.test_seed())
  def testCovariance(self):
    base_scale_tril = np.array([[1., 0.], [-3., 0.2]], dtype=np.float32)
    base_cov = tf.matmul(base_scale_tril, base_scale_tril, adjoint_b=True)

    shift = np.array([[-1., 0.], [-1., -2.], [4., 5.]], dtype=np.float32)
    scale = np.array([[1., -2.], [2., -3.], [0.1, -2.]], dtype=np.float32)
    scale_matvec = np.array([[0.5, 0.], [-2., 0.7]], dtype=np.float32)
    normal = tfd.TransformedDistribution(
        distribution=tfd.MultivariateNormalTriL(
            loc=[0., 0.], scale_tril=base_scale_tril, validate_args=True),
        bijector=tfb.Chain([tfb.ScaleMatvecTriL(scale_matvec),
                            tfb.Shift(shift=shift),
                            tfb.Scale(scale=scale)],
                           validate_args=True),
        validate_args=True)

    overall_scale = tf.matmul(scale_matvec, tf.linalg.diag(scale))
    expected_cov = tf.matmul(overall_scale,
                             tf.matmul(base_cov, overall_scale, adjoint_b=True))
    self.assertAllClose(normal.covariance(), expected_cov)
    def _testMVN(self,
                 base_distribution_class,
                 base_distribution_kwargs,
                 event_shape=()):
        # Base distribution shapes must be compatible w/bijector; most bijectors are
        # batch_shape agnostic and only care about event_ndims.
        # In the case of `ScaleMatvecTriL`, if we got it wrong then it would fire an
        # exception due to incompatible dimensions.
        event_shape_var = tf.Variable(np.int32(event_shape),
                                      shape=tf.TensorShape(None),
                                      name='dynamic_event_shape')

        base_distribution_dynamic_kwargs = {
            k: tf.Variable(v,
                           shape=tf.TensorShape(None),
                           name='dynamic_{}'.format(k))
            for k, v in base_distribution_kwargs.items()
        }
        fake_mvn_dynamic = self._cls()(
            distribution=tfd.Sample(base_distribution_class(
                validate_args=True, **base_distribution_dynamic_kwargs),
                                    sample_shape=event_shape_var),
            bijector=tfb.Chain([
                tfb.Shift(shift=self._shift),
                tfb.ScaleMatvecTriL(scale_tril=self._tril)
            ]),
            validate_args=True)

        fake_mvn_static = self._cls()(
            distribution=tfd.Sample(base_distribution_class(
                validate_args=True, **base_distribution_kwargs),
                                    sample_shape=event_shape),
            bijector=tfb.Chain([
                tfb.Shift(shift=self._shift),
                tfb.ScaleMatvecTriL(scale_tril=self._tril)
            ]),
            validate_args=True)

        actual_mean = np.tile(self._shift,
                              [2, 1])  # ScaleMatvecTriL elided tile.
        actual_cov = np.matmul(self._tril, np.transpose(self._tril, [0, 2, 1]))

        def actual_mvn_log_prob(x):
            return np.concatenate([
                [  # pylint: disable=g-complex-comprehension
                    stats.multivariate_normal(actual_mean[i],
                                              actual_cov[i]).logpdf(x[:, i, :])
                ] for i in range(len(actual_cov))
            ]).T

        actual_mvn_entropy = np.concatenate([[
            stats.multivariate_normal(actual_mean[i], actual_cov[i]).entropy()
        ] for i in range(len(actual_cov))])

        self.assertAllEqual([3], fake_mvn_static.event_shape)
        self.assertAllEqual([2], fake_mvn_static.batch_shape)

        if not tf.executing_eagerly():
            self.assertAllEqual(tf.TensorShape(None),
                                fake_mvn_dynamic.event_shape)
            self.assertAllEqual(tf.TensorShape(None),
                                fake_mvn_dynamic.batch_shape)

        num_samples = 7e3
        for fake_mvn in [fake_mvn_static, fake_mvn_dynamic]:
            # Ensure sample works by checking first, second moments.
            y = fake_mvn.sample(int(num_samples), seed=test_util.test_seed())
            x = y[0:5, ...]
            sample_mean = tf.reduce_mean(y, axis=0)
            centered_y = tf.transpose(a=y - sample_mean, perm=[1, 2, 0])
            sample_cov = tf.matmul(centered_y, centered_y,
                                   transpose_b=True) / num_samples
            self.evaluate([
                v.initializer
                for v in base_distribution_dynamic_kwargs.values()
            ] + [event_shape_var.initializer])
            [
                sample_mean_,
                sample_cov_,
                x_,
                fake_event_shape_,
                fake_batch_shape_,
                fake_log_prob_,
                fake_prob_,
                fake_mean_,
                fake_entropy_,
            ] = self.evaluate([
                sample_mean,
                sample_cov,
                x,
                fake_mvn.event_shape_tensor(),
                fake_mvn.batch_shape_tensor(),
                fake_mvn.log_prob(x),
                fake_mvn.prob(x),
                fake_mvn.mean(),
                fake_mvn.entropy(),
            ])

            self.assertAllClose(actual_mean, sample_mean_, atol=0.1, rtol=0.1)
            self.assertAllClose(actual_cov, sample_cov_, atol=0., rtol=0.1)

            # Ensure all other functions work as intended.
            self.assertAllEqual([5, 2, 3], x_.shape)
            self.assertAllEqual([3], fake_event_shape_)
            self.assertAllEqual([2], fake_batch_shape_)
            self.assertAllClose(actual_mvn_log_prob(x_),
                                fake_log_prob_,
                                atol=0.,
                                rtol=1e-6)
            self.assertAllClose(np.exp(actual_mvn_log_prob(x_)),
                                fake_prob_,
                                atol=0.,
                                rtol=1e-5)
            self.assertAllClose(actual_mean, fake_mean_, atol=0., rtol=1e-6)
            self.assertAllClose(actual_mvn_entropy,
                                fake_entropy_,
                                atol=0.,
                                rtol=1e-6)