Exemple #1
0
  def _make_mixture_dist(self, component_logits, locs, scales):
    """Builds a mixture of quantized logistic distributions.

    Args:
      component_logits: 4D `Tensor` of logits for the Categorical distribution
        over Quantized Logistic mixture components. Dimensions are `[batch_size,
        height, width, num_logistic_mix]`.
      locs: 4D `Tensor` of location parameters for the Quantized Logistic
        mixture components. Dimensions are `[batch_size, height, width,
        num_logistic_mix, num_channels]`.
      scales: 4D `Tensor` of location parameters for the Quantized Logistic
        mixture components. Dimensions are `[batch_size, height, width,
        num_logistic_mix, num_channels]`.

    Returns:
      dist: A quantized logistic mixture `tfp.distribution` over the input data.
    """
    mixture_distribution = categorical.Categorical(logits=component_logits)

    # Convert distribution parameters for pixel values in
    # `[self._low, self._high]` for use with `QuantizedDistribution`
    locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.)
    scales *= 0.5 * (self._high - self._low)
    logistic_dist = quantized_distribution.QuantizedDistribution(
        distribution=transformed_distribution.TransformedDistribution(
            distribution=logistic.Logistic(loc=locs, scale=scales),
            bijector=shift.Shift(shift=tf.cast(-0.5, self.dtype))),
        low=self._low, high=self._high)

    dist = mixture_same_family.MixtureSameFamily(
        mixture_distribution=mixture_distribution,
        components_distribution=independent.Independent(
            logistic_dist, reinterpreted_batch_ndims=1))
    return independent.Independent(dist, reinterpreted_batch_ndims=2)
Exemple #2
0
    def _fn(dtype, shape, name, trainable, add_variable_fn):
        """Creates multivariate `Deterministic` or `Normal` distribution.

    Args:
      dtype: Type of parameter's event.
      shape: Python `list`-like representing the parameter's event shape.
      name: Python `str` name prepended to any created (or existing)
        `tf.Variable`s.
      trainable: Python `bool` indicating all created `tf.Variable`s should be
        added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
      add_variable_fn: `tf.get_variable`-like `callable` used to create (or
        access existing) `tf.Variable`s.

    Returns:
      Multivariate `Deterministic` or `Normal` distribution.
    """
        loc, scale = loc_scale_fn(dtype, shape, name, trainable,
                                  add_variable_fn)
        if scale is None:
            dist = deterministic_lib.Deterministic(loc=loc)
        else:
            dist = normal_lib.Normal(loc=loc, scale=scale)
        batch_ndims = tf.size(input=dist.batch_shape_tensor())
        return independent_lib.Independent(
            dist, reinterpreted_batch_ndims=batch_ndims)
Exemple #3
0
def _build_posterior_for_one_parameter(param, batch_shape, seed):
    """Built a transformed-normal variational dist over a parameter's support."""

    # Build a trainable Normal distribution.
    initial_loc = sample_uniform_initial_state(param,
                                               init_sample_shape=batch_shape,
                                               return_constrained=False,
                                               seed=seed)
    loc = tf.Variable(initial_value=initial_loc, name=param.name + '_loc')
    scale = tfp_util.TransformedVariable(
        tf.fill(tf.shape(initial_loc),
                value=tf.constant(0.02, initial_loc.dtype),
                name=param.name + '_scale'), softplus_lib.Softplus())
    posterior_dist = normal_lib.Normal(loc=loc, scale=scale)

    # Ensure the `event_shape` of the variational distribution matches the
    # parameter.
    if (param.prior.event_shape.ndims is None
            or param.prior.event_shape.ndims > 0):
        posterior_dist = independent_lib.Independent(
            posterior_dist,
            reinterpreted_batch_ndims=param.prior.event_shape.ndims)

    # Transform to constrained parameter space.
    posterior_dist = transformed_distribution_lib.TransformedDistribution(
        posterior_dist, param.bijector, name='{}_posterior'.format(param.name))
    return posterior_dist
def _wrap_as_distributions(structure):
  return tf.nest.map_structure(
      lambda x: independent.Independent(  # pylint: disable=g-long-lambda
          deterministic.Deterministic(x),
          # Particles are a batch dimension.
          reinterpreted_batch_ndims=tf.rank(x) - 1),
      structure)
