Beispiel #1
0
 def _add_right_batch_dim(obs, event_shape):
     ndims = prefer_static.rank_from_shape(prefer_static.shape(obs))
     event_ndims = prefer_static.rank_from_shape(event_shape)
     return tf.expand_dims(obs, ndims - event_ndims)
Beispiel #2
0
    def joint_distribution(self,
                           observed_time_series=None,
                           num_timesteps=None,
                           trajectories_shape=None,
                           initial_step=0,
                           mask=None,
                           experimental_parallelize=False):
        """Constructs the joint distribution over parameters and observed values.

    Args:
      observed_time_series: Optional observed time series to model, as a
        `Tensor` or `tfp.sts.MaskedTimeSeries` instance having shape
        `concat([batch_shape, trajectories_shape, num_timesteps, 1])`. If
        an observed time series is provided, the `num_timesteps`,
        `trajectories_shape`, and `mask` arguments are ignored, and
        an unnormalized (pinned) distribution over parameter values is returned.
        Default value: `None`.
      num_timesteps: scalar `int` `Tensor` number of timesteps to model. This
        must be specified either directly or by passing an
        `observed_time_series`.
        Default value: `0`.
      trajectories_shape: `int` `Tensor` shape of sampled trajectories
        for each set of parameter values. If not specified (either directly
        or by passing an `observed_time_series`), defaults to a
        one-to-one correspondence between trajectories and parameter settings
        (implicitly `trajectories_shape=()`).
        Default value: `None`.
      initial_step: Optional scalar `int` `Tensor` specifying the starting
        timestep.
        Default value: `0`.
      mask: Optional `bool` `Tensor` having shape
        `concat([batch_shape, trajectories_shape, num_timesteps])`, in which
        `True` entries indicate that the series value at the corresponding step
        is missing and should be ignored. This argument should be passed only
        if `observed_time_series` is not specified or does not already contain
        a missingness mask; it is an error to pass both this
        argument and an `observed_time_series` value containing a missingness
        mask.
        Default value: `None`.
      experimental_parallelize: If `True`, use parallel message passing
        algorithms from `tfp.experimental.parallel_filter` to perform time
        series operations in `O(log num_timesteps)` sequential steps. The
        overall FLOP and memory cost may be larger than for the sequential
        implementations by a constant factor.
        Default value: `False`.
    Returns:
      joint_distribution: joint distribution of model parameters and
        observed trajectories. If no `observed_time_series` was specified, this
        is an instance of `tfd.JointDistributionNamedAutoBatched` with a
        random variable for each model parameter (with names and order matching
        `self.parameters`), plus a final random variable `observed_time_series`
        representing a trajectory(ies) conditioned on the parameters. If
        `observed_time_series` was specified, the return value is given by
        `joint_distribution.experimental_pin(
        observed_time_series=observed_time_series)` where `joint_distribution`
        is as just described, so it defines an unnormalized posterior
        distribution over the parameters.

    #### Example:

    The joint distribution can generate prior samples of parameters and
    trajectories:

    ```python
    from matplotlib import pylab as plt
    import tensorflow_probability as tfp

    # Sample and plot 100 trajectories from the prior.
    model = tfp.sts.LocalLinearTrendModel()
    prior_samples = model.joint_distribution().sample([100])
    plt.plot(
      tf.linalg.matrix_transpose(prior_samples['observed_time_series'][..., 0]))
    ```

    It also integrates with TFP inference APIs, providing a more flexible
    alternative to the STS-specific fitting utilities.

    ```python
    jd = model.joint_distribution(observed_time_series)

    # Variational inference.
    surrogate_posterior = (
      tfp.experimental.vi.build_factored_surrogate_posterior(
        event_shape=jd.event_shape,
        bijector=jd.experimental_default_event_space_bijector()))
    losses = tfp.vi.fit_surrogate_posterior(
      target_log_prob_fn=jd.unnormalized_log_prob,
      surrogate_posterior=surrogate_posterior,
      optimizer=tf.optimizers.Adam(0.1),
      num_steps=200)
    parameter_samples = surrogate_posterior.sample(50)

    # No U-Turn Sampler.
    samples, kernel_results = tfp.experimental.mcmc.windowed_adaptive_nuts(
      n_draws=500, joint_dist=dist)
    ```

    """
        def state_space_model_likelihood(**param_vals):
            ssm = self.make_state_space_model(
                param_vals=param_vals,
                num_timesteps=num_timesteps,
                initial_step=initial_step,
                mask=mask,
                experimental_parallelize=experimental_parallelize)
            # Looping LGSSM methods are really expensive in eager mode; wrap them
            # to keep this from slowing things down in interactive use.
            ssm = tfe_util.JitPublicMethods(ssm, trace_only=True)
            if distribution_util.shape_may_be_nontrivial(trajectories_shape):
                return sample.Sample(ssm, sample_shape=trajectories_shape)
            return ssm

        batch_ndims = ps.rank_from_shape(self.batch_shape_tensor,
                                         self.batch_shape)
        if observed_time_series is not None:
            [observed_time_series, is_missing
             ] = sts_util.canonicalize_observed_time_series_with_mask(
                 observed_time_series)
            if is_missing is not None:
                if mask is not None:
                    raise ValueError(
                        'Passed non-None value for `mask`, but the observed '
                        'time series already contains an `is_missing` mask.')
                mask = is_missing
            num_timesteps = ps.shape(observed_time_series)[-2]
            trajectories_shape = ps.shape(observed_time_series)[batch_ndims:-2]

        joint_distribution = (
            joint_distribution_auto_batched.JointDistributionNamedAutoBatched(
                model=collections.OrderedDict(
                    # Prior.
                    list(self._joint_prior_distribution().model.items()) +
                    # Likelihood.
                    [('observed_time_series', state_space_model_likelihood)]),
                use_vectorized_map=False,
                batch_ndims=batch_ndims))

        if observed_time_series is not None:
            return joint_distribution.experimental_pin(
                observed_time_series=observed_time_series)

        return joint_distribution
        def posterior_generator():

            prior_gen = prior._model_coroutine()  # pylint: disable=protected-access
            dist = next(prior_gen)

            i = 0
            try:
                while True:
                    original_dist = dist.distribution if isinstance(
                        dist, Root) else dist

                    if isinstance(original_dist,
                                  joint_distribution.JointDistribution):
                        # TODO(kateslin): Build inner JD surrogate in
                        # _make_asvi_trainable_variables to avoid rebuilding variables.
                        raise TypeError(
                            'Argument `prior` cannot be a nested `JointDistribution`.'
                        )

                    else:

                        original_dist = _as_trainable_family(original_dist)

                        try:
                            actual_dist = original_dist.distribution
                        except AttributeError:
                            actual_dist = original_dist

                        dist_params = actual_dist.parameters
                        temp_params_dict = {}

                        for param, value in dist_params.items():
                            if param in (
                                    _NON_STATISTICAL_PARAMS +
                                    _NON_TRAINABLE_PARAMS) or value is None:
                                temp_params_dict[param] = value
                            else:
                                prior_weight = param_dicts[i][
                                    param].prior_weight
                                mean_field_parameter = param_dicts[i][
                                    param].mean_field_parameter
                                if mean_field:
                                    temp_params_dict[
                                        param] = mean_field_parameter
                                else:
                                    temp_params_dict[
                                        param] = prior_weight * value + (
                                            1. - prior_weight
                                        ) * mean_field_parameter

                        if isinstance(original_dist, sample.Sample):
                            inner_dist = type(actual_dist)(**temp_params_dict)

                            surrogate_dist = independent.Independent(
                                inner_dist,
                                reinterpreted_batch_ndims=ps.rank_from_shape(
                                    original_dist.sample_shape))
                        else:
                            surrogate_dist = type(actual_dist)(
                                **temp_params_dict)

                        if isinstance(
                                original_dist, transformed_distribution.
                                TransformedDistribution):
                            surrogate_dist = transformed_distribution.TransformedDistribution(
                                surrogate_dist,
                                bijector=original_dist.bijector)

                        if isinstance(original_dist, independent.Independent):
                            surrogate_dist = independent.Independent(
                                surrogate_dist,
                                reinterpreted_batch_ndims=original_dist.
                                reinterpreted_batch_ndims)

                        if isinstance(dist, Root):
                            value_out = yield Root(surrogate_dist)
                        else:
                            value_out = yield surrogate_dist

                    dist = prior_gen.send(value_out)
                    i += 1
            except StopIteration:
                pass
Beispiel #4
0
 def test_rank_from_shape_scalar(self):
     self.assertEqual(1, ps.rank_from_shape(5))
     v = tf.Variable(4, shape=tf.TensorShape(None))
     self.evaluate(v.initializer)
     self.assertEqual(1, self.evaluate(ps.rank_from_shape(v)))
def _is_scalar_from_shape_tensor(shape):
    """Returns `True` `Tensor` if `Tensor` shape implies a scalar."""
    return prefer_static.equal(prefer_static.rank_from_shape(shape), 0)
