def _maybe_build_joint_distribution(structure_of_distributions): """Turns a (potentially nested) structure of dists into a single dist.""" # Base case: if we already have a Distribution, return it. if dist_util.is_distribution_instance(structure_of_distributions): return structure_of_distributions # Otherwise, recursively convert all interior nested structures into JDs. outer_structure = tf.nest.map_structure(_maybe_build_joint_distribution, structure_of_distributions) if (hasattr(outer_structure, '_asdict') or isinstance(outer_structure, collections.Mapping)): return joint_distribution_named.JointDistributionNamed(outer_structure) else: return joint_distribution_sequential.JointDistributionSequential( outer_structure)
def independent_joint_distribution_from_structure(structure_of_distributions, validate_args=False): """Turns a (potentially nested) structure of dists into a single dist. Args: structure_of_distributions: instance of `tfd.Distribution`, or nested structure (tuple, list, dict, etc.) in which all leaves are `tfd.Distribution` instances. validate_args: Python `bool`. Whether the joint distribution should validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. Returns: distribution: instance of `tfd.Distribution` such that `distribution.sample()` is equivalent to `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`. If `structure_of_distributions` was indeed a structure (as opposed to a single `Distribution` instance), this will be a `JointDistribution` with the corresponding structure. Raises: TypeError: if any leaves of the input structure are not `tfd.Distribution` instances. """ # If input is already a Distribution, just return it. if dist_util.is_distribution_instance(structure_of_distributions): return structure_of_distributions # If this structure contains other structures (ie, has elements at depth > 1), # recursively turn them into JDs. element_depths = nest.map_structure_with_tuple_paths( lambda path, x: len(path), structure_of_distributions) if max(tf.nest.flatten(element_depths)) > 1: next_level_shallow_structure = nest.get_traverse_shallow_structure( traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1, structure=element_depths) structure_of_distributions = nest.map_structure_up_to( next_level_shallow_structure, independent_joint_distribution_from_structure, structure_of_distributions) # Otherwise, build a JD from the current structure. if (hasattr(structure_of_distributions, '_asdict') or isinstance(structure_of_distributions, collections.Mapping)): return joint_distribution_named.JointDistributionNamed( structure_of_distributions, validate_args=validate_args) return joint_distribution_sequential.JointDistributionSequential( structure_of_distributions, validate_args=validate_args)
def joint_prior_on_parameters_and_state(parameter_prior, parameterized_initial_state_prior_fn, parameter_constraining_bijector, prior_is_constrained=True): """Constructs a joint dist. from p(parameters) and p(state | parameters).""" if prior_is_constrained: parameter_prior = transformed_distribution.TransformedDistribution( parameter_prior, invert.Invert(parameter_constraining_bijector), name='unconstrained_parameter_prior') return joint_distribution_named.JointDistributionNamed( ParametersAndState( unconstrained_parameters=parameter_prior, state=lambda unconstrained_parameters: ( # pylint: disable=g-long-lambda parameterized_initial_state_prior_fn( parameter_constraining_bijector.forward( unconstrained_parameters)))))
def augmented_fn(step, state_with_history, **kwargs): """Builds history-tracking dist. over `StateWithHistory` instances.""" with tf.name_scope('augment_with_state_history'): new_state_dist = fn(step, state_with_history, **kwargs) def new_state_history_dist(state): with tf.name_scope('new_state_history_dist'): new_state_histories = tf.nest.map_structure( lambda h, s: tf.concat([h[:, 1:], # pylint: disable=g-long-lambda s[:, tf.newaxis]], axis=1), state_with_history.state_history, state) return ( joint_distribution_util .independent_joint_distribution_from_structure( _wrap_as_distributions(new_state_histories))) return joint_distribution_named.JointDistributionNamed( StateWithHistory( state=new_state_dist, state_history=new_state_history_dist))
def params_and_state_transition_fn(step, params_and_state, perturbation_scale, **kwargs): """Transition function operating on a `ParamsAndState` namedtuple.""" # Extract the state, to pass through to the observation fn. unconstrained_params, state = params_and_state if 'state_history' in kwargs: kwargs['state_history'] = kwargs['state_history'].state # Perturb each (unconstrained) parameter with normally-distributed noise. if not tf.nest.is_nested(perturbation_scale): perturbation_scale = tf.nest.map_structure( lambda x: tf.convert_to_tensor(perturbation_scale, # pylint: disable=g-long-lambda name='perturbation_scale', dtype=x.dtype), unconstrained_params) perturbed_unconstrained_parameter_dists = tf.nest.map_structure( lambda x, p, s: independent.Independent( # pylint: disable=g-long-lambda normal.Normal(loc=x, scale=p), reinterpreted_batch_ndims=prefer_static.rank_from_shape(s)), unconstrained_params, perturbation_scale, parameter_prior.event_shape_tensor()) # For the joint transition, pass the perturbed parameters # into the original transition fn (after pushing them into constrained # space). return joint_distribution_named.JointDistributionNamed( ParametersAndState( unconstrained_parameters=_maybe_build_joint_distribution( perturbed_unconstrained_parameter_dists), state=lambda unconstrained_parameters: ( # pylint: disable=g-long-lambda parameterized_transition_fn( step, state, parameters=parameter_constraining_bijector.forward( unconstrained_parameters), **kwargs))))
def augment_prior_with_state_history(prior, history_size): """Augments a prior or proposal distribution's state space with history. The augmented state space is over `tfp.experimental.mcmc.StateWithHistory` namedtuples, which contain the original `state` as well as a `state_history`. The `state_history` is a structure of `Tensor`s matching `state`, of shape `concat([[num_particles, history_size], state.shape[1:]])`. In other words, previous states for each particle are indexed along `axis=1`, to the right of the particle indices. Args: prior: a (joint) distribution over the initial latent state, with optional batch shape `[b1, ..., bN]`. history_size: integer `Tensor` number of steps of history to pass. Returns: augmented_prior: a `tfd.JointDistributionNamed` instance whose samples are `tfp.experimental.mcmc.StateWithHistory` namedtuples. #### Example As a toy example, let's see how we'd use state history to experiment with stochastic 'Fibonacci sequences'. We'll assume that the sequence starts at a value sampled from a Poisson distribution. ```python initial_state_prior = tfd.Poisson(5.) initial_state_with_history_prior = ( tfp.experimental.mcmc.augment_prior_with_state_history( initial_state_prior, history_size=2)) ``` Note that we've augmented the state space to include a state history of size two. The augmented state space is over instances of `tfp.experimental.mcmc.StateWithHistory`. Initially, the state history will simply tile the initial state: if `s = initial_state_with_history_prior.sample()`, then `s.state_history==[s.state, s.state]`. Next, we'll define a `transition_fn` that uses the history to sample the next integer in the sequence, also from a Poisson distribution. ```python @tfp.experimental.mcmc.augment_with_state_history def fibonacci_transition_fn(_, state_with_history): expected_next_element = tf.reduce_sum( state_with_history.state_history[:, -2:], axis=1) return tfd.Poisson(rate=expected_next_element) ``` Our transition function must accept `state_with_history`, so that it can access the history, but it returns a distribution only over the next state. Decorating it with `augment_with_state_history` ensures that the state history is automatically propagated. Note: if we were using an `initial_state_proposal` and/or `proposal_fn`, we would need to wrap them similarly to the prior and transition function shown here. Combined with an observation function (which must also now be defined on the augmented `StateWithHistory` space), we can track stochastic Fibonacci sequences and, for example, infer the initial value of a sequence: ```python def observation_fn(_, state_with_history): return tfd.Poisson(rate=state_with_history.state) trajectories, _ = tfp.experimental.mcmc.infer_trajectories( observations=tf.convert_to_tensor([4., 11., 16., 23., 40., 69., 100.]), initial_state_prior=initial_state_with_history_prior, transition_fn=fibonacci_transition_fn, observation_fn=observation_fn, num_particles=1024) inferred_initial_states = trajectories.state[0] print(tf.unique_with_counts(inferred_initial_states)) ``` """ def initialize_state_history(state): """Build an initial state history by replicating the initial state.""" with tf.name_scope('initialize_state_history'): initial_state_histories = tf.nest.map_structure( lambda x: tf.broadcast_to( # pylint: disable=g-long-lambda tf.expand_dims(x, ps.minimum(ps.rank(x), 1)), ps.concat( [ps.shape(x)[:1], [history_size], ps.shape(x)[1:]], axis=0)), state) return (joint_distribution_util. independent_joint_distribution_from_structure( _wrap_as_distributions(initial_state_histories))) return joint_distribution_named.JointDistributionNamed( StateWithHistory(state=prior, state_history=initialize_state_history))
def build_factored_surrogate_posterior(model, batch_shape=(), seed=None, name=None): """Build a variational posterior that factors over model parameters. The surrogate posterior consists of independent Normal distributions for each parameter with trainable `loc` and `scale`, transformed using the parameter's `bijector` to the appropriate support space for that parameter. Args: model: An instance of `StructuralTimeSeries` representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape `[b1, ..., bN]`. batch_shape: Batch shape (Python `tuple`, `list`, or `int`) of initial states to optimize in parallel. Default value: `()`. (i.e., just run a single optimization). seed: Python integer to seed the random number generator. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_surrogate_posterior'). Returns: variational_posterior: `tfd.JointDistributionNamed` defining a trainable surrogate posterior over model parameters. Samples from this distribution are Python `dict`s with Python `str` parameter names as keys. ### Examples Assume we've built a structural time-series model: ```python day_of_week = tfp.sts.Seasonal( num_seasons=7, observed_time_series=observed_time_series, name='day_of_week') local_linear_trend = tfp.sts.LocalLinearTrend( observed_time_series=observed_time_series, name='local_linear_trend') model = tfp.sts.Sum(components=[day_of_week, local_linear_trend], observed_time_series=observed_time_series) ``` To fit the model to data, we define a surrogate posterior and fit it by optimizing a variational bound: ```python surrogate_posterior = tfp.sts.build_factored_surrogate_posterior( model=model) loss_curve = tfp.vi.fit_surrogate_posterior( target_log_prob_fn=model.joint_log_prob(observed_time_series), surrogate_posterior=surrogate_posterior, optimizer=tf.optimizers.Adam(learning_rate=0.1), num_steps=200) posterior_samples = surrogate_posterior.sample(50) # In graph mode, we would need to write: # with tf.control_dependencies([loss_curve]): # posterior_samples = surrogate_posterior.sample(50) ``` For more control, we can also build and optimize a variational loss manually: ```python @tf.function(autograph=False) # Ensure the loss is computed efficiently def loss_fn(): return tfp.vi.monte_carlo_variational_loss( model.joint_log_prob(observed_time_series), surrogate_posterior, sample_size=10) optimizer = tf.optimizers.Adam(learning_rate=0.1) for step in range(200): with tf.GradientTape() as tape: loss = loss_fn() grads = tape.gradient(loss, surrogate_posterior.trainable_variables) optimizer.apply_gradients( zip(grads, surrogate_posterior.trainable_variables)) if step % 20 == 0: print('step {} loss {}'.format(step, loss)) posterior_samples = surrogate_posterior.sample(50) ``` """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): seed = tfp_util.SeedStream( seed, salt='StructuralTimeSeries_build_factored_surrogate_posterior') variational_posterior = collections.OrderedDict() for param in model.parameters: variational_posterior[ param.name] = _build_posterior_for_one_parameter( param, batch_shape=batch_shape, seed=seed()) return joint_distribution_named_lib.JointDistributionNamed( variational_posterior)