Esempio n. 1
0
def _asvi_surrogate_for_joint_distribution(dist, build_nested_surrogate):
    """Builds a structured joint surrogate posterior for a joint model."""

    # Probabilistic program for ASVI surrogate posterior.
    prior_coroutine = dist._model_coroutine  # pylint: disable=protected-access

    def posterior_generator():
        # Run the prior model coroutine and create a surrogate for each
        # distribution in turn. This coroutine yields *both* the surrogate
        # parameters and the surrogate distributions themselves. It can therefore
        # run in a `trainable_state_util` context to define parameters (using
        # the `_get_coroutine_parameters` wrapper to hide the yielded distributions)
        # *and* as a JDCoroutine model (using the `_set_coroutine_parameters`
        # wrapper to hide the yielded parameters).
        prior_gen = prior_coroutine()
        dist = next(prior_gen)
        i = 0
        try:
            while True:
                surrogate_posterior = yield from build_nested_surrogate(
                    dist.distribution if isinstance(dist, Root) else dist)
                if isinstance(dist, Root):
                    surrogate_posterior = Root(surrogate_posterior)
                value_out = yield surrogate_posterior
                dist = prior_gen.send(value_out)
                i += 1
        except StopIteration:
            pass

    parameters = yield from _get_coroutine_parameters(
        posterior_generator, seed=samplers.zeros_seed())
    surrogate_posterior = (
        joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched(
            _set_coroutine_parameters(posterior_generator, parameters)))

    # Ensure that the surrogate posterior structure matches that of the prior.
    try:
        tf.nest.assert_same_structure(dist.dtype, surrogate_posterior.dtype)
    except TypeError:
        tokenize = lambda jd: jd._model_unflatten(  # pylint: disable=protected-access, g-long-lambda
            range(len(jd._model_flatten(jd.dtype)))  # pylint: disable=protected-access
        )
        surrogate_posterior = restructure.Restructure(
            output_structure=tokenize(dist),
            input_structure=tokenize(surrogate_posterior))(surrogate_posterior)
    return surrogate_posterior
