Пример #1
0
  def testChainDynamicToStatic(self):
    if tf.executing_eagerly():
      return

    def xform_dynamic(x):
      return tf1.placeholder_with_default(x, shape=None)

    def xform_static(x):
      # Copy the Tensor, because otherwise the set_shape can pass information
      # into the past.
      x = tf.identity(x)
      tensorshape_util.set_shape(x, [1])
      return x

    def ldj(_):
      return tf.constant(1.)

    # The issue was that the sample's shape was going in-and-out of being fully
    # specified, causing internal consistency issues inside the bijector.
    chain = tfb.Chain([
        tfb.Inline(
            inverse_fn=xform_dynamic,
            forward_min_event_ndims=0,
            forward_log_det_jacobian_fn=ldj,
            forward_fn=xform_dynamic),
        tfb.Inline(
            inverse_fn=xform_static,
            forward_min_event_ndims=0,
            forward_log_det_jacobian_fn=ldj,
            forward_fn=xform_static),
        tfb.Inline(
            inverse_fn=xform_dynamic,
            forward_min_event_ndims=0,
            forward_log_det_jacobian_fn=ldj,
            forward_fn=xform_dynamic)
    ])

    ildj = chain.inverse_log_det_jacobian(
        tf.zeros((2, 3), dtype=tf.float32), event_ndims=1)

    # The shape of `ildj` is known statically to be scalar; its value is
    # not statically known.
    self.assertTrue(tensorshape_util.is_fully_defined(ildj.shape))

    # `ldj_reduce_shape` uses `prefer_static` to get input shapes. That means
    # that we respect statically-known shape information where present.
    # In this case, the manually-assigned static shape is incorrect.
    self.assertEqual(self.evaluate(ildj), -7.)

    # Ditto.
    fldj = chain.forward_log_det_jacobian([0.], event_ndims=0)
    self.assertTrue(tensorshape_util.is_fully_defined(fldj.shape))
    self.assertEqual(self.evaluate(fldj), 3.)
Пример #2
0
    def testBijector(self):
        with self.test_session():
            exp = tfb.Exp()
            inline = tfb.Inline(
                forward_fn=tf.exp,
                inverse_fn=tf.log,
                inverse_log_det_jacobian_fn=lambda y: -tf.log(y),
                forward_log_det_jacobian_fn=lambda x: x,
                forward_min_event_ndims=0,
                name="exp")

            self.assertEqual(exp.name, inline.name)
            x = [[[1., 2.], [3., 4.], [5., 6.]]]
            y = np.exp(x)
            self.assertAllClose(y, self.evaluate(inline.forward(x)))
            self.assertAllClose(x, self.evaluate(inline.inverse(y)))
            self.assertAllClose(
                -np.sum(np.log(y), axis=-1),
                self.evaluate(inline.inverse_log_det_jacobian(y,
                                                              event_ndims=1)))
            self.assertAllClose(
                self.evaluate(
                    -inline.inverse_log_det_jacobian(y, event_ndims=1)),
                self.evaluate(inline.forward_log_det_jacobian(x,
                                                              event_ndims=1)))
Пример #3
0
 def testIsIncreasing(self):
     inline = tfb.Inline(
         forward_fn=tf.exp,
         inverse_fn=tf.math.log,
         inverse_log_det_jacobian_fn=lambda y: -tf.math.log(y),
         forward_min_event_ndims=0,
         is_increasing=True,
         name='exp')
     self.assertAllEqual(True, inline._internal_is_increasing())
     inline = tfb.Inline(
         forward_fn=lambda x: tf.exp(x) * [1., -1],
         inverse_fn=lambda y: tf.math.log(y * [1., -1]),
         inverse_log_det_jacobian_fn=lambda y: -tf.math.log(y),
         forward_min_event_ndims=0,
         is_increasing=lambda: [True, False],
         name='exp')
     self.assertAllEqual([True, False], inline._internal_is_increasing())
