Example #1
0
 def _bijector_fn(x, **condition_kwargs):
   params = shift_and_log_scale_fn(x, **condition_kwargs)
   if tf.is_tensor(params):
     shift, log_scale = tf.unstack(params, num=2, axis=-1)
   else:
     shift, log_scale = params
   return affine_scalar.AffineScalar(shift=shift, log_scale=log_scale)
Example #2
0
 def _transformed_beta(self,
                       low=None,
                       peak=None,
                       high=None,
                       temperature=None):
     low = tf.convert_to_tensor(self.low) if low is None else low
     peak = tf.convert_to_tensor(self.peak) if peak is None else peak
     high = tf.convert_to_tensor(self.high) if high is None else high
     temperature = (tf.convert_to_tensor(self.temperature)
                    if temperature is None else temperature)
     scale = high - low
     concentration1 = (1. + temperature * (peak - low) / scale)
     concentration0 = (1. + temperature * (high - peak) / scale)
     return transformed_distribution.TransformedDistribution(
         distribution=beta.Beta(concentration1=concentration1,
                                concentration0=concentration0,
                                allow_nan_stats=self.allow_nan_stats),
         bijector=affine_scalar.AffineScalar(
             shift=low,
             # Broadcasting scale on affine bijector to match batch dimension.
             # This prevents dimension mismatch for operations like cdf.
             # Note that `concentration1` incorporates the broadcast of all four
             # parameters.
             scale=tf.broadcast_to(scale,
                                   prefer_static.shape(concentration1))))
 def _bijector_fn(x, **condition_kwargs):
     if conditioning is not None:
         print(x, conditioning)
         x = tf.concat([conditioning, x], axis=-1)
         cond_depth = tf.compat.dimension_value(
             tensorshape_util.with_rank_at_least(
                 conditioning.shape, 1)[-1])
     else:
         cond_depth = 0
     params = shift_and_log_scale_fn(x, **condition_kwargs)
     if tf.is_tensor(params):
         shift, log_scale = tf.unstack(params, num=2, axis=-1)
     else:
         shift, log_scale = params
     shift = shift[..., cond_depth:]
     log_scale = log_scale[..., cond_depth:]
     return affine_scalar.AffineScalar(shift=shift,
                                       log_scale=log_scale)
Example #4
0
    def __init__(self,
                 diag_bijector=None,
                 diag_shift=1e-5,
                 validate_args=False,
                 name="scale_tril"):
        """Instantiates the `ScaleTriL` bijector.

    Args:
      diag_bijector: `Bijector` instance, used to transform the output diagonal
        to be positive.
        Default value: `None` (i.e., `tfb.Softplus()`).
      diag_shift: Float value broadcastable and added to all diagonal entries
        after applying the `diag_bijector`. Setting a positive
        value forces the output diagonal entries to be positive, but
        prevents inverting the transformation for matrices with
        diagonal entries less than this value.
        Default value: `1e-5`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False` (i.e., arguments are not validated).
      name: Python `str` name given to ops managed by this object.
        Default value: `scale_tril`.
    """
        with tf.compat.v1.name_scope(name, values=[diag_shift]) as name:
            if diag_bijector is None:
                diag_bijector = softplus.Softplus(validate_args=validate_args)

            if diag_shift is not None:
                diag_shift = tf.convert_to_tensor(value=diag_shift,
                                                  dtype=diag_bijector.dtype,
                                                  name="diag_shift")
                diag_bijector = chain.Chain([
                    affine_scalar.AffineScalar(shift=diag_shift), diag_bijector
                ])

            super(ScaleTriL, self).__init__([
                transform_diagonal.TransformDiagonal(
                    diag_bijector=diag_bijector),
                fill_triangular.FillTriangular()
            ],
                                            validate_args=validate_args,
                                            name=name)
            self._use_tf_function = False  # So input bijectors cache EagerTensors.
Example #5
0
 def _bijector_fn(x0, input_depth, **condition_kwargs):
     shift, log_scale = shift_and_log_scale_fn(
         x0, input_depth, **condition_kwargs)
     return affine_scalar.AffineScalar(shift=shift,
                                       log_scale=log_scale)
