def testCdfDescendingChained(self): bij1 = tfb.Shift(shift=1.)(tfb.Scale(scale=[1., -2.])) bij2 = tfb.Shift(shift=1.)(tfb.Scale(scale=[[3.], [-5.]])) bij3 = tfb.Shift(shift=1.)(tfb.Scale(scale=[[[7.]], [[-11.]]])) for chain in bij2(bij1), bij3(bij2(bij1)): td = self._cls()( distribution=tfd.Normal(loc=0., scale=tf.ones([2, 2, 2])), bijector=chain, validate_args=True) nd = tfd.Normal(loc=1., scale=3., validate_args=True) self.assertAllEqual(tf.ones(td.batch_shape, dtype=tf.bool), td.cdf(nd.quantile(.4)) < td.cdf(nd.quantile(.6)), msg=chain.name)
def testStddev(self): base_stddev = 2. shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32) expected_stddev = tf.abs(base_stddev * scale) normal = self._cls()( distribution=tfd.Normal(loc=tf.zeros_like(shift), scale=base_stddev * tf.ones_like(scale), validate_args=True), bijector=tfb.Chain( [tfb.Shift(shift=shift), tfb.Scale(scale=scale)], validate_args=True), validate_args=True) self.assertAllClose(expected_stddev, normal.stddev()) self.assertAllClose(expected_stddev**2, normal.variance()) split_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.Split(3), validate_args=True) self.assertAllCloseNested( tf.split(expected_stddev, num_or_size_splits=3, axis=-1), split_normal.stddev()) scaled_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.ScaleMatvecTriL([[1., 0.], [-1., 2.]]), validate_args=True) with self.assertRaisesRegex(NotImplementedError, 'is a multivariate transformation'): scaled_normal.stddev()
def testMatchWithAffineTransform(self): direct_bj = tfb.Tanh() indirect_bj = tfb.Chain([ tfb.Shift(tf.cast(-1.0, dtype=tf.float64)), tfb.Scale(tf.cast(2.0, dtype=tf.float64)), tfb.Sigmoid(), tfb.Scale(tf.cast(2.0, dtype=tf.float64)) ]) x = np.linspace(-3.0, 3.0, 100) y = np.tanh(x) self.assertAllClose(self.evaluate(direct_bj.forward(x)), self.evaluate(indirect_bj.forward(x))) self.assertAllClose(self.evaluate(direct_bj.inverse(y)), self.evaluate(indirect_bj.inverse(y))) self.assertAllClose( self.evaluate(direct_bj.inverse_log_det_jacobian(y, event_ndims=0)), self.evaluate( indirect_bj.inverse_log_det_jacobian(y, event_ndims=0))) self.assertAllClose( self.evaluate(direct_bj.forward_log_det_jacobian(x, event_ndims=0)), self.evaluate( indirect_bj.forward_log_det_jacobian(x, event_ndims=0)))
def testBijectorWithDeepStructure(self): bij = tfb.JointMap({ 'a': tfb.Exp(), 'bc': tfb.JointMap([tfb.Scale(2.), tfb.Shift(3.)]) }) a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32) # shape=[1, 2, 2] b = np.asarray([[0, 4]], dtype=np.float32) # shape=[1, 2] c = np.asarray([[5, 6]], dtype=np.float32) # shape=[1, 2] inputs = { 'a': a, 'bc': [b, c] } # Could be inputs to forward or inverse. event_ndims = {'a': 1, 'bc': [0, 0]} self.assertStartsWith(bij.name, 'jointmap_of_exp_and_jointmap_of_') self.assertAllCloseNested({ 'a': np.exp(a), 'bc': [b * 2., c + 3] }, self.evaluate(bij.forward(inputs))) self.assertAllCloseNested({ 'a': np.log(a), 'bc': [b / 2., c - 3] }, self.evaluate(bij.inverse(inputs))) fldj = self.evaluate(bij.forward_log_det_jacobian(inputs, event_ndims)) self.assertEqual((1, 2), fldj.shape) self.assertAllClose(np.sum(a, axis=-1) + np.log(2), fldj) ildj = self.evaluate(bij.inverse_log_det_jacobian(inputs, event_ndims)) self.assertEqual((1, 2), ildj.shape) self.assertAllClose(-np.log(a).sum(axis=-1) - np.log(2), ildj)
def testQuantileDescending(self): td = self._cls()(distribution=tfd.Normal(loc=0., scale=[1., 1.]), bijector=tfb.Shift(shift=1.)( tfb.Scale(scale=[2., -2.])), validate_args=True) self.assertAllEqual(tf.ones(td.batch_shape, dtype=tf.bool), td.quantile(.8) < td.quantile(.9))
def testBijector(self): low = np.array([[-3.], [0.], [5.]]).astype(np.float32) high = 12. bijector = tfb.Sigmoid(low=low, high=high, validate_args=True) equivalent_bijector = tfb.Chain( [tfb.Shift(shift=low), tfb.Scale(scale=high - low), tfb.Sigmoid()]) x = [[[1., 2., -5., -0.3]]] y = self.evaluate(equivalent_bijector.forward(x)) self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x, self.evaluate(bijector.inverse(y)[..., :1, :]), rtol=1e-5) self.assertAllClose( self.evaluate( equivalent_bijector.inverse_log_det_jacobian(y, event_ndims=1)), self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=1)), rtol=1e-5) self.assertAllClose( self.evaluate( equivalent_bijector.forward_log_det_jacobian(x, event_ndims=1)), self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=1)))
def test_batch_broadcast_vector_to_parts(self): batch_shape = [4, 2] true_split_sizes = [1, 3, 2] base_event_size = sum(true_split_sizes) # Base dist with no batch shape (will require broadcasting). base_dist = tfd.MultivariateNormalDiag( loc=tf.random.normal([base_event_size], seed=test_util.test_seed()), scale_diag=tf.exp(tf.random.normal([base_event_size], seed=test_util.test_seed()))) # Bijector with batch shape in one part. bijector = tfb.Chain([tfb.JointMap([tfb.Identity(), tfb.Identity(), tfb.Shift( tf.ones(batch_shape + [true_split_sizes[-1]]))]), tfb.Split(true_split_sizes, axis=-1)]) split_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertAllEqual(split_dist.batch_shape, batch_shape) # Because one branch of the split has batch shape, TD should feed batches # of base samples *into* the split, so the batch shape propagates to all # branches. xs = split_dist.sample(seed=test_util.test_seed()) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, xs), [batch_shape + [d] for d in true_split_sizes])
def testBatch(self, is_static): shift = tfb.Shift([[2., -.5], [1., -3.]]) x = self.maybe_static([1., 1.], is_static) self.assertAllClose([[3., .5], [2., -2.]], shift.forward(x)) self.assertAllClose([[-1., 1.5], [0., 4.]], shift.inverse(x)) self.assertAllClose(0., shift.inverse_log_det_jacobian(x, event_ndims=1))
def testNoBatch(self, is_static): shift = tfb.Shift([1., -1.]) x = self.maybe_static([1., 1.], is_static) self.assertAllClose([2., 0.], shift.forward(x)) self.assertAllClose([0., 2.], shift.inverse(x)) self.assertAllClose(0., shift.inverse_log_det_jacobian(x, event_ndims=1))
def testCdfDescending(self): td = tfd.TransformedDistribution( distribution=tfd.Normal(loc=0., scale=[1., 1.]), bijector=tfb.Shift(shift=1.)(tfb.Scale(scale=[2., -2.])), validate_args=True) nd = tfd.Normal(loc=1., scale=2., validate_args=True) self.assertAllEqual(tf.ones(td.batch_shape, dtype=tf.bool), td.cdf(nd.quantile(.8)) < td.cdf(nd.quantile(.9)))
def testLDJRatio(self): q = tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.Scale(2.), 'c': tfb.Shift(3.) }) p = tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.Scale(3.), 'c': tfb.Shift(4.) }) a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32) # shape=[1, 2, 2] b = np.asarray([[0, 4]], dtype=np.float32) # shape=[1, 2] c = np.asarray([[5, 6]], dtype=np.float32) # shape=[1, 2] x = {'a': a, 'b': b, 'c': c} y = {'a': a + 1, 'b': b + 1, 'c': c + 1} event_ndims = {'a': 1, 'b': 0, 'c': 0} fldj_ratio_true = p.forward_log_det_jacobian( x, event_ndims) - q.forward_log_det_jacobian(y, event_ndims) fldj_ratio = ldj_ratio.forward_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(fldj_ratio_true, fldj_ratio) ildj_ratio_true = p.inverse_log_det_jacobian( x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims) ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(ildj_ratio_true, ildj_ratio) event_ndims = {'a': 1, 'b': 2, 'c': 0} fldj_ratio_true = p.forward_log_det_jacobian( x, event_ndims) - q.forward_log_det_jacobian(y, event_ndims) fldj_ratio = ldj_ratio.forward_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(fldj_ratio_true, fldj_ratio) ildj_ratio_true = p.inverse_log_det_jacobian( x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims) ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(ildj_ratio_true, ildj_ratio)
def default_bijector(cls, dtype: Any = None, **kwargs) -> tfb.Bijector: """ Affine bijection between $[[0, 1], [0, 1]] <--> [[-2.5, 2.5], [-1.0, 2.0]]$ """ if dtype is None: dtype = default_float() scale = tfb.Scale(tf.convert_to_tensor([5.0, 3.0], dtype=dtype)) shift = tfb.Shift(tf.convert_to_tensor([-0.5, -1 / 3], dtype=dtype)) return tfb.Chain([scale, shift])
def _default_event_space_bijector(self): """The bijector maps a zero-dimensional null Tensor input to `self.loc`.""" # The shape of the pulled back null tensor will be `self.loc.shape + (0,)`. # First we pad to a tensor of zeros with shape `self.loc.shape + (1,)`. pad_zero = tfb.Pad([(1, 0)]) # Next, we squeeze to a tensor of zeros with shape matching `self.loc`. zeros_squeezed = tfb.Reshape([], event_shape_in=[1])(pad_zero) # Finally, we shift the zeros by `self.loc`. return tfb.Shift(self.loc)(zeros_squeezed)
def testSharedCaching(self): for fwd in [ tfb.Exp(), tfb.Shift(2.), ]: x = tf.constant([0.5, -1.], dtype=tf.float32) inv = tfb.Invert(fwd) y = fwd.forward(x) self.assertIs(inv.forward(y), x) self.assertIs(inv.inverse(x), y)
def testMode(self): dist = self._cls()(tfd.Beta(concentration1=[5., 10.], concentration0=15., validate_args=True), tfb.Shift(2., validate_args=True)(tfb.Scale( 10., validate_args=True)), validate_args=True) self.assertAllClose(2. + 10. * dist.distribution.mode(), self.evaluate(dist.mode()), atol=0., rtol=1e-6)
def default_bijector(cls, dtype: Any = None, **kwargs) -> tfb.Bijector: """ Affine bijection between $[[0, 1]]^4 <--> [[-20, 120]] x [[0, 1]]^3$ """ if dtype is None: dtype = default_float() scale = tfb.Scale(tf.convert_to_tensor([140] + 3 * [1], dtype=dtype)) shift = tfb.Shift(tf.convert_to_tensor([-1 / 7] + 3 * [0], dtype=dtype)) return tfb.Chain([scale, shift])
def __init__(self, loc, chol_precision_tril, name=None): super(MVNCholPrecisionTriL, self).__init__( distribution=tfd.Independent(tfd.Normal(tf.zeros_like(loc), scale=tf.ones_like(loc)), reinterpreted_batch_ndims=1), bijector=tfb.Chain([ tfb.Shift(shift=loc), tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=chol_precision_tril, adjoint=True)), ]), name=name)
def _bijector_fn(x): if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, ...] reshape_output = lambda x: x[0] else: reshape_output = lambda x: x shift, logit_gate = tf.unstack(layer(x), axis=-1) shift = reshape_output(shift) logit_gate = reshape_output(logit_gate) gate = tf.nn.sigmoid(logit_gate) return tfb.Shift(shift=(1. - gate) * shift)(tfb.Scale(scale=gate))
def bijector_fn(x): """Banana transform.""" batch_shape = ps.shape(x)[:-1] shift = tf.concat( [ tf.zeros(ps.concat([batch_shape, [1]], axis=0)), curvature * (tf.square(x[..., :1]) - 100), tf.zeros(ps.concat([batch_shape, [ndims - 2]], axis=0)), ], axis=-1, ) return tfb.Shift(shift)
def testMVN(self, event_shape, shift, tril, dynamic_shape): if dynamic_shape and tf.executing_eagerly(): self.skipTest('Eager execution does not support dynamic shape.') as_tensor = tf.convert_to_tensor if dynamic_shape: as_tensor = lambda v, name: tf1.placeholder_with_default( # pylint: disable=g-long-lambda v, shape=None, name='dynamic_' + name) fake_mvn = tfd.TransformedDistribution( distribution=tfd.Sample( tfd.Normal(loc=as_tensor(0., name='loc'), scale=as_tensor(1., name='scale'), validate_args=True), sample_shape=as_tensor(np.int32(event_shape), name='event_shape')), bijector=tfb.Chain( [tfb.Shift(shift=as_tensor(shift, name='shift')), tfb.ScaleMatvecTriL(scale_tril=as_tensor(tril, name='scale_tril')) ]), validate_args=True) base_dist = fake_mvn.distribution expected_mean = tf.linalg.matvec( tril, tf.broadcast_to(base_dist.mean(), shift.shape)) + shift expected_cov = tf.linalg.matmul( tril, tf.matmul( tf.linalg.diag(tf.broadcast_to(base_dist.variance(), shift.shape)), tril, adjoint_b=True)) expected_batch_shape = ps.shape(expected_mean)[:-1] if dynamic_shape: self.assertAllEqual(tf.TensorShape(None), fake_mvn.event_shape) self.assertAllEqual(tf.TensorShape(None), fake_mvn.batch_shape) else: self.assertAllEqual(event_shape, fake_mvn.event_shape) self.assertAllEqual(expected_batch_shape, fake_mvn.batch_shape) # Ensure sample works by checking first, second moments. num_samples = 7e3 y = fake_mvn.sample(int(num_samples), seed=test_util.test_seed()) x = y[0:5, ...] self.assertAllClose(expected_mean, tf.reduce_mean(y, axis=0), atol=0.1, rtol=0.1) self.assertAllClose(expected_cov, tfp.stats.covariance(y, sample_axis=0), atol=0., rtol=0.1) self.assertAllEqual(event_shape, fake_mvn.event_shape_tensor()) self.assertAllEqual(expected_batch_shape, fake_mvn.batch_shape_tensor()) self.assertAllEqual( ps.concat([[5], expected_batch_shape, event_shape], axis=0), ps.shape(x)) self.assertAllClose(expected_mean, fake_mvn.mean())
def get_trainable_shift_bijector(flat_event_size, init_loc_unconstrained, dtype=DEFAULT_FLOAT_DTYPE_TF): return tfb.JointMap( tf.nest.map_structure( lambda s, init: tfb.Shift( tf.Variable( tf.random.uniform( (s, ), minval=-2.0, maxval=2.0, dtype=dtype) if init is None else init)), flat_event_size, init_loc_unconstrained, ))
def _bijector_fn(x, output_units): if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, ...] reshape_output = lambda x: x[0] else: reshape_output = lambda x: x out = tf1.layers.dense(inputs=x, units=2 * output_units) shift, logit_gate = tf.split(out, 2, axis=-1) shift = reshape_output(shift) logit_gate = reshape_output(logit_gate) gate = tf.nn.sigmoid(logit_gate) return tfb.Shift(shift=(1. - gate) * shift)(tfb.Scale(scale=gate))
def testTransformedNormalNormalKL(self): batch_size = 6 mu_a = np.array([3.0] * batch_size).astype(np.float32) sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]).astype(np.float32) mu_b = np.array([-3.0] * batch_size).astype(np.float32) sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]).astype(np.float32) n_a = tfd.Normal(loc=mu_a, scale=sigma_a, validate_args=True) n_b = tfd.Normal(loc=mu_b, scale=sigma_b, validate_args=True) kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) bij1 = tfb.Shift(shift=1.)(tfb.Scale(scale=2.)) bij2 = (tfb.Shift(shift=np.array(2., dtype=np.float32)) (tfb.Scale(scale=np.array(3., dtype=np.float32)))) bij3 = tfb.Tanh() for chain in bij2(bij1), bij3(bij2(bij1)): td_a = tfd.TransformedDistribution( distribution=n_a, bijector=chain, validate_args=True) td_b = tfd.TransformedDistribution( distribution=n_b, bijector=copy.copy(chain), validate_args=True) kl = tfd.kl_divergence(td_a, td_b) kl_val = self.evaluate(kl) x = td_a.sample(int(1e5), seed=test_util.test_seed()) kl_sample = tf.reduce_mean(td_a.log_prob(x) - td_b.log_prob(x), axis=0) kl_sample_ = self.evaluate(kl_sample) self.assertEqual(kl.shape, (batch_size,)) self.assertAllClose(kl_val, kl_expected) self.assertAllClose(kl_expected, kl_sample_, atol=0.0, rtol=1e-2)
def testMean(self): shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) diag = np.array([[1, 2, 3], [2, 3, 2]], dtype=np.float32) fake_mvn = self._cls()( tfd.MultivariateNormalDiag( loc=tf.zeros_like(shift), scale_diag=tf.ones_like(diag), validate_args=True), tfb.Chain([ tfb.Shift(shift=shift), tfb.ScaleMatvecLinearOperator( scale=tf.linalg.LinearOperatorDiag(diag, is_non_singular=True)) ], validate_args=True), validate_args=True) self.assertAllClose(shift, self.evaluate(fake_mvn.mean()))
def test_composition_str_and_repr_match_expected_dynamic_shape(self): bij = tfb.Chain([ tfb.Exp(), tfb.Shift(self._tensor([1., 2.])), tfb.SoftmaxCentered() ]) self.assertContainsInOrder([ 'tfp.bijectors.Chain(', ('min_event_ndims=1, bijectors=[Exp, Shift, SoftmaxCentered])') ], str(bij)) self.assertContainsInOrder([ '<tfp.bijectors.Chain ', ('batch_shape=? forward_min_event_ndims=1 inverse_min_event_ndims=1 ' 'dtype_x=float32 dtype_y=float32 bijectors=[<tfp.bijectors.Exp'), '>, <tfp.bijectors.Shift', '>, <tfp.bijectors.SoftmaxCentered', '>]>' ], repr(bij)) bij = tfb.Chain([ tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.ScaleMatvecDiag(self._tensor([2., 2.])) }), tfb.Restructure({ 'a': 0, 'b': 1 }, [0, 1]), tfb.Split(2), tfb.Invert(tfb.SoftmaxCentered()), ]) self.assertContainsInOrder([ 'tfp.bijectors.Chain(', ('forward_min_event_ndims=1, ' 'inverse_min_event_ndims={a: 1, b: 1}, ' 'bijectors=[JointMap({a: Exp, b: ScaleMatvecDiag}), ' 'Restructure, Split, Invert(SoftmaxCentered)])') ], str(bij)) self.assertContainsInOrder([ '<tfp.bijectors.Chain ', ('batch_shape=? forward_min_event_ndims=1 ' "inverse_min_event_ndims={'a': 1, 'b': 1} dtype_x=float32 " "dtype_y={'a': ?, 'b': float32} " "bijectors=[<tfp.bijectors.JointMap "), '>, <tfp.bijectors.Restructure', '>, <tfp.bijectors.Split', '>, <tfp.bijectors.Invert', '>]>' ], repr(bij))
def testEntropy(self): shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) diag = np.array([[1, 2, 3], [2, 3, 2]], dtype=np.float32) actual_mvn_entropy = np.concatenate( [[stats.multivariate_normal(shift[i], np.diag(diag[i]**2)).entropy()] for i in range(len(diag))]) fake_mvn = self._cls()( tfd.MultivariateNormalDiag( loc=tf.zeros_like(shift), scale_diag=tf.ones_like(diag), validate_args=True), tfb.Chain([ tfb.Shift(shift=shift), tfb.ScaleMatvecLinearOperator( scale=tf.linalg.LinearOperatorDiag(diag, is_non_singular=True)) ], validate_args=True), validate_args=True) self.assertAllClose(actual_mvn_entropy, self.evaluate(fake_mvn.entropy()))
def _as_trainable_family(distribution): """Substitutes prior distributions with more easily trainable ones.""" with tf.name_scope('as_trainable_family'): if isinstance(distribution, half_normal.HalfNormal): return truncated_normal.TruncatedNormal(loc=0., scale=distribution.scale, low=0., high=distribution.scale * 10.) elif isinstance(distribution, uniform.Uniform): return tfb.Shift(distribution.low)( tfb.Scale(distribution.high - distribution.low)(beta.Beta( concentration0=tf.ones(distribution.event_shape_tensor(), dtype=distribution.dtype), concentration1=1.))) else: return distribution
def testNumericallySuperiorToEquivalentChain(self): x = np.array([-5., 3., 17., 23.]).astype(np.float32) low = -0.08587775 high = 0.12498104 bijector = tfb.Sigmoid(low=low, high=high, validate_args=True) equivalent_bijector = tfb.Chain( [tfb.Shift(shift=low), tfb.Scale(scale=high - low), tfb.Sigmoid()]) self.assertAllLessEqual(self.evaluate(bijector.forward(x)), high) # The mathematically equivalent `Chain` bijector can return values greater # than the intended upper bound of `high`. self.assertTrue( (self.evaluate(equivalent_bijector.forward(x)) > high).any())
def stochastic_volatility_prior_fn(num_timesteps): """Generative process for the stochastic volatility model.""" persistence_of_volatility = yield Root( tfb.Shift(-1.)(tfb.Scale(2.)(tfd.Beta( concentration1=20., concentration0=1.5, name='persistence_of_volatility')))) mean_log_volatility = yield Root( tfd.Cauchy(loc=0., scale=5., name='mean_log_volatility')) white_noise_shock_scale = yield Root( tfd.HalfCauchy(loc=0., scale=2., name='white_noise_shock_scale')) _ = yield tfd.JointDistributionCoroutine(functools.partial( autoregressive_series_fn, num_timesteps=num_timesteps, mean=mean_log_volatility, noise_scale=white_noise_shock_scale, persistence=persistence_of_volatility), name='log_volatility')
def testCovariance(self): base_scale_tril = np.array([[1., 0.], [-3., 0.2]], dtype=np.float32) base_cov = tf.matmul(base_scale_tril, base_scale_tril, adjoint_b=True) shift = np.array([[-1., 0.], [-1., -2.], [4., 5.]], dtype=np.float32) scale = np.array([[1., -2.], [2., -3.], [0.1, -2.]], dtype=np.float32) scale_matvec = np.array([[0.5, 0.], [-2., 0.7]], dtype=np.float32) normal = tfd.TransformedDistribution( distribution=tfd.MultivariateNormalTriL( loc=[0., 0.], scale_tril=base_scale_tril, validate_args=True), bijector=tfb.Chain([tfb.ScaleMatvecTriL(scale_matvec), tfb.Shift(shift=shift), tfb.Scale(scale=scale)], validate_args=True), validate_args=True) overall_scale = tf.matmul(scale_matvec, tf.linalg.diag(scale)) expected_cov = tf.matmul(overall_scale, tf.matmul(base_cov, overall_scale, adjoint_b=True)) self.assertAllClose(normal.covariance(), expected_cov)