Exemple #1
0
  def __init__(self,
               diag_bijector=None,
               diag_shift=1e-5,
               validate_args=False,
               name="scale_tril"):
    """Instantiates the `ScaleTriL` bijector.

    Args:
      diag_bijector: `Bijector` instance, used to transform the output diagonal
        to be positive.
        Default value: `None` (i.e., `tfb.Softplus()`).
      diag_shift: Float value broadcastable and added to all diagonal entries
        after applying the `diag_bijector`. Setting a positive
        value forces the output diagonal entries to be positive, but
        prevents inverting the transformation for matrices with
        diagonal entries less than this value.
        Default value: `1e-5` (i.e., no shift is applied).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False` (i.e., arguments are not validated).
      name: Python `str` name given to ops managed by this object.
        Default value: `scale_tril`.
    """
    with tf.name_scope(name, values=[diag_shift]) as name:
      if diag_bijector is None:
        diag_bijector = softplus.Softplus(validate_args=validate_args)

      if diag_shift is not None:
        diag_shift = tf.convert_to_tensor(
            diag_shift, dtype=diag_bijector.dtype, name="diag_shift")
        diag_bijector = chain.Chain([
            affine_scalar.AffineScalar(shift=diag_shift),
            diag_bijector
        ])

      super(ScaleTriL, self).__init__(
          [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector),
           fill_triangular.FillTriangular()],
          validate_args=validate_args,
          name=name)
Exemple #2
0
def _get_flat_unconstraining_bijector(jd_model):
    """Create a bijector from a joint distribution that flattens and unconstrains.

  The intention is (loosely) to go from a model joint distribution supported on

  U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j}

  to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense
  of base measures: some distribution may be supported on an m-dimensional
  subset of R^n, and the default transform for that distribution may then
  have support on R^m. See [1] for details.

  Args:
    jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a
      model.

  Returns:
    A `tfb.Bijector` where the `.forward` method flattens and unconstrains
    points.
  """
    # TODO(b/180396233): This bijector is in general point-dependent.
    to_chain = [jd_model.experimental_default_event_space_bijector()]
    flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())
    to_chain.append(flat_bijector)

    unconstrained_shapes = flat_bijector.inverse_event_shape_tensor(
        jd_model.event_shape_tensor())

    # this reshaping is required as as split can produce a tensor of shape [1]
    # when the distribution event shape is []
    reshapers = [
        reshape.Reshape(event_shape_out=x, event_shape_in=[-1])
        for x in unconstrained_shapes
    ]
    to_chain.append(joint_map.JointMap(bijectors=reshapers))

    size_splits = [ps.reduce_prod(x) for x in unconstrained_shapes]
    to_chain.append(split.Split(num_or_size_splits=size_splits))

    return invert.Invert(chain.Chain(to_chain))
def _get_flat_unconstraining_bijector(jd_model):
    """Create a bijector from a joint distribution that flattens and unconstrains.

  The intention is (loosely) to go from a model joint distribution supported on

  U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j}

  to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense
  of base measures: some distribution may be supported on an m-dimensional
  subset of R^n, and the default transform for that distribution may then
  have support on R^m. See [1] for details.

  Args:
    jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a
      model.

  Returns:
    Two `tfb.Bijector`s where the `.forward` method flattens and unconstrains
    points, and the second may be used to initialize a step size.
  """
    # TODO(b/180396233): This bijector is in general point-dependent.
    event_space_bij = jd_model.experimental_default_event_space_bijector()
    flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())

    unconstrained_shapes = event_space_bij(
        flat_bijector).inverse_event_shape_tensor(
            jd_model.event_shape_tensor())

    # this reshaping is required as as split can produce a tensor of shape [1]
    # when the distribution event shape is []
    unsplit = joint_map.JointMap(
        tf.nest.map_structure(
            lambda x: reshape.Reshape(event_shape_out=x, event_shape_in=[-1]),
            unconstrained_shapes))

    bij = invert.Invert(chain.Chain([event_space_bij, flat_bijector, unsplit]))
    step_size_bij = invert.Invert(flat_bijector)

    return bij, step_size_bij
  def __init__(self,
               loc,
               scale,
               concentration,
               validate_args=False,
               name='generalized_pareto'):
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype(
          [loc, scale, concentration], dtype_hint=tf.float32)

      self._loc = tensor_util.convert_nonref_to_tensor(loc)
      self._scale = tensor_util.convert_nonref_to_tensor(scale)
      self._concentration = tensor_util.convert_nonref_to_tensor(concentration)
      self._non_negative_concentration_bijector = chain_bijector.Chain([
          shift_bijector.Shift(shift=self._loc, validate_args=validate_args),
          softplus_bijector.Softplus(validate_args=validate_args)
      ], validate_args=validate_args)
      super(GeneralizedPareto, self).__init__(
          validate_args=validate_args,
          forward_min_event_ndims=0,
          dtype=dtype,
          name=name)
Exemple #5
0
 def _transformed_beta(self, low=None, peak=None, high=None, temperature=None):
   low = tf.convert_to_tensor(self.low) if low is None else low
   peak = tf.convert_to_tensor(self.peak) if peak is None else peak
   high = tf.convert_to_tensor(self.high) if high is None else high
   temperature = (
       tf.convert_to_tensor(self.temperature)
       if temperature is None else temperature)
   scale = high - low
   concentration1 = (1. + temperature * (peak - low) / scale)
   concentration0 = (1. + temperature * (high - peak) / scale)
   return transformed_distribution.TransformedDistribution(
       distribution=beta.Beta(
           concentration1=concentration1,
           concentration0=concentration0,
           allow_nan_stats=self.allow_nan_stats),
       bijector=chain_bijector.Chain([
           shift_bijector.Shift(shift=low),
           # Broadcasting scale on affine bijector to match batch dimension.
           # This prevents dimension mismatch for operations like cdf.
           # Note that `concentration1` incorporates the broadcast of all four
           # parameters.
           scale_bijector.Scale(
               scale=tf.broadcast_to(
                   scale, prefer_static.shape(concentration1)))]))
