Exemple #1
0
 def _default_event_space_bijector(self):
   return chain_bijector.Chain([
       softplus_bijector.Softplus(validate_args=self.validate_args),
       scale_bijector.Scale(scale=-1., validate_args=self.validate_args),
       exp_bijector.Log(validate_args=self.validate_args),
       softplus_bijector.Softplus(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
 def _transformed_beta(self,
                       low=None,
                       peak=None,
                       high=None,
                       temperature=None):
     low = tf.convert_to_tensor(self.low) if low is None else low
     peak = tf.convert_to_tensor(self.peak) if peak is None else peak
     high = tf.convert_to_tensor(self.high) if high is None else high
     temperature = (tf.convert_to_tensor(self.temperature)
                    if temperature is None else temperature)
     scale = high - low
     concentration1 = (1. + temperature * (peak - low) / scale)
     concentration0 = (1. + temperature * (high - peak) / scale)
     return transformed_distribution.TransformedDistribution(
         distribution=beta.Beta(concentration1=concentration1,
                                concentration0=concentration0,
                                allow_nan_stats=self.allow_nan_stats),
         bijector=chain_bijector.Chain([
             shift_bijector.Shift(shift=low),
             # Broadcasting scale on affine bijector to match batch dimension.
             # This prevents dimension mismatch for operations like cdf.
             # Note that `concentration1` incorporates the broadcast of all four
             # parameters.
             scale_bijector.Scale(
                 scale=tf.broadcast_to(scale, ps.shape(concentration1)))
         ]))
Exemple #3
0
    def __init__(self,
                 concentration,
                 scale=None,
                 log_scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='ExpInverseGamma'):
        """Construct ExpInverseGamma with `concentration` and `scale` parameters.

    The parameters `concentration` and `scale` (or `log_scale`) must be shaped
    in a way that supports broadcasting (e.g. `concentration + scale` is a valid
    operation).

    Args:
      concentration: Floating point tensor, the concentration params of the
        distribution(s). Must contain only positive values.
      scale: Floating point tensor, the scale params of the distribution(s).
        Must contain only positive values. Mutually exclusive with `log_scale`.
      log_scale: Floating point tensor, the natural logarithm of the scale
        params of the distribution(s). Mutually exclusive with `scale`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.


    Raises:
      TypeError: if `concentration`, `scale`, or `log_scale` are different
        dtypes.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([concentration, scale, log_scale],
                                            dtype_hint=tf.float32)
            concentration = tensor_util.convert_nonref_to_tensor(
                concentration, dtype=dtype, name='concentration')
            scale = tensor_util.convert_nonref_to_tensor(scale,
                                                         dtype=dtype,
                                                         name='scale')
            log_scale = tensor_util.convert_nonref_to_tensor(log_scale,
                                                             dtype=dtype,
                                                             name='log_scale')
            bijector = scale_bijector.Scale(scale=-tf.ones([], dtype=dtype))
            to_transform = ExpGamma(concentration=concentration,
                                    rate=scale,
                                    log_rate=log_scale,
                                    validate_args=validate_args,
                                    allow_nan_stats=allow_nan_stats)
            super(ExpInverseGamma, self).__init__(bijector=bijector,
                                                  distribution=to_transform,
                                                  validate_args=validate_args,
                                                  parameters=parameters,
                                                  name=name)
Exemple #4
0
 def _bijector_fn(x0, input_depth, **condition_kwargs):
   shift, log_scale = shift_and_log_scale_fn(x0, input_depth,
                                             **condition_kwargs)
   bijectors = []
   if shift is not None:
     bijectors.append(shift_lib.Shift(shift))
   if log_scale is not None:
     bijectors.append(scale_lib.Scale(log_scale=log_scale))
   return chain_lib.Chain(bijectors)
Exemple #5
0
 def _default_event_space_bijector(self):
     low = tfp_util.DeferredTensor(self.low, lambda x: x)
     scale = tfp_util.DeferredTensor(self.high, lambda x: x - self.low)
     return chain_bijector.Chain([
         shift_bijector.Shift(shift=low, validate_args=self.validate_args),
         scale_bijector.Scale(scale=scale,
                              validate_args=self.validate_args),
         sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
     ],
                                 validate_args=self.validate_args)
Exemple #6
0
 def _default_event_space_bijector(self):
   if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high):
     scale = DeferredTensor(self.high, lambda x: x - self.low)
   else:
     scale = self.high - self.low
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.low, validate_args=self.validate_args),
       scale_bijector.Scale(scale=scale, validate_args=self.validate_args),
       sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
 def _default_event_space_bijector(self):
     # TODO(b/145620027) Finalize choice of bijector.
     return chain_bijector.Chain([
         shift_bijector.Shift(shift=-np.pi,
                              validate_args=self.validate_args),
         scale_bijector.Scale(scale=2. * np.pi,
                              validate_args=self.validate_args),
         sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
     ],
                                 validate_args=self.validate_args)
    def __init__(self,
                 shift,
                 scale,
                 tailweight,
                 validate_args=False,
                 name="lambertw_tail"):
        """Construct a location scale heavy-tail Lambert W bijector.

    The parameters `shift`, `scale`, and `tail` must be shaped in a way that
    supports broadcasting (e.g. `shift + scale + tail` is a valid operation).

    Args:
      shift: Floating point tensor; the shift for centering (uncentering) the
        input (output) random variable(s).
      scale: Floating point tensor; the scaling (unscaling) of the input
        (output) random variable(s). Must contain only positive values.
      tailweight: Floating point tensor; the tail behaviors of the output random
        variable(s).  Must contain only non-negative values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `shift` and `scale` and `tail` have different `dtype`.
    """
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([tailweight, shift, scale],
                                            tf.float32)
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name="tailweight", dtype=dtype)
            self._shift = tensor_util.convert_nonref_to_tensor(shift,
                                                               name="shift",
                                                               dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name="scale",
                                                               dtype=dtype)
            dtype_util.assert_same_float_dtype(
                (self._tailweight, self._shift, self._scale))

            self._shift_and_scale = chain.Chain(
                [tfb_shift.Shift(self._shift),
                 tfb_scale.Scale(self._scale)])
            # 'bijectors' argument in tfb.Chain super class are executed in reverse(!)
            # order.  Hence the ordering in the list must be (3,2,1), not (1,2,3).
            super(LambertWTail, self).__init__(bijectors=[
                self._shift_and_scale,
                _HeavyTailOnly(tailweight=self._tailweight),
                invert.Invert(self._shift_and_scale)
            ],
                                               validate_args=validate_args)
 def _negative_concentration_bijector(self):
   # Constructed dynamically so that `scale * reciprocal(concentration)` is
   # tape-safe.
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args),
       # TODO(b/146568897): Resolve numerical issues by implementing a new
       # bijector instead of multiplying `scale` by `(1. - 1e-6)`.
       scale_bijector.Scale(
           scale=-(self.scale *
                   tf.math.reciprocal(self.concentration) * (1. - 1e-6)),
           validate_args=self.validate_args),
       sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Exemple #10
0
 def _default_event_space_bijector(self):
   # TODO(b/146568897): Resolve numerical issues by implementing a new bijector
   # instead of multiplying `scale` by `(1. - 1e-6)`.
   if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high):
     scale = DeferredTensor(
         self.high,
         lambda x: (x - self.low) * (1. - 1e-6),
         shape=tf.broadcast_static_shape(self.high.shape, self.low.shape))
   else:
     scale = (self.high - self.low) * (1. - 1e-6)
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.low, validate_args=self.validate_args),
       scale_bijector.Scale(scale=scale, validate_args=self.validate_args),
       sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Exemple #11
0
  def __init__(self, nchan, dtype=tf.float32, validate_args=False, name=None):
    parameters = dict(locals())

    self._initialized = tf.Variable(False, trainable=False)
    self._m = tf.Variable(tf.zeros(nchan, dtype))
    self._s = TransformedVariable(tf.ones(nchan, dtype), exp.Exp())
    self._bijector = invert.Invert(
        chain.Chain([
            scale.Scale(self._s),
            shift.Shift(self._m),
        ]))
    super(ActivationNormalization, self).__init__(
        validate_args=validate_args,
        forward_min_event_ndims=1,
        parameters=parameters,
        name=name or 'ActivationNormalization')
def _as_trainable_family(distribution):
  """Substitutes prior distributions with more easily trainable ones."""
  with tf.name_scope('as_trainable_family'):

    if isinstance(distribution, half_normal.HalfNormal):
      return truncated_normal.TruncatedNormal(
          loc=0.,
          scale=distribution.scale,
          low=0.,
          high=distribution.scale * 10.)
    elif isinstance(distribution, uniform.Uniform):
      return shift.Shift(distribution.low)(
          scale_lib.Scale(distribution.high - distribution.low)(beta.Beta(
              concentration0=tf.ones(
                  distribution.event_shape_tensor(), dtype=distribution.dtype),
              concentration1=1.)))
    else:
      return distribution
Exemple #13
0
        def bijector_fn(inputs, ignored_input):
            """Decorated function to get the RealNVP bijector."""
            # Build this so we can handle a user passing a NN that returns a tensor
            # OR an NN that returns a bijector
            possible_output = layer(inputs)

            # We need to produce a bijector, but we do not know if the layer has done
            # so. We are setting this up to handle 2 possibilities:
            # 1) The layer outputs a bijector --> all is good
            # 2) The layer outputs a tensor --> we need to turn it into a bijector.
            if isinstance(possible_output, bijector.Bijector):
                output = possible_output
            elif isinstance(possible_output, tf.Tensor):
                input_shape = inputs.get_shape().as_list()
                output_shape = possible_output.get_shape().as_list()
                assert input_shape[:-1] == output_shape[:-1]
                c = input_shape[-1]

                # For layers which output a tensor, we have two possibilities:
                # 1) There are twice as many output channels as inputs --> the coupling
                #    is affine, meaning there is a scale followed by a shift.
                # 2) There are an equal number of input and output channels --> the
                #    coupling is additive, meaning there is just a shift
                if input_shape[-1] == output_shape[-1] // 2:
                    this_scale = scale.Scale(
                        scale_fn(possible_output[..., :c] + 2.))
                    this_shift = shift.Shift(possible_output[..., c:])
                    output = this_shift(this_scale)
                elif input_shape[-1] == output_shape[-1]:

                    output = shift.Shift(possible_output[..., :c])
                else:
                    raise ValueError(
                        'Shape inconsistent with input. Expected shape'
                        '{0} or {1} but tensor was shape {2}'.format(
                            input_shape,
                            tf.concat(
                                [input_shape[:-1], [2 * input_shape[-1]]], 0),
                            output_shape))
            else:
                raise ValueError(
                    'Expected a bijector or a tensor, but instead got'
                    '{}'.format(possible_output.__class__))
            return output
Exemple #14
0
    def __init__(self,
                 skewness,
                 tailweight,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct Johnson's SU distributions.

    The distributions have shape parameteres `tailweight` and `skewness`,
    mean `loc`, and scale `scale`.

    The parameters `tailweight`, `skewness`, `loc`, and `scale` must be shaped
    in a way that supports broadcasting
    (e.g. `skewness + tailweight + loc + scale` is a valid operation).

    Args:
      skewness: Floating-point `Tensor`. Skewness of the distribution(s).
      tailweight: Floating-point `Tensor`. Tail weight of the
        distribution(s). `tailweight` must contain only positive values.
      loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
      scale: Floating-point `Tensor`. The scaling factor(s) for the
        distribution(s). Note that `scale` is not technically the standard
        deviation of this distribution but has semantics more similar to
        standard deviation than variance.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if any of skewness, tailweight, loc and scale are different
        dtypes.
    """
        parameters = dict(locals())
        with tf.name_scope(name or 'JohnsonSU') as name:
            dtype = dtype_util.common_dtype([skewness, tailweight, loc, scale],
                                            tf.float32)
            self._skewness = tensor_util.convert_nonref_to_tensor(
                skewness, name='skewness', dtype=dtype)
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name='tailweight', dtype=dtype)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name='loc',
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name='scale',
                                                               dtype=dtype)

            norm_shift = invert_bijector.Invert(
                shift_bijector.Shift(shift=self._skewness,
                                     validate_args=validate_args))

            norm_scale = invert_bijector.Invert(
                scale_bijector.Scale(scale=self._tailweight,
                                     validate_args=validate_args))

            sinh = sinh_bijector.Sinh(validate_args=validate_args)

            scale = scale_bijector.Scale(scale=self._scale,
                                         validate_args=validate_args)

            shift = shift_bijector.Shift(shift=self._loc,
                                         validate_args=validate_args)

            bijector = shift(scale(sinh(norm_scale(norm_shift))))

            batch_rank = ps.reduce_max([
                distribution_util.prefer_static_rank(x)
                for x in (self._skewness, self._tailweight, self._loc,
                          self._scale)
            ])

            super(JohnsonSU, self).__init__(
                # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden
                # `batch_shape` and `batch_shape_tensor` when
                # TransformedDistribution's bijector can modify its `batch_shape`.
                distribution=normal.Normal(loc=tf.zeros(ps.ones(
                    batch_rank, tf.int32),
                                                        dtype=dtype),
                                           scale=tf.ones([], dtype=dtype),
                                           validate_args=validate_args,
                                           allow_nan_stats=allow_nan_stats),
                bijector=bijector,
                validate_args=validate_args,
                parameters=parameters,
                name=name)
Exemple #15
0
    global ASVI_SURROGATE_SUBSTITUTIONS
    if inspect.isclass(condition):
        condition = lambda distribution, cls=condition: isinstance(  # pylint: disable=g-long-lambda
            distribution, cls)
    ASVI_SURROGATE_SUBSTITUTIONS[condition] = substitution_fn


# Default substitutions attempt to express distributions using the most
# flexible available parameterization.
# pylint: disable=g-long-lambda
register_asvi_substitution_rule(
    half_normal.HalfNormal, lambda dist: truncated_normal.TruncatedNormal(
        loc=0., scale=dist.scale, low=0., high=dist.scale * 10.))
register_asvi_substitution_rule(
    uniform.Uniform, lambda dist: shift.Shift(dist.low)
    (scale_lib.Scale(dist.high - dist.low)
     (beta.Beta(concentration0=tf.ones_like(dist.mean()), concentration1=1.))))
register_asvi_substitution_rule(
    exponential.Exponential,
    lambda dist: gamma.Gamma(concentration=1., rate=dist.rate))
register_asvi_substitution_rule(
    chi2.Chi2, lambda dist: gamma.Gamma(concentration=0.5 * dist.df, rate=0.5))

# pylint: enable=g-long-lambda


# TODO(kateslin): Add support for models with prior+likelihood written as
# a single JointDistribution.
def build_asvi_surrogate_posterior(prior,
                                   mean_field=False,
                                   initial_prior_weight=0.5,
                                   seed=None,
Exemple #16
0
  def __init__(self,
               loc,
               scale,
               skewness=None,
               tailweight=None,
               distribution=None,
               validate_args=False,
               allow_nan_stats=True,
               name='SinhArcsinh'):
    """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Must have a batch shape to which the shapes of `loc`, `scale`,
        `skewness`, and `tailweight` all broadcast. Default is
        `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted
        shape of the parameters. Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())

    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                      tf.float32)
      self._loc = tensor_util.convert_nonref_to_tensor(
          loc, name='loc', dtype=dtype)
      self._scale = tensor_util.convert_nonref_to_tensor(
          scale, name='scale', dtype=dtype)
      tailweight = 1. if tailweight is None else tailweight
      has_default_skewness = skewness is None
      skewness = 0. if has_default_skewness else skewness
      self._tailweight = tensor_util.convert_nonref_to_tensor(
          tailweight, name='tailweight', dtype=dtype)
      self._skewness = tensor_util.convert_nonref_to_tensor(
          skewness, name='skewness', dtype=dtype)

      # Recall, with Z a random variable,
      #   Y := loc + scale * F(Z),
      #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C
      #   C := 2 / F_0(2)
      #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
      if distribution is None:
        batch_shape = functools.reduce(
            ps.broadcast_shape,
            [ps.shape(x)
             for x in (self._skewness, self._tailweight,
                       self._loc, self._scale)])

        distribution = normal.Normal(
            loc=tf.zeros(batch_shape, dtype=dtype),
            scale=tf.ones([], dtype=dtype),
            allow_nan_stats=allow_nan_stats,
            validate_args=validate_args)

      # Make the SAS bijector, 'F'.
      f = sinh_arcsinh_bijector.SinhArcsinh(
          skewness=self._skewness, tailweight=self._tailweight,
          validate_args=validate_args)

      # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
      affine = shift_bijector.Shift(shift=self._loc)(
          scale_bijector.Scale(scale=self._scale))
      bijector = chain_bijector.Chain([affine, f])

      super(SinhArcsinh, self).__init__(
          distribution=distribution,
          bijector=bijector,
          validate_args=validate_args,
          name=name)
      self._parameters = parameters
Exemple #17
0
    def __init__(self,
                 low=None,
                 high=None,
                 hinge_softness=None,
                 validate_args=False,
                 name='soft_clip'):
        """Instantiates the SoftClip bijector.

    Args:
      low: Optional float `Tensor` lower bound. If `None`, the lower-bound
        constraint is omitted.
        Default value: `None`.
      high: Optional float `Tensor` upper bound. If `None`, the upper-bound
        constraint is omitted.
        Default value: `None`.
      hinge_softness: Optional nonzero float `Tensor`. Controls the softness
        of the constraint at the boundaries; values outside of the constraint
        set are mapped into intervals of width approximately
        `log(2) * hinge_softness` on the interior of each boundary. High
        softness reserves more space for values outside of the constraint set,
        leading to greater distortion of inputs *within* the constraint set,
        but improved numerical stability near the boundaries.
        Default value: `None` (`1.0`).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
        parameters = dict(locals())
        with tf.name_scope(name):
            dtype = dtype_util.common_dtype([low, high, hinge_softness],
                                            dtype_hint=tf.float32)
            low = tensor_util.convert_nonref_to_tensor(low,
                                                       name='low',
                                                       dtype=dtype)
            high = tensor_util.convert_nonref_to_tensor(high,
                                                        name='high',
                                                        dtype=dtype)
            hinge_softness = tensor_util.convert_nonref_to_tensor(
                hinge_softness, name='hinge_softness', dtype=dtype)

            softplus_bijector = softplus.Softplus(
                hinge_softness=hinge_softness)
            negate = tf.convert_to_tensor(-1., dtype=dtype)

            components = []
            if low is not None and high is not None:
                # Support reference tensors (eg Variables) for `high` and `low` by
                # deferring all computation on them until needed.
                width = tfp_util.DeferredTensor(
                    pretransformed_input=high,
                    transform_fn=lambda high: high - low)
                negated_shrinkage_factor = tfp_util.DeferredTensor(
                    pretransformed_input=width,
                    transform_fn=lambda w: tf.cast(  # pylint: disable=g-long-lambda
                        negate * w / softplus_bijector.forward(w),
                        dtype=dtype))

                # Implement the soft constraint from 'Mathematical Details' above:
                #  softclip(x) := -softplus(width - softplus(x - low)) *
                #                        (width) / (softplus(width)) + high
                components = [
                    shift.Shift(high),
                    scale.Scale(negated_shrinkage_factor), softplus_bijector,
                    shift.Shift(width),
                    scale.Scale(negate), softplus_bijector,
                    shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x))
                ]
            elif low is not None:
                # Implement a soft lower bound:
                #  softlower(x) := softplus(x - low) + low
                components = [
                    shift.Shift(low), softplus_bijector,
                    shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x))
                ]
            elif high is not None:
                # Implement a soft upper bound:
                #  softupper(x) := -softplus(high - x) + high
                components = [
                    shift.Shift(high),
                    scale.Scale(negate), softplus_bijector,
                    scale.Scale(negate),
                    shift.Shift(high)
                ]

            self._low = low
            self._high = high
            self._hinge_softness = hinge_softness
            self._chain = chain.Chain(components, validate_args=validate_args)

        super(SoftClip, self).__init__(forward_min_event_ndims=0,
                                       dtype=dtype,
                                       validate_args=validate_args,
                                       parameters=parameters,
                                       is_constant_jacobian=not components,
                                       name=name)