Beispiel #6
0
    def _distributional_transform(self, x, event_shape):
        """Performs distributional transform of the mixture samples.

    Distributional transform removes the parameters from samples of a
    multivariate distribution by applying conditional CDFs:
      (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1))
    (the indexing is over the 'flattened' event dimensions).
    The result is a sample of product of Uniform[0, 1] distributions.

    We assume that the components are factorized, so the conditional CDFs become
      F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i),
    where w_i^k is the posterior mixture weight: for i > 0
      w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1)
    and w_0^k = w_k is the mixture probability of the k-th component.

    Arguments:
      x: Sample of mixture distribution
      event_shape: The event shape of this distribution

    Returns:
      Result of the distributional transform
    """

        if tensorshape_util.rank(x.shape) is None:
            # tf.math.softmax raises an error when applied to inputs of undefined
            # rank.
            raise ValueError(
                'Distributional transform does not support inputs of '
                'undefined rank.')

        # Obtain factorized components distribution and assert that it's
        # a scalar distribution.
        if isinstance(self._components_distribution, independent.Independent):
            univariate_components = self._components_distribution.distribution
        else:
            univariate_components = self._components_distribution

        with tf.control_dependencies([
                assert_util.assert_equal(
                    univariate_components.is_scalar_event(),
                    True,
                    message='`univariate_components` must have scalar event')
        ]):
            event_ndims = prefer_static.rank_from_shape(event_shape)
            x_padded = self._pad_sample_dims(
                x, event_ndims=event_ndims)  # [S, B, 1, E]
            log_prob_x = univariate_components.log_prob(
                x_padded)  # [S, B, k, E]
            cdf_x = univariate_components.cdf(x_padded)  # [S, B, k, E]

            # log prob_k (x_1, ..., x_i-1)
            event_size = prefer_static.cast(
                prefer_static.reduce_prod(event_shape), dtype=tf.int32)
            cumsum_log_prob_x = tf.reshape(
                tf.math.cumsum(
                    # [S*prod(B)*k, prod(E)]
                    tf.reshape(log_prob_x, [-1, event_size]),
                    exclusive=True,
                    axis=-1),
                tf.shape(log_prob_x))  # [S, B, k, E]

            event_ndims = prefer_static.rank_from_shape(event_shape)
            logits_mix_prob = self.mixture_distribution.logits_parameter()
            logits_mix_prob = tf.reshape(
                logits_mix_prob,  # [k] or [B, k]
                prefer_static.concat([
                    prefer_static.shape(logits_mix_prob),
                    prefer_static.ones([event_ndims], dtype=tf.int32),
                ],
                                     axis=0))  # [k, [1]*e] or [B, k, [1]*e]

            # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1)
            log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x

            component_axis = tensorshape_util.rank(x.shape) - event_ndims
            posterior_weights_x = tf.math.softmax(log_posterior_weights_x,
                                                  axis=component_axis)
            return tf.reduce_sum(posterior_weights_x * cdf_x,
                                 axis=component_axis)
Beispiel #7
0
def _kl_independent(a, b, name='kl_independent'):
    """Batched KL divergence `KL(a || b)` for Independent distributions.

  We can leverage the fact that
  ```
  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
  ```
  where the sum is over the `reinterpreted_batch_ndims`.

  Args:
    a: Instance of `Independent`.
    b: Instance of `Independent`.
    name: (optional) name to use for created ops. Default 'kl_independent'.

  Returns:
    Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the event space for `a` and `b`, or their underlying
      distributions don't match.
  """
    p = a.distribution
    q = b.distribution

    # The KL between any two (non)-batched distributions is a scalar.
    # Given that the KL between two factored distributions is the sum, i.e.
    # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
    # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
    if (tensorshape_util.is_fully_defined(a.event_shape)
            and tensorshape_util.is_fully_defined(b.event_shape)):
        if a.event_shape == b.event_shape:
            if p.event_shape == q.event_shape:
                num_reduce_dims = (tensorshape_util.rank(a.event_shape) -
                                   tensorshape_util.rank(p.event_shape))
                reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]

                return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                    q,
                                                                    name=name),
                                     axis=reduce_dims)
            else:
                raise NotImplementedError(
                    'KL between Independents with different '
                    'event shapes not supported.')
        else:
            raise ValueError('Event shapes do not match.')
    else:
        p_event_shape_tensor = p.event_shape_tensor()
        q_event_shape_tensor = q.event_shape_tensor()
        # NOTE: We could optimize by passing the event_shape_tensor of p and q
        # to a.event_shape_tensor() and b.event_shape_tensor().
        a_event_shape_tensor = a.event_shape_tensor()
        b_event_shape_tensor = b.event_shape_tensor()
        with tf.control_dependencies([
                assert_util.assert_equal(a_event_shape_tensor,
                                         b_event_shape_tensor,
                                         message='Event shapes do not match.'),
                assert_util.assert_equal(p_event_shape_tensor,
                                         q_event_shape_tensor,
                                         message='Event shapes do not match.'),
        ]):
            num_reduce_dims = (
                ps.rank_from_shape(a_event_shape_tensor, a.event_shape) -
                ps.rank_from_shape(p_event_shape_tensor, p.event_shape))
            reduce_dims = ps.range(-num_reduce_dims, 0, 1)
            return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                q,
                                                                name=name),
                                 axis=reduce_dims)
Beispiel #8
0
 def _get_default_reinterpreted_batch_ndims(self, distribution):
     """Computes the default value for reinterpreted_batch_ndim __init__ arg."""
     ndims = prefer_static.rank_from_shape(distribution.batch_shape_tensor,
                                           distribution.batch_shape)
     return prefer_static.maximum(0, ndims - 1)
Beispiel #9
0
    def testParameterProperties(self, bijector_name, data):
        if tf.config.functions_run_eagerly() or not tf.executing_eagerly():
            self.skipTest(
                'To reduce test weight, parameter properties tests run in '
                'eager mode only.')

        non_trainable_params = (
            'bijector',  # Several.
            'forward_fn',  # Inline.
            'inverse_fn',  # Inline.
            'forward_min_event_ndims',  # Inline.
            'inverse_min_event_ndims',  # Inline.
            'event_shape_out',  # Reshape.
            'event_shape_in',  # Reshape.
            'perm',  # Transpose.
            'rightmost_transposed_ndims',  # Transpose.
            'diag_bijector',  # TransformDiagonal.
        )
        bijector, event_dim = self._draw_bijector(
            bijector_name,
            data,
            validate_args=True,
            allowed_bijectors=TF2_FRIENDLY_BIJECTORS)

        # Extract the full shape of an output from this bijector.
        xs = self._draw_domain_tensor(bijector, data, event_dim)
        ys = bijector.forward(xs)
        output_shape = prefer_static.shape(ys)
        sample_and_batch_ndims = (prefer_static.rank_from_shape(output_shape) -
                                  bijector.inverse_min_event_ndims)

        try:
            params = type(bijector).parameter_properties()
            params64 = type(bijector).parameter_properties(dtype=tf.float64)
        except NotImplementedError as e:
            self.skipTest(str(e))

        seeds = samplers.split_seed(test_util.test_seed(), n=len(params))
        new_parameters = {}
        for i, (param_name, param) in enumerate(params.items()):
            # Check that the shape_fn is consistent with event_ndims.
            try:
                param_shape = param.shape_fn(sample_shape=output_shape)
            except NotImplementedError:
                self.skipTest('No shape function implemented for bijector {} '
                              'parameter {}.'.format(bijector_name,
                                                     param_name))
            self.assertGreaterEqual(
                param.event_ndims,
                prefer_static.rank_from_shape(param_shape) -
                sample_and_batch_ndims)

            if param.is_preferred:
                try:
                    param_bijector = param.default_constraining_bijector_fn()
                except NotImplementedError:
                    self.skipTest(
                        'No constraining bijector implemented for {} '
                        'parameter {}.'.format(bijector_name, param_name))
                unconstrained_shape = (
                    param_bijector.inverse_event_shape_tensor(param_shape))
                unconstrained_param = samplers.normal(unconstrained_shape,
                                                      seed=seeds[i])
                new_parameters[param_name] = param_bijector.forward(
                    unconstrained_param)

                # Check that passing a float64 `eps` works with float64 parameters.
                b_float64 = params64[
                    param_name].default_constraining_bijector_fn()
                b_float64(tf.cast(unconstrained_param, tf.float64))

        # Copy over any non-trainable parameters.
        new_parameters.update({
            k: v
            for (k, v) in bijector.parameters.items()
            if k in non_trainable_params
        })

        # Sanity check that we got valid parameters.
        new_parameters['validate_args'] = True
        new_bijector = type(bijector)(**new_parameters)
        self.evaluate(
            tf.group(*[v.initializer for v in new_bijector.variables]))
        xs = self._draw_domain_tensor(new_bijector, data, event_dim)
        self.evaluate(new_bijector.forward(xs))