Example #6
0
    def __init__(self,
                 loc,
                 scale,
                 skewness=None,
                 tailweight=None,
                 distribution=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="SinhArcsinh"):
        """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Default is `tfd.Normal(0., 1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())

        with tf.compat.v2.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                            tf.float32)
            loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
            scale = tf.convert_to_tensor(value=scale,
                                         name="scale",
                                         dtype=dtype)
            tailweight = 1. if tailweight is None else tailweight
            has_default_skewness = skewness is None
            skewness = 0. if skewness is None else skewness
            tailweight = tf.convert_to_tensor(value=tailweight,
                                              name="tailweight",
                                              dtype=dtype)
            skewness = tf.convert_to_tensor(value=skewness,
                                            name="skewness",
                                            dtype=dtype)

            batch_shape = distribution_util.get_broadcast_shape(
                loc, scale, tailweight, skewness)

            # Recall, with Z a random variable,
            #   Y := loc + C * F(Z),
            #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
            #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
            #   C := 2 * scale / F_0(2)
            if distribution is None:
                distribution = normal.Normal(loc=tf.zeros([], dtype=dtype),
                                             scale=tf.ones([], dtype=dtype),
                                             allow_nan_stats=allow_nan_stats)
            else:
                asserts = distribution_util.maybe_check_scalar_distribution(
                    distribution, dtype, validate_args)
                if asserts:
                    loc = distribution_util.with_dependencies(asserts, loc)

            # Make the SAS bijector, 'F'.
            f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness,
                                                  tailweight=tailweight)
            if has_default_skewness:
                f_noskew = f
            else:
                f_noskew = sinh_arcsinh_bijector.SinhArcsinh(
                    skewness=skewness.dtype.as_numpy_dtype(0.),
                    tailweight=tailweight)

            # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
            c = 2 * scale / f_noskew.forward(
                tf.convert_to_tensor(value=2, dtype=dtype))
            affine = affine_scalar_bijector.AffineScalar(
                shift=loc, scale=c, validate_args=validate_args)

            bijector = chain_bijector.Chain([affine, f])

            super(SinhArcsinh, self).__init__(distribution=distribution,
                                              bijector=bijector,
                                              batch_shape=batch_shape,
                                              validate_args=validate_args,
                                              name=name)
        self._parameters = parameters
        self._loc = loc
        self._scale = scale
        self._tailweight = tailweight
        self._skewness = skewness
 def _bijector_fn(x0, input_depth, **condition_kwargs):
     shift, log_scale = shift_and_log_scale_fn(
         x0, input_depth, **condition_kwargs)
     # ** First modification is here.
     return affine_scalar.AffineScalar(shift=shift, scale=log_scale)
Example #8
0
    def __init__(self,
                 loc,
                 scale,
                 skewness=None,
                 tailweight=None,
                 distribution=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='SinhArcsinh'):
        """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Must have a batch shape to which the shapes of `loc`, `scale`,
        `skewness`, and `tailweight` all broadcast. Default is
        `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted
        shape of the parameters. Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                            tf.float32)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name='loc',
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name='scale',
                                                               dtype=dtype)
            tailweight = 1. if tailweight is None else tailweight
            has_default_skewness = skewness is None
            skewness = 0. if has_default_skewness else skewness
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name='tailweight', dtype=dtype)
            self._skewness = tensor_util.convert_nonref_to_tensor(
                skewness, name='skewness', dtype=dtype)

            # Recall, with Z a random variable,
            #   Y := loc + scale * F(Z),
            #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C
            #   C := 2 / F_0(2)
            #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
            if distribution is None:
                batch_rank = tf.reduce_max([
                    distribution_util.prefer_static_rank(x)
                    for x in (self._skewness, self._tailweight, self._loc,
                              self._scale)
                ])
                # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden
                # `batch_shape` and `batch_shape_tensor` when
                # TransformedDistribution's bijector can modify its `batch_shape`.
                distribution = normal.Normal(loc=tf.zeros(tf.ones(
                    batch_rank, tf.int32),
                                                          dtype=dtype),
                                             scale=tf.ones([], dtype=dtype),
                                             allow_nan_stats=allow_nan_stats,
                                             validate_args=validate_args)

            # Make the SAS bijector, 'F'.
            f = sinh_arcsinh_bijector.SinhArcsinh(skewness=self._skewness,
                                                  tailweight=self._tailweight,
                                                  validate_args=validate_args)

            # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
            affine = affine_scalar_bijector.AffineScalar(
                shift=self._loc,
                scale=self._scale,
                validate_args=validate_args)

            bijector = chain_bijector.Chain([affine, f])

            super(SinhArcsinh, self).__init__(distribution=distribution,
                                              bijector=bijector,
                                              validate_args=validate_args,
                                              name=name)
            self._parameters = parameters