Ejemplo n.º 1
0
 def _default_event_space_bijector(self):
     return chain_bijector.Chain([
         shift_bijector.Shift(shift=self.loc,
                              validate_args=self.validate_args),
         exp_bijector.Exp(validate_args=self.validate_args)
     ],
                                 validate_args=self.validate_args)
Ejemplo n.º 2
0
 def _transformed_beta(self,
                       low=None,
                       peak=None,
                       high=None,
                       temperature=None):
     low = tf.convert_to_tensor(self.low) if low is None else low
     peak = tf.convert_to_tensor(self.peak) if peak is None else peak
     high = tf.convert_to_tensor(self.high) if high is None else high
     temperature = (tf.convert_to_tensor(self.temperature)
                    if temperature is None else temperature)
     scale = high - low
     concentration1 = (1. + temperature * (peak - low) / scale)
     concentration0 = (1. + temperature * (high - peak) / scale)
     return transformed_distribution.TransformedDistribution(
         distribution=beta.Beta(concentration1=concentration1,
                                concentration0=concentration0,
                                allow_nan_stats=self.allow_nan_stats),
         bijector=chain_bijector.Chain([
             shift_bijector.Shift(shift=low),
             # Broadcasting scale on affine bijector to match batch dimension.
             # This prevents dimension mismatch for operations like cdf.
             # Note that `concentration1` incorporates the broadcast of all four
             # parameters.
             scale_bijector.Scale(
                 scale=tf.broadcast_to(scale, ps.shape(concentration1)))
         ]))
Ejemplo n.º 3
0
    def __init__(self,
                 loc,
                 scale,
                 concentration,
                 validate_args=False,
                 name='generalized_pareto'):
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale, concentration],
                                            dtype_hint=tf.float32)

            self._loc = tensor_util.convert_nonref_to_tensor(loc)
            self._scale = tensor_util.convert_nonref_to_tensor(scale)
            self._concentration = tensor_util.convert_nonref_to_tensor(
                concentration)
            self._non_negative_concentration_bijector = chain_bijector.Chain(
                [
                    shift_bijector.Shift(shift=self._loc,
                                         validate_args=validate_args),
                    softplus_bijector.Softplus(validate_args=validate_args)
                ],
                validate_args=validate_args)
            super(GeneralizedPareto,
                  self).__init__(validate_args=validate_args,
                                 forward_min_event_ndims=0,
                                 dtype=dtype,
                                 parameters=parameters,
                                 name=name)
Ejemplo n.º 4
0
    def _make_mixture_dist(self, component_logits, locs, scales):
        """Builds a mixture of quantized logistic distributions.

    Args:
      component_logits: 4D `Tensor` of logits for the Categorical distribution
        over Quantized Logistic mixture components. Dimensions are `[batch_size,
        height, width, num_logistic_mix]`.
      locs: 4D `Tensor` of location parameters for the Quantized Logistic
        mixture components. Dimensions are `[batch_size, height, width,
        num_logistic_mix, num_channels]`.
      scales: 4D `Tensor` of location parameters for the Quantized Logistic
        mixture components. Dimensions are `[batch_size, height, width,
        num_logistic_mix, num_channels]`.

    Returns:
      dist: A quantized logistic mixture `tfp.distribution` over the input data.
    """
        mixture_distribution = categorical.Categorical(logits=component_logits)

        # Convert distribution parameters for pixel values in
        # `[self._low, self._high]` for use with `QuantizedDistribution`
        locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.)
        scales *= 0.5 * (self._high - self._low)
        logistic_dist = quantized_distribution.QuantizedDistribution(
            distribution=transformed_distribution.TransformedDistribution(
                distribution=logistic.Logistic(loc=locs, scale=scales),
                bijector=shift.Shift(shift=tf.cast(-0.5, self.dtype))),
            low=self._low,
            high=self._high)

        dist = mixture_same_family.MixtureSameFamily(
            mixture_distribution=mixture_distribution,
            components_distribution=independent.Independent(
                logistic_dist, reinterpreted_batch_ndims=1))
        return independent.Independent(dist, reinterpreted_batch_ndims=2)
