コード例 #1
0
ファイル: chain_test.py プロジェクト: yfe404/probability
    def testMinEventNdimsChain(self):
        chain = tfb.Chain([tfb.Exp(), tfb.Exp(), tfb.Exp()])
        self.assertEqual(0, chain.forward_min_event_ndims)
        self.assertEqual(0, chain.inverse_min_event_ndims)

        chain = tfb.Chain([
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.])
        ])
        self.assertEqual(1, chain.forward_min_event_ndims)
        self.assertEqual(1, chain.inverse_min_event_ndims)

        chain = tfb.Chain(
            [tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=[1., 1.])])
        self.assertEqual(1, chain.forward_min_event_ndims)
        self.assertEqual(1, chain.inverse_min_event_ndims)

        chain = tfb.Chain(
            [tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
             tfb.Exp()])
        self.assertEqual(1, chain.forward_min_event_ndims)
        self.assertEqual(1, chain.inverse_min_event_ndims)

        chain = tfb.Chain([
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
            tfb.Exp(),
            tfb.Softplus(),
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.])
        ])
        self.assertEqual(1, chain.forward_min_event_ndims)
        self.assertEqual(1, chain.inverse_min_event_ndims)
コード例 #2
0
 def testName(self):
     exp = tfb.Exp()
     sp = tfb.Softplus()
     aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.])
     blockwise = tfb.Blockwise(bijectors=[exp, sp, aff],
                               block_sizes=[2, 1, 3])
     self.assertStartsWith(
         blockwise.name,
         'blockwise_of_exp_and_softplus_and_scale_matvec_diag')
コード例 #3
0
 def testRaisesWhenSingular(self):
     with self.assertRaisesRegexp(
             Exception,
             'Singular operator:  Diagonal contained zero values'):
         bijector = tfb.ScaleMatvecDiag(
             # Has zero on the diagonal.
             scale_diag=[0., 1],
             validate_args=True)
         self.evaluate(bijector.forward([1., 1.]))
コード例 #4
0
 def testBatch(self, is_static):
     # Corresponds to 1 2x2 matrix, with twos on the diagonal.
     scale_diag = [[2., 2]]
     bijector = tfb.ScaleMatvecDiag(scale_diag=scale_diag)
     x = self.maybe_static([[[1., 1]]], is_static)
     self.assertAllClose([[[2., 2]]], bijector.forward(x))
     self.assertAllClose([[[0.5, 0.5]]], bijector.inverse(x))
     self.assertAllClose([-np.log(4)],
                         bijector.inverse_log_det_jacobian(x,
                                                           event_ndims=1))
コード例 #5
0
ファイル: chain_test.py プロジェクト: NeilGirdhar/probability
 def testEventNdimsIsOptional(self):
   scale_diag = np.array([1., 2., 3.], dtype=np.float32)
   chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=scale_diag), tfb.Exp()])
   x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)]
   y = [1., 4., 9.]
   self.assertAllClose(
       np.log(6, dtype=np.float32) + np.sum(x),
       self.evaluate(chain.forward_log_det_jacobian(x)))
   self.assertAllClose(
       -np.log(6, dtype=np.float32) - np.sum(x),
       self.evaluate(chain.inverse_log_det_jacobian(y)))
コード例 #6
0
ファイル: chain_test.py プロジェクト: yfe404/probability
    def testMinEventNdimsShapeChangingRemoveDims(self):
        chain = tfb.Chain([ShapeChanging(3, 0)])
        self.assertEqual(3, chain.forward_min_event_ndims)
        self.assertEqual(0, chain.inverse_min_event_ndims)

        chain = tfb.Chain(
            [ShapeChanging(3, 0),
             tfb.ScaleMatvecDiag(scale_diag=[1., 1.])])
        self.assertEqual(3, chain.forward_min_event_ndims)
        self.assertEqual(0, chain.inverse_min_event_ndims)

        chain = tfb.Chain(
            [tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
             ShapeChanging(3, 0)])
        self.assertEqual(4, chain.forward_min_event_ndims)
        self.assertEqual(1, chain.inverse_min_event_ndims)

        chain = tfb.Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)])
        self.assertEqual(6, chain.forward_min_event_ndims)
        self.assertEqual(0, chain.inverse_min_event_ndims)
