def test_f64_state(self, method, method_kwargs):
        states, _ = callable_util.get_output_spec(lambda: method(  # pylint: disable=g-long-lambda
            5,
            tfd.Normal(tf.constant(0., tf.float64), 1.),
            n_chains=2,
            num_adaptation_steps=100,
            seed=test_util.test_seed(),
            **method_kwargs))

        self.assertEqual(tf.float64, states.dtype)
Beispiel #2
0
    def __init__(
            self,
            log_prob_increment,
            validate_args=False,
            allow_nan_stats=False,  # pylint: disable=unused-argument
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,  # pylint: disable=unused-argument
            log_prob_ratio_fn=None,
            name='IncrementLogProb',
            **kwargs):
        """Construct a `IncrementLogProb` distribution-like object.

    Args:
      log_prob_increment: Float Tensor or callable returning a float Tensor. Log
        probability/density to increment by.
      validate_args: This argument is ignored, but is present because it is used
        in certain situations where `Distribution`s are expected.
      allow_nan_stats: This argument is ignored, but is present because it is
        used in certain situations where `Distribution`s are expected.
      reparameterization_type: This argument is ignored, but is present because
        it is used in certain situations where `Distribution`s are expected.
      log_prob_ratio_fn: Optional callable with signature `(p_kwargs, q_kwargs)
        -> log_prob_ratio`, used to implement a custom `p_log_prob_increment -
        q_log_prob_increment` computation.
      name: Python `str` name prefixed to Ops created by this class.
      **kwargs: Passed to `log_prob_increment` if it is callable.
    """
        self._parameters = dict(locals())

        with tf.name_scope(name) as name:
            if callable(log_prob_increment):
                log_prob_increment_fn = lambda: tensor_util.convert_nonref_to_tensor(  # pylint: disable=g-long-lambda
                    log_prob_increment(**kwargs))
                spec = callable_util.get_output_spec(log_prob_increment_fn)
            else:
                if kwargs:
                    raise ValueError(
                        '`kwargs` is only valid when `log_prob_increment` is callable.'
                    )
                log_prob_increment = tensor_util.convert_nonref_to_tensor(
                    log_prob_increment)
                log_prob_increment_fn = lambda: log_prob_increment
                spec = log_prob_increment

            self._log_prob_increment_fn = log_prob_increment_fn
            self._log_prob_increment = log_prob_increment
            self._dtype = spec.dtype
            self._batch_shape = spec.shape
            self._name = name
            self._validate_args = validate_args
            self._log_prob_ratio_fn = log_prob_ratio_fn
            self._kwargs = kwargs
 def test_get_output_spec_from_tensor_specs(self):
     args = (tf.TensorSpec([], dtype=tf.float32),
             (tf.TensorSpec([1, 1], dtype=tf.float32),
              tf.TensorSpec([2], dtype=tf.float64)))
     additional_args = (tf.TensorSpec([2, 1], dtype=tf.int32), )
     # Trace using both positional and keyword args.
     results = callable_util.get_output_spec(
         _return_args_from_infinite_loop,
         *args,
         additional_loop_vars=additional_args)
     self.assertAllEqualNested(
         tf.nest.map_structure(lambda x: x.shape, args + additional_args),
         tf.nest.map_structure(lambda x: x.shape, results))
     self.assertAllAssertsNested(
         self.assertEqual,
         tf.nest.map_structure(lambda x: x.dtype, args + additional_args),
         tf.nest.map_structure(lambda x: x.dtype, results))
 def test_get_output_spec_loop(self):
     args = (np.array(0., dtype=np.float64),
             (tf.convert_to_tensor(0.),
              tf.convert_to_tensor([1., 1.], dtype=tf.float64)))
     additional_args = (tf.convert_to_tensor([[3], [4]], dtype=tf.int32), )
     # Trace using both positional and keyword args.
     results = callable_util.get_output_spec(
         _return_args_from_infinite_loop,
         *args,
         additional_loop_vars=additional_args)
     self.assertAllEqualNested(
         tf.nest.map_structure(lambda x: tf.convert_to_tensor(x).shape,
                               args + additional_args),
         tf.nest.map_structure(lambda x: x.shape, results))
     self.assertAllAssertsNested(
         self.assertEqual,
         tf.nest.map_structure(lambda x: tf.convert_to_tensor(x).dtype,
                               args + additional_args),
         tf.nest.map_structure(lambda x: x.dtype, results))