Exemple #6
0
    def __init__(self,
                 loc,
                 scale,
                 skewness=None,
                 tailweight=None,
                 distribution=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="SinhArcsinh"):
        """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Default is `tfd.Normal(0., 1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())

        with tf.compat.v2.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                            tf.float32)
            loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
            scale = tf.convert_to_tensor(value=scale,
                                         name="scale",
                                         dtype=dtype)
            tailweight = 1. if tailweight is None else tailweight
            has_default_skewness = skewness is None
            skewness = 0. if skewness is None else skewness
            tailweight = tf.convert_to_tensor(value=tailweight,
                                              name="tailweight",
                                              dtype=dtype)
            skewness = tf.convert_to_tensor(value=skewness,
                                            name="skewness",
                                            dtype=dtype)

            batch_shape = distribution_util.get_broadcast_shape(
                loc, scale, tailweight, skewness)

            # Recall, with Z a random variable,
            #   Y := loc + C * F(Z),
            #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
            #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
            #   C := 2 * scale / F_0(2)
            if distribution is None:
                distribution = normal.Normal(loc=tf.zeros([], dtype=dtype),
                                             scale=tf.ones([], dtype=dtype),
                                             allow_nan_stats=allow_nan_stats)
            else:
                asserts = distribution_util.maybe_check_scalar_distribution(
                    distribution, dtype, validate_args)
                if asserts:
                    loc = distribution_util.with_dependencies(asserts, loc)

            # Make the SAS bijector, 'F'.
            f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness,
                                                  tailweight=tailweight)
            if has_default_skewness:
                f_noskew = f
            else:
                f_noskew = sinh_arcsinh_bijector.SinhArcsinh(
                    skewness=skewness.dtype.as_numpy_dtype(0.),
                    tailweight=tailweight)

            # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
            c = 2 * scale / f_noskew.forward(
                tf.convert_to_tensor(value=2, dtype=dtype))
            affine = affine_scalar_bijector.AffineScalar(
                shift=loc, scale=c, validate_args=validate_args)

            bijector = chain_bijector.Chain([affine, f])

            super(SinhArcsinh, self).__init__(distribution=distribution,
                                              bijector=bijector,
                                              batch_shape=batch_shape,
                                              validate_args=validate_args,
                                              name=name)
        self._parameters = parameters
        self._loc = loc
        self._scale = scale
        self._tailweight = tailweight
        self._skewness = skewness
Exemple #7
0
 def _default_event_space_bijector(self):
     return chain_bijector.Chain([
         reciprocal_bijector.Reciprocal(validate_args=self.validate_args),
         softplus_bijector.Softplus(validate_args=self.validate_args)
     ],
                                 validate_args=self.validate_args)
Exemple #8
0
def build_split_flow_surrogate_posterior(event_shape,
                                         trainable_bijector,
                                         constraining_bijector=None,
                                         base_distribution=normal.Normal,
                                         batch_shape=(),
                                         dtype=tf.float32,
                                         validate_args=False,
                                         name=None):
    """Builds a joint variational posterior by splitting a normalizing flow.

  Args:
    event_shape: (Nested) event shape of the surrogate posterior.
    trainable_bijector: A trainable `tfb.Bijector` instance that operates on
      `Tensor`s (not structures), e.g. `tfb.MaskedAutoregressiveFlow` or
      `tfb.RealNVP`. This bijector transforms the base distribution before it is
      split.
    constraining_bijector: `tfb.Bijector` instance, or nested structure of
      `tfb.Bijector` instances, that maps (nested) values in R^n to the support
      of the posterior. (This can be the
      `experimental_default_event_space_bijector` of the distribution over the
      prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and
      `scale`. The base distribution for the transformed surrogate has `loc=0.`
      and `scale=1.`.
      Default value: `tfd.Normal`.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    dtype: The `dtype` of the surrogate posterior.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_split_flow_surrogate_posterior').

  Returns:
    surrogate_distribution: Trainable `tfd.TransformedDistribution` with event
      shape equal to `event_shape`.

  ### Examples
  ```python

  # Train a normalizing flow on the Eight Schools model [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]
  model = tfd.JointDistributionNamed({
      'avg_effect':
          tfd.Normal(loc=0., scale=10., name='avg_effect'),
      'log_stddev':
          tfd.Normal(loc=5., scale=1., name='log_stddev'),
      'school_effects':
          lambda log_stddev, avg_effect: (
              tfd.Independent(
                  tfd.Normal(
                      loc=avg_effect[..., None] * tf.ones(8),
                      scale=tf.exp(log_stddev[..., None]) * tf.ones(8),
                      name='school_effects'),
                  reinterpreted_batch_ndims=1)),
      'treatment_effects': lambda school_effects: tfd.Independent(
          tfd.Normal(loc=school_effects, scale=treatment_stddevs),
          reinterpreted_batch_ndims=1)
  })

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Create a Masked Autoregressive Flow bijector.
  net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32)
  maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net)

  # Build and fit the surrogate posterior.
  surrogate_posterior = (
      tfp.experimental.vi.build_split_flow_surrogate_posterior(
          event_shape=target_model.event_shape_tensor(),
          trainable_bijector=maf,
          constraining_bijector=(
              target_model.experimental_default_event_space_bijector())))

  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name or 'build_split_flow_surrogate_posterior'):

        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(shallow_structure,
                                               ps.convert_to_shape_tensor,
                                               event_shape)

        if nest.is_nested(constraining_bijector):
            constraining_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity()
                    if b is None else b, constraining_bijector),
                validate_args=validate_args)

        if constraining_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                constraining_bijector.inverse_event_shape_tensor(event_shape))

        flat_base_event_shape = nest.flatten(unconstrained_event_shape)
        flat_base_event_size = nest.map_structure(tf.reduce_prod,
                                                  flat_base_event_shape)
        event_size = tf.reduce_sum(flat_base_event_size)

        base_distribution = sample.Sample(
            base_distribution(tf.zeros(batch_shape, dtype=dtype), scale=1.),
            [event_size])

        # After transforming base distribution samples with `trainable_bijector`,
        # split them into vector-valued components.
        split_bijector = split.Split(flat_base_event_size,
                                     validate_args=validate_args)

        # Reshape the vectors to the correct posterior event shape.
        event_reshape = joint_map.JointMap(nest.map_structure(
            reshape.Reshape, unconstrained_event_shape),
                                           validate_args=validate_args)

        # Restructure the flat list of components to the correct posterior
        # structure.
        event_unflatten = restructure.Restructure(
            nest.pack_sequence_as(unconstrained_event_shape,
                                  range(len(flat_base_event_shape))))

        bijectors = [] if constraining_bijector is None else [
            constraining_bijector
        ]
        bijectors.extend([
            event_reshape, event_unflatten, split_bijector, trainable_bijector
        ])
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        return transformed_distribution.TransformedDistribution(
            base_distribution, bijector=bijector, validate_args=validate_args)
Exemple #9
0
def _affine_surrogate_posterior_from_base_distribution(
        base_distribution,
        operators='diag',
        bijector=None,
        initial_unconstrained_loc_fn=_sample_uniform_initial_loc,
        validate_args=False,
        name=None):
    """Builds a variational posterior by linearly transforming base distributions.

  This function builds a surrogate posterior by applying a trainable
  transformation to a base distribution (typically a `tfd.JointDistribution`) or
  nested structure of base distributions, and constraining the samples with
  `bijector`. Note that the distributions must have event shapes corresponding
  to the *pretransformed* surrogate posterior -- that is, if `bijector` contains
  a shape-changing bijector, then the corresponding base distribution event
  shape is the inverse event shape of the bijector applied to the desired
  surrogate posterior shape. The surrogate posterior is constucted as follows:

  1. Flatten the base distribution event shapes to vectors, and pack the base
     distributions into a `tfd.JointDistribution`.
  2. Apply a trainable blockwise LinearOperator bijector to the joint base
     distribution.
  3. Apply the constraining bijectors and return the resulting trainable
     `tfd.TransformedDistribution` instance.

  Args:
    base_distribution: `tfd.Distribution` instance (typically a
      `tfd.JointDistribution`), or a nested structure of `tfd.Distribution`
      instances.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    initial_unconstrained_loc_fn: Optional Python `callable` with signature
      `initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)` used to
      sample real-valued initializations for the unconstrained location of
      each variable.
      Default value: `functools.partial(tf.random.stateless_uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e.,
      'build_affine_surrogate_posterior_from_base_distribution').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  Raises:
    NotImplementedError: Base distributions with mixed dtypes are not supported.

  #### Examples
  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Fit a multivariate Normal surrogate posterior on the Eight Schools model
  # [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]

  def model_fn():
    avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect')
    log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev')
    school_effects = yield tfd.Sample(
        tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)),
        sample_shape=[8],
        name='school_effects')
    treatment_effects = yield tfd.Independent(
        tfd.Normal(loc=school_effects, scale=treatment_stddevs),
        reinterpreted_batch_ndims=1,
        name='treatment_effects')
  model = tfd.JointDistributionCoroutineAutoBatched(model_fn)

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Define a lower triangular structure of `LinearOperator` subclasses that
  # models full covariance among latent variables except for the 8 dimensions
  # of `school_effect`, which are modeled as independent (using
  # `LinearOperatorDiag`).
  operators = [
    [tf.linalg.LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix,
     tf.linalg.LinearOperatorDiag]]


  # Constrain the posterior values to the support of the prior.
  bijector = target_model.experimental_default_event_space_bijector()

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution(
        base_distribution=base_distribution,
        operators=operators,
        bijector=bijector))

  # Fit the model.
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name
                       or 'affine_surrogate_posterior_from_base_distribution'):

        if nest.is_nested(base_distribution):
            base_distribution = (joint_distribution_util.
                                 independent_joint_distribution_from_structure(
                                     base_distribution,
                                     validate_args=validate_args))

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        batch_shape = base_distribution.batch_shape_tensor()
        if tf.nest.is_nested(
                batch_shape):  # Base is a classic JointDistribution.
            batch_shape = functools.reduce(ps.broadcast_shape,
                                           tf.nest.flatten(batch_shape))
        event_shape = base_distribution.event_shape_tensor()
        flat_event_size = nest.flatten(
            nest.map_structure(ps.reduce_prod, event_shape))

        base_dtypes = set([
            dtype_util.base_dtype(d)
            for d in nest.flatten(base_distribution.dtype)
        ])
        if len(base_dtypes) > 1:
            raise NotImplementedError(
                'Base distributions with mixed dtype are not supported. Saw '
                'components of dtype {}'.format(base_dtypes))
        base_dtype = list(base_dtypes)[0]

        num_components = len(flat_event_size)
        if operators == 'diag':
            operators = [tf.linalg.LinearOperatorDiag] * num_components
        elif operators == 'tril':
            operators = [[tf.linalg.LinearOperatorFullMatrix] * i +
                         [tf.linalg.LinearOperatorLowerTriangular]
                         for i in range(num_components)]
        elif isinstance(operators, str):
            raise ValueError(
                'Unrecognized operator type {}. Valid operators are "diag", "tril", '
                'or a structure that can be passed to '
                '`tfp.experimental.vi.util.build_trainable_linear_operator_block` as '
                'the `operators` arg.'.format(operators))

        if nest.is_nested(operators):
            operators = yield from trainable_linear_operators._trainable_linear_operator_block(  # pylint: disable=protected-access
                operators,
                block_dims=flat_event_size,
                dtype=base_dtype,
                batch_shape=batch_shape)

        linop_bijector = (
            scale_matvec_linear_operator.ScaleMatvecLinearOperatorBlock(
                scale=operators, validate_args=validate_args))

        def generate_shift_bijector(s):
            x = yield trainable_state_util.Parameter(
                functools.partial(initial_unconstrained_loc_fn,
                                  ps.concat([batch_shape, [s]], axis=0),
                                  dtype=base_dtype))
            return shift.Shift(x)

        loc_bijectors = yield from nest_util.map_structure_coroutine(
            generate_shift_bijector, flat_event_size)
        loc_bijector = joint_map.JointMap(loc_bijectors,
                                          validate_args=validate_args)

        unflatten_and_reshape = chain.Chain([
            joint_map.JointMap(nest.map_structure(reshape.Reshape,
                                                  event_shape),
                               validate_args=validate_args),
            restructure.Restructure(
                nest.pack_sequence_as(event_shape, range(num_components)))
        ],
                                            validate_args=validate_args)

        bijectors = [] if bijector is None else [bijector]
        bijectors.extend([
            unflatten_and_reshape,
            loc_bijector,  # Allow the mean of the standard dist to shift from 0.
            linop_bijector
        ])  # Apply LinOp to scale the standard dist.
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        flat_base_distribution = invert.Invert(unflatten_and_reshape)(
            base_distribution)

        return transformed_distribution.TransformedDistribution(
            flat_base_distribution,
            bijector=bijector,
            validate_args=validate_args)