Пример #4
0
  def testChainDynamicToStatic(self):
    if tf.executing_eagerly():
      return

    def xform_dynamic(x):
      return tf.compat.v1.placeholder_with_default(x, shape=None)

    def xform_static(x):
      tensorshape_util.set_shape(x, [1])
      return x

    def ldj(_):
      return tf.constant(0.)

    # The issue was that the sample's shape was going in-and-out of being fully
    # specified, causing internal consistency issues inside the bijector.
    chain = tfb.Chain([
        tfb.Inline(
            inverse_log_det_jacobian_fn=ldj,
            inverse_fn=xform_dynamic,
            forward_min_event_ndims=0,
            forward_log_det_jacobian_fn=ldj,
            forward_fn=xform_dynamic),
        tfb.Inline(
            inverse_log_det_jacobian_fn=ldj,
            inverse_fn=xform_static,
            forward_min_event_ndims=0,
            forward_log_det_jacobian_fn=ldj,
            forward_fn=xform_static),
        tfb.Inline(
            inverse_log_det_jacobian_fn=ldj,
            inverse_fn=xform_dynamic,
            forward_min_event_ndims=0,
            forward_log_det_jacobian_fn=ldj,
            forward_fn=xform_dynamic)
    ])

    ildj = chain.inverse_log_det_jacobian([0.], event_ndims=0)
    # The static shape information is lost on the account of the final bijector
    # being dynamic.
    self.assertFalse(tensorshape_util.is_fully_defined(ildj.shape))
    fldj = chain.forward_log_det_jacobian([0.], event_ndims=0)
    # Ditto.
    self.assertFalse(tensorshape_util.is_fully_defined(fldj.shape))
Пример #5
0
    def testChainDynamicToStatic(self):
        if tf.executing_eagerly():
            return

        def xform_dynamic(x):
            return tf1.placeholder_with_default(x, shape=None)

        def xform_static(x):
            tensorshape_util.set_shape(x, [1])
            return x

        def ldj(_):
            return tf.constant(1.)

        # The issue was that the sample's shape was going in-and-out of being fully
        # specified, causing internal consistency issues inside the bijector.
        chain = tfb.Chain([
            tfb.Inline(inverse_fn=xform_dynamic,
                       forward_min_event_ndims=0,
                       forward_log_det_jacobian_fn=ldj,
                       forward_fn=xform_dynamic),
            tfb.Inline(inverse_fn=xform_static,
                       forward_min_event_ndims=0,
                       forward_log_det_jacobian_fn=ldj,
                       forward_fn=xform_static),
            tfb.Inline(inverse_fn=xform_dynamic,
                       forward_min_event_ndims=0,
                       forward_log_det_jacobian_fn=ldj,
                       forward_fn=xform_dynamic)
        ])

        ildj = chain.inverse_log_det_jacobian(tf.zeros((2, 3),
                                                       dtype=tf.float32),
                                              event_ndims=1)

        # The shape of `ildj` is known statically to be scalar; its value is
        # not statically known.
        self.assertTrue(tensorshape_util.is_fully_defined(ildj.shape))
        self.assertEqual(self.evaluate(ildj), -9.)

        # Ditto.
        fldj = chain.forward_log_det_jacobian([0.], event_ndims=0)
        self.assertTrue(tensorshape_util.is_fully_defined(fldj.shape))
        self.assertEqual(self.evaluate(fldj), 3.)
Пример #6
0
 def testShapeGetters(self):
     bijector = tfb.Inline(
         forward_event_shape_tensor_fn=lambda x: tf.concat((x, [1]), 0),
         forward_event_shape_fn=lambda x: x.as_list() + [1],
         inverse_event_shape_tensor_fn=lambda x: x[:-1],
         inverse_event_shape_fn=lambda x: x[:-1],
         forward_min_event_ndims=0,
         name="shape_only")
     x = tf.TensorShape([1, 2, 3])
     y = tf.TensorShape([1, 2, 3, 1])
     self.assertAllEqual(y, bijector.forward_event_shape(x))
     self.assertAllEqual(
         y.as_list(),
         self.evaluate(bijector.forward_event_shape_tensor(x.as_list())))
     self.assertAllEqual(x, bijector.inverse_event_shape(y))
     self.assertAllEqual(
         x.as_list(),
         self.evaluate(bijector.inverse_event_shape_tensor(y.as_list())))