def make_distribution_bijector(distribution,
                               name='make_distribution_bijector'):
    """Builds a bijector to approximately transform `N(0, 1)` into `distribution`.

  This represents a distribution as a bijector that
  transforms a (multivariate) standard normal distribution into the distribution
  of interest.

  Args:
    distribution: A `tfd.Distribution` instance; this may be a joint
      distribution.
    name: Python `str` name for ops created by this method.
  Returns:
    distribution_bijector: a `tfb.Bijector` instance such that
      `distribution_bijector(tfd.Normal(0., 1.))` is approximately equivalent
      to `distribution`.

  #### Examples

  This method may be used to convert structured variational distributions into
  MCMC preconditioners. Consider a model containing
  [funnel geometry](https://crackedbassoon.com/writing/funneling), which may
  be difficult for an MCMC algorithm to sample directly.

  ```python
  model_with_funnel = tfd.JointDistributionSequentialAutoBatched([
      tfd.Normal(loc=-1., scale=2., name='z'),
      lambda z: tfd.Normal(loc=[0., 0., 0.], scale=tf.exp(z), name='x'),
      lambda x: tfd.Poisson(log_rate=x, name='y')])
  pinned_model = tfp.experimental.distributions.JointDistributionPinned(
      model_with_funnel, y=[1, 3, 0])
  ```

  We can approximate the posterior in this model using a structured variational
  surrogate distribution, which will capture the funnel geometry, but cannot
  exactly represent the (non-Gaussian) posterior.

  ```python
  # Build and fit a structured surrogate posterior distribution.
  surrogate_posterior = tfp.experimental.vi.build_asvi_surrogate_posterior(
    pinned_model)
  _ = tfp.vi.fit_surrogate_posterior(pinned_model.unnormalized_log_prob,
                                     surrogate_posterior=surrogate_posterior,
                                     optimizer=tf.optimizers.Adam(0.01),
                                     num_steps=200)
  ```

  Creating a preconditioning bijector allows us to obtain higher-quality
  posterior samples, without any Gaussianity assumption, by using the surrogate
  to guide an MCMC sampler.

  ```python
  surrogate_posterior_bijector = (
    tfp.experimental.bijectors.make_distribution_bijector(surrogate_posterior))
  samples, _ = tfp.mcmc.sample_chain(
    kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
      tfp.mcmc.TransformedTransitionKernel(
        tfp.mcmc.NoUTurnSampler(pinned_model.unnormalized_log_prob,
                                step_size=0.1),
        bijector=surrogate_posterior_bijector),
      num_adaptation_steps=80),
    current_state=surrogate_posterior.sample(),
    num_burnin_steps=100,
    trace_fn=lambda _0, _1: [],
    num_results=500)
  ```

  #### Mathematical details

  The bijectors returned by this method generally follow the following
  principles, although the specific bijectors returned may vary without notice.

  Normal distributions are reparameterized by a location-scale transform.

  ```python
  b = tfp.experimental.bijectors.make_distribution_bijector(
    tfd.Normal(loc=10., scale=5.))
  # ==> tfb.Shift(10.)(tfb.Scale(5.)))

  b = tfp.experimental.bijectors.make_distribution_bijector(
    tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale_tril))
  # ==> tfb.Shift(loc)(tfb.ScaleMatvecTriL(scale_tril))
  ```

  The distribution's `quantile` function is used, when available:

  ```python
  d = tfd.Cauchy(loc=loc, scale=scale)
  b = tfp.experimental.bijectors.make_distribution_bijector(d)
  # ==> tfb.Inline(forward_fn=d.quantile, inverse_fn=d.cdf)(tfb.NormalCDF())
  ```

  Otherwise, a quantile function is derived by inverting the CDF:

  ```python
  d = tfd.Gamma(concentration=alpha, rate=beta)
  b = tfp.experimental.bijectors.make_distribution_bijector(d)
  # ==> tfb.Invert(
  #  tfp.experimental.bijectors.ScalarFunctionWithInferredInverse(fn=d.cdf))(
  #    tfb.NormalCDF())
  ```

  Transformed distributions are represented by chaining the
  transforming bijector with a preconditioning bijector for the base
  distribution:

  ```python
  b = tfp.experimental.bijectors.make_distribution_bijector(
    tfb.Exp(tfd.Normal(loc=10., scale=5.)))
  # ==> tfb.Exp(tfb.Shift(10.)(tfb.Scale(5.)))
  ```

  Joint distributions are represented by a joint bijector, which converts each
  component distribution to a bijector with parameters conditioned
  on the previous variables in the model. The joint bijector's inputs and
  outputs follow the structure of the joint distribution.

  ```python
  jd = tfd.JointDistributionNamed(
      {'a': tfd.InverseGamma(concentration=2., scale=1.),
       'b': lambda a: tfd.Normal(loc=3., scale=tf.sqrt(a))})
  b = tfp.experimental.bijectors.make_distribution_bijector(jd)
  whitened_jd = tfb.Invert(b)(jd)
  x = whitened_jd.sample()
  # x <=> {'a': tfd.Normal(0., 1.).sample(), 'b': tfd.Normal(0., 1.).sample()}
  ```

  """

    with tf.name_scope(name):
        event_space_bijector = (
            distribution.experimental_default_event_space_bijector())
        if event_space_bijector is None:  # Fail if the distribution is discrete.
            raise NotImplementedError(
                'Cannot transform distribution {} to a standard normal '
                'distribution.'.format(distribution))

        # Recurse over joint distributions.
        if isinstance(distribution, joint_distribution.JointDistribution):
            return joint_distribution._DefaultJointBijector(  # pylint: disable=protected-access
                distribution,
                bijector_fn=make_distribution_bijector)

        # Recurse through transformed distributions.
        if isinstance(distribution,
                      transformed_distribution.TransformedDistribution):
            return distribution.bijector(
                make_distribution_bijector(distribution.distribution))

        # If we've annotated a specific bijector for this distribution, use that.
        if isinstance(distribution, tuple(preconditioning_bijector_fns)):
            return preconditioning_bijector_fns[type(distribution)](
                distribution)

        # Otherwise, if this distribution implements a CDF and inverse CDF, build
        # a bijector from those.
        implements_cdf = False
        implements_quantile = False
        input_spec = tf.zeros(shape=distribution.event_shape,
                              dtype=distribution.dtype)
        try:
            callable_util.get_output_spec(distribution.cdf, input_spec)
            implements_cdf = True
        except NotImplementedError:
            pass
        try:
            callable_util.get_output_spec(distribution.quantile, input_spec)
            implements_quantile = True
        except NotImplementedError:
            pass
        if implements_cdf and implements_quantile:
            # This path will only trigger for scalar distributions, since multivariate
            # distributions have non-invertible CDF and so cannot define a `quantile`.
            return tfb.Inline(forward_fn=distribution.quantile,
                              inverse_fn=distribution.cdf,
                              forward_min_event_ndims=ps.rank_from_shape(
                                  distribution.event_shape_tensor,
                                  distribution.event_shape))(tfb.NormalCDF())

        # If the events are scalar, try to invert the CDF numerically.
        if implements_cdf and tf.get_static_value(
                distribution.is_scalar_event()):
            return tfb.Invert(
                scalar_function_with_inferred_inverse.
                ScalarFunctionWithInferredInverse(
                    distribution.cdf,
                    domain_constraint_fn=(event_space_bijector)))(
                        tfb.NormalCDF())

        raise NotImplementedError('Could not automatically construct a '
                                  'bijector for distribution type '
                                  '{}; it does not implement an invertible '
                                  'CDF.'.format(distribution))
