def one_step(self, current_state, previous_kernel_results): with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc', 'one_step'), values=[ self.step_size, self.num_leapfrog_steps, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob ]): [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append( tf.random_normal(shape=tf.shape(x), dtype=x.dtype.base_dtype, seed=self._seed_stream())) def _leapfrog_one_step(*args): """Closure representing computation done during each leapfrog step.""" return _leapfrog_integrator_one_step( target_log_prob_fn=self.target_log_prob_fn, independent_chain_ndims=independent_chain_ndims, step_sizes=step_sizes, current_momentum_parts=args[0], current_state_parts=args[1], current_target_log_prob=args[2], current_target_log_prob_grad_parts=args[3], state_gradients_are_stopped=self. state_gradients_are_stopped) # Do leapfrog integration. [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = tf.while_loop( cond=lambda i, *args: i < self.num_leapfrog_steps, body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args) ), loop_vars=[ tf.zeros([], tf.int32, name='iter'), current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts, ])[1:] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction= _compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, ), ]
def auto_correlation( x, axis=-1, max_lags=None, center=True, normalize=True, name="auto_correlation"): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with ops.name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = util.prefer_static_rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T "time" steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = util.rotate_transpose(x, shift) if center: x_rotated -= math_ops.reduce_mean(x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = util.prefer_static_shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = math_ops.cast(x_len, np.float64) target_length = math_ops.pow( np.float64(2.), math_ops.ceil(math_ops.log(x_len_float64 * 2) / np.log(2.))) pad_length = math_ops.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype.is_complex: if not dtype.is_floating: raise TypeError("Argument x must have either float or complex dtype" " found: {}".format(dtype)) x_rotated_pad = math_ops.complex(x_rotated_pad, dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = spectral_ops.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = math_ops.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not x_rotated.shape.is_fully_defined(): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = ops.convert_to_tensor(max_lags, name="max_lags") max_lags_ = tensor_util.constant_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = math_ops.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = x_rotated.shape.as_list() chopped_shape[-1] = min(x_len, max_lags + 1) shifted_product_chopped.set_shape(chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = math_ops.cast(x_len, dtype.real_dtype) max_lags = math_ops.cast(max_lags, dtype.real_dtype) denominator = x_len - math_ops.range(0., max_lags + 1.) denominator = math_ops.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return util.rotate_transpose(shifted_product_rotated, -shift)
def auto_correlation( x, axis=-1, max_lags=None, center=True, normalize=True, name="auto_correlation"): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with ops.name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = util.prefer_static_rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T "time" steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = util.rotate_transpose(x, shift) if center: x_rotated -= math_ops.reduce_mean(x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = util.prefer_static_shape(x_rotated)[-1] # TODO (langmore) Investigate whether this zero padding helps or hurts. At id:595 gh:596 # the moment is is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = math_ops.cast(x_len, np.float64) target_length = math_ops.pow( np.float64(2.), math_ops.ceil(math_ops.log(x_len_float64 * 2) / np.log(2.))) pad_length = math_ops.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype.is_complex: if not dtype.is_floating: raise TypeError("Argument x must have either float or complex dtype" " found: {}".format(dtype)) x_rotated_pad = math_ops.complex(x_rotated_pad, dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = spectral_ops.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = math_ops.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not x_rotated.shape.is_fully_defined(): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = ops.convert_to_tensor(max_lags, name="max_lags") max_lags_ = tensor_util.constant_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = math_ops.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = x_rotated.shape.as_list() chopped_shape[-1] = min(x_len, max_lags + 1) shifted_product_chopped.set_shape(chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = math_ops.cast(x_len, dtype.real_dtype) max_lags = math_ops.cast(max_lags, dtype.real_dtype) denominator = x_len - math_ops.range(0., max_lags + 1.) denominator = math_ops.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return util.rotate_transpose(shifted_product_rotated, -shift)
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ # Take one inner step. [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results) if (not has_target_log_prob(proposed_results) or not has_target_log_prob( previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [ proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob ] try: to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn( 'Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum(to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio # Note: # - We mutate seed state so subsequent calls are not correlated. # - We mutate seed BEFORE using it just in case users supplied the # same seed to the inner kernel. self._seed = distributions_util.gen_new_seed( self.seed, salt='metropolis_hastings_one_step') log_uniform = tf.log( tf.random_uniform( shape=tf.shape(proposed_results.target_log_prob), dtype=proposed_results.target_log_prob.dtype.base_dtype, seed=self.seed)) is_accepted = log_uniform < log_accept_ratio independent_chain_ndims = distributions_util.prefer_static_rank( proposed_results.target_log_prob) next_state = mcmc_util.choose(is_accepted, proposed_state, current_state, independent_chain_ndims) accepted_results = type(proposed_results)( **dict([(fn, mcmc_util.choose( is_accepted, getattr(proposed_results, fn), getattr(previous_kernel_results.accepted_results, fn), independent_chain_ndims)) for fn in proposed_results._fields])) return [ next_state, MetropolisHastingsKernelResults( accepted_results=accepted_results, is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, ) ]
def move_dimension(x, source_idx, dest_idx): """Move a single tensor dimension within its shape. This is a special case of `tf.transpose()`, which applies arbitrary permutations to tensor dimensions. Args: x: Tensor of rank `ndims`. source_idx: Integer index into `x.shape` (negative indexing is supported). dest_idx: Integer index into `x.shape` (negative indexing is supported). Returns: x_perm: Tensor of rank `ndims`, in which the dimension at original index `source_idx` has been moved to new index `dest_idx`, with all other dimensions retained in their original order. Example: ```python x = tf.compat.v1.placeholder(shape=[200, 30, 4, 1, 6]) x_perm = _move_dimension(x, 1, 1) # no-op x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6] x_perm = _move_dimension(x, 0, -2) # equivalent to previous x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1] ``` """ ndims = util.prefer_static_rank(x) if isinstance(source_idx, int): dtype = dtypes.int32 else: dtype = dtypes.as_dtype(source_idx.dtype) # Handle negative indexing. Since ndims might be dynamic, this makes # source_idx and dest_idx also possibly dynamic. if source_idx < 0: source_idx = ndims + source_idx if dest_idx < 0: dest_idx = ndims + dest_idx # Construct the appropriate permutation of dimensions, depending # whether the source is before or after the destination. def move_left_permutation(): return util.prefer_static_value( array_ops.concat([ math_ops.range(0, dest_idx, dtype=dtype), [source_idx], math_ops.range(dest_idx, source_idx, dtype=dtype), math_ops.range(source_idx + 1, ndims, dtype=dtype) ], axis=0)) def move_right_permutation(): return util.prefer_static_value( array_ops.concat([ math_ops.range(0, source_idx, dtype=dtype), math_ops.range(source_idx + 1, dest_idx + 1, dtype=dtype), [source_idx], math_ops.range(dest_idx + 1, ndims, dtype=dtype) ], axis=0)) def x_permuted(): return array_ops.transpose(x, perm=smart_cond.smart_cond( source_idx < dest_idx, move_right_permutation, move_left_permutation)) # One final conditional to handle the special case where source # and destination indices are equal. return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx), lambda: x, x_permuted)
def move_dimension(x, source_idx, dest_idx): """Move a single tensor dimension within its shape. This is a special case of `tf.transpose()`, which applies arbitrary permutations to tensor dimensions. Args: x: Tensor of rank `ndims`. source_idx: Integer index into `x.shape` (negative indexing is supported). dest_idx: Integer index into `x.shape` (negative indexing is supported). Returns: x_perm: Tensor of rank `ndims`, in which the dimension at original index `source_idx` has been moved to new index `dest_idx`, with all other dimensions retained in their original order. Example: ```python x = tf.placeholder(shape=[200, 30, 4, 1, 6]) x_perm = _move_dimension(x, 1, 1) # no-op x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6] x_perm = _move_dimension(x, 0, -2) # equivalent to previous x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1] ``` """ ndims = util.prefer_static_rank(x) if isinstance(source_idx, int): dtype = tf.int32 else: dtype = tf.as_dtype(source_idx.dtype) # Handle negative indexing. Since ndims might be dynamic, this makes # source_idx and dest_idx also possibly dynamic. if source_idx < 0: source_idx = ndims + source_idx if dest_idx < 0: dest_idx = ndims + dest_idx # Construct the appropriate permutation of dimensions, depending # whether the source is before or after the destination. def move_left_permutation(): return util.prefer_static_value( tf.concat( [ tf.range(0, dest_idx, dtype=dtype), [source_idx], tf.range(dest_idx, source_idx, dtype=dtype), tf.range(source_idx + 1, ndims, dtype=dtype) ], axis=0)) def move_right_permutation(): return util.prefer_static_value( tf.concat( [ tf.range(0, source_idx, dtype=dtype), tf.range(source_idx + 1, dest_idx + 1, dtype=dtype), [source_idx], tf.range(dest_idx + 1, ndims, dtype=dtype) ], axis=0)) def x_permuted(): return tf.transpose( x, perm=smart_cond.smart_cond(source_idx < dest_idx, move_right_permutation, move_left_permutation)) # One final conditional to handle the special case where source # and destination indices are equal. return smart_cond.smart_cond( tf.equal(source_idx, dest_idx), lambda: x, x_permuted)
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(self.name, 'hmc_kernel', [ self.step_size, self.num_leapfrog_steps, self.seed, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob ]): with tf.name_scope('initialize'): [ current_state_parts, step_sizes, current_target_log_prob, current_grads_target_log_prob, ] = _prepare_args( self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True) current_momentums = [] for s in current_state_parts: # Note: # - We mutate seed state so subsequent calls are not correlated. # - We mutate seed BEFORE using it just in case users supplied the # same seed to an outer kernel, e.g., `MetropolisHastings`. self._seed = distributions_util.gen_new_seed( self.seed, salt='hmc_kernel_momentums') current_momentums.append( tf.random_normal(shape=tf.shape(s), dtype=s.dtype.base_dtype, seed=self.seed)) num_leapfrog_steps = tf.convert_to_tensor( self.num_leapfrog_steps, dtype=tf.int32, name='num_leapfrog_steps') independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) [ next_momentums, next_state_parts, next_target_log_prob, next_grads_target_log_prob, ] = _leapfrog_integrator(current_momentums, self.target_log_prob_fn, current_state_parts, step_sizes, num_leapfrog_steps, current_target_log_prob, current_grads_target_log_prob) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction= _compute_log_acceptance_correction( current_momentums, next_momentums, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_grads_target_log_prob, ), ]
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(name=mcmc_util.make_name(self.name, 'mala', 'one_step'), values=[ self.step_size, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, previous_kernel_results.volatility, previous_kernel_results.diffusion_drift ]): with tf.name_scope('initialize'): # Prepare input arguments to be passed to `_euler_method`. [ current_state_parts, step_size_parts, current_target_log_prob, _, # grads_target_log_prob current_volatility_parts, _, # grads_volatility current_drift_parts, ] = _prepare_args( self.target_log_prob_fn, self.volatility_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, previous_kernel_results.volatility, previous_kernel_results.grads_volatility) random_draw_parts = [] for s in current_state_parts: random_draw_parts.append( tf.random_normal(shape=tf.shape(s), dtype=s.dtype.base_dtype, seed=self._seed_stream())) # Number of independent chains run by the algorithm. independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) # Generate the next state of the algorithm using Euler-Maruyama method. next_state_parts = _euler_method(random_draw_parts, current_state_parts, current_drift_parts, step_size_parts, current_volatility_parts) # Compute helper `UncalibratedLangevinKernelResults` to be processed by # `_compute_log_acceptance_correction` and in the next iteration of # `one_step` function. [ _, # state_parts _, # step_sizes next_target_log_prob, next_grads_target_log_prob, next_volatility_parts, next_grads_volatility, next_drift_parts, ] = _prepare_args(self.target_log_prob_fn, self.volatility_fn, next_state_parts, step_size_parts) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedLangevinKernelResults( log_acceptance_correction= _compute_log_acceptance_correction( current_state_parts, next_state_parts, current_volatility_parts, next_volatility_parts, current_drift_parts, next_drift_parts, step_size_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_grads_target_log_prob, volatility=maybe_flatten(next_volatility_parts), grads_volatility=next_grads_volatility, diffusion_drift=next_drift_parts), ]
def kernel(target_log_prob_fn, current_state, step_size, num_leapfrog_steps, seed=None, current_target_log_prob=None, current_grads_target_log_prob=None, name=None): """Runs one iteration of Hamiltonian Monte Carlo. Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm that takes a series of gradient-informed steps to produce a Metropolis proposal. This function applies one step of HMC to randomly update the variable `x`. This function can update multiple chains in parallel. It assumes that all leftmost dimensions of `current_state` index independent chain states (and are therefore updated independently). The output of `target_log_prob_fn()` should sum log-probabilities across all event dimensions. Slices along the rightmost dimensions may have different target distributions; for example, `current_state[0, :]` could have a different target distribution from `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of independent chains is `tf.size(target_log_prob_fn(*current_state))`.) #### Examples: ##### Simple chain with warm-up. ```python tfd = tf.contrib.distributions # Tuning acceptance rates: dtype = np.float32 target_accept_rate = 0.631 num_warmup_iter = 500 num_chain_iter = 500 x = tf.get_variable(name="x", initializer=dtype(1)) step_size = tf.get_variable(name="step_size", initializer=dtype(1)) target = tfd.Normal(loc=dtype(0), scale=dtype(1)) new_x, other_results = hmc.kernel( target_log_prob_fn=target.log_prob, current_state=x, step_size=step_size, num_leapfrog_steps=3)[:4] x_update = x.assign(new_x) step_size_update = step_size.assign_add( step_size * tf.where( other_results.acceptance_probs > target_accept_rate, 0.01, -0.01)) warmup = tf.group([x_update, step_size_update]) tf.global_variables_initializer().run() sess.graph.finalize() # No more graph building. # Warm up the sampler and adapt the step size for _ in xrange(num_warmup_iter): sess.run(warmup) # Collect samples without adapting step size samples = np.zeros([num_chain_iter]) for i in xrange(num_chain_iter): _, x_, target_log_prob_, grad_ = sess.run([ x_update, x, other_results.target_log_prob, other_results.grads_target_log_prob]) samples[i] = x_ print(samples.mean(), samples.std()) ``` ##### Sample from more complicated posterior. I.e., ```none W ~ MVN(loc=0, scale=sigma * eye(dims)) for i=1...num_samples: X[i] ~ MVN(loc=0, scale=eye(dims)) eps[i] ~ Normal(loc=0, scale=1) Y[i] = X[i].T * W + eps[i] ``` ```python tfd = tf.contrib.distributions def make_training_data(num_samples, dims, sigma): dt = np.asarray(sigma).dtype zeros = tf.zeros(dims, dtype=dt) x = tfd.MultivariateNormalDiag( loc=zeros).sample(num_samples, seed=1) w = tfd.MultivariateNormalDiag( loc=zeros, scale_identity_multiplier=sigma).sample(seed=2) noise = tfd.Normal( loc=dt(0), scale=dt(1)).sample(num_samples, seed=3) y = tf.tensordot(x, w, axes=[[1], [0]]) + noise return y, x, w def make_prior(sigma, dims): # p(w | sigma) return tfd.MultivariateNormalDiag( loc=tf.zeros([dims], dtype=sigma.dtype), scale_identity_multiplier=sigma) def make_likelihood(x, w): # p(y | x, w) return tfd.MultivariateNormalDiag( loc=tf.tensordot(x, w, axes=[[1], [0]])) # Setup assumptions. dtype = np.float32 num_samples = 150 dims = 10 num_iters = int(5e3) true_sigma = dtype(0.5) y, x, true_weights = make_training_data(num_samples, dims, true_sigma) # Estimate of `log(true_sigma)`. log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0)) sigma = tf.exp(log_sigma) # State of the Markov chain. weights = tf.get_variable( name="weights", initializer=np.random.randn(dims).astype(dtype)) prior = make_prior(sigma, dims) def joint_log_prob_fn(w): # f(w) = log p(w, y | x) return prior.log_prob(w) + make_likelihood(x, w).log_prob(y) weights_update = weights.assign( hmc.kernel(target_log_prob_fn=joint_log_prob, current_state=weights, step_size=0.1, num_leapfrog_steps=5)[0]) with tf.control_dependencies([weights_update]): loss = -prior.log_prob(weights) optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma]) sess.graph.finalize() # No more graph building. tf.global_variables_initializer().run() sigma_history = np.zeros(num_iters, dtype) weights_history = np.zeros([num_iters, dims], dtype) for i in xrange(num_iters): _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights]) weights_history[i, :] = weights_ sigma_history[i] = sigma_ true_weights_ = sess.run(true_weights) # Should converge to something close to true_sigma. plt.plot(sigma_history); plt.ylabel("sigma"); plt.xlabel("iteration"); ``` Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. seed: Python integer to seed the random number generator. current_target_log_prob: (Optional) `Tensor` representing the value of `target_log_prob_fn` at the `current_state`. The only reason to specify this argument is to reduce TF graph size. Default value: `None` (i.e., compute as needed). current_grads_target_log_prob: (Optional) Python list of `Tensor`s representing gradient of `current_target_log_prob` at the `current_state` and wrt the `current_state`. Must have same shape as `current_state`. The only reason to specify this argument is to reduce TF graph size. Default value: `None` (i.e., compute as needed). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "hmc_kernel"). Returns: accepted_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. acceptance_probs: Tensor with the acceptance probabilities for each iteration. Has shape matching `target_log_prob_fn(current_state)`. accepted_target_log_prob: `Tensor` representing the value of `target_log_prob_fn` at `accepted_state`. accepted_grads_target_log_prob: Python `list` of `Tensor`s representing the gradient of `accepted_target_log_prob` wrt each `accepted_state`. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. """ with ops.name_scope(name, "hmc_kernel", [ current_state, step_size, num_leapfrog_steps, seed, current_target_log_prob, current_grads_target_log_prob ]): with ops.name_scope("initialize"): [ current_state_parts, step_sizes, current_target_log_prob, current_grads_target_log_prob ] = _prepare_args(target_log_prob_fn, current_state, step_size, current_target_log_prob, current_grads_target_log_prob, maybe_expand=True) independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) def init_momentum(s): return random_ops.random_normal( shape=array_ops.shape(s), dtype=s.dtype.base_dtype, seed=distributions_util.gen_new_seed( seed, salt="hmc_kernel_momentums")) current_momentums = [init_momentum(s) for s in current_state_parts] [ proposed_momentums, proposed_state_parts, proposed_target_log_prob, proposed_grads_target_log_prob, ] = _leapfrog_integrator(current_momentums, target_log_prob_fn, current_state_parts, step_sizes, num_leapfrog_steps, current_target_log_prob, current_grads_target_log_prob) energy_change = _compute_energy_change(current_target_log_prob, current_momentums, proposed_target_log_prob, proposed_momentums, independent_chain_ndims) # u < exp(min(-energy, 0)), where u~Uniform[0,1) # ==> -log(u) >= max(e, 0) # ==> -log(u) >= e # (Perhaps surprisingly, we don't have a better way to obtain a random # uniform from positive reals, i.e., `tf.random_uniform(minval=0, # maxval=np.inf)` won't work.) random_uniform = random_ops.random_uniform( shape=array_ops.shape(energy_change), dtype=energy_change.dtype, seed=seed) random_positive = -math_ops.log(random_uniform) is_accepted = random_positive >= energy_change accepted_target_log_prob = array_ops.where(is_accepted, proposed_target_log_prob, current_target_log_prob) accepted_state_parts = [ _choose(is_accepted, proposed_state_part, current_state_part, independent_chain_ndims) for current_state_part, proposed_state_part in zip( current_state_parts, proposed_state_parts) ] accepted_grads_target_log_prob = [ _choose(is_accepted, proposed_grad, grad, independent_chain_ndims) for proposed_grad, grad in zip(proposed_grads_target_log_prob, current_grads_target_log_prob) ] maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] return [ maybe_flatten(accepted_state_parts), KernelResults( acceptance_probs=math_ops.exp( math_ops.minimum(-energy_change, 0.)), current_grads_target_log_prob=accepted_grads_target_log_prob, current_target_log_prob=accepted_target_log_prob, energy_change=energy_change, is_accepted=is_accepted, proposed_grads_target_log_prob=proposed_grads_target_log_prob, proposed_state=maybe_flatten(proposed_state_parts), proposed_target_log_prob=proposed_target_log_prob, random_positive=random_positive, ), ]
def one_step(self, current_state, previous_kernel_results): """Runs one iteration of Slice Sampler. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. TypeError: if `not target_log_prob.dtype.is_floating`. """ with tf.name_scope(name=mcmc_util.make_name(self.name, 'slice', 'one_step'), values=[ self.step_size, self.max_doublings, self._seed_stream, current_state, previous_kernel_results.target_log_prob ]): with tf.name_scope('initialize'): [current_state_parts, step_sizes, current_target_log_prob ] = _prepare_args(self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, maybe_expand=True) max_doublings = tf.convert_to_tensor(self.max_doublings, dtype=tf.int32, name='max_doublings') independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) [ next_state_parts, next_target_log_prob, bounds_satisfied, direction, upper_bounds, lower_bounds ] = _sample_next(self.target_log_prob_fn, current_state_parts, step_sizes, max_doublings, current_target_log_prob, independent_chain_ndims, seed=self._seed_stream()) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), SliceSamplerKernelResults(target_log_prob=next_target_log_prob, bounds_satisfied=bounds_satisfied, direction=direction, upper_bounds=upper_bounds, lower_bounds=lower_bounds), ]
def kernel(target_log_prob_fn, current_state, step_size, num_leapfrog_steps, seed=None, current_target_log_prob=None, current_grads_target_log_prob=None, name=None): """Runs one iteration of Hamiltonian Monte Carlo. Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm that takes a series of gradient-informed steps to produce a Metropolis proposal. This function applies one step of HMC to randomly update the variable `x`. This function can update multiple chains in parallel. It assumes that all leftmost dimensions of `current_state` index independent chain states (and are therefore updated independently). The output of `target_log_prob_fn()` should sum log-probabilities across all event dimensions. Slices along the rightmost dimensions may have different target distributions; for example, `current_state[0, :]` could have a different target distribution from `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of independent chains is `tf.size(target_log_prob_fn(*current_state))`.) #### Examples: ##### Simple chain with warm-up. ```python tfd = tf.contrib.distributions # Tuning acceptance rates: dtype = np.float32 target_accept_rate = 0.631 num_warmup_iter = 500 num_chain_iter = 500 x = tf.get_variable(name="x", initializer=dtype(1)) step_size = tf.get_variable(name="step_size", initializer=dtype(1)) target = tfd.Normal(loc=dtype(0), scale=dtype(1)) new_x, other_results = hmc.kernel( target_log_prob_fn=target.log_prob, current_state=x, step_size=step_size, num_leapfrog_steps=3)[:4] x_update = x.assign(new_x) step_size_update = step_size.assign_add( step_size * tf.where( other_results.acceptance_probs > target_accept_rate, 0.01, -0.01)) warmup = tf.group([x_update, step_size_update]) tf.global_variables_initializer().run() sess.graph.finalize() # No more graph building. # Warm up the sampler and adapt the step size for _ in xrange(num_warmup_iter): sess.run(warmup) # Collect samples without adapting step size samples = np.zeros([num_chain_iter]) for i in xrange(num_chain_iter): _, x_, target_log_prob_, grad_ = sess.run([ x_update, x, other_results.target_log_prob, other_results.grads_target_log_prob]) samples[i] = x_ print(samples.mean(), samples.std()) ``` ##### Sample from more complicated posterior. I.e., ```none W ~ MVN(loc=0, scale=sigma * eye(dims)) for i=1...num_samples: X[i] ~ MVN(loc=0, scale=eye(dims)) eps[i] ~ Normal(loc=0, scale=1) Y[i] = X[i].T * W + eps[i] ``` ```python tfd = tf.contrib.distributions def make_training_data(num_samples, dims, sigma): dt = np.asarray(sigma).dtype zeros = tf.zeros(dims, dtype=dt) x = tfd.MultivariateNormalDiag( loc=zeros).sample(num_samples, seed=1) w = tfd.MultivariateNormalDiag( loc=zeros, scale_identity_multiplier=sigma).sample(seed=2) noise = tfd.Normal( loc=dt(0), scale=dt(1)).sample(num_samples, seed=3) y = tf.tensordot(x, w, axes=[[1], [0]]) + noise return y, x, w def make_prior(sigma, dims): # p(w | sigma) return tfd.MultivariateNormalDiag( loc=tf.zeros([dims], dtype=sigma.dtype), scale_identity_multiplier=sigma) def make_likelihood(x, w): # p(y | x, w) return tfd.MultivariateNormalDiag( loc=tf.tensordot(x, w, axes=[[1], [0]])) # Setup assumptions. dtype = np.float32 num_samples = 150 dims = 10 num_iters = int(5e3) true_sigma = dtype(0.5) y, x, true_weights = make_training_data(num_samples, dims, true_sigma) # Estimate of `log(true_sigma)`. log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0)) sigma = tf.exp(log_sigma) # State of the Markov chain. weights = tf.get_variable( name="weights", initializer=np.random.randn(dims).astype(dtype)) prior = make_prior(sigma, dims) def joint_log_prob_fn(w): # f(w) = log p(w, y | x) return prior.log_prob(w) + make_likelihood(x, w).log_prob(y) weights_update = weights.assign( hmc.kernel(target_log_prob_fn=joint_log_prob, current_state=weights, step_size=0.1, num_leapfrog_steps=5)[0]) with tf.control_dependencies([weights_update]): loss = -prior.log_prob(weights) optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma]) sess.graph.finalize() # No more graph building. tf.global_variables_initializer().run() sigma_history = np.zeros(num_iters, dtype) weights_history = np.zeros([num_iters, dims], dtype) for i in xrange(num_iters): _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights]) weights_history[i, :] = weights_ sigma_history[i] = sigma_ true_weights_ = sess.run(true_weights) # Should converge to something close to true_sigma. plt.plot(sigma_history); plt.ylabel("sigma"); plt.xlabel("iteration"); ``` Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. seed: Python integer to seed the random number generator. current_target_log_prob: (Optional) `Tensor` representing the value of `target_log_prob_fn` at the `current_state`. The only reason to specify this argument is to reduce TF graph size. Default value: `None` (i.e., compute as needed). current_grads_target_log_prob: (Optional) Python list of `Tensor`s representing gradient of `current_target_log_prob` at the `current_state` and wrt the `current_state`. Must have same shape as `current_state`. The only reason to specify this argument is to reduce TF graph size. Default value: `None` (i.e., compute as needed). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "hmc_kernel"). Returns: accepted_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. """ with ops.name_scope( name, "hmc_kernel", [current_state, step_size, num_leapfrog_steps, seed, current_target_log_prob, current_grads_target_log_prob]): with ops.name_scope("initialize"): [current_state_parts, step_sizes, current_target_log_prob, current_grads_target_log_prob] = _prepare_args( target_log_prob_fn, current_state, step_size, current_target_log_prob, current_grads_target_log_prob, maybe_expand=True) independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) current_momentums = [] for s in current_state_parts: current_momentums.append(random_ops.random_normal( shape=array_ops.shape(s), dtype=s.dtype.base_dtype, seed=seed)) seed = distributions_util.gen_new_seed( seed, salt="hmc_kernel_momentums") num_leapfrog_steps = ops.convert_to_tensor( num_leapfrog_steps, dtype=dtypes.int32, name="num_leapfrog_steps") [ proposed_momentums, proposed_state_parts, proposed_target_log_prob, proposed_grads_target_log_prob, ] = _leapfrog_integrator(current_momentums, target_log_prob_fn, current_state_parts, step_sizes, num_leapfrog_steps, current_target_log_prob, current_grads_target_log_prob) energy_change = _compute_energy_change(current_target_log_prob, current_momentums, proposed_target_log_prob, proposed_momentums, independent_chain_ndims) # u < exp(min(-energy, 0)), where u~Uniform[0,1) # ==> -log(u) >= max(e, 0) # ==> -log(u) >= e # (Perhaps surprisingly, we don't have a better way to obtain a random # uniform from positive reals, i.e., `tf.random_uniform(minval=0, # maxval=np.inf)` won't work.) random_uniform = random_ops.random_uniform( shape=array_ops.shape(energy_change), dtype=energy_change.dtype, seed=seed) random_positive = -math_ops.log(random_uniform) is_accepted = random_positive >= energy_change accepted_target_log_prob = array_ops.where(is_accepted, proposed_target_log_prob, current_target_log_prob) accepted_state_parts = [_choose(is_accepted, proposed_state_part, current_state_part, independent_chain_ndims) for current_state_part, proposed_state_part in zip(current_state_parts, proposed_state_parts)] accepted_grads_target_log_prob = [ _choose(is_accepted, proposed_grad, grad, independent_chain_ndims) for proposed_grad, grad in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)] maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] return [ maybe_flatten(accepted_state_parts), KernelResults( acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)), current_grads_target_log_prob=accepted_grads_target_log_prob, current_target_log_prob=accepted_target_log_prob, energy_change=energy_change, is_accepted=is_accepted, proposed_grads_target_log_prob=proposed_grads_target_log_prob, proposed_state=maybe_flatten(proposed_state_parts), proposed_target_log_prob=proposed_target_log_prob, random_positive=random_positive, ), ]
def one_step(self, current_state, previous_kernel_results): """Runs one iteration of Slice Sampler. Args: current_state: `Tensor` or Python `list` of `Tensor`s of fully defined static shape representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. ValueError: if `current_state` does not have a fully defined static shape. TypeError: if `not target_log_prob.dtype.is_floating`. """ with tf.name_scope( name=mcmc_util.make_name(self.name, 'slice', 'one_step'), values=[self.step_size, self.max_doublings, self._seed_stream, current_state, previous_kernel_results.target_log_prob]): with tf.name_scope('initialize'): [ current_state_parts, step_sizes, current_target_log_prob ] = _prepare_args( self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, maybe_expand=True) max_doublings = tf.convert_to_tensor( self.max_doublings, dtype=tf.int32, name='max_doublings') independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) [ next_state_parts, next_target_log_prob, bounds_satisfied, direction, upper_bounds, lower_bounds ] = _sample_next( self.target_log_prob_fn, current_state_parts, step_sizes, max_doublings, current_target_log_prob, independent_chain_ndims, seed=self._seed_stream() ) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), SliceSamplerKernelResults( target_log_prob=next_target_log_prob, bounds_satisfied=bounds_satisfied, direction=direction, upper_bounds=upper_bounds, lower_bounds=lower_bounds ), ]