Пример #7
0
  def testBijector(self):
    inline = tfb.Inline(
        forward_fn=tf.exp,
        inverse_fn=tf.math.log,
        inverse_log_det_jacobian_fn=lambda y: -tf.math.log(y),
        forward_min_event_ndims=0,
        name='exp')

    self.assertStartsWith(inline.name, 'exp')
    x = [[[1., 2.], [3., 4.], [5., 6.]]]
    y = np.exp(x)
    self.assertAllClose(y, self.evaluate(inline.forward(x)))
    self.assertAllClose(x, self.evaluate(inline.inverse(y)))
    self.assertAllClose(
        -np.sum(np.log(y), axis=-1),
        self.evaluate(inline.inverse_log_det_jacobian(y, event_ndims=1)))
    self.assertAllClose(
        self.evaluate(-inline.inverse_log_det_jacobian(y, event_ndims=1)),
        self.evaluate(inline.forward_log_det_jacobian(x, event_ndims=1)))
Пример #8
0
    def testCastLogDetJacobian(self):
        """Test log_prob when Jacobian and log_prob dtypes do not match."""

        # Create an identity bijector whose jacobians have dtype int32
        int_identity = tfb.Inline(
            forward_fn=tf.identity,
            inverse_fn=tf.identity,
            inverse_log_det_jacobian_fn=(lambda y: tf.cast(0, tf.int32)),
            forward_log_det_jacobian_fn=(lambda x: tf.cast(0, tf.int32)),
            forward_min_event_ndims=0,
            is_constant_jacobian=True)
        normal = self._cls()(distribution=tfd.Normal(loc=0., scale=1.),
                             bijector=int_identity,
                             validate_args=True)

        y = normal.sample(seed=test_util.test_seed())
        self.evaluate(normal.log_prob(y))
        self.evaluate(normal.prob(y))
        self.evaluate(normal.mean())
        self.evaluate(normal.entropy())
Пример #9
0
    def testKwargs(self):
        zeros = tf.zeros(1)

        bijectors = [
            tfb.Inline(  # pylint: disable=g-complex-comprehension
                forward_fn=mock.Mock(return_value=zeros),
                inverse_fn=mock.Mock(return_value=zeros),
                forward_log_det_jacobian_fn=mock.Mock(return_value=zeros),
                inverse_log_det_jacobian_fn=mock.Mock(return_value=zeros),
                forward_min_event_ndims=0,
                name='inner{}'.format(i)) for i in range(2)
        ]

        blockwise = tfb.Blockwise(bijectors)

        x = [1, 2]
        blockwise.forward(x, inner0={'arg': 1}, inner1={'arg': 2})
        blockwise.inverse(x, inner0={'arg': 3}, inner1={'arg': 4})
        blockwise.forward_log_det_jacobian(x,
                                           event_ndims=1,
                                           inner0={'arg': 5},
                                           inner1={'arg': 6})
        blockwise.inverse_log_det_jacobian(x,
                                           event_ndims=1,
                                           inner0={'arg': 7},
                                           inner1={'arg': 8})

        bijectors[0]._forward.assert_any_call(mock.ANY, arg=1)
        bijectors[1]._forward.assert_any_call(mock.ANY, arg=2)
        bijectors[0]._inverse.assert_any_call(mock.ANY, arg=3)
        bijectors[1]._inverse.assert_any_call(mock.ANY, arg=4)
        bijectors[0]._forward_log_det_jacobian.assert_called_with(mock.ANY,
                                                                  arg=5)
        bijectors[1]._forward_log_det_jacobian.assert_called_with(mock.ANY,
                                                                  arg=6)
        bijectors[0]._inverse_log_det_jacobian.assert_called_with(mock.ANY,
                                                                  arg=7)
        bijectors[1]._inverse_log_det_jacobian.assert_called_with(mock.ANY,
                                                                  arg=8)