class MarkovChainBijectorTest(test_util.TestCase):

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        dict(testcase_name='deterministic_prior',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)),
        dict(testcase_name='deterministic_transition',
             prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='fully_deterministic',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='mvn_diag',
             prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]],
                                                          scale_diag=[1.])),
             transition_fn=lambda _, x: tfd.VectorDeterministic(x)),
        dict(testcase_name='docstring_dirichlet',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 {'probs': tfd.Dirichlet([1., 1.])}),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 {
                     'probs':
                     tfd.MultivariateNormalDiag(loc=x['probs'],
                                                scale_diag=[0.1, 0.1])
                 },
                 batch_ndims=ps.rank(x['probs']))),
        dict(testcase_name='uniform_step',
             prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])),
             transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)),
        dict(testcase_name='joint_distribution',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=2,
                 model={
                     'a':
                     tfd.Gamma(tf.zeros([5]), 1.),
                     'b':
                     lambda a: (tfb.Reshape(event_shape_in=[4, 3],
                                            event_shape_out=[2, 3, 2])
                                (tfd.Independent(tfd.Normal(
                                    loc=tf.zeros([5, 4, 3]),
                                    scale=a[..., tf.newaxis, tf.newaxis]),
                                                 reinterpreted_batch_ndims=2)))
                 }),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=ps.rank_from_shape(x['a'].shape),
                 model={
                     'a':
                     tfd.Normal(loc=x['a'], scale=1.),
                     'b':
                     lambda a: tfd.Deterministic(x['b'] + a[
                         ..., tf.newaxis, tf.newaxis, tf.newaxis])
                 })),
        dict(testcase_name='nested_chain',
             prior_fn=lambda: tfd.
             MarkovChain(initial_state_prior=tfb.Split(2)
                         (tfd.MultivariateNormalDiag(0., [1., 2.])),
                         transition_fn=lambda _, x: tfb.Split(2)
                         (tfd.MultivariateNormalDiag(x[0], [1., 2.])),
                         num_steps=6),
             transition_fn=(
                 lambda _, x: tfd.JointDistributionSequentialAutoBatched(
                     [
                         tfd.MultivariateNormalDiag(x[0], [1.]),
                         tfd.MultivariateNormalDiag(x[1], [1.])
                     ],
                     batch_ndims=ps.rank(x[0])))))
    # pylint: enable=g-long-lambda
    def test_default_bijector(self, prior_fn, transition_fn):
        chain = tfd.MarkovChain(initial_state_prior=prior_fn(),
                                transition_fn=transition_fn,
                                num_steps=7)

        y = self.evaluate(chain.sample(seed=test_util.test_seed()))
        bijector = chain.experimental_default_event_space_bijector()

        self.assertAllEqual(chain.batch_shape_tensor(),
                            bijector.experimental_batch_shape_tensor())

        x = bijector.inverse(y)
        yy = bijector.forward(tf.nest.map_structure(
            tf.identity, x))  # Bypass bijector cache.
        self.assertAllCloseNested(y, yy)

        chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape,
                                                  chain.event_shape_tensor())
        self.assertAllEqualNested(bijector.inverse_min_event_ndims,
                                  chain_event_ndims)

        ildj = bijector.inverse_log_det_jacobian(
            tf.nest.map_structure(tf.identity, y),  # Bypass bijector cache.
            event_ndims=chain_event_ndims)
        if not bijector.is_constant_jacobian:
            self.assertAllEqual(ildj.shape, chain.batch_shape)
        fldj = bijector.forward_log_det_jacobian(
            tf.nest.map_structure(tf.identity, x),  # Bypass bijector cache.
            event_ndims=bijector.inverse_event_ndims(chain_event_ndims))
        self.assertAllClose(ildj, -fldj)

        # Verify that event shapes are passed through and flattened/unflattened
        # correctly.
        inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape)
        x_event_shapes = tf.nest.map_structure(
            lambda t, nd: t.shape[ps.rank(t) - nd:], x,
            bijector.forward_min_event_ndims)
        self.assertAllEqualNested(inverse_event_shapes, x_event_shapes)
        forward_event_shapes = bijector.forward_event_shape(
            inverse_event_shapes)
        self.assertAllEqualNested(forward_event_shapes, chain.event_shape)

        # Verify that the outputs of other methods have the correct structure.
        inverse_event_shape_tensors = bijector.inverse_event_shape_tensor(
            chain.event_shape_tensor())
        self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes)
        forward_event_shape_tensors = bijector.forward_event_shape_tensor(
            inverse_event_shape_tensors)
        self.assertAllEqualNested(forward_event_shape_tensors,
                                  chain.event_shape_tensor())
Beispiel #11
0
 def _batch_shape_tensor(self):
     with tf.control_dependencies(self._runtime_assertions):
         batch_shape = self.distribution.batch_shape_tensor()
         batch_ndims = prefer_static.rank_from_shape(
             batch_shape, self.distribution.batch_shape)
         return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims]
Beispiel #12
0
 def _has_nonzero_rank(self, override_shape):
     return prefer_static.logical_not(
         prefer_static.equal(prefer_static.rank_from_shape(override_shape),
                             self._zero))
  def vectorized_fn(*args):
    """Vectorized version of `fn` that accepts arguments of any rank."""
    with tf.name_scope(name or 'make_rank_polymorphic'):
      assertions = []

      # If we got a single value for core_ndims, tile it across all args.
      core_ndims_structure = (
          core_ndims
          if tf.nest.is_nested(core_ndims)
          else tf.nest.map_structure(lambda _: core_ndims, args))

      # Build flat lists of all argument parts and their corresponding core
      # ndims.
      flat_core_ndims = tf.nest.flatten(core_ndims_structure)
      flat_args = nest.flatten_up_to(
          core_ndims_structure, args, check_types=False)

      # Filter to only the `Tensor`-valued args (taken to be those with `None`
      # values for `core_ndims`). Other args will be passed through to `fn`
      # unmodified.
      (vectorized_arg_core_ndims,
       vectorized_args,
       fn_of_vectorized_args) = _lock_in_non_vectorized_args(
           fn,
           arg_structure=core_ndims_structure,
           flat_core_ndims=flat_core_ndims,
           flat_args=flat_args)

      # `vectorized_map` requires all inputs to have a single, common batch
      # dimension `[n]`. So we broadcast all input parts to a common
      # batch shape, then flatten it down to a single dimension.

      # First, compute how many 'extra' (batch) ndims each part has. This must
      # be nonnegative.
      vectorized_arg_shapes = [tf.shape(arg) for arg in vectorized_args]
      batch_ndims = [
          ps.rank_from_shape(arg_shape) - nd
          for (arg_shape, nd) in zip(
              vectorized_arg_shapes, vectorized_arg_core_ndims)]
      static_ndims = [tf.get_static_value(nd) for nd in batch_ndims]
      if any([nd and nd < 0 for nd in static_ndims]):
        raise ValueError('Cannot broadcast a Tensor having lower rank than the '
                         'specified `core_ndims`! (saw input ranks {}, '
                         '`core_ndims` {}).'.format(
                             tf.nest.map_structure(
                                 ps.rank_from_shape,
                                 vectorized_arg_shapes),
                             vectorized_arg_core_ndims))
      if validate_args:
        for nd, part, core_nd in zip(
            batch_ndims, vectorized_args, vectorized_arg_core_ndims):
          assertions.append(tf.debugging.assert_non_negative(
              nd, message='Cannot broadcast a Tensor having lower rank than '
              'the specified `core_ndims`! (saw {} vs minimum rank {}).'.format(
                  part, core_nd)))

      # Next, split each part's shape into batch and core shapes, and
      # broadcast the batch shapes.
      with tf.control_dependencies(assertions):
        empty_shape = np.zeros([0], dtype=np.int32)
        batch_shapes, core_shapes = empty_shape, empty_shape
        if vectorized_arg_shapes:
          batch_shapes, core_shapes = zip(*[
              (arg_shape[:nd], arg_shape[nd:])
              for (arg_shape, nd) in zip(vectorized_arg_shapes, batch_ndims)])
        broadcast_batch_shape = (
            functools.reduce(ps.broadcast_shape, batch_shapes, []))

      # Flatten all of the batch dimensions into one.
      n = tf.cast(ps.reduce_prod(broadcast_batch_shape), tf.int32)
      static_n = tf.get_static_value(n)
      if static_n == 1:
        result = fn(*args)
      else:
        # Pad all input parts to the common shape, then flatten
        # into the single leading dimension `[n]`.
        # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap
        # when batch rank is static so that we can exploit broadcasting.
        broadcast_vectorized_args = [
            tf.broadcast_to(part, ps.concat(
                [broadcast_batch_shape, core_shape], axis=0))
            for (part, core_shape) in zip(vectorized_args, core_shapes)]
        vectorized_args_with_flattened_batch_dim = [
            tf.reshape(part, ps.concat([[n], core_shape], axis=0))
            for (part, core_shape) in zip(
                broadcast_vectorized_args, core_shapes)]
        batched_result = tf.vectorized_map(
            fn_of_vectorized_args, vectorized_args_with_flattened_batch_dim)

        # Unflatten any `Tensor`s in the result.
        unflatten = lambda x: tf.reshape(x, ps.concat([  # pylint: disable=g-long-lambda
            broadcast_batch_shape, ps.shape(x)[1:]], axis=0))
        result = tf.nest.map_structure(
            lambda x: unflatten(x) if tf.is_tensor(x) else x, batched_result,
            expand_composites=True)
    return result
Beispiel #14
0
 def move_particles_to_rightmost_batch_dim(x, event_shape):
     ndims = prefer_static.rank_from_shape(prefer_static.shape(x))
     event_ndims = prefer_static.rank_from_shape(event_shape)
     return dist_util.move_dimension(x, 0, ndims - event_ndims - 1)
Beispiel #15
0
 def _event_ndims(self):
     return prefer_static.rank_from_shape(
         self.components_distribution.event_shape_tensor,
         self.components_distribution.event_shape)
