def test_dont_pad_batch_dimension_when_input_has_no_sample_shape(self): model_batch_shape = [3, 2] model = tfp.sts.LocalLevel(level_scale_prior=tfd.Normal( loc=self._build_tensor(np.random.randn(*model_batch_shape)), scale=1.)) num_timesteps = 5 observed_time_series = self._build_tensor( np.random.randn(num_timesteps, 1)) padded_observed_time_series = ( sts_util.pad_batch_dimension_for_multiple_chains( observed_time_series, model=model, chain_batch_shape=[8, 2])) self.assertAllEqual(self._shape_as_list(padded_observed_time_series), self._shape_as_list(observed_time_series))
def build_factored_variational_loss(model, observed_time_series, init_batch_shape=(), seed=None, name=None): """Build a loss function for variational inference in STS models. Variational inference searches for the distribution within some family of approximate posteriors that minimizes a divergence between the approximate posterior `q(z)` and true posterior `p(z|observed_time_series)`. By converting inference to optimization, it's generally much faster than sampling-based inference algorithms such as HMC. The tradeoff is that the approximating family rarely contains the true posterior, so it may miss important aspects of posterior structure (in particular, dependence between variables) and should not be blindly trusted. Results may vary; it's generally wise to compare to HMC to evaluate whether inference quality is sufficient for your task at hand. This method constructs a loss function for variational inference using the Kullback-Liebler divergence `KL[q(z) || p(z|observed_time_series)]`, with an approximating family given by independent Normal distributions transformed to the appropriate parameter space for each parameter. Minimizing this loss (the negative ELBO) maximizes a lower bound on the log model evidence `-log p(observed_time_series)`. This is equivalent to the 'mean-field' method implemented in [1]. and is a standard approach. The resulting posterior approximations are unimodal; they will tend to underestimate posterior uncertainty when the true posterior contains multiple modes (the `KL[q||p]` divergence encourages choosing a single mode) or dependence between variables. Args: model: An instance of `StructuralTimeSeries` representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape `[b1, ..., bN]`. observed_time_series: `float` `Tensor` of shape `concat([sample_shape, model.batch_shape, [num_timesteps, 1]]) where `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]` dimension may (optionally) be omitted if `num_timesteps > 1`. May optionally be an instance of `tfp.sts.MaskedTimeSeries`, which includes a mask `Tensor` to specify timesteps with missing observations. init_batch_shape: Batch shape (Python `tuple`, `list`, or `int`) of initial states to optimize in parallel. Default value: `()`. (i.e., just run a single optimization). seed: Python integer to seed the random number generator. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_variational_loss'). Returns: variational_loss: `float` `Tensor` of shape `concat([init_batch_shape, model.batch_shape])`, encoding a stochastic estimate of an upper bound on the negative model evidence `-log p(y)`. Minimizing this loss performs variational inference; the gap between the variational bound and the true (generally unknown) model evidence corresponds to the divergence `KL[q||p]` between the approximate and true posterior. variational_distributions: `collections.OrderedDict` giving the approximate posterior for each model parameter. The keys are Python `str` parameter names in order, corresponding to `[param.name for param in model.parameters]`. The values are `tfd.Distribution` instances with batch shape `concat([init_batch_shape, model.batch_shape])`; these will typically be of the form `tfd.TransformedDistribution(tfd.Normal(...), bijector=param.bijector)`. #### Examples Assume we've built a structural time-series model: ```python day_of_week = tfp.sts.Seasonal( num_seasons=7, observed_time_series=observed_time_series, name='day_of_week') local_linear_trend = tfp.sts.LocalLinearTrend( observed_time_series=observed_time_series, name='local_linear_trend') model = tfp.sts.Sum(components=[day_of_week, local_linear_trend], observed_time_series=observed_time_series) ``` To run variational inference, we simply construct the loss and optimize it: ```python (variational_loss, variational_distributions) = tfp.sts.build_factored_variational_loss( model=model, observed_time_series=observed_time_series) train_op = tf.train.AdamOptimizer(0.1).minimize(variational_loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for step in range(200): _, loss_ = sess.run((train_op, variational_loss)) if step % 20 == 0: print("step {} loss {}".format(step, loss_)) posterior_samples_ = sess.run({ param_name: q.sample(50) for param_name, q in variational_distributions.items()}) ``` As a more complex example, we might try to avoid local optima by optimizing from multiple initializations in parallel, and selecting the result with the lowest loss: ```python (variational_loss, variational_distributions) = tfp.sts.build_factored_variational_loss( model=model, observed_time_series=observed_time_series, init_batch_shape=[10]) train_op = tf.train.AdamOptimizer(0.1).minimize(variational_loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for step in range(200): _, loss_ = sess.run((train_op, variational_loss)) if step % 20 == 0: print("step {} losses {}".format(step, loss_)) # Draw multiple samples to reduce Monte Carlo error in the optimized # variational bounds. avg_loss = np.mean( [sess.run(variational_loss) for _ in range(25)], axis=0) best_posterior_idx = np.argmin(avg_loss, axis=0).astype(np.int32) ``` #### References [1]: Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, and David M. Blei. Automatic Differentiation Variational Inference. In _Journal of Machine Learning Research_, 2017. https://arxiv.org/abs/1603.00788 """ with tf.compat.v1.name_scope( name, 'build_factored_variational_loss', values=[observed_time_series]) as name: seed = tfd.SeedStream( seed, salt='StructuralTimeSeries_build_factored_variational_loss') variational_distributions = collections.OrderedDict() variational_samples = [] for param in model.parameters: def initial_loc_fn(param): return sample_uniform_initial_state( param, return_constrained=True, init_sample_shape=init_batch_shape, seed=seed()) q = _build_trainable_posterior(param, initial_loc_fn=initial_loc_fn) variational_distributions[param.name] = q variational_samples.append(q.sample(seed=seed())) # Multiple initializations (similar to HMC chains) manifest as an extra # param batch dimension, so we need to add corresponding batch dimension(s) # to `observed_time_series`. observed_time_series = sts_util.pad_batch_dimension_for_multiple_chains( observed_time_series, model, chain_batch_shape=init_batch_shape) # Construct the variational bound. log_prob_fn = model.joint_log_prob(observed_time_series) expected_log_joint = log_prob_fn(*variational_samples) entropy = tf.reduce_sum( input_tensor=[ -q.log_prob(sample) for (q, sample) in zip( variational_distributions.values(), variational_samples) ], axis=0) variational_loss = -(expected_log_joint + entropy) # -ELBO return variational_loss, variational_distributions
def fit_with_hmc(model, observed_time_series, num_results=100, num_warmup_steps=50, num_leapfrog_steps=15, initial_state=None, initial_step_size=None, chain_batch_shape=(), num_variational_steps=150, variational_optimizer=None, seed=None, name=None): """Draw posterior samples using Hamiltonian Monte Carlo (HMC). Markov chain Monte Carlo (MCMC) methods are considered the gold standard of Bayesian inference; under suitable conditions and in the limit of infinitely many draws they generate samples from the true posterior distribution. HMC [1] uses gradients of the model's log-density function to propose samples, allowing it to exploit posterior geometry. However, it is computationally more expensive than variational inference and relatively sensitive to tuning. This method attempts to provide a sensible default approach for fitting StructuralTimeSeries models using HMC. It first runs variational inference as a fast posterior approximation, and initializes the HMC sampler from the variational posterior, using the posterior standard deviations to set per-variable step sizes (equivalently, a diagonal mass matrix). During the warmup phase, it adapts the step size to target an acceptance rate of 0.75, which is thought to be in the desirable range for optimal mixing [2]. Args: model: An instance of `StructuralTimeSeries` representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape `[b1, ..., bN]`. observed_time_series: `float` `Tensor` of shape `concat([sample_shape, model.batch_shape, [num_timesteps, 1]]) where `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]` dimension may (optionally) be omitted if `num_timesteps > 1`. May optionally be an instance of `tfp.sts.MaskedTimeSeries`, which includes a mask `Tensor` to specify timesteps with missing observations. num_results: Integer number of Markov chain draws. Default value: `100`. num_warmup_steps: Integer number of steps to take before starting to collect results. The warmup steps are also used to adapt the step size towards a target acceptance rate of 0.75. Default value: `50`. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. Default value: `15`. initial_state: Optional Python `list` of `Tensor`s, one for each model parameter, representing the initial state(s) of the Markov chain(s). These should have shape `concat([chain_batch_shape, param.prior.batch_shape, param.prior.event_shape])`. If `None`, the initial state is set automatically using a sample from a variational posterior. Default value: `None`. initial_step_size: Python `list` of `Tensor`s, one for each model parameter, representing the step size for the leapfrog integrator. Must broadcast with the shape of `initial_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. If `None`, the step size is set automatically using the standard deviation of a variational posterior. Default value: `None`. chain_batch_shape: Batch shape (Python `tuple`, `list`, or `int`) of chains to run in parallel. Default value: `[]` (i.e., a single chain). num_variational_steps: Python `int` number of steps to run the variational optimization to determine the initial state and step sizes. Default value: `200`. variational_optimizer: Optional `tf.train.Optimizer` instance to use in the variational optimization. If `None`, defaults to `tf.train.AdamOptimizer(0.1)`. Default value: `None`. seed: Python integer to seed the random number generator. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'fit_with_hmc'). Returns: samples: Python `list` of `Tensors` representing posterior samples of model parameters, with shapes `[concat([[num_results], chain_batch_shape, param.prior.batch_shape, param.prior.event_shape]) for param in model.parameters]`. kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the HMC sampler. #### Examples Assume we've built a structural time-series model: ```python day_of_week = tfp.sts.Seasonal( num_seasons=7, observed_time_series=observed_time_series, name='day_of_week') local_linear_trend = tfp.sts.LocalLinearTrend( observed_time_series=observed_time_series, name='local_linear_trend') model = tfp.sts.Sum(components=[day_of_week, local_linear_trend], observed_time_series=observed_time_series) ``` To draw posterior samples using HMC under default settings: ```python samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) samples_, kernel_results_ = sess.run((samples, kernel_results)) print("acceptance rate: {}".format( np.mean(kernel_results_.inner_results.is_accepted, axis=0))) print("posterior means: {}".format( {param.name: np.mean(param_draws, axis=0) for (param, param_draws) in zip(model.parameters, samples_)})) ``` We can also run multiple chains. This may help diagnose convergence issues and allows us to exploit vectorization to draw samples more quickly, although warmup still requires the same number of sequential steps. ```python from matplotlib import pylab as plt samples, kernel_results = tfp.sts.fit_with_hmc( model, observed_time_series, chain_batch_shape=[10]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) samples_, kernel_results_ = sess.run((samples, kernel_results)) print("acceptance rate: {}".format( np.mean(kernel_results_.inner_results.is_accepted, axis=0))) # Plot the sampled traces for each parameter. If the chains have mixed, their # traces should all cover the same region of state space, frequently crossing # over each other. for (param, param_draws) in zip(model.parameters, samples_): if param.prior.event_shape.ndims > 0: print("Only plotting traces for scalar parameters, skipping {}".format( param.name)) continue plt.figure(figsize=[10, 4]) plt.title(param.name) plt.plot(param_draws) plt.ylabel(param.name) plt.xlabel("HMC step") # Combining the samples from multiple chains into a single dimension allows # us to easily pass sampled parameters to downstream forecasting methods. combined_samples_ = [np.reshape(param_draws, [-1] + list(param_draws.shape[2:])) for param_draws in samples_] ``` For greater flexibility, you may prefer to implement your own sampler using the TensorFlow Probability primitives in `tfp.mcmc`. The following recipe constructs a basic HMC sampler, using a `TransformedTransitionKernel` to incorporate constraints on the parameter space. ```python transformed_hmc_kernel = mcmc.TransformedTransitionKernel( inner_kernel=mcmc.HamiltonianMonteCarlo( target_log_prob_fn=model.joint_log_prob(observed_time_series), step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy( num_adaptation_steps=num_adaptation_steps), state_gradients_are_stopped=True, seed=seed), bijector=[param.bijector for param in model.parameters]) # Initialize from a Uniform[-2, 2] distribution in unconstrained space. initial_state = [tfp.sts.sample_uniform_initial_state( param, return_constrained=True) for param in model.parameters] samples, kernel_results = tfp.mcmc.sample_chain( kernel=transformed_hmc_kernel, num_results=num_results, current_state=initial_state, num_burnin_steps=num_warmup_steps) ``` #### References [1]: Radford Neal. MCMC Using Hamiltonian Dynamics. _Handbook of Markov Chain Monte Carlo_, 2011. https://arxiv.org/abs/1206.1901 [2] M.J. Betancourt, Simon Byrne, and Mark Girolami. Optimizing The Integrator Step Size for Hamiltonian Monte Carlo. https://arxiv.org/abs/1411.6669 """ with tf.compat.v1.name_scope( name, 'fit_with_hmc', values=[observed_time_series]) as name: seed = tfd.SeedStream(seed, salt='StructuralTimeSeries_fit_with_hmc') # Initialize state and step sizes from a variational posterior if not # specified. if initial_step_size is None or initial_state is None: # To avoid threading variational distributions through the training # while loop, we build our own copy here. `make_template` ensures # that our variational distributions share the optimized parameters. def make_variational(): return build_factored_variational_loss( model, observed_time_series, init_batch_shape=chain_batch_shape, seed=seed()) make_variational = tf.compat.v1.make_template('make_variational', make_variational) _, variational_distributions = make_variational() minimize_op = _minimize_in_graph( build_loss_fn=lambda: make_variational()[0], # return just the loss. num_steps=num_variational_steps, optimizer=variational_optimizer) with tf.control_dependencies([minimize_op]): if initial_state is None: initial_state = [tf.stop_gradient(d.sample()) for d in variational_distributions.values()] # Set step sizes using the unconstrained variational distribution. if initial_step_size is None: initial_step_size = [ transformed_q.distribution.stddev() for transformed_q in variational_distributions.values()] # Multiple chains manifest as an extra param batch dimension, so we need to # add a corresponding batch dimension to `observed_time_series`. observed_time_series = sts_util.pad_batch_dimension_for_multiple_chains( observed_time_series, model, chain_batch_shape=chain_batch_shape) # When the initial step size depends on a variational optimization, we # can't initialize step size variables before the optimization runs. # Instead we initialize with a dummy value of the appropriate # shape, then wrap the HMC chain with `control_dependencies` to ensure the # variational step sizes are assigned before HMC actually runs. step_size = [ tf.compat.v1.get_variable( initializer=tf.zeros_like( sample_uniform_initial_state( param, init_sample_shape=chain_batch_shape, return_constrained=False)), name='{}_step_size'.format(param.name), trainable=False, use_resource=True) for (param, ss) in zip(model.parameters, initial_step_size) ] step_size_init_op = tf.group([ tf.compat.v1.assign(ss, initial_ss) for (ss, initial_ss) in zip(step_size, initial_step_size) ]) # Run HMC to sample from the posterior on parameters. with tf.control_dependencies([step_size_init_op]): samples, kernel_results = mcmc.sample_chain( num_results=num_results, current_state=initial_state, num_burnin_steps=num_warmup_steps, kernel=mcmc.TransformedTransitionKernel( inner_kernel=mcmc.HamiltonianMonteCarlo( target_log_prob_fn=model.joint_log_prob(observed_time_series), step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, step_size_update_fn=mcmc.make_simple_step_size_update_policy( num_adaptation_steps=int(num_warmup_steps * 0.8), decrement_multiplier=0.1, increment_multiplier=0.1), state_gradients_are_stopped=True, seed=seed()), bijector=[param.bijector for param in model.parameters]), parallel_iterations=1 if seed is not None else 10) return samples, kernel_results
def fit_with_hmc(model, observed_time_series, num_results=100, num_warmup_steps=50, num_leapfrog_steps=15, initial_state=None, initial_step_size=None, chain_batch_shape=(), num_variational_steps=150, variational_optimizer=None, variational_sample_size=5, seed=None, name=None): """Draw posterior samples using Hamiltonian Monte Carlo (HMC). Markov chain Monte Carlo (MCMC) methods are considered the gold standard of Bayesian inference; under suitable conditions and in the limit of infinitely many draws they generate samples from the true posterior distribution. HMC [1] uses gradients of the model's log-density function to propose samples, allowing it to exploit posterior geometry. However, it is computationally more expensive than variational inference and relatively sensitive to tuning. This method attempts to provide a sensible default approach for fitting StructuralTimeSeries models using HMC. It first runs variational inference as a fast posterior approximation, and initializes the HMC sampler from the variational posterior, using the posterior standard deviations to set per-variable step sizes (equivalently, a diagonal mass matrix). During the warmup phase, it adapts the step size to target an acceptance rate of 0.75, which is thought to be in the desirable range for optimal mixing [2]. Args: model: An instance of `StructuralTimeSeries` representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape `[b1, ..., bN]`. observed_time_series: `float` `Tensor` of shape `concat([sample_shape, model.batch_shape, [num_timesteps, 1]])` where `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]` dimension may (optionally) be omitted if `num_timesteps > 1`. Any `NaN`s are interpreted as missing observations; missingness may be also be explicitly specified by passing a `tfp.sts.MaskedTimeSeries` instance. num_results: Integer number of Markov chain draws. Default value: `100`. num_warmup_steps: Integer number of steps to take before starting to collect results. The warmup steps are also used to adapt the step size towards a target acceptance rate of 0.75. Default value: `50`. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. Default value: `15`. initial_state: Optional Python `list` of `Tensor`s, one for each model parameter, representing the initial state(s) of the Markov chain(s). These should have shape `concat([chain_batch_shape, param.prior.batch_shape, param.prior.event_shape])`. If `None`, the initial state is set automatically using a sample from a variational posterior. Default value: `None`. initial_step_size: Python `list` of `Tensor`s, one for each model parameter, representing the step size for the leapfrog integrator. Must broadcast with the shape of `initial_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. If `None`, the step size is set automatically using the standard deviation of a variational posterior. Default value: `None`. chain_batch_shape: Batch shape (Python `tuple`, `list`, or `int`) of chains to run in parallel. Default value: `[]` (i.e., a single chain). num_variational_steps: Python `int` number of steps to run the variational optimization to determine the initial state and step sizes. Default value: `150`. variational_optimizer: Optional `tf.train.Optimizer` instance to use in the variational optimization. If `None`, defaults to `tf.train.AdamOptimizer(0.1)`. Default value: `None`. variational_sample_size: Python `int` number of Monte Carlo samples to use in estimating the variational divergence. Larger values may stabilize the optimization, but at higher cost per step in time and memory. Default value: `1`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'fit_with_hmc'). Returns: samples: Python `list` of `Tensors` representing posterior samples of model parameters, with shapes `[concat([[num_results], chain_batch_shape, param.prior.batch_shape, param.prior.event_shape]) for param in model.parameters]`. kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the HMC sampler. #### Examples Assume we've built a structural time-series model: ```python day_of_week = tfp.sts.Seasonal( num_seasons=7, observed_time_series=observed_time_series, name='day_of_week') local_linear_trend = tfp.sts.LocalLinearTrend( observed_time_series=observed_time_series, name='local_linear_trend') model = tfp.sts.Sum(components=[day_of_week, local_linear_trend], observed_time_series=observed_time_series) ``` To draw posterior samples using HMC under default settings: ```python samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series) print("acceptance rate: {}".format( np.mean(kernel_results.inner_results.inner_results.is_accepted, axis=0))) print("posterior means: {}".format( {param.name: np.mean(param_draws, axis=0) for (param, param_draws) in zip(model.parameters, samples)})) ``` We can also run multiple chains. This may help diagnose convergence issues and allows us to exploit vectorization to draw samples more quickly, although warmup still requires the same number of sequential steps. ```python from matplotlib import pylab as plt samples, kernel_results = tfp.sts.fit_with_hmc( model, observed_time_series, chain_batch_shape=[10]) print("acceptance rate: {}".format( np.mean(kernel_results.inner_results.inner_results.is_accepted, axis=0))) # Plot the sampled traces for each parameter. If the chains have mixed, their # traces should all cover the same region of state space, frequently crossing # over each other. for (param, param_draws) in zip(model.parameters, samples): if param.prior.event_shape.ndims > 0: print("Only plotting traces for scalar parameters, skipping {}".format( param.name)) continue plt.figure(figsize=[10, 4]) plt.title(param.name) plt.plot(param_draws.numpy()) plt.ylabel(param.name) plt.xlabel("HMC step") # Combining the samples from multiple chains into a single dimension allows # us to easily pass sampled parameters to downstream forecasting methods. combined_samples = [np.reshape(param_draws, [-1] + list(param_draws.shape[2:])) for param_draws in samples] ``` For greater flexibility, you may prefer to implement your own sampler using the TensorFlow Probability primitives in `tfp.mcmc`. The following recipe constructs a basic HMC sampler, using a `TransformedTransitionKernel` to incorporate constraints on the parameter space. ```python transformed_hmc_kernel = tfp.mcmc.TransformedTransitionKernel( inner_kernel=tfp.mcmc.DualAveragingStepSizeAdaptation( inner_kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=model.joint_distribution(observed_time_series).log_prob, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=True, seed=seed), num_adaptation_steps = int(0.8 * num_warmup_steps)), bijector=[param.bijector for param in model.parameters]) # Initialize from a Uniform[-2, 2] distribution in unconstrained space. initial_state = [tfp.sts.sample_uniform_initial_state( param, return_constrained=True) for param in model.parameters] samples, kernel_results = tfp.mcmc.sample_chain( kernel=transformed_hmc_kernel, num_results=num_results, current_state=initial_state, num_burnin_steps=num_warmup_steps) ``` #### References [1]: Radford Neal. MCMC Using Hamiltonian Dynamics. _Handbook of Markov Chain Monte Carlo_, 2011. https://arxiv.org/abs/1206.1901 [2] M.J. Betancourt, Simon Byrne, and Mark Girolami. Optimizing The Integrator Step Size for Hamiltonian Monte Carlo. https://arxiv.org/abs/1411.6669 """ with tf.name_scope(name or 'fit_with_hmc') as name: init_seed, vi_seed, hmc_seed = samplers.split_seed( seed=seed, n=3, salt='StructuralTimeSeries_fit_with_hmc') observed_time_series = sts_util.pad_batch_dimension_for_multiple_chains( observed_time_series, model, chain_batch_shape=chain_batch_shape) target_log_prob_fn = model.joint_distribution( observed_time_series).log_prob # Initialize state and step sizes from a variational posterior if not # specified. if initial_step_size is None or initial_state is None: variational_posterior = build_factored_surrogate_posterior( model, batch_shape=chain_batch_shape, seed=init_seed) if variational_optimizer is None: variational_optimizer = tf1.train.AdamOptimizer( learning_rate=0.1 ) # TODO(b/137299119) Replace with TF2 optimizer. loss_curve = vi.fit_surrogate_posterior( target_log_prob_fn, variational_posterior, sample_size=variational_sample_size, num_steps=num_variational_steps, optimizer=variational_optimizer, seed=vi_seed) with tf.control_dependencies([loss_curve]): if initial_state is None: posterior_sample = variational_posterior.sample() initial_state = [ posterior_sample[p.name] for p in model.parameters ] # Set step sizes using the unconstrained variational distribution. if initial_step_size is None: q_dists_by_name, _ = (variational_posterior.distribution. sample_distributions()) initial_step_size = [ q_dists_by_name[p.name].stddev() for p in model.parameters ] # Run HMC to sample from the posterior on parameters. @tf.function(autograph=False) def run_hmc(): return mcmc.sample_chain( num_results=num_results, current_state=initial_state, num_burnin_steps=num_warmup_steps, kernel=mcmc.DualAveragingStepSizeAdaptation( inner_kernel=mcmc.TransformedTransitionKernel( inner_kernel=mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, step_size=initial_step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=True), bijector=[ param.bijector for param in model.parameters ]), num_adaptation_steps=int(num_warmup_steps * 0.8)), seed=hmc_seed) samples, kernel_results = run_hmc() return samples, kernel_results