Ejemplo n.º 5
0
 def _default_event_space_bijector(self):
   # TODO(b/145620027) Finalize choice of bijector.
   deferred_scale = DeferredTensor(self.scale, lambda x: x)
   return chain_bijector.Chain([
       shift_bijector.Shift(
           shift=deferred_scale, validate_args=self.validate_args),
       softplus_bijector.Softplus(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Ejemplo n.º 6
0
 def _bijector_fn(x0, input_depth, **condition_kwargs):
   shift, log_scale = shift_and_log_scale_fn(x0, input_depth,
                                             **condition_kwargs)
   bijectors = []
   if shift is not None:
     bijectors.append(shift_lib.Shift(shift))
   if log_scale is not None:
     bijectors.append(scale_lib.Scale(log_scale=log_scale))
   return chain_lib.Chain(bijectors)
 def _default_event_space_bijector(self):
     return chain_bijector.Chain([
         shift_bijector.Shift(shift=self.loc,
                              validate_args=self.validate_args),
         scale_matvec_linear_operator.ScaleMatvecLinearOperator(
             scale=self.scale, validate_args=self.validate_args),
         softplus_bijector.Softplus(validate_args=self.validate_args)
     ],
                                 validate_args=self.validate_args)
Ejemplo n.º 8
0
        def bijector_fn(inputs, ignored_input):
            """Decorated function to get the RealNVP bijector."""
            # Build this so we can handle a user passing a NN that returns a tensor
            # OR an NN that returns a bijector
            possible_output = layer(inputs)

            # We need to produce a bijector, but we do not know if the layer has done
            # so. We are setting this up to handle 2 possibilities:
            # 1) The layer outputs a bijector --> all is good
            # 2) The layer outputs a tensor --> we need to turn it into a bijector.
            if isinstance(possible_output, bijector.Bijector):
                output = possible_output
            elif isinstance(possible_output, tf.Tensor):
                input_shape = inputs.get_shape().as_list()
                output_shape = possible_output.get_shape().as_list()
                assert input_shape[:-1] == output_shape[:-1]
                c = input_shape[-1]

                # For layers which output a tensor, we have two possibilities:
                # 1) There are twice as many output channels as inputs --> the coupling
                #    is affine, meaning there is a scale followed by a shift.
                # 2) There are an equal number of input and output channels --> the
                #    coupling is additive, meaning there is just a shift
                if input_shape[-1] == output_shape[-1] // 2:
                    this_scale = scale.Scale(
                        scale_fn(possible_output[..., :c] + 2.))
                    this_shift = shift.Shift(possible_output[..., c:])
                    output = this_shift(this_scale)
                elif input_shape[-1] == output_shape[-1]:

                    output = shift.Shift(possible_output[..., :c])
                else:
                    raise ValueError(
                        'Shape inconsistent with input. Expected shape'
                        '{0} or {1} but tensor was shape {2}'.format(
                            input_shape,
                            tf.concat(
                                [input_shape[:-1], [2 * input_shape[-1]]], 0),
                            output_shape))
            else:
                raise ValueError(
                    'Expected a bijector or a tensor, but instead got'
                    '{}'.format(possible_output.__class__))
            return output
Ejemplo n.º 9
0
 def _default_event_space_bijector(self):
     # TODO(b/145620027) Finalize choice of bijector.
     return chain_bijector.Chain([
         shift_bijector.Shift(shift=-np.pi,
                              validate_args=self.validate_args),
         scale_bijector.Scale(scale=2. * np.pi,
                              validate_args=self.validate_args),
         sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
     ],
                                 validate_args=self.validate_args)
Ejemplo n.º 10
0
 def _default_event_space_bijector(self):
   if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high):
     scale = DeferredTensor(self.high, lambda x: x - self.low)
   else:
     scale = self.high - self.low
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.low, validate_args=self.validate_args),
       scale_bijector.Scale(scale=scale, validate_args=self.validate_args),
       sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Ejemplo n.º 11
0
 def _default_event_space_bijector(self):
     low = tfp_util.DeferredTensor(self.low, lambda x: x)
     scale = tfp_util.DeferredTensor(self.high, lambda x: x - self.low)
     return chain_bijector.Chain([
         shift_bijector.Shift(shift=low, validate_args=self.validate_args),
         scale_bijector.Scale(scale=scale,
                              validate_args=self.validate_args),
         sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
     ],
                                 validate_args=self.validate_args)
Ejemplo n.º 12
0
    def __init__(self,
                 diag_bijector=None,
                 diag_shift=1e-5,
                 validate_args=False,
                 name='fill_scale_tril'):
        """Instantiates the `FillScaleTriL` bijector.

    Args:
      diag_bijector: `Bijector` instance, used to transform the output diagonal
        to be positive. Must be an instance of `tf.__internal__.CompositeTensor`
        (including `tfb.AutoCompositeTensorBijector`).
        Default value: `None` (i.e., `tfb.Softplus()`).
      diag_shift: Float value broadcastable and added to all diagonal entries
        after applying the `diag_bijector`. Setting a positive
        value forces the output diagonal entries to be positive, but
        prevents inverting the transformation for matrices with
        diagonal entries less than this value.
        Default value: `1e-5`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False` (i.e., arguments are not validated).
      name: Python `str` name given to ops managed by this object.
        Default value: `fill_scale_tril`.

    Raises:
      TypeError, if `diag_bijector` is not an instance of
        `tf.__internal__.CompositeTensor`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            if diag_bijector is None:
                diag_bijector = softplus.Softplus(validate_args=validate_args)
            if not isinstance(diag_bijector, tf.__internal__.CompositeTensor):
                raise TypeError('`diag_bijector` must be an instance of '
                                '`tf.__internal__.CompositeTensor`.')

            if diag_shift is not None:
                dtype = dtype_util.common_dtype([diag_bijector, diag_shift],
                                                tf.float32)
                diag_shift = tensor_util.convert_nonref_to_tensor(
                    diag_shift, name='diag_shift', dtype=dtype)
                diag_bijector = chain.Chain(
                    [shift.Shift(shift=diag_shift), diag_bijector])

            super(FillScaleTriL, self).__init__([
                transform_diagonal.TransformDiagonal(
                    diag_bijector=diag_bijector),
                fill_triangular.FillTriangular()
            ],
                                                validate_args=validate_args,
                                                validate_event_size=False,
                                                parameters=parameters,
                                                name=name)
    def __init__(self,
                 shift,
                 scale,
                 tailweight,
                 validate_args=False,
                 name="lambertw_tail"):
        """Construct a location scale heavy-tail Lambert W bijector.

    The parameters `shift`, `scale`, and `tail` must be shaped in a way that
    supports broadcasting (e.g. `shift + scale + tail` is a valid operation).

    Args:
      shift: Floating point tensor; the shift for centering (uncentering) the
        input (output) random variable(s).
      scale: Floating point tensor; the scaling (unscaling) of the input
        (output) random variable(s). Must contain only positive values.
      tailweight: Floating point tensor; the tail behaviors of the output random
        variable(s).  Must contain only non-negative values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `shift` and `scale` and `tail` have different `dtype`.
    """
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([tailweight, shift, scale],
                                            tf.float32)
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name="tailweight", dtype=dtype)
            self._shift = tensor_util.convert_nonref_to_tensor(shift,
                                                               name="shift",
                                                               dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name="scale",
                                                               dtype=dtype)
            dtype_util.assert_same_float_dtype(
                (self._tailweight, self._shift, self._scale))

            self._shift_and_scale = chain.Chain(
                [tfb_shift.Shift(self._shift),
                 tfb_scale.Scale(self._scale)])
            # 'bijectors' argument in tfb.Chain super class are executed in reverse(!)
            # order.  Hence the ordering in the list must be (3,2,1), not (1,2,3).
            super(LambertWTail, self).__init__(bijectors=[
                self._shift_and_scale,
                _HeavyTailOnly(tailweight=self._tailweight),
                invert.Invert(self._shift_and_scale)
            ],
                                               validate_args=validate_args)
Ejemplo n.º 14
0
 def _negative_concentration_bijector(self):
   # Constructed dynamically so that `scale * reciprocal(concentration)` is
   # tape-safe.
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args),
       # TODO(b/146568897): Resolve numerical issues by implementing a new
       # bijector instead of multiplying `scale` by `(1. - 1e-6)`.
       scale_bijector.Scale(
           scale=-(self.scale *
                   tf.math.reciprocal(self.concentration) * (1. - 1e-6)),
           validate_args=self.validate_args),
       sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Ejemplo n.º 15
0
 def _default_event_space_bijector(self):
   # TODO(b/146568897): Resolve numerical issues by implementing a new bijector
   # instead of multiplying `scale` by `(1. - 1e-6)`.
   if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high):
     scale = DeferredTensor(
         self.high,
         lambda x: (x - self.low) * (1. - 1e-6),
         shape=tf.broadcast_static_shape(self.high.shape, self.low.shape))
   else:
     scale = (self.high - self.low) * (1. - 1e-6)
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.low, validate_args=self.validate_args),
       scale_bijector.Scale(scale=scale, validate_args=self.validate_args),
       sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Ejemplo n.º 16
0
  def __init__(self, nchan, dtype=tf.float32, validate_args=False, name=None):
    parameters = dict(locals())

    self._initialized = tf.Variable(False, trainable=False)
    self._m = tf.Variable(tf.zeros(nchan, dtype))
    self._s = TransformedVariable(tf.ones(nchan, dtype), exp.Exp())
    self._bijector = invert.Invert(
        chain.Chain([
            scale.Scale(self._s),
            shift.Shift(self._m),
        ]))
    super(ActivationNormalization, self).__init__(
        validate_args=validate_args,
        forward_min_event_ndims=1,
        parameters=parameters,
        name=name or 'ActivationNormalization')
Ejemplo n.º 17
0
def _as_trainable_family(distribution):
  """Substitutes prior distributions with more easily trainable ones."""
  with tf.name_scope('as_trainable_family'):

    if isinstance(distribution, half_normal.HalfNormal):
      return truncated_normal.TruncatedNormal(
          loc=0.,
          scale=distribution.scale,
          low=0.,
          high=distribution.scale * 10.)
    elif isinstance(distribution, uniform.Uniform):
      return shift.Shift(distribution.low)(
          scale_lib.Scale(distribution.high - distribution.low)(beta.Beta(
              concentration0=tf.ones(
                  distribution.event_shape_tensor(), dtype=distribution.dtype),
              concentration1=1.)))
    else:
      return distribution
def build_affine_surrogate_posterior_from_base_distribution(
        base_distribution,
        operators='diag',
        bijector=None,
        initial_unconstrained_loc_fn=_sample_uniform_initial_loc,
        seed=None,
        validate_args=False,
        name=None):
    """Builds a variational posterior by linearly transforming base distributions.

  This function builds a surrogate posterior by applying a trainable
  transformation to a base distribution (typically a `tfd.JointDistribution`) or
  nested structure of base distributions, and constraining the samples with
  `bijector`. Note that the distributions must have event shapes corresponding
  to the *pretransformed* surrogate posterior -- that is, if `bijector` contains
  a shape-changing bijector, then the corresponding base distribution event
  shape is the inverse event shape of the bijector applied to the desired
  surrogate posterior shape. The surrogate posterior is constucted as follows:

  1. Flatten the base distribution event shapes to vectors, and pack the base
     distributions into a `tfd.JointDistribution`.
  2. Apply a trainable blockwise LinearOperator bijector to the joint base
     distribution.
  3. Apply the constraining bijectors and return the resulting trainable
     `tfd.TransformedDistribution` instance.

  Args:
    base_distribution: `tfd.Distribution` instance (typically a
      `tfd.JointDistribution`), or a nested structure of `tfd.Distribution`
      instances.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    initial_unconstrained_loc_fn: Optional Python `callable` with signature
      `initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)` used to
      sample real-valued initializations for the unconstrained location of
      each variable.
      Default value: `functools.partial(tf.random.stateless_uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    seed: Python integer to seed the random number generator for initial values.
      Default value: `None`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e.,
      'build_affine_surrogate_posterior_from_base_distribution').

  Returns:
    surrogate_distribution: Trainable `tfd.JointDistribution` instance.
  Raises:
    NotImplementedError: Base distributions with mixed dtypes are not supported.

  #### Examples
  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Fit a multivariate Normal surrogate posterior on the Eight Schools model
  # [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]

  def model_fn():
    avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect')
    log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev')
    school_effects = yield tfd.Sample(
        tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)),
        sample_shape=[8],
        name='school_effects')
    treatment_effects = yield tfd.Independent(
        tfd.Normal(loc=school_effects, scale=treatment_stddevs),
        reinterpreted_batch_ndims=1,
        name='treatment_effects')
  model = tfd.JointDistributionCoroutineAutoBatched(model_fn)

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Define a lower triangular structure of `LinearOperator` subclasses that
  # models full covariance among latent variables except for the 8 dimensions
  # of `school_effect`, which are modeled as independent (using
  # `LinearOperatorDiag`).
  operators = [
    [tf.linalg.LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix,
     tf.linalg.LinearOperatorDiag]]


  # Constrain the posterior values to the support of the prior.
  bijector = target_model.experimental_default_event_space_bijector()

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution(
        base_distribution=base_distribution,
        operators=operators,
        bijector=bijector))

  # Fit the model.
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(
            name or 'build_affine_surrogate_posterior_from_base_distribution'):

        if nest.is_nested(base_distribution):
            base_distribution = (joint_distribution_util.
                                 independent_joint_distribution_from_structure(
                                     base_distribution,
                                     validate_args=validate_args))

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        batch_shape = base_distribution.batch_shape_tensor()
        if tf.nest.is_nested(
                batch_shape):  # Base is a classic JointDistribution.
            batch_shape = functools.reduce(ps.broadcast_shape,
                                           tf.nest.flatten(batch_shape))
        event_shape = base_distribution.event_shape_tensor()
        flat_event_size = nest.flatten(
            nest.map_structure(ps.reduce_prod, event_shape))

        base_dtypes = set(nest.flatten(base_distribution.dtype))
        if len(base_dtypes) > 1:
            raise NotImplementedError(
                'Base distributions with mixed dtype are not supported. Saw '
                'components of dtype {}'.format(base_dtypes))
        base_dtype = list(base_dtypes)[0]

        num_components = len(flat_event_size)
        if operators == 'diag':
            operators = [tf.linalg.LinearOperatorDiag] * num_components
        elif operators == 'tril':
            operators = [[tf.linalg.LinearOperatorFullMatrix] * i +
                         [tf.linalg.LinearOperatorLowerTriangular]
                         for i in range(num_components)]
        elif isinstance(operators, str):
            raise ValueError(
                'Unrecognized operator type {}. Valid operators are "diag", "tril", '
                'or a structure that can be passed to '
                '`tfp.experimental.vi.util.build_trainable_linear_operator_block` as '
                'the `operators` arg.'.format(operators))

        if nest.is_nested(operators):
            seed, operators_seed = samplers.split_seed(seed)
            operators = (trainable_linear_operators.
                         build_trainable_linear_operator_block(
                             operators,
                             block_dims=flat_event_size,
                             dtype=base_dtype,
                             batch_shape=batch_shape,
                             seed=operators_seed))

        linop_bijector = (
            scale_matvec_linear_operator.ScaleMatvecLinearOperatorBlock(
                scale=operators, validate_args=validate_args))
        loc_bijector = joint_map.JointMap(
            tf.nest.map_structure(
                lambda s, seed: shift.Shift(  # pylint: disable=g-long-lambda
                    tf.Variable(
                        initial_unconstrained_loc_fn(ps.concat(
                            [batch_shape, [s]], axis=0),
                                                     dtype=base_dtype,
                                                     seed=seed))),
                flat_event_size,
                samplers.split_seed(seed, n=len(flat_event_size))),
            validate_args=validate_args)

        unflatten_and_reshape = chain.Chain([
            joint_map.JointMap(nest.map_structure(reshape.Reshape,
                                                  event_shape),
                               validate_args=validate_args),
            restructure.Restructure(
                nest.pack_sequence_as(event_shape, range(num_components)))
        ],
                                            validate_args=validate_args)

        bijectors = [] if bijector is None else [bijector]
        bijectors.extend([
            unflatten_and_reshape,
            loc_bijector,  # Allow the mean of the standard dist to shift from 0.
            linop_bijector
        ])  # Apply LinOp to scale the standard dist.
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        flat_base_distribution = invert.Invert(unflatten_and_reshape)(
            base_distribution)

        return transformed_distribution.TransformedDistribution(
            flat_base_distribution,
            bijector=bijector,
            validate_args=validate_args)
Ejemplo n.º 19
0
    def __init__(self,
                 skewness,
                 tailweight,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct Johnson's SU distributions.

    The distributions have shape parameteres `tailweight` and `skewness`,
    mean `loc`, and scale `scale`.

    The parameters `tailweight`, `skewness`, `loc`, and `scale` must be shaped
    in a way that supports broadcasting
    (e.g. `skewness + tailweight + loc + scale` is a valid operation).

    Args:
      skewness: Floating-point `Tensor`. Skewness of the distribution(s).
      tailweight: Floating-point `Tensor`. Tail weight of the
        distribution(s). `tailweight` must contain only positive values.
      loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
      scale: Floating-point `Tensor`. The scaling factor(s) for the
        distribution(s). Note that `scale` is not technically the standard
        deviation of this distribution but has semantics more similar to
        standard deviation than variance.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if any of skewness, tailweight, loc and scale are different
        dtypes.
    """
        parameters = dict(locals())
        with tf.name_scope(name or 'JohnsonSU') as name:
            dtype = dtype_util.common_dtype([skewness, tailweight, loc, scale],
                                            tf.float32)
            self._skewness = tensor_util.convert_nonref_to_tensor(
                skewness, name='skewness', dtype=dtype)
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name='tailweight', dtype=dtype)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name='loc',
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name='scale',
                                                               dtype=dtype)

            norm_shift = invert_bijector.Invert(
                shift_bijector.Shift(shift=self._skewness,
                                     validate_args=validate_args))

            norm_scale = invert_bijector.Invert(
                scale_bijector.Scale(scale=self._tailweight,
                                     validate_args=validate_args))

            sinh = sinh_bijector.Sinh(validate_args=validate_args)

            scale = scale_bijector.Scale(scale=self._scale,
                                         validate_args=validate_args)

            shift = shift_bijector.Shift(shift=self._loc,
                                         validate_args=validate_args)

            bijector = shift(scale(sinh(norm_scale(norm_shift))))

            batch_rank = ps.reduce_max([
                distribution_util.prefer_static_rank(x)
                for x in (self._skewness, self._tailweight, self._loc,
                          self._scale)
            ])

            super(JohnsonSU, self).__init__(
                # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden
                # `batch_shape` and `batch_shape_tensor` when
                # TransformedDistribution's bijector can modify its `batch_shape`.
                distribution=normal.Normal(loc=tf.zeros(ps.ones(
                    batch_rank, tf.int32),
                                                        dtype=dtype),
                                           scale=tf.ones([], dtype=dtype),
                                           validate_args=validate_args,
                                           allow_nan_stats=allow_nan_stats),
                bijector=bijector,
                validate_args=validate_args,
                parameters=parameters,
                name=name)
Ejemplo n.º 20
0
    def __init__(self,
                 loc=None,
                 scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='MultivariateNormalLinearOperator'):
        """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
        parameters = dict(locals())
        if scale is None:
            raise ValueError('Missing required `scale` parameter.')
        if not dtype_util.is_floating(scale.dtype):
            raise TypeError(
                '`scale` parameter must have floating-point dtype.')

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale],
                                            dtype_hint=tf.float32)
            # Since expand_dims doesn't preserve constant-ness, we obtain the
            # non-dynamic value if possible.
            loc = tensor_util.convert_nonref_to_tensor(loc,
                                                       dtype=dtype,
                                                       name='loc')
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale)
        self._loc = loc
        self._scale = scale

        bijector = scale_matvec_linear_operator.ScaleMatvecLinearOperator(
            scale, validate_args=validate_args)
        if loc is not None:
            bijector = shift_bijector.Shift(
                shift=loc, validate_args=validate_args)(bijector)

        super(MultivariateNormalLinearOperator, self).__init__(
            distribution=normal.Normal(loc=tf.zeros([], dtype=dtype),
                                       scale=tf.ones([], dtype=dtype)),
            bijector=bijector,
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
            name=name)
        self._parameters = parameters
Ejemplo n.º 21
0
 def generate_shift_bijector(s):
     x = yield trainable_state_util.Parameter(
         functools.partial(initial_unconstrained_loc_fn,
                           ps.concat([batch_shape, [s]], axis=0),
                           dtype=base_dtype))
     return shift.Shift(x)
Ejemplo n.º 22
0
    global ASVI_SURROGATE_SUBSTITUTIONS
    if inspect.isclass(condition):
        condition = lambda distribution, cls=condition: isinstance(  # pylint: disable=g-long-lambda
            distribution, cls)
    ASVI_SURROGATE_SUBSTITUTIONS[condition] = substitution_fn


# Default substitutions attempt to express distributions using the most
# flexible available parameterization.
# pylint: disable=g-long-lambda
register_asvi_substitution_rule(
    half_normal.HalfNormal, lambda dist: truncated_normal.TruncatedNormal(
        loc=0., scale=dist.scale, low=0., high=dist.scale * 10.))
register_asvi_substitution_rule(
    uniform.Uniform, lambda dist: shift.Shift(dist.low)
    (scale_lib.Scale(dist.high - dist.low)
     (beta.Beta(concentration0=tf.ones_like(dist.mean()), concentration1=1.))))
register_asvi_substitution_rule(
    exponential.Exponential,
    lambda dist: gamma.Gamma(concentration=1., rate=dist.rate))
register_asvi_substitution_rule(
    chi2.Chi2, lambda dist: gamma.Gamma(concentration=0.5 * dist.df, rate=0.5))

# pylint: enable=g-long-lambda


# TODO(kateslin): Add support for models with prior+likelihood written as
# a single JointDistribution.
def build_asvi_surrogate_posterior(prior,
                                   mean_field=False,
                                   initial_prior_weight=0.5,
Ejemplo n.º 23
0
  def __init__(self,
               loc=None,
               scale=None,
               validate_args=False,
               allow_nan_stats=True,
               experimental_use_kahan_sum=False,
               name='MultivariateNormalLinearOperator'):
    """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan
        summation to aggregate independent underlying log_prob values. For best
        results, Kahan summation should also be applied when computing the
        log-determinant of the `LinearOperator` representing the scale matrix.
        Kahan summation improves against the precision of a naive float32 sum.
        This can be noticeable in particular for large dimensions in float32.
        See CPU caveat on `tfp.math.reduce_kahan_sum`.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
    parameters = dict(locals())
    self._experimental_use_kahan_sum = experimental_use_kahan_sum
    if scale is None:
      raise ValueError('Missing required `scale` parameter.')
    if not dtype_util.is_floating(scale.dtype):
      raise TypeError('`scale` parameter must have floating-point dtype.')

    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32)
      # Since expand_dims doesn't preserve constant-ness, we obtain the
      # non-dynamic value if possible.
      loc = tensor_util.convert_nonref_to_tensor(
          loc, dtype=dtype, name='loc')
      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
          loc, scale)
    self._loc = loc
    self._scale = scale

    bijector = scale_matvec_linear_operator.ScaleMatvecLinearOperator(
        scale, validate_args=validate_args)
    if loc is not None:
      bijector = shift_bijector.Shift(
          shift=loc, validate_args=validate_args)(bijector)
    super(MultivariateNormalLinearOperator, self).__init__(
        # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
        # shape instead of tf.zeros.
        # We use `Sample` instead of `Independent` because `Independent`
        # requires concatenating `batch_shape` and `event_shape`, which loses
        # static `batch_shape` information when `event_shape` is not statically
        # known.
        distribution=sample.Sample(
            normal.Normal(
                loc=tf.zeros(batch_shape, dtype=dtype),
                scale=tf.ones([], dtype=dtype)),
            event_shape,
            experimental_use_kahan_sum=experimental_use_kahan_sum),
        bijector=bijector,
        validate_args=validate_args,
        name=name)
    self._parameters = parameters
Ejemplo n.º 24
0
  def __init__(self,
               loc=None,
               precision_factor=None,
               precision=None,
               validate_args=False,
               allow_nan_stats=True,
               name='MultivariateNormalPrecisionFactorLinearOperator'):
    """Initialize distribution.

    Precision is the inverse of the covariance matrix, and
    `precision_factor @ precision_factor.T = precision`.

    The `batch_shape` of this distribution is the broadcast of
    `loc.shape[:-1]` and `precision_factor.batch_shape`.

    The `event_shape` of this distribution is determined by `loc.shape[-1:]`,
    OR `precision_factor.shape[-1:]`, which must match.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      precision_factor: Required nonsingular `tf.linalg.LinearOperator` instance
        with same `dtype` and shape compatible with `loc`.
      precision: Optional square `tf.linalg.LinearOperator` instance with same
        `dtype` and shape compatible with `loc` and `precision_factor`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      if precision_factor is None:
        raise ValueError(
            'Argument `precision_factor` must be provided. Found `None`')

      dtype = dtype_util.common_dtype([loc, precision_factor, precision],
                                      dtype_hint=tf.float32)
      loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc')

      self._loc = loc
      self._precision_factor = precision_factor
      self._precision = precision

      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
          loc, precision_factor)

      # Proof of factors (used throughout code):
      # Let,
      #   C = covariance,
      #   P = inv(covariance) = precision
      #   P = F @ F.T  (so F is the `precision_factor`).
      #
      # Then, the log prob term is
      #  x.T @ inv(C) @ x
      #  = x.T @ P @ x
      #  = x.T @ F @ F.T @ x
      #  = || F.T @ x ||**2
      # notice it involves F.T, which is why we set adjoint=True in various
      # places.
      #
      # Also, if w ~ Normal(0, I), then we can sample by setting
      #  x = inv(F.T) @ w + loc,
      # since then
      #  E[(x - loc) @ (x - loc).T]
      #  = E[inv(F.T) @ w @ w.T @ inv(F)]
      #  = inv(F.T) @ inv(F)
      #  = inv(F @ F.T)
      #  = inv(P)
      #  = C.

      if precision is not None:
        precision.shape.assert_is_compatible_with(precision_factor.shape)

      bijector = invert.Invert(
          scale_matvec_linear_operator.ScaleMatvecLinearOperator(
              scale=precision_factor,
              validate_args=validate_args,
              adjoint=True)
      )
      if loc is not None:
        shift = shift_bijector.Shift(shift=loc, validate_args=validate_args)
        bijector = shift(bijector)

      super(MultivariateNormalPrecisionFactorLinearOperator, self).__init__(
          distribution=mvn_diag.MultivariateNormalDiag(
              loc=tf.zeros(
                  ps.concat([batch_shape, event_shape], axis=0), dtype=dtype)),
          bijector=bijector,
          validate_args=validate_args,
          name=name)
      self._parameters = parameters
Ejemplo n.º 25
0
  def __init__(self,
               loc,
               scale,
               skewness=None,
               tailweight=None,
               distribution=None,
               validate_args=False,
               allow_nan_stats=True,
               name='SinhArcsinh'):
    """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Must have a batch shape to which the shapes of `loc`, `scale`,
        `skewness`, and `tailweight` all broadcast. Default is
        `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted
        shape of the parameters. Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())

    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                      tf.float32)
      self._loc = tensor_util.convert_nonref_to_tensor(
          loc, name='loc', dtype=dtype)
      self._scale = tensor_util.convert_nonref_to_tensor(
          scale, name='scale', dtype=dtype)
      tailweight = 1. if tailweight is None else tailweight
      has_default_skewness = skewness is None
      skewness = 0. if has_default_skewness else skewness
      self._tailweight = tensor_util.convert_nonref_to_tensor(
          tailweight, name='tailweight', dtype=dtype)
      self._skewness = tensor_util.convert_nonref_to_tensor(
          skewness, name='skewness', dtype=dtype)

      # Recall, with Z a random variable,
      #   Y := loc + scale * F(Z),
      #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C
      #   C := 2 / F_0(2)
      #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
      if distribution is None:
        batch_shape = functools.reduce(
            ps.broadcast_shape,
            [ps.shape(x)
             for x in (self._skewness, self._tailweight,
                       self._loc, self._scale)])

        distribution = normal.Normal(
            loc=tf.zeros(batch_shape, dtype=dtype),
            scale=tf.ones([], dtype=dtype),
            allow_nan_stats=allow_nan_stats,
            validate_args=validate_args)

      # Make the SAS bijector, 'F'.
      f = sinh_arcsinh_bijector.SinhArcsinh(
          skewness=self._skewness, tailweight=self._tailweight,
          validate_args=validate_args)

      # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
      affine = shift_bijector.Shift(shift=self._loc)(
          scale_bijector.Scale(scale=self._scale))
      bijector = chain_bijector.Chain([affine, f])

      super(SinhArcsinh, self).__init__(
          distribution=distribution,
          bijector=bijector,
          validate_args=validate_args,
          name=name)
      self._parameters = parameters