Esempio n. 2
0
def build_asvi_surrogate_posterior(prior,
                                   mean_field=False,
                                   initial_prior_weight=0.5,
                                   name=None):
  """Builds a structured surrogate posterior inspired by conjugate updating.

  ASVI, or Automatic Structured Variational Inference, was proposed by
  Ambrogioni et al. (2020) [1] as a method of automatically constructing a
  surrogate posterior with the same structure as the prior. It does this by
  reparameterizing the variational family of the surrogate posterior by
  structuring each parameter according to the equation
  ```none
  prior_weight * prior_parameter + (1 - prior_weight) * mean_field_parameter
  ```
  In this equation, `prior_parameter` is a vector of prior parameters and
  `mean_field_parameter` is a vector of trainable parameters with the same
  domain as `prior_parameter`. `prior_weight` is a vector of learnable
  parameters where `0. <= prior_weight <= 1.`. When `prior_weight =
  0`, the surrogate posterior will be a mean-field surrogate, and when
  `prior_weight = 1.`, the surrogate posterior will be the prior. This convex
  combination equation, inspired by conjugacy in exponential families, thus
  allows the surrogate posterior to balance between the structure of the prior
  and the structure of a mean-field approximation.

  Args:
    prior: tfd.JointDistribution instance of the prior.
    mean_field: Optional Python boolean. If `True`, creates a degenerate
      surrogate distribution in which all variables are independent,
      ignoring the prior dependence structure. Default value: `False`.
    initial_prior_weight: Optional float value (either static or tensor value)
      on the interval [0, 1]. A larger value creates an initial surrogate
      distribution with more dependence on the prior structure. Default value:
      `0.5`.
    name: Optional string. Default value: `build_asvi_surrogate_posterior`.

  Returns:
    surrogate_posterior: A `tfd.JointDistributionCoroutineAutoBatched` instance
    whose samples have shape and structure matching that of `prior`.

  Raises:
    TypeError: The `prior` argument cannot be a nested `JointDistribution`.

  ### Examples

  Consider a Brownian motion model expressed as a JointDistribution:

  ```python
  prior_loc = 0.
  innovation_noise = .1

  def model_fn():
    new = yield tfd.Normal(loc=prior_loc, scale=innovation_noise)
    for i in range(4):
      new = yield tfd.Normal(loc=new, scale=innovation_noise)

  prior = tfd.JointDistributionCoroutineAutoBatched(model_fn)
  ```

  Let's use variational inference to approximate the posterior. We'll build a
  surrogate posterior distribution by feeding in the prior distribution.

  ```python
  surrogate_posterior =
    tfp.experimental.vi.build_asvi_surrogate_posterior(prior)
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  losses = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn,
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]
  posterior_std = [tf.math.reduce_std(x) for x in samples]
  ```

  #### References
  [1]: Luca Ambrogioni, Max Hinne, Marcel van Gerven. Automatic structured
        variational inference. _arXiv preprint arXiv:2002.00643_, 2020
        https://arxiv.org/abs/2002.00643

  """

  with tf.name_scope(name or 'build_asvi_surrogate_posterior'):
    param_dicts = _make_asvi_trainable_variables(
        prior=prior,
        mean_field=mean_field,
        initial_prior_weight=initial_prior_weight)
    def posterior_generator():

      prior_gen = prior._model_coroutine()  # pylint: disable=protected-access
      dist = next(prior_gen)

      i = 0
      try:
        while True:
          original_dist = dist.distribution if isinstance(dist, Root) else dist

          if isinstance(original_dist, joint_distribution.JointDistribution):
            # TODO(kateslin): Build inner JD surrogate in
            # _make_asvi_trainable_variables to avoid rebuilding variables.
            raise TypeError(
                'Argument `prior` cannot be a nested `JointDistribution`.')

          else:

            original_dist = _as_trainable_family(original_dist)

            try:
              actual_dist = original_dist.distribution
            except AttributeError:
              actual_dist = original_dist

            dist_params = actual_dist.parameters
            temp_params_dict = {}

            for param, value in dist_params.items():
              if param in (_NON_STATISTICAL_PARAMS +
                           _NON_TRAINABLE_PARAMS) or value is None:
                temp_params_dict[param] = value
              else:
                prior_weight = param_dicts[i][param].prior_weight
                mean_field_parameter = param_dicts[i][
                    param].mean_field_parameter
                if mean_field:
                  temp_params_dict[param] = mean_field_parameter
                else:
                  temp_params_dict[param] = prior_weight * value + (
                      1. - prior_weight) * mean_field_parameter

            if isinstance(original_dist, sample.Sample):
              inner_dist = type(actual_dist)(**temp_params_dict)

              surrogate_dist = independent.Independent(
                  inner_dist,
                  reinterpreted_batch_ndims=ps.rank_from_shape(
                      original_dist.sample_shape))
            else:
              surrogate_dist = type(actual_dist)(**temp_params_dict)

            if isinstance(original_dist,
                          transformed_distribution.TransformedDistribution):
              surrogate_dist = transformed_distribution.TransformedDistribution(
                  surrogate_dist, bijector=original_dist.bijector)

            if isinstance(original_dist, independent.Independent):
              surrogate_dist = independent.Independent(
                  surrogate_dist,
                  reinterpreted_batch_ndims=original_dist
                  .reinterpreted_batch_ndims)

            if isinstance(dist, Root):
              value_out = yield Root(surrogate_dist)
            else:
              value_out = yield surrogate_dist

          dist = prior_gen.send(value_out)
          i += 1
      except StopIteration:
        pass

    surrogate_posterior = (
        joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched(
            posterior_generator))

    # Ensure that the surrogate posterior structure matches that of the prior
    try:
      nest.assert_same_structure(prior.dtype, surrogate_posterior.dtype)
    except TypeError:
      tokenize = lambda jd: jd._model_unflatten(  # pylint: disable=protected-access, g-long-lambda
          range(len(jd._model_flatten(jd.dtype)))  # pylint: disable=protected-access
      )
      surrogate_posterior = restructure.Restructure(
          output_structure=tokenize(prior),
          input_structure=tokenize(surrogate_posterior))(
              surrogate_posterior)

    surrogate_posterior.also_track = param_dicts
    return surrogate_posterior