Пример #10
0
 def bijector_fn(*args, **kwargs):
     del args, kwargs
     return tfb.Inline(forward_min_event_ndims=0,
                       inverse_min_event_ndims=1)
Пример #11
0
def bijectors(draw,
              bijector_name=None,
              batch_shape=None,
              event_dim=None,
              enable_vars=False):
    """Strategy for drawing Bijectors.

  The emitted bijector may be a basic bijector or an `Invert` of a basic
  bijector, but not a compound like `Chain`.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    bijector_name: Optional Python `str`.  If given, the produced bijectors
      will all have this type.  If omitted, Hypothesis chooses one from
      the whitelist `TF2_FRIENDLY_BIJECTORS`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      bijector.  Hypothesis will pick one if omitted.
    event_dim: Optional Python int giving the size of each of the underlying
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor`
      `tfp.util.TransformedVariable`}

  Returns:
    bijectors: A strategy for drawing bijectors with the specified `batch_shape`
      (or an arbitrary one if omitted).
  """
    if bijector_name is None:
        bijector_name = draw(hps.sampled_from(TF2_FRIENDLY_BIJECTORS))
    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes())
    if event_dim is None:
        event_dim = draw(hps.integers(min_value=2, max_value=6))
    if bijector_name == 'Invert':
        underlying_name = draw(
            hps.sampled_from(sorted(set(TF2_FRIENDLY_BIJECTORS) - {'Invert'})))
        underlying = draw(
            bijectors(bijector_name=underlying_name,
                      batch_shape=batch_shape,
                      event_dim=event_dim,
                      enable_vars=enable_vars))
        return tfb.Invert(underlying, validate_args=True)
    if bijector_name == 'Inline':
        if enable_vars:
            scale = tf.Variable(1., name='scale')
        else:
            scale = 2.
        b = tfb.AffineScalar(scale=scale)

        inline = tfb.Inline(
            forward_fn=b.forward,
            inverse_fn=b.inverse,
            forward_log_det_jacobian_fn=lambda x: b.forward_log_det_jacobian(  # pylint: disable=g-long-lambda
                x,
                event_ndims=b.forward_min_event_ndims),
            forward_min_event_ndims=b.forward_min_event_ndims,
            is_constant_jacobian=b.is_constant_jacobian,
        )
        inline.b = b
        return inline
    if bijector_name == 'DiscreteCosineTransform':
        dct_type = draw(hps.integers(min_value=2, max_value=3))
        return tfb.DiscreteCosineTransform(validate_args=True,
                                           dct_type=dct_type)
    if bijector_name == 'PowerTransform':
        power = draw(hps.floats(min_value=0., max_value=10.))
        return tfb.PowerTransform(validate_args=True, power=power)

    bijector_params = draw(
        broadcasting_params(bijector_name,
                            batch_shape,
                            event_dim=event_dim,
                            enable_vars=enable_vars))
    ctor = getattr(tfb, bijector_name)
    return ctor(validate_args=True, **bijector_params)
