def pad_batch_dimension_for_multiple_chains( observed_time_series, model, chain_batch_shape): """"Expand the observed time series with extra batch dimension(s).""" # Running with multiple chains introduces an extra batch dimension. In # general we also need to pad the observed time series with a matching batch # dimension. # # For example, suppose our model has batch shape [3, 4] and # the observed time series has shape `concat([[5], [3, 4], [100])`, # corresponding to `sample_shape`, `batch_shape`, and `num_timesteps` # respectively. The model will produce distributions with batch shape # `concat([chain_batch_shape, [3, 4]])`, so we pad `observed_time_series` to # have matching shape `[5, 1, 3, 4, 100]`, where the added `1` dimension # between the sample and batch shapes will broadcast to `chain_batch_shape`. observed_time_series = maybe_expand_trailing_dim( observed_time_series) # Guarantee `event_ndims=2` event_ndims = 2 # event_shape = [num_timesteps, observation_size=1] model_batch_ndims = ( model.batch_shape.ndims if model.batch_shape.ndims is not None else tf.shape(input=model.batch_shape_tensor())[0]) # Compute ndims from chain_batch_shape. chain_batch_shape = tf.convert_to_tensor( value=chain_batch_shape, name='chain_batch_shape', dtype=tf.int32) if not chain_batch_shape.shape.is_fully_defined(): raise ValueError('Batch shape must have static rank. (given: {})'.format( chain_batch_shape)) if chain_batch_shape.shape.ndims == 0: # expand int `k` to `[k]`. chain_batch_shape = chain_batch_shape[tf.newaxis] chain_batch_ndims = tf.compat.dimension_value(chain_batch_shape.shape[0]) def do_padding(observed_time_series_tensor): current_sample_shape = tf.shape( input=observed_time_series_tensor)[:-(model_batch_ndims + event_ndims)] current_batch_and_event_shape = tf.shape( input=observed_time_series_tensor)[-(model_batch_ndims + event_ndims):] return tf.reshape( tensor=observed_time_series_tensor, shape=tf.concat([ current_sample_shape, tf.ones([chain_batch_ndims], dtype=tf.int32), current_batch_and_event_shape], axis=0)) # Padding is only needed if the observed time series has sample shape. observed_time_series = prefer_static.cond( (dist_util.prefer_static_rank(observed_time_series) > model_batch_ndims + event_ndims), lambda: do_padding(observed_time_series), lambda: observed_time_series) return observed_time_series
def _batch_transpose(mat): """Transpose a possibly batched matrix. Args: mat: A `tf.Tensor` of shape `[..., n, m]`. Returns: A tensor of shape `[..., m, n]` with matching batch dimensions. """ n = distribution_util.prefer_static_rank(mat) perm = tf.range(n) perm = tf.concat([perm[:-2], [perm[-1], perm[-2]]], axis=0) return tf.transpose(a=mat, perm=perm)
def _mul_right(mat, vec): """Computes the product of a square matrix with a vector on the right. Note this accepts a generalized square matrix `M`, i.e. of shape `s + s` with `rank(s) >= 1`, a generalized vector `v` of shape `s`, and computes the product `M.v` (also of shape `s`). Furthermore, the shapes may be fully dynamic. Examples: v = tf.constant([0, 1]) M = tf.constant([[0, 1], [2, 3]]) _mul_right(M, v) # => [1, 3] v = tf.reshape(tf.range(6), shape=(2, 3)) # => [[0, 1, 2], # [3, 4, 5]] M = tf.reshape(tf.range(36), shape=(2, 3, 2, 3)) _mul_right(M, v) # => [[ 55, 145, 235], # [325, 415, 505]] Args: mat: A `tf.Tensor` of shape `s + s`. vec: A `tf.Tensor` of shape `s`. Returns: A tensor with the result of the product (also of shape `s`). """ contraction_axes = tf.range(-distribution_util.prefer_static_rank(vec), 0) result = tf.tensordot(mat, vec, axes=tf.stack([contraction_axes, contraction_axes])) # This last reshape is needed to help with inference about the shape # information, otherwise a partially-known shape would become completely # unknown. return tf.reshape(result, distribution_util.prefer_static_shape(vec))
def _mul_right(mat, vec): """Computes the product of a square matrix with a vector on the right. Note this accepts a generalized square matrix `M`, i.e. of shape `s + s` with `rank(s) >= 1`, a generalized vector `v` of shape `s`, and computes the product `M.v` (also of shape `s`). Furthermore, the shapes may be fully dynamic. Examples: v = tf.constant([0, 1]) M = tf.constant([[0, 1], [2, 3]]) _mul_right(M, v) # => [1, 3] v = tf.reshape(tf.range(6), shape=(2, 3)) # => [[0, 1, 2], # [3, 4, 5]] M = tf.reshape(tf.range(36), shape=(2, 3, 2, 3)) _mul_right(M, v) # => [[ 55, 145, 235], # [325, 415, 505]] Args: mat: A `tf.Tensor` of shape `s + s`. vec: A `tf.Tensor` of shape `s`. Returns: A tensor with the result of the product (also of shape `s`). """ contraction_axes = tf.range(-distribution_util.prefer_static_rank(vec), 0) result = tf.tensordot(mat, vec, axes=tf.stack([contraction_axes, contraction_axes])) # This last reshape is needed to help with inference about the shape # information, otherwise a partially-known shape would become completely # unknown. return tf.reshape(result, distribution_util.prefer_static_shape(vec))
def __init__(self, skewness, tailweight, loc, scale, validate_args=False, allow_nan_stats=True, name=None): """Construct Johnson's SU distributions. The distributions have shape parameteres `tailweight` and `skewness`, mean `loc`, and scale `scale`. The parameters `tailweight`, `skewness`, `loc`, and `scale` must be shaped in a way that supports broadcasting (e.g. `skewness + tailweight + loc + scale` is a valid operation). Args: skewness: Floating-point `Tensor`. Skewness of the distribution(s). tailweight: Floating-point `Tensor`. Tail weight of the distribution(s). `tailweight` must contain only positive values. loc: Floating-point `Tensor`. The mean(s) of the distribution(s). scale: Floating-point `Tensor`. The scaling factor(s) for the distribution(s). Note that `scale` is not technically the standard deviation of this distribution but has semantics more similar to standard deviation than variance. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value '`NaN`' to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if any of skewness, tailweight, loc and scale are different dtypes. """ parameters = dict(locals()) with tf.name_scope(name or 'JohnsonSU') as name: dtype = dtype_util.common_dtype([skewness, tailweight, loc, scale], tf.float32) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name='skewness', dtype=dtype) self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name='tailweight', dtype=dtype) self._loc = tensor_util.convert_nonref_to_tensor(loc, name='loc', dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name='scale', dtype=dtype) norm_shift = invert_bijector.Invert( shift_bijector.Shift(shift=self._skewness, validate_args=validate_args)) norm_scale = invert_bijector.Invert( scale_bijector.Scale(scale=self._tailweight, validate_args=validate_args)) sinh = sinh_bijector.Sinh(validate_args=validate_args) scale = scale_bijector.Scale(scale=self._scale, validate_args=validate_args) shift = shift_bijector.Shift(shift=self._loc, validate_args=validate_args) bijector = shift(scale(sinh(norm_scale(norm_shift)))) batch_rank = ps.reduce_max([ distribution_util.prefer_static_rank(x) for x in (self._skewness, self._tailweight, self._loc, self._scale) ]) super(JohnsonSU, self).__init__( # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden # `batch_shape` and `batch_shape_tensor` when # TransformedDistribution's bijector can modify its `batch_shape`. distribution=normal.Normal(loc=tf.zeros(ps.ones( batch_rank, tf.int32), dtype=dtype), scale=tf.ones([], dtype=dtype), validate_args=validate_args, allow_nan_stats=allow_nan_stats), bijector=bijector, validate_args=validate_args, parameters=parameters, name=name)
def one_step(self, current_state, previous_kernel_results): """Runs one iteration of Slice Sampler. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. TypeError: if `not target_log_prob.dtype.is_floating`. """ with tf.compat.v1.name_scope( name=mcmc_util.make_name(self.name, 'slice', 'one_step'), values=[ self.step_size, self.max_doublings, self._seed_stream, current_state, previous_kernel_results.target_log_prob ]): with tf.compat.v1.name_scope('initialize'): [ current_state_parts, step_sizes, current_target_log_prob ] = _prepare_args( self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, maybe_expand=True) max_doublings = tf.convert_to_tensor( value=self.max_doublings, dtype=tf.int32, name='max_doublings') independent_chain_ndims = distribution_util.prefer_static_rank( current_target_log_prob) [ next_state_parts, next_target_log_prob, bounds_satisfied, direction, upper_bounds, lower_bounds ] = _sample_next( self.target_log_prob_fn, current_state_parts, step_sizes, max_doublings, current_target_log_prob, independent_chain_ndims, seed=self._seed_stream() ) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), SliceSamplerKernelResults( target_log_prob=next_target_log_prob, bounds_satisfied=bounds_satisfied, direction=direction, upper_bounds=upper_bounds, lower_bounds=lower_bounds ), ]
def one_step(self, current_state, previous_kernel_results): with tf.name_scope( name=mcmc_util.make_name(self.name, 'hmc', 'one_step'), values=[self.step_size, self.num_leapfrog_steps, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob]): [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) independent_chain_ndims = distribution_util.prefer_static_rank( current_target_log_prob) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append(tf.random_normal( shape=tf.shape(x), dtype=x.dtype.base_dtype, seed=self._seed_stream())) def _leapfrog_one_step(*args): """Closure representing computation done during each leapfrog step.""" return _leapfrog_integrator_one_step( target_log_prob_fn=self.target_log_prob_fn, independent_chain_ndims=independent_chain_ndims, step_sizes=step_sizes, current_momentum_parts=args[0], current_state_parts=args[1], current_target_log_prob=args[2], current_target_log_prob_grad_parts=args[3], state_gradients_are_stopped=self.state_gradients_are_stopped) num_leapfrog_steps = tf.convert_to_tensor( self.num_leapfrog_steps, dtype=tf.int64, name='num_leapfrog_steps') [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = tf.while_loop( cond=lambda i, *args: i < num_leapfrog_steps, body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args)), loop_vars=[ tf.zeros([], tf.int64, name='iter'), current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts ])[1:] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, ), ]
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name, values=[x]): x = tf.convert_to_tensor(value=x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = util.prefer_static_rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = util.rotate_transpose(x, shift) if center: x_rotated -= tf.reduce_mean( input_tensor=x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = util.prefer_static_shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = tf.cast(x_len, np.float64) target_length = tf.pow( np.float64(2.), tf.math.ceil( tf.math.log(x_len_float64 * 2) / np.log(2.))) pad_length = tf.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype.is_complex: if not dtype.is_floating: raise TypeError('Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex(x_rotated_pad, dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.signal.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.signal.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not x_rotated.shape.is_fully_defined(): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(value=max_lags, name='max_lags') max_lags_ = tf.get_static_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = x_rotated.shape.as_list() chopped_shape[-1] = min(x_len, max_lags + 1) shifted_product_chopped.set_shape(chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = tf.cast(x_len, dtype.real_dtype) max_lags = tf.cast(max_lags, dtype.real_dtype) denominator = x_len - tf.range(0., max_lags + 1.) denominator = tf.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return util.rotate_transpose(shifted_product_rotated, -shift)
def testDynamicRankEndsUpBeingScalar(self): if tf.executing_eagerly(): return x = tf1.placeholder_with_default(np.array(1, dtype=np.int32), shape=None) rank = distribution_util.prefer_static_rank(x) self.assertAllEqual(0, self.evaluate(rank))
def __init__(self, loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name='SinhArcsinh'): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape (indexing batch dimensions). They must all have the same `dtype`. Args: loc: Floating-point `Tensor`. scale: `Tensor` of same `dtype` as `loc`. skewness: Skewness parameter. Default is `0.0` (no skew). tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. Must have a batch shape to which the shapes of `loc`, `scale`, `skewness`, and `tailweight` all broadcast. Default is `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted shape of the parameters. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a `SinhArcsinh` sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight], tf.float32) self._loc = tensor_util.convert_nonref_to_tensor(loc, name='loc', dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name='scale', dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if has_default_skewness else skewness self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name='tailweight', dtype=dtype) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name='skewness', dtype=dtype) # Recall, with Z a random variable, # Y := loc + scale * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C # C := 2 / F_0(2) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) if distribution is None: batch_rank = tf.reduce_max([ distribution_util.prefer_static_rank(x) for x in (self._skewness, self._tailweight, self._loc, self._scale) ]) # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden # `batch_shape` and `batch_shape_tensor` when # TransformedDistribution's bijector can modify its `batch_shape`. distribution = normal.Normal(loc=tf.zeros(tf.ones( batch_rank, tf.int32), dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats, validate_args=validate_args) # Make the SAS bijector, 'F'. f = sinh_arcsinh_bijector.SinhArcsinh(skewness=self._skewness, tailweight=self._tailweight, validate_args=validate_args) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) affine = affine_scalar_bijector.AffineScalar( shift=self._loc, scale=self._scale, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__(distribution=distribution, bijector=bijector, validate_args=validate_args, name=name) self._parameters = parameters
def testScalarTensor(self): x = tf.constant(1.) rank = distribution_util.prefer_static_rank(x) if not tf.executing_eagerly(): self.assertIsInstance(rank, np.ndarray) self.assertEqual(0, rank)
def testNonEmptyConstantTensor(self): x = tf.zeros([2, 3, 4]) rank = distribution_util.prefer_static_rank(x) if not tf.executing_eagerly(): self.assertIsInstance(rank, np.ndarray) self.assertEqual(3, rank)
def one_step(self, current_state, previous_kernel_results): with tf.compat.v2.name_scope( mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) self.restoreShapes = [] for x in current_state_parts: n = 1 shape = x.shape for m in shape: n *= m self.restoreShapes.append([shape, n]) current_state_parts = [ tf.reshape(part, [-1]) for part in current_state_parts ] current_state_parts = tf.concat(current_state_parts, -1) temp = [] #print(current_state_parts) for x in range(current_state_parts.shape[0]): temp.append(current_state_parts[x]) current_state_parts = temp #print(current_state_parts) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append( tf.random.normal(shape=tf.shape(input=x), dtype=self._momentum_dtype or x.dtype.base_dtype, seed=self._seed_stream())) next_state_parts, initial_kinetic, final_kinetic, final_target_log_prob = self.run_integrator( step_sizes, num_leapfrog_steps, current_momentum_parts, current_state_parts) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] independent_chain_ndims = distribution_util.prefer_static_rank( current_target_log_prob) next_state_parts = maybe_flatten(next_state_parts) new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( initial_kinetic, final_kinetic, independent_chain_ndims), target_log_prob=final_target_log_prob) argv = next_state_parts #[0] next_state_parts = [] index = 0 #print(self.restoreShapes) for info in self.restoreShapes: next_state_parts.append( tf.reshape(argv[index:index + info[1]], info[0])) index += info[1] return next_state_parts, new_kernel_results
def one_step(self, current_state, previous_kernel_results): with tf.compat.v2.name_scope( mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append( tf.random.normal(shape=tf.shape(input=x), dtype=self._momentum_dtype or x.dtype.base_dtype, seed=self._seed_stream())) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator(current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] independent_chain_ndims = distribution_util.prefer_static_rank( current_target_log_prob) new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, ) return maybe_flatten(next_state_parts), new_kernel_results
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name, values=[x]): x = tf.convert_to_tensor(x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = util.prefer_static_rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = util.rotate_transpose(x, shift) if center: x_rotated -= tf.reduce_mean(x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = util.prefer_static_shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = tf.cast(x_len, np.float64) target_length = tf.pow( np.float64(2.), tf.ceil(tf.log(x_len_float64 * 2) / np.log(2.))) pad_length = tf.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype.is_complex: if not dtype.is_floating: raise TypeError('Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex(x_rotated_pad, dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not x_rotated.shape.is_fully_defined(): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name='max_lags') max_lags_ = tf.contrib.util.constant_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = x_rotated.shape.as_list() chopped_shape[-1] = min(x_len, max_lags + 1) shifted_product_chopped.set_shape(chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = tf.cast(x_len, dtype.real_dtype) max_lags = tf.cast(max_lags, dtype.real_dtype) denominator = x_len - tf.range(0., max_lags + 1.) denominator = tf.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return util.rotate_transpose(shifted_product_rotated, -shift)
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(name=mcmc_util.make_name(self.name, 'mala', 'one_step'), values=[ self.step_size, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, previous_kernel_results.volatility, previous_kernel_results.diffusion_drift ]): with tf.name_scope('initialize'): # Prepare input arguments to be passed to `_euler_method`. [ current_state_parts, step_size_parts, current_target_log_prob, _, # grads_target_log_prob current_volatility_parts, _, # grads_volatility current_drift_parts, ] = _prepare_args( self.target_log_prob_fn, self.volatility_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, previous_kernel_results.volatility, previous_kernel_results.grads_volatility, previous_kernel_results.diffusion_drift, self.parallel_iterations) random_draw_parts = [] for s in current_state_parts: random_draw_parts.append( tf.random_normal(shape=tf.shape(s), dtype=s.dtype.base_dtype, seed=self._seed_stream())) # Number of independent chains run by the algorithm. independent_chain_ndims = distribution_util.prefer_static_rank( current_target_log_prob) # Generate the next state of the algorithm using Euler-Maruyama method. next_state_parts = _euler_method(random_draw_parts, current_state_parts, current_drift_parts, step_size_parts, current_volatility_parts) # Compute helper `UncalibratedLangevinKernelResults` to be processed by # `_compute_log_acceptance_correction` and in the next iteration of # `one_step` function. [ _, # state_parts _, # step_sizes next_target_log_prob, next_grads_target_log_prob, next_volatility_parts, next_grads_volatility, next_drift_parts, ] = _prepare_args(self.target_log_prob_fn, self.volatility_fn, next_state_parts, step_size_parts, parallel_iterations=self.parallel_iterations) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] # Decide whether to compute the acceptance ratio log_acceptance_correction_compute = _compute_log_acceptance_correction( current_state_parts, next_state_parts, current_volatility_parts, next_volatility_parts, current_drift_parts, next_drift_parts, step_size_parts, independent_chain_ndims) log_acceptance_correction_skip = tf.zeros_like( next_target_log_prob) log_acceptance_correction = tf.cond( self.compute_acceptance, lambda: log_acceptance_correction_compute, lambda: log_acceptance_correction_skip) return [ maybe_flatten(next_state_parts), UncalibratedLangevinKernelResults( log_acceptance_correction=log_acceptance_correction, target_log_prob=next_target_log_prob, grads_target_log_prob=next_grads_target_log_prob, volatility=maybe_flatten(next_volatility_parts), grads_volatility=next_grads_volatility, diffusion_drift=next_drift_parts), ]
def one_step(self, current_state, previous_kernel_results): with tf.compat.v1.name_scope( name=mcmc_util.make_name(self.name, 'hmc', 'one_step'), values=[ self.step_size, self.num_leapfrog_steps, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob ]): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) independent_chain_ndims = distribution_util.prefer_static_rank( current_target_log_prob) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append( tf.random.normal( shape=tf.shape(input=x), dtype=x.dtype.base_dtype, seed=self._seed_stream())) def _leapfrog_one_step(*args): """Closure representing computation done during each leapfrog step.""" return _leapfrog_integrator_one_step( target_log_prob_fn=self.target_log_prob_fn, independent_chain_ndims=independent_chain_ndims, step_sizes=step_sizes, current_momentum_parts=args[0], current_state_parts=args[1], current_target_log_prob=args[2], current_target_log_prob_grad_parts=args[3], state_gradients_are_stopped=self.state_gradients_are_stopped) num_leapfrog_steps = tf.convert_to_tensor( value=self.num_leapfrog_steps, dtype=tf.int32, name='num_leapfrog_steps') [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = tf.while_loop( cond=lambda i, *args: i < num_leapfrog_steps, body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args)), loop_vars=[ tf.zeros([], tf.int32, name='iter'), current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts ])[1:] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, ) return maybe_flatten(next_state_parts), new_kernel_results
def testDynamicRankEndsUpBeingNonEmpty(self): if tf.executing_eagerly(): return x = tf1.placeholder_with_default(np.zeros([2, 3], dtype=np.float64), shape=None) rank = distribution_util.prefer_static_rank(x) self.assertAllEqual(2, self.evaluate(rank))