Beispiel #16
0
    def testAutoVectorization(self, bijector_name, data):

        # TODO(b/150161911): reconcile numeric behavior of eager and graph mode.
        if tf.executing_eagerly():
            return

        bijector, event_dim = self._draw_bijector(
            bijector_name,
            data,
            batch_shape=[],  # Avoid conflict with vmap sample dimension.
            validate_args=False,  # Work around lack of `If` support in vmap.
            allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) -
                               set(AUTOVECTORIZATION_IS_BROKEN)))
        atol = AUTOVECTORIZATION_ATOL[bijector_name]
        rtol = AUTOVECTORIZATION_RTOL[bijector_name]

        # Forward
        n = 3
        xs = self._draw_domain_tensor(bijector,
                                      data,
                                      event_dim,
                                      sample_shape=[n])
        ys = bijector.forward(xs)
        vectorized_ys = tf.vectorized_map(bijector.forward, xs)
        self.assertAllClose(*self.evaluate((ys, vectorized_ys)),
                            atol=atol,
                            rtol=rtol)

        # FLDJ
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=prefer_static.rank_from_shape(xs.shape) -
                         1))
        fldj_fn = functools.partial(bijector.forward_log_det_jacobian,
                                    event_ndims=event_ndims)
        vectorized_fldj = tf.vectorized_map(fldj_fn, xs)
        fldj = tf.broadcast_to(fldj_fn(xs), tf.shape(vectorized_fldj))
        self.assertAllClose(*self.evaluate((fldj, vectorized_fldj)),
                            atol=atol,
                            rtol=rtol)

        # Inverse
        ys = self._draw_codomain_tensor(bijector,
                                        data,
                                        event_dim,
                                        sample_shape=[n])
        xs = bijector.inverse(ys)
        vectorized_xs = tf.vectorized_map(bijector.inverse, ys)
        self.assertAllClose(*self.evaluate((xs, vectorized_xs)),
                            atol=atol,
                            rtol=rtol)

        # ILDJ
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=prefer_static.rank_from_shape(ys.shape) -
                         1))
        ildj_fn = functools.partial(bijector.inverse_log_det_jacobian,
                                    event_ndims=event_ndims)
        vectorized_ildj = tf.vectorized_map(ildj_fn, ys)
        ildj = tf.broadcast_to(ildj_fn(ys), tf.shape(vectorized_ildj))
        self.assertAllClose(*self.evaluate((ildj, vectorized_ildj)),
                            atol=atol,
                            rtol=rtol)
Beispiel #17
0
    def _sample_n(self, n, seed):
        components_seed, mix_seed = samplers.split_seed(
            seed, salt='MixtureSameFamily')
        try:
            seed_stream = SeedStream(seed, salt='MixtureSameFamily')
        except TypeError as e:  # Can happen for Tensor seeds.
            seed_stream = None
            seed_stream_err = e
        try:
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=components_seed)
            if seed_stream is not None:
                seed_stream()  # Advance even if unused.
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `components_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. {}')
            warnings.warn(
                msg.format(self.components_distribution.name,
                           type(self.components_distribution), str(e)))
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=seed_stream())

        event_shape = None
        event_ndims = tensorshape_util.rank(self.event_shape)
        if event_ndims is None:
            event_shape = self.components_distribution.event_shape_tensor()
            event_ndims = prefer_static.rank_from_shape(event_shape)
        event_ndims_static = tf.get_static_value(event_ndims)

        num_components = None
        if event_ndims_static is not None:
            num_components = tf.compat.dimension_value(
                x.shape[-1 - event_ndims_static])
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        if num_components is None:
            num_components = tf.shape(x)[-1 - event_ndims]

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        try:
            mix_sample = self.mixture_distribution.sample(
                n, seed=mix_seed)  # [n, B] or [n]
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `mixture_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(self.mixture_distribution.name,
                           type(self.mixture_distribution), str(e)))
            mix_sample = self.mixture_distribution.sample(
                n, seed=seed_stream())  # [n, B] or [n]
        mask = tf.one_hot(
            indices=mix_sample,  # [n, B] or [n]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k] or [n, k]

        # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] .
        batch_ndims = prefer_static.rank(x) - event_ndims - 1
        mask_batch_ndims = prefer_static.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = prefer_static.shape(mask)
        mask = tf.reshape(
            mask,
            shape=prefer_static.concat([
                mask_shape[:-1],
                prefer_static.ones([pad_ndims], dtype=tf.int32),
                mask_shape[-1:],
                prefer_static.ones([event_ndims], dtype=tf.int32),
            ],
                                       axis=0))

        if x.dtype in [
                tf.bfloat16, tf.float16, tf.float32, tf.float64, tf.complex64,
                tf.complex128
        ]:
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        ret = tf.reduce_sum(masked, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            if event_shape is None:
                event_shape = self.components_distribution.event_shape_tensor()
            ret = self._reparameterize_sample(ret, event_shape=event_shape)

        return ret
def build_factored_surrogate_posterior(  # pylint: disable=dangerous-default-value
        event_shape=None,
        bijector=None,
        batch_shape=(),
        base_distribution_cls=normal.Normal,
        initial_parameters={'scale': 1e-2},
        dtype=tf.float32,
        seed=None,
        validate_args=False,
        name=None):
    """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    bijector: Optional `tfb.Bijector` instance, or nested structure of such
      instances, defining support(s) of the posterior variables. The structure
      must match that of `event_shape` and may contain `None` values. A
      posterior variable will be modeled as
      `tfd.TransformedDistribution(underlying_dist, bijector)` if a
      corresponding constraining bijector is specified, otherwise it is modeled
      as supported on the unconstrained real line.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    base_distribution_cls: Subclass of `tfd.Distribution` that is instantiated
      and optionally transformed by the bijector to define the component
      distributions. May optionally be a structure of such subclasses
      matching `event_shape`.
      Default value: `tfd.Normal`.
    initial_parameters: Optional `str : Tensor` dictionary specifying initial
      values for some or all of the base distribution's trainable parameters,
      or a Python `callable` with signature
      `value = parameter_init_fn(parameter_name, shape, dtype, seed,
      constraining_bijector)`, passed to `tfp.experimental.util.make_trainable`.
      May optionally be a structure matching `event_shape` of such dictionaries
      and/or callables. Dictionary entries that do not correspond to parameter
      names are ignored.
      Default value: `{'scale': 1e-2}` (ignored when `base_distribution` does
        not have a `scale` parameter).
    dtype: Optional float `dtype` for trainable parameters. May
      optionally be a structure of such `dtype`s matching `event_shape`.
      Default value: `tf.float32`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').

  Returns:
    surrogate_posterior: A `tfd.Distribution` instance whose samples have
      shape and structure matching that of `event_shape` or `initial_loc`.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    bijector=[tfb.Softplus(),   # Rate is positive.
              tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify initial parameters when we build the surrogate posterior. Note that
  these parameterize the distribution(s) over unconstrained values,
  so we need to transform our desired constrained locations using the inverse
  of the constraining bijector(s).

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    bijector={'concentration': tfb.Softplus(),   # Rate is positive.
              'rate': tfb.Softplus()}   # Concentration is positive.
    initial_parameters={
      'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2},
      'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2}})
  ```

  """
    with tf.name_scope(name or 'build_factored_surrogate_posterior'):
        # Convert event shapes to Tensors.
        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(
            shallow_structure,
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)

        if nest.is_nested(bijector):
            event_space_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity() if b is None else b,
                    nest_util.coerce_structure(event_shape, bijector)),
                validate_args=validate_args)
        else:
            event_space_bijector = bijector

        if event_space_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                event_space_bijector.inverse_event_shape_tensor(event_shape))
        unconstrained_batch_and_event_shape = tf.nest.map_structure(
            lambda s: ps.concat([batch_shape, s], axis=0),
            unconstrained_event_shape)

        base_distribution_cls = nest_util.broadcast_structure(
            event_shape, base_distribution_cls)
        try:
            # Check that we have initial parameters for each event part.
            nest.assert_shallow_structure(event_shape, initial_parameters)
        except (ValueError, TypeError):
            # If not, broadcast the parameters to match the event structure.
            # We do this manually rather than using `nest_util.broadcast_structure`
            # because the initial parameters can themselves be structures (dicts).
            initial_parameters = nest.map_structure(
                lambda x: initial_parameters, event_shape)

        unconstrained_trainable_distributions = (
            nest_util.map_structure_with_named_args(
                tfe_util.make_trainable,
                cls=base_distribution_cls,
                initial_parameters=initial_parameters,
                batch_and_event_shape=unconstrained_batch_and_event_shape,
                parameter_dtype=nest_util.broadcast_structure(
                    event_shape, dtype),
                seed=tf.nest.pack_sequence_as(
                    event_shape,
                    samplers.split_seed(seed,
                                        n=len(tf.nest.flatten(event_shape)))),
                _up_to=event_shape))
        unconstrained_trainable_distribution = (
            joint_distribution_util.
            independent_joint_distribution_from_structure(
                unconstrained_trainable_distributions,
                batch_ndims=ps.rank_from_shape(batch_shape),
                validate_args=validate_args))
        if event_space_bijector is None:
            return unconstrained_trainable_distribution
        return transformed_distribution.TransformedDistribution(
            unconstrained_trainable_distribution, event_space_bijector)
Beispiel #19
0
 def _batch_shape_tensor(self):
     batch_shape = self.distribution.batch_shape_tensor()
     batch_ndims = ps.rank_from_shape(batch_shape,
                                      self.distribution.batch_shape)
     return batch_shape[:batch_ndims -
                        self._get_reinterpreted_batch_ndims(batch_shape)]