Пример #12
0
def make_distribution_bijector(distribution,
                               name='make_distribution_bijector'):
    """Builds a bijector to approximately transform `N(0, 1)` into `distribution`.

  This represents a distribution as a bijector that
  transforms a (multivariate) standard normal distribution into the distribution
  of interest.

  Args:
    distribution: A `tfd.Distribution` instance; this may be a joint
      distribution.
    name: Python `str` name for ops created by this method.
  Returns:
    distribution_bijector: a `tfb.Bijector` instance such that
      `distribution_bijector(tfd.Normal(0., 1.))` is approximately equivalent
      to `distribution`.

  #### Examples

  This method may be used to convert structured variational distributions into
  MCMC preconditioners. Consider a model containing
  [funnel geometry](https://crackedbassoon.com/writing/funneling), which may
  be difficult for an MCMC algorithm to sample directly.

  ```python
  model_with_funnel = tfd.JointDistributionSequentialAutoBatched([
      tfd.Normal(loc=-1., scale=2., name='z'),
      lambda z: tfd.Normal(loc=[0., 0., 0.], scale=tf.exp(z), name='x'),
      lambda x: tfd.Poisson(log_rate=x, name='y')])
  pinned_model = tfp.experimental.distributions.JointDistributionPinned(
      model_with_funnel, y=[1, 3, 0])
  ```

  We can approximate the posterior in this model using a structured variational
  surrogate distribution, which will capture the funnel geometry, but cannot
  exactly represent the (non-Gaussian) posterior.

  ```python
  # Build and fit a structured surrogate posterior distribution.
  surrogate_posterior = tfp.experimental.vi.build_asvi_surrogate_posterior(
    pinned_model)
  _ = tfp.vi.fit_surrogate_posterior(pinned_model.unnormalized_log_prob,
                                     surrogate_posterior=surrogate_posterior,
                                     optimizer=tf.optimizers.Adam(0.01),
                                     num_steps=200)
  ```

  Creating a preconditioning bijector allows us to obtain higher-quality
  posterior samples, without any Gaussianity assumption, by using the surrogate
  to guide an MCMC sampler.

  ```python
  surrogate_posterior_bijector = (
    tfp.experimental.bijectors.make_distribution_bijector(surrogate_posterior))
  samples, _ = tfp.mcmc.sample_chain(
    kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
      tfp.mcmc.TransformedTransitionKernel(
        tfp.mcmc.NoUTurnSampler(pinned_model.unnormalized_log_prob,
                                step_size=0.1),
        bijector=surrogate_posterior_bijector),
      num_adaptation_steps=80),
    current_state=surrogate_posterior.sample(),
    num_burnin_steps=100,
    trace_fn=lambda _0, _1: [],
    num_results=500)
  ```

  #### Mathematical details

  The bijectors returned by this method generally follow the following
  principles, although the specific bijectors returned may vary without notice.

  Normal distributions are reparameterized by a location-scale transform.

  ```python
  b = tfp.experimental.bijectors.make_distribution_bijector(
    tfd.Normal(loc=10., scale=5.))
  # ==> tfb.Shift(10.)(tfb.Scale(5.)))

  b = tfp.experimental.bijectors.make_distribution_bijector(
    tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale_tril))
  # ==> tfb.Shift(loc)(tfb.ScaleMatvecTriL(scale_tril))
  ```

  The distribution's `quantile` function is used, when available:

  ```python
  d = tfd.Cauchy(loc=loc, scale=scale)
  b = tfp.experimental.bijectors.make_distribution_bijector(d)
  # ==> tfb.Inline(forward_fn=d.quantile, inverse_fn=d.cdf)(tfb.NormalCDF())
  ```

  Otherwise, a quantile function is derived by inverting the CDF:

  ```python
  d = tfd.Gamma(concentration=alpha, rate=beta)
  b = tfp.experimental.bijectors.make_distribution_bijector(d)
  # ==> tfb.Invert(
  #  tfp.experimental.bijectors.ScalarFunctionWithInferredInverse(fn=d.cdf))(
  #    tfb.NormalCDF())
  ```

  Transformed distributions are represented by chaining the
  transforming bijector with a preconditioning bijector for the base
  distribution:

  ```python
  b = tfp.experimental.bijectors.make_distribution_bijector(
    tfb.Exp(tfd.Normal(loc=10., scale=5.)))
  # ==> tfb.Exp(tfb.Shift(10.)(tfb.Scale(5.)))
  ```

  Joint distributions are represented by a joint bijector, which converts each
  component distribution to a bijector with parameters conditioned
  on the previous variables in the model. The joint bijector's inputs and
  outputs follow the structure of the joint distribution.

  ```python
  jd = tfd.JointDistributionNamed(
      {'a': tfd.InverseGamma(concentration=2., scale=1.),
       'b': lambda a: tfd.Normal(loc=3., scale=tf.sqrt(a))})
  b = tfp.experimental.bijectors.make_distribution_bijector(jd)
  whitened_jd = tfb.Invert(b)(jd)
  x = whitened_jd.sample()
  # x <=> {'a': tfd.Normal(0., 1.).sample(), 'b': tfd.Normal(0., 1.).sample()}
  ```

  """

    with tf.name_scope(name):
        event_space_bijector = (
            distribution.experimental_default_event_space_bijector())
        if event_space_bijector is None:  # Fail if the distribution is discrete.
            raise NotImplementedError(
                'Cannot transform distribution {} to a standard normal '
                'distribution.'.format(distribution))

        # Recurse over joint distributions.
        if isinstance(distribution, joint_distribution.JointDistribution):
            return joint_distribution._DefaultJointBijector(  # pylint: disable=protected-access
                distribution,
                bijector_fn=make_distribution_bijector)

        # Recurse through transformed distributions.
        if isinstance(distribution,
                      transformed_distribution.TransformedDistribution):
            return distribution.bijector(
                make_distribution_bijector(distribution.distribution))

        # If we've annotated a specific bijector for this distribution, use that.
        if isinstance(distribution, tuple(preconditioning_bijector_fns)):
            return preconditioning_bijector_fns[type(distribution)](
                distribution)

        # Otherwise, if this distribution implements a CDF and inverse CDF, build
        # a bijector from those.
        implements_cdf = False
        implements_quantile = False
        input_spec = tf.zeros(shape=distribution.event_shape,
                              dtype=distribution.dtype)
        try:
            callable_util.get_output_spec(distribution.cdf, input_spec)
            implements_cdf = True
        except NotImplementedError:
            pass
        try:
            callable_util.get_output_spec(distribution.quantile, input_spec)
            implements_quantile = True
        except NotImplementedError:
            pass
        if implements_cdf and implements_quantile:
            # This path will only trigger for scalar distributions, since multivariate
            # distributions have non-invertible CDF and so cannot define a `quantile`.
            return tfb.Inline(forward_fn=distribution.quantile,
                              inverse_fn=distribution.cdf,
                              forward_min_event_ndims=ps.rank_from_shape(
                                  distribution.event_shape_tensor,
                                  distribution.event_shape))(tfb.NormalCDF())

        # If the events are scalar, try to invert the CDF numerically.
        if implements_cdf and tf.get_static_value(
                distribution.is_scalar_event()):
            return tfb.Invert(
                scalar_function_with_inferred_inverse.
                ScalarFunctionWithInferredInverse(
                    distribution.cdf,
                    domain_constraint_fn=(event_space_bijector)))(
                        tfb.NormalCDF())

        raise NotImplementedError('Could not automatically construct a '
                                  'bijector for distribution type '
                                  '{}; it does not implement an invertible '
                                  'CDF.'.format(distribution))
