def testMinEventNdimsWithPartiallyDependentJointMap(self): dependent = tfb.Chain([tfb.Split(2), tfb.Invert(tfb.Split(2))]) wrap_in_list = tfb.Restructure(input_structure=[0, 1], output_structure=[[0, 1]]) dependent_as_chain = tfb.Chain([ tfb.Invert(wrap_in_list), tfb.JointMap([dependent]), wrap_in_list]) self.assertAllEqualNested(dependent.forward_min_event_ndims, dependent_as_chain.forward_min_event_ndims) self.assertAllEqualNested(dependent.inverse_min_event_ndims, dependent_as_chain.inverse_min_event_ndims) self.assertAllEqualNested(dependent._parts_interact, dependent_as_chain._parts_interact)
def test_unknown_event_rank(self): if tf.executing_eagerly(): self.skipTest('Eager execution.') unknown_rank_dist = tfd.Independent( tfd.Normal(loc=tf.ones([2, 1, 3]), scale=2.), reinterpreted_batch_ndims=tf1.placeholder_with_default(1, shape=[])) td = tfd.TransformedDistribution( distribution=unknown_rank_dist, bijector=tfb.Scale(1.), validate_args=True) self.assertEqual(td.batch_shape, tf.TensorShape(None)) self.assertEqual(td.event_shape, tf.TensorShape(None)) self.assertAllEqual(td.batch_shape_tensor(), [2, 1]) self.assertAllEqual(td.event_shape_tensor(), [3]) joint_td = tfd.TransformedDistribution( distribution=tfd.JointDistributionSequentialAutoBatched( [unknown_rank_dist, unknown_rank_dist]), bijector=tfb.Invert(tfb.Split(2)), validate_args=True) # Note that the current behavior is conservative; we could also correctly # return a batch shape of `[]` in this case. self.assertEqual(joint_td.batch_shape, tf.TensorShape(None)) self.assertEqual(joint_td.event_shape, tf.TensorShape(None)) self.assertAllEqual(joint_td.batch_shape_tensor(), []) self.assertAllEqual(joint_td.event_shape_tensor(), [2, 1, 6])
def testAssertRaisesWrongNumSplits(self): num_splits = 4 y = [np.random.rand(2, 3)] * 3 bijector = tfb.Split(num_splits, validate_args=True) with self.assertRaisesRegexp(ValueError, "don't have the same sequence length"): self.evaluate(bijector.inverse(y))
def testEventShape(self, num_or_size_splits, expected_split_sizes): split_sizes = self.build_input(num_or_size_splits) total_size = np.sum(expected_split_sizes) shape_in_static = tf.TensorShape([total_size, 2]) shape_out_static = [ tf.TensorShape([d, 2]) for d in expected_split_sizes ] bijector = tfb.Split(num_or_size_splits=split_sizes, axis=-2, validate_args=True) output_shape = [[None, 2]] * 3 self.assertAllEqual([ s.as_list() for s in bijector.forward_event_shape(shape_in_static) ], output_shape) self.assertEqual( bijector.inverse_event_shape(shape_out_static).as_list(), shape_in_static.as_list()) self.assertAllEqual([ s.as_list() for s in bijector.forward_event_shape( tf.TensorShape([total_size, None])) ], [[None, None]] * 3) self.assertAllEqual( bijector.inverse_event_shape([[None, 3], [3, 3], [2, 3]]).as_list(), [None, 3]) self.assertAllEqual( bijector.inverse_event_shape([[2, 3], [None, 3], [2, 3]]).as_list(), [None, 3])
def _testAssertRaisesMismatchedOutputShapes(self): split_sizes = self.build_input([5, -1, 3]) y = [np.random.rand(3, 1, i) for i in [6, 2, 3]] bijector = tfb.Split(split_sizes, validate_args=True) if tf.get_static_value(split_sizes) is not None: with self.assertRaisesError('does not match expected `split_size`'): self.evaluate(bijector.inverse(y))
def testStddev(self): base_stddev = 2. shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32) expected_stddev = tf.abs(base_stddev * scale) normal = self._cls()( distribution=tfd.Normal(loc=tf.zeros_like(shift), scale=base_stddev * tf.ones_like(scale), validate_args=True), bijector=tfb.Chain( [tfb.Shift(shift=shift), tfb.Scale(scale=scale)], validate_args=True), validate_args=True) self.assertAllClose(expected_stddev, normal.stddev()) self.assertAllClose(expected_stddev**2, normal.variance()) split_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.Split(3), validate_args=True) self.assertAllCloseNested( tf.split(expected_stddev, num_or_size_splits=3, axis=-1), split_normal.stddev()) scaled_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.ScaleMatvecTriL([[1., 0.], [-1., 2.]]), validate_args=True) with self.assertRaisesRegex(NotImplementedError, 'is a multivariate transformation'): scaled_normal.stddev()
def testInverseWithEventDimsOmitted(self): bij = tfb.Split(2) self.assertAllEqual( 0.0, self.evaluate(bij.inverse_log_det_jacobian( [tf.ones((3, 4, 5)), tf.ones((3, 4, 5))])))
def testEventShape(self, num_or_size_splits, expected_split_sizes): num_or_size_splits = self.build_input(num_or_size_splits) total_size = np.sum(expected_split_sizes) shape_in_static = tf.TensorShape([total_size, 2]) shape_out_static = [ tf.TensorShape([d, 2]) for d in expected_split_sizes] bijector = tfb.Split( num_or_size_splits=num_or_size_splits, axis=-2, validate_args=True) # Test that forward_ and inverse_event_shape are correct when # event_shape_in/_out are statically known, even when the input shapes # are only partially specified. self.assertAllEqual( bijector.forward_event_shape(shape_in_static), shape_out_static) self.assertEqual( bijector.inverse_event_shape(shape_out_static), shape_in_static) # Shape is always known for splitting in eager mode, so we skip these tests. if tf.executing_eagerly(): return self.assertAllEqual( [s.as_list() for s in bijector.forward_event_shape( tf.TensorShape([total_size, None]))], [[d, None] for d in expected_split_sizes]) if bijector.split_sizes is None: static_split_sizes = tensorshape_util.constant_value_as_shape( expected_split_sizes).as_list() else: static_split_sizes = tensorshape_util.constant_value_as_shape( num_or_size_splits).as_list() static_total_size = None if None in static_split_sizes else total_size # Test correctness with an inverse input dimension of None that coincides # with the `-1` element in not-fully specified `split_sizes` shape_with_maybe_unknown_dim = ( [[None, 3]] + [[d, 3] for d in expected_split_sizes[1:]]) self.assertAllEqual( bijector.inverse_event_shape(shape_with_maybe_unknown_dim).as_list(), [static_total_size, 3]) # Test correctness with an input dimension of None that does not coincide # with a `-1` split_size. shape_with_deducable_dim = [[d, 3] for d in expected_split_sizes] shape_with_deducable_dim[2] = [None, 3] self.assertAllEqual( bijector.inverse_event_shape( shape_with_deducable_dim).as_list(), [total_size, 3]) # Test correctness for an input shape of known rank only. if bijector.split_sizes is not None: shape_with_unknown_total = ( [[d, None] for d in static_split_sizes]) else: shape_with_unknown_total = [[None, None]] * len(expected_split_sizes) self.assertAllEqual( [s.as_list() for s in bijector.forward_event_shape( tf.TensorShape([None, None]))], shape_with_unknown_total)
def test_single_part_str_repr_match_expected(self): bij = tfb.Exp() self.assertContainsInOrder( ['tfp.bijectors.Exp("exp", batch_shape=[], min_event_ndims=0)'], str(bij)) self.assertContainsInOrder( ["<tfp.bijectors.Exp 'exp' batch_shape=[] forward_min_event_ndims=0 " "inverse_min_event_ndims=0 dtype_x=? dtype_y=?>"], repr(bij)) bij = tfb.Scale([1., 1.], name='myscale') self.assertContainsInOrder( ['tfp.bijectors.Scale("myscale", batch_shape=[2], min_event_ndims=0, ' 'dtype=float32)'], str(bij)) self.assertContainsInOrder( ["<tfp.bijectors.Scale 'myscale' batch_shape=[2] " "forward_min_event_ndims=0 inverse_min_event_ndims=0 dtype_x=float32 " "dtype_y=float32>"], repr(bij)) bij = tfb.Split([3, 4, 2], name='s_p_l_i_t') self.assertContainsInOrder( ['tfp.bijectors.Split("s_p_l_i_t", batch_shape=[], ' 'forward_min_event_ndims=1, inverse_min_event_ndims=[1, 1, 1])'], str(bij)) self.assertContainsInOrder( ["<tfp.bijectors.Split 's_p_l_i_t' batch_shape=[] " "forward_min_event_ndims=1 inverse_min_event_ndims=[1, 1, 1] " "dtype_x=? dtype_y=[?, ?, ?]>" ], repr(bij))
def _testAssertRaisesTooSmallInputShape(self): split_sizes = self.build_input([-1, 2, 3]) x = tf.Variable(tf.zeros((2, 4)), shape=None) self.evaluate(x.initializer) bijector = tfb.Split(split_sizes, validate_args=True) with self.assertRaisesError('size of the input along `axis`'): self.evaluate(bijector.forward(x))
def testAssertRaisesWrongNumberOfOutputs(self): split_sizes = self.build_input([5, 3, -1]) y = [np.random.rand(2, i) for i in [5, 3, 1, 2]] bijector = tfb.Split(split_sizes, validate_args=True) with self.assertRaisesRegexp(ValueError, "don't have the same sequence length"): self.evaluate(bijector.inverse(y))
def test_batch_broadcast_vector_to_parts(self): batch_shape = [4, 2] true_split_sizes = [1, 3, 2] base_event_size = sum(true_split_sizes) # Base dist with no batch shape (will require broadcasting). base_dist = tfd.MultivariateNormalDiag( loc=tf.random.normal([base_event_size], seed=test_util.test_seed()), scale_diag=tf.exp(tf.random.normal([base_event_size], seed=test_util.test_seed()))) # Bijector with batch shape in one part. bijector = tfb.Chain([tfb.JointMap([tfb.Identity(), tfb.Identity(), tfb.Shift( tf.ones(batch_shape + [true_split_sizes[-1]]))]), tfb.Split(true_split_sizes, axis=-1)]) split_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertAllEqual(split_dist.batch_shape, batch_shape) # Because one branch of the split has batch shape, TD should feed batches # of base samples *into* the split, so the batch shape propagates to all # branches. xs = split_dist.sample(seed=test_util.test_seed()) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, xs), [batch_shape + [d] for d in true_split_sizes])
def testCovarianceNotImplemented(self): mvn = tfd.MultivariateNormalDiag(loc=[0., 0.], scale_diag=[1., 2.]) # Non-affine bijector. with self.assertRaisesRegex( NotImplementedError, '`covariance` is not implemented'): tfd.TransformedDistribution( distribution=mvn, bijector=tfb.Exp()).covariance() # Non-injective bijector. with self.assertRaisesRegex( NotImplementedError, '`covariance` is not implemented'): tfd.TransformedDistribution( distribution=mvn, bijector=tfb.AbsoluteValue()).covariance() # Non-vector event shape. with self.assertRaisesRegex( NotImplementedError, '`covariance` is only implemented'): tfd.TransformedDistribution( distribution=mvn, bijector=tfb.Reshape(event_shape_out=[2, 1], event_shape_in=[2])).covariance() # Multipart bijector. with self.assertRaisesRegex( NotImplementedError, '`covariance` is only implemented'): tfd.TransformedDistribution( distribution=mvn, bijector=tfb.Split(2)).covariance()
def testAssertRaisesWrongNumberOfOutputs(self): split_sizes = self.build_input([5, 3, -1]) y = [np.random.rand(2, i) for i in [5, 3, 1, 2]] bijector = tfb.Split(split_sizes, validate_args=True) with self.assertRaisesRegexp(ValueError, 'does not match the number of splits'): self.evaluate(bijector.inverse(y))
def testAssertRaisesWrongNumSplits(self): num_splits = 4 y = [np.random.rand(2, 3)] * 3 bijector = tfb.Split(num_splits, validate_args=True) with self.assertRaisesRegexp(ValueError, 'does not match the number of splits'): self.evaluate(bijector.inverse(y))
def testCompositeTensor(self): split_sizes = self.build_input([1, 2, 2]) bijector = tfb.Split(split_sizes, validate_args=True) x = tf.ones([3, 2, 5]) flat = tf.nest.flatten(bijector, expand_composites=True) unflat = tf.nest.pack_sequence_as(bijector, flat, expand_composites=True) self.assertAllClose( bijector.forward(x), tf.function(lambda b_: b_.forward(x))(unflat))
def test_transform_parts_to_vector(self, known_split_sizes): batch_shape = [4, 2] true_split_sizes = [1, 3, 2] # Create a joint distribution with parts of the specified sizes. seed = test_util.test_seed_stream() component_dists = tf.nest.map_structure( lambda size: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=tf.random.normal(batch_shape + [size], seed=seed()), scale_diag=tf.exp( tf.random.normal(batch_shape + [size], seed=seed()))), true_split_sizes) base_dist = tfd.JointDistributionSequential(component_dists) # Transform to a vector-valued distribution by concatenating the parts. bijector = tfb.Invert(tfb.Split(known_split_sizes, axis=-1)) with self.assertRaisesRegexp(ValueError, 'Overriding the batch shape'): tfd.TransformedDistribution(base_dist, bijector, batch_shape=[3]) with self.assertRaisesRegexp(ValueError, 'Overriding the event shape'): tfd.TransformedDistribution(base_dist, bijector, event_shape=[3]) concat_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertAllEqual(concat_dist.event_shape, [sum(true_split_sizes)]) self.assertAllEqual(self.evaluate(concat_dist.event_shape_tensor()), [sum(true_split_sizes)]) self.assertAllEqual(concat_dist.batch_shape, batch_shape) self.assertAllEqual(self.evaluate(concat_dist.batch_shape_tensor()), batch_shape) # Since the Split bijector has (constant) unit Jacobian, the transformed # entropy and mean/mode should match the base entropy and (split) base # mean/mode. self.assertAllEqual(*self.evaluate( (base_dist.entropy(), concat_dist.entropy()))) self.assertAllEqual(*self.evaluate( (concat_dist.mean(), bijector.forward(base_dist.mean())))) self.assertAllEqual(*self.evaluate( (concat_dist.mode(), bijector.forward(base_dist.mode())))) # Since the Split bijector has zero Jacobian, the transformed `log_prob` # and `prob` should match the base distribution. sample_shape = [3] x = base_dist.sample(sample_shape, seed=seed()) y = bijector.forward(x) for attr in ('log_prob', 'prob'): base_attr = getattr(base_dist, attr)(x) concat_attr = getattr(concat_dist, attr)(y) self.assertAllClose(*self.evaluate((base_attr, concat_attr))) # Test that `.sample()` works and returns a result of the expected structure # and shape. y_sampled = concat_dist.sample(sample_shape, seed=seed()) self.assertAllEqual(y.shape, y_sampled.shape)
def test_transform_vector_to_parts(self, known_split_sizes): batch_shape = [4, 2] true_split_sizes = [1, 3, 2] base_event_size = sum(true_split_sizes) base_dist = tfd.MultivariateNormalDiag( loc=tf.random.normal(batch_shape + [base_event_size], seed=test_util.test_seed()), scale_diag=tf.exp( tf.random.normal(batch_shape + [base_event_size], seed=test_util.test_seed()))) bijector = tfb.Split(known_split_sizes, axis=-1) split_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertRegex( str(split_dist), '{}.*batch_shape.*event_shape.*dtype'.format(split_dist.name)) expected_event_shape = [np.array([s]) for s in true_split_sizes] output_event_shape = [np.array(s) for s in split_dist.event_shape] self.assertAllEqual(output_event_shape, expected_event_shape) self.assertAllEqual(self.evaluate(split_dist.event_shape_tensor()), expected_event_shape) self.assertAllEqual(split_dist.batch_shape, batch_shape) self.assertAllEqual(self.evaluate(split_dist.batch_shape_tensor()), batch_shape) # Since the Split bijector has (constant) unit Jacobian, the transformed # entropy and mean/mode should match the base entropy and (split) base # mean/mode. self.assertAllEqual(*self.evaluate((base_dist.entropy(), split_dist.entropy()))) self.assertAllEqualNested( *self.evaluate((split_dist.mean(), bijector.forward(base_dist.mean())))) self.assertAllEqualNested( *self.evaluate((split_dist.mode(), bijector.forward(base_dist.mode())))) # Since the Split bijector has zero Jacobian, the transformed `log_prob` # and `prob` should match the base distribution. sample_shape = [3] x = base_dist.sample(sample_shape, seed=test_util.test_seed()) y = bijector.forward(x) for attr in ('log_prob', 'prob'): split_attr = getattr(split_dist, attr)(y) base_attr = getattr(base_dist, attr)(x) self.assertAllClose(*self.evaluate((base_attr, split_attr)), rtol=1e-5) # Test that `.sample()` works and returns a result of the expected structure # and shape. y_sampled = split_dist.sample(sample_shape, seed=test_util.test_seed()) self.assertAllEqual([x.shape for x in y], [x.shape for x in y_sampled])
def test_invert_str_and_repr_match_expected(self): bij = tfb.Invert(tfb.Split([3, 4, 2])) self.assertContainsInOrder([ 'tfp.bijectors.Invert("invert_split", batch_shape=[], ' 'forward_min_event_ndims=[1, 1, 1], inverse_min_event_ndims=1, ' 'bijector=Split)' ], str(bij)) self.assertContainsInOrder([ "<tfp.bijectors.Invert 'invert_split' batch_shape=[] " "forward_min_event_ndims=[1, 1, 1] inverse_min_event_ndims=1 " "dtype_x=[?, ?, ?] dtype_y=? " "bijector=<tfp.bijectors.Split 'split' batch_shape=[] " "forward_min_event_ndims=1 inverse_min_event_ndims=[1, 1, 1] " "dtype_x=? dtype_y=[?, ?, ?]>>" ], repr(bij))
def test_composition_str_and_repr_match_expected_dynamic_shape(self): bij = tfb.Chain([ tfb.Exp(), tfb.Shift(self._tensor([1., 2.])), tfb.SoftmaxCentered() ]) self.assertContainsInOrder([ 'tfp.bijectors.Chain(', ('min_event_ndims=1, bijectors=[Exp, Shift, SoftmaxCentered])') ], str(bij)) self.assertContainsInOrder([ '<tfp.bijectors.Chain ', ('batch_shape=? forward_min_event_ndims=1 inverse_min_event_ndims=1 ' 'dtype_x=float32 dtype_y=float32 bijectors=[<tfp.bijectors.Exp'), '>, <tfp.bijectors.Shift', '>, <tfp.bijectors.SoftmaxCentered', '>]>' ], repr(bij)) bij = tfb.Chain([ tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.ScaleMatvecDiag(self._tensor([2., 2.])) }), tfb.Restructure({ 'a': 0, 'b': 1 }, [0, 1]), tfb.Split(2), tfb.Invert(tfb.SoftmaxCentered()), ]) self.assertContainsInOrder([ 'tfp.bijectors.Chain(', ('forward_min_event_ndims=1, ' 'inverse_min_event_ndims={a: 1, b: 1}, ' 'bijectors=[JointMap({a: Exp, b: ScaleMatvecDiag}), ' 'Restructure, Split, Invert(SoftmaxCentered)])') ], str(bij)) self.assertContainsInOrder([ '<tfp.bijectors.Chain ', ('batch_shape=? forward_min_event_ndims=1 ' "inverse_min_event_ndims={'a': 1, 'b': 1} dtype_x=float32 " "dtype_y={'a': ?, 'b': float32} " "bijectors=[<tfp.bijectors.JointMap "), '>, <tfp.bijectors.Restructure', '>, <tfp.bijectors.Split', '>, <tfp.bijectors.Invert', '>]>' ], repr(bij))
def _testBijector( # pylint: disable=invalid-name self, num_or_size_splits, expected_split_sizes, shape_in, axis): """Do a basic sanity check of forward, inverse, jacobian.""" num_or_size_splits = self.build_input(num_or_size_splits) bijector = tfb.Split(num_or_size_splits, axis=axis, validate_args=True) self.assertStartsWith(bijector.name, 'split') x = np.random.rand(*shape_in) y = tf.split(x, num_or_size_splits, axis=axis) self.assertAllClose(self.evaluate(y), self.evaluate(bijector.forward(x)), atol=0., rtol=1e-2) self.assertAllClose(x, self.evaluate(bijector.inverse(y)), atol=0., rtol=1e-4) shape_out = [] for d in expected_split_sizes: s = shape_in[:] s[axis] = d shape_out.append(self.build_input(s)) shape_in_ = self.evaluate( bijector.inverse_event_shape_tensor(shape_out)) self.assertAllEqual(shape_in_, shape_in) shape_in = self.build_input(shape_in) shape_out_ = self.evaluate( bijector.forward_event_shape_tensor(shape_in)) self.assertAllEqual(shape_out_, self.evaluate(shape_out)) event_ndims = abs(axis) inverse_event_ndims = [event_ndims for _ in expected_split_sizes] self.assertEqual( 0., self.evaluate( bijector.inverse_log_det_jacobian( y, event_ndims=inverse_event_ndims))) self.assertEqual( 0., self.evaluate( bijector.forward_log_det_jacobian(x, event_ndims=event_ndims)))
def test_slice_transformed_distribution_with_chain(self): dist = tfd.TransformedDistribution( distribution=tfd.MultivariateNormalDiag( loc=tf.zeros([4]), scale_diag=tf.ones([1, 4])), bijector=tfb.Chain([tfb.JointMap([tfb.Identity(), tfb.Shift(tf.ones([4, 3, 2]))]), tfb.Split(2), tfb.ScaleMatvecDiag(tf.ones([5, 1, 3, 4])), tfb.Exp()])) self.assertAllEqual(dist.batch_shape_tensor(), [5, 4, 3]) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, dist.sample(seed=test_util.test_seed())), [[5, 4, 3, 2], [5, 4, 3, 2]]) sliced = dist[tf.newaxis, ..., 0, :, :-1] self.assertAllEqual(sliced.batch_shape_tensor(), [1, 4, 2]) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, sliced.sample(seed=test_util.test_seed())), [[1, 4, 2, 2], [1, 4, 2, 2]])
def testAssertRaisesUnknownNumSplits(self): split_sizes = tf1.placeholder_with_default([-1, 2, 1], shape=[None]) with self.assertRaisesRegexp( ValueError, 'must have a statically-known number of elements'): tfb.Split(num_or_size_splits=split_sizes, validate_args=True)
class BijectorBatchShapesTest(test_util.TestCase): @parameterized.named_parameters( ('exp', tfb.Exp, None), ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None), ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None), ('scale_matvec', lambda: tfb.ScaleMatvecDiag([[0.], [3.]]), None), ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None), ('reshape', lambda: tfb.Reshape([1], event_shape_in=[1, 1]), None), ('chain', lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]), # pylint: disable=g-long-lambda tfb.Invert(tfb.Split(2))]), None), ('jointmap_01', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]), ('jointmap_11', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]), ('jointmap_20', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]), ('jointmap_22', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]), ('restructure_with_ragged_event_ndims', lambda: tfb.Restructure(input_structure=[0, 1], # pylint: disable=g-long-lambda output_structure={'a': 0, 'b': 1}), [0, 1])) def test_batch_shape_matches_output_shapes(self, bijector_fn, override_x_event_ndims=None): bijector = bijector_fn() if override_x_event_ndims is None: x_event_ndims = bijector.forward_min_event_ndims y_event_ndims = bijector.inverse_min_event_ndims else: x_event_ndims = override_x_event_ndims y_event_ndims = bijector.forward_event_ndims(x_event_ndims) # All ways of calculating the batch shape should yield the same result. batch_shape_x = bijector.experimental_batch_shape( x_event_ndims=x_event_ndims) batch_shape_y = bijector.experimental_batch_shape( y_event_ndims=y_event_ndims) self.assertEqual(batch_shape_x, batch_shape_y) batch_shape_tensor_x = bijector.experimental_batch_shape_tensor( x_event_ndims=x_event_ndims) batch_shape_tensor_y = bijector.experimental_batch_shape_tensor( y_event_ndims=y_event_ndims) self.assertAllEqual(batch_shape_tensor_x, batch_shape_tensor_y) self.assertAllEqual(batch_shape_tensor_x, batch_shape_x) # Check that we're robust to integer type. batch_shape_tensor_x64 = bijector.experimental_batch_shape_tensor( x_event_ndims=tf.nest.map_structure(np.int64, x_event_ndims)) batch_shape_tensor_y64 = bijector.experimental_batch_shape_tensor( y_event_ndims=tf.nest.map_structure(np.int64, y_event_ndims)) self.assertAllEqual(batch_shape_tensor_x64, batch_shape_tensor_y64) self.assertAllEqual(batch_shape_tensor_x64, batch_shape_x) # Pushing a value through the bijector should return a Tensor(s) with # the expected batch shape... xs = tf.nest.map_structure(lambda nd: tf.ones([1] * nd), x_event_ndims) ys = bijector.forward(xs) for y_part, nd in zip(tf.nest.flatten(ys), tf.nest.flatten(y_event_ndims)): part_batch_shape = ps.shape(y_part)[:ps.rank(y_part) - nd] self.assertAllEqual(batch_shape_y, ps.broadcast_shape(batch_shape_y, part_batch_shape)) # ... which should also be the shape of the fldj. fldj = bijector.forward_log_det_jacobian(xs, event_ndims=x_event_ndims) self.assertAllEqual(batch_shape_y, ps.shape(fldj)) # Also check the inverse case. xs = bijector.inverse(tf.nest.map_structure(tf.identity, ys)) for x_part, nd in zip(tf.nest.flatten(xs), tf.nest.flatten(x_event_ndims)): part_batch_shape = ps.shape(x_part)[:ps.rank(x_part) - nd] self.assertAllEqual(batch_shape_x, ps.broadcast_shape(batch_shape_x, part_batch_shape)) ildj = bijector.inverse_log_det_jacobian(ys, event_ndims=y_event_ndims) self.assertAllEqual(batch_shape_x, ps.shape(ildj)) @parameterized.named_parameters( ('scale', lambda: tfb.Scale([3.14159])), ('chain', lambda: tfb.Exp()(tfb.Scale([3.14159])))) def test_ndims_specification(self, bijector_fn): bijector = bijector_fn() # If no `event_ndims` is passed, should assume min_event_ndims. self.assertAllEqual(bijector.experimental_batch_shape(), [1]) self.assertAllEqual(bijector.experimental_batch_shape_tensor(), [1]) with self.assertRaisesRegex( ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'): bijector.experimental_batch_shape(x_event_ndims=0, y_event_ndims=0) with self.assertRaisesRegex( ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'): bijector.experimental_batch_shape_tensor(x_event_ndims=0, y_event_ndims=0) @parameterized.named_parameters( ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None), ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None), ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None), ('chain', lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]), # pylint: disable=g-long-lambda tfb.Invert(tfb.Split(2))]), None), ('jointmap_01', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]), ('jointmap_11', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]), ('jointmap_20', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]), ('jointmap_22', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]), ('nested_jointmap', lambda: tfb.JointMap([tfb.JointMap({'a': tfb.Scale([1.]), # pylint: disable=g-long-lambda 'b': tfb.Exp()}), tfb.Scale([1, 4])(tfb.Invert(tfb.Split(2)))]), [{'a': 0, 'b': 0}, [2, 2]])) def test_with_broadcast_batch_shape(self, bijector_fn, x_event_ndims=None): bijector = bijector_fn() if x_event_ndims is None: x_event_ndims = bijector.forward_min_event_ndims batch_shape = bijector.experimental_batch_shape(x_event_ndims=x_event_ndims) param_batch_shapes = batch_shape_lib.batch_shape_parts( bijector, bijector_x_event_ndims=x_event_ndims) new_batch_shape = [4, 2, 1, 1, 1] broadcast_bijector = bijector._broadcast_parameters_with_batch_shape( new_batch_shape, x_event_ndims) broadcast_batch_shape = broadcast_bijector.experimental_batch_shape_tensor( x_event_ndims=x_event_ndims) self.assertAllEqual(broadcast_batch_shape, ps.broadcast_shape(batch_shape, new_batch_shape)) # Check that all params have the expected batch shape. broadcast_param_batch_shapes = batch_shape_lib.batch_shape_parts( broadcast_bijector, bijector_x_event_ndims=x_event_ndims) def _maybe_broadcast_param_batch_shape(p, s): if isinstance(p, tfb.Invert) and not p.bijector._params_event_ndims(): return s # Can't broadcast a bijector that doesn't itself have params. return ps.broadcast_shape(s, new_batch_shape) expected_broadcast_param_batch_shapes = tf.nest.map_structure( _maybe_broadcast_param_batch_shape, {param: getattr(bijector, param) for param in param_batch_shapes}, param_batch_shapes) self.assertAllEqualNested(broadcast_param_batch_shapes, expected_broadcast_param_batch_shapes)
class BatchShapeInferenceTests(test_util.TestCase): @parameterized.named_parameters( {'testcase_name': '_trivial', 'value_fn': lambda: tfd.Normal(loc=0., scale=1.), 'expected_batch_shape': []}, {'testcase_name': '_simple_tensor_broadcasting', 'value_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=[0., 0.], scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])), 'expected_batch_shape': [2]}, {'testcase_name': '_rank_deficient_tensor_broadcasting', 'value_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=0., scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])), 'expected_batch_shape': [2]}, {'testcase_name': '_mixture_same_family', 'value_fn': lambda: tfd.MixtureSameFamily( # pylint: disable=g-long-lambda mixture_distribution=tfd.Categorical( logits=[[[1., 2., 3.], [4., 5., 6.]]]), components_distribution=tfd.Normal(loc=0., scale=[[[1., 2., 3.], [4., 5., 6.]]])), 'expected_batch_shape': [1, 2]}, {'testcase_name': '_deeply_nested', 'value_fn': lambda: tfd.Independent( # pylint: disable=g-long-lambda tfd.Independent( tfd.Independent( tfd.Independent( tfd.Normal(loc=0., scale=[[[[[[[[1.]]]]]]]]), reinterpreted_batch_ndims=2), reinterpreted_batch_ndims=0), reinterpreted_batch_ndims=1), reinterpreted_batch_ndims=1), 'expected_batch_shape': [1, 1, 1, 1]}) def test_batch_shape_inference_is_correct( self, value_fn, expected_batch_shape): value = value_fn() # Defer construction until we're in the right graph. self.assertAllEqual( expected_batch_shape, value.batch_shape_tensor()) batch_shape = value.batch_shape self.assertIsInstance(batch_shape, tf.TensorShape) self.assertTrue( batch_shape.is_compatible_with(expected_batch_shape)) def assert_all_parameters_have_full_batch_shape( self, dist, expected_batch_shape): self.assertAllEqual(expected_batch_shape, dist.batch_shape_tensor()) param_batch_shapes = batch_shape_lib.batch_shape_parts(dist) for param_batch_shape in param_batch_shapes.values(): self.assertAllEqual(expected_batch_shape, param_batch_shape) @parameterized.named_parameters( {'testcase_name': '_trivial', 'dist_fn': lambda: tfd.Normal(loc=0., scale=1.)}, {'testcase_name': '_simple_tensor_broadcasting', 'dist_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=[0., 0.], scale_diag=[[1., 1.], [1., 1.]])}, {'testcase_name': '_rank_deficient_tensor_broadcasting', 'dist_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=0., scale_diag=[[1., 1.], [1., 1.]])}, {'testcase_name': '_deeply_nested', 'dist_fn': lambda: tfd.Independent( # pylint: disable=g-long-lambda tfd.Independent( tfd.Independent( tfd.Independent( tfd.Normal(loc=0., scale=[[[[[[[[1.]]]]]]]]), reinterpreted_batch_ndims=2), reinterpreted_batch_ndims=0), reinterpreted_batch_ndims=1), reinterpreted_batch_ndims=1)}, {'testcase_name': '_transformed_dist_simple', 'dist_fn': lambda: tfd.TransformedDistribution( # pylint: disable=g-long-lambda tfd.Normal(loc=[[1., 2., 3.], [3., 4., 5.]], scale=[1.]), tfb.Scale(scale=[2., 3., 4.]))}, {'testcase_name': '_transformed_dist_with_chain', 'dist_fn': lambda: tfd.TransformedDistribution( # pylint: disable=g-long-lambda tfd.Normal(loc=[[1., 2., 3.], [3., 4., 5.]], scale=[1.]), tfb.Shift(-4.)(tfb.Scale(scale=[2., 3., 4.])))}, {'testcase_name': '_transformed_dist_multipart_nested', 'dist_fn': lambda: tfd.TransformedDistribution( # pylint: disable=g-long-lambda tfd.TransformedDistribution( tfd.TransformedDistribution( tfd.MultivariateNormalDiag(tf.zeros([4, 6]), tf.ones([6])), tfb.Split([3, 3])), tfb.JointMap([tfb.Identity(), tfb.Reshape([3, 1])])), tfb.JointMap([tfb.Scale(scale=[2., 3., 4.]), tfb.Shift(1.)]))} ) def test_batch_broadcasting(self, dist_fn): dist = dist_fn() broadcast_dist = dist._broadcast_parameters_with_batch_shape( dist.batch_shape) self.assert_all_parameters_have_full_batch_shape( broadcast_dist, expected_batch_shape=broadcast_dist.batch_shape_tensor()) expanded_batch_shape = ps.concat([[7, 4], dist.batch_shape], axis=0) broadcast_params = batch_shape_lib.broadcast_parameters_with_batch_shape( dist, expanded_batch_shape) broadcast_dist = dist.copy(**broadcast_params) self.assert_all_parameters_have_full_batch_shape( broadcast_dist, expected_batch_shape=expanded_batch_shape)
def _testAssertRaisesNegativeSplitSizes(self): split_sizes = self.build_input([-2, 3, 5]) with self.assertRaisesError('must be either non-negative integers or'): bijector = tfb.Split(split_sizes, validate_args=True) self.evaluate(bijector.forward(tf.zeros((4, 10))))
def _testAssertRaisesMultipleUnknownSplitSizes(self): split_sizes = self.build_input([-1, 4, -1, 8]) with self.assertRaisesError('must have at most one'): bijector = tfb.Split(split_sizes, validate_args=True) self.evaluate(bijector.forward(tf.zeros((3, 14))))
class MarkovChainBijectorTest(test_util.TestCase): # pylint: disable=g-long-lambda @parameterized.named_parameters( dict(testcase_name='deterministic_prior', prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]), transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)), dict(testcase_name='deterministic_transition', prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.), transition_fn=lambda _, x: tfd.Deterministic(x)), dict(testcase_name='fully_deterministic', prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]), transition_fn=lambda _, x: tfd.Deterministic(x)), dict(testcase_name='mvn_diag', prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]], scale_diag=[1.])), transition_fn=lambda _, x: tfd.VectorDeterministic(x)), dict(testcase_name='docstring_dirichlet', prior_fn=lambda: tfd.JointDistributionNamedAutoBatched( {'probs': tfd.Dirichlet([1., 1.])}), transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched( { 'probs': tfd.MultivariateNormalDiag(loc=x['probs'], scale_diag=[0.1, 0.1]) }, batch_ndims=ps.rank(x['probs']))), dict(testcase_name='uniform_step', prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])), transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)), dict(testcase_name='joint_distribution', prior_fn=lambda: tfd.JointDistributionNamedAutoBatched( batch_ndims=2, model={ 'a': tfd.Gamma(tf.zeros([5]), 1.), 'b': lambda a: (tfb.Reshape(event_shape_in=[4, 3], event_shape_out=[2, 3, 2]) (tfd.Independent(tfd.Normal( loc=tf.zeros([5, 4, 3]), scale=a[..., tf.newaxis, tf.newaxis]), reinterpreted_batch_ndims=2))) }), transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched( batch_ndims=ps.rank_from_shape(x['a'].shape), model={ 'a': tfd.Normal(loc=x['a'], scale=1.), 'b': lambda a: tfd.Deterministic(x['b'] + a[ ..., tf.newaxis, tf.newaxis, tf.newaxis]) })), dict(testcase_name='nested_chain', prior_fn=lambda: tfd. MarkovChain(initial_state_prior=tfb.Split(2) (tfd.MultivariateNormalDiag(0., [1., 2.])), transition_fn=lambda _, x: tfb.Split(2) (tfd.MultivariateNormalDiag(x[0], [1., 2.])), num_steps=6), transition_fn=( lambda _, x: tfd.JointDistributionSequentialAutoBatched( [ tfd.MultivariateNormalDiag(x[0], [1.]), tfd.MultivariateNormalDiag(x[1], [1.]) ], batch_ndims=ps.rank(x[0]))))) # pylint: enable=g-long-lambda def test_default_bijector(self, prior_fn, transition_fn): chain = tfd.MarkovChain(initial_state_prior=prior_fn(), transition_fn=transition_fn, num_steps=7) y = self.evaluate(chain.sample(seed=test_util.test_seed())) bijector = chain.experimental_default_event_space_bijector() self.assertAllEqual(chain.batch_shape_tensor(), bijector.experimental_batch_shape_tensor()) x = bijector.inverse(y) yy = bijector.forward(tf.nest.map_structure( tf.identity, x)) # Bypass bijector cache. self.assertAllCloseNested(y, yy) chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape, chain.event_shape_tensor()) self.assertAllEqualNested(bijector.inverse_min_event_ndims, chain_event_ndims) ildj = bijector.inverse_log_det_jacobian( tf.nest.map_structure(tf.identity, y), # Bypass bijector cache. event_ndims=chain_event_ndims) if not bijector.is_constant_jacobian: self.assertAllEqual(ildj.shape, chain.batch_shape) fldj = bijector.forward_log_det_jacobian( tf.nest.map_structure(tf.identity, x), # Bypass bijector cache. event_ndims=bijector.inverse_event_ndims(chain_event_ndims)) self.assertAllClose(ildj, -fldj) # Verify that event shapes are passed through and flattened/unflattened # correctly. inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape) x_event_shapes = tf.nest.map_structure( lambda t, nd: t.shape[ps.rank(t) - nd:], x, bijector.forward_min_event_ndims) self.assertAllEqualNested(inverse_event_shapes, x_event_shapes) forward_event_shapes = bijector.forward_event_shape( inverse_event_shapes) self.assertAllEqualNested(forward_event_shapes, chain.event_shape) # Verify that the outputs of other methods have the correct structure. inverse_event_shape_tensors = bijector.inverse_event_shape_tensor( chain.event_shape_tensor()) self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes) forward_event_shape_tensors = bijector.forward_event_shape_tensor( inverse_event_shape_tensors) self.assertAllEqualNested(forward_event_shape_tensors, chain.event_shape_tensor())
def testAssertRaisesNumSplitsNonDivisible(self): num_splits = 3 x = np.random.rand(4, 5, 6) bijector = tfb.Split(num_splits, axis=-2, validate_args=True) with self.assertRaisesRegexp(ValueError, 'number of splits'): self.evaluate(bijector.forward(x))
def testAssertRaisesNonVectorSplitSizes(self): split_sizes = self.build_input([[1, 2, 2]]) with self.assertRaisesRegexp(ValueError, 'must be an integer or 1-D'): tfb.Split(split_sizes, validate_args=True)