Esempio n. 3
0
def build_split_flow_surrogate_posterior(event_shape,
                                         trainable_bijector,
                                         constraining_bijector=None,
                                         base_distribution=normal.Normal,
                                         batch_shape=(),
                                         dtype=tf.float32,
                                         validate_args=False,
                                         name=None):
    """Builds a joint variational posterior by splitting a normalizing flow.

  Args:
    event_shape: (Nested) event shape of the surrogate posterior.
    trainable_bijector: A trainable `tfb.Bijector` instance that operates on
      `Tensor`s (not structures), e.g. `tfb.MaskedAutoregressiveFlow` or
      `tfb.RealNVP`. This bijector transforms the base distribution before it is
      split.
    constraining_bijector: `tfb.Bijector` instance, or nested structure of
      `tfb.Bijector` instances, that maps (nested) values in R^n to the support
      of the posterior. (This can be the
      `experimental_default_event_space_bijector` of the distribution over the
      prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and
      `scale`. The base distribution for the transformed surrogate has `loc=0.`
      and `scale=1.`.
      Default value: `tfd.Normal`.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    dtype: The `dtype` of the surrogate posterior.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_split_flow_surrogate_posterior').

  Returns:
    surrogate_distribution: Trainable `tfd.TransformedDistribution` with event
      shape equal to `event_shape`.

  ### Examples
  ```python

  # Train a normalizing flow on the Eight Schools model [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]
  model = tfd.JointDistributionNamed({
      'avg_effect':
          tfd.Normal(loc=0., scale=10., name='avg_effect'),
      'log_stddev':
          tfd.Normal(loc=5., scale=1., name='log_stddev'),
      'school_effects':
          lambda log_stddev, avg_effect: (
              tfd.Independent(
                  tfd.Normal(
                      loc=avg_effect[..., None] * tf.ones(8),
                      scale=tf.exp(log_stddev[..., None]) * tf.ones(8),
                      name='school_effects'),
                  reinterpreted_batch_ndims=1)),
      'treatment_effects': lambda school_effects: tfd.Independent(
          tfd.Normal(loc=school_effects, scale=treatment_stddevs),
          reinterpreted_batch_ndims=1)
  })

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Create a Masked Autoregressive Flow bijector.
  net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32)
  maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net)

  # Build and fit the surrogate posterior.
  surrogate_posterior = (
      tfp.experimental.vi.build_split_flow_surrogate_posterior(
          event_shape=target_model.event_shape_tensor(),
          trainable_bijector=maf,
          constraining_bijector=(
              target_model.experimental_default_event_space_bijector())))

  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name or 'build_split_flow_surrogate_posterior'):

        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(shallow_structure,
                                               ps.convert_to_shape_tensor,
                                               event_shape)

        if nest.is_nested(constraining_bijector):
            constraining_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity()
                    if b is None else b, constraining_bijector),
                validate_args=validate_args)

        if constraining_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                constraining_bijector.inverse_event_shape_tensor(event_shape))

        flat_base_event_shape = nest.flatten(unconstrained_event_shape)
        flat_base_event_size = nest.map_structure(tf.reduce_prod,
                                                  flat_base_event_shape)
        event_size = tf.reduce_sum(flat_base_event_size)

        base_distribution = sample.Sample(
            base_distribution(tf.zeros(batch_shape, dtype=dtype), scale=1.),
            [event_size])

        # After transforming base distribution samples with `trainable_bijector`,
        # split them into vector-valued components.
        split_bijector = split.Split(flat_base_event_size,
                                     validate_args=validate_args)

        # Reshape the vectors to the correct posterior event shape.
        event_reshape = joint_map.JointMap(nest.map_structure(
            reshape.Reshape, unconstrained_event_shape),
                                           validate_args=validate_args)

        # Restructure the flat list of components to the correct posterior
        # structure.
        event_unflatten = restructure.Restructure(
            nest.pack_sequence_as(unconstrained_event_shape,
                                  range(len(flat_base_event_shape))))

        bijectors = [] if constraining_bijector is None else [
            constraining_bijector
        ]
        bijectors.extend([
            event_reshape, event_unflatten, split_bijector, trainable_bijector
        ])
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        return transformed_distribution.TransformedDistribution(
            base_distribution, bijector=bijector, validate_args=validate_args)
Esempio n. 4
0
def _affine_surrogate_posterior_from_base_distribution(
        base_distribution,
        operators='diag',
        bijector=None,
        initial_unconstrained_loc_fn=_sample_uniform_initial_loc,
        validate_args=False,
        name=None):
    """Builds a variational posterior by linearly transforming base distributions.

  This function builds a surrogate posterior by applying a trainable
  transformation to a base distribution (typically a `tfd.JointDistribution`) or
  nested structure of base distributions, and constraining the samples with
  `bijector`. Note that the distributions must have event shapes corresponding
  to the *pretransformed* surrogate posterior -- that is, if `bijector` contains
  a shape-changing bijector, then the corresponding base distribution event
  shape is the inverse event shape of the bijector applied to the desired
  surrogate posterior shape. The surrogate posterior is constucted as follows:

  1. Flatten the base distribution event shapes to vectors, and pack the base
     distributions into a `tfd.JointDistribution`.
  2. Apply a trainable blockwise LinearOperator bijector to the joint base
     distribution.
  3. Apply the constraining bijectors and return the resulting trainable
     `tfd.TransformedDistribution` instance.

  Args:
    base_distribution: `tfd.Distribution` instance (typically a
      `tfd.JointDistribution`), or a nested structure of `tfd.Distribution`
      instances.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    initial_unconstrained_loc_fn: Optional Python `callable` with signature
      `initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)` used to
      sample real-valued initializations for the unconstrained location of
      each variable.
      Default value: `functools.partial(tf.random.stateless_uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e.,
      'build_affine_surrogate_posterior_from_base_distribution').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  Raises:
    NotImplementedError: Base distributions with mixed dtypes are not supported.

  #### Examples
  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Fit a multivariate Normal surrogate posterior on the Eight Schools model
  # [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]

  def model_fn():
    avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect')
    log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev')
    school_effects = yield tfd.Sample(
        tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)),
        sample_shape=[8],
        name='school_effects')
    treatment_effects = yield tfd.Independent(
        tfd.Normal(loc=school_effects, scale=treatment_stddevs),
        reinterpreted_batch_ndims=1,
        name='treatment_effects')
  model = tfd.JointDistributionCoroutineAutoBatched(model_fn)

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Define a lower triangular structure of `LinearOperator` subclasses that
  # models full covariance among latent variables except for the 8 dimensions
  # of `school_effect`, which are modeled as independent (using
  # `LinearOperatorDiag`).
  operators = [
    [tf.linalg.LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix,
     tf.linalg.LinearOperatorDiag]]


  # Constrain the posterior values to the support of the prior.
  bijector = target_model.experimental_default_event_space_bijector()

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution(
        base_distribution=base_distribution,
        operators=operators,
        bijector=bijector))

  # Fit the model.
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name
                       or 'affine_surrogate_posterior_from_base_distribution'):

        if nest.is_nested(base_distribution):
            base_distribution = (joint_distribution_util.
                                 independent_joint_distribution_from_structure(
                                     base_distribution,
                                     validate_args=validate_args))

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        batch_shape = base_distribution.batch_shape_tensor()
        if tf.nest.is_nested(
                batch_shape):  # Base is a classic JointDistribution.
            batch_shape = functools.reduce(ps.broadcast_shape,
                                           tf.nest.flatten(batch_shape))
        event_shape = base_distribution.event_shape_tensor()
        flat_event_size = nest.flatten(
            nest.map_structure(ps.reduce_prod, event_shape))

        base_dtypes = set([
            dtype_util.base_dtype(d)
            for d in nest.flatten(base_distribution.dtype)
        ])
        if len(base_dtypes) > 1:
            raise NotImplementedError(
                'Base distributions with mixed dtype are not supported. Saw '
                'components of dtype {}'.format(base_dtypes))
        base_dtype = list(base_dtypes)[0]

        num_components = len(flat_event_size)
        if operators == 'diag':
            operators = [tf.linalg.LinearOperatorDiag] * num_components
        elif operators == 'tril':
            operators = [[tf.linalg.LinearOperatorFullMatrix] * i +
                         [tf.linalg.LinearOperatorLowerTriangular]
                         for i in range(num_components)]
        elif isinstance(operators, str):
            raise ValueError(
                'Unrecognized operator type {}. Valid operators are "diag", "tril", '
                'or a structure that can be passed to '
                '`tfp.experimental.vi.util.build_trainable_linear_operator_block` as '
                'the `operators` arg.'.format(operators))

        if nest.is_nested(operators):
            operators = yield from trainable_linear_operators._trainable_linear_operator_block(  # pylint: disable=protected-access
                operators,
                block_dims=flat_event_size,
                dtype=base_dtype,
                batch_shape=batch_shape)

        linop_bijector = (
            scale_matvec_linear_operator.ScaleMatvecLinearOperatorBlock(
                scale=operators, validate_args=validate_args))

        def generate_shift_bijector(s):
            x = yield trainable_state_util.Parameter(
                functools.partial(initial_unconstrained_loc_fn,
                                  ps.concat([batch_shape, [s]], axis=0),
                                  dtype=base_dtype))
            return shift.Shift(x)

        loc_bijectors = yield from nest_util.map_structure_coroutine(
            generate_shift_bijector, flat_event_size)
        loc_bijector = joint_map.JointMap(loc_bijectors,
                                          validate_args=validate_args)

        unflatten_and_reshape = chain.Chain([
            joint_map.JointMap(nest.map_structure(reshape.Reshape,
                                                  event_shape),
                               validate_args=validate_args),
            restructure.Restructure(
                nest.pack_sequence_as(event_shape, range(num_components)))
        ],
                                            validate_args=validate_args)

        bijectors = [] if bijector is None else [bijector]
        bijectors.extend([
            unflatten_and_reshape,
            loc_bijector,  # Allow the mean of the standard dist to shift from 0.
            linop_bijector
        ])  # Apply LinOp to scale the standard dist.
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        flat_base_distribution = invert.Invert(unflatten_and_reshape)(
            base_distribution)

        return transformed_distribution.TransformedDistribution(
            flat_base_distribution,
            bijector=bijector,
            validate_args=validate_args)
Esempio n. 5
0
def _asvi_surrogate_for_joint_distribution(dist,
                                           base_distribution_surrogate_fn,
                                           variables=None,
                                           seed=None):
    """Builds a structured joint surrogate posterior for a joint model."""

    # Probabilistic program for ASVI surrogate posterior.
    flat_variables = dist._model_flatten(variables) if variables else None  # pylint: disable=protected-access
    prior_coroutine = dist._model_coroutine  # pylint: disable=protected-access

    def posterior_generator(seed=seed):
        prior_gen = prior_coroutine()
        dist = next(prior_gen)
        i = 0
        try:
            while True:
                was_root = isinstance(dist, Root)
                if was_root:
                    dist = dist.distribution

                seed, init_seed = samplers.split_seed(seed)
                surrogate_posterior, variables = _asvi_surrogate_for_distribution(
                    dist,
                    base_distribution_surrogate_fn=
                    base_distribution_surrogate_fn,
                    variables=flat_variables[i] if flat_variables else None,
                    seed=init_seed)

                if was_root:
                    surrogate_posterior = Root(surrogate_posterior)
                # If variables were not given---i.e., we're creating new
                # variables---then yield the new variables along with the surrogate
                # posterior. This assumes an execution context such as
                # `_extract_variables_from_coroutine_model` below that will capture and
                # save the variables.
                value_out = yield (surrogate_posterior if flat_variables else
                                   (surrogate_posterior, variables))
                dist = prior_gen.send(value_out)
                i += 1
        except StopIteration:
            pass

    if variables is None:
        # Run the generator to create variables, then call ourselves again
        # to construct the surrogate JD from these variables. Note that we can't
        # just create a JDC from the current `posterior_generator`, because it will
        # try to build new variables on every invocation; the recursive call will
        # define a new `posterior_generator` that knows about the variables we're
        # about to create.
        return _asvi_surrogate_for_joint_distribution(
            dist=dist,
            base_distribution_surrogate_fn=base_distribution_surrogate_fn,
            variables=dist._model_unflatten(  # pylint: disable=protected-access
                _extract_variables_from_coroutine_model(posterior_generator,
                                                        seed=seed)))

    surrogate_posterior = (
        joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched(
            posterior_generator, name=_get_name(dist)))

    # Ensure that the surrogate posterior structure matches that of the prior.
    try:
        tf.nest.assert_same_structure(dist.dtype, surrogate_posterior.dtype)
    except TypeError:
        tokenize = lambda jd: jd._model_unflatten(  # pylint: disable=protected-access, g-long-lambda
            range(len(jd._model_flatten(jd.dtype)))  # pylint: disable=protected-access
        )
        surrogate_posterior = restructure.Restructure(
            output_structure=tokenize(dist),
            input_structure=tokenize(surrogate_posterior))(
                surrogate_posterior, name=_get_name(dist))
    return surrogate_posterior, variables