Пример #13
0
def bijectors(draw,
              bijector_name=None,
              batch_shape=None,
              event_dim=None,
              enable_vars=False):
    """Strategy for drawing Bijectors.

  The emitted bijector may be a basic bijector or an `Invert` of a basic
  bijector, but not a compound like `Chain`.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    bijector_name: Optional Python `str`.  If given, the produced bijectors
      will all have this type.  If omitted, Hypothesis chooses one from
      the whitelist `TF2_FRIENDLY_BIJECTORS`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      bijector.  Hypothesis will pick one if omitted.
    event_dim: Optional Python int giving the size of each of the underlying
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor`
      `tfp.util.TransformedVariable`}

  Returns:
    bijectors: A strategy for drawing bijectors with the specified `batch_shape`
      (or an arbitrary one if omitted).
  """
    if bijector_name is None:
        bijector_name = draw(hps.sampled_from(TF2_FRIENDLY_BIJECTORS))
    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes())
    if event_dim is None:
        event_dim = draw(hps.integers(min_value=2, max_value=6))
    if bijector_name == 'Invert':
        underlying_name = draw(
            hps.sampled_from(sorted(set(TF2_FRIENDLY_BIJECTORS) - {'Invert'})))
        underlying = draw(
            bijectors(bijector_name=underlying_name,
                      batch_shape=batch_shape,
                      event_dim=event_dim,
                      enable_vars=enable_vars))
        return tfb.Invert(underlying, validate_args=True)
    if bijector_name == 'TransformDiagonal':
        underlying_name = draw(
            hps.sampled_from(sorted(TRANSFORM_DIAGONAL_WHITELIST)))
        underlying = draw(
            bijectors(bijector_name=underlying_name,
                      batch_shape=(),
                      event_dim=event_dim,
                      enable_vars=enable_vars))
        return tfb.TransformDiagonal(underlying, validate_args=True)
    if bijector_name == 'Inline':
        if enable_vars:
            scale = tf.Variable(1., name='scale')
        else:
            scale = 2.
        b = tfb.AffineScalar(scale=scale)

        inline = tfb.Inline(
            forward_fn=b.forward,
            inverse_fn=b.inverse,
            forward_log_det_jacobian_fn=lambda x: b.forward_log_det_jacobian(  # pylint: disable=g-long-lambda
                x,
                event_ndims=b.forward_min_event_ndims),
            forward_min_event_ndims=b.forward_min_event_ndims,
            is_constant_jacobian=b.is_constant_jacobian,
        )
        inline.b = b
        return inline
    if bijector_name == 'DiscreteCosineTransform':
        dct_type = draw(hps.integers(min_value=2, max_value=3))
        return tfb.DiscreteCosineTransform(validate_args=True,
                                           dct_type=dct_type)
    if bijector_name == 'PowerTransform':
        power = draw(hps.floats(min_value=0., max_value=10.))
        return tfb.PowerTransform(validate_args=True, power=power)
    if bijector_name == 'Permute':
        event_ndims = draw(hps.integers(min_value=1, max_value=2))
        axis = draw(hps.integers(min_value=-event_ndims, max_value=-1))
        # This is a permutation of dimensions within an axis.
        # (Contrast with `Transpose` below.)
        permutation = draw(hps.permutations(np.arange(event_dim)))
        return tfb.Permute(permutation, axis=axis)
    if bijector_name == 'Reshape':
        event_shape_out = draw(tfp_hps.shapes(min_ndims=1))
        # TODO(b/142135119): Wanted to draw general input and output shapes like the
        # following, but Hypothesis complained about filtering out too many things.
        # event_shape_in = draw(tfp_hps.shapes(min_ndims=1))
        # hp.assume(event_shape_out.num_elements() == event_shape_in.num_elements())
        event_shape_in = [event_shape_out.num_elements()]
        return tfb.Reshape(event_shape_out=event_shape_out,
                           event_shape_in=event_shape_in,
                           validate_args=True)
    if bijector_name == 'Transpose':
        event_ndims = draw(hps.integers(min_value=0, max_value=2))
        # This is a permutation of axes.
        # (Contrast with `Permute` above.)
        permutation = draw(hps.permutations(np.arange(event_ndims)))
        return tfb.Transpose(perm=permutation)

    bijector_params = draw(
        broadcasting_params(bijector_name,
                            batch_shape,
                            event_dim=event_dim,
                            enable_vars=enable_vars))
    ctor = getattr(tfb, bijector_name)
    return ctor(validate_args=True, **bijector_params)