def testStddev(self): base_stddev = 2. shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32) expected_stddev = tf.abs(base_stddev * scale) normal = self._cls()( distribution=tfd.Normal(loc=tf.zeros_like(shift), scale=base_stddev * tf.ones_like(scale), validate_args=True), bijector=tfb.Chain( [tfb.Shift(shift=shift), tfb.Scale(scale=scale)], validate_args=True), validate_args=True) self.assertAllClose(expected_stddev, normal.stddev()) self.assertAllClose(expected_stddev**2, normal.variance()) split_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.Split(3), validate_args=True) self.assertAllCloseNested( tf.split(expected_stddev, num_or_size_splits=3, axis=-1), split_normal.stddev()) scaled_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.ScaleMatvecTriL([[1., 0.], [-1., 2.]]), validate_args=True) with self.assertRaisesRegex(NotImplementedError, 'is a multivariate transformation'): scaled_normal.stddev()
def testCompareToBijector(self): """Demonstrates equivalence between TD, Bijector approach and AR dist.""" sample_shape = np.int32([4, 5]) batch_shape = np.int32([]) event_size = np.int32(2) batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0) sample0 = tf.zeros(batch_event_shape) affine = tfb.ScaleMatvecTriL( scale_tril=self._random_scale_tril(event_size), validate_args=True) ar = tfd.Autoregressive(self._normal_fn(affine), sample0, validate_args=True) ar_flow = tfb.MaskedAutoregressiveFlow( is_constant_jacobian=True, shift_and_log_scale_fn=lambda x: [None, affine.forward(x)], validate_args=True) td = tfd.TransformedDistribution( # TODO(b/137665504): Use batch-adding meta-distribution to set the batch # shape instead of tf.zeros. distribution=tfd.Sample(tfd.Normal(tf.zeros(batch_shape), 1.), [event_size]), bijector=ar_flow, validate_args=True) x_shape = np.concatenate([sample_shape, batch_shape, [event_size]], axis=0) x = 2. * self._rng.random_sample(x_shape).astype(np.float32) - 1. td_log_prob_, ar_log_prob_ = self.evaluate( [td.log_prob(x), ar.log_prob(x)]) self.assertAllClose(td_log_prob_, ar_log_prob_, atol=0., rtol=1e-6)
def testNoBatch(self, is_static): bijector = tfb.ScaleMatvecTriL(scale_tril=[[2., 0.], [2., 2.]]) x = self.maybe_static([[1., 2.]], is_static) self.assertAllClose([[2., 6.]], bijector.forward(x)) self.assertAllClose([[.5, .5]], bijector.inverse(x)) self.assertAllClose( -np.abs(np.log(4.)), bijector.inverse_log_det_jacobian(x, event_ndims=1))
def testRaisesWhenSingular(self): with self.assertRaisesRegexp( Exception, '.*Singular operator: Diagonal contained zero values.*'): bijector = tfb.ScaleMatvecTriL( # Has zero on the diagonal. scale_tril=[[0., 0.], [1., 1.]], validate_args=True) self.evaluate(bijector.forward([1., 1.]))
def __init__(self, loc, chol_precision_tril, name=None): super(MVNCholPrecisionTriL, self).__init__( distribution=tfd.Independent(tfd.Normal(tf.zeros_like(loc), scale=tf.ones_like(loc)), reinterpreted_batch_ndims=1), bijector=tfb.Chain([ tfb.Shift(shift=loc), tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=chol_precision_tril, adjoint=True)), ]), name=name)
def testMVN(self, event_shape, shift, tril, dynamic_shape): if dynamic_shape and tf.executing_eagerly(): self.skipTest('Eager execution does not support dynamic shape.') as_tensor = tf.convert_to_tensor if dynamic_shape: as_tensor = lambda v, name: tf1.placeholder_with_default( # pylint: disable=g-long-lambda v, shape=None, name='dynamic_' + name) fake_mvn = tfd.TransformedDistribution( distribution=tfd.Sample( tfd.Normal(loc=as_tensor(0., name='loc'), scale=as_tensor(1., name='scale'), validate_args=True), sample_shape=as_tensor(np.int32(event_shape), name='event_shape')), bijector=tfb.Chain( [tfb.Shift(shift=as_tensor(shift, name='shift')), tfb.ScaleMatvecTriL(scale_tril=as_tensor(tril, name='scale_tril')) ]), validate_args=True) base_dist = fake_mvn.distribution expected_mean = tf.linalg.matvec( tril, tf.broadcast_to(base_dist.mean(), shift.shape)) + shift expected_cov = tf.linalg.matmul( tril, tf.matmul( tf.linalg.diag(tf.broadcast_to(base_dist.variance(), shift.shape)), tril, adjoint_b=True)) expected_batch_shape = ps.shape(expected_mean)[:-1] if dynamic_shape: self.assertAllEqual(tf.TensorShape(None), fake_mvn.event_shape) self.assertAllEqual(tf.TensorShape(None), fake_mvn.batch_shape) else: self.assertAllEqual(event_shape, fake_mvn.event_shape) self.assertAllEqual(expected_batch_shape, fake_mvn.batch_shape) # Ensure sample works by checking first, second moments. num_samples = 7e3 y = fake_mvn.sample(int(num_samples), seed=test_util.test_seed()) x = y[0:5, ...] self.assertAllClose(expected_mean, tf.reduce_mean(y, axis=0), atol=0.1, rtol=0.1) self.assertAllClose(expected_cov, tfp.stats.covariance(y, sample_axis=0), atol=0., rtol=0.1) self.assertAllEqual(event_shape, fake_mvn.event_shape_tensor()) self.assertAllEqual(expected_batch_shape, fake_mvn.batch_shape_tensor()) self.assertAllEqual( ps.concat([[5], expected_batch_shape, event_shape], axis=0), ps.shape(x)) self.assertAllClose(expected_mean, fake_mvn.mean())
def testSampleAndLogProbConsistency(self): batch_shape = np.int32([]) event_size = 2 batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0) sample0 = tf.zeros(batch_event_shape) affine = tfb.ScaleMatvecTriL( scale_tril=self._random_scale_tril(event_size), validate_args=True) ar = tfd.Autoregressive( self._normal_fn(affine), sample0, validate_args=True) self.run_test_sample_consistent_log_prob( self.evaluate, ar, num_samples=int(1e6), radius=1., center=0., rtol=0.01, seed=test_util.test_seed())
def testCovariance(self): base_scale_tril = np.array([[1., 0.], [-3., 0.2]], dtype=np.float32) base_cov = tf.matmul(base_scale_tril, base_scale_tril, adjoint_b=True) shift = np.array([[-1., 0.], [-1., -2.], [4., 5.]], dtype=np.float32) scale = np.array([[1., -2.], [2., -3.], [0.1, -2.]], dtype=np.float32) scale_matvec = np.array([[0.5, 0.], [-2., 0.7]], dtype=np.float32) normal = tfd.TransformedDistribution( distribution=tfd.MultivariateNormalTriL( loc=[0., 0.], scale_tril=base_scale_tril, validate_args=True), bijector=tfb.Chain([tfb.ScaleMatvecTriL(scale_matvec), tfb.Shift(shift=shift), tfb.Scale(scale=scale)], validate_args=True), validate_args=True) overall_scale = tf.matmul(scale_matvec, tf.linalg.diag(scale)) expected_cov = tf.matmul(overall_scale, tf.matmul(base_cov, overall_scale, adjoint_b=True)) self.assertAllClose(normal.covariance(), expected_cov)
def _testMVN(self, base_distribution_class, base_distribution_kwargs, event_shape=()): # Base distribution shapes must be compatible w/bijector; most bijectors are # batch_shape agnostic and only care about event_ndims. # In the case of `ScaleMatvecTriL`, if we got it wrong then it would fire an # exception due to incompatible dimensions. event_shape_var = tf.Variable(np.int32(event_shape), shape=tf.TensorShape(None), name='dynamic_event_shape') base_distribution_dynamic_kwargs = { k: tf.Variable(v, shape=tf.TensorShape(None), name='dynamic_{}'.format(k)) for k, v in base_distribution_kwargs.items() } fake_mvn_dynamic = self._cls()( distribution=tfd.Sample(base_distribution_class( validate_args=True, **base_distribution_dynamic_kwargs), sample_shape=event_shape_var), bijector=tfb.Chain([ tfb.Shift(shift=self._shift), tfb.ScaleMatvecTriL(scale_tril=self._tril) ]), validate_args=True) fake_mvn_static = self._cls()( distribution=tfd.Sample(base_distribution_class( validate_args=True, **base_distribution_kwargs), sample_shape=event_shape), bijector=tfb.Chain([ tfb.Shift(shift=self._shift), tfb.ScaleMatvecTriL(scale_tril=self._tril) ]), validate_args=True) actual_mean = np.tile(self._shift, [2, 1]) # ScaleMatvecTriL elided tile. actual_cov = np.matmul(self._tril, np.transpose(self._tril, [0, 2, 1])) def actual_mvn_log_prob(x): return np.concatenate([ [ # pylint: disable=g-complex-comprehension stats.multivariate_normal(actual_mean[i], actual_cov[i]).logpdf(x[:, i, :]) ] for i in range(len(actual_cov)) ]).T actual_mvn_entropy = np.concatenate([[ stats.multivariate_normal(actual_mean[i], actual_cov[i]).entropy() ] for i in range(len(actual_cov))]) self.assertAllEqual([3], fake_mvn_static.event_shape) self.assertAllEqual([2], fake_mvn_static.batch_shape) if not tf.executing_eagerly(): self.assertAllEqual(tf.TensorShape(None), fake_mvn_dynamic.event_shape) self.assertAllEqual(tf.TensorShape(None), fake_mvn_dynamic.batch_shape) num_samples = 7e3 for fake_mvn in [fake_mvn_static, fake_mvn_dynamic]: # Ensure sample works by checking first, second moments. y = fake_mvn.sample(int(num_samples), seed=test_util.test_seed()) x = y[0:5, ...] sample_mean = tf.reduce_mean(y, axis=0) centered_y = tf.transpose(a=y - sample_mean, perm=[1, 2, 0]) sample_cov = tf.matmul(centered_y, centered_y, transpose_b=True) / num_samples self.evaluate([ v.initializer for v in base_distribution_dynamic_kwargs.values() ] + [event_shape_var.initializer]) [ sample_mean_, sample_cov_, x_, fake_event_shape_, fake_batch_shape_, fake_log_prob_, fake_prob_, fake_mean_, fake_entropy_, ] = self.evaluate([ sample_mean, sample_cov, x, fake_mvn.event_shape_tensor(), fake_mvn.batch_shape_tensor(), fake_mvn.log_prob(x), fake_mvn.prob(x), fake_mvn.mean(), fake_mvn.entropy(), ]) self.assertAllClose(actual_mean, sample_mean_, atol=0.1, rtol=0.1) self.assertAllClose(actual_cov, sample_cov_, atol=0., rtol=0.1) # Ensure all other functions work as intended. self.assertAllEqual([5, 2, 3], x_.shape) self.assertAllEqual([3], fake_event_shape_) self.assertAllEqual([2], fake_batch_shape_) self.assertAllClose(actual_mvn_log_prob(x_), fake_log_prob_, atol=0., rtol=1e-6) self.assertAllClose(np.exp(actual_mvn_log_prob(x_)), fake_prob_, atol=0., rtol=1e-5) self.assertAllClose(actual_mean, fake_mean_, atol=0., rtol=1e-6) self.assertAllClose(actual_mvn_entropy, fake_entropy_, atol=0., rtol=1e-6)