def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'dual_averaging_step_size_adaptation', 'bootstrap_results')): inner_results = self.inner_kernel.bootstrap_results(init_state) step_size = self.step_size_getter_fn(inner_results) log_accept_prob = self.log_accept_prob_getter_fn(inner_results) state_parts = tf.nest.flatten(init_state) step_size_parts = tf.nest.flatten(step_size) dtype = dtype_util.common_dtype(step_size_parts, tf.float32) error_sum, log_averaging_step, log_shrinkage_target = [], [], [] for state_part, step_size_part in zip(state_parts, step_size_parts): num_reduce_dims = prefer_static.minimum( prefer_static.rank(log_accept_prob), prefer_static.rank(state_part) - prefer_static.rank(step_size_part)) reduced_log_accept_prob = reduce_logmeanexp( log_accept_prob, axis=prefer_static.range(num_reduce_dims)) reduce_indices = get_differing_dims(reduced_log_accept_prob, step_size_part) reduced_log_accept_prob = reduce_logmeanexp( reduced_log_accept_prob, axis=reduce_indices, keepdims=True) error_sum.append( tf.zeros_like(reduced_log_accept_prob, dtype=dtype)) log_averaging_step.append( tf.zeros_like(step_size_part, dtype=dtype)) if self._parameters['shrinkage_target'] is None: log_shrinkage_target.append( float(np.log(10.)) + tf.math.log(step_size_part)) else: log_shrinkage_target.append( tf.math.log( tf.cast(self._parameters['shrinkage_target'], dtype))) return DualAveragingStepSizeAdaptationResults( inner_results=inner_results, step=tf.constant(0, dtype=tf.int32), target_accept_prob=tf.cast( self.parameters['target_accept_prob'], log_accept_prob.dtype), log_shrinkage_target=log_shrinkage_target, exploration_shrinkage=tf.cast( self.parameters['exploration_shrinkage'], dtype), step_count_smoothing=tf.cast( self.parameters['step_count_smoothing'], dtype), decay_rate=tf.cast(self.parameters['decay_rate'], dtype), error_sum=error_sum, log_averaging_step=log_averaging_step, new_step_size=step_size)
def _bootstrap_from_inner_results(self, init_state, inner_results): step_size = self.step_size_getter_fn(inner_results) log_accept_prob = self.log_accept_prob_getter_fn(inner_results) state_parts = tf.nest.flatten(init_state) step_size_parts = tf.nest.flatten(step_size) if self._parameters['shrinkage_target'] is None: shrinkage_target_parts = [None] * len(step_size_parts) else: shrinkage_target_parts = tf.nest.flatten( self._parameters['shrinkage_target']) if len(shrinkage_target_parts) not in [1, len(step_size_parts)]: raise ValueError( '`shrinkage_target` should be a Tensor or list of tensors of ' 'same length as `step_size`. Found len(`step_size`) = {} and ' 'len(shrinkage_target) = {}'.format( len(step_size_parts), len(shrinkage_target_parts))) if len(shrinkage_target_parts) < len(step_size_parts): shrinkage_target_parts *= len(step_size_parts) dtype = dtype_util.common_dtype(step_size_parts, tf.float32) error_sum, log_averaging_step, log_shrinkage_target = [], [], [] for state_part, step_size_part, shrinkage_target_part in zip( state_parts, step_size_parts, shrinkage_target_parts): num_reduce_dims = ps.minimum( ps.rank(log_accept_prob), ps.rank(state_part) - ps.rank(step_size_part)) reduced_log_accept_prob = reduce_logmeanexp( log_accept_prob, axis=ps.range(num_reduce_dims), experimental_named_axis=self. experimental_reduce_chain_axis_names) reduce_indices = get_differing_dims(reduced_log_accept_prob, step_size_part) reduced_log_accept_prob = reduce_logmeanexp( reduced_log_accept_prob, axis=reduce_indices, keepdims=True) error_sum.append( tf.zeros_like(reduced_log_accept_prob, dtype=dtype)) log_averaging_step.append( tf.zeros_like(step_size_part, dtype=dtype)) if shrinkage_target_part is None: log_shrinkage_target.append( float(np.log(10.)) + tf.math.log(step_size_part)) else: log_shrinkage_target.append( tf.math.log(tf.cast(shrinkage_target_part, dtype))) return DualAveragingStepSizeAdaptationResults( inner_results=inner_results, step=tf.constant(0, dtype=tf.int32), target_accept_prob=tf.cast(self.parameters['target_accept_prob'], log_accept_prob.dtype), log_shrinkage_target=log_shrinkage_target, exploration_shrinkage=tf.cast( self.parameters['exploration_shrinkage'], dtype), step_count_smoothing=tf.cast( self.parameters['step_count_smoothing'], dtype), decay_rate=tf.cast(self.parameters['decay_rate'], dtype), error_sum=error_sum, log_averaging_step=log_averaging_step, new_step_size=step_size, num_adaptation_steps=tf.cast(self.num_adaptation_steps, dtype=tf.int32))
def _one_step_part(self, step_size, state, error_sum, log_averaging_step, log_shrinkage_target, log_accept_prob_rank=None, log_accept_prob=None, target_accept_prob=None, previous_kernel_results=None): """Compute new step sizes for each step size part. If step size part has smaller rank than the corresponding state part, then the difference is averaged away in the log accept prob. Example: state_part has shape [2, 3, 4, 5] step_size_part has shape [1, 4, 1] log_accept_prob has shape [2, 3, 4] Since step size has 1 rank fewer than the state, we reduce away the leading dimension of `log_accept_prob` to get a Tensor with shape [3, 4]. Next, since `log_accept_prob` must broadcast into step_size_part on the left, we reduce the dimensions where their shapes differ, to get a Tensor with shape [1, 4], which now is compatible with the leading dimensions of step_size_part. There is a subtlety here in that `step_size_parts` might be a length-1 list, which means that we'll be "structure-broadcasting" it for all the state parts (see logic in, e.g., hmc.py). In this case we must assume that that the lone step size provided broadcasts with the event dims of each state part. This means that either step size has no dimensions corresponding to chain dimensions, or all states are of the same shape. For the former, we want to reduce over all chain dimensions. For the later, we want to use the same logic as in the non-structure-broadcasted case. It turns out we can compute the reduction dimensions for both cases uniformly by taking the rank of any state part. This obviously works in the second case (where all state ranks are the same). In the first case, all state parts have the rank L + D_i + B, where L is the rank of log_accept_prob, D_i is the non-shared dimensions amongst all states, and B are the shared dimensions of all the states, which are equal to the step size. When we subtract B, we will always get a number >= L, which means we'll get the full reduction we want. Args: step_size: Previous step's step_size. state: Previous step's state value. error_sum: Previous step's error accumulator. log_averaging_step: Previous step's log_averaging_step. log_shrinkage_target: Floating point scalar `Tensor`. Logarithm of value the exploration step size is biased towards. log_accept_prob_rank: Rank of log_accept_prob. log_accept_prob: Floating point scalar `Tensor`. Target accept probability. target_accept_prob: A floating point `Tensor` representing desired acceptance probability. Must be a positive number less than 1. previous_kernel_results: Results struct from previous step. Returns: new_step_size: Updated `step_size`. new_log_averaging_step: Updated `log_averaging_step`. new_error_sum: Updated `error_sum`. """ num_reduce_dims = ps.minimum(log_accept_prob_rank, (ps.rank(state) - ps.rank(step_size))) reduced_log_accept_prob = self.reduce_fn( log_accept_prob, axis=ps.range(num_reduce_dims), keepdims=False, experimental_named_axis=self.experimental_reduce_chain_axis_names) # reduced_log_accept_prob must broadcast into step_size on the # left, so we do an additional reduction over dimensions where their # shapes differ. reduce_indices = get_differing_dims(reduced_log_accept_prob, step_size) reduced_log_accept_prob = self.reduce_fn( reduced_log_accept_prob, axis=reduce_indices, keepdims=True, experimental_named_axis=self.experimental_reduce_chain_axis_names) new_error_sum = (error_sum + target_accept_prob - tf.math.exp(reduced_log_accept_prob)) num_ones_to_pad = ps.maximum( ps.rank(log_shrinkage_target) - ps.rank(new_error_sum), 0) new_error_sum_extend = tf.reshape(new_error_sum, shape=ps.pad( ps.shape(new_error_sum), paddings=[[0, num_ones_to_pad]], constant_values=1)) step_count_smoothing = previous_kernel_results.step_count_smoothing step = tf.cast(previous_kernel_results.step, step_count_smoothing.dtype) + 1. soft_t = step_count_smoothing + step new_log_step = (log_shrinkage_target - ( (tf.cast(new_error_sum_extend, step.dtype) * tf.math.sqrt(step)) / (soft_t * previous_kernel_results.exploration_shrinkage))) eta = step**(-previous_kernel_results.decay_rate) new_log_averaging_step = (eta * new_log_step + (1. - eta) * log_averaging_step) # - If still adapting, return an exploring step size, # - If just finished, return the averaging step size # - Otherwise, do not update num_adaptation_steps = previous_kernel_results.num_adaptation_steps step = previous_kernel_results.step + 1 new_step_size = tf.where( step < num_adaptation_steps, tf.math.exp(new_log_step), tf.where(step > num_adaptation_steps, step_size, tf.math.exp(new_log_averaging_step))) new_log_averaging_step = tf.where(step > num_adaptation_steps, log_averaging_step, new_log_averaging_step) return new_step_size, new_log_averaging_step, new_error_sum