コード例 #7
0
    def testMinEventNdimsChain(self):
        self._validateChainMinEventNdims(
            bijectors=[tfb.Exp(), tfb.Exp(), tfb.Exp()],
            forward_min_event_ndims=0,
            inverse_min_event_ndims=0)

        self._validateChainMinEventNdims(bijectors=[
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.])
        ],
                                         forward_min_event_ndims=1,
                                         inverse_min_event_ndims=1)

        self._validateChainMinEventNdims(
            bijectors=[tfb.Exp(),
                       tfb.ScaleMatvecDiag(scale_diag=[1., 1.])],
            forward_min_event_ndims=1,
            inverse_min_event_ndims=1)

        self._validateChainMinEventNdims(bijectors=[
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
            tfb.Exp(),
            tfb.Softplus(),
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.])
        ],
                                         forward_min_event_ndims=1,
                                         inverse_min_event_ndims=1)
コード例 #8
0
ファイル: chain_test.py プロジェクト: NeilGirdhar/probability
  def testChainExpAffine(self):
    scale_diag = np.array([1., 2., 3.], dtype=np.float32)
    chain = tfb.Chain([tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=scale_diag)])
    x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)]
    y = [1., 4., 27.]
    self.assertAllClose(y, self.evaluate(chain.forward(x)))
    self.assertAllClose(x, self.evaluate(chain.inverse(y)))
    self.assertAllClose(
        np.log(6, dtype=np.float32) + np.sum(scale_diag * x),
        self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))

    self.assertAllClose(
        -np.log(6, dtype=np.float32) - np.sum(scale_diag * x),
        self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))
コード例 #9
0
    def testMinEventNdimsShapeChangingAddDims(self):
        self._validateChainMinEventNdims(bijectors=[ShapeChanging()],
                                         forward_min_event_ndims=0,
                                         inverse_min_event_ndims=3)

        self._validateChainMinEventNdims(bijectors=[
            ShapeChanging(),
            tfb.ScaleMatvecDiag(scale_diag=[1., 1.])
        ],
                                         forward_min_event_ndims=1,
                                         inverse_min_event_ndims=4)

        self._validateChainMinEventNdims(
            bijectors=[ShapeChanging(), ShapeChanging()],
            forward_min_event_ndims=0,
            inverse_min_event_ndims=6)
コード例 #10
0
    def test_composition_str_and_repr_match_expected_dynamic_shape(self):
        bij = tfb.Chain([
            tfb.Exp(),
            tfb.Shift(self._tensor([1., 2.])),
            tfb.SoftmaxCentered()
        ])
        self.assertContainsInOrder([
            'tfp.bijectors.Chain(',
            ('min_event_ndims=1, bijectors=[Exp, Shift, SoftmaxCentered])')
        ], str(bij))
        self.assertContainsInOrder([
            '<tfp.bijectors.Chain ',
            ('batch_shape=? forward_min_event_ndims=1 inverse_min_event_ndims=1 '
             'dtype_x=float32 dtype_y=float32 bijectors=[<tfp.bijectors.Exp'),
            '>, <tfp.bijectors.Shift', '>, <tfp.bijectors.SoftmaxCentered',
            '>]>'
        ], repr(bij))

        bij = tfb.Chain([
            tfb.JointMap({
                'a': tfb.Exp(),
                'b': tfb.ScaleMatvecDiag(self._tensor([2., 2.]))
            }),
            tfb.Restructure({
                'a': 0,
                'b': 1
            }, [0, 1]),
            tfb.Split(2),
            tfb.Invert(tfb.SoftmaxCentered()),
        ])
        self.assertContainsInOrder([
            'tfp.bijectors.Chain(',
            ('forward_min_event_ndims=1, '
             'inverse_min_event_ndims={a: 1, b: 1}, '
             'bijectors=[JointMap({a: Exp, b: ScaleMatvecDiag}), '
             'Restructure, Split, Invert(SoftmaxCentered)])')
        ], str(bij))
        self.assertContainsInOrder([
            '<tfp.bijectors.Chain ',
            ('batch_shape=? forward_min_event_ndims=1 '
             "inverse_min_event_ndims={'a': 1, 'b': 1} dtype_x=float32 "
             "dtype_y={'a': ?, 'b': float32} "
             "bijectors=[<tfp.bijectors.JointMap "),
            '>, <tfp.bijectors.Restructure', '>, <tfp.bijectors.Split',
            '>, <tfp.bijectors.Invert', '>]>'
        ], repr(bij))