Exemple #5
0
 def inner(dtype,
           shape,
           name,
           trainable,
           add_variable_fn,
           loc=loc,
           scale=scale):
     """Creates multivariate standard `Normal` distribution.
     Args:
         dtype: Type of parameter's event.
         shape: Python `list`-like representing the parameter's event shape.
         name: Python `str` name prepended to any created (or existing)
         `tf.Variable`s.
         trainable: Python `bool` indicating all created `tf.Variable`s should be
         added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
         add_variable_fn: `tf.get_variable`-like `callable` used to create (or
                                                                            access existing) `tf.Variable`s.
                                                                            Returns:
                                                                                Multivariate standard `Normal` distribution.
 """
     del name, trainable, add_variable_fn  # unused
     if loc is None:
         loc = tf.zeros(shape, dtype)
     dist = normal_lib.Normal(loc=loc, scale=dtype.as_numpy_dtype(scale))
     batch_ndims = tf.size(dist.batch_shape_tensor())
     return independent_lib.Independent(
         dist, reinterpreted_batch_ndims=batch_ndims)
Exemple #6
0
 def _fn(dtype, shape, name, trainable, add_variable_fn):
     loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn)
     if scale is None:
         dist = deterministic_lib.Deterministic(loc=loc)
     else:
         dist = normal_lib.Normal(loc=loc, scale=scale)
     batch_ndims = tf2.size(dist.batch_shape_tensor())
     return independent_lib.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
Exemple #7
0
def _asvi_surrogate_for_sample(dist,
                               build_nested_surrogate,
                               sample_shape=None):
    """Builds the surrogate for a `tfd.Sample`-wrapped distribution."""
    dist_sample_shape = distribution_util.expand_to_vector(dist.sample_shape)
    nested_surrogate = yield from build_nested_surrogate(
        dist=dist.distribution,
        sample_shape=(dist_sample_shape if sample_shape is None else ps.concat(
            [sample_shape, dist_sample_shape], axis=0)))
    return independent.Independent(
        nested_surrogate,
        reinterpreted_batch_ndims=ps.rank_from_shape(dist_sample_shape))
def build_trainable_location_scale_distribution(initial_loc,
                                                initial_scale,
                                                event_ndims,
                                                distribution_fn=normal.Normal,
                                                validate_args=False,
                                                name=None):
    """Builds a variational distribution from a location-scale family.

  Args:
    initial_loc: Float `Tensor` initial location.
    initial_scale: Float `Tensor` initial scale.
    event_ndims: Integer `Tensor` number of event dimensions in `initial_loc`.
    distribution_fn: Optional constructor for a `tfd.Distribution` instance
      in a location-scale family. This should have signature `dist =
      distribution_fn(loc, scale, validate_args)`.
      Default value: `tfd.Normal`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e.,
        'build_trainable_location_scale_distribution').

  Returns:
    posterior_dist: A `tfd.Distribution` instance.
  """
    with tf.name_scope(name or 'build_trainable_location_scale_distribution'):
        dtype = dtype_util.common_dtype([initial_loc, initial_scale],
                                        dtype_hint=tf.float32)
        initial_loc = initial_loc * tf.ones(tf.shape(initial_scale),
                                            dtype=dtype)
        initial_scale = initial_scale * tf.ones_like(initial_loc)

        loc = tf.Variable(initial_value=initial_loc, name='loc')
        scale = tfp_util.TransformedVariable(initial_scale,
                                             softplus.Softplus(),
                                             name='scale')
        posterior_dist = distribution_fn(loc=loc,
                                         scale=scale,
                                         validate_args=validate_args)

        # Ensure the distribution has the desired number of event dimensions.
        static_event_ndims = tf.get_static_value(event_ndims)
        if static_event_ndims is None or static_event_ndims > 0:
            posterior_dist = independent.Independent(
                posterior_dist,
                reinterpreted_batch_ndims=event_ndims,
                validate_args=validate_args)

    return posterior_dist