Exemple #10
0
  def __call__(self, value, name=None, **kwargs):
    """Applies or composes the `Bijector`, depending on input type.

    This is a convenience function which applies the `Bijector` instance in
    three different ways, depending on the input:

    1. If the input is a `tfd.Distribution` instance, return
       `tfd.TransformedDistribution(distribution=input, bijector=self)`.
    2. If the input is a `tfb.Bijector` instance, return
       `tfb.Chain([self, input])`.
    3. Otherwise, return `self.forward(input)`

    Args:
      value: A `tfd.Distribution`, `tfb.Bijector`, or a `Tensor`.
      name: Python `str` name given to ops created by this function.
      **kwargs: Additional keyword arguments passed into the created
        `tfd.TransformedDistribution`, `tfb.Bijector`, or `self.forward`.

    Returns:
      composition: A `tfd.TransformedDistribution` if the input was a
        `tfd.Distribution`, a `tfb.Chain` if the input was a `tfb.Bijector`, or
        a `Tensor` computed by `self.forward`.

    #### Examples

    ```python
    sigmoid = tfb.Reciprocal()(
        tfb.AffineScalar(shift=1.)(
          tfb.Exp()(
            tfb.AffineScalar(scale=-1.))))
    # ==> `tfb.Chain([
    #         tfb.Reciprocal(),
    #         tfb.AffineScalar(shift=1.),
    #         tfb.Exp(),
    #         tfb.AffineScalar(scale=-1.),
    #      ])`  # ie, `tfb.Sigmoid()`

    log_normal = tfb.Exp()(tfd.Normal(0, 1))
    # ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())`

    tfb.Exp()([-1., 0., 1.])
    # ==> tf.exp([-1., 0., 1.])
    ```

    """

    # To avoid circular dependencies and keep the implementation local to the
    # `Bijector` class, we violate PEP8 guidelines and import here rather than
    # at the top of the file.
    from tensorflow_probability.python.bijectors import chain  # pylint: disable=g-import-not-at-top
    from tensorflow_probability.python.distributions import distribution  # pylint: disable=g-import-not-at-top
    from tensorflow_probability.python.distributions import transformed_distribution  # pylint: disable=g-import-not-at-top

    # TODO(b/128841942): Handle Conditional distributions and bijectors.
    if type(value) is transformed_distribution.TransformedDistribution:  # pylint: disable=unidiomatic-typecheck
      # We cannot accept subclasses with different constructors here, because
      # subclass constructors may accept constructor arguments TD doesn't know
      # how to handle. e.g. `TypeError: __init__() got an unexpected keyword
      # argument 'allow_nan_stats'` when doing
      # `tfb.Identity()(tfd.Chi(df=1., allow_nan_stats=True))`.
      new_kwargs = value.parameters
      new_kwargs.update(kwargs)
      new_kwargs['name'] = name or new_kwargs.get('name', None)
      new_kwargs['bijector'] = self(value.bijector)
      return transformed_distribution.TransformedDistribution(**new_kwargs)

    if isinstance(value, distribution.Distribution):
      return transformed_distribution.TransformedDistribution(
          distribution=value,
          bijector=self,
          name=name,
          **kwargs)

    if isinstance(value, chain.Chain):
      new_kwargs = kwargs.copy()
      new_kwargs['bijectors'] = [self] + ([] if value.bijectors is None
                                          else list(value.bijectors))
      if 'validate_args' not in new_kwargs:
        new_kwargs['validate_args'] = value.validate_args
      new_kwargs['name'] = name or value.name
      return chain.Chain(**new_kwargs)

    if isinstance(value, Bijector):
      return chain.Chain([self, value], name=name, **kwargs)

    return self.forward(value, name=name or 'forward', **kwargs)
