def testRollStatic(self): with self.assertRaisesRegexp(Exception, 'None'): distribution_util.rotate_transpose(None, 1) for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))): for shift in np.arange(-5, 5): y = distribution_util.rotate_transpose(x, shift) self.assertAllEqual(self._np_rotate_transpose(x, shift), self.evaluate(y)) self.assertAllEqual(np.roll(x.shape, shift), tensorshape_util.as_list(y.shape))
def testRollStatic(self): if tf.executing_eagerly(): error_message = r'Attempt to convert a value \(None\)' else: error_message = 'None values not supported.' with self.assertRaisesRegexp(ValueError, error_message): distribution_util.rotate_transpose(None, 1) for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))): for shift in np.arange(-5, 5): y = distribution_util.rotate_transpose(x, shift) self.assertAllEqual( self._np_rotate_transpose(x, shift), self.evaluate(y)) self.assertAllEqual( np.roll(x.shape, shift), tensorshape_util.as_list(y.shape))
def undo_make_batch_of_event_sample_matrices( self, x, sample_shape, expand_batch_dim=True, name="undo_make_batch_of_event_sample_matrices"): """Reshapes/transposes `Distribution` `Tensor` from B_+E_+S_ to S+B+E. Where: - `B_ = B if B or not expand_batch_dim else [1]`, - `E_ = E if E else [1]`, - `S_ = [tf.reduce_prod(S)]`. This function "reverses" `make_batch_of_event_sample_matrices`. Args: x: `Tensor` of shape `B_+E_+S_`. sample_shape: `Tensor` (1D, `int32`). expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded such that `batch_ndims>=1`. name: Python `str`. The name to give this op. Returns: x: `Tensor`. Input transposed/reshaped to `S+B+E`. """ with self._name_scope(name, values=[x, sample_shape]): x = tf.convert_to_tensor(x, name="x") # x.shape: _B+_E+[prod(S)] sample_shape = tf.convert_to_tensor(sample_shape, name="sample_shape") x = distribution_util.rotate_transpose(x, shift=1) # x.shape: [prod(S)]+_B+_E if self._is_all_constant_helper(self.batch_ndims, self.event_ndims): if self._batch_ndims_is_0 or self._event_ndims_is_0: squeeze_dims = [] if self._event_ndims_is_0: squeeze_dims += [-1] if self._batch_ndims_is_0 and expand_batch_dim: squeeze_dims += [1] if squeeze_dims: x = tf.squeeze(x, axis=squeeze_dims) # x.shape: [prod(S)]+B+E _, batch_shape, event_shape = self.get_shape(x) else: s = (x.shape.as_list() if x.shape.is_fully_defined() else tf.shape(x)) batch_shape = s[1:1 + self.batch_ndims] # Since sample_dims=1 and is left-most, we add 1 to the number of # batch_ndims to get the event start dim. event_start = tf.where( tf.logical_and(expand_batch_dim, self._batch_ndims_is_0), 2, 1 + self.batch_ndims) event_shape = s[event_start:event_start + self.event_ndims] new_shape = tf.concat([sample_shape, batch_shape, event_shape], 0) x = tf.reshape(x, shape=new_shape) # x.shape: S+B+E return x
def testRollDynamic(self): for x_value in (np.ones(1, dtype=np.float32), np.ones([2, 1], dtype=np.float32), np.ones([3, 2, 1], dtype=np.float32)): for shift_value in np.arange(-5, 5).astype(np.int32): x = tf1.placeholder_with_default(x_value, shape=None) shift = tf1.placeholder_with_default(shift_value, shape=None) self.assertAllEqual( self._np_rotate_transpose(x_value, shift_value), self.evaluate(distribution_util.rotate_transpose(x, shift)))
def _sub_diag(nonmatrix): """Get the first sub-diagonal of a shape [N, N, ...] 'non matrix'.""" with tf.name_scope('sub_matrix'): # TODO(b/143702351) Once array_ops.matrix_diag_part_v3 is ready and exposed, # replace the call to matrix_diag_part_v2 below with tf.linalg.matrix_diag. # We can also stop special casing for matrix_dim < 2 at that point. # Until then, OpError raised for 1x1 matricies without static shape. # In fact, non-static shape breaks matrix_diag_part_v2, so we must raise # this message now. # See http://b/138403336 for the TF issue tracker. if not tensorshape_util.is_fully_defined(nonmatrix.shape[:2]): raise ValueError( '`inverse_temperatures did not have statically defined shape, ' 'which breaks tracking of is_swap_{proposed,accepted}. ' 'Please provide an inverse_temperatures with statically known shape.' ) # The sub-matrix of a 1x1 matrix is not defined (throws exception), so in # this special case return an empty matrix. # TODO(b/143702351) Remove this special case handling once # matrix_diag_part_v3 is ready. matrix_dim = ps.size0(nonmatrix) if matrix_dim is not None and matrix_dim < 2: # Shape is [..., 0], so returned tensor is empty, thus contains no # values...and therefore the fact that we use 'ones' doesn't matter. shape = ps.pad(ps.shape(nonmatrix)[2:], paddings=[[0, 1]], constant_values=0) matrix_sub_diag = tf.cast(tf.ones(shape), nonmatrix.dtype) else: # Get first sub-diagonal. `padding_value` is not used (since matrix is # square), but is required for the API since this is raw gen_array_ops. matrix_sub_diag = tf.raw_ops.MatrixDiagPartV2( input=distribution_util.rotate_transpose(nonmatrix, shift=-2), k=ps.convert_to_shape_tensor(-1, dtype=tf.int32), padding_value=tf.cast(0.0, dtype=nonmatrix.dtype)) return distribution_util.rotate_transpose(matrix_sub_diag, shift=1)
def _observation_particles_cov_linop( predicted_observation_particles, ensemble_mean_observations, observation_cov, ): """LinearOperatorLowRankUpdate holding observation noise covariance. All arguments can be derived from `observation_particles_dist`. We pass them as arguments to have a simpler graph, and encourage calling `.sample` once. Args: predicted_observation_particles: Ensemble of state particles fed through the observation function. `observation_particles_dist.mean()` ensemble_mean_observations: Ensemble mean (mean across `axis=0`) of `predicted_observation_particles`. observation_cov: `LinearOperator` defining the observation noise covariance. `_linop_covariance(observation_particles_dist)`. Returns: LinearOperatorLowRankUpdate with covariance the sum of `observation_cov` and the ensemble covariance of `predicted_observation_particles`. """ # In our usual docstring notation, let B be a batch shape, X be the ensemble # of states, and G(X) the deterministic observation transformation of X. Then, # predicted_observations_particles = G(X) (an ensemble) # shape = [n_ensemble] + B + [n_observations] # ensemble_mean_observations = # tf.reduce_mean(predicted_observations, axis=0) # Ensemble mean # Create matrix U with shape B + [n_observations, n_ensemble] so that, with # Cov the ensemble covariance, Cov(G(X)) = UUᵀ. centered_observations = (predicted_observation_particles - ensemble_mean_observations) n_ensemble = tf.cast( tf.shape(centered_observations)[0], centered_observations.dtype) u = distribution_util.rotate_transpose( centered_observations / tf.sqrt(n_ensemble), -1) # cov_operator ~ Γ + Cov(G(X)) return tf.linalg.LinearOperatorLowRankUpdate( base_operator=observation_cov, # = Γ u=u, # UUᵀ = Cov(G(X)) is_self_adjoint=True, is_positive_definite=True)
def make_batch_of_event_sample_matrices( self, x, expand_batch_dim=True, name="make_batch_of_event_sample_matrices"): """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_. Where: - `B_ = B if B or not expand_batch_dim else [1]`, - `E_ = E if E else [1]`, - `S_ = [tf.reduce_prod(S)]`. Args: x: `Tensor`. expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded such that `batch_ndims >= 1`. name: Python `str`. The name to give this op. Returns: x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`. sample_shape: `Tensor` (1D, `int32`). """ with self._name_scope(name, values=[x]): x = tf.convert_to_tensor(x, name="x") # x.shape: S+B+E sample_shape, batch_shape, event_shape = self.get_shape(x) event_shape = distribution_util.pick_vector( self._event_ndims_is_0, [1], event_shape) if expand_batch_dim: batch_shape = distribution_util.pick_vector( self._batch_ndims_is_0, [1], batch_shape) new_shape = tf.concat([[-1], batch_shape, event_shape], 0) x = tf.reshape(x, shape=new_shape) # x.shape: [prod(S)]+B_+E_ x = distribution_util.rotate_transpose(x, shift=-1) # x.shape: B_+E_+[prod(S)] return x, sample_shape
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name): x = tf.convert_to_tensor(x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = ps.rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = distribution_util.rotate_transpose(x, shift) if center: x_rotated = x_rotated - tf.reduce_mean( x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = ps.shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = ps.cast(x_len, np.float64) target_length = ps.pow(np.float64(2.), ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.))) pad_length = ps.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = distribution_util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype_util.is_complex(dtype): if not dtype_util.is_floating(dtype): raise TypeError( 'Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex( x_rotated_pad, dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.signal.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.signal.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not tensorshape_util.is_fully_defined(x_rotated.shape): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name='max_lags') max_lags_ = tf.get_static_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = tensorshape_util.as_list(x_rotated.shape) chopped_shape[-1] = min(x_len, max_lags + 1) tensorshape_util.set_shape(shifted_product_chopped, chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = ps.cast(x_len, dtype_util.real_dtype(dtype)) max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype)) denominator = x_len - ps.range(0., max_lags + 1.) denominator = ps.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return distribution_util.rotate_transpose(shifted_product_rotated, -shift)
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') if self._state_includes_replicas: it_n_replica = inverse_temperatures.shape[0] state_n_replica = init_state[0].shape[0] if ((it_n_replica is not None) and (state_n_replica is not None) and (it_n_replica != state_n_replica)): raise ValueError( 'Number of replicas implied by initial state ({}) must equal ' 'number of replicas implied by inverse_temperatures ({}), but ' 'did not'.format(it_n_replica, state_n_replica)) # We will now replicate each of a possible batch of initial stats, one for # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy] # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means # concatenation and T=shape(inverse_temperature). num_replica = ps.size0(inverse_temperatures) replica_shape = ps.convert_to_shape_tensor([num_replica]) if self._state_includes_replicas: replica_states = init_state else: replica_states = [ tf.broadcast_to( # pylint: disable=g-complex-comprehension x, ps.concat([replica_shape, ps.shape(x)], axis=0), name='replica_states') for x in init_state ] target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( target_log_prob_fn=self.target_log_prob_fn, inverse_temperatures=inverse_temperatures, untempered_log_prob_fn=self.untempered_log_prob_fn, tempered_log_prob_fn=self.tempered_log_prob_fn, ) # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise raise TypeError( '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a second ' '(`seed`) argument. `TransitionKernel` instances now receive seeds ' 'via `one_step`.') replica_results = inner_kernel.bootstrap_results(replica_states) pre_swap_replica_target_log_prob = _get_field( replica_results, 'target_log_prob') replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] inverse_temperatures = bu.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Pretend we did a "null swap", which will always be accepted. swaps = bu.left_justified_broadcast_to(tf.range(num_replica), replica_and_batch_shape) # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape. is_swap_accepted = distribution_util.rotate_transpose(tf.eye( num_replica, batch_shape=batch_shape, dtype=tf.bool), shift=2) return ReplicaExchangeMCKernelResults( post_swap_replica_states=replica_states, pre_swap_replica_results=replica_results, post_swap_replica_results=_set_swapped_fields_to_nan( replica_results), is_swap_proposed=is_swap_accepted, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_accepted), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), inverse_temperatures=self.inverse_temperatures, swaps=swaps, step_count=tf.zeros(shape=(), dtype=tf.int32), seed=samplers.zeros_seed(), potential_energy=tf.zeros_like( pre_swap_replica_target_log_prob), )
def percentile(x, q, axis=None, interpolation=None, keep_dims=False, validate_args=False, preserve_gradients=True, name=None): """Compute the `q`-th percentile(s) of `x`. Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the way from the minimum to the maximum in a sorted copy of `x`. The values and distances of the two nearest neighbors as well as the `interpolation` parameter will determine the percentile if the normalized ranking does not match the location of `q` exactly. This function is the same as the median if `q = 50`, the same as the minimum if `q = 0` and the same as the maximum if `q = 100`. Multiple percentiles can be computed at once by using `1-D` vector `q`. Dimension zero of the returned `Tensor` will index the different percentiles. Compare to `numpy.percentile`. Args: x: Numeric `N-D` `Tensor` with `N > 0`. If `axis` is not `None`, `x` must have statically known number of dimensions. q: Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s). axis: Optional `0-D` or `1-D` integer `Tensor` with constant values. The axis that index independent samples over which to return the desired percentile. If `None` (the default), treat every dimension as a sample dimension, returning a scalar. interpolation : {'nearest', 'linear', 'lower', 'higher', 'midpoint'}. Default value: 'nearest'. This specifies the interpolation method to use when the desired quantile lies between two data points `i < j`: * linear: i + (j - i) * fraction, where fraction is the fractional part of the index surrounded by i and j. * lower: `i`. * higher: `j`. * nearest: `i` or `j`, whichever is nearest. * midpoint: (i + j) / 2. `linear` and `midpoint` interpolation do not work with integer dtypes. keep_dims: Python `bool`. If `True`, the last dimension is kept with size 1 If `False`, the last dimension is removed from the output shape. validate_args: Whether to add runtime checks of argument validity. If False, and arguments are incorrect, correct behavior is not guaranteed. preserve_gradients: Python `bool`. If `True`, ensure that gradient w.r.t the percentile `q` is preserved in the case of linear interpolation. If `False`, the gradient will be (incorrectly) zero when `q` corresponds to a point in `x`. name: A Python string name to give this `Op`. Default is 'percentile' Returns: A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or, if `axis` is `None`, a `rank(q)` `Tensor`. The first `rank(q)` dimensions index quantiles for different values of `q`. Raises: ValueError: If argument 'interpolation' is not an allowed type. ValueError: If interpolation type not compatible with `dtype`. #### Examples ```python # Get 30th percentile with default ('nearest') interpolation. x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=30.) ==> 2.0 # Get 30th percentile with 'linear' interpolation. x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=30., interpolation='linear') ==> 1.9 # Get 30th and 70th percentiles with 'lower' interpolation x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=[30., 70.], interpolation='lower') ==> [1., 3.] # Get 100th percentile (maximum). By default, this is computed over every dim x = [[1., 2.] [3., 4.]] tfp.stats.percentile(x, q=100.) ==> 4. # Treat the leading dim as indexing samples, and find the 100th quantile (max) # over all such samples. x = [[1., 2.] [3., 4.]] tfp.stats.percentile(x, q=100., axis=[0]) ==> [3., 4.] ``` """ name = name or 'percentile' allowed_interpolations = { 'linear', 'lower', 'higher', 'nearest', 'midpoint' } if interpolation is None: interpolation = 'nearest' else: if interpolation not in allowed_interpolations: raise ValueError( 'Argument `interpolation` must be in %s. Found %s' % (allowed_interpolations, interpolation)) with tf1.name_scope(name, values=[x, q]): x = tf.convert_to_tensor(value=x, name='x') if interpolation in {'linear', 'midpoint'} and x.dtype.is_integer: raise TypeError( '{} interpolation not allowed with dtype {}'.format( interpolation, x.dtype)) # Double is needed here and below, else we get the wrong index if the array # is huge along axis. q = tf.cast(q, tf.float64) _get_static_ndims(q, expect_ndims_no_more_than=1) if validate_args: q = distribution_util.with_dependencies([ tf1.assert_rank_in(q, [0, 1]), tf1.assert_greater_equal(q, tf.cast(0., tf.float64)), tf1.assert_less_equal(q, tf.cast(100., tf.float64)) ], q) # Move `axis` dims of `x` to the rightmost, call it `y`. if axis is None: y = tf.reshape(x, [-1]) else: x_ndims = _get_static_ndims(x, expect_static=True, expect_ndims_at_least=1) axis = _make_static_axis_non_negative_list(axis, x_ndims) y = _move_dims_to_flat_end(x, axis, x_ndims, right_end=True) frac_at_q_or_above = 1. - q / 100. # Sort everything, not just the top 'k' entries, which allows multiple calls # to sort only once (under the hood) and use CSE. sorted_y = _sort_tensor(y) d = tf.cast(tf.shape(input=y)[-1], tf.float64) def _get_indices(interp_type): """Get values of y at the indices implied by interp_type.""" # Note `lower` <--> ceiling. Confusing, huh? Due to the fact that # _sort_tensor sorts highest to lowest, tf.ceil corresponds to the higher # index, but the lower value of y! if interp_type == 'lower': indices = tf.math.ceil((d - 1) * frac_at_q_or_above) elif interp_type == 'higher': indices = tf.floor((d - 1) * frac_at_q_or_above) elif interp_type == 'nearest': indices = tf.round((d - 1) * frac_at_q_or_above) # d - 1 will be distinct from d in int32, but not necessarily double. # So clip to avoid out of bounds errors. return tf.clip_by_value(tf.cast(indices, tf.int32), 0, tf.shape(input=y)[-1] - 1) if interpolation in ['nearest', 'lower', 'higher']: gathered_y = tf.gather(sorted_y, _get_indices(interpolation), axis=-1) elif interpolation == 'midpoint': gathered_y = 0.5 * ( tf.gather(sorted_y, _get_indices('lower'), axis=-1) + tf.gather(sorted_y, _get_indices('higher'), axis=-1)) elif interpolation == 'linear': # Copy-paste of docstring on interpolation: # linear: i + (j - i) * fraction, where fraction is the fractional part # of the index surrounded by i and j. larger_y_idx = _get_indices('lower') exact_idx = (d - 1) * frac_at_q_or_above if preserve_gradients: # If q corresponds to a point in x, we will initially have # larger_y_idx == smaller_y_idx. # This results in the gradient w.r.t. fraction being zero (recall `q` # enters only through `fraction`...and see that things cancel). # The fix is to ensure that smaller_y_idx and larger_y_idx are always # separated by exactly 1. smaller_y_idx = tf.maximum(larger_y_idx - 1, 0) larger_y_idx = tf.minimum(smaller_y_idx + 1, tf.shape(input=y)[-1] - 1) fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx else: smaller_y_idx = _get_indices('higher') fraction = tf.math.ceil( (d - 1) * frac_at_q_or_above) - exact_idx fraction = tf.cast(fraction, y.dtype) gathered_y = ( tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction) + tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction) # Propagate NaNs if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64): # Apparently tf.is_nan doesn't like other dtypes nan_batch_members = tf.reduce_any(input_tensor=tf.math.is_nan(x), axis=axis) right_rank_matched_shape = tf.pad( tensor=tf.shape(input=nan_batch_members), paddings=[[0, tf.rank(input=q)]], constant_values=1) nan_batch_members = tf.reshape(nan_batch_members, shape=right_rank_matched_shape) nan = np.array(np.nan, gathered_y.dtype.as_numpy_dtype) gathered_y = tf.where(nan_batch_members, nan, gathered_y) # Expand dimensions if requested if keep_dims: if axis is None: ones_vec = tf.ones(shape=[ _get_best_effort_ndims(x) + _get_best_effort_ndims(q) ], dtype=tf.int32) gathered_y *= tf.ones(ones_vec, dtype=x.dtype) else: gathered_y = _insert_back_keep_dims(gathered_y, axis) # If q is a scalar, then result has the right shape. # If q is a vector, then result has trailing dim of shape q.shape, which # needs to be rotated to dim 0. return distribution_util.rotate_transpose(gathered_y, tf.rank(q))
def find_bins(x, edges, extend_lower_interval=False, extend_upper_interval=False, dtype=None, name=None): """Bin values into discrete intervals. Given `edges = [c0, ..., cK]`, defining intervals `I0 = [c0, c1)`, `I1 = [c1, c2)`, ..., `I_{K-1} = [c_{K-1}, cK]`, This function returns `bins`, such that: `edges[bins[i]] <= x[i] < edges[bins[i] + 1]`. Args: x: Numeric `N-D` `Tensor` with `N > 0`. edges: `Tensor` of same `dtype` as `x`. The first dimension indexes edges of intervals. Must either be `1-D` or have `x.shape[1:] == edges.shape[1:]`. If `rank(edges) > 1`, `edges[k]` designates a shape `edges.shape[1:]` `Tensor` of bin edges for the corresponding dimensions of `x`. extend_lower_interval: Python `bool`. If `True`, extend the lowest interval `I0` to `(-inf, c1]`. extend_upper_interval: Python `bool`. If `True`, extend the upper interval `I_{K-1}` to `[c_{K-1}, +inf)`. dtype: The output type (`int32` or `int64`). `Default value:` `x.dtype`. This effects the output values when `x` is below/above the intervals, which will be `-1/K+1` for `int` types and `NaN` for `float`s. At indices where `x` is `NaN`, the output values will be `0` for `int` types and `NaN` for floats. name: A Python string name to prepend to created ops. Default: 'find_bins' Returns: bins: `Tensor` with same `shape` as `x` and `dtype`. Has whole number values. `bins[i] = k` means the `x[i]` falls into the `kth` bin, ie, `edges[bins[i]] <= x[i] < edges[bins[i] + 1]`. Raises: ValueError: If `edges.shape[0]` is determined to be less than 2. #### Examples Cut a `1-D` array ```python x = [0., 5., 6., 10., 20.] edges = [0., 5., 10.] tfp.stats.find_bins(x, edges) ==> [0., 0., 1., 1., np.nan] ``` Cut `x` into its deciles ```python x = tf.random_uniform(shape=(100, 200)) decile_edges = tfp.stats.quantiles(x, num_quantiles=10) bins = tfp.stats.find_bins(x, edges=decile_edges) bins.shape ==> (100, 200) tf.reduce_mean(bins == 0.) ==> approximately 0.1 tf.reduce_mean(bins == 1.) ==> approximately 0.1 ``` """ # TFP users may be surprised to see the "action" in the leftmost dim of # edges, rather than the rightmost (event) dim. Why? # 1. Most likely you created edges by getting quantiles over samples, and # quantile/percentile return these edges in the leftmost (sample) dim. # 2. Say you have event_shape = [5], then we expect the bin will be different # for all 5 events, so the index of the bin should not be in the event dim. with tf1.name_scope(name, default_name='find_bins', values=[x, edges]): in_type = dtype_util.common_dtype([x, edges], dtype_hint=tf.float32) edges = tf.convert_to_tensor(value=edges, name='edges', dtype=in_type) x = tf.convert_to_tensor(value=x, name='x', dtype=in_type) if (tf.compat.dimension_value(edges.shape[0]) is not None and tf.compat.dimension_value(edges.shape[0]) < 2): raise ValueError( 'First dimension of `edges` must have length > 1 to index 1 or ' 'more bin. Found: {}'.format(edges.shape)) flattening_x = edges.shape.ndims == 1 and x.shape.ndims > 1 if flattening_x: x_orig_shape = tf.shape(input=x) x = tf.reshape(x, [-1]) if dtype is None: dtype = in_type dtype = tf.as_dtype(dtype) # Move first dims into the rightmost. x_permed = distribution_util.rotate_transpose(x, shift=-1) edges_permed = distribution_util.rotate_transpose(edges, shift=-1) # If... # x_permed = [0, 1, 6., 10] # edges = [0, 5, 10.] # ==> almost_output = [0, 1, 2, 2] searchsorted_type = dtype if dtype in [tf.int32, tf.int64] else None almost_output_permed = tf.searchsorted(sorted_sequence=edges_permed, values=x_permed, side='right', out_type=searchsorted_type) # Move the rightmost dims back to the leftmost. almost_output = tf.cast( distribution_util.rotate_transpose(almost_output_permed, shift=1), dtype) # In above example, we want [0, 0, 1, 1], so correct this here. bins = tf.clip_by_value(almost_output - 1, tf.cast(0, dtype), tf.cast(tf.shape(input=edges)[0] - 2, dtype)) if not extend_lower_interval: low_fill = np.nan if dtype.is_floating else -1 bins = tf.where(x < tf.expand_dims(edges[0], 0), tf.cast(low_fill, dtype), bins) if not extend_upper_interval: up_fill = np.nan if dtype.is_floating else tf.shape( input=edges)[0] - 1 bins = tf.where(x > tf.expand_dims(edges[-1], 0), tf.cast(up_fill, dtype), bins) if flattening_x: bins = tf.reshape(bins, x_orig_shape) return bins
def percentile(x, q, axis=None, interpolation=None, keep_dims=False, validate_args=False, name=None): """Compute the `q`-th percentile(s) of `x`. Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the way from the minimum to the maximum in a sorted copy of `x`. The values and distances of the two nearest neighbors as well as the `interpolation` parameter will determine the percentile if the normalized ranking does not match the location of `q` exactly. This function is the same as the median if `q = 50`, the same as the minimum if `q = 0` and the same as the maximum if `q = 100`. Multiple percentiles can be computed at once by using `1-D` vector `q`. Dimension zero of the returned `Tensor` will index the different percentiles. ```python # Get 30th percentile with default ('nearest') interpolation. x = [1., 2., 3., 4.] percentile(x, q=30.) ==> 2.0 # Get 30th and 70th percentiles with 'lower' interpolation x = [1., 2., 3., 4.] percentile(x, q=[30., 70.], interpolation='lower') ==> [1., 3.] # Get 100th percentile (maximum). By default, this is computed over every dim x = [[1., 2.] [3., 4.]] percentile(x, q=100.) ==> 4. # Treat the leading dim as indexing samples, and find the 100th quantile (max) # over all such samples. x = [[1., 2.] [3., 4.]] percentile(x, q=100., axis=[0]) ==> [3., 4.] ``` Compare to `numpy.percentile`. Args: x: Floating point `N-D` `Tensor` with `N > 0`. If `axis` is not `None`, `x` must have statically known number of dimensions. q: Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s). axis: Optional `0-D` or `1-D` integer `Tensor` with constant values. The axis that hold independent samples over which to return the desired percentile. If `None` (the default), treat every dimension as a sample dimension, returning a scalar. interpolation : {'lower', 'higher', 'nearest'}. Default: 'nearest' This optional parameter specifies the interpolation method to use when the desired quantile lies between two data points `i < j`: * lower: `i`. * higher: `j`. * nearest: `i` or `j`, whichever is nearest. keep_dims: Python `bool`. If `True`, the last dimension is kept with size 1 If `False`, the last dimension is removed from the output shape. validate_args: Whether to add runtime checks of argument validity. If False, and arguments are incorrect, correct behavior is not guaranteed. name: A Python string name to give this `Op`. Default is 'percentile' Returns: A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or, if `axis` is `None`, a `rank(q)` `Tensor`. The first `rank(q)` dimensions index quantiles for different values of `q`. Raises: ValueError: If argument 'interpolation' is not an allowed type. """ name = name or 'percentile' allowed_interpolations = {'lower', 'higher', 'nearest'} if interpolation is None: interpolation = 'nearest' else: if interpolation not in allowed_interpolations: raise ValueError( 'Argument `interpolation` must be in %s. Found %s' % (allowed_interpolations, interpolation)) with tf.name_scope(name, values=[x, q]): x = tf.convert_to_tensor(x, name='x') # Double is needed here and below, else we get the wrong index if the array # is huge along axis. q = tf.to_double(q, name='q') _get_static_ndims(q, expect_ndims_no_more_than=1) if validate_args: q = control_flow_ops.with_dependencies([ tf.assert_rank_in(q, [0, 1]), tf.assert_greater_equal(q, tf.to_double(0.)), tf.assert_less_equal(q, tf.to_double(100.)) ], q) if axis is None: y = tf.reshape(x, [-1]) else: axis = tf.convert_to_tensor(axis, name='axis') tf.assert_integer(axis) axis_ndims = _get_static_ndims(axis, expect_static=True, expect_ndims_no_more_than=1) axis_const = tensor_util.constant_value(axis) if axis_const is None: raise ValueError( 'Expected argument `axis` to be statically available. Found: %s' % axis) axis = axis_const if axis_ndims == 0: axis = [axis] axis = [int(a) for a in axis] x_ndims = _get_static_ndims(x, expect_static=True, expect_ndims_at_least=1) axis = _make_static_axis_non_negative(axis, x_ndims) # Move dims in axis to the end, since _sort_tensor, which calls top_k, # only sorts the last dim. y = _move_dims_to_flat_end(x, axis, x_ndims) frac_at_q_or_above = 1. - q / 100. d = tf.to_double(tf.shape(y)[-1]) if interpolation == 'lower': indices = tf.ceil((d - 1) * frac_at_q_or_above) elif interpolation == 'higher': indices = tf.floor((d - 1) * frac_at_q_or_above) elif interpolation == 'nearest': indices = tf.round((d - 1) * frac_at_q_or_above) # If d is gigantic, then we would have d == d - 1, even in double... So # let's use max/min to avoid out of bounds errors. d = tf.shape(y)[-1] # d - 1 will be distinct from d in int32. indices = tf.clip_by_value(tf.to_int32(indices), 0, d - 1) # Sort everything, not just the top 'k' entries, which allows multiple calls # to sort only once (under the hood) and use CSE. sorted_y = _sort_tensor(y) # Gather the indices along the sorted (last) dimension. # If q is a vector, the last dim of gathered_y indexes different q[i]. gathered_y = tf.gather(sorted_y, indices, axis=-1) if keep_dims: if axis is None: ones_vec = tf.ones(shape=[ _get_best_effort_ndims(x) + _get_best_effort_ndims(q) ], dtype=tf.int32) gathered_y *= tf.ones(ones_vec, dtype=x.dtype) else: gathered_y = _insert_back_keep_dims(gathered_y, axis) # If q is a scalar, then result has the right shape. # If q is a vector, then result has trailing dim of shape q.shape, which # needs to be rotated to dim 0. return util.rotate_transpose(gathered_y, tf.rank(q))
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') if self._state_includes_replicas: it_n_replica = inverse_temperatures.shape[0] state_n_replica = init_state[0].shape[0] if ((it_n_replica is not None) and (state_n_replica is not None) and (it_n_replica != state_n_replica)): raise ValueError( 'Number of replicas implied by initial state ({}) must equal ' 'number of replicas implied by inverse_temperatures ({}), but ' 'did not'.format(it_n_replica, state_n_replica)) # We will now replicate each of a possible batch of initial stats, one for # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy] # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means # concatenation and T=shape(inverse_temperature). num_replica = ps.size0(inverse_temperatures) replica_shape = tf.convert_to_tensor([num_replica]) if self._state_includes_replicas: replica_states = init_state else: replica_states = [ tf.broadcast_to( # pylint: disable=g-complex-comprehension x, ps.concat([replica_shape, ps.shape(x)], axis=0), name='replica_states') for x in init_state ] target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( self.target_log_prob_fn, inverse_temperatures) # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.make_kernel_fn`. # In other words: # - We try `make_kernel_fn` without a seed first; this is the future. The # kernel will receive a seed later, as part of `one_step`. # - If the user code doesn't like that (Python complains about a missing # required argument), we fall back to the previous behavior and warn. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The second (`seed`) argument to `ReplicaExchangeMC`s ' '`make_kernel_fn` is deprecated. `TransitionKernel` instances now ' 'receive seeds via `bootstrap_results` and `one_step`. This ' 'fallback may become an error 2020-09-20.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, self._seed_stream()) replica_results = inner_kernel.bootstrap_results(replica_states) pre_swap_replica_target_log_prob = _get_field( replica_results, 'target_log_prob') replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Pretend we did a "null swap", which will always be accepted. swaps = mcmc_util.left_justified_broadcast_to( tf.range(num_replica), replica_and_batch_shape) # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape. is_swap_accepted = distribution_util.rotate_transpose(tf.eye( num_replica, batch_shape=batch_shape, dtype=tf.bool), shift=2) post_swap_replica_results = _make_post_swap_replica_results( replica_results, inverse_temperatures, inverse_temperatures, is_swap_accepted[0], lambda x: x, ) return ReplicaExchangeMCKernelResults( post_swap_replica_states=replica_states, pre_swap_replica_results=replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_accepted, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_accepted), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), inverse_temperatures=self.inverse_temperatures, swaps=swaps, step_count=tf.zeros(shape=(), dtype=tf.int32), seed=samplers.zeros_seed(), )
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') # We will now replicate each of a possible batch of initial stats, one for # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy] # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means # concatenation and T=shape(inverse_temperature). num_replica = prefer_static.size0(inverse_temperatures) replica_shape = tf.convert_to_tensor([num_replica]) replica_states = [ tf.broadcast_to( # pylint: disable=g-complex-comprehension x, prefer_static.concat( [replica_shape, prefer_static.shape(x)], axis=0), name='replica_states') for x in init_state ] inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable _make_replica_target_log_prob_fn(self.target_log_prob_fn, inverse_temperatures), self._seed_stream()) replica_results = inner_kernel.bootstrap_results(replica_states) pre_swap_replica_target_log_prob = _get_field( replica_results, 'target_log_prob') replica_and_batch_shape = prefer_static.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Pretend we did a "null swap", which will always be accepted. swaps = mcmc_util.left_justified_broadcast_to( tf.range(num_replica), replica_and_batch_shape) # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape. is_swap_accepted = distribution_util.rotate_transpose(tf.eye( num_replica, batch_shape=batch_shape, dtype=tf.bool), shift=2) post_swap_replica_results = _make_post_swap_replica_results( replica_results, inverse_temperatures, inverse_temperatures, is_swap_accepted[0], lambda x: x, ) return ReplicaExchangeMCKernelResults( post_swap_replica_states=replica_states, pre_swap_replica_results=replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_accepted, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_accepted), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), inverse_temperatures=self.inverse_temperatures, swaps=swaps, )
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name, values=[x]): x = tf.convert_to_tensor(x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = util.prefer_static_rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = util.rotate_transpose(x, shift) if center: x_rotated -= tf.reduce_mean(x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = util.prefer_static_shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = tf.cast(x_len, np.float64) target_length = tf.pow( np.float64(2.), tf.ceil(tf.log(x_len_float64 * 2) / np.log(2.))) pad_length = tf.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype.is_complex: if not dtype.is_floating: raise TypeError('Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex(x_rotated_pad, dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not x_rotated.shape.is_fully_defined(): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name='max_lags') max_lags_ = tf.contrib.util.constant_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = x_rotated.shape.as_list() chopped_shape[-1] = min(x_len, max_lags + 1) shifted_product_chopped.set_shape(chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = tf.cast(x_len, dtype.real_dtype) max_lags = tf.cast(max_lags, dtype.real_dtype) denominator = x_len - tf.range(0., max_lags + 1.) denominator = tf.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return util.rotate_transpose(shifted_product_rotated, -shift)
def _get_exchanged_states(self, old_states, exchange_proposed, exchange_proposed_n, sampled_replica_states, sampled_replica_results): """Get list of TensorArrays holding exchanged states, and zeros.""" with tf1.name_scope('get_exchanged_states'): target_log_probs = [] for replica in range(self.num_replica): replica_log_prob = _get_field(sampled_replica_results[replica], 'target_log_prob') inverse_temp = self.inverse_temperatures[replica] target_log_probs.append(replica_log_prob / inverse_temp) target_log_probs = tf.stack(target_log_probs, axis=0) dtype = target_log_probs.dtype num_state_parts = len(sampled_replica_states[0]) # exchanged_states[k][i] is Tensor of (new) state part k, for replica i. # The `k` will be known statically, and `i` is a Tensor. # We will insert values into indices `i` for every replica with a proposed # exchange. exchanged_states = [ tf.TensorArray( dtype, size=self.num_replica, dynamic_size=False, tensor_array_name='exchanged_states', # State part k has same shape, regardless of replica. So use 0. element_shape=sampled_replica_states[0][k].shape) for k in range(num_state_parts) ] # Two TensorArrays, for KernelResults only. if self._exchange_between_adjacent_only: # Since exchanges are between adjacent only, we track exchanges by the # index of the edge between replicas. E.g., if we have replicas # [0, 1, 2, 3], then edge index 0 is for exchanges between replicas 0 # and 1. is_exchange_proposed_for_kr = tf.TensorArray( dtype=tf.bool, # Initialized to False size=self.num_replica - 1, dynamic_size=False, tensor_array_name='is_exchange_proposed_for_kr', element_shape=tf.TensorShape([])) is_exchange_accepted_for_kr = tf.TensorArray( dtype=tf.bool, # Initialized to False size=self.num_replica - 1, dynamic_size=False, tensor_array_name='is_exchange_accepted_for_kr', element_shape=target_log_probs[-1].shape) else: is_exchange_proposed_for_kr = tf.convert_to_tensor(np.nan) is_exchange_accepted_for_kr = tf.convert_to_tensor(np.nan) # Draw random variables here, to avoid sampling in the loop (and losing # reproducibility). This may mean we sample too many, but we will always # have enough. sample_shape = tf.concat( ([self.num_replica // 2 ], tf.shape(input=target_log_probs)[1:]), axis=0) log_uniforms = tf.math.log( tf.random.uniform(shape=sample_shape, dtype=dtype, seed=self._seed_stream())) def _swap(is_exchange_accepted, x, y): """Swap batches of x, y where accepted.""" with tf1.name_scope('swap_where_exchange_accepted'): new_x = mcmc_util.choose(is_exchange_accepted, y, x) new_y = mcmc_util.choose(is_exchange_accepted, x, y) return new_x, new_y def cond(i, unused_exchanged_states, unused_is_exchanged_for_kr, unused_is_exchange_accepted_for_kr): return i < exchange_proposed_n def body(i, exchanged_states, is_exchange_proposed_for_kr, is_exchange_accepted_for_kr): """Body of while loop for exchanging states.""" # Propose exchange between replicas indexed by m and n. m, n = tf.unstack(exchange_proposed[i]) # Construct log_accept_ratio: -temp_diff * target_log_prob_diff. # Note target_log_prob_diff = -EnergyDiff (common definition is in terms # of energy). temp_diff = self.inverse_temperatures[ m] - self.inverse_temperatures[n] # Difference of target log probs may be +- Inf or NaN. We want the # product of this with the temperature difference to have "alt value" of # -Inf. log_accept_ratio = mcmc_util.safe_sum([ -temp_diff * target_log_probs[m], temp_diff * target_log_probs[n] ]) is_exchange_accepted = log_uniforms[i] < log_accept_ratio if self._exchange_between_adjacent_only: exchange_edge = tf.minimum(m, n) is_exchange_proposed_for_kr = is_exchange_proposed_for_kr.write( exchange_edge, True) is_exchange_accepted_for_kr = is_exchange_accepted_for_kr.write( exchange_edge, is_exchange_accepted) for k in range(num_state_parts): new_m, new_n = _swap(is_exchange_accepted, old_states[k].read(m), old_states[k].read(n)) exchanged_states[k] = exchanged_states[k].write(m, new_m) exchanged_states[k] = exchanged_states[k].write(n, new_n) return (i + 1, exchanged_states, is_exchange_proposed_for_kr, is_exchange_accepted_for_kr) # At this point, exchanged_states[k] is a length num_replicas TensorArray. (exchanged_states, is_exchange_proposed_for_kr, is_exchange_accepted_for_kr) = tf.while_loop( cond=cond, body=body, loop_vars=[ tf.constant(0), exchanged_states, is_exchange_proposed_for_kr, is_exchange_accepted_for_kr, ])[1:] # Remove `i` if self._exchange_between_adjacent_only: # Stack to give shape [self.num_replica] is_exchange_proposed_for_kr = is_exchange_proposed_for_kr.stack( ) is_exchange_proposed_for_kr.set_shape([self.num_replica - 1]) # Stack on axis=-1 to give shape batch_shape + [self.num_replica] # ...TensorArray.stack stacks on axis=0, and doesn't take an axis kwarg, # so must rotate_transpose. is_exchange_accepted_for_kr = distribution_util.rotate_transpose( is_exchange_accepted_for_kr.stack(), shift=-1) is_exchange_accepted_for_kr.set_shape( target_log_probs[-1].shape.concatenate(self.num_replica - 1)) return (exchanged_states, is_exchange_proposed_for_kr, is_exchange_accepted_for_kr)