Exemple #9
0
            def create_fixed_gaussian_prior(dtype, shape, name, trainable, add_variable_fn):

                prior_loc = add_variable_fn('kernel_prior_loc', shape=shape, dtype=dtype, trainable=False,
                                                      initializer=prior_kernel_loc_initializer)
                prior_untr_scale = add_variable_fn('kernel_prior_untransformed_scale', shape=(), dtype=dtype,
                                                             trainable=False,
                                                             initializer=prior_kernel_untr_scale_initializer)

                dist = normal_lib.Normal(
                    loc=prior_loc, scale=tf.nn.softplus(prior_untr_scale))

                batch_ndims = tf.size(input=dist.batch_shape_tensor())

                return independent_lib.Independent(
                    dist, reinterpreted_batch_ndims=batch_ndims)
 def _fn(dtype, shape, name, trainable, add_variable_fn):
     loc_init = tf.compat.v1.constant_initializer(loc)
     scale_init = tf.compat.v1.constant_initializer(scale)
     new_loc = add_variable_fn(name=name + '_loc',
                               shape=shape,
                               initializer=loc_init,
                               regularizer=None,
                               constraint=None,
                               dtype=dtype,
                               trainable=isPosterior)
     new_scale = add_variable_fn(name=name + '_scale',
                                 shape=shape,
                                 initializer=scale_init,
                                 regularizer=None,
                                 constraint=None,
                                 dtype=dtype,
                                 trainable=isPosterior)
     dist = normal_lib.Normal(loc=new_loc, scale=new_scale)
     batch_ndims = tf.size(input=dist.batch_shape_tensor())
     return independent_lib.Independent(
         dist, reinterpreted_batch_ndims=batch_ndims)
  def params_and_state_transition_fn(step,
                                     params_and_state,
                                     perturbation_scale,
                                     **kwargs):
    """Transition function operating on a `ParamsAndState` namedtuple."""
    # Extract the state, to pass through to the observation fn.
    unconstrained_params, state = params_and_state
    if 'state_history' in kwargs:
      kwargs['state_history'] = kwargs['state_history'].state

    # Perturb each (unconstrained) parameter with normally-distributed noise.
    if not tf.nest.is_nested(perturbation_scale):
      perturbation_scale = tf.nest.map_structure(
          lambda x: tf.convert_to_tensor(perturbation_scale,  # pylint: disable=g-long-lambda
                                         name='perturbation_scale',
                                         dtype=x.dtype),
          unconstrained_params)
    perturbed_unconstrained_parameter_dists = tf.nest.map_structure(
        lambda x, p, s: independent.Independent(  # pylint: disable=g-long-lambda
            normal.Normal(loc=x, scale=p),
            reinterpreted_batch_ndims=prefer_static.rank_from_shape(s)),
        unconstrained_params,
        perturbation_scale,
        parameter_prior.event_shape_tensor())

    # For the joint transition, pass the perturbed parameters
    # into the original transition fn (after pushing them into constrained
    # space).
    return joint_distribution_named.JointDistributionNamed(
        ParametersAndState(
            unconstrained_parameters=_maybe_build_joint_distribution(
                perturbed_unconstrained_parameter_dists),
            state=lambda unconstrained_parameters: (  # pylint: disable=g-long-lambda
                parameterized_transition_fn(
                    step,
                    state,
                    parameters=parameter_constraining_bijector.forward(
                        unconstrained_parameters),
                    **kwargs))))
    def variational_loss(self,
                         observations,
                         observation_index_points=None,
                         kl_weight=1.,
                         name='variational_loss'):
        """Variational loss for the VGP.

    Given `observations` and `observation_index_points`, compute the
    negative variational lower bound as specified in [Hensman, 2013][1].

    Args:
      observations: `float` `Tensor` representing collection, or batch of
        collections, of observations corresponding to
        `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which
        must be brodcastable with the batch and example shapes of
        `observation_index_points`. The batch shape `[b1, ..., bB]` must be
        broadcastable with the shapes of all other batched parameters
        (`kernel.batch_shape`, `observation_index_points`, etc.).
      observation_index_points: `float` `Tensor` representing finite (batch of)
        vector(s) of points where observations are defined. Shape has the
        form `[b1, ..., bB, e1, f1, ..., fF]` where `F` is the number of feature
        dimensions and must equal `kernel.feature_ndims` and `e1` is the number
        (size) of index points in each batch (we denote it `e1` to distinguish
        it from the numer of inducing index points, denoted `e2` below). If
        set to `None` uses `index_points` as the origin for observations.
        Default value: None.
      kl_weight: Amount by which to scale the KL divergence loss between prior
        and posterior.
        Default value: 1.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "GaussianProcess".
    Returns:
      loss: Scalar tensor representing the negative variational lower bound.
        Can be directly used in a `tf.Optimizer`.
    Raises:
      ValueError: if `mean_fn` is not `None` and is not callable.

    #### References

    [1]: Hensman, J., Lawrence, N. "Gaussian Processes for Big Data", 2013
         https://arxiv.org/abs/1309.6835
    """

        with tf.name_scope(name or 'variational_gp_loss'):
            if observation_index_points is None:
                observation_index_points = self._index_points
            observation_index_points = tf.convert_to_tensor(
                observation_index_points,
                dtype=self._dtype,
                name='observation_index_points')
            observations = tf.convert_to_tensor(observations,
                                                dtype=self._dtype,
                                                name='observations')
            kl_weight = tf.convert_to_tensor(kl_weight,
                                             dtype=self._dtype,
                                             name='kl_weight')

            # The variational loss is a negative ELBO. The ELBO can be broken down
            # into three terms:
            #  1. a likelihood term
            #  2. a trace term arising from the covariance of the posterior predictive

            kzx = self.kernel.matrix(self._inducing_index_points,
                                     observation_index_points)

            kzx_linop = tf.linalg.LinearOperatorFullMatrix(kzx)
            loc = (self._mean_fn(observation_index_points) +
                   kzx_linop.matvec(self._kzz_inv_varloc, adjoint=True))

            likelihood = independent.Independent(normal.Normal(
                loc=loc,
                scale=tf.sqrt(self._observation_noise_variance + self._jitter),
                name='NormalLikelihood'),
                                                 reinterpreted_batch_ndims=1)
            obs_ll = likelihood.log_prob(observations)

            chol_kzz_linop = tf.linalg.LinearOperatorLowerTriangular(
                self._chol_kzz)
            chol_kzz_inv_kzx = chol_kzz_linop.solve(kzx)
            kzz_inv_kzx = chol_kzz_linop.solve(chol_kzz_inv_kzx, adjoint=True)

            kxx_diag = self.kernel.apply(observation_index_points,
                                         observation_index_points,
                                         example_ndims=1)
            ktilde_trace_term = (
                tf.reduce_sum(kxx_diag, axis=-1) -
                tf.reduce_sum(chol_kzz_inv_kzx**2, axis=[-2, -1]))

            # Tr(SB)
            # where S = A A.T, A = variational_inducing_observations_scale
            # and B = Kzz^-1 Kzx Kzx.T Kzz^-1
            #
            # Now Tr(SB) = Tr(A A.T Kzz^-1 Kzx Kzx.T Kzz^-1)
            #            = Tr(A.T Kzz^-1 Kzx Kzx.T Kzz^-1 A)
            #            = sum_ij (A.T Kzz^-1 Kzx)_{ij}^2
            other_trace_term = tf.reduce_sum(
                (self._variational_inducing_observations_posterior.scale.
                 matmul(kzz_inv_kzx)**2),
                axis=[-2, -1])

            trace_term = (.5 * (ktilde_trace_term + other_trace_term) /
                          self._observation_noise_variance)

            kl_term = kl_weight * self.surrogate_posterior_kl_divergence_prior(
            )

            lower_bound = (obs_ll - trace_term - kl_term)

            return -tf.reduce_mean(lower_bound)
 def _batched_isotropic_normal_like(state_part):
     event_ndims = ps.rank(state_part) - batch_rank
     return independent.Independent(
         normal.Normal(ps.zeros_like(state_part, tf.float32), 1.),
         reinterpreted_batch_ndims=event_ndims)