Beispiel #6
0
def bracket_root(objective_fn,
                 dtype=tf.float32,
                 num_points=512,
                 name='bracket_root'):
    """Finds bounds that bracket a root of the objective function.

  This method attempts to return an interval bracketing a root of the objective
  function. It evaluates the objective in parallel at `num_points`
  locations, at exponentially increasing distance from the origin, and returns
  the first pair of adjacent points `[low, high]` such that the objective is
  finite and has a different sign at the two points. If no such pair was
  observed, it returns the trivial interval
  `[np.finfo(dtype).min, np.finfo(dtype).max]` containing all float values of
  the specified `dtype`. If the objective has multiple
  roots, the returned interval will contain at least one (but perhaps not all)
  of the roots.

  Args:
    objective_fn: Python callable for which roots are searched. It must be a
      continuous function that accepts a scalar `Tensor` of type `dtype` and
      returns a `Tensor` of shape `batch_shape`.
    dtype: Optional float `dtype` of inputs to `objective_fn`.
      Default value: `tf.float32`.
    num_points: Optional Python `int` number of points at which to evaluate
      the objective.
      Default value: `512`.
    name: Python `str` name given to ops created by this method.
  Returns:
    low: Float `Tensor` of shape `batch_shape` and dtype `dtype`. Lower bound
      on a root of `objective_fn`.
    high: Float `Tensor` of shape `batch_shape` and dtype `dtype`. Upper bound
      on a root of `objective_fn`.
  """
    with tf.name_scope(name):
        # Build a logarithmic sequence of `num_points` values from -inf to inf.
        dtype_info = np.finfo(dtype_util.as_numpy_dtype(dtype))
        xs_positive = tf.exp(
            tf.linspace(tf.cast(-10., dtype), tf.math.log(dtype_info.max),
                        num_points // 2))
        xs = tf.concat([tf.reverse(-xs_positive, axis=[0]), xs_positive],
                       axis=0)

        # Evaluate the objective at all points. The objective function may return
        # a batch of values (e.g., `objective(x) = x - batch_of_roots`).
        if NUMPY_MODE:
            objective_output_spec = objective_fn(tf.zeros([], dtype=dtype))
        else:
            objective_output_spec = callable_util.get_output_spec(
                objective_fn, tf.convert_to_tensor(0., dtype=dtype))
        batch_ndims = tensorshape_util.rank(objective_output_spec.shape)
        if batch_ndims is None:
            raise ValueError('Cannot infer tensor rank of objective values.')
        xs_pad_shape = ps.pad([num_points],
                              paddings=[[0, batch_ndims]],
                              constant_values=1)
        ys = objective_fn(tf.reshape(xs, xs_pad_shape))

        # Find the smallest point where the objective is finite.
        is_finite = tf.math.is_finite(ys)
        ys_transposed = distribution_util.move_dimension(  # For batch gather.
            ys, 0, -1)
        first_finite_value = tf.gather(
            ys_transposed,
            tf.argmax(is_finite, axis=0),  # Index of smallest finite point.
            batch_dims=batch_ndims,
            axis=-1)
        # Select the next point where the objective has a different sign.
        sign_change_idx = tf.argmax(
            tf.not_equal(tf.math.sign(ys), tf.math.sign(first_finite_value))
            & is_finite,
            axis=0)
        # If the sign never changes, we can't bracket a root.
        bracketing_failed = tf.equal(sign_change_idx, 0)
        # If the objective's sign is zero, we've found an actual root.
        root_found = tf.equal(
            tf.gather(tf.math.sign(ys_transposed),
                      sign_change_idx,
                      batch_dims=batch_ndims,
                      axis=-1), 0.)
        return _structure_broadcasting_where(
            bracketing_failed,
            # If we didn't detect a sign change, fall back to the trivial interval.
            (dtype_info.min, dtype_info.max),
            # Otherwise, return the points around the sign change, unless we
            # actually evaluated a root, in which case, return the zero-width
            # bracket at that root.
            (tf.gather(
                xs,
                tf.where(bracketing_failed | root_found, sign_change_idx,
                         sign_change_idx - 1)), tf.gather(xs, sign_change_idx)
             ))
 def test_get_output_spec_oom(self):
     result = callable_util.get_output_spec(_compute_oom)
     self.assertEqual((int(1e9), int(1e9)), result.shape)
     self.assertEqual(tf.float32, result.dtype)