Ejemplo n.º 26
0
    def __init__(self,
                 low=None,
                 high=None,
                 hinge_softness=None,
                 validate_args=False,
                 name='soft_clip'):
        """Instantiates the SoftClip bijector.

    Args:
      low: Optional float `Tensor` lower bound. If `None`, the lower-bound
        constraint is omitted.
        Default value: `None`.
      high: Optional float `Tensor` upper bound. If `None`, the upper-bound
        constraint is omitted.
        Default value: `None`.
      hinge_softness: Optional nonzero float `Tensor`. Controls the softness
        of the constraint at the boundaries; values outside of the constraint
        set are mapped into intervals of width approximately
        `log(2) * hinge_softness` on the interior of each boundary. High
        softness reserves more space for values outside of the constraint set,
        leading to greater distortion of inputs *within* the constraint set,
        but improved numerical stability near the boundaries.
        Default value: `None` (`1.0`).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
        parameters = dict(locals())
        with tf.name_scope(name):
            dtype = dtype_util.common_dtype([low, high, hinge_softness],
                                            dtype_hint=tf.float32)
            low = tensor_util.convert_nonref_to_tensor(low,
                                                       name='low',
                                                       dtype=dtype)
            high = tensor_util.convert_nonref_to_tensor(high,
                                                        name='high',
                                                        dtype=dtype)
            hinge_softness = tensor_util.convert_nonref_to_tensor(
                hinge_softness, name='hinge_softness', dtype=dtype)

            softplus_bijector = softplus.Softplus(
                hinge_softness=hinge_softness)
            negate = tf.convert_to_tensor(-1., dtype=dtype)

            components = []
            if low is not None and high is not None:
                # Support reference tensors (eg Variables) for `high` and `low` by
                # deferring all computation on them until needed.
                width = tfp_util.DeferredTensor(
                    pretransformed_input=high,
                    transform_fn=lambda high: high - low)
                negated_shrinkage_factor = tfp_util.DeferredTensor(
                    pretransformed_input=width,
                    transform_fn=lambda w: tf.cast(  # pylint: disable=g-long-lambda
                        negate * w / softplus_bijector.forward(w),
                        dtype=dtype))

                # Implement the soft constraint from 'Mathematical Details' above:
                #  softclip(x) := -softplus(width - softplus(x - low)) *
                #                        (width) / (softplus(width)) + high
                components = [
                    shift.Shift(high),
                    scale.Scale(negated_shrinkage_factor), softplus_bijector,
                    shift.Shift(width),
                    scale.Scale(negate), softplus_bijector,
                    shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x))
                ]
            elif low is not None:
                # Implement a soft lower bound:
                #  softlower(x) := softplus(x - low) + low
                components = [
                    shift.Shift(low), softplus_bijector,
                    shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x))
                ]
            elif high is not None:
                # Implement a soft upper bound:
                #  softupper(x) := -softplus(high - x) + high
                components = [
                    shift.Shift(high),
                    scale.Scale(negate), softplus_bijector,
                    scale.Scale(negate),
                    shift.Shift(high)
                ]

            self._low = low
            self._high = high
            self._hinge_softness = hinge_softness
            self._chain = chain.Chain(components, validate_args=validate_args)

        super(SoftClip, self).__init__(forward_min_event_ndims=0,
                                       dtype=dtype,
                                       validate_args=validate_args,
                                       parameters=parameters,
                                       is_constant_jacobian=not components,
                                       name=name)
    def __init__(self,
                 loc=None,
                 scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='VectorExponentialLinearOperator'):
        """Construct Vector Exponential distribution supported on a subset of `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
        parameters = dict(locals())
        if loc is None:
            loc = 0.0  # Implicit value for backwards compatibility.
        if scale is None:
            raise ValueError('Missing required `scale` parameter.')
        if not dtype_util.is_floating(scale.dtype):
            raise TypeError(
                '`scale` parameter must have floating-point dtype.')

        with tf.name_scope(name) as name:
            # Since expand_dims doesn't preserve constant-ness, we obtain the
            # non-dynamic value if possible.
            loc = loc if loc is None else tf.convert_to_tensor(
                loc, name='loc', dtype=scale.dtype)
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale)
            self._loc = loc
            self._scale = scale
            super(VectorExponentialLinearOperator, self).__init__(
                # TODO(b/137665504): Use batch-adding meta-distribution to set the
                # batch shape instead of tf.ones.
                # We use `Sample` instead of `Independent` because `Independent`
                # requires concatenating `batch_shape` and `event_shape`, which loses
                # static `batch_shape` information when `event_shape` is not
                # statically known.
                distribution=sample.Sample(
                    exponential.Exponential(rate=tf.ones(batch_shape,
                                                         dtype=scale.dtype),
                                            allow_nan_stats=allow_nan_stats),
                    event_shape),
                bijector=shift_bijector.Shift(shift=loc)(
                    scale_matvec_linear_operator.ScaleMatvecLinearOperator(
                        scale=scale, validate_args=validate_args)),
                validate_args=validate_args,
                name=name)
            self._parameters = parameters