Exemple #14
0
def _asvi_surrogate_for_independent(dist, build_nested_surrogate):
    """Builds the surrogate for a `tfd.Independent`-wrapped distribution."""
    nested_surrogate = yield from build_nested_surrogate(dist.distribution)
    return independent.Independent(
        nested_surrogate,
        reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims)
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  batch_ndims=None,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    batch_ndims: Optional integer `Tensor` number of leftmost batch dimensions
      shared across all members of the input structure. If this is specified,
      the returned joint distribution will be an autobatched distribution with
      the given batch rank, and all other dimensions absorbed into the event.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        dist = structure_of_distributions
        if batch_ndims is not None:
            excess_ndims = ps.rank_from_shape(
                dist.batch_shape_tensor()) - batch_ndims
            if tf.get_static_value(
                    excess_ndims) != 0:  # Static value may be None.
                dist = independent.Independent(
                    dist, reinterpreted_batch_ndims=excess_ndims)
        return dist

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            functools.partial(independent_joint_distribution_from_structure,
                              batch_ndims=batch_ndims,
                              validate_args=validate_args),
            structure_of_distributions)

    jdnamed = joint_distribution_named.JointDistributionNamed
    jdsequential = joint_distribution_sequential.JointDistributionSequential
    # Use an autobatched JD if a specific batch rank was requested.
    if batch_ndims is not None:
        jdnamed = functools.partial(
            joint_distribution_auto_batched.JointDistributionNamedAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)
        jdsequential = functools.partial(
            joint_distribution_auto_batched.
            JointDistributionSequentialAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict') or isinstance(
            structure_of_distributions, collections.abc.Mapping)):
        return jdnamed(structure_of_distributions, validate_args=validate_args)
    return jdsequential(structure_of_distributions,
                        validate_args=validate_args)
