def _check_weights_shape(self, log_weights, sample_shape):
     assertions = []
     message = (
         'Shape of importance weights does not match the batch shape '
         'of `self.proposal_distribution`. This implies that '
         'the proposal is not producing independent samples for some '
         'batch dimension(s) expected by `self.target_log_prob_fn`.')
     sample_and_batch_shape = ps.concat(
         [
             sample_shape,
             _get_joint_batch_shape(
                 self.proposal_distribution.batch_shape_tensor())
         ],
         axis=0,
     )
     sample_and_batch_shape_ = tf.get_static_value(sample_and_batch_shape)
     if (sample_and_batch_shape_ is not None
             and not tensorshape_util.is_compatible_with(
                 log_weights.shape, sample_and_batch_shape_)):
         raise ValueError(
             message +
             ' Saw: weights shape {} vs proposal sample and batch '
             'shape {}.'.format(log_weights.shape, sample_and_batch_shape))
     elif self.validate_args:
         assertions += [
             assert_util.assert_equal(tf.shape(log_weights),
                                      sample_and_batch_shape,
                                      message=message)
         ]
     return assertions
Esempio n. 2
0
def _maybe_broadcast(a, b):
  if not (tensorshape_util.is_fully_defined(a.shape) and
          tensorshape_util.is_fully_defined(b.shape) and
          tensorshape_util.is_compatible_with(a.shape, b.shape)):
    # If both shapes are well defined and equal, we skip broadcasting.
    b = b + tf.zeros_like(a)
    a = a + tf.zeros_like(b)
  return a, b
Esempio n. 3
0
  def _cdf(self, counts):
    probs = self._probs_parameter_no_checks()
    if not (tensorshape_util.is_fully_defined(counts.shape) and
            tensorshape_util.is_fully_defined(probs.shape) and
            tensorshape_util.is_compatible_with(counts.shape, probs.shape)):
      # If both shapes are well defined and equal, we skip broadcasting.
      probs = probs + tf.zeros_like(counts)
      counts = counts + tf.zeros_like(probs)

    return _bdtr(k=counts, n=tf.convert_to_tensor(self.total_count), p=probs)
Esempio n. 4
0
    def _cdf(self, counts):
        counts = self._maybe_assert_valid_sample(counts)
        probs = self.probs
        if not (tensorshape_util.is_fully_defined(counts.shape)
                and tensorshape_util.is_fully_defined(self.probs.shape)
                and tensorshape_util.is_compatible_with(
                    counts.shape, self.probs.shape)):
            # If both shapes are well defined and equal, we skip broadcasting.
            probs += tf.zeros_like(counts)
            counts += tf.zeros_like(self.probs)

        return _bdtr(k=counts, n=self.total_count, p=probs)
 def _value(self, dtype=None, name=None, as_ref=False):
     y = self.transform_fn(self.pretransformed_input)  # pylint: disable=not-callable
     if dtype_util.base_dtype(y.dtype) != self.dtype:
         raise TypeError(
             'Actual dtype ({}) does not match deferred dtype ({}).'.format(
                 dtype_util.name(dtype_util.base_dtype(y.dtype)),
                 dtype_util.name(self.dtype)))
     if not tensorshape_util.is_compatible_with(y.shape, self.shape):
         raise TypeError(
             'Actual shape ({}) is incompatible with deferred shape ({}).'.
             format(y.shape, self.shape))
     return tf.convert_to_tensor(y, dtype=dtype, name=name)
Esempio n. 6
0
def _assert_same_shape(x, y,
                       message='Shapes do not match.',
                       validate_args=False):
  """Asserts (statically if possible) that two Tensor have the same shape."""
  if not tensorshape_util.is_compatible_with(x.shape, y.shape):
    raise ValueError(message +
                     ' Saw shapes: {} vs {}.'.format(x.shape, y.shape))

  assertions = []
  if validate_args and not (tensorshape_util.is_fully_defined(x.shape) and
                            tensorshape_util.is_fully_defined(y.shape)):
    assertions.append(
        assert_util.assert_equal(
            tf.shape(x), tf.shape(y), message=message))
  return assertions
