def _check_weights_shape(self, log_weights, sample_shape): assertions = [] message = ( 'Shape of importance weights does not match the batch shape ' 'of `self.proposal_distribution`. This implies that ' 'the proposal is not producing independent samples for some ' 'batch dimension(s) expected by `self.target_log_prob_fn`.') sample_and_batch_shape = ps.concat( [ sample_shape, _get_joint_batch_shape( self.proposal_distribution.batch_shape_tensor()) ], axis=0, ) sample_and_batch_shape_ = tf.get_static_value(sample_and_batch_shape) if (sample_and_batch_shape_ is not None and not tensorshape_util.is_compatible_with( log_weights.shape, sample_and_batch_shape_)): raise ValueError( message + ' Saw: weights shape {} vs proposal sample and batch ' 'shape {}.'.format(log_weights.shape, sample_and_batch_shape)) elif self.validate_args: assertions += [ assert_util.assert_equal(tf.shape(log_weights), sample_and_batch_shape, message=message) ] return assertions
def _maybe_broadcast(a, b): if not (tensorshape_util.is_fully_defined(a.shape) and tensorshape_util.is_fully_defined(b.shape) and tensorshape_util.is_compatible_with(a.shape, b.shape)): # If both shapes are well defined and equal, we skip broadcasting. b = b + tf.zeros_like(a) a = a + tf.zeros_like(b) return a, b
def _cdf(self, counts): probs = self._probs_parameter_no_checks() if not (tensorshape_util.is_fully_defined(counts.shape) and tensorshape_util.is_fully_defined(probs.shape) and tensorshape_util.is_compatible_with(counts.shape, probs.shape)): # If both shapes are well defined and equal, we skip broadcasting. probs = probs + tf.zeros_like(counts) counts = counts + tf.zeros_like(probs) return _bdtr(k=counts, n=tf.convert_to_tensor(self.total_count), p=probs)
def _cdf(self, counts): counts = self._maybe_assert_valid_sample(counts) probs = self.probs if not (tensorshape_util.is_fully_defined(counts.shape) and tensorshape_util.is_fully_defined(self.probs.shape) and tensorshape_util.is_compatible_with( counts.shape, self.probs.shape)): # If both shapes are well defined and equal, we skip broadcasting. probs += tf.zeros_like(counts) counts += tf.zeros_like(self.probs) return _bdtr(k=counts, n=self.total_count, p=probs)
def _value(self, dtype=None, name=None, as_ref=False): y = self.transform_fn(self.pretransformed_input) # pylint: disable=not-callable if dtype_util.base_dtype(y.dtype) != self.dtype: raise TypeError( 'Actual dtype ({}) does not match deferred dtype ({}).'.format( dtype_util.name(dtype_util.base_dtype(y.dtype)), dtype_util.name(self.dtype))) if not tensorshape_util.is_compatible_with(y.shape, self.shape): raise TypeError( 'Actual shape ({}) is incompatible with deferred shape ({}).'. format(y.shape, self.shape)) return tf.convert_to_tensor(y, dtype=dtype, name=name)
def _assert_same_shape(x, y, message='Shapes do not match.', validate_args=False): """Asserts (statically if possible) that two Tensor have the same shape.""" if not tensorshape_util.is_compatible_with(x.shape, y.shape): raise ValueError(message + ' Saw shapes: {} vs {}.'.format(x.shape, y.shape)) assertions = [] if validate_args and not (tensorshape_util.is_fully_defined(x.shape) and tensorshape_util.is_fully_defined(y.shape)): assertions.append( assert_util.assert_equal( tf.shape(x), tf.shape(y), message=message)) return assertions
def __init__(self, cat, components, validate_args=False, allow_nan_stats=True, use_static_graph=False, name='Mixture'): """Initialize a Mixture distribution. A `Mixture` is defined by a `Categorical` (`cat`, representing the mixture probabilities) and a list of `Distribution` objects all having matching dtype, batch shape, event shape, support, and continuity properties (the components). The `num_classes` of `cat` must be possible to infer at graph construction time and match `len(components)`. Args: cat: A `Categorical` distribution instance, representing the probabilities of `distributions`. components: A list or tuple of `Distribution` instances. Each instance must have the same type, be defined on the same domain, and have matching `event_shape` and `batch_shape`. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between cat and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. use_static_graph: Calls to `sample` will not rely on dynamic tensor indexing, allowing for some static graph compilation optimizations, but at the expense of sampling all underlying distributions in the mixture. (Possibly useful when running on TPUs). Default value: `False` (i.e., use dynamic indexing). name: A name for this distribution (optional). Raises: TypeError: If cat is not a `Categorical`, or `components` is not a list or tuple, or the elements of `components` are not instances of `Distribution`, or do not have matching `dtype`. ValueError: If `components` is an empty list or tuple, or its elements do not have a statically known event rank. If `cat.num_classes` cannot be inferred at graph creation time, or the constant value of `cat.num_classes` is not equal to `len(components)`, or all `components` and `cat` do not have matching static batch shapes, or all components do not have matching static event shapes. """ parameters = dict(locals()) with tf.name_scope(name) as name: if not isinstance(cat, categorical.Categorical): raise TypeError( 'cat must be a Categorical distribution, but saw: %s' % cat) if not components: raise ValueError( 'components must be a non-empty list or tuple') if not isinstance(components, (list, tuple)): raise TypeError( 'components must be a list or tuple, but saw: %s' % components) if not all( isinstance(c, distribution.Distribution) for c in components): raise TypeError( 'all entries in components must be Distribution instances' ' but saw: %s' % components) dtype = components[0].dtype if not all(d.dtype == dtype for d in components): raise TypeError( 'All components must have the same dtype, but saw ' 'dtypes: %s' % [(d.name, d.dtype) for d in components]) static_event_shape = components[0].event_shape static_batch_shape = cat.batch_shape for di, d in enumerate(components): if not tensorshape_util.is_compatible_with( static_batch_shape, d.batch_shape): raise ValueError( 'components[{}] batch shape must be compatible with cat ' 'shape and other component batch shapes'.format(di)) static_event_shape = tensorshape_util.merge_with( static_event_shape, d.event_shape) static_batch_shape = tensorshape_util.merge_with( static_batch_shape, d.batch_shape) if tensorshape_util.rank(static_event_shape) is None: raise ValueError( 'Expected to know rank(event_shape) from components, but ' 'none of the components provide a static number of ndims') # pylint: disable=protected-access cat_dist_param = cat._probs if cat._logits is None else cat._logits # pylint: enable=protected-access static_num_components = tf.compat.dimension_value( cat_dist_param.shape[-1]) if static_num_components is None: raise ValueError( 'Could not infer number of classes from cat and unable ' 'to compare this value to the number of components passed in.' ) if static_num_components != len(components): raise ValueError( 'cat.num_classes != len(components): %d vs. %d' % (static_num_components, len(components))) self._cat = cat self._components = list(components) self._num_components = static_num_components self._static_event_shape = static_event_shape self._static_batch_shape = static_batch_shape self._use_static_graph = use_static_graph super(Mixture, self).__init__( dtype=dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, cat, components, validate_args=False, allow_nan_stats=True, use_static_graph=False, name="Mixture"): """Initialize a Mixture distribution. A `Mixture` is defined by a `Categorical` (`cat`, representing the mixture probabilities) and a list of `Distribution` objects all having matching dtype, batch shape, event shape, and continuity properties (the components). The `num_classes` of `cat` must be possible to infer at graph construction time and match `len(components)`. Args: cat: A `Categorical` distribution instance, representing the probabilities of `distributions`. components: A list or tuple of `Distribution` instances. Each instance must have the same type, be defined on the same domain, and have matching `event_shape` and `batch_shape`. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between cat and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. use_static_graph: Calls to `sample` will not rely on dynamic tensor indexing, allowing for some static graph compilation optimizations, but at the expense of sampling all underlying distributions in the mixture. (Possibly useful when running on TPUs). Default value: `False` (i.e., use dynamic indexing). name: A name for this distribution (optional). Raises: TypeError: If cat is not a `Categorical`, or `components` is not a list or tuple, or the elements of `components` are not instances of `Distribution`, or do not have matching `dtype`. ValueError: If `components` is an empty list or tuple, or its elements do not have a statically known event rank. If `cat.num_classes` cannot be inferred at graph creation time, or the constant value of `cat.num_classes` is not equal to `len(components)`, or all `components` and `cat` do not have matching static batch shapes, or all components do not have matching static event shapes. """ parameters = dict(locals()) if not isinstance(cat, categorical.Categorical): raise TypeError( "cat must be a Categorical distribution, but saw: %s" % cat) if not components: raise ValueError("components must be a non-empty list or tuple") if not isinstance(components, (list, tuple)): raise TypeError("components must be a list or tuple, but saw: %s" % components) if not all( isinstance(c, distribution.Distribution) for c in components): raise TypeError( "all entries in components must be Distribution instances" " but saw: %s" % components) dtype = components[0].dtype if not all(d.dtype == dtype for d in components): raise TypeError("All components must have the same dtype, but saw " "dtypes: %s" % [(d.name, d.dtype) for d in components]) static_event_shape = components[0].event_shape static_batch_shape = cat.batch_shape for di, d in enumerate(components): if not tensorshape_util.is_compatible_with(static_batch_shape, d.batch_shape): raise ValueError( "components[{}] batch shape must be compatible with cat " "shape and other component batch shapes".format(di)) static_event_shape = tensorshape_util.merge_with( static_event_shape, d.event_shape) static_batch_shape = tensorshape_util.merge_with( static_batch_shape, d.batch_shape) if tensorshape_util.rank(static_event_shape) is None: raise ValueError( "Expected to know rank(event_shape) from components, but " "none of the components provide a static number of ndims") # Ensure that all batch and event ndims are consistent. with tf.name_scope(name) as name: num_components = cat._num_categories() static_num_components = tf.get_static_value(num_components) if static_num_components is None: raise ValueError( "Could not infer number of classes from cat and unable " "to compare this value to the number of components passed in." ) # Possibly convert from numpy 0-D array. static_num_components = int(static_num_components) if static_num_components != len(components): raise ValueError( "cat.num_classes != len(components): %d vs. %d" % (static_num_components, len(components))) cat_batch_shape = cat.batch_shape_tensor() cat_batch_rank = tf.size(cat_batch_shape) if validate_args: batch_shapes = [d.batch_shape_tensor() for d in components] batch_ranks = [tf.size(bs) for bs in batch_shapes] check_message = ("components[%d] batch shape must match cat " "batch shape") self._assertions = [ assert_util.assert_equal(cat_batch_rank, batch_ranks[di], message=check_message % di) for di in range(len(components)) ] self._assertions += [ assert_util.assert_equal(cat_batch_shape, batch_shapes[di], message=check_message % di) for di in range(len(components)) ] else: self._assertions = [] self._cat = cat self._components = list(components) self._num_components = static_num_components self._static_event_shape = static_event_shape self._static_batch_shape = static_batch_shape self._use_static_graph = use_static_graph if use_static_graph and static_num_components is None: raise ValueError( "Number of categories must be known statically when " "`static_sample=True`.") # We let the Mixture distribution access _graph_parents since its arguably # more like a baseclass. graph_parents = self._cat._graph_parents # pylint: disable=protected-access for c in self._components: graph_parents += c._graph_parents # pylint: disable=protected-access super(Mixture, self).__init__( dtype=dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=graph_parents, name=name)
def __init__(self, parameter_prior, parameterized_initial_state_prior_fn, parameterized_transition_fn, parameterized_observation_fn, parameterized_initial_state_proposal_fn=None, parameterized_proposal_fn=None, parameter_constraining_bijector=None, name=None): """Builds an iterated filter for parameter estimation in sequential models. Iterated filtering is a parameter estimation method in which parameters are included in an augmented state space, with dynamics that introduce parameter perturbations, and a filtering algorithm such as particle filtering is run several times with perturbations of decreasing size. This class implements the IF2 algorithm of [Ionides et al., 2015][1], for which, under appropriate conditions (including a uniform prior) the final parameter distribution approaches a point mass at the maximum likelihood estimate. If a non-uniform prior is provided, the final parameter distribution will (under appropriate conditions) approach a point mass at the maximum a posteriori (MAP) value. This class augments the state space of a sequential model to include parameter perturbations, and provides utilities to run particle filtering on that augmented model. Alternately, the augmented components may be passed directly into a filtering algorithm of the user's choice. Args: parameter_prior: prior `tfd.Distribution` over parameters (may be a joint distribution). parameterized_initial_state_prior_fn: `callable` with signature `initial_state_prior = parameterized_initial_state_prior_fn(parameters)` where `parameters` has the form of a sample from `parameter_prior`, and `initial_state_prior` is a distribution over the initial state. parameterized_transition_fn: `callable` with signature `next_state_dist = parameterized_transition_fn( step, state, parameters, **kwargs)`. parameterized_observation_fn: `callable` with signature `observation_dist = parameterized_observation_fn( step, state, parameters, **kwargs)`. parameterized_initial_state_proposal_fn: optional `callable` with signature `initial_state_proposal = parameterized_initial_state_proposal_fn(parameters)` where `parameters` has the form of a sample from `parameter_prior`, and `initial_state_proposal` is a distribution over the initial state. parameterized_proposal_fn: optional `callable` with signature `next_state_dist = parameterized_transition_fn( step, state, parameters, **kwargs)`. Default value: `None`. parameter_constraining_bijector: optional `tfb.Bijector` instance such that `parameter_constraining_bijector.forward(x)` returns valid parameters for any real-valued `x` of the same structure and shape as `parameters`. If `None`, the default bijector of the provided `parameter_prior` will be used. Default value: `None`. name: `str` name for ops constructed by this object. Default value: `iterated_filter`. #### Example We'll walk through applying iterated filtering to a toy Susceptible-Infected-Recovered (SIR) model, a [compartmental model]( https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#The_SIR_model) of infectious disease. Note that the model we use here is extremely simplified and is intended as a pedagogical example; it should not be interpreted to describe disease spread in the real world. We begin by specifying a prior distribution over the parameters to be inferred, thus defining the structure of the parameter space and the support of the parameters (which will imply a default constraining bijector). Here we'll use uniform priors over ranges that we expect to contain the parameters: ```python parameter_prior = tfd.JointDistributionNamed({ 'infection_rate': tfd.Uniform(low=0., high=3.), 'recovery_rate': tfd.Uniform(low=0., high=3.), }) ``` The model specification itself is identical to that used by `tfp.experimental.mcmc.infer_trajectories`, except that each component accepts an additional `parameters` keyword argument. We start by specifying a parameterized prior on initial states. In this case, our state includes the current number of susceptible and infected individuals (the third compartment, recovered individuals, is implicitly defined to include the remaining population). We'll also include, as auxiliary variables, the daily counts of new infections and new recoveries; these will help ensure that people shift consistently across compartments. ```python population_size = 1000 initial_state_prior_fn = lambda parameters: tfd.JointDistributionNamed({ 'new_infections': tfd.Poisson(parameters['infection_rate']), 'new_recoveries': tfd.Deterministic( tf.broadcast_to(0., tf.shape(parameters['recovery_rate']))), 'susceptible': (lambda new_infections: tfd.Deterministic(population_size - new_infections)), 'infected': (lambda new_infections: tfd.Deterministic(new_infections))}) ``` **Note**: the state prior must have the same batch shape as the passed-in parameters; equivalently, it must sample a full state for each parameter particle. If any part of the state prior does not depend on the parameters, you must manually ensure that it has the appropriate batch shape. For example, in the definition of `new_recoveries` above, applying `broadcast_to` with the shape of a parameter ensures that the batch shape is maintained. Next, we specify a transition model. This takes the state at the previous day, along with parameters, and returns a distribution over the state for the current day. ```python def parameterized_infection_dynamics(_, previous_state, parameters): new_infections = tfd.Poisson( parameters['infection_rate'] * previous_state['infected'] * previous_state['susceptible'] / population_size) new_recoveries = tfd.Poisson( previous_state['infected'] * parameters['recovery_rate']) return tfd.JointDistributionNamed({ 'new_infections': new_infections, 'new_recoveries': new_recoveries, 'susceptible': lambda new_infections: tfd.Deterministic( tf.maximum(0., previous_state['susceptible'] - new_infections)), 'infected': lambda new_infections, new_recoveries: tfd.Deterministic( tf.maximum(0., (previous_state['infected'] + new_infections - new_recoveries)))}) ``` Finally, assume that every day we get to observe noisy counts of new infections and recoveries. ```python def parameterized_infection_observations(_, state, parameters): del parameters # Not used. return tfd.JointDistributionNamed({ 'new_infections': tfd.Poisson(state['new_infections'] + 0.1), 'new_recoveries': tfd.Poisson(state['new_recoveries'] + 0.1)}) ``` Combining these components, an `IteratedFilter` augments the state space to include parameters that may change over time. ```python iterated_filter = tfp.experimental.sequential.IteratedFilter( parameter_prior=parameter_prior, parameterized_initial_state_prior_fn=initial_state_prior_fn, parameterized_transition_fn=parameterized_infection_dynamics, parameterized_observation_fn=parameterized_infection_observations) ``` We may then run the filter to estimate parameters from a series of observations: ```python # Simulated with `infection_rate=1.2` and `recovery_rate=0.1`. observed_values = { 'new_infections': tf.convert_to_tensor([ 2., 7., 14., 24., 45., 93., 160., 228., 252., 158., 17., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 'new_recoveries': tf.convert_to_tensor([ 0., 0., 3., 4., 3., 8., 12., 31., 49., 73., 85., 65., 71., 58., 42., 65., 36., 31., 32., 27., 31., 20., 19., 19., 14., 27.]) } parameter_particles = iterated_filter.estimate_parameters( observations=observed_values, num_iterations=20, num_particles=4096, initial_perturbation_scale=1.0, cooling_schedule=( tfp.experimental.sequential.geometric_cooling_schedule( 0.001, k=20)), seed=test_util.test_seed()) print('Mean of parameter particles from final iteration: {}'.format( tf.nest.map_structure(lambda x: tf.reduce_mean(x[-1], axis=0), parameter_particles))) print('Standard deviation of parameter particles from ' 'final iteration: {}'.format( tf.nest.map_structure(lambda x: tf.math.reduce_std(x[-1], axis=0), parameter_particles))) ``` For more control, we could alternately choose to run filtering iterations on the augmented model manually, using the filter of our choice. For example, manually invoking `infer_trajectories` would allow us to inspect the parameter and state values at all timesteps, and their corresponding log-probabilities: ```python trajectories, lps = tfp.experimental.mcmc.infer_trajectories( observations=observations, initial_state_prior=iterated_filter.joint_initial_state_prior, transition_fn=functools.partial( iterated_filter.joint_transition_fn, perturbation_scale=perturbation_scale), observation_fn=iterated_filter.joint_observation_fn, proposal_fn=iterated_filter.joint_proposal_fn, initial_state_proposal=iterated_filter.joint_initial_state_proposal( initial_unconstrained_parameters), num_particles=4096) ``` #### References: [1] Edward L. Ionides, Dao Nguyen, Yves Atchade, Stilian Stoev, and Aaron A. King. Inference for dynamic and latent variable models via iterated, perturbed Bayes maps. _Proceedings of the National Academy of Sciences_ 112, no. 3: 719-724, 2015. https://www.pnas.org/content/pnas/112/3/719.full.pdf """ name = name or 'IteratedFilter' with tf.name_scope(name): self._parameter_prior = parameter_prior self._parameterized_initial_state_prior_fn = ( parameterized_initial_state_prior_fn) if parameter_constraining_bijector is None: parameter_constraining_bijector = ( parameter_prior.experimental_default_event_space_bijector( )) self._parameter_constraining_bijector = parameter_constraining_bijector # Augment the prior to include both parameters and states. self._joint_initial_state_prior = joint_prior_on_parameters_and_state( parameter_prior, parameterized_initial_state_prior_fn, parameter_constraining_bijector, prior_is_constrained=True) # Check that prior samples have a consistent number of particles. # TODO(davmre): remove the need for dummy shape dependencies, # and this check, by using `JointDistributionNamedAutoBatched` with # auto-vectorization enabled in `joint_prior_on_parameters_and_state`. num_particles_canary = 13 canary_seed = samplers.zeros_seed() def _get_shape_1(x): if hasattr(x, 'state'): x = x.state return tf.TensorShape(x.shape[1:2]) prior_static_sample_shapes = tf.nest.map_structure( # Sample shape [0, num_particles_canary] particles (size will be zero) # then trim off the leading 0 and (possibly) any event shape. # We expect shape [num_particles_canary] to remain. _get_shape_1, self._joint_initial_state_prior.sample( [0, num_particles_canary], seed=canary_seed)) if not all([ tensorshape_util.is_compatible_with( s[:1], [num_particles_canary]) for s in tf.nest.flatten(prior_static_sample_shapes) ]): raise ValueError( 'The specified prior does not generate consistent ' 'shapes when sampled. Please verify that all parts of ' '`initial_state_prior_fn` have batch shape matching ' 'that of the parameters. This may require creating ' '"dummy" dependencies on parameters; for example: ' '`tf.broadcast_to(value, tf.shape(parameter))`. (in a ' f'test sample with {num_particles_canary} particles, we expected ' 'all) values to have shape compatible with ' f'[{num_particles_canary}, ...]; ' f'saw shapes {prior_static_sample_shapes})') # Augment the transition and observation fns to cover both # parameters and states. self._joint_transition_fn = augment_transition_fn_with_parameters( parameter_prior, parameterized_transition_fn, parameter_constraining_bijector) self._joint_observation_fn = augment_observation_fn_with_parameters( parameterized_observation_fn, parameter_constraining_bijector) # If given a proposal for the initial state, augment it into a joint # proposal over parameters and states. joint_initial_state_proposal = None if parameterized_initial_state_proposal_fn: joint_initial_state_proposal = joint_prior_on_parameters_and_state( parameter_prior, parameterized_initial_state_proposal_fn, parameter_constraining_bijector) else: parameterized_initial_state_proposal_fn = ( parameterized_initial_state_prior_fn) self._joint_initial_state_proposal = joint_initial_state_proposal self._parameterized_initial_state_proposal_fn = ( parameterized_initial_state_proposal_fn) # If given a conditional proposal fn (for non-initial states), augment # it to be joint over states and parameters. self._joint_proposal_fn = None if parameterized_proposal_fn: self._joint_proposal_fn = augment_transition_fn_with_parameters( parameter_prior, parameterized_proposal_fn, parameter_constraining_bijector) self._batch_ndims = tf.nest.map_structure( ps.rank_from_shape, parameter_prior.batch_shape_tensor()) self._name = name