Exemple #16
0
def _asvi_surrogate_for_distribution(dist,
                                     base_distribution_surrogate_fn,
                                     sample_shape=None,
                                     variables=None,
                                     seed=None):
    """Recursively creates ASVI surrogates, and creates new variables if needed.

  Args:
    dist: a `tfd.Distribution` instance.
    base_distribution_surrogate_fn: Callable to build a surrogate posterior
      for a 'base' (non-meta and non-joint) distribution, with signature
      `surrogate_posterior, variables = base_distribution_fn(
      dist, sample_shape=None, variables=None, seed=None)`.
    sample_shape: Optional `Tensor` shape of samples drawn from `dist` by
      `tfd.Sample` wrappers. If not `None`, the surrogate's event will include
      independent sample dimensions, i.e., it will have event shape
      `concat([sample_shape, dist.event_shape], axis=0)`.
      Default value: `None`.
    variables: Optional nested structure of `tf.Variable`s returned from a
      previous call to `_asvi_surrogate_for_distribution`. If `None`,
      new variables will be created; otherwise, constructs a surrogate posterior
      backed by the passed-in variables.
      Default value: `None`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
  Returns:
    surrogate_posterior: Instance of `tfd.Distribution` representing a trainable
      surrogate posterior distribution, with the same structure and `name` as
      `dist`.
    variables: Nested structure of `tf.Variable` trainable parameters for the
      surrogate posterior. If `dist` is a base distribution, this is
      a `dict` of `ASVIParameters` instances. If `dist` is a joint
      distribution, this is a `dist.dtype` structure of such `dict`s.
  """
    # Pass args to any nested surrogates.
    build_nested_surrogate = functools.partial(
        _asvi_surrogate_for_distribution,
        base_distribution_surrogate_fn=base_distribution_surrogate_fn,
        sample_shape=sample_shape,
        seed=seed)

    # Apply any substitutions, while attempting to preserve the original name.
    dist = _set_name(_as_substituted_distribution(dist), name=_get_name(dist))

    # Handle wrapper ("meta") distributions.
    if isinstance(dist, markov_chain.MarkovChain):
        return _asvi_surrogate_for_markov_chain(
            dist=dist,
            variables=variables,
            base_distribution_surrogate_fn=base_distribution_surrogate_fn,
            sample_shape=sample_shape,
            seed=seed)
    if isinstance(dist, sample.Sample):
        dist_sample_shape = distribution_util.expand_to_vector(
            dist.sample_shape)
        nested_surrogate, variables = build_nested_surrogate(  # pylint: disable=redundant-keyword-arg
            dist=dist.distribution,
            variables=variables,
            sample_shape=(dist_sample_shape if sample_shape is None
                          else ps.concat([sample_shape, dist_sample_shape],
                                         axis=0)))
        surrogate_posterior = independent.Independent(
            nested_surrogate,
            reinterpreted_batch_ndims=ps.rank_from_shape(dist_sample_shape),
            name=_get_name(dist))
    # Treat distributions that subclass TransformedDistribution with their own
    # parameters (e.g., Gumbel, Weibull, MultivariateNormal*, etc) as their
    # own type of base distribution, rather than as explicit TDs.
    elif type(dist) == transformed_distribution.TransformedDistribution:  # pylint: disable=unidiomatic-typecheck
        nested_surrogate, variables = build_nested_surrogate(
            dist.distribution, variables=variables)
        surrogate_posterior = transformed_distribution.TransformedDistribution(
            nested_surrogate, bijector=dist.bijector, name=_get_name(dist))
    elif isinstance(dist, independent.Independent):
        nested_surrogate, variables = build_nested_surrogate(
            dist.distribution, variables=variables)
        surrogate_posterior = independent.Independent(
            nested_surrogate,
            reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims,
            name=_get_name(dist))
    elif hasattr(dist, '_model_coroutine'):
        surrogate_posterior, variables = _asvi_surrogate_for_joint_distribution(
            dist,
            base_distribution_surrogate_fn=base_distribution_surrogate_fn,
            variables=variables,
            seed=seed)
    elif (hasattr(dist, 'distribution') and
          # Transformed dists not handled above are treated as base distributions.
          not isinstance(dist,
                         transformed_distribution.TransformedDistribution)):
        raise ValueError('Meta-distribution `{}` is not yet supported by this '
                         'implementation of ASVI. Contact '
                         '`[email protected]` if you need this '
                         'functionality.'.format(type(dist)))
    else:
        surrogate_posterior, variables = base_distribution_surrogate_fn(
            dist=dist,
            sample_shape=sample_shape,
            variables=variables,
            seed=seed)
    return surrogate_posterior, variables
    def posterior_generator():

      prior_gen = prior._model_coroutine()  # pylint: disable=protected-access
      dist = next(prior_gen)

      i = 0
      try:
        while True:
          original_dist = dist.distribution if isinstance(dist, Root) else dist

          if isinstance(original_dist, joint_distribution.JointDistribution):
            # TODO(kateslin): Build inner JD surrogate in
            # _make_asvi_trainable_variables to avoid rebuilding variables.
            raise TypeError(
                'Argument `prior` cannot be a nested `JointDistribution`.')

          else:

            original_dist = _as_trainable_family(original_dist)

            try:
              actual_dist = original_dist.distribution
            except AttributeError:
              actual_dist = original_dist

            dist_params = actual_dist.parameters
            temp_params_dict = {}

            for param, value in dist_params.items():
              if param in (_NON_STATISTICAL_PARAMS +
                           _NON_TRAINABLE_PARAMS) or value is None:
                temp_params_dict[param] = value
              else:
                prior_weight = param_dicts[i][param].prior_weight
                mean_field_parameter = param_dicts[i][
                    param].mean_field_parameter
                if mean_field:
                  temp_params_dict[param] = mean_field_parameter
                else:
                  temp_params_dict[param] = prior_weight * value + (
                      1. - prior_weight) * mean_field_parameter

            if isinstance(original_dist, sample.Sample):
              inner_dist = type(actual_dist)(**temp_params_dict)

              surrogate_dist = independent.Independent(
                  inner_dist,
                  reinterpreted_batch_ndims=ps.rank_from_shape(
                      original_dist.sample_shape))
            else:
              surrogate_dist = type(actual_dist)(**temp_params_dict)

            if isinstance(original_dist,
                          transformed_distribution.TransformedDistribution):
              surrogate_dist = transformed_distribution.TransformedDistribution(
                  surrogate_dist, bijector=original_dist.bijector)

            if isinstance(original_dist, independent.Independent):
              surrogate_dist = independent.Independent(
                  surrogate_dist,
                  reinterpreted_batch_ndims=original_dist
                  .reinterpreted_batch_ndims)

            if isinstance(dist, Root):
              value_out = yield Root(surrogate_dist)
            else:
              value_out = yield surrogate_dist

          dist = prior_gen.send(value_out)
          i += 1
      except StopIteration:
        pass
