def _transpose_around_bijector_fn(self, bijector_fn, arg, src_event_ndims, dest_event_ndims=None, fn_reduces_event=False, **kwargs): # This function moves the axes corresponding to `self.sample_shape` to the # left of the batch shape, then applies `bijector_fn`, then moves the axes # corresponding to `self.sample_shape` back to the event part of the shape. # # `src_event_ndims` and `dest_event_ndims` indicate the expected event rank # (omitting `self.sample_shape`) before and after applying `bijector_fn`. # # This function arose because forward and inverse ended up being quite # similar. It was then only a small generalization to also support {F/I}LDJ. batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = ps.rank_from_shape(self.sample_shape) arg_ndims = ps.rank(arg) # (1) Expand arg's dims. d = arg_ndims - batch_ndims - extra_sample_ndims - src_event_ndims arg = tf.reshape(arg, shape=ps.pad(ps.shape(arg), paddings=[[ps.maximum(0, -d), 0]], constant_values=1)) arg_ndims = ps.rank(arg) sample_ndims = ps.maximum(0, d) # (2) Transpose arg's dims. sample_dims = ps.range(0, sample_ndims) batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = ps.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims, arg_ndims) perm = ps.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) arg = tf.transpose(arg, perm=perm) # (3) Apply underlying bijector. result = bijector_fn(arg, **kwargs) # (4) Transpose sample_shape from the sample to the event shape. result_ndims = ps.rank(result) if fn_reduces_event: dest_event_ndims = 0 d = result_ndims - batch_ndims - extra_sample_ndims - dest_event_ndims if fn_reduces_event: # In some cases, fn may reduce event too far, i.e. ildj may return a # scalar `0.`, which won't work with the transpose we do below. result = tf.reshape(result, shape=ps.pad(ps.shape(result), paddings=[[ps.maximum(0, -d), 0]], constant_values=1)) result_ndims = ps.rank(result) sample_ndims = ps.maximum(0, d) sample_dims = ps.range(0, sample_ndims) extra_sample_dims = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) batch_dims = ps.range(sample_ndims + extra_sample_ndims, sample_ndims + extra_sample_ndims + batch_ndims) event_dims = ps.range(sample_ndims + extra_sample_ndims + batch_ndims, result_ndims) perm = ps.concat( [sample_dims, batch_dims, extra_sample_dims, event_dims], axis=0) return tf.transpose(result, perm=perm)
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name): x = tf.convert_to_tensor(x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = prefer_static.rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = distribution_util.rotate_transpose(x, shift) if center: x_rotated -= tf.reduce_mean(x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = prefer_static.shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = tf.cast(x_len, np.float64) target_length = tf.pow( np.float64(2.), tf.math.ceil(tf.math.log(x_len_float64 * 2) / np.log(2.))) pad_length = tf.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = distribution_util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype_util.is_complex(dtype): if not dtype_util.is_floating(dtype): raise TypeError( 'Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex( x_rotated_pad, dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.signal.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.signal.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not tensorshape_util.is_fully_defined(x_rotated.shape): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name='max_lags') max_lags_ = tf.get_static_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = tensorshape_util.as_list(x_rotated.shape) chopped_shape[-1] = min(x_len, max_lags + 1) shifted_product_chopped.set_shape(chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = tf.cast(x_len, dtype_util.real_dtype(dtype)) max_lags = tf.cast(max_lags, dtype_util.real_dtype(dtype)) denominator = x_len - tf.range(0., max_lags + 1.) denominator = tf.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return distribution_util.rotate_transpose(shifted_product_rotated, -shift)
def where_left_justified_mask(mask, vals1, vals2, name=None): """Like `tf.where`, but broadcasts the `mask` left-justified.""" with tf.name_scope(name or 'where_left_justified_mask'): target_rank = ps.maximum(ps.rank(vals1), ps.rank(vals2)) bcast_mask = left_justified_expand_dims_to(mask, target_rank) return tf.where(bcast_mask, vals1, vals2)
def _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory, seed): """Base case in tree doubling.""" acceptance_seed, next_seed = samplers.split_seed(seed) with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) momentum_cumsum = [ p0 + p1 for p0, p1 in zip(momentum_cumsum_previous, next_momentum_parts) ] # If the tree have not yet terminated previously, we count this leapfrog. leapfrogs_taken = tf.where(continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken) write_instruction = current_step_meta_info.write_instruction read_instruction = current_step_meta_info.read_instruction init_energy = current_step_meta_info.init_energy if GENERALIZED_UTURN: state_to_write = momentum_cumsum_previous state_to_check = momentum_cumsum else: state_to_write = next_state_parts state_to_check = next_state_parts batch_shape = ps.shape(next_target) has_not_u_turn_init = ps.ones(batch_shape, dtype=tf.bool) read_index = read_instruction.gather([iter_])[0] no_u_turns_within_tree = has_not_u_turn_at_all_index( # pylint: disable=g-long-lambda read_index, directions, momentum_state_memory, next_momentum_parts, state_to_check, has_not_u_turn_init, log_prob_rank=ps.rank(next_target), shard_axis_names=self.experimental_shard_axis_names) # Get index to write state into memory swap write_index = write_instruction.gather([iter_]) momentum_state_memory = MomentumStateSwap( momentum_swap=[ _safe_tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ _safe_tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(momentum_state_memory.state_swap, state_to_write) ]) energy = compute_hamiltonian( next_target, next_momentum_parts, shard_axis_names=self.experimental_shard_axis_names) current_energy = tf.where(tf.math.is_nan(energy), tf.constant(-np.inf, dtype=energy.dtype), energy) energy_diff = current_energy - init_energy if MULTINOMIAL_SAMPLE: not_divergent = -energy_diff < self.max_energy_diff weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff) log_accept_thresh = energy_diff - weight_sum else: log_slice_sample = current_step_meta_info.log_slice_sample not_divergent = log_slice_sample - energy_diff < self.max_energy_diff # Uniform sampling on the trajectory within the subtree across valid # samples. is_valid = log_slice_sample <= energy_diff weight_sum = tf.where(is_valid, candidate_tree_state.weight + 1, candidate_tree_state.weight) log_accept_thresh = tf.where( is_valid, -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)), tf.constant(-np.inf, dtype=tf.float32)) u = tf.math.log1p(-samplers.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=acceptance_seed)) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ bu.where_left_justified_mask(is_sample_accepted, s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=bu.where_left_justified_mask( is_sample_accepted, next_target, candidate_tree_state.target), target_grad_parts=[ bu.where_left_justified_mask(is_sample_accepted, grad0, grad1) for grad0, grad1 in zip( next_target_grad_parts, candidate_tree_state.target_grad_parts) ], energy=bu.where_left_justified_mask( is_sample_accepted, current_energy, candidate_tree_state.energy), weight=weight_sum) continue_tree = not_divergent & continue_tree_previous continue_tree_next = no_u_turns_within_tree & continue_tree not_divergent_tokeep = tf.where( continue_tree_previous, not_divergent, ps.ones(batch_shape, dtype=tf.bool)) # min(1., exp(energy_diff)). exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.)) energy_diff_sum = tf.where( continue_tree, energy_diff_sum_previous + exp_energy_diff, energy_diff_sum_previous) return ( iter_ + 1, next_seed, energy_diff_sum, momentum_cumsum, leapfrogs_taken, next_tree_state, next_candidate_tree_state, continue_tree_next, not_divergent_previous & not_divergent_tokeep, momentum_state_memory, )
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( self.target_log_prob_fn, inverse_temperatures) # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.make_kernel_fn`, and no seed # expected by `kernel.one_step`. # In other words: # - We try `make_kernel_fn` without a seed first; this is the future. The # kernel will receive a seed later, as part of `one_step`. # - If the user code doesn't like that (Python complains about a missing # required argument), we warn and fall back to the previous behavior. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is ' 'deprecated. `TransitionKernel` instances now receive seeds via ' '`one_step`.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, self._seed_stream()) # Now that we've constructed the TransitionKernel instance: # - If we were given a seed, we sanitize it to stateless and pass along # to `kernel.one_step`. If it doesn't like that, we crash and propagate # the error. Rationale: The contract is stateless sampling given # seed, and doing otherwise would not meet it. # - If not given a seed, we don't pass one along. This avoids breaking # underlying kernels lacking a `seed` arg on `one_step`. # TODO(b/159636942): Clean up after 2020-09-20. if seed is not None: seed = samplers.sanitize_seed(seed) inner_seed, swap_seed, logu_seed = samplers.split_seed( seed, n=3, salt='remc_one_step') inner_kwargs = dict(seed=inner_seed) else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) inner_kwargs = {} swap_seed, logu_seed = samplers.split_seed(self._seed_stream()) [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results, **inner_kwargs) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob) num_replica = ps.size0(inverse_temperatures) inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. try: swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed, step_count=previous_kernel_results.step_count), dtype=tf.int32) except TypeError as e: if 'step_count' not in str(e): raise warnings.warn( 'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept ' 'the `step_count` argument. Falling back to omitting the ' 'argument. This fallback will be removed after 24-Oct-2020.' ) swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed), dtype=tf.int32) null_swaps = mcmc_util.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs. E.g., for replica k, at point x_k, this is # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k. untempered_pre_swap_replica_target_log_prob = ( pre_swap_replica_target_log_prob / inverse_temperatures) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. energy_diff = (untempered_pre_swap_replica_target_log_prob - mcmc_util.index_remapping_gather( untempered_pre_swap_replica_target_log_prob, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * mcmc_util.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce Log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=logu_seed)) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = mcmc_util.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, mcmc_util.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) if self._state_includes_replicas: post_swap_states = post_swap_replica_states else: post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _make_post_swap_replica_results( pre_swap_replica_results, inverse_temperatures, swapped_inverse_temperatures, is_swap_accepted_mask, _swap_tensor) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, step_count=previous_kernel_results.step_count + 1, seed=samplers.zeros_seed() if seed is None else seed, ) return states, post_swap_kernel_results
def _sample_n(self, n, seed=None, conditional_input=None, training=False): """Samples from the distribution, with optional conditional input. Args: n: `int`, number of samples desired. seed: `int`, seed for RNG. Setting a random seed enforces reproducability of the samples between sessions (not within a single session). conditional_input: `Tensor` on which to condition the distribution (e.g. class labels), or `None`. training: `bool` or `None`. If `bool`, it controls the dropout layer, where `True` implies dropout is active. If `None`, it defers to Keras' handling of train/eval status. Returns: samples: a `Tensor` of shape `[n, height, width, num_channels]`. """ if conditional_input is not None: conditional_input = tf.convert_to_tensor( conditional_input, dtype=self.dtype) conditional_event_rank = tensorshape_util.rank(self.conditional_shape) conditional_input_shape = prefer_static.shape(conditional_input) conditional_sample_rank = prefer_static.rank( conditional_input) - conditional_event_rank # If `conditional_input` has no sample dimensions, prepend a sample # dimension if conditional_sample_rank == 0: conditional_input = conditional_input[tf.newaxis, ...] conditional_sample_rank = 1 # Assert that the conditional event shape in the `PixelCnnNetwork` is the # same as that implied by `conditional_input`. conditional_event_shape = conditional_input_shape[ conditional_sample_rank:] with tf.control_dependencies([ tf.assert_equal(self.conditional_shape, conditional_event_shape)]): conditional_sample_shape = conditional_input_shape[ :conditional_sample_rank] repeat = n // prefer_static.reduce_prod(conditional_sample_shape) h = tf.reshape( conditional_input, prefer_static.concat([(-1,), self.conditional_shape], axis=0)) h = tf.tile(h, prefer_static.pad( [repeat], paddings=[[0, conditional_event_rank]], constant_values=1)) samples_0 = tf.random.uniform( prefer_static.concat([(n,), self.event_shape], axis=0), minval=-1., maxval=1., dtype=self.dtype, seed=seed) inputs = samples_0 if conditional_input is None else [samples_0, h] params_0 = self.network(inputs, training=training) samples_0 = self._sample_channels(*params_0, seed=seed) image_height, image_width, _ = tensorshape_util.as_list(self.event_shape) def loop_body(index, samples): """Loop for iterative pixel sampling. Args: index: 0D `Tensor` of type `int32`. Index of the current pixel. samples: 4D `Tensor`. Images with pixels sampled in raster order, up to pixel `[index]`, with dimensions `[batch_size, height, width, num_channels]`. Returns: samples: 4D `Tensor`. Images with pixels sampled in raster order, up to and including pixel `[index]`, with dimensions `[batch_size, height, width, num_channels]`. """ inputs = samples if conditional_input is None else [samples, h] params = self.network(inputs, training=training) samples_new = self._sample_channels(*params, seed=seed) # Update the current pixel samples = tf.transpose(samples, [1, 2, 3, 0]) samples_new = tf.transpose(samples_new, [1, 2, 3, 0]) row, col = index // image_width, index % image_width updates = samples_new[row, col, ...][tf.newaxis, ...] samples = tf.tensor_scatter_nd_update(samples, [[row, col]], updates) samples = tf.transpose(samples, [3, 0, 1, 2]) return index + 1, samples index0 = tf.zeros([], dtype=tf.int32) # Construct the while loop for sampling total_pixels = image_height * image_width loop_cond = lambda ind, _: tf.less(ind, total_pixels) init_vars = (index0, samples_0) _, samples = tf.while_loop(loop_cond, loop_body, init_vars, parallel_iterations=1) transformed_samples = (self._low + 0.5 * (self._high - self._low) * (samples + 1.)) return tf.round(transformed_samples)
def _compute_log_acceptance_correction(current_momentums, proposed_momentums, independent_chain_ndims, name=None): """Helper to `kernel` which computes the log acceptance-correction. A sufficient but not necessary condition for the existence of a stationary distribution, `p(x)`, is "detailed balance", i.e.: ```none p(x'|x) p(x) = p(x|x') p(x') ``` In the Metropolis-Hastings algorithm, a state is proposed according to `g(x'|x)` and accepted according to `a(x'|x)`, hence `p(x'|x) = g(x'|x) a(x'|x)`. Inserting this into the detailed balance equation implies: ```none g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x') ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)] (*) ``` One definition of `a(x'|x)` which satisfies (*) is: ```none a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)]) ``` (To see that this satisfies (*), notice that under this definition only at most one `a(x'|x)` and `a(x|x') can be other than one.) We call the bracketed term the "acceptance correction". In the case of UncalibratedHMC, the log acceptance-correction is not the log proposal-ratio. UncalibratedHMC augments the state-space with momentum, z. Assuming a standard Gaussian distribution for momentums, the chain eventually converges to: ```none p([x, z]) propto= target_prob(x) exp(-0.5 z**2) ``` Relating this back to Metropolis-Hastings parlance, for HMC we have: ```none p([x, z]) propto= target_prob(x) exp(-0.5 z**2) g([x, z] | [x', z']) = g([x', z'] | [x, z]) ``` In other words, the MH bracketed term is `1`. However, because we desire to use a general MH framework, we can place the momentum probability ratio inside the metropolis-correction factor thus getting an acceptance probability: ```none target_prob(x') accept_prob(x'|x) = ----------------- [exp(-0.5 z**2) / exp(-0.5 z'**2)] target_prob(x) ``` (Note: we actually need to handle the kinetic energy change at each leapfrog step, but this is the idea.) Args: current_momentums: `Tensor` representing the value(s) of the current momentum(s) of the state (parts). proposed_momentums: `Tensor` representing the value(s) of the proposed momentum(s) of the state (parts). independent_chain_ndims: Scalar `int` `Tensor` representing the number of leftmost `Tensor` dimensions which index independent chains. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'compute_log_acceptance_correction'). Returns: log_acceptance_correction: `Tensor` representing the `log` acceptance-correction. (See docstring for mathematical definition.) """ with tf.name_scope(name or 'compute_log_acceptance_correction'): sum_sq = lambda v: tf.reduce_sum( v**2., axis=prefer_static.range( # pylint: disable=g-long-lambda independent_chain_ndims, prefer_static.rank(v))) current_kinetic = tf.add_n([sum_sq(v) for v in current_momentums]) proposed_kinetic = tf.add_n([sum_sq(v) for v in proposed_momentums]) return 0.5 * mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel=kernel, filter_shape=filter_shape, strides=(strides, ) * rank, padding=padding, dilations=dilations, c_out=c_out, batch_shape=batch_shape, event_shape=event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) x_pad_shape = ps.shape(x_pad)[:-3] flat_shape = ps.pad(x_pad_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_pad, shape=flat_shape) idx, s = im2row_index( (xh + tf.reduce_sum(padding_vals[0]), xw + tf.reduce_sum(padding_vals[1]), c_in), block_shape=(sub_fh, sub_fw), slice_step=(1, 1), dilations=dilations) x_ = tf.gather(flat_x, indices=idx, axis=-1) im_x = tf.reshape(x_, shape=ps.concat([x_pad_shape, s], axis=0)) # Add channels to subkernel indices idx_event = event_ind * [[c_in, 1]] idx_event_channels = (idx_event[tf.newaxis] + tf.stack( [ps.range(c_in), tf.zeros( (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :]) idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels, block_shape=[c_in], crops=[[0, 0]]), axis=0) idx_event_broadcast = tf.broadcast_to( idx_event, shape=ps.concat( [kernel_batch, ps.shape(idx_event)], axis=0)) # Add cartesian product of batch indices, since scatter_nd can only be # applied to leading dimensions. idx_batch = tf.stack(tf.meshgrid(*[ ps.range(b_, delta=1, dtype=dtype) for b_ in tf.unstack(kernel_batch) ], indexing='ij'), axis=ps.size(kernel_batch)) idx_batch = tf.cast(idx_batch, dtype=dtype) # empty tensor is float idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros( (ps.shape(idx_event)[0], 1), dtype=dtype) idx_kernel = tf.concat( [idx_batch_broadcast, idx_event_broadcast], axis=-1) kernel_mat = tf.scatter_nd( idx_kernel, updates=kernel, shape=ps.cast(ps.concat([ kernel_batch, [sub_fh * sub_fw * c_in, strides**2, c_out] ], axis=0), dtype=dtype)) kernel_mat = tf.reshape( kernel_mat, shape=ps.concat( [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]], axis=0)) kernel_mat = kernel_mat[..., tf.newaxis, :, :] out = tf.matmul(im_x, kernel_mat) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) if strides > 1: tot_size = tf.reduce_prod(broadcast_batch_shape) flat_out = tf.reshape(out, shape=ps.concat([[tot_size], ps.shape(out)[-3:]], axis=0)) out = tf.nn.depth_to_space(flat_out, block_size=strides) if padding == 'VALID': out_height = fh + strides * (xh - 1) out_width = fw + strides * (xw - 1) elif padding == 'SAME': out_height = xh * strides out_width = xw * strides out = out[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :] out = tf.reshape( out, shape=ps.concat([ broadcast_batch_shape, [out_height, out_width, c_out] ], axis=0)) return out
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, c_out, batch_shape, event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1 ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1 def loop_body(i, outputs): subkernel_ind = kernels_ind.read(i) fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2) eh = ex_h + fh_ - 1 ew = ex_w + fw_ - 1 subkernel_ind = ps.reshape(ps.reshape( subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] + ps.range(c_in), shape=[-1]) k = tf.gather(kernel, subkernel_ind, axis=-2) ind, shape = im2row_index([eh, ew, c_in], block_shape=(fh_, fw_), slice_step=(1, 1), dilations=dilations) x_i = x_pad[..., :eh, :ew, :] x_i_shape = ps.shape(x_i) flat_shape = ps.pad(x_i_shape[:-3], paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_i, flat_shape) x_ = tf.gather(flat_x, ind, axis=-1) im_x = tf.reshape( x_, ps.concat([x_i_shape[:-3], shape], axis=0)) outputs = outputs.write( i, tf.matmul( im_x, tf.reshape( k, ps.concat([ kernel_batch, [1, fh_ * fw_ * c_in, c_out] ], axis=0)))) return i + 1, outputs outputs = tf.TensorArray(dtype=input_dtype, infer_shape=False, size=1, dynamic_size=True) _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body, [0, outputs]) y = outputs.concat() m = tf.reduce_prod(ps.shape(y)[:-3]) y_ = tf.reshape(y, shape=ps.concat([[m], ps.shape(y)[-3:]], axis=0)) y2 = tf.batch_to_space(y_, strides, crops=tf.zeros([2, 2], dtype=tf.int64)) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) y2 = tf.reshape( y2, ps.concat([broadcast_batch_shape, ps.shape(y2)[-3:]], axis=0)) if padding == 'VALID': out_height = fh + sh * (xh - 1) out_width = fw + sw * (xw - 1) elif padding == 'SAME': out_height = xh * sh out_width = xw * sw return y2[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :]
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) fh, fw = filter_shape assertions = _maybe_validate_input_shapes(ps.shape(kernel), channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): if tf.get_static_value(ps.rank(kernel)) == 2: flat_x = tf.reshape(x, shape=ps.concat([[-1], event_shape], axis=0)) flat_y = tf.nn.conv2d(x, filters=tf.reshape( kernel, shape=[fh, fw, c_in, -1]), strides=strides, padding=padding, data_format='NHWC', dilations=dilations) output_shape = ps.shape(flat_y)[-3:] return tf.reshape(flat_y, shape=ps.concat([batch_shape, output_shape], axis=0)) pad_values = [ _get_conv_padding(xdim, filter_dim=k, stride=s, dilation=d, padding=padding) for (xdim, k, s, d) in zip((xh, xw), filter_shape, strides, dilations) ] idx, shape = im2row_index( (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in), block_shape=filter_shape, slice_step=strides, dilations=dilations, dtype=dtype) if padding == 'SAME': n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) x = tf.pad(x, paddings=paddings, constant_values=0) flat_shape = ps.pad(batch_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.gather(tf.reshape(x, shape=flat_shape), indices=idx, axis=-1) im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, kernel_shape[-1], batch_shape, event_shape) idx, shape = im2row_index((xh * sh + sum(pad_values[0]), xw * sw + sum(pad_values[1]), c_in), block_shape=filter_shape, slice_step=(1, 1), dilations=dilations, dtype=dtype, transpose=True) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) # Interleave the rows and columns of the input with rows and columns of # zeros equal to the number of strides. x_half_dilated = tf.concat([ tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)], axis=0), dtype=input_dtype), tf.reshape(x, shape=ps.concat( [batch_shape, (xh * xw, 1, c_in)], axis=0)) ], axis=-2) y = tf.reshape(x_half_dilated, shape=ps.concat( [batch_shape, (xh, 1, xw * sw, c_in)], axis=0)) x = tf.reshape(tf.concat([ tf.zeros(ps.concat( [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0), dtype=input_dtype), y ], axis=-3), shape=ps.concat( [batch_shape, (xh * sh, xw * sw, c_in)], axis=0)) x_pad = tf.pad(x, paddings=paddings, constant_values=0) flat_shape = ps.pad(batch_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape), indices=idx, axis=-1) im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
def _sample_next(target_log_prob_fn, current_state_parts, step_sizes, max_doublings, current_target_log_prob, batch_rank, seed=None, name=None): """Applies a single iteration of slice sampling update. Applies hit and run style slice sampling. Chooses a uniform random direction on the unit sphere in the event space. Applies the one dimensional slice sampling update along that direction. Args: target_log_prob_fn: Python callable which takes an argument like `*current_state_parts` and returns its (possibly unnormalized) log-density under the target distribution. current_state_parts: Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `independent_chain_ndims` of the `Tensor`(s) index different chains. step_sizes: Python `list` of `Tensor`s. Provides a measure of the width of the density. Used to find the slice bounds. Must broadcast with the shape of `current_state_parts`. max_doublings: Integer number of doublings to allow while locating the slice boundaries. current_target_log_prob: `Tensor` representing the value of `target_log_prob_fn(*current_state_parts)`. The only reason to specify this argument is to reduce TF graph size. batch_rank: Integer. The number of axes in the state that correspond to independent batches. seed: Tensor seed pair. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'find_slice_bounds'). Returns: proposed_state_parts: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state_parts`. proposed_target_log_prob: `Tensor` representing the value of `target_log_prob_fn` at `next_state`. bounds_satisfied: Boolean `Tensor` of the same shape as the log density. True indicates whether the an interval containing the slice for that batch was found successfully. direction: `Tensor` or Python list of `Tensors`s representing the direction along which the slice was sampled. Has the same shape and dtype(s) as `current_state_parts`. upper_bounds: `Tensor` of batch shape and the dtype of the input state. The upper bounds of the slices along the sampling direction. lower_bounds: `Tensor` of batch shape and the dtype of the input state. The lower bounds of the slices along the sampling direction. """ direction_seed, slice_seed = samplers.split_seed(seed) with tf.name_scope(name or 'sample_next'): # First step: Choose a random direction. # Direction is a list of tensors. The i'th tensor should have the same shape # as the i'th state part. direction = _choose_random_direction(current_state_parts, batch_rank=batch_rank, seed=direction_seed) # Interpolates the step sizes for the chosen direction. # Applies an ellipsoidal interpolation to compute the step direction for # the chosen direction. Suppose we are given step sizes for each direction. # Label these s_1, s_2, ... s_k. These are the step sizes to use if moving # in a direction parallel to one of the axes. Consider an ellipsoid which # intercepts the i'th axis at s_i. The step size for a direction specified # by the unit vector (n_1, n_2 ...n_k) is then defined as the intersection # of the line through this vector with this ellipsoid. # # One can show that the length of the vector from the origin to the # intersection point is given by: # 1 / sqrt(n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...). # # Proof: # The equation of the ellipsoid is: # Sum_i [x_i^2 / s_i^2 ] = 1. Let n be a unit direction vector. Points # along the line given by n may be parameterized as alpha*n where alpha is # the distance along the vector. Plugging this into the equation for the # ellipsoid, we get: # alpha^2 ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) = 1 # so alpha = \sqrt { \frac{1} { ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) } } reduce_axes = [ps.range(batch_rank, ps.rank(dirn_part)) for dirn_part in direction] components = [ tf.reduce_sum((dirn_part / step_size)**2, axis=reduce_axes[i]) for i, (step_size, dirn_part) in enumerate(zip(step_sizes, direction)) ] step_size = tf.math.rsqrt(tf.add_n(components)) # Computes the rank of a tensor. Uses the static rank if possible. state_part_ranks = [ps.rank(part) for part in current_state_parts] def _step_along_direction(alpha): """Converts the scalar alpha into an n-dim vector with full state info. Computes x_0 + alpha * direction where x_0 is the current state and direction is the direction chosen above. Args: alpha: A tensor of shape equal to the batch dimensions of `current_state_parts`. Returns: state_parts: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) for a given alpha and a given chosen direction. Has the same shape as `current_state_parts`. """ padded_alphas = [_right_pad(alpha, final_rank=part_rank) for part_rank in state_part_ranks] state_parts = [state_part + padded_alpha * direction_part for state_part, direction_part, padded_alpha in zip(current_state_parts, direction, padded_alphas)] return state_parts def projected_target_log_prob_fn(alpha): """The target log density projected along the chosen direction. Args: alpha: A tensor of shape equal to the batch dimensions of `current_state_parts`. Returns: Target log density evaluated at x_0 + alpha * direction where x_0 is the current state and direction is the direction chosen above. Has the same shape as `alpha`. """ return target_log_prob_fn(*_step_along_direction(alpha)) alpha_init = tf.zeros_like(current_target_log_prob, dtype=current_state_parts[0].dtype) [ next_alpha, next_target_log_prob, bounds_satisfied, upper_bounds, lower_bounds ] = ssu.slice_sampler_one_dim(projected_target_log_prob_fn, x_initial=alpha_init, max_doublings=max_doublings, step_size=step_size, seed=slice_seed) return [ _step_along_direction(next_alpha), next_target_log_prob, bounds_satisfied, direction, upper_bounds, lower_bounds ]
def one_step(self, current_state, previous_kernel_results, seed=None): """Runs one iteration of Slice Sampler. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) seed: Optional, a seed for reproducible sampling. Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. TypeError: if `not target_log_prob.dtype.is_floating`. """ seed = samplers.sanitize_seed(seed) # Retain for diagnostics. with tf.name_scope(mcmc_util.make_name(self.name, 'slice', 'one_step')): with tf.name_scope('initialize'): [ current_state_parts, step_sizes, current_target_log_prob ] = _prepare_args( self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, maybe_expand=True) max_doublings = ps.convert_to_shape_tensor( value=self.max_doublings, dtype=tf.int32, name='max_doublings') independent_chain_ndims = ps.rank(current_target_log_prob) [ next_state_parts, next_target_log_prob, bounds_satisfied, direction, upper_bounds, lower_bounds ] = _sample_next( self.target_log_prob_fn, current_state_parts, step_sizes, max_doublings, current_target_log_prob, independent_chain_ndims, seed=seed, ) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), SliceSamplerKernelResults( target_log_prob=next_target_log_prob, bounds_satisfied=bounds_satisfied, direction=direction, upper_bounds=upper_bounds, lower_bounds=lower_bounds, seed=seed, ), ]
def _event_size(tensor_structure, event_ndims): """Returns the number of elements in the event-portion of a structure.""" event_shapes = nest.map_structure( lambda t, nd: ps.slice(ps.shape(t), [ps.rank(t)-nd], [nd]), tensor_structure, event_ndims) return sum(ps.reduce_prod(shape) for shape in nest.flatten(event_shapes))
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable _make_replica_target_log_prob_fn(self.target_log_prob_fn, inverse_temperatures), self._seed_stream()) [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = prefer_static.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = prefer_static.rank( pre_swap_replica_target_log_prob) num_replica = prefer_static.size0(inverse_temperatures) inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=self._seed_stream()), dtype=tf.int32) null_swaps = mcmc_util.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs. E.g., for replica k, at point x_k, this is # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k. untempered_pre_swap_replica_target_log_prob = ( pre_swap_replica_target_log_prob / inverse_temperatures) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. energy_diff = (untempered_pre_swap_replica_target_log_prob - mcmc_util.index_remapping_gather( untempered_pre_swap_replica_target_log_prob, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * mcmc_util.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce Log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( tf.random.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=self._seed_stream())) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = mcmc_util.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, mcmc_util.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _make_post_swap_replica_results( pre_swap_replica_results, inverse_temperatures, swapped_inverse_temperatures, is_swap_accepted_mask, _swap_tensor) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, ) return states, post_swap_kernel_results
def _one_step_part( self, step_size, state, error_sum, log_averaging_step, log_shrinkage_target, log_accept_prob_rank=None, log_accept_prob=None, target_accept_prob=None, previous_kernel_results=None): """Compute new step sizes for each step size part. If step size part has smaller rank than the corresponding state part, then the difference is averaged away in the log accept prob. Example: state_part has shape [2, 3, 4, 5] step_size_part has shape [1, 4, 1] log_accept_prob has shape [2, 3, 4] Since step size has 1 rank fewer than the state, we reduce away the leading dimension of `log_accept_prob` to get a Tensor with shape [3, 4]. Next, since `log_accept_prob` must broadcast into step_size_part on the left, we reduce the dimensions where their shapes differ, to get a Tensor with shape [1, 4], which now is compatible with the leading dimensions of step_size_part. There is a subtlety here in that `step_size_parts` might be a length-1 list, which means that we'll be "structure-broadcasting" it for all the state parts (see logic in, e.g., hmc.py). In this case we must assume that that the lone step size provided broadcasts with the event dims of each state part. This means that either step size has no dimensions corresponding to chain dimensions, or all states are of the same shape. For the former, we want to reduce over all chain dimensions. For the later, we want to use the same logic as in the non-structure-broadcasted case. It turns out we can compute the reduction dimensions for both cases uniformly by taking the rank of any state part. This obviously works in the second case (where all state ranks are the same). In the first case, all state parts have the rank L + D_i + B, where L is the rank of log_accept_prob, D_i is the non-shared dimensions amongst all states, and B are the shared dimensions of all the states, which are equal to the step size. When we subtract B, we will always get a number >= L, which means we'll get the full reduction we want. Args: step_size: Previous step's step_size. state: Previous step's state value. error_sum: Previous step's error accumulator. log_averaging_step: Previous step's log_averaging_step. log_shrinkage_target: Floating point scalar `Tensor`. Logarithm of value the exploration step size is biased towards. log_accept_prob_rank: Rank of log_accept_prob. log_accept_prob: Floating point scalar `Tensor`. Target accept probability. target_accept_prob: A floating point `Tensor` representing desired acceptance probability. Must be a positive number less than 1. previous_kernel_results: Results struct from previous step. Returns: new_step_size: Updated `step_size`. new_log_averaging_step: Updated `log_averaging_step`. new_error_sum: Updated `error_sum`. """ num_reduce_dims = prefer_static.minimum( log_accept_prob_rank, (prefer_static.rank(state) - prefer_static.rank(step_size))) reduced_log_accept_prob = reduce_logmeanexp( log_accept_prob, axis=prefer_static.range(num_reduce_dims)) # reduced_log_accept_prob must broadcast into step_size on the # left, so we do an additional reduction over dimensions where their # shapes differ. reduce_indices = get_differing_dims( reduced_log_accept_prob, step_size) reduced_log_accept_prob = reduce_logmeanexp( reduced_log_accept_prob, axis=reduce_indices, keepdims=True) new_error_sum = (error_sum + target_accept_prob - tf.math.exp(reduced_log_accept_prob)) num_ones_to_pad = prefer_static.maximum( prefer_static.rank(log_shrinkage_target) - prefer_static.rank(new_error_sum), 0) new_error_sum_extend = tf.reshape( new_error_sum, shape=prefer_static.pad( prefer_static.shape(new_error_sum), paddings=[[0, num_ones_to_pad]], constant_values=1)) step_count_smoothing = previous_kernel_results.step_count_smoothing step = tf.cast( previous_kernel_results.step, step_count_smoothing.dtype) + 1. soft_t = step_count_smoothing + step new_log_step = ( log_shrinkage_target - ((tf.cast(new_error_sum_extend, step.dtype) * tf.math.sqrt(step)) / (soft_t * previous_kernel_results.exploration_shrinkage))) eta = step**(-previous_kernel_results.decay_rate) new_log_averaging_step = (eta * new_log_step + (1. - eta) * log_averaging_step) # - If still adapting, return an exploring step size, # - If just finished, return the averaging step size # - Otherwise, do not update new_step_size = tf.where( previous_kernel_results.step < self.num_adaptation_steps, tf.math.exp(new_log_step), tf.where(previous_kernel_results.step > self.num_adaptation_steps, step_size, tf.math.exp(new_log_averaging_step))) new_log_averaging_step = tf.where( previous_kernel_results.step > self.num_adaptation_steps, log_averaging_step, new_log_averaging_step) return new_step_size, new_log_averaging_step, new_error_sum
def _log_prob(self, value, conditional_input=None, training=None): """Log probability function with optional conditional input. Calculates the log probability of a batch of data under the modeled distribution (or conditional distribution, if conditional input is provided). Args: value: `Tensor` or Numpy array of image data. May have leading batch dimension(s), which must broadcast to the leading batch dimensions of `conditional_input`. conditional_input: `Tensor` on which to condition the distribution (e.g. class labels), or `None`. May have leading batch dimension(s), which must broadcast to the leading batch dimensions of `value`. training: `bool` or `None`. If `bool`, it controls the dropout layer, where `True` implies dropout is active. If `None`, it defaults to `tf.keras.backend.learning_phase()`. Returns: log_prob_values: `Tensor`. """ # Determine the batch shape of the input images image_batch_shape = prefer_static.shape(value)[:-3] # Broadcast `value` and `conditional_input` to the same batch_shape if conditional_input is None: image_batch_and_conditional_shape = image_batch_shape else: conditional_input = tf.convert_to_tensor(conditional_input) conditional_input_shape = prefer_static.shape(conditional_input) conditional_batch_rank = (prefer_static.rank(conditional_input) - tensorshape_util.rank(self.conditional_shape)) conditional_batch_shape = conditional_input_shape[:conditional_batch_rank] image_batch_and_conditional_shape = prefer_static.broadcast_shape( image_batch_shape, conditional_batch_shape) conditional_input = tf.broadcast_to( conditional_input, prefer_static.concat( [image_batch_and_conditional_shape, self.conditional_shape], axis=0)) value = tf.broadcast_to( value, prefer_static.concat( [image_batch_and_conditional_shape, self.event_shape], axis=0)) # Flatten batch dimension for input to Keras model conditional_input = tf.reshape( conditional_input, prefer_static.concat([(-1,), self.conditional_shape], axis=0)) value = tf.reshape( value, prefer_static.concat([(-1,), self.event_shape], axis=0)) transformed_value = (2. * (value - self._low) / (self._high - self._low)) - 1. inputs = (transformed_value if conditional_input is None else [transformed_value, conditional_input]) params = self.network(inputs, training=training) num_channels = self.event_shape[-1] if num_channels == 1: component_logits, locs, scales = params else: # If there is more than one channel, we create a linear autoregressive # dependency among the location parameters of the channels of a single # pixel (the scale parameters within a pixel are independent). For a pixel # with R/G/B channels, the `r`, `g`, and `b` saturation values are # distributed as: # # r ~ Logistic(loc_r, scale_r) # g ~ Logistic(coef_rg * r + loc_g, scale_g) # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b) # TODO(emilyaf) Investigate using fill_triangular/matrix multiplication # on the coefficients instead of split/multiply/concat component_logits, locs, scales, coeffs = params num_coeffs = num_channels * (num_channels - 1) // 2 loc_tensors = tf.split(locs, num_channels, axis=-1) coef_tensors = tf.split(coeffs, num_coeffs, axis=-1) channel_tensors = tf.split(transformed_value, num_channels, axis=-1) coef_count = 0 for i in range(num_channels): channel_tensors[i] = channel_tensors[i][..., tf.newaxis, :] for j in range(i): loc_tensors[i] += channel_tensors[j] * coef_tensors[coef_count] coef_count += 1 locs = tf.concat(loc_tensors, axis=-1) dist = self._make_mixture_dist(component_logits, locs, scales) return tf.reshape(dist.log_prob(value), image_batch_and_conditional_shape)
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'dual_averaging_step_size_adaptation', 'bootstrap_results')): inner_results = self.inner_kernel.bootstrap_results(init_state) step_size = self.step_size_getter_fn(inner_results) log_accept_prob = self.log_accept_prob_getter_fn(inner_results) state_parts = tf.nest.flatten(init_state) step_size_parts = tf.nest.flatten(step_size) if self._parameters['shrinkage_target'] is None: shrinkage_target_parts = [None] * len(step_size_parts) else: shrinkage_target_parts = tf.nest.flatten( self._parameters['shrinkage_target']) if len(shrinkage_target_parts) not in [1, len(step_size_parts)]: raise ValueError( '`shrinkage_target` should be a Tensor or list of tensors of ' 'same length as `step_size`. Found len(`step_size`) = {} and ' 'len(shrinkage_target) = {}'.format( len(step_size_parts), len(shrinkage_target_parts))) if len(shrinkage_target_parts) < len(step_size_parts): shrinkage_target_parts *= len(step_size_parts) dtype = dtype_util.common_dtype(step_size_parts, tf.float32) error_sum, log_averaging_step, log_shrinkage_target = [], [], [] for state_part, step_size_part, shrinkage_target_part in zip( state_parts, step_size_parts, shrinkage_target_parts): num_reduce_dims = prefer_static.minimum( prefer_static.rank(log_accept_prob), prefer_static.rank(state_part) - prefer_static.rank(step_size_part)) reduced_log_accept_prob = reduce_logmeanexp( log_accept_prob, axis=prefer_static.range(num_reduce_dims)) reduce_indices = get_differing_dims( reduced_log_accept_prob, step_size_part) reduced_log_accept_prob = reduce_logmeanexp( reduced_log_accept_prob, axis=reduce_indices, keepdims=True) error_sum.append(tf.zeros_like(reduced_log_accept_prob, dtype=dtype)) log_averaging_step.append(tf.zeros_like(step_size_part, dtype=dtype)) if shrinkage_target_part is None: log_shrinkage_target.append( float(np.log(10.)) + tf.math.log(step_size_part)) else: log_shrinkage_target.append( tf.math.log(tf.cast(shrinkage_target_part, dtype))) return DualAveragingStepSizeAdaptationResults( inner_results=inner_results, step=tf.constant(0, dtype=tf.int32), target_accept_prob=tf.cast(self.parameters['target_accept_prob'], log_accept_prob.dtype), log_shrinkage_target=log_shrinkage_target, exploration_shrinkage=tf.cast( self.parameters['exploration_shrinkage'], dtype), step_count_smoothing=tf.cast( self.parameters['step_count_smoothing'], dtype), decay_rate=tf.cast(self.parameters['decay_rate'], dtype), error_sum=error_sum, log_averaging_step=log_averaging_step, new_step_size=step_size)
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append( tf.random.normal(shape=tf.shape(x), dtype=self._momentum_dtype or dtype_util.base_dtype(x.dtype), seed=self._seed_stream())) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator(current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] independent_chain_ndims = prefer_static.rank( current_target_log_prob) new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, ) return maybe_flatten(next_state_parts), new_kernel_results
def batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis, fill_value='constant_extension', name=None): """Multi-linear interpolation on a regular (constant spacing) grid. Given [a batch of] reference values, this function computes a multi-linear interpolant and evaluates it on [a batch of] of new `x` values. The interpolant is built from reference values indexed by `nd` dimensions of `y_ref`, starting at `axis`. For example, take the case of a `2-D` scalar valued function and no leading batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]` is the reference value corresponding to grid point ``` [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1), x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)] ``` In the general case, dimensions to the left of `axis` in `y_ref` are broadcast with leading dimensions in `x`, `x_ref_min`, `x_ref_max`. Args: x: Numeric `Tensor` The x-coordinates of the interpolated output values for each batch. Shape `[..., D, nd]`, designating [a batch of] `D` coordinates in `nd` space. `D` must be `>= 1` and is not a batch dim. x_ref_min: `Tensor` of same `dtype` as `x`. The minimum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. x_ref_max: `Tensor` of same `dtype` as `x`. The maximum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. y_ref: `Tensor` of same `dtype` as `x`. The reference output values. Shape `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued function (for `M >= 0`). axis: Scalar integer `Tensor`. Dimensions `[axis, axis + nd)` of `y_ref` index the interpolation table. E.g. `3-D` interpolation of a scalar valued function requires `axis=-3` and a `3-D` matrix valued function requires `axis=-5`. fill_value: Determines what values output should take for `x` values that are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or 'constant_extension' ==> Extend as constant function. Default value: `'constant_extension'` name: A name to prepend to created ops. Default value: `'batch_interp_regular_nd_grid'`. Returns: y_interp: Interpolation between members of `y_ref`, at points `x`. `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].` Raises: ValueError: If `rank(x) < 2` is determined statically. ValueError: If `axis` is not a scalar is determined statically. ValueError: If `axis + nd > rank(y_ref)` is determined statically. #### Examples Interpolate a function of one variable. ```python y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20)) tfp.math.batch_interp_regular_nd_grid( # x.shape = [3, 1], x_ref_min/max.shape = [1]. Trailing `1` for `1-D`. x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref, axis=0) ==> approx [exp(6.0), exp(0.5), exp(3.3)] ``` Interpolate a scalar function of two variables. ```python x_ref_min = [0., 0.] x_ref_max = [2 * np.pi, 2 * np.pi] # Build y_ref. x0s, x1s = tf.meshgrid( tf.linspace(x_ref_min[0], x_ref_max[0], num=100), tf.linspace(x_ref_min[1], x_ref_max[1], num=100), indexing='ij') def func(x0, x1): return tf.sin(x0) * tf.cos(x1) y_ref = func(x0s, x1s) x = np.pi * tf.random.uniform(shape=(10, 2)) tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2) ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1]) ``` """ with tf.name_scope(name or 'interp_regular_nd_grid'): dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], dtype_hint=tf.float32) # Arg checking. if isinstance(fill_value, str): if fill_value != 'constant_extension': raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fill_value, 'constant_extension')) else: fill_value = tf.convert_to_tensor( fill_value, name='fill_value', dtype=dtype) _assert_ndims_statically(fill_value, expect_ndims=0) # x.shape = [..., nd]. x = tf.convert_to_tensor(x, name='x', dtype=dtype) _assert_ndims_statically(x, expect_ndims_at_least=2) # y_ref.shape = [..., C1,...,Cnd, B1,...,BM] y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype) # x_ref_min.shape = [nd] x_ref_min = tf.convert_to_tensor( x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor( x_ref_max, name='x_ref_max', dtype=dtype) _assert_ndims_statically( x_ref_min, expect_ndims_at_least=1, expect_static=True) _assert_ndims_statically( x_ref_max, expect_ndims_at_least=1, expect_static=True) # nd is the number of dimensions indexing the interpolation table, it's the # 'nd' in the function name. nd = tf.compat.dimension_value(x_ref_min.shape[-1]) if nd is None: raise ValueError('`x_ref_min.shape[-1]` must be known statically.') tensorshape_util.assert_is_compatible_with( x_ref_max.shape[-1:], x_ref_min.shape[-1:]) # Convert axis and check it statically. axis = tf.convert_to_tensor(axis, dtype=tf.int32, name='axis') axis = ps.non_negative_axis(axis, tf.rank(y_ref)) tensorshape_util.assert_has_rank(axis.shape, 0) axis_ = tf.get_static_value(axis) y_ref_rank_ = tf.get_static_value(tf.rank(y_ref)) if axis_ is not None and y_ref_rank_ is not None: if axis_ + nd > y_ref_rank_: raise ValueError( 'Since dims `[axis, axis + nd)` index the interpolation table, we ' 'must have `axis + nd <= rank(y_ref)`. Found: ' '`axis`: {}, rank(y_ref): {}, and inferred `nd` from trailing ' 'dimensions of `x_ref_min` to be {}.'.format( axis_, y_ref_rank_, nd)) x_batch_shape = ps.shape(x)[:-2] x_ref_min_batch_shape = ps.shape(x_ref_min)[:-1] x_ref_max_batch_shape = ps.shape(x_ref_max)[:-1] y_ref_batch_shape = ps.shape(y_ref)[:axis] # Do a brute-force broadcast of batch dims (add zeros). batch_shape = y_ref_batch_shape for tensor in [x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape]: batch_shape = ps.broadcast_shape(batch_shape, tensor) def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons): """Return Tensor of zeros with some singletons on the rightmost dims.""" ones = ps.ones(shape=[n_singletons], dtype=tf.int32) return ps.concat([batch_shape, ones], axis=0) x = _broadcast_with( x, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=2)) x_ref_min = _broadcast_with( x_ref_min, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1)) x_ref_max = _broadcast_with( x_ref_max, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1)) y_ref = _broadcast_with( y_ref, _batch_shape_of_zeros_with_rightmost_singletons( n_singletons=tf.rank(y_ref) - axis)) return _batch_interp_with_gather_nd( x=x, x_ref_min=x_ref_min, x_ref_max=x_ref_max, y_ref=y_ref, nd=nd, fill_value=fill_value, batch_dims=ps.rank(x) - 2)
def _copy(v): return v * ps.ones(ps.pad( [2], paddings=[[0, ps.rank(v)]], constant_values=1), dtype=v.dtype)
def window_tune_nuts_sampling(target_log_prob, prior_samples, constraining_bijectors=None, init_state=None, num_samples=500, nchains=4, init_nchains=1, target_accept_prob=.8, max_tree_depth=9, use_scaled_init=True, tuning_window_schedule=(75, 25, 25, 25, 25, 25, 50), use_wide_window_expanding_mode=True, seed=None, parallel_iterations=10, jit_compile=True, use_input_signature=True, reduce_retracing=False): """Sample from a density with NUTS and an expanding window tuning scheme. This function implements a turnkey MCMC sampling routine using NUTS and an expanding window tuning strategy similar to Stan [1]. It learns a pre- conditioner that scales and rotates the target distribution using a series of expanding windows - either in number of samples (same as in Stan, use_wide_window_expanding_mode=False) or in number of batches/chains (use_wide_window_expanding_mode=True). Currently, the function uses `prior_samples` to initialize MCMC chains uniformly at random between -1 and 1 scaled by the prior standard deviation (i.e., [-prior_std, prior_std]). The scaling is ignored if `use_scaled_init` is set to False. Alternatively, user can input the `init_state` directly. Currently, the tuning and sampling routine is run in Python, with each block of the tuning epoch (window 1, 2, and 3 in Stan [1]) run with two @tf.function compiled functions. The user can control the compilation options using the kwargs `jit_compile`, `use_input_signature`, and `reduce_retracing`. Setting all to True would compile to XLA and potentially avoid the small overhead of function recompilation (note that it is not yet the case in XLA right now). It is not yet clear whether doing it this way is better than just wrapping the full inference routine in tf.function with XLA. Internally, this calls `_sample_posterior`, which assumes a real-valued target density function and takes a Tensor with shape=(batch * dimension) as input. The tuning routine is a memory-less (i.e., no warm-up samples are saved) MCMC sampling with number of samples specified by a list-like `tuning_window_schedule`. Args: target_log_prob: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. prior_samples: Nested structure of `Tensor`s, each of shape `[batches, latent_part_event_shape]` and should be sample from the prior. They are used to generate an initial chain position if `init_state` is not supplied. constraining_bijectors: `tfp.distributions.Bijector` or list of `tfp.distributions.Bijector`s. These bijectors use `forward` to map the state on the real space to the constrained state expected by `target_log_prob`. init_state: (Optional) `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). num_samples: Integer number of the Markov chain draws after tuning. nchains: Integer number of the Markov chains after tuning. init_nchains: Integer number of the Markov chains in the first phase of tuning. target_accept_prob: Floating point scalar `Tensor`. Target acceptance probability for step size adaptation. max_tree_depth: Maximum depth of the tree implicitly built by NUTS. See `tfp.mcmc.NoUTurnSampler` for more details use_scaled_init: Boolean. If `True`, generate initial state within [-1, 1] scaled by prior sample standard deviation in the unconstrained real space. This kwarg is ignored if `init_state` is not None tuning_window_schedule: List-like sequence of integers that specify the tuning schedule. Each integer number specifies the number of MCMC samples within a single warm-up window. The first and the last window tunes the step size (a scalar) only, while the intermediate windows tune both step size and the pre-conditioner. Moreover, the intermediate windows double the number of samples taken: for example, the default schedule (75, 25, 25, 25, 25, 25, 50) actually means it will take (75, 25 * 2**0, 25 * 2**1, 25 * 2**2, 25 * 2**3, 25 * 2**4, 50) samples. use_wide_window_expanding_mode: Boolean. Default to `True` that we double the number of chains from the previous stage for the intermediate windows. See `tuning_window_schedule` kwarg for more details. seed: Python integer to seed the random number generator. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. Note that if you set the seed to have deterministic output you should also set `parallel_iterations` to 1. jit_compile: kwarg pass to tf.function decorator. If True, the function is always compiled by XLA. use_input_signature: If True, generate an input_signature kwarg to pass to tf.function decorator. reduce_retracing: kwarg pass to tf.function decorator. When True, tf.function may generate fewer, graphs that are less specialized on input shapes. Returns: posterior_samples: A `Tensor` or Python list of `Tensor`s representing the posterior MCMC samples after tuning. It has the same structure as `prior_samples` but with the leading shape being (num_samples * nchains) diagnostic: A list of `Tensor` representing the diagnostics from NUTS: `target_log_prob`, `leapfrogs_taken`, `has_divergence`, `energy`, `log_accept_ratio`, `reach_max_depth`, `step_size. conditioning_bijector: A tfp bijector that scales and rotates the target density function in latent unconstrained space as determined by adaptation. ### Examples Sampling from a multivariate Student-T distribution. ```python DTYPE = np.float32 nd = 50 concentration = 1. prior_dist = tfd.Sample(tfd.Normal(tf.constant(0., DTYPE), 100.), nd) mu = tf.cast(np.linspace(-100, 100, nd), dtype=DTYPE) sigma = tf.cast(np.exp(np.linspace(-1, 1.5, nd)), dtype=DTYPE) corr_tril = tfd.CholeskyLKJ( dimension=nd, concentration=concentration).sample() scale_tril = tf.linalg.matmul(tf.linalg.diag(sigma), corr_tril) target_dist = tfd.MultivariateStudentTLinearOperator( df=5., loc=mu, scale=tf.linalg.LinearOperatorLowerTriangular(scale_tril)) target_log_prob = lambda *x: ( prior_dist.log_prob(*x) + target_dist.log_prob(*x)) ( [mcmc_samples], diagnostic, conditioning_bijector ) = window_tune_nuts_sampling(target_log_prob, [prior_dist.sample(2000)]) loc_conditioner, scale_conditioner = conditioning_bijector.trainable_variables _, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].plot(mu, loc_conditioner.numpy(), 'o', label='conditioner mean') ax[0].plot(mu, tf.reduce_mean( mcmc_samples, axis=[0, 1]), 'o', label='estimated mean') ax[0].legend() sigma_sim = target_dist._stddev() ax[1].plot(sigma_sim, scale_conditioner.numpy(), 'o', label='conditioner std') ax[1].plot(sigma_sim, tf.math.reduce_std( mcmc_samples, axis=[0, 1]), 'o', label='estimated std'); ax[1].legend() ax[0].plot([min(mu), max(mu)], [min(mu), max(mu)]) ax[1].plot([min(sigma_sim), max(sigma_sim)], [min(sigma_sim), max(sigma_sim)]) ``` #### References [1]: Stan Reference Manual. https://mc-stan.org/docs/2_23/reference-manual/hmc-algorithm-parameters.html#automatic-parameter-tuning """ log_prob_val = target_log_prob(*prior_samples) log_prob_rank = ps.rank(log_prob_val) assert log_prob_rank == 1 if constraining_bijectors is not None: target_log_prob_unconstrained = make_transformed_log_prob( target_log_prob, constraining_bijectors, direction='forward', enable_bijector_caching=False) # constrain to unconstrain inverse_transform = make_transform_fn(constraining_bijectors, 'inverse') # unconstrain to constrain forward_transform = make_transform_fn(constraining_bijectors, 'forward') else: target_log_prob_unconstrained = target_log_prob inverse_transform = lambda x: x forward_transform = lambda y: y prior_samples_unconstrained = inverse_transform(prior_samples) init_state_unconstrained = None # If the input to target_log_prob_fn is a nested structure of Tensors, we # flatten and concatenate them into a 1D vector so that it is easier to work # with in mass matrix adaptation. if tf.nest.is_nested(prior_samples_unconstrained): free_rv_event_shape = [x.shape[log_prob_rank:] for x in prior_samples] flat_event_splits = [s.num_elements() for s in free_rv_event_shape] # TODO(b/158878248): replace the two function below with `tfb.Split`. def split_and_reshape(x): assertions = [] message = 'Input must have at least one dimension.' if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 0: raise ValueError(message) else: assertions.append( assert_util.assert_rank_at_least(x, 1, message=message)) with tf.control_dependencies(assertions): x = tf.nest.pack_sequence_as( free_rv_event_shape, tf.split(x, flat_event_splits, axis=-1)) def _reshape_map_part(part, event_shape): static_rank = tf.get_static_value( ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) x = tf.nest.map_structure(_reshape_map_part, x, free_rv_event_shape) return x def concat_list_event(x): def handle_part(x, shape): if len(shape) == 0: # pylint: disable=g-explicit-length-test return x[..., tf.newaxis] return tf.reshape(x, list(x.shape)[:-len(shape)] + [-1]) flat_parts = [ handle_part(v, s) for v, s in zip(x, free_rv_event_shape) ] return tf.concat(flat_parts, axis=-1) def target_log_prob_unconstrained_concated(x): x = split_and_reshape(x) return target_log_prob_unconstrained(*x) prior_samples_unconstrained_concated = concat_list_event( prior_samples_unconstrained) if init_state is not None: init_state_unconstrained = concat_list_event( inverse_transform(init_state)) else: target_log_prob_unconstrained_concated = target_log_prob_unconstrained prior_samples_unconstrained_concated = prior_samples_unconstrained split_and_reshape = lambda x: x if init_state is not None: init_state_unconstrained = inverse_transform(init_state) nuts_samples, diagnostic, conditioning_bijector = _sample_posterior( target_log_prob_unconstrained_concated, prior_samples_unconstrained_concated, init_state=init_state_unconstrained, num_samples=num_samples, nchains=nchains, init_nchains=init_nchains, target_accept_prob=target_accept_prob, max_tree_depth=max_tree_depth, use_scaled_init=use_scaled_init, tuning_window_schedule=tuning_window_schedule, use_wide_window_expanding_mode=use_wide_window_expanding_mode, seed=seed, parallel_iterations=parallel_iterations, jit_compile=jit_compile, use_input_signature=use_input_signature, reduce_retracing=reduce_retracing) return forward_transform( split_and_reshape(nuts_samples)), diagnostic, conditioning_bijector
def reduce_sum(x, m, shard_axes): out = tf.reduce_sum(x, axis=ps.range(log_prob_rank, ps.rank(m))) if shard_axes is not None: out = distribute_lib.psum(out, shard_axes) return out
def _sample_posterior(target_log_prob_unconstrained, prior_samples_unconstrained, init_state=None, num_samples=500, nchains=4, init_nchains=1, target_accept_prob=.8, max_tree_depth=9, use_scaled_init=True, tuning_window_schedule=(75, 25, 25, 25, 25, 25, 50), use_wide_window_expanding_mode=True, seed=None, parallel_iterations=10, jit_compile=True, use_input_signature=False, reduce_retracing=False): """MCMC sampling with HMC/NUTS using an expanding epoch tuning scheme.""" seed_stream = tfp.util.SeedStream(seed, 'window_tune_nuts_sampling') rv_rank = ps.rank(prior_samples_unconstrained) assert rv_rank == 2 total_ndims = ps.shape(prior_samples_unconstrained)[-1] dtype = prior_samples_unconstrained.dtype # TODO(b/158878248): explore option to for user to control the # parameterization of conditioning_bijector. # TODO(b/158878248): right now, we use 2 tf.Variable to initialize a scaling # bijector, and update the underlying value at the end of each warmup window. # It might be faster to rewrite it into a functional style (with a small # additional compilation cost). loc_conditioner = tf.Variable(tf.zeros([total_ndims], dtype=dtype), name='loc_conditioner') scale_conditioner = tf.Variable(tf.ones([total_ndims], dtype=dtype), name='scale_conditioner') # Start with Identity Covariance matrix scale = tf.linalg.LinearOperatorDiag(diag=scale_conditioner, is_non_singular=True, is_self_adjoint=True, is_positive_definite=True) conditioning_bijector = tfb.Shift(shift=loc_conditioner)( tfb.ScaleMatvecLinearOperator(scale)) if init_state is None: # Start at uniform random [-1, 1] around the prior mean in latent space init_state_uniform = tf.random.uniform([init_nchains, total_ndims], dtype=dtype, seed=seed_stream()) * 2. - 1. if use_scaled_init: prior_z_mean = tf.math.reduce_mean(prior_samples_unconstrained, axis=0) prior_z_std = tf.math.reduce_std(prior_samples_unconstrained, axis=0) init_state = init_state_uniform * prior_z_std + prior_z_mean else: init_state = init_state_uniform # The denominator is the O(N^0.25) scaling from Beskos et al. 2010. The # numerator corresponds to the trajectory length. Candidate value includs: 1, # 1.57 (pi / 2). We use a conservately small value here (0.25). init_step_size = tf.constant(0.25 / (total_ndims**0.25), dtype=dtype) hmc_inner = tfp.mcmc.TransformedTransitionKernel( tfp.mcmc.NoUTurnSampler( target_log_prob_fn=target_log_prob_unconstrained, step_size=init_step_size, max_tree_depth=max_tree_depth, parallel_iterations=parallel_iterations, ), conditioning_bijector) hmc_step_size_tuning = tfp.mcmc.DualAveragingStepSizeAdaptation( inner_kernel=hmc_inner, num_adaptation_steps=max(tuning_window_schedule), target_accept_prob=target_accept_prob) if use_input_signature: input_signature = [ tf.TensorSpec(shape=None, dtype=tf.int32), tf.TensorSpec(shape=[None, total_ndims], dtype=dtype), ] else: input_signature = None # TODO(b/158878248): move the nested function definitions to module top-level. @tf.function(input_signature=input_signature, autograph=False, jit_compile=jit_compile, reduce_retracing=reduce_retracing) def fast_adaptation_interval(num_steps, previous_state): """Step size only adaptation interval. This corresponds to window 1 and window 3 in the Stan HMC parameter tuning scheme. Args: num_steps: Number of tuning steps the interval will run. previous_state: Initial state of the tuning interval. Returns: last_state: Last state of the tuning interval. last_pkr: Kernel result from the TransitionKernel at the end of the tuning interval. """ def body_fn(i, state, pkr): next_state, next_pkr = hmc_step_size_tuning.one_step(state, pkr) return i + 1, next_state, next_pkr current_pkr = hmc_step_size_tuning.bootstrap_results(previous_state) _, last_state, last_pkr = tf.while_loop( lambda i, *_: i < num_steps, body_fn, loop_vars=(0, previous_state, current_pkr), maximum_iterations=num_steps, parallel_iterations=parallel_iterations) return last_state, last_pkr def body_fn_window2(i, previous_state, previous_pkr, previous_mean, previous_cov): """Take one MCMC step and update the step size and mass matrix.""" next_state, next_pkr = hmc_step_size_tuning.one_step( previous_state, previous_pkr) n_next = i + 1 delta_pre = previous_state - previous_mean next_mean = previous_mean + delta_pre / tf.cast( n_next, delta_pre.dtype) delta_post = previous_state - next_mean delta_cov = tf.expand_dims(delta_post, -1) * tf.expand_dims( delta_pre, -2) next_cov = previous_cov + delta_cov next_mean.set_shape(previous_mean.shape) next_cov.set_shape(previous_cov.shape) return n_next, next_state, next_pkr, next_mean, next_cov if use_input_signature: input_signature = [ tf.TensorSpec(shape=None, dtype=tf.int32), tf.TensorSpec(shape=None, dtype=tf.int32), tf.TensorSpec(shape=[None, total_ndims], dtype=dtype), tf.TensorSpec(shape=[None, total_ndims], dtype=dtype), tf.TensorSpec(shape=[None, total_ndims, total_ndims], dtype=dtype), ] else: input_signature = None # TODO(b/158878248): move the nested function definitions to module top-level. @tf.function(input_signature=input_signature, autograph=False, jit_compile=jit_compile, reduce_retracing=reduce_retracing) def slow_adaptation_interval(num_steps, previous_n, previous_state, previous_mean, previous_cov): """Interval that tunes the mass matrix and step size simultaneously. This corresponds to window 2 in the Stan HMC parameter tuning scheme. Args: num_steps: Number of tuning steps the interval will run. previous_n: Previous number of tuning steps we have run. previous_state: Initial state of the tuning interval. previous_mean: Current estimated posterior mean. previous_cov: Current estimated posterior covariance matrix. Returns: total_n: Total number of tuning steps we have run. next_state: Last state of the tuning interval. next_pkr: Kernel result from the TransitionKernel at the end of the tuning interval. next_mean: estimated posterior mean after tuning. next_cov: estimated posterior covariance matrix after tuning. """ previous_pkr = hmc_step_size_tuning.bootstrap_results(previous_state) total_n, next_state, next_pkr, next_mean, next_cov = tf.while_loop( lambda i, *_: i < num_steps + previous_n, body_fn_window2, loop_vars=(previous_n, previous_state, previous_pkr, previous_mean, previous_cov), maximum_iterations=num_steps, parallel_iterations=parallel_iterations) float_n = tf.cast(total_n, next_cov.dtype) cov = next_cov / (float_n - 1.) # Regularization scaled_cov = (float_n / (float_n + 5.)) * cov shrinkage = 1e-3 * (5. / (float_n + 5.)) next_cov = scaled_cov + shrinkage return total_n, next_state, next_pkr, next_mean, next_cov def trace_fn(_, pkr): return ( pkr.inner_results.target_log_prob, pkr.inner_results.leapfrogs_taken, pkr.inner_results.has_divergence, pkr.inner_results.energy, pkr.inner_results.log_accept_ratio, pkr.inner_results.reach_max_depth, pkr.inner_results.step_size, ) @tf.function(autograph=False, jit_compile=jit_compile) def run_chain(num_results, current_state, previous_kernel_results): return tfp.mcmc.sample_chain( num_results=num_results, num_burnin_steps=0, current_state=current_state, previous_kernel_results=previous_kernel_results, kernel=hmc_inner, trace_fn=trace_fn, parallel_iterations=parallel_iterations, seed=seed_stream()) # Main sampling with tuning routine. num_steps_tuning_window_schedule0 = tuning_window_schedule[0] # Window 1 to tune step size logging.info('Tuning Window 1...') next_state, _ = fast_adaptation_interval(num_steps_tuning_window_schedule0, init_state) next_mean = tf.zeros_like(init_state) next_cov = tf.zeros(ps.concat( [ps.shape(init_state), ps.shape(init_state)[-1:]], axis=-1), dtype=dtype) mean_updater = tf.zeros([total_ndims], dtype=dtype) diag_updater = tf.ones([total_ndims], dtype=dtype) # Window 2 to tune mass matrix. total_n = 0 for i, num_steps in enumerate(tuning_window_schedule[1:-1]): logging.info('Tuning Window 2 - %s...', i) if not use_wide_window_expanding_mode: num_steps = num_steps * 2**i with tf.control_dependencies([ loc_conditioner.assign(mean_updater, read_value=False), scale_conditioner.assign(diag_updater, read_value=False) ]): (total_n, next_state_, _, next_mean_, next_cov_) = slow_adaptation_interval(num_steps, total_n, next_state, next_mean, next_cov) diag_part = tf.linalg.diag_part(next_cov_) if ps.rank(next_state) > 1: mean_updater = tf.reduce_mean(next_mean_, axis=0) diag_updater = tf.math.sqrt(tf.reduce_mean(diag_part, axis=0)) else: mean_updater = next_mean_ diag_updater = tf.math.sqrt(diag_part) if use_wide_window_expanding_mode: next_mean = tf.concat([next_mean_, next_mean_], axis=0) next_cov = tf.concat([next_cov_, next_cov_], axis=0) next_state = tf.concat([next_state_, next_state_], axis=0) else: next_mean, next_cov, next_state = next_mean_, next_cov_, next_state_ num_steps_tuning_window_schedule3 = tuning_window_schedule[-1] num_batches = ps.size0(next_state) if nchains > num_batches: final_init_state = tf.repeat(next_state, (nchains + 1) // num_batches, axis=0)[:nchains] else: final_init_state = next_state[:nchains] with tf.control_dependencies([ loc_conditioner.assign(mean_updater, read_value=False), scale_conditioner.assign(diag_updater, read_value=False) ]): # Window 3 step size tuning logging.info('Tuning Window 3...') final_tuned_state, final_pkr = fast_adaptation_interval( num_steps_tuning_window_schedule3, final_init_state) # Final samples logging.info('Sampling...') nuts_samples, diagnostic = run_chain(num_samples, final_tuned_state, final_pkr.inner_results) return nuts_samples, diagnostic, conditioning_bijector
def _make_post_swap_replica_results(pre_swap_replica_results, inverse_temperatures, swapped_inverse_temperatures, is_swap_accepted_mask, swap_tensor_fn): """Return Kernel results, valid for post-swap states. Fields will be removed if they cannot be updated in an unambiguous manner. Args: pre_swap_replica_results: Kernel results obtained by running inner_kernel.one_step before swapping. inverse_temperatures: Tensor of inverse temperatures. swapped_inverse_temperatures: Tensor of inverse temperatures, permuted by swaps. is_swap_accepted_mask: Shape [num_replica] + batch_shape boolean Tensor telling which swaps were accepted. Returns Kernel results of same type as pre_swap_replica_results. swap_tensor_fn: Callable. For `x.shape = [num_replica] + batch_shape`, swap_tensor_fn(x) performs swaps where they are accepted, and does not swap otherwise. Returns: new_replica_results: Same type as pre_swap_replica_results. Raises: NotImplementedError: If type of [nested] results is not handled. """ if not isinstance(pre_swap_replica_results, metropolis_hastings.MetropolisHastingsKernelResults): # TODO(b/143702650) Handle other kernels. raise NotImplementedError( '`pre_swap_replica_results` currently works only for ' 'MetropolisHastingsKernelResults. Found: {}. ' 'Please file a request with the TensorFlow Probability team.'. format(type(pre_swap_replica_results))) kr = pre_swap_replica_results dtype = swapped_inverse_temperatures.dtype # Hard to modify proposed_results in an um-ambiguous manner. # ...we also don't need to. kr = kr._replace( proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype), proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype), ) replica_and_batch_rank = ps.rank(kr.log_accept_ratio) # After using swap_tensor_fn on "values", values will be multiplied by the # swapped_inverse_temperatures. We need it to be multiplied instead by the # inverse temperature corresponding to its index. it_ratio_raw = inverse_temperatures / swapped_inverse_temperatures it_ratio = tf.where( is_swap_accepted_mask, mcmc_util.left_justified_expand_dims_to(it_ratio_raw, replica_and_batch_rank), tf.convert_to_tensor(1.0, dtype=dtype)) def _swap_then_retemper(x): x, is_multipart = mcmc_util.prepare_state_parts(x) it_ratio_ = mcmc_util.left_justified_expand_dims_like(it_ratio, x[0]) x = [swap_tensor_fn(x_part) * it_ratio_ for x_part in x] if not is_multipart: x = x[0] return x if isinstance(kr.accepted_results, hmc.UncalibratedHamiltonianMonteCarloKernelResults): kr = kr._replace(accepted_results=kr.accepted_results._replace( target_log_prob=_swap_then_retemper( kr.accepted_results.target_log_prob), grads_target_log_prob=_swap_then_retemper( kr.accepted_results.grads_target_log_prob))) elif isinstance(kr.accepted_results, random_walk_metropolis.UncalibratedRandomWalkResults): kr = kr._replace(accepted_results=kr.accepted_results._replace( target_log_prob=_swap_then_retemper( kr.accepted_results.target_log_prob))) else: # TODO(b/143702650) Handle other kernels. raise NotImplementedError( 'Only HMC and RWMH Kernels are handled at this time. Please file a ' 'request with the TensorFlow Probability team.') return kr
def covariance(x, y=None, sample_axis=0, event_axis=-1, keepdims=False, name=None): """Sample covariance between observations indexed by `event_axis`. Given `N` samples of scalar random variables `X` and `Y`, covariance may be estimated as ```none Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)} Xbar := N^{-1} sum_{n=1}^N X_n Ybar := N^{-1} sum_{n=1}^N Y_n ``` For vector-variate random variables `X = (X1, ..., Xd)`, `Y = (Y1, ..., Yd)`, one is often interested in the covariance matrix, `C_{ij} := Cov[Xi, Yj]`. ```python x = tf.random.normal(shape=(100, 2, 3)) y = tf.random.normal(shape=(100, 2, 3)) # cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j]. cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None) # cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n] cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1) ``` Notice we divide by `N`, which does not create `NaN` when `N = 1`, but is slightly biased. Args: x: A numeric `Tensor` holding samples. y: Optional `Tensor` with same `dtype` and `shape` as `x`. Default value: `None` (`y` is effectively set to `x`). sample_axis: Scalar or vector `Tensor` designating axis holding samples, or `None` (meaning all axis hold samples). Default value: `0` (leftmost dimension). event_axis: Scalar or vector `Tensor`, or `None` (scalar events). Axis indexing random events, whose covariance we are interested in. If a vector, entries must form a contiguous block of dims. `sample_axis` and `event_axis` should not intersect. Default value: `-1` (rightmost axis holds events). keepdims: Boolean. Whether to keep the sample axis as singletons. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., `'covariance'`). Returns: cov: A `Tensor` of same `dtype` as the `x`, and rank equal to `rank(x) - len(sample_axis) + 2 * len(event_axis)`. Raises: AssertionError: If `x` and `y` are found to have different shape. ValueError: If `sample_axis` and `event_axis` are found to overlap. ValueError: If `event_axis` is found to not be contiguous. """ with tf.name_scope(name or 'covariance'): x = tf.convert_to_tensor(x, name='x') # Covariance *only* uses the centered versions of x (and y). x = x - tf.reduce_mean(x, axis=sample_axis, keepdims=True) if y is None: y = x else: y = tf.convert_to_tensor(y, name='y', dtype=x.dtype) # If x and y have different shape, sample_axis and event_axis will likely # be wrong for one of them! tensorshape_util.assert_is_compatible_with(x.shape, y.shape) y = y - tf.reduce_mean(y, axis=sample_axis, keepdims=True) if event_axis is None: return tf.reduce_mean(x * tf.math.conj(y), axis=sample_axis, keepdims=keepdims) if sample_axis is None: raise ValueError( 'sample_axis was None, which means all axis hold events, and this ' 'overlaps with event_axis ({})'.format(event_axis)) event_axis = _make_positive_axis(event_axis, ps.rank(x)) sample_axis = _make_positive_axis(sample_axis, ps.rank(x)) # If we get lucky and axis is statically defined, we can do some checks. if _is_list_like(event_axis) and _is_list_like(sample_axis): event_axis = tuple(map(int, event_axis)) sample_axis = tuple(map(int, sample_axis)) if set(event_axis).intersection(sample_axis): raise ValueError( 'sample_axis ({}) and event_axis ({}) overlapped'.format( sample_axis, event_axis)) if (np.diff(np.array(sorted(event_axis))) > 1).any(): raise ValueError( 'event_axis must be contiguous. Found: {}'.format( event_axis)) batch_axis = list( sorted( set(range(tensorshape_util.rank( x.shape))).difference(sample_axis + event_axis))) else: batch_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), ps.concat((sample_axis, event_axis), 0)) event_axis = ps.cast(event_axis, dtype=tf.int32) sample_axis = ps.cast(sample_axis, dtype=tf.int32) batch_axis = ps.cast(batch_axis, dtype=tf.int32) # Permute x/y until shape = B + E + S perm_for_xy = ps.concat((batch_axis, event_axis, sample_axis), 0) x_permed = tf.transpose(a=x, perm=perm_for_xy) y_permed = tf.transpose(a=y, perm=perm_for_xy) batch_ndims = ps.size(batch_axis) batch_shape = ps.shape(x_permed)[:batch_ndims] event_ndims = ps.size(event_axis) event_shape = ps.shape(x_permed)[batch_ndims:batch_ndims + event_ndims] sample_shape = ps.shape(x_permed)[batch_ndims + event_ndims:] sample_ndims = ps.size(sample_shape) n_samples = ps.reduce_prod(sample_shape) n_events = ps.reduce_prod(event_shape) # Flatten sample_axis into one long dim. x_permed_flat = tf.reshape( x_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0)) y_permed_flat = tf.reshape( y_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0)) # Do the same for event_axis. x_permed_flat = tf.reshape( x_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0)) y_permed_flat = tf.reshape( y_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0)) # After matmul, cov.shape = batch_shape + [n_events, n_events] cov = tf.matmul(x_permed_flat, y_permed_flat, adjoint_b=True) / ps.cast(n_samples, x.dtype) # Insert some singletons to make # cov.shape = batch_shape + event_shape**2 + [1,...,1] # This is just like x_permed.shape, except the sample_axis is all 1's, and # the [n_events] became event_shape**2. cov = tf.reshape( cov, ps.concat( ( batch_shape, # event_shape**2 used here because it is the same length as # event_shape, and has the same number of elements as one # batch of covariance. event_shape**2, ps.ones([sample_ndims], tf.int32)), 0)) # Permuting by the argsort inverts the permutation, making # cov.shape have ones in the position where there were samples, and # [n_events * n_events] in the event position. cov = tf.transpose(a=cov, perm=ps.invert_permutation(perm_for_xy)) # Now expand event_shape**2 into event_shape + event_shape. # We here use (for the first time) the fact that we require event_axis to be # contiguous. e_start = event_axis[0] e_len = 1 + event_axis[-1] - event_axis[0] cov = tf.reshape( cov, ps.concat((ps.shape(cov)[:e_start], event_shape, event_shape, ps.shape(cov)[e_start + e_len:]), 0)) # tf.squeeze requires python ints for axis, not Tensor. This is enough to # require our axis args to be constants. if not keepdims: squeeze_axis = ps.where(sample_axis < e_start, sample_axis, sample_axis + e_len) cov = _squeeze(cov, axis=squeeze_axis) return cov
def left_justified_expand_dims_like(x, reference, name=None): """Right pads `x` with `rank(reference) - rank(x)` ones.""" with tf.name_scope(name or 'left_justified_expand_dims_like'): return left_justified_expand_dims_to(x, ps.rank(reference))
class MarkovChainBijectorTest(test_util.TestCase): # pylint: disable=g-long-lambda @parameterized.named_parameters( dict(testcase_name='deterministic_prior', prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]), transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)), dict(testcase_name='deterministic_transition', prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.), transition_fn=lambda _, x: tfd.Deterministic(x)), dict(testcase_name='fully_deterministic', prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]), transition_fn=lambda _, x: tfd.Deterministic(x)), dict(testcase_name='mvn_diag', prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]], scale_diag=[1.])), transition_fn=lambda _, x: tfd.VectorDeterministic(x)), dict(testcase_name='docstring_dirichlet', prior_fn=lambda: tfd.JointDistributionNamedAutoBatched( {'probs': tfd.Dirichlet([1., 1.])}), transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched( { 'probs': tfd.MultivariateNormalDiag(loc=x['probs'], scale_diag=[0.1, 0.1]) }, batch_ndims=ps.rank(x['probs']))), dict(testcase_name='uniform_step', prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])), transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)), dict(testcase_name='joint_distribution', prior_fn=lambda: tfd.JointDistributionNamedAutoBatched( batch_ndims=2, model={ 'a': tfd.Gamma(tf.zeros([5]), 1.), 'b': lambda a: (tfb.Reshape(event_shape_in=[4, 3], event_shape_out=[2, 3, 2]) (tfd.Independent(tfd.Normal( loc=tf.zeros([5, 4, 3]), scale=a[..., tf.newaxis, tf.newaxis]), reinterpreted_batch_ndims=2))) }), transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched( batch_ndims=ps.rank_from_shape(x['a'].shape), model={ 'a': tfd.Normal(loc=x['a'], scale=1.), 'b': lambda a: tfd.Deterministic(x['b'] + a[ ..., tf.newaxis, tf.newaxis, tf.newaxis]) })), dict(testcase_name='nested_chain', prior_fn=lambda: tfd. MarkovChain(initial_state_prior=tfb.Split(2) (tfd.MultivariateNormalDiag(0., [1., 2.])), transition_fn=lambda _, x: tfb.Split(2) (tfd.MultivariateNormalDiag(x[0], [1., 2.])), num_steps=6), transition_fn=( lambda _, x: tfd.JointDistributionSequentialAutoBatched( [ tfd.MultivariateNormalDiag(x[0], [1.]), tfd.MultivariateNormalDiag(x[1], [1.]) ], batch_ndims=ps.rank(x[0]))))) # pylint: enable=g-long-lambda def test_default_bijector(self, prior_fn, transition_fn): chain = tfd.MarkovChain(initial_state_prior=prior_fn(), transition_fn=transition_fn, num_steps=7) y = self.evaluate(chain.sample(seed=test_util.test_seed())) bijector = chain.experimental_default_event_space_bijector() self.assertAllEqual(chain.batch_shape_tensor(), bijector.experimental_batch_shape_tensor()) x = bijector.inverse(y) yy = bijector.forward(tf.nest.map_structure( tf.identity, x)) # Bypass bijector cache. self.assertAllCloseNested(y, yy) chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape, chain.event_shape_tensor()) self.assertAllEqualNested(bijector.inverse_min_event_ndims, chain_event_ndims) ildj = bijector.inverse_log_det_jacobian( tf.nest.map_structure(tf.identity, y), # Bypass bijector cache. event_ndims=chain_event_ndims) if not bijector.is_constant_jacobian: self.assertAllEqual(ildj.shape, chain.batch_shape) fldj = bijector.forward_log_det_jacobian( tf.nest.map_structure(tf.identity, x), # Bypass bijector cache. event_ndims=bijector.inverse_event_ndims(chain_event_ndims)) self.assertAllClose(ildj, -fldj) # Verify that event shapes are passed through and flattened/unflattened # correctly. inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape) x_event_shapes = tf.nest.map_structure( lambda t, nd: t.shape[ps.rank(t) - nd:], x, bijector.forward_min_event_ndims) self.assertAllEqualNested(inverse_event_shapes, x_event_shapes) forward_event_shapes = bijector.forward_event_shape( inverse_event_shapes) self.assertAllEqualNested(forward_event_shapes, chain.event_shape) # Verify that the outputs of other methods have the correct structure. inverse_event_shape_tensors = bijector.inverse_event_shape_tensor( chain.event_shape_tensor()) self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes) forward_event_shape_tensors = bijector.forward_event_shape_tensor( inverse_event_shape_tensors) self.assertAllEqualNested(forward_event_shape_tensors, chain.event_shape_tensor())
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( target_log_prob_fn=self.target_log_prob_fn, inverse_temperatures=inverse_temperatures, untempered_log_prob_fn=self.untempered_log_prob_fn, tempered_log_prob_fn=self.tempered_log_prob_fn, ) # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise raise TypeError( '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` ' 'argument. `TransitionKernel` instances now receive seeds via ' '`one_step`.') seed = samplers.sanitize_seed(seed) # Retain for diagnostics. inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3) # Step the inner TransitionKernel. [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results, seed=inner_seed) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob) num_replica = ps.size0(inverse_temperatures) inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. try: swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed, step_count=previous_kernel_results.step_count), dtype=tf.int32) except TypeError as e: if 'step_count' not in str(e): raise warnings.warn( 'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept ' 'the `step_count` argument. Falling back to omitting the ' 'argument. This fallback will be removed after 24-Oct-2020.' ) swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed), dtype=tf.int32) null_swaps = mcmc_util.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs for use in the swap acceptance ratio. if self.tempered_log_prob_fn is None: # Efficient way of re-evaluating target_log_prob_fn on the # pre_swap_replica_states. untempered_energy_ignoring_ulp = ( # Since untempered_log_prob_fn is None, we may assume # inverse_temperatures > 0 (else the target is improper). pre_swap_replica_target_log_prob / inverse_temperatures) else: # The untempered_log_prob_fn does not factor into the acceptance ratio. # Proof: Suppose the tempered target is # p_k(x) = f(x)^{beta_k} g(x), # So f(x) is tempered, and g(x) is not. Then, the acceptance ratio for # a 1 <--> 2 swap is... # (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2)) # which depends only on f(x), since terms involving g(x) cancel. untempered_energy_ignoring_ulp = self.tempered_log_prob_fn( *pre_swap_replica_states) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. # Note: The untempered_log_prob_fn (if provided) is not included in # untempered_pre_swap_replica_target_log_prob, and hence does not factor # into energy_diff. Why? Because, it cancels out in the acceptance ratio. energy_diff = (untempered_energy_ignoring_ulp - mcmc_util.index_remapping_gather( untempered_energy_ignoring_ulp, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * mcmc_util.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=logu_seed)) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = mcmc_util.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, mcmc_util.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) if self._state_includes_replicas: post_swap_states = post_swap_replica_states else: post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _set_swapped_fields_to_nan( _swap_log_prob_and_maybe_grads(pre_swap_replica_results, post_swap_replica_states, inner_kernel)) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, step_count=previous_kernel_results.step_count + 1, seed=seed, ) return states, post_swap_kernel_results
def loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_shape = prefer_static.shape( current_step_meta_info.init_energy) direction = tf.cast(tf.random.uniform(shape=batch_shape, minval=0, maxval=2, dtype=tf.int32, seed=self._seed_stream()), dtype=tf.bool) tree_start_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[1], v[0]), initial_step_state) directions_expanded = [ _rightmost_expand_to_rank(direction, prefer_static.rank(state)) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(d, ss, -ss) for d, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, momentum_subtree_cumsum, leapfrogs_taken ] = self._build_sub_tree( directions_expanded, integrator, current_step_meta_info, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state tree_weight = candidate_tree_state.weight if MULTINOMIAL_SAMPLE: weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], energy=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.energy, last_candidate_state.energy), weight=weight_sum) for new_candidate_state_temp, old_candidate_state_temp in zip( new_candidate_state.state, last_candidate_state.state): new_candidate_state_temp.set_shape( old_candidate_state_temp.shape) for new_candidate_grad_temp, old_candidate_grad_temp in zip( new_candidate_state.target_grad_parts, last_candidate_state.target_grad_parts): new_candidate_grad_temp.set_shape( old_candidate_grad_temp.shape) # Update left right information of the trajectory, and check trajectory # level U turn tree_otherend_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[0], v[1]), initial_step_state) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ tf.stack( [ # pylint: disable=g-complex-comprehension tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), r, l), tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), l, r), ], axis=0) for l, r in zip(tf.nest.flatten(tree_final_states), tf.nest.flatten(tree_otherend_states)) ]) momentum_tree_cumsum = [] for p0, p1 in zip(initial_step_metastate.momentum_sum, momentum_subtree_cumsum): momentum_part_temp = p0 + p1 momentum_part_temp.set_shape(p0.shape) momentum_tree_cumsum.append(momentum_part_temp) for new_state_temp, old_state_temp in zip( tf.nest.flatten(new_step_state), tf.nest.flatten(initial_step_state)): new_state_temp.set_shape(old_state_temp.shape) if GENERALIZED_UTURN: state_diff = momentum_tree_cumsum else: state_diff = [s[1] - s[0] for s in new_step_state.state] no_u_turns_trajectory = has_not_u_turn( state_diff, [m[0] for m in new_step_state.momentum], [m[1] for m in new_step_state.momentum], log_prob_rank=prefer_static.rank_from_shape(batch_shape)) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, momentum_sum=momentum_tree_cumsum, energy_diff_sum=(energy_diff_tree_sum + initial_step_metastate.energy_diff_sum), continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, new_step_state, new_step_metastate