Exemple #11
0
  def __init__(self,
               loc,
               scale,
               skewness=None,
               tailweight=None,
               distribution=None,
               validate_args=False,
               allow_nan_stats=True,
               name='SinhArcsinh'):
    """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Must have a batch shape to which the shapes of `loc`, `scale`,
        `skewness`, and `tailweight` all broadcast. Default is
        `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted
        shape of the parameters. Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())

    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                      tf.float32)
      self._loc = tensor_util.convert_nonref_to_tensor(
          loc, name='loc', dtype=dtype)
      self._scale = tensor_util.convert_nonref_to_tensor(
          scale, name='scale', dtype=dtype)
      tailweight = 1. if tailweight is None else tailweight
      has_default_skewness = skewness is None
      skewness = 0. if has_default_skewness else skewness
      self._tailweight = tensor_util.convert_nonref_to_tensor(
          tailweight, name='tailweight', dtype=dtype)
      self._skewness = tensor_util.convert_nonref_to_tensor(
          skewness, name='skewness', dtype=dtype)

      # Recall, with Z a random variable,
      #   Y := loc + scale * F(Z),
      #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C
      #   C := 2 / F_0(2)
      #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
      if distribution is None:
        batch_shape = functools.reduce(
            ps.broadcast_shape,
            [ps.shape(x)
             for x in (self._skewness, self._tailweight,
                       self._loc, self._scale)])

        distribution = normal.Normal(
            loc=tf.zeros(batch_shape, dtype=dtype),
            scale=tf.ones([], dtype=dtype),
            allow_nan_stats=allow_nan_stats,
            validate_args=validate_args)

      # Make the SAS bijector, 'F'.
      f = sinh_arcsinh_bijector.SinhArcsinh(
          skewness=self._skewness, tailweight=self._tailweight,
          validate_args=validate_args)

      # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
      affine = shift_bijector.Shift(shift=self._loc)(
          scale_bijector.Scale(scale=self._scale))
      bijector = chain_bijector.Chain([affine, f])

      super(SinhArcsinh, self).__init__(
          distribution=distribution,
          bijector=bijector,
          validate_args=validate_args,
          name=name)
      self._parameters = parameters
Exemple #12
0
    def __call__(self, value, name=None, **kwargs):
        """Applies or composes the `Bijector`, depending on input type.

    This is a convenience function which applies the `Bijector` instance in
    three different ways, depending on the input:

    1. If the input is a `tfd.Distribution` instance, return
       `tfd.TransformedDistribution(distribution=input, bijector=self)`.
    2. If the input is a `tfb.Bijector` instance, return
       `tfb.Chain([self, input])`.
    3. Otherwise, return `self.forward(input)`

    Args:
      value: A `tfd.Distribution`, `tfb.Bijector`, or a `Tensor`.
      name: Python `str` name given to ops created by this function.
      **kwargs: Additional keyword arguments passed into the created
        `tfd.TransformedDistribution`, `tfb.Bijector`, or `self.forward`.

    Returns:
      composition: A `tfd.TransformedDistribution` if the input was a
        `tfd.Distribution`, a `tfb.Chain` if the input was a `tfb.Bijector`, or
        a `Tensor` computed by `self.forward`.

    #### Examples

    ```python
    sigmoid = tfb.Reciprocal()(
        tfb.AffineScalar(shift=1.)(
          tfb.Exp()(
            tfb.AffineScalar(scale=-1.))))
    # ==> `tfb.Chain([
    #         tfb.Reciprocal(),
    #         tfb.AffineScalar(shift=1.),
    #         tfb.Exp(),
    #         tfb.AffineScalar(scale=-1.),
    #      ])`  # ie, `tfb.Sigmoid()`

    log_normal = tfb.Exp()(tfd.Normal(0, 1))
    # ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())`

    tfb.Exp()([-1., 0., 1.])
    # ==> tf.exp([-1., 0., 1.])
    ```

    """

        # To avoid circular dependencies and keep the implementation local to the
        # `Bijector` class, we violate PEP8 guidelines and import here rather than
        # at the top of the file.
        from tensorflow_probability.python.bijectors import chain  # pylint: disable=g-import-not-at-top
        from tensorflow_probability.python.distributions import distribution  # pylint: disable=g-import-not-at-top
        from tensorflow_probability.python.distributions import transformed_distribution  # pylint: disable=g-import-not-at-top

        if isinstance(value, transformed_distribution.TransformedDistribution):
            new_kwargs = value.parameters
            new_kwargs.update(kwargs)
            new_kwargs["name"] = name or new_kwargs.get("name", None)
            new_kwargs["bijector"] = self(value.bijector)
            return transformed_distribution.TransformedDistribution(
                **new_kwargs)

        if isinstance(value, distribution.Distribution):
            return transformed_distribution.TransformedDistribution(
                distribution=value, bijector=self, name=name, **kwargs)

        if isinstance(value, chain.Chain):
            new_kwargs = kwargs.copy()
            new_kwargs["bijectors"] = [self] + ([] if value.bijectors is None
                                                else list(value.bijectors))
            if "validate_args" not in new_kwargs:
                new_kwargs["validate_args"] = value.validate_args
            new_kwargs["name"] = name or value.name
            return chain.Chain(**new_kwargs)

        if isinstance(value, Bijector):
            return chain.Chain([self, value], name=name, **kwargs)

        return self._call_forward(value, name=name or "forward", **kwargs)
Exemple #13
0
    def __init__(self,
                 bijectors,
                 block_sizes=None,
                 validate_args=False,
                 maybe_changes_size=True,
                 name=None):
        """Creates the bijector.

    Args:
      bijectors: A non-empty list of bijectors.
      block_sizes: A 1-D integer `Tensor` with each element signifying the
        length of the block of the input vector to pass to the corresponding
        bijector. The length of `block_sizes` must be be equal to the length of
        `bijectors`. If left as None, a vector of 1's is used.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      maybe_changes_size: Python `bool` indicating that this bijector might
        change the event size. If this is known to be false and set
        appropriately, then this will lead to improved static shape inference
        when the block sizes are not statically known.
      name: Python `str`, name given to ops managed by this object. Default:
        E.g., `Blockwise([Exp(), Softplus()]).name ==
        'blockwise_of_exp_and_softplus'`.

    Raises:
      NotImplementedError: If there is a bijector with `event_ndims` > 1.
      ValueError: If `bijectors` list is empty.
      ValueError: If size of `block_sizes` does not equal to the length of
        bijectors or is not a vector.
    """
        parameters = dict(locals())
        if not name:
            name = 'blockwise_of_' + '_and_'.join([b.name for b in bijectors])
            name = name.replace('/', '')

        with tf.name_scope(name) as name:
            for b in bijectors:
                if (nest.is_nested(b.forward_min_event_ndims)
                        or nest.is_nested(b.inverse_min_event_ndims)):
                    raise ValueError('Bijectors must all be single-part.')
                elif isinstance(b.forward_min_event_ndims, int):
                    if b.forward_min_event_ndims != b.inverse_min_event_ndims:
                        raise ValueError(
                            'Rank-changing bijectors are not supported.')
                    elif b.forward_min_event_ndims > 1:
                        raise ValueError(
                            'Only scalar and vector event-shape '
                            'bijectors are supported at this time.')

            b_joint = joint_map.JointMap(list(bijectors), name='jointmap')

            block_sizes = (np.ones(len(bijectors), dtype=np.int32)
                           if block_sizes is None else _validate_block_sizes(
                               block_sizes, bijectors, validate_args))
            b_split = split.Split(block_sizes,
                                  name='split',
                                  validate_args=validate_args)

            if maybe_changes_size:
                i_block_sizes = _validate_block_sizes(
                    ps.concat(b_joint.forward_event_shape_tensor(
                        ps.split(block_sizes, len(bijectors))),
                              axis=0), bijectors, validate_args)
                maybe_changes_size = not tf.get_static_value(
                    ps.reduce_all(block_sizes == i_block_sizes))
            b_concat = invert.Invert((split.Split(i_block_sizes, name='isplit')
                                      if maybe_changes_size else b_split),
                                     name='concat')

            self._maybe_changes_size = maybe_changes_size
            self._chain = chain.Chain([b_concat, b_joint, b_split],
                                      validate_args=validate_args)
            super(_Blockwise, self).__init__(bijectors=self._chain.bijectors,
                                             validate_args=validate_args,
                                             validate_event_size=True,
                                             parameters=parameters,
                                             name=name)
Exemple #14
0
    def __init__(self,
                 output_shape=(32, 32, 3),
                 num_glow_blocks=3,
                 num_steps_per_block=32,
                 coupling_bijector_fn=None,
                 exit_bijector_fn=None,
                 grab_after_block=None,
                 use_actnorm=True,
                 seed=None,
                 validate_args=False,
                 name='glow'):
        """Creates the Glow bijector.

    Args:
      output_shape: A list of integers, specifying the event shape of the
        output, of the bijectors forward pass (the image).  Specified as
        [H, W, C].
        Default Value: (32, 32, 3)
      num_glow_blocks: An integer, specifying how many downsampling levels to
        include in the model. This must divide equally into both H and W,
        otherwise the bijector would not be invertible.
        Default Value: 3
      num_steps_per_block: An integer specifying how many Affine Coupling and
        1x1 convolution layers to include at each level of the spatial
        hierarchy.
        Default Value: 32 (i.e. the value used in the original glow paper).
      coupling_bijector_fn: A function which takes the argument `input_shape`
        and returns a callable neural network (e.g. a keras.Sequential). The
        network should either return a tensor with the same event shape as
        `input_shape` (this will employ additive coupling), a tensor with the
        same height and width as `input_shape` but twice the number of channels
        (this will employ affine coupling), or a bijector which takes in a
        tensor with event shape `input_shape`, and returns a tensor with shape
        `input_shape`.
      exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is
        a function which takes the argument `input_shape` and `output_chan`
        and returns a callable neural network. The neural network it returns
        should take a tensor of shape `input_shape` as the input, and return
        one of three options: A tensor with `output_chan` channels, a tensor
        with `2 * output_chan` channels, or a bijector. Additional details can
        be found in the documentation for ExitBijector.
      grab_after_block: A tuple of floats, specifying what fraction of the
        remaining channels to remove following each glow block. Glow will take
        the integer floor of this number multiplied by the remaining number of
        channels. The default is half at each spatial hierarchy.
        Default value: None (this will take out half of the channels after each
          block.
      use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent
        initialization is used to initialize this layer.
        Default value: `False`
      seed: A seed to control randomness in the 1x1 convolution initialization.
        Default value: `None` (i.e., non-reproducible sampling).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False`
      name: Python `str`, name given to ops managed by this object.
        Default value: `'glow'`.
    """
        parameters = dict(locals())
        # Make sure that the input shape is fully defined.
        if not tensorshape_util.is_fully_defined(output_shape):
            raise ValueError('Shape must be fully defined.')
        if tensorshape_util.rank(output_shape) != 3:
            raise ValueError('Shape ndims must be 3 for images.  Your shape is'
                             '{}'.format(tensorshape_util.rank(output_shape)))

        num_glow_blocks_ = tf.get_static_value(num_glow_blocks)
        if (num_glow_blocks_ is None
                or int(num_glow_blocks_) != num_glow_blocks_
                or num_glow_blocks_ < 1):
            raise ValueError(
                'Argument `num_glow_blocks` must be a statically known'
                'positive `int` (saw: {}).'.format(num_glow_blocks))
        num_glow_blocks = int(num_glow_blocks_)

        output_shape = tensorshape_util.as_list(output_shape)
        h, w, c = output_shape
        n = num_glow_blocks
        nsteps = num_steps_per_block

        # Default Glow: Half of the channels are split off after each block,
        # and after the final block, no channels are split off.
        if grab_after_block is None:
            grab_after_block = tuple([0.5] * (n - 1) + [0.])

        # Thing we know must be true: h and w are evenly divisible by 2, n times.
        # Otherwise, the squeeze bijector will not work.
        if w % 2**n != 0:
            raise ValueError('Width must be divisible by 2 at least n times.'
                             'Saw: {} % {} != 0'.format(w, 2**n))
        if h % 2**n != 0:
            raise ValueError(
                'Height should be divisible by 2 at least n times.')
        if h // 2**n < 1:
            raise ValueError(
                'num_glow_blocks ({0}) is too large. The image height '
                '({1}) must be divisible by 2 no more than {2} '
                'times.'.format(num_glow_blocks, h,
                                int(np.log(h) / np.log(2.))))
        if w // 2**n < 1:
            raise ValueError(
                'num_glow_blocks ({0}) is too large. The image width '
                '({1}) must be divisible by 2 no more than {2} '
                'times.'.format(num_glow_blocks, w,
                                int(np.log(h) / np.log(2.))))

        # Other things we want to be true:
        # - The number of times we take must be equal to the number of glow blocks.
        if len(grab_after_block) != num_glow_blocks:
            raise ValueError(
                'Length of grab_after_block ({0}) must match the number'
                'of blocks ({1}).'.format(len(grab_after_block),
                                          num_glow_blocks))

        self._blockwise_splits = self._get_blockwise_splits(
            output_shape, grab_after_block[::-1])

        # Now check on the values of blockwise splits
        if any([bs[0] < 1 for bs in self._blockwise_splits]):
            first_offender = [bs[0]
                              for bs in self._blockwise_splits].index(True)
            raise ValueError(
                'At at least one exit, you are taking out all of your '
                'channels, and therefore have no inputs to later blocks.'
                ' Try setting grab_after_block to a lower value at index'
                '{}.'.format(first_offender))

        if any(np.isclose(gab, 0) for gab in grab_after_block):
            # Special case: if specifically exiting no channels, then the exit is
            # just an identity bijector.
            pass
        elif any([bs[1] < 1 for bs in self._blockwise_splits]):
            first_offender = [bs[1]
                              for bs in self._blockwise_splits].index(True)
            raise ValueError(
                'At least one of your layers has < 1 output channels. '
                'This means you set grab_at_block too small. '
                'Try setting grab_after_block to a larger value at index'
                '{}.'.format(first_offender))

        # Lets start to build our bijector. We assume that the distribution is 1
        # dimensional. First, lets reshape it to an image.
        glow_chain = [
            reshape.Reshape(event_shape_out=[h // 2**n, w // 2**n, c * 4**n],
                            event_shape_in=[h * w * c])
        ]

        seedstream = SeedStream(seed=seed, salt='random_beta')

        for i in range(n):

            # This is the shape of the current tensor
            current_shape = (h // 2**n * 2**i, w // 2**n * 2**i,
                             c * 4**(i + 1))

            # This is the shape of the input to both the glow block and exit bijector.
            this_nchan = sum(self._blockwise_splits[i][0:2])
            this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan)

            glow_chain.append(
                invert.Invert(
                    ExitBijector(current_shape, self._blockwise_splits[i],
                                 exit_bijector_fn)))

            glow_block = GlowBlock(input_shape=this_input_shape,
                                   num_steps=nsteps,
                                   coupling_bijector_fn=coupling_bijector_fn,
                                   use_actnorm=use_actnorm,
                                   seedstream=seedstream)

            if self._blockwise_splits[i][2] == 0:
                # All channels are passed to the RealNVP
                glow_chain.append(glow_block)
            else:
                # Some channels are passed around the block.
                # This is done with the Blockwise bijector.
                glow_chain.append(
                    blockwise.Blockwise(
                        [glow_block, identity.Identity()], [
                            sum(self._blockwise_splits[i][0:2]),
                            self._blockwise_splits[i][2]
                        ]))

            # Finally, lets expand the channels into spatial features.
            glow_chain.append(
                Expand(input_shape=[
                    h // 2**n * 2**i,
                    w // 2**n * 2**i,
                    c * 4**n // 4**i,
                ]))

        glow_chain = glow_chain[::-1]
        # To finish off, we build a bijector that chains the components together
        # sequentially.
        super(Glow, self).__init__(bijectors=chain.Chain(
            glow_chain, validate_args=validate_args),
                                   validate_args=validate_args,
                                   parameters=parameters,
                                   name=name)
Exemple #15
0
  def __init__(self,
               loc=None,
               scale_diag=None,
               scale_identity_multiplier=None,
               skewness=None,
               tailweight=None,
               distribution=None,
               validate_args=False,
               allow_nan_stats=True,
               name="MultivariateNormalLinearOperator"):
    """Construct VectorSinhArcsinhDiag distribution on `R^k`.

    The arguments `scale_diag` and `scale_identity_multiplier` combine to
    define the diagonal `scale` referred to in this class docstring:

    ```none
    scale = diag(scale_diag + scale_identity_multiplier * ones(k))
    ```

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
        matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
        and characterizes `b`-batches of `k x k` diagonal matrices added to
        `scale`. When both `scale_identity_multiplier` and `scale_diag` are
        `None` then `scale` is the `Identity`.
      scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
        a scale-identity-matrix added to `scale`. May have shape
        `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale
        `k x k` identity matrices added to `scale`. When both
        `scale_identity_multiplier` and `scale_diag` are `None` then `scale`
        is the `Identity`.
      skewness:  Skewness parameter.  floating-point `Tensor` with shape
        broadcastable with `event_shape`.
      tailweight:  Tailweight parameter.  floating-point `Tensor` with shape
        broadcastable with `event_shape`.
      distribution: `tf.Distribution`-like instance. Distribution from which `k`
        iid samples are used as input to transformation `F`.  Default is
        `tfd.Normal(loc=0., scale=1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a VectorSinhArcsinhDiag sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if at most `scale_identity_multiplier` is specified.
    """
    parameters = dict(locals())

    with tf.name_scope(
        name,
        values=[
            loc, scale_diag, scale_identity_multiplier, skewness, tailweight
        ]) as name:
      dtype = dtype_util.common_dtype(
          [loc, scale_diag, scale_identity_multiplier, skewness, tailweight],
          tf.float32)
      loc = loc if loc is None else tf.convert_to_tensor(
          value=loc, name="loc", dtype=dtype)
      tailweight = 1. if tailweight is None else tailweight
      has_default_skewness = skewness is None
      skewness = 0. if skewness is None else skewness

      # Recall, with Z a random variable,
      #   Y := loc + C * F(Z),
      #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
      #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
      #   C := 2 * scale / F_0(2)

      # Construct shapes and 'scale' out of the scale_* and loc kwargs.
      # scale_linop is only an intermediary to:
      #  1. get shapes from looking at loc and the two scale args.
      #  2. combine scale_diag with scale_identity_multiplier, which gives us
      #     'scale', which in turn gives us 'C'.
      scale_linop = distribution_util.make_diag_scale(
          loc=loc,
          scale_diag=scale_diag,
          scale_identity_multiplier=scale_identity_multiplier,
          validate_args=False,
          assert_positive=False,
          dtype=dtype)
      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
          loc, scale_linop)
      # scale_linop.diag_part() is efficient since it is a diag type linop.
      scale_diag_part = scale_linop.diag_part()
      dtype = scale_diag_part.dtype

      if distribution is None:
        distribution = normal.Normal(
            loc=tf.zeros([], dtype=dtype),
            scale=tf.ones([], dtype=dtype),
            allow_nan_stats=allow_nan_stats)
      else:
        asserts = distribution_util.maybe_check_scalar_distribution(
            distribution, dtype, validate_args)
        if asserts:
          scale_diag_part = distribution_util.with_dependencies(
              asserts, scale_diag_part)

      # Make the SAS bijector, 'F'.
      skewness = tf.convert_to_tensor(
          value=skewness, dtype=dtype, name="skewness")
      tailweight = tf.convert_to_tensor(
          value=tailweight, dtype=dtype, name="tailweight")
      f = sinh_arcsinh_bijector.SinhArcsinh(
          skewness=skewness, tailweight=tailweight)
      if has_default_skewness:
        f_noskew = f
      else:
        f_noskew = sinh_arcsinh_bijector.SinhArcsinh(
            skewness=skewness.dtype.as_numpy_dtype(0.),
            tailweight=tailweight)

      # Make the Affine bijector, Z --> loc + C * Z.
      c = 2 * scale_diag_part / f_noskew.forward(
          tf.convert_to_tensor(value=2, dtype=dtype))
      affine = affine_bijector.Affine(
          shift=loc, scale_diag=c, validate_args=validate_args)

      bijector = chain_bijector.Chain([affine, f])

      super(VectorSinhArcsinhDiag, self).__init__(
          distribution=distribution,
          bijector=bijector,
          batch_shape=batch_shape,
          event_shape=event_shape,
          validate_args=validate_args,
          name=name)
    self._parameters = parameters
    self._loc = loc
    self._scale = scale_linop
    self._tailweight = tailweight
    self._skewness = skewness
 def _default_event_space_bijector(self):
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args),
       exp_bijector.Exp(validate_args=self.validate_args)
   ],
                               validate_args=self.validate_args)
Exemple #17
0
    def __init__(self,
                 low=None,
                 high=None,
                 hinge_softness=None,
                 validate_args=False,
                 name='soft_clip'):
        """Instantiates the SoftClip bijector.

    Args:
      low: Optional float `Tensor` lower bound. If `None`, the lower-bound
        constraint is omitted.
        Default value: `None`.
      high: Optional float `Tensor` upper bound. If `None`, the upper-bound
        constraint is omitted.
        Default value: `None`.
      hinge_softness: Optional nonzero float `Tensor`. Controls the softness
        of the constraint at the boundaries; values outside of the constraint
        set are mapped into intervals of width approximately
        `log(2) * hinge_softness` on the interior of each boundary. High
        softness reserves more space for values outside of the constraint set,
        leading to greater distortion of inputs *within* the constraint set,
        but improved numerical stability near the boundaries.
        Default value: `None` (`1.0`).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
        parameters = dict(locals())
        with tf.name_scope(name):
            dtype = dtype_util.common_dtype([low, high, hinge_softness],
                                            dtype_hint=tf.float32)
            low = tensor_util.convert_nonref_to_tensor(low,
                                                       name='low',
                                                       dtype=dtype)
            high = tensor_util.convert_nonref_to_tensor(high,
                                                        name='high',
                                                        dtype=dtype)
            hinge_softness = tensor_util.convert_nonref_to_tensor(
                hinge_softness, name='hinge_softness', dtype=dtype)

            softplus_bijector = softplus.Softplus(
                hinge_softness=hinge_softness)
            negate = tf.convert_to_tensor(-1., dtype=dtype)

            components = []
            if low is not None and high is not None:
                # Support reference tensors (eg Variables) for `high` and `low` by
                # deferring all computation on them until needed.
                width = tfp_util.DeferredTensor(
                    pretransformed_input=high,
                    transform_fn=lambda high: high - low)
                negated_shrinkage_factor = tfp_util.DeferredTensor(
                    pretransformed_input=width,
                    transform_fn=lambda w: tf.cast(  # pylint: disable=g-long-lambda
                        negate * w / softplus_bijector.forward(w),
                        dtype=dtype))

                # Implement the soft constraint from 'Mathematical Details' above:
                #  softclip(x) := -softplus(width - softplus(x - low)) *
                #                        (width) / (softplus(width)) + high
                components = [
                    shift.Shift(high),
                    scale.Scale(negated_shrinkage_factor), softplus_bijector,
                    shift.Shift(width),
                    scale.Scale(negate), softplus_bijector,
                    shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x))
                ]
            elif low is not None:
                # Implement a soft lower bound:
                #  softlower(x) := softplus(x - low) + low
                components = [
                    shift.Shift(low), softplus_bijector,
                    shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x))
                ]
            elif high is not None:
                # Implement a soft upper bound:
                #  softupper(x) := -softplus(high - x) + high
                components = [
                    shift.Shift(high),
                    scale.Scale(negate), softplus_bijector,
                    scale.Scale(negate),
                    shift.Shift(high)
                ]

            self._low = low
            self._high = high
            self._hinge_softness = hinge_softness
            self._chain = chain.Chain(components, validate_args=validate_args)

        super(SoftClip, self).__init__(forward_min_event_ndims=0,
                                       dtype=dtype,
                                       validate_args=validate_args,
                                       parameters=parameters,
                                       is_constant_jacobian=not components,
                                       name=name)