Esempio n. 7
0
    def __init__(self,
                 cat,
                 components,
                 validate_args=False,
                 allow_nan_stats=True,
                 use_static_graph=False,
                 name='Mixture'):
        """Initialize a Mixture distribution.

    A `Mixture` is defined by a `Categorical` (`cat`, representing the
    mixture probabilities) and a list of `Distribution` objects
    all having matching dtype, batch shape, event shape, support, and continuity
    properties (the components).

    The `num_classes` of `cat` must be possible to infer at graph construction
    time and match `len(components)`.

    Args:
      cat: A `Categorical` distribution instance, representing the probabilities
          of `distributions`.
      components: A list or tuple of `Distribution` instances.
        Each instance must have the same type, be defined on the same domain,
        and have matching `event_shape` and `batch_shape`.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between cat and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      use_static_graph: Calls to `sample` will not rely on dynamic tensor
        indexing, allowing for some static graph compilation optimizations, but
        at the expense of sampling all underlying distributions in the mixture.
        (Possibly useful when running on TPUs).
        Default value: `False` (i.e., use dynamic indexing).
      name: A name for this distribution (optional).

    Raises:
      TypeError: If cat is not a `Categorical`, or `components` is not
        a list or tuple, or the elements of `components` are not
        instances of `Distribution`, or do not have matching `dtype`.
      ValueError: If `components` is an empty list or tuple, or its
        elements do not have a statically known event rank.
        If `cat.num_classes` cannot be inferred at graph creation time,
        or the constant value of `cat.num_classes` is not equal to
        `len(components)`, or all `components` and `cat` do not have
        matching static batch shapes, or all components do not
        have matching static event shapes.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:

            if not isinstance(cat, categorical.Categorical):
                raise TypeError(
                    'cat must be a Categorical distribution, but saw: %s' %
                    cat)
            if not components:
                raise ValueError(
                    'components must be a non-empty list or tuple')
            if not isinstance(components, (list, tuple)):
                raise TypeError(
                    'components must be a list or tuple, but saw: %s' %
                    components)
            if not all(
                    isinstance(c, distribution.Distribution)
                    for c in components):
                raise TypeError(
                    'all entries in components must be Distribution instances'
                    ' but saw: %s' % components)

            dtype = components[0].dtype
            if not all(d.dtype == dtype for d in components):
                raise TypeError(
                    'All components must have the same dtype, but saw '
                    'dtypes: %s' % [(d.name, d.dtype) for d in components])

            static_event_shape = components[0].event_shape
            static_batch_shape = cat.batch_shape
            for di, d in enumerate(components):
                if not tensorshape_util.is_compatible_with(
                        static_batch_shape, d.batch_shape):
                    raise ValueError(
                        'components[{}] batch shape must be compatible with cat '
                        'shape and other component batch shapes'.format(di))
                static_event_shape = tensorshape_util.merge_with(
                    static_event_shape, d.event_shape)
                static_batch_shape = tensorshape_util.merge_with(
                    static_batch_shape, d.batch_shape)
            if tensorshape_util.rank(static_event_shape) is None:
                raise ValueError(
                    'Expected to know rank(event_shape) from components, but '
                    'none of the components provide a static number of ndims')

            # pylint: disable=protected-access
            cat_dist_param = cat._probs if cat._logits is None else cat._logits
            # pylint: enable=protected-access
            static_num_components = tf.compat.dimension_value(
                cat_dist_param.shape[-1])
            if static_num_components is None:
                raise ValueError(
                    'Could not infer number of classes from cat and unable '
                    'to compare this value to the number of components passed in.'
                )
            if static_num_components != len(components):
                raise ValueError(
                    'cat.num_classes != len(components): %d vs. %d' %
                    (static_num_components, len(components)))

            self._cat = cat
            self._components = list(components)
            self._num_components = static_num_components
            self._static_event_shape = static_event_shape
            self._static_batch_shape = static_batch_shape
            self._use_static_graph = use_static_graph

            super(Mixture, self).__init__(
                dtype=dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
Esempio n. 8
0
    def __init__(self,
                 cat,
                 components,
                 validate_args=False,
                 allow_nan_stats=True,
                 use_static_graph=False,
                 name="Mixture"):
        """Initialize a Mixture distribution.

    A `Mixture` is defined by a `Categorical` (`cat`, representing the
    mixture probabilities) and a list of `Distribution` objects
    all having matching dtype, batch shape, event shape, and continuity
    properties (the components).

    The `num_classes` of `cat` must be possible to infer at graph construction
    time and match `len(components)`.

    Args:
      cat: A `Categorical` distribution instance, representing the probabilities
          of `distributions`.
      components: A list or tuple of `Distribution` instances.
        Each instance must have the same type, be defined on the same domain,
        and have matching `event_shape` and `batch_shape`.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between cat and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      use_static_graph: Calls to `sample` will not rely on dynamic tensor
        indexing, allowing for some static graph compilation optimizations, but
        at the expense of sampling all underlying distributions in the mixture.
        (Possibly useful when running on TPUs).
        Default value: `False` (i.e., use dynamic indexing).
      name: A name for this distribution (optional).

    Raises:
      TypeError: If cat is not a `Categorical`, or `components` is not
        a list or tuple, or the elements of `components` are not
        instances of `Distribution`, or do not have matching `dtype`.
      ValueError: If `components` is an empty list or tuple, or its
        elements do not have a statically known event rank.
        If `cat.num_classes` cannot be inferred at graph creation time,
        or the constant value of `cat.num_classes` is not equal to
        `len(components)`, or all `components` and `cat` do not have
        matching static batch shapes, or all components do not
        have matching static event shapes.
    """
        parameters = dict(locals())
        if not isinstance(cat, categorical.Categorical):
            raise TypeError(
                "cat must be a Categorical distribution, but saw: %s" % cat)
        if not components:
            raise ValueError("components must be a non-empty list or tuple")
        if not isinstance(components, (list, tuple)):
            raise TypeError("components must be a list or tuple, but saw: %s" %
                            components)
        if not all(
                isinstance(c, distribution.Distribution) for c in components):
            raise TypeError(
                "all entries in components must be Distribution instances"
                " but saw: %s" % components)

        dtype = components[0].dtype
        if not all(d.dtype == dtype for d in components):
            raise TypeError("All components must have the same dtype, but saw "
                            "dtypes: %s" % [(d.name, d.dtype)
                                            for d in components])
        static_event_shape = components[0].event_shape
        static_batch_shape = cat.batch_shape
        for di, d in enumerate(components):
            if not tensorshape_util.is_compatible_with(static_batch_shape,
                                                       d.batch_shape):
                raise ValueError(
                    "components[{}] batch shape must be compatible with cat "
                    "shape and other component batch shapes".format(di))
            static_event_shape = tensorshape_util.merge_with(
                static_event_shape, d.event_shape)
            static_batch_shape = tensorshape_util.merge_with(
                static_batch_shape, d.batch_shape)
        if tensorshape_util.rank(static_event_shape) is None:
            raise ValueError(
                "Expected to know rank(event_shape) from components, but "
                "none of the components provide a static number of ndims")

        # Ensure that all batch and event ndims are consistent.
        with tf.name_scope(name) as name:
            num_components = cat._num_categories()
            static_num_components = tf.get_static_value(num_components)
            if static_num_components is None:
                raise ValueError(
                    "Could not infer number of classes from cat and unable "
                    "to compare this value to the number of components passed in."
                )
            # Possibly convert from numpy 0-D array.
            static_num_components = int(static_num_components)
            if static_num_components != len(components):
                raise ValueError(
                    "cat.num_classes != len(components): %d vs. %d" %
                    (static_num_components, len(components)))

            cat_batch_shape = cat.batch_shape_tensor()
            cat_batch_rank = tf.size(cat_batch_shape)
            if validate_args:
                batch_shapes = [d.batch_shape_tensor() for d in components]
                batch_ranks = [tf.size(bs) for bs in batch_shapes]
                check_message = ("components[%d] batch shape must match cat "
                                 "batch shape")
                self._assertions = [
                    assert_util.assert_equal(cat_batch_rank,
                                             batch_ranks[di],
                                             message=check_message % di)
                    for di in range(len(components))
                ]
                self._assertions += [
                    assert_util.assert_equal(cat_batch_shape,
                                             batch_shapes[di],
                                             message=check_message % di)
                    for di in range(len(components))
                ]
            else:
                self._assertions = []

            self._cat = cat
            self._components = list(components)
            self._num_components = static_num_components
            self._static_event_shape = static_event_shape
            self._static_batch_shape = static_batch_shape

            self._use_static_graph = use_static_graph
            if use_static_graph and static_num_components is None:
                raise ValueError(
                    "Number of categories must be known statically when "
                    "`static_sample=True`.")
        # We let the Mixture distribution access _graph_parents since its arguably
        # more like a baseclass.
        graph_parents = self._cat._graph_parents  # pylint: disable=protected-access
        for c in self._components:
            graph_parents += c._graph_parents  # pylint: disable=protected-access

        super(Mixture, self).__init__(
            dtype=dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=graph_parents,
            name=name)
Esempio n. 9
0
    def __init__(self,
                 parameter_prior,
                 parameterized_initial_state_prior_fn,
                 parameterized_transition_fn,
                 parameterized_observation_fn,
                 parameterized_initial_state_proposal_fn=None,
                 parameterized_proposal_fn=None,
                 parameter_constraining_bijector=None,
                 name=None):
        """Builds an iterated filter for parameter estimation in sequential models.

    Iterated filtering is a parameter estimation method in which parameters
    are included in an augmented state space, with dynamics that introduce
    parameter perturbations, and a filtering
    algorithm such as particle filtering is run several times with perturbations
    of decreasing size. This class implements the IF2 algorithm of
    [Ionides et al., 2015][1], for which, under appropriate conditions
    (including a uniform prior) the final parameter distribution approaches a
    point mass at the maximum likelihood estimate. If a non-uniform prior is
    provided, the final parameter distribution will (under appropriate
    conditions) approach a point mass at the maximum a posteriori (MAP) value.

    This class augments the state space of a sequential model to include
    parameter perturbations, and provides utilities to run particle filtering
    on that augmented model. Alternately, the augmented components may be passed
    directly into a filtering algorithm of the user's choice.

    Args:
      parameter_prior: prior `tfd.Distribution` over parameters (may be a joint
        distribution).
      parameterized_initial_state_prior_fn: `callable` with signature
        `initial_state_prior = parameterized_initial_state_prior_fn(parameters)`
        where `parameters` has the form of a sample from `parameter_prior`,
        and `initial_state_prior` is a distribution over the initial state.
      parameterized_transition_fn: `callable` with signature
        `next_state_dist = parameterized_transition_fn(
        step, state, parameters, **kwargs)`.
      parameterized_observation_fn: `callable` with signature
        `observation_dist = parameterized_observation_fn(
        step, state, parameters, **kwargs)`.
      parameterized_initial_state_proposal_fn: optional `callable` with
        signature `initial_state_proposal =
        parameterized_initial_state_proposal_fn(parameters)` where `parameters`
        has the form of a sample from `parameter_prior`, and
        `initial_state_proposal` is a distribution over the initial state.
      parameterized_proposal_fn: optional `callable` with signature
        `next_state_dist = parameterized_transition_fn(
        step, state, parameters, **kwargs)`.
        Default value: `None`.
      parameter_constraining_bijector: optional `tfb.Bijector` instance
        such that `parameter_constraining_bijector.forward(x)` returns valid
        parameters for any real-valued `x` of the same structure and shape
        as `parameters`. If `None`, the default bijector of the provided
        `parameter_prior` will be used.
        Default value: `None`.
      name: `str` name for ops constructed by this object.
        Default value: `iterated_filter`.

    #### Example

    We'll walk through applying iterated filtering to a toy
    Susceptible-Infected-Recovered (SIR) model, a [compartmental model](
    https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#The_SIR_model)
    of infectious disease. Note that the model we use here is extremely
    simplified and is intended as a pedagogical example; it should not be
    interpreted to describe disease spread in the real world.

    We begin by specifying a prior distribution over the parameters to be
    inferred, thus defining the structure of the parameter space and the support
    of the parameters (which will imply a default constraining bijector). Here
    we'll use uniform priors over ranges that we expect to contain the
    parameters:

    ```python
    parameter_prior = tfd.JointDistributionNamed({
        'infection_rate': tfd.Uniform(low=0., high=3.),
        'recovery_rate': tfd.Uniform(low=0., high=3.),
    })
    ```

    The model specification itself is identical to that used by
    `tfp.experimental.mcmc.infer_trajectories`, except that each component
    accepts an additional `parameters` keyword argument. We start by specifying
    a parameterized prior on initial states. In this case, our state
    includes the current number of susceptible and infected individuals
    (the third compartment, recovered individuals, is implicitly defined
    to include the remaining population). We'll also include, as auxiliary
    variables, the daily counts of new infections and new recoveries; these
    will help ensure that people shift consistently across compartments.

    ```python
    population_size = 1000
    initial_state_prior_fn = lambda parameters: tfd.JointDistributionNamed({
        'new_infections': tfd.Poisson(parameters['infection_rate']),
        'new_recoveries': tfd.Deterministic(
            tf.broadcast_to(0., tf.shape(parameters['recovery_rate']))),
        'susceptible': (lambda new_infections:
                        tfd.Deterministic(population_size - new_infections)),
        'infected': (lambda new_infections:
                     tfd.Deterministic(new_infections))})
    ```

    **Note**: the state prior must have the same batch shape as the
    passed-in parameters; equivalently, it must sample a full state for each
    parameter particle. If any part of the state prior does not depend
    on the parameters, you must manually ensure that it has the appropriate
    batch shape. For example, in the definition of `new_recoveries` above,
    applying `broadcast_to` with the shape of a parameter ensures that
    the batch shape is maintained.

    Next, we specify a transition model. This takes the state at the
    previous day, along with parameters, and returns a distribution
    over the state for the current day.

    ```python
    def parameterized_infection_dynamics(_, previous_state, parameters):
      new_infections = tfd.Poisson(
          parameters['infection_rate'] * previous_state['infected'] *
          previous_state['susceptible'] / population_size)
      new_recoveries = tfd.Poisson(
          previous_state['infected'] * parameters['recovery_rate'])
      return tfd.JointDistributionNamed({
          'new_infections': new_infections,
          'new_recoveries': new_recoveries,
          'susceptible': lambda new_infections: tfd.Deterministic(
            tf.maximum(0., previous_state['susceptible'] - new_infections)),
          'infected': lambda new_infections, new_recoveries: tfd.Deterministic(
            tf.maximum(0.,
                       (previous_state['infected'] +
                        new_infections - new_recoveries)))})
    ```

    Finally, assume that every day we get to observe noisy counts of new
    infections and recoveries.

    ```python
    def parameterized_infection_observations(_, state, parameters):
      del parameters  # Not used.
      return tfd.JointDistributionNamed({
          'new_infections': tfd.Poisson(state['new_infections'] + 0.1),
          'new_recoveries': tfd.Poisson(state['new_recoveries'] + 0.1)})
    ```

    Combining these components, an `IteratedFilter` augments
    the state space to include parameters that may change over time.

    ```python
    iterated_filter = tfp.experimental.sequential.IteratedFilter(
      parameter_prior=parameter_prior,
      parameterized_initial_state_prior_fn=initial_state_prior_fn,
      parameterized_transition_fn=parameterized_infection_dynamics,
      parameterized_observation_fn=parameterized_infection_observations)
    ```

    We may then run the filter to estimate parameters from a series
    of observations:

    ```python
     # Simulated with `infection_rate=1.2` and `recovery_rate=0.1`.
     observed_values = {
       'new_infections': tf.convert_to_tensor([
          2., 7., 14., 24., 45., 93., 160., 228., 252., 158.,  17.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
       'new_recoveries': tf.convert_to_tensor([
          0., 0., 3., 4., 3., 8., 12., 31., 49., 73., 85., 65., 71.,
          58., 42., 65., 36., 31., 32., 27., 31., 20., 19., 19., 14., 27.])
     }
     parameter_particles = iterated_filter.estimate_parameters(
         observations=observed_values,
         num_iterations=20,
         num_particles=4096,
         initial_perturbation_scale=1.0,
         cooling_schedule=(
             tfp.experimental.sequential.geometric_cooling_schedule(
                 0.001, k=20)),
         seed=test_util.test_seed())
     print('Mean of parameter particles from final iteration: {}'.format(
       tf.nest.map_structure(lambda x: tf.reduce_mean(x[-1], axis=0),
                             parameter_particles)))
     print('Standard deviation of parameter particles from '
           'final iteration: {}'.format(
           tf.nest.map_structure(lambda x: tf.math.reduce_std(x[-1], axis=0),
                                 parameter_particles)))
    ```

    For more control, we could alternately choose to run filtering iterations
    on the augmented model manually, using the filter of our choice.
    For example, manually invoking `infer_trajectories` would allow us
    to inspect the parameter and state values at all timesteps, and their
    corresponding log-probabilities:

    ```python
    trajectories, lps = tfp.experimental.mcmc.infer_trajectories(
      observations=observations,
      initial_state_prior=iterated_filter.joint_initial_state_prior,
      transition_fn=functools.partial(
          iterated_filter.joint_transition_fn,
          perturbation_scale=perturbation_scale),
      observation_fn=iterated_filter.joint_observation_fn,
      proposal_fn=iterated_filter.joint_proposal_fn,
      initial_state_proposal=iterated_filter.joint_initial_state_proposal(
          initial_unconstrained_parameters),
      num_particles=4096)
    ```

    #### References:

    [1] Edward L. Ionides, Dao Nguyen, Yves Atchade, Stilian Stoev, and Aaron A.
    King. Inference for dynamic and latent variable models via iterated,
    perturbed Bayes maps. _Proceedings of the National Academy of Sciences_
    112, no. 3: 719-724, 2015.
    https://www.pnas.org/content/pnas/112/3/719.full.pdf
    """
        name = name or 'IteratedFilter'
        with tf.name_scope(name):
            self._parameter_prior = parameter_prior
            self._parameterized_initial_state_prior_fn = (
                parameterized_initial_state_prior_fn)

            if parameter_constraining_bijector is None:
                parameter_constraining_bijector = (
                    parameter_prior.experimental_default_event_space_bijector(
                    ))
            self._parameter_constraining_bijector = parameter_constraining_bijector

            # Augment the prior to include both parameters and states.
            self._joint_initial_state_prior = joint_prior_on_parameters_and_state(
                parameter_prior,
                parameterized_initial_state_prior_fn,
                parameter_constraining_bijector,
                prior_is_constrained=True)

            # Check that prior samples have a consistent number of particles.
            # TODO(davmre): remove the need for dummy shape dependencies,
            # and this check, by using `JointDistributionNamedAutoBatched` with
            # auto-vectorization enabled in `joint_prior_on_parameters_and_state`.

            num_particles_canary = 13
            canary_seed = samplers.zeros_seed()

            def _get_shape_1(x):
                if hasattr(x, 'state'):
                    x = x.state
                return tf.TensorShape(x.shape[1:2])

            prior_static_sample_shapes = tf.nest.map_structure(
                # Sample shape [0, num_particles_canary] particles (size will be zero)
                # then trim off the leading 0 and (possibly) any event shape.
                # We expect shape [num_particles_canary] to remain.
                _get_shape_1,
                self._joint_initial_state_prior.sample(
                    [0, num_particles_canary], seed=canary_seed))
            if not all([
                    tensorshape_util.is_compatible_with(
                        s[:1], [num_particles_canary])
                    for s in tf.nest.flatten(prior_static_sample_shapes)
            ]):
                raise ValueError(
                    'The specified prior does not generate consistent '
                    'shapes when sampled. Please verify that all parts of '
                    '`initial_state_prior_fn` have batch shape matching '
                    'that of the parameters. This may require creating '
                    '"dummy" dependencies on parameters; for example: '
                    '`tf.broadcast_to(value, tf.shape(parameter))`. (in a '
                    f'test sample with {num_particles_canary} particles, we expected '
                    'all) values to have shape compatible with '
                    f'[{num_particles_canary}, ...]; '
                    f'saw shapes {prior_static_sample_shapes})')

            # Augment the transition and observation fns to cover both
            # parameters and states.
            self._joint_transition_fn = augment_transition_fn_with_parameters(
                parameter_prior, parameterized_transition_fn,
                parameter_constraining_bijector)
            self._joint_observation_fn = augment_observation_fn_with_parameters(
                parameterized_observation_fn, parameter_constraining_bijector)

            # If given a proposal for the initial state, augment it into a joint
            # proposal over parameters and states.
            joint_initial_state_proposal = None
            if parameterized_initial_state_proposal_fn:
                joint_initial_state_proposal = joint_prior_on_parameters_and_state(
                    parameter_prior, parameterized_initial_state_proposal_fn,
                    parameter_constraining_bijector)
            else:
                parameterized_initial_state_proposal_fn = (
                    parameterized_initial_state_prior_fn)
            self._joint_initial_state_proposal = joint_initial_state_proposal
            self._parameterized_initial_state_proposal_fn = (
                parameterized_initial_state_proposal_fn)

            # If given a conditional proposal fn (for non-initial states), augment
            # it to be joint over states and parameters.
            self._joint_proposal_fn = None
            if parameterized_proposal_fn:
                self._joint_proposal_fn = augment_transition_fn_with_parameters(
                    parameter_prior, parameterized_proposal_fn,
                    parameter_constraining_bijector)

            self._batch_ndims = tf.nest.map_structure(
                ps.rank_from_shape, parameter_prior.batch_shape_tensor())
            self._name = name