def extended_kalman_filter_one_step(
    state, observation, transition_fn, observation_fn,
    transition_jacobian_fn, observation_jacobian_fn, name=None):
  """A single step of the EKF.

  Args:
    state: A `Tensor` of shape
      `concat([[num_timesteps, b1, ..., bN], [state_size]])` with scalar
      `event_size` and optional batch dimensions `b1, ..., bN`.
    observation: A `Tensor` of shape
      `concat([[num_timesteps, b1, ..., bN], [event_size]])` with scalar
      `event_size` and optional batch dimensions `b1, ..., bN`.
    transition_fn: a Python `callable` that accepts (batched) vectors of length
      `state_size`, and returns a `tfd.Distribution` instance, typically a
      `MultivariateNormal`, representing the state transition and covariance.
    observation_fn: a Python `callable` that accepts a (batched) vector of
      length `state_size` and returns a `tfd.Distribution` instance, typically
      a `MultivariateNormal` representing the observation model and covariance.
    transition_jacobian_fn: a Python `callable` that accepts a (batched) vector
      of length `state_size` and returns a (batched) matrix of shape
      `[state_size, state_size]`, representing the Jacobian of `transition_fn`.
    observation_jacobian_fn: a Python `callable` that accepts a (batched) vector
      of length `state_size` and returns a (batched) matrix of size
      `[state_size, event_size]`, representing the Jacobian of `observation_fn`.
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'extended_kalman_filter_one_step'`).
  Returns:
    updated_state: `KalmanFilterState` object containing the updated state
      estimate.
  """
  with tf.name_scope(name or 'extended_kalman_filter_one_step') as name:

    # If observations are scalar, we can avoid some matrix ops.
    observation_size_is_static_and_scalar = (observation.shape[-1] == 1)

    current_state = state.filtered_mean
    current_covariance = state.filtered_cov
    current_jacobian = transition_jacobian_fn(current_state)
    state_prior = transition_fn(current_state)

    predicted_cov = (tf.matmul(
        current_jacobian,
        tf.matmul(current_covariance, current_jacobian, transpose_b=True)) +
                     state_prior.covariance())
    predicted_mean = state_prior.mean()

    observation_dist = observation_fn(predicted_mean)
    observation_mean = observation_dist.mean()
    observation_cov = observation_dist.covariance()

    predicted_jacobian = observation_jacobian_fn(predicted_mean)
    tmp_obs_cov = tf.matmul(predicted_jacobian, predicted_cov)
    residual_covariance = tf.matmul(
        predicted_jacobian, tmp_obs_cov, transpose_b=True) + observation_cov

    if observation_size_is_static_and_scalar:
      gain_transpose = tmp_obs_cov / residual_covariance
    else:
      chol_residual_cov = tf.linalg.cholesky(residual_covariance)
      gain_transpose = tf.linalg.cholesky_solve(chol_residual_cov, tmp_obs_cov)

    filtered_mean = predicted_mean + tf.matmul(
        gain_transpose,
        (observation - observation_mean)[..., tf.newaxis],
        transpose_a=True)[..., 0]

    tmp_term = -tf.matmul(predicted_jacobian, gain_transpose, transpose_a=True)
    tmp_term = tf.linalg.set_diag(tmp_term, tf.linalg.diag_part(tmp_term) + 1.)
    filtered_cov = (
        tf.matmul(
            tmp_term, tf.matmul(predicted_cov, tmp_term), transpose_a=True) +
        tf.matmul(gain_transpose,
                  tf.matmul(observation_cov, gain_transpose), transpose_a=True))

    if observation_size_is_static_and_scalar:
      # A plain Normal would have event shape `[]`; wrapping with Independent
      # ensures `event_shape=[1]` as required.
      predictive_dist = independent.Independent(
          normal.Normal(loc=observation_mean,
                        scale=tf.sqrt(residual_covariance[..., 0])),
          reinterpreted_batch_ndims=1)

    else:
      predictive_dist = mvn_tril.MultivariateNormalTriL(
          loc=observation_mean,
          scale_tril=chol_residual_cov)

    log_marginal_likelihood = predictive_dist.log_prob(observation)

    return linear_gaussian_ssm.KalmanFilterState(
        filtered_mean=filtered_mean,
        filtered_cov=filtered_cov,
        predicted_mean=predicted_mean,
        predicted_cov=predicted_cov,
        observation_mean=observation_mean,
        observation_cov=observation_cov,
        log_marginal_likelihood=log_marginal_likelihood,
        timestep=state.timestep + 1)