Beispiel #20
0
def build_factored_surrogate_posterior(
    event_shape=None,
    constraining_bijectors=None,
    initial_unconstrained_loc=_sample_uniform_initial_loc,
    initial_unconstrained_scale=1e-2,
    trainable_distribution_fn=_build_trainable_normal_dist,
    seed=None,
    validate_args=False,
    name=None):
  """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    constraining_bijectors: Optional `tfb.Bijector` instance, or nested
      structure of such instances, defining support(s) of the posterior
      variables. The structure must match that of `event_shape` and may
      contain `None` values. A posterior variable will
      be modeled as `tfd.TransformedDistribution(underlying_dist,
      constraining_bijector)` if a corresponding constraining bijector is
      specified, otherwise it is modeled as supported on the
      unconstrained real line.
    initial_unconstrained_loc: Optional Python `callable` with signature
      `tensor = initial_unconstrained_loc(shape, seed)` used to sample
      real-valued initializations for the unconstrained representation of each
      variable. May alternately be a nested structure of
      `Tensor`s, giving specific initial locations for each variable; these
      must have structure matching `event_shape` and shapes determined by the
      inverse image of `event_shape` under `constraining_bijectors`, which
      may optionally be prefixed with a common batch shape.
      Default value: `functools.partial(tf.random.uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    initial_unconstrained_scale: Optional scalar float `Tensor` initial
      scale for the unconstrained distributions, or a nested structure of
      `Tensor` initial scales for each variable.
      Default value: `1e-2`.
    trainable_distribution_fn: Optional Python `callable` with signature
      `trainable_dist = trainable_distribution_fn(initial_loc, initial_scale,
      event_ndims, validate_args)`. This is called for each model variable to
      build the corresponding factor in the surrogate posterior. It is expected
      that the distribution returned is supported on unconstrained real values.
      Default value: `functools.partial(
        tfp.experimental.vi.build_trainable_location_scale_distribution,
        distribution_fn=tfd.Normal)`, i.e., a trainable Normal distribution.
    seed: Python integer to seed the random number generator. This is used
      only when `initial_loc` is not specified.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').

  Returns:
    surrogate_posterior: A `tfd.Distribution` instance whose samples have
      shape and structure matching that of `event_shape` or `initial_loc`.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    constraining_bijectors=[tfb.Softplus(),   # Rate is positive.
                            tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify one when we build the surrogate posterior. This function requires the
  initial location to be specified in *unconstrained* space; we do this by
  inverting the constraining bijectors (note this section also demonstrates the
  creation of a dict-structured model).

  ```python
  initial_loc = {'concentration': 0.4, 'rate': 0.2}
  constraining_bijectors={'concentration': tfb.Softplus(),   # Rate is positive.
                          'rate': tfb.Softplus()}   # Concentration is positive.
  initial_unconstrained_loc = tf.nest.map_fn(
    lambda b, x: b.inverse(x) if b is not None else x,
    constraining_bijectors, initial_loc)
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    constraining_bijectors=constraining_bijectors,
    initial_unconstrained_loc=initial_unconstrained_state,
    initial_unconstrained_scale=1e-4)
  ```

  """

  with tf.name_scope(name or 'build_factored_surrogate_posterior'):
    seed = tfp_util.SeedStream(seed, salt='build_factored_surrogate_posterior')

    # Convert event shapes to Tensors.
    shallow_structure = _get_event_shape_shallow_structure(event_shape)
    event_shape = nest.map_structure_up_to(
        shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32),
        event_shape)
    flat_event_shapes = tf.nest.flatten(event_shape)

    # For simplicity, we'll work with flattened lists of state parts and
    # repack the structure at the end.
    if constraining_bijectors is not None:
      flat_bijectors = tf.nest.flatten(constraining_bijectors)
    else:
      flat_bijectors = [None for _ in flat_event_shapes]
    flat_unconstrained_event_shapes = [
        b.inverse_event_shape_tensor(s) if b is not None else s
        for s, b in zip(flat_event_shapes, flat_bijectors)]

    # Construct initial locations for the internal unconstrained dists.
    if callable(initial_unconstrained_loc):  # Sample random initialization.
      flat_unconstrained_locs = [initial_unconstrained_loc(
          shape=s, seed=seed()) for s in flat_unconstrained_event_shapes]
    else:  # Use provided initialization.
      flat_unconstrained_locs = nest.flatten_up_to(
          shallow_structure, initial_unconstrained_loc, check_types=False)

    if nest.is_nested(initial_unconstrained_scale):
      flat_unconstrained_scales = nest.flatten_up_to(
          shallow_structure, initial_unconstrained_scale, check_types=False)
    else:
      flat_unconstrained_scales = [
          initial_unconstrained_scale for _ in flat_unconstrained_locs]

    # Extract the rank of each event, so that we build distributions with the
    # correct event shapes.
    flat_unconstrained_event_ndims = [prefer_static.rank_from_shape(s)
                                      for s in flat_unconstrained_event_shapes]

    # Build the component surrogate posteriors.
    flat_component_dists = []
    for initial_loc, initial_scale, event_ndims, bijector in zip(
        flat_unconstrained_locs,
        flat_unconstrained_scales,
        flat_unconstrained_event_ndims,
        flat_bijectors):
      unconstrained_dist = trainable_distribution_fn(
          initial_loc=initial_loc, initial_scale=initial_scale,
          event_ndims=event_ndims, validate_args=validate_args)
      flat_component_dists.append(
          bijector(unconstrained_dist) if bijector is not None
          else unconstrained_dist)
    component_distributions = tf.nest.pack_sequence_as(
        event_shape, flat_component_dists)

    # Return a `Distribution` object whose events have the specified structure.
    return (
        joint_distribution_util.independent_joint_distribution_from_structure(
            component_distributions, validate_args=validate_args))
