def testMinEventNdimsChain(self): chain = tfb.Chain([tfb.Exp(), tfb.Exp(), tfb.Exp()]) self.assertEqual(0, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = tfb.Chain([ tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), tfb.ScaleMatvecDiag(scale_diag=[1., 1.]) ]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = tfb.Chain( [tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = tfb.Chain( [tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), tfb.Exp()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = tfb.Chain([ tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), tfb.Exp(), tfb.Softplus(), tfb.ScaleMatvecDiag(scale_diag=[1., 1.]) ]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims)
def testName(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff], block_sizes=[2, 1, 3]) self.assertStartsWith( blockwise.name, 'blockwise_of_exp_and_softplus_and_scale_matvec_diag')
def testRaisesWhenSingular(self): with self.assertRaisesRegexp( Exception, 'Singular operator: Diagonal contained zero values'): bijector = tfb.ScaleMatvecDiag( # Has zero on the diagonal. scale_diag=[0., 1], validate_args=True) self.evaluate(bijector.forward([1., 1.]))
def testBatch(self, is_static): # Corresponds to 1 2x2 matrix, with twos on the diagonal. scale_diag = [[2., 2]] bijector = tfb.ScaleMatvecDiag(scale_diag=scale_diag) x = self.maybe_static([[[1., 1]]], is_static) self.assertAllClose([[[2., 2]]], bijector.forward(x)) self.assertAllClose([[[0.5, 0.5]]], bijector.inverse(x)) self.assertAllClose([-np.log(4)], bijector.inverse_log_det_jacobian(x, event_ndims=1))
def testEventNdimsIsOptional(self): scale_diag = np.array([1., 2., 3.], dtype=np.float32) chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=scale_diag), tfb.Exp()]) x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] y = [1., 4., 9.] self.assertAllClose( np.log(6, dtype=np.float32) + np.sum(x), self.evaluate(chain.forward_log_det_jacobian(x))) self.assertAllClose( -np.log(6, dtype=np.float32) - np.sum(x), self.evaluate(chain.inverse_log_det_jacobian(y)))
def testMinEventNdimsShapeChangingRemoveDims(self): chain = tfb.Chain([ShapeChanging(3, 0)]) self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = tfb.Chain( [ShapeChanging(3, 0), tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = tfb.Chain( [tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), ShapeChanging(3, 0)]) self.assertEqual(4, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = tfb.Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)]) self.assertEqual(6, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims)
def testMinEventNdimsChain(self): self._validateChainMinEventNdims( bijectors=[tfb.Exp(), tfb.Exp(), tfb.Exp()], forward_min_event_ndims=0, inverse_min_event_ndims=0) self._validateChainMinEventNdims(bijectors=[ tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), tfb.ScaleMatvecDiag(scale_diag=[1., 1.]) ], forward_min_event_ndims=1, inverse_min_event_ndims=1) self._validateChainMinEventNdims( bijectors=[tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=[1., 1.])], forward_min_event_ndims=1, inverse_min_event_ndims=1) self._validateChainMinEventNdims(bijectors=[ tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), tfb.Exp(), tfb.Softplus(), tfb.ScaleMatvecDiag(scale_diag=[1., 1.]) ], forward_min_event_ndims=1, inverse_min_event_ndims=1)
def testChainExpAffine(self): scale_diag = np.array([1., 2., 3.], dtype=np.float32) chain = tfb.Chain([tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=scale_diag)]) x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] y = [1., 4., 27.] self.assertAllClose(y, self.evaluate(chain.forward(x))) self.assertAllClose(x, self.evaluate(chain.inverse(y))) self.assertAllClose( np.log(6, dtype=np.float32) + np.sum(scale_diag * x), self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1))) self.assertAllClose( -np.log(6, dtype=np.float32) - np.sum(scale_diag * x), self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))
def testMinEventNdimsShapeChangingAddDims(self): self._validateChainMinEventNdims(bijectors=[ShapeChanging()], forward_min_event_ndims=0, inverse_min_event_ndims=3) self._validateChainMinEventNdims(bijectors=[ ShapeChanging(), tfb.ScaleMatvecDiag(scale_diag=[1., 1.]) ], forward_min_event_ndims=1, inverse_min_event_ndims=4) self._validateChainMinEventNdims( bijectors=[ShapeChanging(), ShapeChanging()], forward_min_event_ndims=0, inverse_min_event_ndims=6)
def test_composition_str_and_repr_match_expected_dynamic_shape(self): bij = tfb.Chain([ tfb.Exp(), tfb.Shift(self._tensor([1., 2.])), tfb.SoftmaxCentered() ]) self.assertContainsInOrder([ 'tfp.bijectors.Chain(', ('min_event_ndims=1, bijectors=[Exp, Shift, SoftmaxCentered])') ], str(bij)) self.assertContainsInOrder([ '<tfp.bijectors.Chain ', ('batch_shape=? forward_min_event_ndims=1 inverse_min_event_ndims=1 ' 'dtype_x=float32 dtype_y=float32 bijectors=[<tfp.bijectors.Exp'), '>, <tfp.bijectors.Shift', '>, <tfp.bijectors.SoftmaxCentered', '>]>' ], repr(bij)) bij = tfb.Chain([ tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.ScaleMatvecDiag(self._tensor([2., 2.])) }), tfb.Restructure({ 'a': 0, 'b': 1 }, [0, 1]), tfb.Split(2), tfb.Invert(tfb.SoftmaxCentered()), ]) self.assertContainsInOrder([ 'tfp.bijectors.Chain(', ('forward_min_event_ndims=1, ' 'inverse_min_event_ndims={a: 1, b: 1}, ' 'bijectors=[JointMap({a: Exp, b: ScaleMatvecDiag}), ' 'Restructure, Split, Invert(SoftmaxCentered)])') ], str(bij)) self.assertContainsInOrder([ '<tfp.bijectors.Chain ', ('batch_shape=? forward_min_event_ndims=1 ' "inverse_min_event_ndims={'a': 1, 'b': 1} dtype_x=float32 " "dtype_y={'a': ?, 'b': float32} " "bijectors=[<tfp.bijectors.JointMap "), '>, <tfp.bijectors.Restructure', '>, <tfp.bijectors.Split', '>, <tfp.bijectors.Invert', '>]>' ], repr(bij))
def testBijectiveAndFinite(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff], block_sizes=[2, 1, 3]) x = tf.cast([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=tf.float32) x = tf1.placeholder_with_default(x, shape=x.shape) # Identity to break the caching. blockwise_y = tf.identity(blockwise.forward(x)) bijector_test_util.assert_bijective_and_finite( blockwise, x=self.evaluate(x), y=self.evaluate(blockwise_y), eval_func=self.evaluate, event_ndims=1)
def testNoBatch(self, is_static): # Corresponds to scale = [[2., 0], [0, 1.]] bijector = tfb.ScaleMatvecDiag(scale_diag=[2., 1]) x = self.maybe_static([1., 1], is_static) # matmul(sigma, x) self.assertAllClose([2., 1], bijector.forward(x)) self.assertAllClose([0.5, 1], bijector.inverse(x)) self.assertAllClose( -np.log(2.), bijector.inverse_log_det_jacobian(x, event_ndims=1)) # x is a 2-batch of 2-vectors. # The first vector is [1, 1], the second is [-1, -1]. # Each undergoes matmul(sigma, x). x = self.maybe_static([[1., 1], [-1., -1]], is_static) self.assertAllClose([[2., 1], [-2., -1]], bijector.forward(x)) self.assertAllClose([[0.5, 1], [-0.5, -1]], bijector.inverse(x)) self.assertAllClose( -np.log(2.), bijector.inverse_log_det_jacobian(x, event_ndims=1))
def testBijector(self): for fwd in [ tfb.Identity(), tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=[2., 3.]), tfb.Softplus(), tfb.SoftmaxCentered(), ]: rev = tfb.Invert(fwd) self.assertStartsWith(rev.name, '_'.join(['invert', fwd.name])) x = [[[1., 2.], [2., 3.]]] self.assertAllClose(self.evaluate(fwd.inverse(x)), self.evaluate(rev.forward(x))) self.assertAllClose(self.evaluate(fwd.forward(x)), self.evaluate(rev.inverse(x))) self.assertAllClose( self.evaluate(fwd.forward_log_det_jacobian(x, event_ndims=1)), self.evaluate(rev.inverse_log_det_jacobian(x, event_ndims=1))) self.assertAllClose( self.evaluate(fwd.inverse_log_det_jacobian(x, event_ndims=1)), self.evaluate(rev.forward_log_det_jacobian(x, event_ndims=1)))
def test_slice_transformed_distribution_with_chain(self): dist = tfd.TransformedDistribution( distribution=tfd.MultivariateNormalDiag( loc=tf.zeros([4]), scale_diag=tf.ones([1, 4])), bijector=tfb.Chain([tfb.JointMap([tfb.Identity(), tfb.Shift(tf.ones([4, 3, 2]))]), tfb.Split(2), tfb.ScaleMatvecDiag(tf.ones([5, 1, 3, 4])), tfb.Exp()])) self.assertAllEqual(dist.batch_shape_tensor(), [5, 4, 3]) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, dist.sample(seed=test_util.test_seed())), [[5, 4, 3, 2], [5, 4, 3, 2]]) sliced = dist[tf.newaxis, ..., 0, :, :-1] self.assertAllEqual(sliced.batch_shape_tensor(), [1, 4, 2]) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, sliced.sample(seed=test_util.test_seed())), [[1, 4, 2, 2], [1, 4, 2, 2]])
def testExplicitBlocks(self, dynamic_shape, batch_shape): block_sizes = tf.convert_to_tensor(value=[2, 1, 3]) block_sizes = tf1.placeholder_with_default( block_sizes, shape=([None] * len(block_sizes.shape) if dynamic_shape else block_sizes.shape)) exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff], block_sizes=block_sizes, maybe_changes_size=False) x = tf.cast([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=tf.float32) for s in batch_shape: x = tf.expand_dims(x, 0) x = tf.tile(x, [s] + [1] * (tensorshape_util.rank(x.shape) - 1)) x = tf1.placeholder_with_default( x, shape=None if dynamic_shape else x.shape) # Identity to break the caching. blockwise_y = tf.identity(blockwise.forward(x)) blockwise_fldj = blockwise.forward_log_det_jacobian(x, event_ndims=1) blockwise_x = blockwise.inverse(blockwise_y) blockwise_ildj = blockwise.inverse_log_det_jacobian(blockwise_y, event_ndims=1) if not dynamic_shape: self.assertEqual(blockwise_y.shape, batch_shape + [6]) self.assertEqual(blockwise_fldj.shape, batch_shape + []) self.assertEqual(blockwise_x.shape, batch_shape + [6]) self.assertEqual(blockwise_ildj.shape, batch_shape + []) self.assertAllEqual(self.evaluate(tf.shape(blockwise_y)), batch_shape + [6]) self.assertAllEqual(self.evaluate(tf.shape(blockwise_fldj)), batch_shape + []) self.assertAllEqual(self.evaluate(tf.shape(blockwise_x)), batch_shape + [6]) self.assertAllEqual(self.evaluate(tf.shape(blockwise_ildj)), batch_shape + []) expl_y = tf.concat([ exp.forward(x[..., :2]), sp.forward(x[..., 2:3]), aff.forward(x[..., 3:]), ], axis=-1) expl_fldj = sum([ exp.forward_log_det_jacobian(x[..., :2], event_ndims=1), sp.forward_log_det_jacobian(x[..., 2:3], event_ndims=1), aff.forward_log_det_jacobian(x[..., 3:], event_ndims=1) ]) expl_x = tf.concat([ exp.inverse(expl_y[..., :2]), sp.inverse(expl_y[..., 2:3]), aff.inverse(expl_y[..., 3:]) ], axis=-1) expl_ildj = sum([ exp.inverse_log_det_jacobian(expl_y[..., :2], event_ndims=1), sp.inverse_log_det_jacobian(expl_y[..., 2:3], event_ndims=1), aff.inverse_log_det_jacobian(expl_y[..., 3:], event_ndims=1) ]) self.assertAllClose(self.evaluate(expl_y), self.evaluate(blockwise_y)) self.assertAllClose(self.evaluate(expl_fldj), self.evaluate(blockwise_fldj)) self.assertAllClose(self.evaluate(expl_x), self.evaluate(blockwise_x)) self.assertAllClose(self.evaluate(expl_ildj), self.evaluate(blockwise_ildj))
def testImplicitBlocks(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.ScaleMatvecDiag(scale_diag=[2.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff]) self.assertAllEqual(self.evaluate(blockwise.block_sizes), [1, 1, 1])
class BijectorBatchShapesTest(test_util.TestCase): @parameterized.named_parameters( ('exp', tfb.Exp, None), ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None), ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None), ('scale_matvec', lambda: tfb.ScaleMatvecDiag([[0.], [3.]]), None), ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None), ('reshape', lambda: tfb.Reshape([1], event_shape_in=[1, 1]), None), ('chain', lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]), # pylint: disable=g-long-lambda tfb.Invert(tfb.Split(2))]), None), ('jointmap_01', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]), ('jointmap_11', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]), ('jointmap_20', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]), ('jointmap_22', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]), ('restructure_with_ragged_event_ndims', lambda: tfb.Restructure(input_structure=[0, 1], # pylint: disable=g-long-lambda output_structure={'a': 0, 'b': 1}), [0, 1])) def test_batch_shape_matches_output_shapes(self, bijector_fn, override_x_event_ndims=None): bijector = bijector_fn() if override_x_event_ndims is None: x_event_ndims = bijector.forward_min_event_ndims y_event_ndims = bijector.inverse_min_event_ndims else: x_event_ndims = override_x_event_ndims y_event_ndims = bijector.forward_event_ndims(x_event_ndims) # All ways of calculating the batch shape should yield the same result. batch_shape_x = bijector.experimental_batch_shape( x_event_ndims=x_event_ndims) batch_shape_y = bijector.experimental_batch_shape( y_event_ndims=y_event_ndims) self.assertEqual(batch_shape_x, batch_shape_y) batch_shape_tensor_x = bijector.experimental_batch_shape_tensor( x_event_ndims=x_event_ndims) batch_shape_tensor_y = bijector.experimental_batch_shape_tensor( y_event_ndims=y_event_ndims) self.assertAllEqual(batch_shape_tensor_x, batch_shape_tensor_y) self.assertAllEqual(batch_shape_tensor_x, batch_shape_x) # Check that we're robust to integer type. batch_shape_tensor_x64 = bijector.experimental_batch_shape_tensor( x_event_ndims=tf.nest.map_structure(np.int64, x_event_ndims)) batch_shape_tensor_y64 = bijector.experimental_batch_shape_tensor( y_event_ndims=tf.nest.map_structure(np.int64, y_event_ndims)) self.assertAllEqual(batch_shape_tensor_x64, batch_shape_tensor_y64) self.assertAllEqual(batch_shape_tensor_x64, batch_shape_x) # Pushing a value through the bijector should return a Tensor(s) with # the expected batch shape... xs = tf.nest.map_structure(lambda nd: tf.ones([1] * nd), x_event_ndims) ys = bijector.forward(xs) for y_part, nd in zip(tf.nest.flatten(ys), tf.nest.flatten(y_event_ndims)): part_batch_shape = ps.shape(y_part)[:ps.rank(y_part) - nd] self.assertAllEqual(batch_shape_y, ps.broadcast_shape(batch_shape_y, part_batch_shape)) # ... which should also be the shape of the fldj. fldj = bijector.forward_log_det_jacobian(xs, event_ndims=x_event_ndims) self.assertAllEqual(batch_shape_y, ps.shape(fldj)) # Also check the inverse case. xs = bijector.inverse(tf.nest.map_structure(tf.identity, ys)) for x_part, nd in zip(tf.nest.flatten(xs), tf.nest.flatten(x_event_ndims)): part_batch_shape = ps.shape(x_part)[:ps.rank(x_part) - nd] self.assertAllEqual(batch_shape_x, ps.broadcast_shape(batch_shape_x, part_batch_shape)) ildj = bijector.inverse_log_det_jacobian(ys, event_ndims=y_event_ndims) self.assertAllEqual(batch_shape_x, ps.shape(ildj)) @parameterized.named_parameters( ('scale', lambda: tfb.Scale([3.14159])), ('chain', lambda: tfb.Exp()(tfb.Scale([3.14159])))) def test_ndims_specification(self, bijector_fn): bijector = bijector_fn() # If no `event_ndims` is passed, should assume min_event_ndims. self.assertAllEqual(bijector.experimental_batch_shape(), [1]) self.assertAllEqual(bijector.experimental_batch_shape_tensor(), [1]) with self.assertRaisesRegex( ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'): bijector.experimental_batch_shape(x_event_ndims=0, y_event_ndims=0) with self.assertRaisesRegex( ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'): bijector.experimental_batch_shape_tensor(x_event_ndims=0, y_event_ndims=0) @parameterized.named_parameters( ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None), ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None), ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None), ('chain', lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]), # pylint: disable=g-long-lambda tfb.Invert(tfb.Split(2))]), None), ('jointmap_01', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]), ('jointmap_11', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]), ('jointmap_20', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]), ('jointmap_22', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]), ('nested_jointmap', lambda: tfb.JointMap([tfb.JointMap({'a': tfb.Scale([1.]), # pylint: disable=g-long-lambda 'b': tfb.Exp()}), tfb.Scale([1, 4])(tfb.Invert(tfb.Split(2)))]), [{'a': 0, 'b': 0}, [2, 2]])) def test_with_broadcast_batch_shape(self, bijector_fn, x_event_ndims=None): bijector = bijector_fn() if x_event_ndims is None: x_event_ndims = bijector.forward_min_event_ndims batch_shape = bijector.experimental_batch_shape(x_event_ndims=x_event_ndims) param_batch_shapes = batch_shape_lib.batch_shape_parts( bijector, bijector_x_event_ndims=x_event_ndims) new_batch_shape = [4, 2, 1, 1, 1] broadcast_bijector = bijector._broadcast_parameters_with_batch_shape( new_batch_shape, x_event_ndims) broadcast_batch_shape = broadcast_bijector.experimental_batch_shape_tensor( x_event_ndims=x_event_ndims) self.assertAllEqual(broadcast_batch_shape, ps.broadcast_shape(batch_shape, new_batch_shape)) # Check that all params have the expected batch shape. broadcast_param_batch_shapes = batch_shape_lib.batch_shape_parts( broadcast_bijector, bijector_x_event_ndims=x_event_ndims) def _maybe_broadcast_param_batch_shape(p, s): if isinstance(p, tfb.Invert) and not p.bijector._params_event_ndims(): return s # Can't broadcast a bijector that doesn't itself have params. return ps.broadcast_shape(s, new_batch_shape) expected_broadcast_param_batch_shapes = tf.nest.map_structure( _maybe_broadcast_param_batch_shape, {param: getattr(bijector, param) for param in param_batch_shapes}, param_batch_shapes) self.assertAllEqualNested(broadcast_param_batch_shapes, expected_broadcast_param_batch_shapes)