def _init_momentum(initial_transformed_position, *, batch_shape): """Initialize momentum so trace_fn can be concatenated.""" variance_parts = [ps.ones_like(p) for p in initial_transformed_position] return preconditioning_utils.make_momentum_distribution( state_parts=initial_transformed_position, batch_shape=batch_shape, running_variance_parts=variance_parts)
def _init_momentum(initial_transformed_position): """Initialize momentum so trace_fn can be concatenated.""" event_shape = ps.shape(initial_transformed_position)[-1] return preconditioning_utils.make_momentum_distribution( state_parts=tf.nest.flatten(initial_transformed_position), batch_ndims=1, running_variance_parts=[ps.ones(event_shape)])
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'bootstrap_results')): if isinstance(self.initial_running_variance, sample_stats.RunningVariance): variance_parts = [self.initial_running_variance] else: variance_parts = list(self.initial_running_variance) diags = [ variance_part.variance() for variance_part in variance_parts ] # Step inner results. inner_results = self.inner_kernel.bootstrap_results(init_state) # Set the momentum. batch_shape = ps.shape( unnest.get_innermost(inner_results, 'target_log_prob')) init_state_parts = tf.nest.flatten(init_state) momentum_distribution = preconditioning_utils.make_momentum_distribution( init_state_parts, batch_shape, diags) inner_results = self.momentum_distribution_setter_fn( inner_results, momentum_distribution) proposed = unnest.get_innermost(inner_results, 'proposed_results', default=None) if proposed is not None: proposed = proposed._replace( momentum_distribution=momentum_distribution) inner_results = unnest.replace_innermost( inner_results, proposed_results=proposed) return DiagonalMassMatrixAdaptationResults( inner_results=inner_results, running_variance=variance_parts)
def test_momentum_dists(self): state_parts = [ tf.ones([13, 5, 3]), tf.ones([13, 5]), tf.ones([13, 5, 2, 4])] batch_shape = [13, 5] md = pu.make_momentum_distribution(state_parts, batch_shape) md = pu.update_momentum_distribution( md, tf.nest.map_structure( lambda s: tf.reduce_sum(s, (0, 1)), state_parts)) self.evaluate(tf.nest.flatten(md, expand_composites=True))
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] state_parts, _ = mcmc_util.prepare_state_parts(init_state, name='current_state') current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads( self.target_log_prob_fn, state_parts) # Confirm that the step size is compatible with the state parts. _ = _prepare_step_size( self.step_size, current_target_log_prob.dtype, len(init_state)) momentum_distribution = self.momentum_distribution if momentum_distribution is None: momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(current_target_log_prob), shard_axis_names=self.experimental_shard_axis_names) momentum_distribution = pu.maybe_make_list_and_batch_broadcast( momentum_distribution, ps.shape(current_target_log_prob)) momentum_parts = momentum_distribution.sample(seed=samplers.zeros_seed()) return PreconditionedNUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, step_size=tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda x, dtype=current_target_log_prob.dtype, name='step_size'), self.step_size), log_accept_ratio=tf.zeros_like( current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like( current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like( current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like( current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like( current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian(current_target_log_prob, momentum_parts, momentum_distribution), momentum_distribution=momentum_distribution, # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def _prepare_args(target_log_prob_fn, state, step_size, momentum_distribution, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False, experimental_shard_axis_names=None): """Helper which processes input args to meet list-like assumptions.""" state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state') if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes, _ = mcmc_util.prepare_state_parts(step_size, dtype=target_log_prob.dtype, name='step_size') # Default momentum distribution is None if momentum_distribution is None: momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(target_log_prob), shard_axis_names=experimental_shard_axis_names) momentum_distribution = pu.maybe_make_list_and_batch_broadcast( momentum_distribution, ps.shape(target_log_prob)) if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), momentum_distribution, target_log_prob, grads_target_log_prob, ]
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'bootstrap_results')): # Step inner results. inner_results = self.inner_kernel.bootstrap_results(init_state) # Bootstrap the results. results = self._bootstrap_from_inner_results( init_state, inner_results) if self.num_estimation_steps is not None: # We only update the momentum at the end of adaptation phase, # so we do not need to set the momentum here. return results # Set the momentum. diags = [ variance_part.variance() for variance_part in results.running_variance ] inner_results = results.inner_results batch_shape = ps.shape( unnest.get_innermost(inner_results, 'target_log_prob')) init_state_parts = tf.nest.flatten(init_state) momentum_distribution = preconditioning_utils.make_momentum_distribution( init_state_parts, batch_shape, diags, shard_axis_names=self.experimental_shard_axis_names) inner_results = self.momentum_distribution_setter_fn( inner_results, momentum_distribution) proposed = unnest.get_innermost(inner_results, 'proposed_results', default=None) if proposed is not None: proposed = proposed._replace( momentum_distribution=momentum_distribution) inner_results = unnest.replace_innermost( inner_results, proposed_results=proposed) results = results._replace(inner_results=inner_results) return results
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'phmc', 'bootstrap_results')): result = super(UncalibratedPreconditionedHamiltonianMonteCarlo, self).bootstrap_results(init_state) state_parts, _ = mcmc_util.prepare_state_parts( init_state, name='current_state') target_log_prob = self.target_log_prob_fn(*state_parts) if (not self._store_parameters_in_results or self.momentum_distribution is None): momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(target_log_prob)) else: momentum_distribution = pu.maybe_make_list_and_batch_broadcast( self.momentum_distribution, ps.shape(target_log_prob)) result = UncalibratedPreconditionedHamiltonianMonteCarloKernelResults( **result._asdict(), # pylint: disable=protected-access momentum_distribution=momentum_distribution) return result
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] # Padding the step_size so it is compatable with the states step_size = self.step_size if len(step_size) == 1: step_size = step_size * len(init_state) if len(step_size) != len(init_state): raise ValueError( 'Expected either one step size or {} (size of ' '`init_state`), but found {}'.format( len(init_state), len(step_size))) state_parts, _ = mcmc_util.prepare_state_parts( init_state, name='current_state') current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads( self.target_log_prob_fn, state_parts) momentum_distribution = self.momentum_distribution if momentum_distribution is None: momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(current_target_log_prob)) momentum_distribution = pu.maybe_make_list_and_batch_broadcast( momentum_distribution, ps.shape(current_target_log_prob)) momentum_parts = momentum_distribution.sample() def _init(shape_and_dtype): """Allocate TensorArray for storing state and velocity.""" return [ # pylint: disable=g-complex-comprehension ps.zeros(ps.concat([[max(self._write_instruction) + 1], s], axis=0), dtype=d) for (s, d) in shape_and_dtype ] get_shapes_and_dtypes = lambda x: [ (ps.shape(x_), x_.dtype) # pylint: disable=g-long-lambda for x_ in x ] velocity_state_memory = VelocityStateSwap( velocity_swap=_init(get_shapes_and_dtypes(momentum_parts)), state_swap=_init(get_shapes_and_dtypes(init_state))) return PreconditionedNUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, velocity_state_memory=velocity_state_memory, step_size=step_size, log_accept_ratio=tf.zeros_like(current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like(current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian(current_target_log_prob, momentum_parts, momentum_distribution), momentum_distribution=momentum_distribution, # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'one_step')): variance_parts = previous_kernel_results.running_variance diags = [ variance_part.variance() for variance_part in variance_parts ] # Set the momentum. batch_ndims = ps.rank( unnest.get_innermost(previous_kernel_results, 'target_log_prob')) state_parts = tf.nest.flatten(current_state) new_momentum_distribution = preconditioning_utils.make_momentum_distribution( state_parts, batch_ndims, diags) inner_results = self.momentum_distribution_setter_fn( previous_kernel_results.inner_results, new_momentum_distribution) # Step the inner kernel. inner_kwargs = {} if seed is None else dict(seed=seed) new_state, new_inner_results = self.inner_kernel.one_step( current_state, inner_results, **inner_kwargs) new_state_parts = tf.nest.flatten(new_state) new_variance_parts = [] for variance_part, diag, state_part in zip(variance_parts, diags, new_state_parts): # Compute new variance for each variance part, accounting for partial # batching of the variance calculation across chains (ie, some, all, or # none of the chains may share the estimated mass matrix). # # For example, say # # state_part has shape [2, 3, 4] + [5, 6] (batch + event) # variance_part has shape [4] + [5, 6] # log_prob has shape [2, 3, 4] # # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass # matrices, each being shared across a [2, 3]-batch of chains. Note this # division is inferred from the shapes of the state part, the log_prob, # and the user-provided initial running variances. # # Until RunningVariance supports rank > 1 chunking, we need to flatten # the states that go into updating the variance estimates. In the above # example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and # fed to `RunningVariance.update(state_part, axis=0)`, recording # 6 new observations in the running variance calculation. # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and # the resulting momentum distribution will have batch shape of # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part. state_rank = ps.rank(state_part) variance_rank = ps.rank(diag) num_reduce_dims = state_rank - variance_rank state_part_shape = ps.shape(state_part) # This reshape adds a 1 when reduce_dims==0, and collapses all the lead # dimensions to a single one otherwise. reshaped_state = ps.reshape( state_part, ps.concat( [[ps.reduce_prod(state_part_shape[:num_reduce_dims])], state_part_shape[num_reduce_dims:]], axis=0)) # The `axis=0` here removes the leading dimension we got from the # reshape above, so the new_variance_parts have the correct shape again. new_variance_parts.append( variance_part.update(reshaped_state, axis=0)) new_kernel_results = previous_kernel_results._replace( inner_results=new_inner_results, running_variance=new_variance_parts) return new_state, new_kernel_results