コード例 #11
0
    def testBijectiveAndFinite(self):
        exp = tfb.Exp()
        sp = tfb.Softplus()
        aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.])
        blockwise = tfb.Blockwise(bijectors=[exp, sp, aff],
                                  block_sizes=[2, 1, 3])

        x = tf.cast([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=tf.float32)
        x = tf1.placeholder_with_default(x, shape=x.shape)
        # Identity to break the caching.
        blockwise_y = tf.identity(blockwise.forward(x))

        bijector_test_util.assert_bijective_and_finite(
            blockwise,
            x=self.evaluate(x),
            y=self.evaluate(blockwise_y),
            eval_func=self.evaluate,
            event_ndims=1)
コード例 #12
0
    def testNoBatch(self, is_static):
        # Corresponds to scale = [[2., 0], [0, 1.]]
        bijector = tfb.ScaleMatvecDiag(scale_diag=[2., 1])
        x = self.maybe_static([1., 1], is_static)

        # matmul(sigma, x)
        self.assertAllClose([2., 1], bijector.forward(x))
        self.assertAllClose([0.5, 1], bijector.inverse(x))
        self.assertAllClose(
            -np.log(2.), bijector.inverse_log_det_jacobian(x, event_ndims=1))

        # x is a 2-batch of 2-vectors.
        # The first vector is [1, 1], the second is [-1, -1].
        # Each undergoes matmul(sigma, x).
        x = self.maybe_static([[1., 1], [-1., -1]], is_static)
        self.assertAllClose([[2., 1], [-2., -1]], bijector.forward(x))
        self.assertAllClose([[0.5, 1], [-0.5, -1]], bijector.inverse(x))
        self.assertAllClose(
            -np.log(2.), bijector.inverse_log_det_jacobian(x, event_ndims=1))
コード例 #13
0
ファイル: invert_test.py プロジェクト: stjordanis/probability
 def testBijector(self):
     for fwd in [
             tfb.Identity(),
             tfb.Exp(),
             tfb.ScaleMatvecDiag(scale_diag=[2., 3.]),
             tfb.Softplus(),
             tfb.SoftmaxCentered(),
     ]:
         rev = tfb.Invert(fwd)
         self.assertStartsWith(rev.name, '_'.join(['invert', fwd.name]))
         x = [[[1., 2.], [2., 3.]]]
         self.assertAllClose(self.evaluate(fwd.inverse(x)),
                             self.evaluate(rev.forward(x)))
         self.assertAllClose(self.evaluate(fwd.forward(x)),
                             self.evaluate(rev.inverse(x)))
         self.assertAllClose(
             self.evaluate(fwd.forward_log_det_jacobian(x, event_ndims=1)),
             self.evaluate(rev.inverse_log_det_jacobian(x, event_ndims=1)))
         self.assertAllClose(
             self.evaluate(fwd.inverse_log_det_jacobian(x, event_ndims=1)),
             self.evaluate(rev.forward_log_det_jacobian(x, event_ndims=1)))
コード例 #14
0
ファイル: slicing_test.py プロジェクト: gisilvs/probability
  def test_slice_transformed_distribution_with_chain(self):
    dist = tfd.TransformedDistribution(
        distribution=tfd.MultivariateNormalDiag(
            loc=tf.zeros([4]), scale_diag=tf.ones([1, 4])),
        bijector=tfb.Chain([tfb.JointMap([tfb.Identity(),
                                          tfb.Shift(tf.ones([4, 3, 2]))]),
                            tfb.Split(2),
                            tfb.ScaleMatvecDiag(tf.ones([5, 1, 3, 4])),
                            tfb.Exp()]))
    self.assertAllEqual(dist.batch_shape_tensor(), [5, 4, 3])
    self.assertAllEqualNested(
        tf.nest.map_structure(lambda x: x.shape,
                              dist.sample(seed=test_util.test_seed())),
        [[5, 4, 3, 2], [5, 4, 3, 2]])

    sliced = dist[tf.newaxis, ..., 0, :, :-1]
    self.assertAllEqual(sliced.batch_shape_tensor(), [1, 4, 2])
    self.assertAllEqualNested(
        tf.nest.map_structure(lambda x: x.shape,
                              sliced.sample(seed=test_util.test_seed())),
        [[1, 4, 2, 2], [1, 4, 2, 2]])
コード例 #15
0
    def testExplicitBlocks(self, dynamic_shape, batch_shape):
        block_sizes = tf.convert_to_tensor(value=[2, 1, 3])
        block_sizes = tf1.placeholder_with_default(
            block_sizes,
            shape=([None] * len(block_sizes.shape)
                   if dynamic_shape else block_sizes.shape))
        exp = tfb.Exp()
        sp = tfb.Softplus()
        aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.])
        blockwise = tfb.Blockwise(bijectors=[exp, sp, aff],
                                  block_sizes=block_sizes,
                                  maybe_changes_size=False)

        x = tf.cast([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=tf.float32)
        for s in batch_shape:
            x = tf.expand_dims(x, 0)
            x = tf.tile(x, [s] + [1] * (tensorshape_util.rank(x.shape) - 1))
        x = tf1.placeholder_with_default(
            x, shape=None if dynamic_shape else x.shape)

        # Identity to break the caching.
        blockwise_y = tf.identity(blockwise.forward(x))
        blockwise_fldj = blockwise.forward_log_det_jacobian(x, event_ndims=1)
        blockwise_x = blockwise.inverse(blockwise_y)
        blockwise_ildj = blockwise.inverse_log_det_jacobian(blockwise_y,
                                                            event_ndims=1)

        if not dynamic_shape:
            self.assertEqual(blockwise_y.shape, batch_shape + [6])
            self.assertEqual(blockwise_fldj.shape, batch_shape + [])
            self.assertEqual(blockwise_x.shape, batch_shape + [6])
            self.assertEqual(blockwise_ildj.shape, batch_shape + [])
        self.assertAllEqual(self.evaluate(tf.shape(blockwise_y)),
                            batch_shape + [6])
        self.assertAllEqual(self.evaluate(tf.shape(blockwise_fldj)),
                            batch_shape + [])
        self.assertAllEqual(self.evaluate(tf.shape(blockwise_x)),
                            batch_shape + [6])
        self.assertAllEqual(self.evaluate(tf.shape(blockwise_ildj)),
                            batch_shape + [])

        expl_y = tf.concat([
            exp.forward(x[..., :2]),
            sp.forward(x[..., 2:3]),
            aff.forward(x[..., 3:]),
        ],
                           axis=-1)
        expl_fldj = sum([
            exp.forward_log_det_jacobian(x[..., :2], event_ndims=1),
            sp.forward_log_det_jacobian(x[..., 2:3], event_ndims=1),
            aff.forward_log_det_jacobian(x[..., 3:], event_ndims=1)
        ])
        expl_x = tf.concat([
            exp.inverse(expl_y[..., :2]),
            sp.inverse(expl_y[..., 2:3]),
            aff.inverse(expl_y[..., 3:])
        ],
                           axis=-1)
        expl_ildj = sum([
            exp.inverse_log_det_jacobian(expl_y[..., :2], event_ndims=1),
            sp.inverse_log_det_jacobian(expl_y[..., 2:3], event_ndims=1),
            aff.inverse_log_det_jacobian(expl_y[..., 3:], event_ndims=1)
        ])

        self.assertAllClose(self.evaluate(expl_y), self.evaluate(blockwise_y))
        self.assertAllClose(self.evaluate(expl_fldj),
                            self.evaluate(blockwise_fldj))
        self.assertAllClose(self.evaluate(expl_x), self.evaluate(blockwise_x))
        self.assertAllClose(self.evaluate(expl_ildj),
                            self.evaluate(blockwise_ildj))
コード例 #16
0
 def testImplicitBlocks(self):
     exp = tfb.Exp()
     sp = tfb.Softplus()
     aff = tfb.ScaleMatvecDiag(scale_diag=[2.])
     blockwise = tfb.Blockwise(bijectors=[exp, sp, aff])
     self.assertAllEqual(self.evaluate(blockwise.block_sizes), [1, 1, 1])
コード例 #17
0
class BijectorBatchShapesTest(test_util.TestCase):

  @parameterized.named_parameters(
      ('exp', tfb.Exp, None),
      ('scale',
       lambda: tfb.Scale(tf.ones([4, 2])), None),
      ('sigmoid',
       lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None),
      ('scale_matvec',
       lambda: tfb.ScaleMatvecDiag([[0.], [3.]]), None),
      ('invert',
       lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None),
      ('reshape',
       lambda: tfb.Reshape([1], event_shape_in=[1, 1]), None),
      ('chain',
       lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]),  # pylint: disable=g-long-lambda
                          tfb.Invert(tfb.Split(2))]),
       None),
      ('jointmap_01',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]),
      ('jointmap_11',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]),
      ('jointmap_20',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]),
      ('jointmap_22',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]),
      ('restructure_with_ragged_event_ndims',
       lambda: tfb.Restructure(input_structure=[0, 1],  # pylint: disable=g-long-lambda
                               output_structure={'a': 0, 'b': 1}),
       [0, 1]))
  def test_batch_shape_matches_output_shapes(self,
                                             bijector_fn,
                                             override_x_event_ndims=None):
    bijector = bijector_fn()
    if override_x_event_ndims is None:
      x_event_ndims = bijector.forward_min_event_ndims
      y_event_ndims = bijector.inverse_min_event_ndims
    else:
      x_event_ndims = override_x_event_ndims
      y_event_ndims = bijector.forward_event_ndims(x_event_ndims)

    # All ways of calculating the batch shape should yield the same result.
    batch_shape_x = bijector.experimental_batch_shape(
        x_event_ndims=x_event_ndims)
    batch_shape_y = bijector.experimental_batch_shape(
        y_event_ndims=y_event_ndims)
    self.assertEqual(batch_shape_x, batch_shape_y)

    batch_shape_tensor_x = bijector.experimental_batch_shape_tensor(
        x_event_ndims=x_event_ndims)
    batch_shape_tensor_y = bijector.experimental_batch_shape_tensor(
        y_event_ndims=y_event_ndims)
    self.assertAllEqual(batch_shape_tensor_x, batch_shape_tensor_y)
    self.assertAllEqual(batch_shape_tensor_x, batch_shape_x)

    # Check that we're robust to integer type.
    batch_shape_tensor_x64 = bijector.experimental_batch_shape_tensor(
        x_event_ndims=tf.nest.map_structure(np.int64, x_event_ndims))
    batch_shape_tensor_y64 = bijector.experimental_batch_shape_tensor(
        y_event_ndims=tf.nest.map_structure(np.int64, y_event_ndims))
    self.assertAllEqual(batch_shape_tensor_x64, batch_shape_tensor_y64)
    self.assertAllEqual(batch_shape_tensor_x64, batch_shape_x)

    # Pushing a value through the bijector should return a Tensor(s) with
    # the expected batch shape...
    xs = tf.nest.map_structure(lambda nd: tf.ones([1] * nd), x_event_ndims)
    ys = bijector.forward(xs)
    for y_part, nd in zip(tf.nest.flatten(ys), tf.nest.flatten(y_event_ndims)):
      part_batch_shape = ps.shape(y_part)[:ps.rank(y_part) - nd]
      self.assertAllEqual(batch_shape_y,
                          ps.broadcast_shape(batch_shape_y, part_batch_shape))

    # ... which should also be the shape of the fldj.
    fldj = bijector.forward_log_det_jacobian(xs, event_ndims=x_event_ndims)
    self.assertAllEqual(batch_shape_y, ps.shape(fldj))

    # Also check the inverse case.
    xs = bijector.inverse(tf.nest.map_structure(tf.identity, ys))
    for x_part, nd in zip(tf.nest.flatten(xs), tf.nest.flatten(x_event_ndims)):
      part_batch_shape = ps.shape(x_part)[:ps.rank(x_part) - nd]
      self.assertAllEqual(batch_shape_x,
                          ps.broadcast_shape(batch_shape_x, part_batch_shape))
    ildj = bijector.inverse_log_det_jacobian(ys, event_ndims=y_event_ndims)
    self.assertAllEqual(batch_shape_x, ps.shape(ildj))

  @parameterized.named_parameters(
      ('scale', lambda: tfb.Scale([3.14159])),
      ('chain', lambda: tfb.Exp()(tfb.Scale([3.14159]))))
  def test_ndims_specification(self, bijector_fn):
    bijector = bijector_fn()

    # If no `event_ndims` is passed, should assume min_event_ndims.
    self.assertAllEqual(bijector.experimental_batch_shape(), [1])
    self.assertAllEqual(bijector.experimental_batch_shape_tensor(), [1])

    with self.assertRaisesRegex(
        ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'):
      bijector.experimental_batch_shape(x_event_ndims=0, y_event_ndims=0)

    with  self.assertRaisesRegex(
        ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'):
      bijector.experimental_batch_shape_tensor(x_event_ndims=0, y_event_ndims=0)

  @parameterized.named_parameters(
      ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None),
      ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])),
       None),
      ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))),
       None),
      ('chain',
       lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]),  # pylint: disable=g-long-lambda
                          tfb.Invert(tfb.Split(2))]),
       None),
      ('jointmap_01', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]),
      ('jointmap_11', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]),
      ('jointmap_20', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]),
      ('jointmap_22', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]),
      ('nested_jointmap',
       lambda: tfb.JointMap([tfb.JointMap({'a': tfb.Scale([1.]),  # pylint: disable=g-long-lambda
                                           'b': tfb.Exp()}),
                             tfb.Scale([1, 4])(tfb.Invert(tfb.Split(2)))]),
       [{'a': 0, 'b': 0}, [2, 2]]))
  def test_with_broadcast_batch_shape(self, bijector_fn, x_event_ndims=None):
    bijector = bijector_fn()
    if x_event_ndims is None:
      x_event_ndims = bijector.forward_min_event_ndims
    batch_shape = bijector.experimental_batch_shape(x_event_ndims=x_event_ndims)
    param_batch_shapes = batch_shape_lib.batch_shape_parts(
        bijector, bijector_x_event_ndims=x_event_ndims)

    new_batch_shape = [4, 2, 1, 1, 1]
    broadcast_bijector = bijector._broadcast_parameters_with_batch_shape(
        new_batch_shape, x_event_ndims)
    broadcast_batch_shape = broadcast_bijector.experimental_batch_shape_tensor(
        x_event_ndims=x_event_ndims)
    self.assertAllEqual(broadcast_batch_shape,
                        ps.broadcast_shape(batch_shape, new_batch_shape))

    # Check that all params have the expected batch shape.
    broadcast_param_batch_shapes = batch_shape_lib.batch_shape_parts(
        broadcast_bijector, bijector_x_event_ndims=x_event_ndims)

    def _maybe_broadcast_param_batch_shape(p, s):
      if isinstance(p, tfb.Invert) and not p.bijector._params_event_ndims():
        return s  # Can't broadcast a bijector that doesn't itself have params.
      return ps.broadcast_shape(s, new_batch_shape)
    expected_broadcast_param_batch_shapes = tf.nest.map_structure(
        _maybe_broadcast_param_batch_shape,
        {param: getattr(bijector, param) for param in param_batch_shapes},
        param_batch_shapes)
    self.assertAllEqualNested(broadcast_param_batch_shapes,
                              expected_broadcast_param_batch_shapes)