Beispiel #21
0
def _get_reduction_axes(x, nd):
    """Enumerates the final `nd` axis indices of `x`."""
    x_rank = prefer_static.rank_from_shape(prefer_static.shape(x))
    return prefer_static.range(x_rank - 1, x_rank - nd - 1, -1)
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  batch_ndims=None,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    batch_ndims: Optional integer `Tensor` number of leftmost batch dimensions
      shared across all members of the input structure. If this is specified,
      the returned joint distribution will be an autobatched distribution with
      the given batch rank, and all other dimensions absorbed into the event.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        dist = structure_of_distributions
        if batch_ndims is not None:
            excess_ndims = ps.rank_from_shape(
                dist.batch_shape_tensor()) - batch_ndims
            if tf.get_static_value(
                    excess_ndims) != 0:  # Static value may be None.
                dist = independent.Independent(
                    dist, reinterpreted_batch_ndims=excess_ndims)
        return dist

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            functools.partial(independent_joint_distribution_from_structure,
                              batch_ndims=batch_ndims,
                              validate_args=validate_args),
            structure_of_distributions)

    jdnamed = joint_distribution_named.JointDistributionNamed
    jdsequential = joint_distribution_sequential.JointDistributionSequential
    # Use an autobatched JD if a specific batch rank was requested.
    if batch_ndims is not None:
        jdnamed = functools.partial(
            joint_distribution_auto_batched.JointDistributionNamedAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)
        jdsequential = functools.partial(
            joint_distribution_auto_batched.
            JointDistributionSequentialAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict')
            or isinstance(structure_of_distributions, collections.Mapping)):
        return jdnamed(structure_of_distributions, validate_args=validate_args)
    return jdsequential(structure_of_distributions,
                        validate_args=validate_args)
    def __init__(self,
                 distribution,
                 bijector,
                 batch_shape=None,
                 event_shape=None,
                 kwargs_split_fn=_default_kwargs_split_fn,
                 validate_args=False,
                 parameters=None,
                 name=None):
        """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns
        a tuple of kwargs `dict`s for each of the `distribution` and `bijector`
        parameters respectively.
        Default value: `_default_kwargs_split_fn` (i.e.,
            `lambda kwargs: (kwargs.get('distribution_kwargs', {}),
                             kwargs.get('bijector_kwargs', {}))`)
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      parameters: Locals dict captured by subclass constructor, to be used for
        copy/slice re-instantiation operations.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
        parameters = dict(locals()) if parameters is None else parameters
        name = name or (("" if bijector is None else bijector.name) +
                        (distribution.name or ""))
        with tf.name_scope(name) as name:
            self._kwargs_split_fn = (_default_kwargs_split_fn
                                     if kwargs_split_fn is None else
                                     kwargs_split_fn)
            # For convenience we define some handy constants.
            self._zero = tf.constant(0, dtype=tf.int32, name="zero")
            self._empty = tf.constant([], dtype=tf.int32, name="empty")

            # We will keep track of a static and dynamic version of
            # self._is_{batch,event}_override. This way we can do more prior to graph
            # execution, including possibly raising Python exceptions.

            self._override_batch_shape = self._maybe_validate_shape_override(
                batch_shape, distribution.is_scalar_batch(), validate_args,
                "batch_shape")
            self._is_batch_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_batch_shape),
                    self._zero))
            self._is_maybe_batch_override = bool(
                tf.get_static_value(self._override_batch_shape) is None
                or tf.get_static_value(self._override_batch_shape).size != 0)

            self._override_event_shape = self._maybe_validate_shape_override(
                event_shape, distribution.is_scalar_event(), validate_args,
                "event_shape")
            self._is_event_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_event_shape),
                    self._zero))
            self._is_maybe_event_override = bool(
                tf.get_static_value(self._override_event_shape) is None
                or tf.get_static_value(self._override_event_shape).size != 0)

            # To convert a scalar distribution into a multivariate distribution we
            # will draw dims from the sample dims, which are otherwise iid. This is
            # easy to do except in the case that the base distribution has batch dims
            # and we're overriding event shape. When that case happens the event dims
            # will incorrectly be to the left of the batch dims. In this case we'll
            # cyclically permute left the new dims.
            self._needs_rotation = prefer_static.reduce_all([
                self._is_event_override,
                prefer_static.logical_not(self._is_batch_override),
                prefer_static.logical_not(distribution.is_scalar_batch())
            ])
            override_event_ndims = prefer_static.rank_from_shape(
                self._override_event_shape)
            self._rotate_ndims = _pick_scalar_condition(
                self._needs_rotation, override_event_ndims, 0)
            # We'll be reducing the head dims (if at all), i.e., this will be []
            # if we don't need to reduce.
            self._reduce_event_indices = tf.range(
                self._rotate_ndims - override_event_ndims, self._rotate_ndims)

        self._distribution = distribution
        self._bijector = bijector
        super(TransformedDistribution, self).__init__(
            dtype=self._distribution.dtype,
            reparameterization_type=self._distribution.reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=self._distribution.allow_nan_stats,
            parameters=parameters,
            # We let TransformedDistribution access _graph_parents since this class
            # is more like a baseclass than derived.
            graph_parents=(
                distribution._graph_parents +  # pylint: disable=protected-access
                bijector.graph_parents),
            name=name)
def _resample_using_log_points(log_probs, sample_shape, log_points, name=None):
    """Resample from `log_probs` using supplied points in interval `[0, 1]`."""

    # We divide up the unit interval [0, 1] according to the provided
    # probability distributions using `cumulative_logsumexp`.
    # At the end of each division we place a 'marker'.
    # We use points on the unit interval supplied by caller.
    # We sort the combination of points and markers. The number
    # of points between the markers defining a division gives the number
    # of samples we require in that division.
    # For example, suppose `probs` is `[0.2, 0.3, 0.5]`.
    # We divide up `[0, 1]` using 3 markers:
    #
    #     |     |          |
    # 0.  0.2   0.5        1.0  <- markers
    #
    # Suppose we are given four points: [0.1, 0.25, 0.9, 0.75]
    # After sorting the combination we get:
    #
    # 0.1  0.25     0.75 0.9    <- points
    #  *  | *   |    *    *|
    # 0.   0.2 0.5         1.0  <- markers
    #
    # We have one sample in the first category, one in the second and
    # two in the last.
    #
    # All of these computations are carried out in batched form.

    with tf.name_scope(name or 'resample_using_log_points') as name:
        points_shape = ps.shape(log_points)
        batch_shape, [num_markers] = ps.split(ps.shape(log_probs),
                                              num_or_size_splits=[-1, 1])

        # `working_shape` specifies the total number of events
        # we will be generating.
        working_shape = ps.concat([sample_shape, batch_shape], axis=0)
        # `markers_shape` is the shape of the markers we temporarily insert.
        markers_shape = ps.concat([working_shape, [num_markers]], axis=0)

        markers = ps.concat([
            tf.ones(markers_shape, dtype=tf.int32),
            tf.zeros(points_shape, dtype=tf.int32)
        ],
                            axis=-1)
        log_marker_positions = tf.broadcast_to(
            log_cumsum_exp(log_probs, axis=-1), markers_shape)
        log_markers_and_points = ps.concat([log_marker_positions, log_points],
                                           axis=-1)
        # Stable sort is used to ensure that no points get sorted between
        # markers that have zero distance between them. This ensures that
        # there will never be a sample drawn whose probability is intended
        # to be zero even when a point falls on the edge of the
        # corresponding zero-width bucket.
        indices = tf.argsort(log_markers_and_points, axis=-1, stable=True)
        sorted_markers = tf.gather_nd(
            markers,
            indices[..., tf.newaxis],
            batch_dims=(ps.rank_from_shape(sample_shape) +
                        ps.rank_from_shape(batch_shape)))
        markers_and_samples = ps.cast(tf.cumsum(sorted_markers, axis=-1),
                                      dtype=tf.int32)
        markers_and_samples = tf.math.minimum(markers_and_samples,
                                              num_markers - np.int32(1))

        # Collect up samples, omitting markers.
        samples_mask = tf.equal(sorted_markers, 0)

        # The following block of code is equivalent to
        # `samples = markers_and_samples[samples_mask]` however boolean mask
        # indices are not supported by XLA.
        # Instead we use `argsort` to pick out the top `num_samples`
        # elements of `markers_and_samples` when sorted using `samples_mask`
        # as key.
        num_samples = points_shape[-1]
        sample_locations = tf.argsort(ps.cast(samples_mask, dtype=tf.int32),
                                      direction='DESCENDING',
                                      stable=True)
        samples = tf.gather_nd(markers_and_samples,
                               sample_locations[..., :num_samples, tf.newaxis],
                               batch_dims=(ps.rank_from_shape(sample_shape) +
                                           ps.rank_from_shape(batch_shape)))

        return tf.reshape(samples, points_shape)
Beispiel #25
0
  def loop_tree_doubling(self, step_size, momentum_state_memory,
                         current_step_meta_info, iter_, initial_step_state,
                         initial_step_metastate):
    """Main loop for tree doubling."""
    with tf.name_scope('loop_tree_doubling'):
      batch_shape = ps.shape(current_step_meta_info.init_energy)
      direction = tf.cast(
          tf.random.uniform(
              shape=batch_shape,
              minval=0,
              maxval=2,
              dtype=tf.int32,
              seed=self._seed_stream()),
          dtype=tf.bool)

      tree_start_states = tf.nest.map_structure(
          lambda v: tf.where(  # pylint: disable=g-long-lambda
              mcmc_util.left_justified_expand_dims_like(direction, v[1]),
              v[1], v[0]),
          initial_step_state)

      directions_expanded = [
          mcmc_util.left_justified_expand_dims_like(direction, state)
          for state in tree_start_states.state
      ]

      integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
          self.target_log_prob_fn,
          step_sizes=[
              tf.where(d, ss, -ss)
              for d, ss in zip(directions_expanded, step_size)
          ],
          num_steps=self.unrolled_leapfrog_steps)

      [
          candidate_tree_state,
          tree_final_states,
          final_not_divergence,
          continue_tree_final,
          energy_diff_tree_sum,
          momentum_subtree_cumsum,
          leapfrogs_taken
      ] = self._build_sub_tree(
          directions_expanded,
          integrator,
          current_step_meta_info,
          # num_steps_at_this_depth = 2**iter_ = 1 << iter_
          tf.bitwise.left_shift(1, iter_),
          tree_start_states,
          initial_step_metastate.continue_tree,
          initial_step_metastate.not_divergence,
          momentum_state_memory)

      last_candidate_state = initial_step_metastate.candidate_state

      energy_diff_sum = (
          energy_diff_tree_sum + initial_step_metastate.energy_diff_sum)
      if MULTINOMIAL_SAMPLE:
        tree_weight = tf.where(
            continue_tree_final,
            candidate_tree_state.weight,
            tf.constant(-np.inf, dtype=candidate_tree_state.weight.dtype))
        weight_sum = log_add_exp(tree_weight, last_candidate_state.weight)
        log_accept_thresh = tree_weight - last_candidate_state.weight
      else:
        tree_weight = tf.where(
            continue_tree_final,
            candidate_tree_state.weight,
            tf.zeros([], dtype=TREE_COUNT_DTYPE))
        weight_sum = tree_weight + last_candidate_state.weight
        log_accept_thresh = tf.math.log(
            tf.cast(tree_weight, tf.float32) /
            tf.cast(last_candidate_state.weight, tf.float32))
      log_accept_thresh = tf.where(
          tf.math.is_nan(log_accept_thresh),
          tf.zeros([], log_accept_thresh.dtype),
          log_accept_thresh)
      u = tf.math.log1p(-tf.random.uniform(
          shape=batch_shape,
          dtype=log_accept_thresh.dtype,
          seed=self._seed_stream()))
      is_sample_accepted = u <= log_accept_thresh

      choose_new_state = is_sample_accepted & continue_tree_final

      new_candidate_state = TreeDoublingStateCandidate(
          state=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  mcmc_util.left_justified_expand_dims_like(
                      choose_new_state, s0),
                  s0, s1)
              for s0, s1 in zip(candidate_tree_state.state,
                                last_candidate_state.state)
          ],
          target=tf.where(
              mcmc_util.left_justified_expand_dims_like(
                  choose_new_state,
                  candidate_tree_state.target),
              candidate_tree_state.target, last_candidate_state.target),
          target_grad_parts=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  mcmc_util.left_justified_expand_dims_like(
                      choose_new_state, grad0),
                  grad0, grad1)
              for grad0, grad1 in zip(candidate_tree_state.target_grad_parts,
                                      last_candidate_state.target_grad_parts)
          ],
          energy=tf.where(
              mcmc_util.left_justified_expand_dims_like(
                  choose_new_state, candidate_tree_state.target),
              candidate_tree_state.energy, last_candidate_state.energy),
          weight=weight_sum)

      for new_candidate_state_temp, old_candidate_state_temp in zip(
          new_candidate_state.state, last_candidate_state.state):
        tensorshape_util.set_shape(new_candidate_state_temp,
                                   old_candidate_state_temp.shape)

      for new_candidate_grad_temp, old_candidate_grad_temp in zip(
          new_candidate_state.target_grad_parts,
          last_candidate_state.target_grad_parts):
        tensorshape_util.set_shape(new_candidate_grad_temp,
                                   old_candidate_grad_temp.shape)

      # Update left right information of the trajectory, and check trajectory
      # level U turn
      tree_otherend_states = tf.nest.map_structure(
          lambda v: tf.where(  # pylint: disable=g-long-lambda
              mcmc_util.left_justified_expand_dims_like(direction, v[1]),
              v[0], v[1]), initial_step_state)

      new_step_state = tf.nest.pack_sequence_as(initial_step_state, [
          tf.stack([  # pylint: disable=g-complex-comprehension
              tf.where(
                  mcmc_util.left_justified_expand_dims_like(direction, left),
                  right, left),
              tf.where(
                  mcmc_util.left_justified_expand_dims_like(direction, left),
                  left, right),
          ], axis=0)
          for left, right in zip(tf.nest.flatten(tree_final_states),
                                 tf.nest.flatten(tree_otherend_states))
      ])

      momentum_tree_cumsum = []
      for p0, p1 in zip(
          initial_step_metastate.momentum_sum, momentum_subtree_cumsum):
        momentum_part_temp = p0 + p1
        tensorshape_util.set_shape(momentum_part_temp, p0.shape)
        momentum_tree_cumsum.append(momentum_part_temp)

      for new_state_temp, old_state_temp in zip(
          tf.nest.flatten(new_step_state),
          tf.nest.flatten(initial_step_state)):
        tensorshape_util.set_shape(new_state_temp, old_state_temp.shape)

      if GENERALIZED_UTURN:
        state_diff = momentum_tree_cumsum
      else:
        state_diff = [s[1] - s[0] for s in new_step_state.state]

      no_u_turns_trajectory = has_not_u_turn(
          state_diff,
          [m[0] for m in new_step_state.momentum],
          [m[1] for m in new_step_state.momentum],
          log_prob_rank=ps.rank_from_shape(batch_shape))

      new_step_metastate = TreeDoublingMetaState(
          candidate_state=new_candidate_state,
          is_accepted=choose_new_state | initial_step_metastate.is_accepted,
          momentum_sum=momentum_tree_cumsum,
          energy_diff_sum=energy_diff_sum,
          continue_tree=continue_tree_final & no_u_turns_trajectory,
          not_divergence=final_not_divergence,
          leapfrog_count=(initial_step_metastate.leapfrog_count +
                          leapfrogs_taken))

      return iter_ + 1, new_step_state, new_step_metastate
def resample_deterministic_minimum_error(
        log_probs,
        event_size,
        sample_shape,
        seed=None,
        name='resample_deterministic_minimum_error'):
    """Deterministic minimum error resampler for sequential Monte Carlo.

    The return value of this function is similar to sampling with

    ```python
    expanded_sample_shape = tf.concat([sample_shape, [event_size]]), axis=-1)
    tfd.Categorical(logits=log_probs).sample(expanded_sample_shape)`
    ```

    but with values chosen deterministically so that the empirical distribution
    is as close as possible to the specified distribution.
    (Note that the empirical distribution can only exactly equal the requested
    distribution if multiplying every probability by `event_size` gives
    an integer. So in general this is a biased "sampler".)
    It is intended to provide a good representative sample, suitable for use
    with some Sequential Monte Carlo algorithms.

  This function is based on Algorithm #3 in [Maskell et al. (2006)][1].

  Args:
    log_probs: a tensor-valued batch of discrete log probability distributions.
    event_size: the dimension of the vector considered a single draw.
    sample_shape: the `sample_shape` determining the number of draws. Because
      this resampler is deterministic it simply replicates the draw you
      would get for `sample_shape=[1]`.
    seed: This argument is unused but is present so that this function shares
      its interface with the other resampling functions.
      Default value: None
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'resample_deterministic_minimum_error'`).

  Returns:
    resampled_indices: a tensor of samples.

  #### References
  [1]: S. Maskell, B. Alun-Jones and M. Macleod. A Single Instruction Multiple
       Data Particle Filter.
       In 2006 IEEE Nonlinear Statistical Signal Processing Workshop.
       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf

  """
    del seed

    with tf.name_scope(name or 'resample_deterministic_minimum_error'):
        sample_shape = tf.convert_to_tensor(sample_shape, dtype_hint=tf.int32)
        log_probs = dist_util.move_dimension(log_probs,
                                             source_idx=0,
                                             dest_idx=-1)
        probs = tf.math.exp(log_probs)
        prob_shape = ps.shape(probs)
        pdf_size = prob_shape[-1]
        # If we could draw fractional numbers of samples we would
        # choose `ideal_numbers` for the number of each element.
        ideal_numbers = event_size * probs
        # We approximate the ideal numbers by truncating to integers
        # and then repair the counts starting with the one with the
        # largest fractional error and working our way down.
        first_approximation = tf.floor(ideal_numbers)
        missing_fractions = ideal_numbers - first_approximation
        first_approximation = ps.cast(first_approximation, dtype=tf.int32)
        fraction_order = tf.argsort(missing_fractions, axis=-1)
        # We sort the integer parts and fractional parts together.
        batch_dims = ps.rank_from_shape(prob_shape) - 1
        first_approximation = tf.gather_nd(first_approximation,
                                           fraction_order[..., tf.newaxis],
                                           batch_dims=batch_dims)
        missing_fractions = tf.gather_nd(missing_fractions,
                                         fraction_order[..., tf.newaxis],
                                         batch_dims=batch_dims)
        sample_defect = event_size - tf.reduce_sum(
            first_approximation, axis=-1, keepdims=True)
        unpermuted = tf.broadcast_to(tf.range(pdf_size), prob_shape)
        increments = tf.cast(unpermuted >= pdf_size - sample_defect,
                             dtype=first_approximation.dtype)
        counts = first_approximation + increments
        samples = _samples_from_counts(fraction_order, counts, event_size)
        result_shape = tf.concat([sample_shape, prob_shape[:-1], [event_size]],
                                 axis=0)
        # Replicate sample up to batch size.
        # TODO(dpiponi): rather than replicating, spread the "error" over
        # multiple samples with a minimum-discrepancy sequence.
        resampled = tf.broadcast_to(samples, result_shape)
        return dist_util.move_dimension(resampled, source_idx=-1, dest_idx=0)
def _make_asvi_trainable_variables(prior,
                                   mean_field=False,
                                   initial_prior_weight=0.5):
    """Generates parameter dictionaries given a prior distribution and list."""
    with tf.name_scope('make_asvi_trainable_variables'):
        param_dicts = []
        prior_dists = prior._get_single_sample_distributions()  # pylint: disable=protected-access
        for dist in prior_dists:
            original_dist = dist.distribution if isinstance(dist,
                                                            Root) else dist

            substituted_dist = _as_trainable_family(original_dist)

            # Grab the base distribution if it exists
            try:
                actual_dist = substituted_dist.distribution
            except AttributeError:
                actual_dist = substituted_dist

            new_params_dict = {}

            #  Build trainable ASVI representation for each distribution's parameters.
            parameter_properties = actual_dist.parameter_properties(
                dtype=actual_dist.dtype)

            if isinstance(original_dist, sample.Sample):
                posterior_batch_shape = ps.concat([
                    actual_dist.batch_shape_tensor(),
                    original_dist.sample_shape
                ],
                                                  axis=0)
            else:
                posterior_batch_shape = actual_dist.batch_shape_tensor()

            for param, value in actual_dist.parameters.items():

                if param in (_NON_STATISTICAL_PARAMS +
                             _NON_TRAINABLE_PARAMS) or value is None:
                    continue

                actual_event_shape = parameter_properties[param].shape_fn(
                    actual_dist.event_shape_tensor())
                try:
                    bijector = parameter_properties[
                        param].default_constraining_bijector_fn()
                except NotImplementedError:
                    bijector = tfb.Identity()

                if mean_field:
                    prior_weight = None
                else:
                    unconstrained_ones = tf.ones(shape=ps.concat([
                        posterior_batch_shape,
                        bijector.inverse_event_shape_tensor(actual_event_shape)
                    ],
                                                                 axis=0),
                                                 dtype=actual_dist.dtype)

                    prior_weight = tfp_util.TransformedVariable(
                        initial_prior_weight * unconstrained_ones,
                        bijector=tfb.Sigmoid(),
                        name='prior_weight/{}/{}'.format(dist.name, param))

                # If the prior distribution was a tfd.Sample wrapping a base
                # distribution, we want to give every single sample in the prior its
                # own lambda and alpha value (rather than having a single lambda and
                # alpha).
                if isinstance(original_dist, sample.Sample):
                    value = tf.reshape(
                        value,
                        ps.concat([
                            actual_dist.batch_shape_tensor(),
                            ps.ones(
                                ps.rank_from_shape(
                                    original_dist.sample_shape)),
                            actual_event_shape
                        ],
                                  axis=0))
                    value = tf.broadcast_to(
                        value,
                        ps.concat([posterior_batch_shape, actual_event_shape],
                                  axis=0))
                new_params_dict[param] = ASVIParameters(
                    prior_weight=prior_weight,
                    mean_field_parameter=tfp_util.TransformedVariable(
                        value,
                        bijector=bijector,
                        name='mean_field_parameter/{}/{}'.format(
                            dist.name, param)))

            param_dicts.append(new_params_dict)
    return param_dicts
Beispiel #28
0
 def _transpose_around_bijector_fn(self,
                                   bijector_fn,
                                   arg,
                                   src_event_ndims,
                                   dest_event_ndims=None,
                                   fn_reduces_event=False,
                                   **kwargs):
     # This function moves the axes corresponding to `self.sample_shape` to the
     # left of the batch shape, then applies `bijector_fn`, then moves the axes
     # corresponding to `self.sample_shape` back to the event part of the shape.
     #
     # `src_event_ndims` and `dest_event_ndims` indicate the expected event rank
     # (omitting `self.sample_shape`) before and after applying `bijector_fn`.
     #
     # This function arose because forward and inverse ended up being quite
     # similar. It was then only a small generalization to also support {F/I}LDJ.
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     arg_ndims = ps.rank(arg)
     # (1) Expand arg's dims.
     d = arg_ndims - batch_ndims - extra_sample_ndims - src_event_ndims
     arg = tf.reshape(arg,
                      shape=ps.pad(ps.shape(arg),
                                   paddings=[[ps.maximum(0, -d), 0]],
                                   constant_values=1))
     arg_ndims = ps.rank(arg)
     sample_ndims = ps.maximum(0, d)
     # (2) Transpose arg's dims.
     sample_dims = ps.range(0, sample_ndims)
     batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
     extra_sample_dims = ps.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims,
                           arg_ndims)
     perm = ps.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     arg = tf.transpose(arg, perm=perm)
     # (3) Apply underlying bijector.
     result = bijector_fn(arg, **kwargs)
     # (4) Transpose sample_shape from the sample to the event shape.
     result_ndims = ps.rank(result)
     if fn_reduces_event:
         dest_event_ndims = 0
     d = result_ndims - batch_ndims - extra_sample_ndims - dest_event_ndims
     if fn_reduces_event:
         # In some cases, fn may reduce event too far, i.e. ildj may return a
         # scalar `0.`, which won't work with the transpose we do below.
         result = tf.reshape(result,
                             shape=ps.pad(ps.shape(result),
                                          paddings=[[ps.maximum(0, -d), 0]],
                                          constant_values=1))
         result_ndims = ps.rank(result)
     sample_ndims = ps.maximum(0, d)
     sample_dims = ps.range(0, sample_ndims)
     extra_sample_dims = ps.range(sample_ndims,
                                  sample_ndims + extra_sample_ndims)
     batch_dims = ps.range(sample_ndims + extra_sample_ndims,
                           sample_ndims + extra_sample_ndims + batch_ndims)
     event_dims = ps.range(sample_ndims + extra_sample_ndims + batch_ndims,
                           result_ndims)
     perm = ps.concat(
         [sample_dims, batch_dims, extra_sample_dims, event_dims], axis=0)
     return tf.transpose(result, perm=perm)
Beispiel #29
0
  def vectorized_fn(*args):
    """Vectorized version of `fn` that accepts arguments of any rank."""
    with tf.name_scope(name or 'make_rank_polymorphic'):
      assertions = []

      # If we got a single value for core_ndims, tile it across all args.
      core_ndims_structure = (
          core_ndims
          if nest.is_nested(core_ndims)
          else nest.map_structure(lambda _: core_ndims, args))

      # Build flat lists of all argument parts and their corresponding core
      # ndims.
      flat_core_ndims = nest.flatten(core_ndims_structure)
      parts = tf.nest.flatten(nest.map_structure_up_to(
          core_ndims_structure, tf.convert_to_tensor, args, check_types=False))
      if len(parts) != len(flat_core_ndims):
        raise ValueError('Number of args does not match `core_ndims` '
                         '({} vs {}). Saw argument parts {}; core '
                         'ndims {}.'.format(len(parts), len(flat_core_ndims),
                                            parts, flat_core_ndims))

      # `vectorized_map` requires all inputs to have a single, common batch
      # dimension `[n]`. So we broadcast all input parts to a common
      # batch shape, then flatten it down to a single dimension.

      # First, compute how many 'extra' (batch) ndims each part has. This must
      # be nonnegative.
      part_shapes = [tf.shape(part) for part in parts]
      batch_ndims = [
          prefer_static.rank_from_shape(part_shape) - nd
          for (part_shape, nd) in zip(part_shapes, flat_core_ndims)]
      static_ndims = [tf.get_static_value(nd) for nd in batch_ndims]
      if any([nd and nd < 0 for nd in static_ndims]):
        raise ValueError('Cannot broadcast a Tensor having lower rank than the '
                         'specified `core_ndims`! (saw input ranks {}, '
                         '`core_ndims` {}).'.format(
                             tf.nest.map_structure(
                                 prefer_static.rank_from_shape, part_shapes),
                             flat_core_ndims))
      if validate_args:
        for nd, part, core_nd in zip(batch_ndims, parts, flat_core_ndims):
          assertions.append(tf.debugging.assert_non_negative(
              nd, message='Cannot broadcast a Tensor having lower rank than '
              'the specified `core_ndims`! (saw {} vs minimum rank {}).'.format(
                  part, core_nd)))

      # Next, split each part's shape into batch and core shapes, and
      # broadcast the batch shapes.
      with tf.control_dependencies(assertions):
        batch_shapes, core_shapes = zip(*[
            (part_shape[:nd], part_shape[nd:])
            for (part_shape, nd) in zip(part_shapes, batch_ndims)])
        broadcast_batch_shape = functools.reduce(
            prefer_static.broadcast_shape, batch_shapes, [])

      # Flatten all of the batch dimensions into one.
      n = tf.cast(prefer_static.reduce_prod(broadcast_batch_shape), tf.int32)
      static_n = tf.get_static_value(n)
      if static_n == 1:
        result = fn(*args)
      else:
        # Pad all input parts to the common shape, then flatten
        # into the single leading dimension `[n]`.
        # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap
        # when batch rank is static so that we can exploit broadcasting.
        broadcast_parts = [
            tf.broadcast_to(part, prefer_static.concat([broadcast_batch_shape,
                                                        core_shape], axis=0))
            for (part, core_shape) in zip(parts, core_shapes)]
        parts_with_flattened_batch_dim = [
            tf.reshape(part, prefer_static.concat([[n], core_shape], axis=0))
            for (part, core_shape) in zip(broadcast_parts, core_shapes)]

        # Run the vectorized computation
        batched_result = tf.vectorized_map(lambda args: fn(*args),
                                           nest.pack_sequence_as(
                                               args,
                                               parts_with_flattened_batch_dim))

        # Unflatten the result
        result = nest.map_structure(
            lambda x: tf.reshape(x, prefer_static.concat([  # pylint: disable=g-long-lambda
                broadcast_batch_shape, prefer_static.shape(x)[1:]], axis=0)),
            batched_result)
    return result
Beispiel #30
0
def _asvi_surrogate_for_distribution(dist,
                                     base_distribution_surrogate_fn,
                                     sample_shape=None,
                                     variables=None,
                                     seed=None):
    """Recursively creates ASVI surrogates, and creates new variables if needed.

  Args:
    dist: a `tfd.Distribution` instance.
    base_distribution_surrogate_fn: Callable to build a surrogate posterior
      for a 'base' (non-meta and non-joint) distribution, with signature
      `surrogate_posterior, variables = base_distribution_fn(
      dist, sample_shape=None, variables=None, seed=None)`.
    sample_shape: Optional `Tensor` shape of samples drawn from `dist` by
      `tfd.Sample` wrappers. If not `None`, the surrogate's event will include
      independent sample dimensions, i.e., it will have event shape
      `concat([sample_shape, dist.event_shape], axis=0)`.
      Default value: `None`.
    variables: Optional nested structure of `tf.Variable`s returned from a
      previous call to `_asvi_surrogate_for_distribution`. If `None`,
      new variables will be created; otherwise, constructs a surrogate posterior
      backed by the passed-in variables.
      Default value: `None`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
  Returns:
    surrogate_posterior: Instance of `tfd.Distribution` representing a trainable
      surrogate posterior distribution, with the same structure and `name` as
      `dist`.
    variables: Nested structure of `tf.Variable` trainable parameters for the
      surrogate posterior. If `dist` is a base distribution, this is
      a `dict` of `ASVIParameters` instances. If `dist` is a joint
      distribution, this is a `dist.dtype` structure of such `dict`s.
  """
    # Pass args to any nested surrogates.
    build_nested_surrogate = functools.partial(
        _asvi_surrogate_for_distribution,
        base_distribution_surrogate_fn=base_distribution_surrogate_fn,
        sample_shape=sample_shape,
        seed=seed)

    # Apply any substitutions, while attempting to preserve the original name.
    dist = _set_name(_as_substituted_distribution(dist), name=_get_name(dist))

    # Handle wrapper ("meta") distributions.
    if isinstance(dist, markov_chain.MarkovChain):
        return _asvi_surrogate_for_markov_chain(
            dist=dist,
            variables=variables,
            base_distribution_surrogate_fn=base_distribution_surrogate_fn,
            sample_shape=sample_shape,
            seed=seed)
    if isinstance(dist, sample.Sample):
        dist_sample_shape = distribution_util.expand_to_vector(
            dist.sample_shape)
        nested_surrogate, variables = build_nested_surrogate(  # pylint: disable=redundant-keyword-arg
            dist=dist.distribution,
            variables=variables,
            sample_shape=(dist_sample_shape if sample_shape is None
                          else ps.concat([sample_shape, dist_sample_shape],
                                         axis=0)))
        surrogate_posterior = independent.Independent(
            nested_surrogate,
            reinterpreted_batch_ndims=ps.rank_from_shape(dist_sample_shape),
            name=_get_name(dist))
    # Treat distributions that subclass TransformedDistribution with their own
    # parameters (e.g., Gumbel, Weibull, MultivariateNormal*, etc) as their
    # own type of base distribution, rather than as explicit TDs.
    elif type(dist) == transformed_distribution.TransformedDistribution:  # pylint: disable=unidiomatic-typecheck
        nested_surrogate, variables = build_nested_surrogate(
            dist.distribution, variables=variables)
        surrogate_posterior = transformed_distribution.TransformedDistribution(
            nested_surrogate, bijector=dist.bijector, name=_get_name(dist))
    elif isinstance(dist, independent.Independent):
        nested_surrogate, variables = build_nested_surrogate(
            dist.distribution, variables=variables)
        surrogate_posterior = independent.Independent(
            nested_surrogate,
            reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims,
            name=_get_name(dist))
    elif hasattr(dist, '_model_coroutine'):
        surrogate_posterior, variables = _asvi_surrogate_for_joint_distribution(
            dist,
            base_distribution_surrogate_fn=base_distribution_surrogate_fn,
            variables=variables,
            seed=seed)
    elif (hasattr(dist, 'distribution') and
          # Transformed dists not handled above are treated as base distributions.
          not isinstance(dist,
                         transformed_distribution.TransformedDistribution)):
        raise ValueError('Meta-distribution `{}` is not yet supported by this '
                         'implementation of ASVI. Contact '
                         '`[email protected]` if you need this '
                         'functionality.'.format(type(dist)))
    else:
        surrogate_posterior, variables = base_distribution_surrogate_fn(
            dist=dist,
            sample_shape=sample_shape,
            variables=variables,
            seed=seed)
    return surrogate_posterior, variables