def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape): """Slices a single parameter of a distribution. Args: param: A `Tensor`, the original parameter to slice. param_event_ndims: `int` event parameterization rank for this parameter. slices: A `tuple` of normalized slices. dist_batch_shape: The distribution's batch shape `Tensor`. Returns: new_param: A `Tensor`, batch-sliced according to slices. """ # Extend param shape with ones on the left to match dist_batch_shape. param_shape = ps.shape(param) insert_ones = ps.ones( [ps.size(dist_batch_shape) + param_event_ndims - ps.rank(param)], dtype=param_shape.dtype) new_param_shape = ps.concat([insert_ones, param_shape], axis=0) full_batch_param = tf.reshape(param, new_param_shape) param_slices = [] # We separately track the batch axis from the parameter axis because we want # them to align for positive indexing, and be offset by param_event_ndims for # negative indexing. param_dim_idx = 0 batch_dim_idx = 0 for slc in slices: if slc is tf.newaxis: param_slices.append(slc) continue if slc is Ellipsis: if batch_dim_idx < 0: raise ValueError('Found multiple `...` in slices {}'.format(slices)) param_slices.append(slc) # Switch over to negative indexing for the broadcast check. num_remaining_non_newaxis_slices = sum( [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]]) batch_dim_idx = -num_remaining_non_newaxis_slices param_dim_idx = batch_dim_idx - param_event_ndims continue # Find the batch dimension sizes for both parameter and distribution. param_dim_size = new_param_shape[param_dim_idx] batch_dim_size = dist_batch_shape[batch_dim_idx] is_broadcast = batch_dim_size > param_dim_size # Slices are denoted by start:stop:step. if isinstance(slc, slice): start, stop, step = slc.start, slc.stop, slc.step if start is not None: start = ps.where(is_broadcast, 0, start) if stop is not None: stop = ps.where(is_broadcast, 1, stop) if step is not None: step = ps.where(is_broadcast, 1, step) param_slices.append(slice(start, stop, step)) else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] param_slices.append(ps.where(is_broadcast, 0, slc)) param_dim_idx += 1 batch_dim_idx += 1 param_slices.extend([ALL_SLICE] * param_event_ndims) return full_batch_param.__getitem__(tuple(param_slices))
def _design_matrix_for_one_seasonal_effect(num_steps, duration, period, dtype): current_period = np.int32(np.arange(num_steps) / duration) % period return np.transpose([ ps.where(current_period == p, # pylint: disable=g-complex-comprehension ps.ones([], dtype=dtype), ps.zeros([], dtype=dtype)) for p in range(period)])
def expand_dims_(x): """Implementation of `expand_dims`.""" with tf.name_scope(name or 'expand_dims'): x = tf.convert_to_tensor(x, name='x') new_axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis') nx = prefer_static.rank(x) na = prefer_static.size(new_axis) is_neg_axis = new_axis < 0 k = prefer_static.reduce_sum( prefer_static.cast(is_neg_axis, new_axis.dtype)) new_axis = prefer_static.where(is_neg_axis, new_axis + nx, new_axis) new_axis = prefer_static.sort(new_axis) axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1]) idx = prefer_static.argsort(prefer_static.concat([ axis_pos, prefer_static.range(nx), axis_neg, ], axis=0), stable=True) shape = prefer_static.pad(prefer_static.shape(x), paddings=[[na - k, k]], constant_values=1) shape = prefer_static.gather(shape, idx) return tf.reshape(x, shape)
def _interleave(a, b, axis): """Interleaves two `Tensor`s along the given axis.""" # [a b c ...] [d e f ...] -> [a d b e c f ...] num_elems_a = ps.shape(a)[axis] num_elems_b = ps.shape(b)[axis] # Note that interleaving implies rank(a)==rank(b). axis = ps.where(axis >= 0, axis, ps.rank(a) + axis) axis = (int(axis) # Avoid ndarray values. if tf.get_static_value(axis) is not None else axis) def _interleave_with_b(a): return tf.reshape( # Work around lack of support for Tensor axes in `tf.stack` by using # `concat` and `expand_dims` instead. tf.concat([tf.expand_dims(a, axis=axis + 1), tf.expand_dims(b, axis=axis + 1)], axis=axis + 1), ps.concat( [ ps.shape(a)[:axis], [2 * num_elems_b], ps.shape(a)[axis + 1:] ], axis=0)) return ps.cond( ps.equal(num_elems_a, num_elems_b + 1), lambda: tf.concat([ # pylint: disable=g-long-lambda _interleave_with_b(_slice_along_axis(a, None, -1, axis=axis)), _slice_along_axis(a, -1, None, axis=axis)], axis=axis), lambda: _interleave_with_b(a))
def expand_dims(x, axis, name=None): """Like `tf.expand_dims` but accepts a vector of axes to expand.""" with tf.name_scope(name or 'expand_dims'): x = tf.convert_to_tensor(x, name='x') axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis') nx = prefer_static.rank(x) na = prefer_static.size(axis) is_neg_axis = axis < 0 k = prefer_static.reduce_sum( prefer_static.cast(is_neg_axis, axis.dtype)) axis = prefer_static.where(is_neg_axis, axis + nx, axis) axis = prefer_static.sort(axis) axis_neg, axis_pos = prefer_static.split(axis, [k, -1]) idx = prefer_static.argsort(prefer_static.concat([ axis_pos, prefer_static.range(nx), axis_neg, ], axis=0), stable=True) shape = prefer_static.pad(prefer_static.shape(x), paddings=[[na - k, k]], constant_values=1) shape = prefer_static.gather(shape, idx) return tf.reshape(x, shape)
def _filter_one_step(step, observation, previous_particles, log_weights, transition_fn, observation_fn, proposal_fn, resample_criterion_fn, seed=None): """Advances the particle filter by a single time step.""" with tf.name_scope('filter_one_step'): seed = SeedStream(seed, 'filter_one_step') num_particles = prefer_static.shape(log_weights)[-1] proposed_particles, proposal_log_weights = _propose_with_log_weights( step=step - 1, particles=previous_particles, transition_fn=transition_fn, proposal_fn=proposal_fn, seed=seed) observation_log_weights = _compute_observation_log_weights( step, proposed_particles, observation, observation_fn) unnormalized_log_weights = (log_weights + proposal_log_weights + observation_log_weights) step_log_marginal_likelihood = tf.math.reduce_logsumexp( unnormalized_log_weights, axis=-1) log_weights = (unnormalized_log_weights - step_log_marginal_likelihood[..., tf.newaxis]) # Adaptive resampling: resample particles iff the specified criterion. do_resample = tf.convert_to_tensor( resample_criterion_fn(unnormalized_log_weights))[ ..., tf.newaxis] # Broadcast over particles. # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to use # the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a dealbreaker. resampled_particles, resample_indices = _resample(proposed_particles, log_weights, seed=seed) dummy_indices = tf.broadcast_to(prefer_static.range(num_particles), prefer_static.shape(resample_indices)) uniform_weights = (prefer_static.zeros_like(log_weights) - prefer_static.log(num_particles)) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: prefer_static.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (proposed_particles, dummy_indices, log_weights)) return ParticleFilterStepResults( particles=resampled_particles, log_weights=log_weights, parent_indices=resample_indices, step_log_marginal_likelihood=step_log_marginal_likelihood)
def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None): if event_shape is None: event_shape = self.event_shape_tensor() batch_shape = self.batch_shape_tensor() batch_rank = ps.rank_from_shape(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape) # Continuing the example from `_augment_sample_shape`, suppose we have: # - sample shape of `[n]`, # - underlying distribution batch shape of `[2, 1]`, # - final broadcast batch shape of `[4, 2, 3]`. # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we # ultimately want to have shape `[n, 4, 2, 3] + event_shape`. # First, we reshape to expand out the batch elements: # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`, # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and # `[4, 1, 3]` is the shape of the elements being added by broadcasting. underlying_bcast_shp = ps.concat([ ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)], dtype=underlying_batch_shape.dtype), underlying_batch_shape ], axis=0) is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp) x_with_doubled_batch = tf.reshape( x, ps.concat([ sample_shape, ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp, event_shape ], axis=0)) # Next, construct the permutation that interleaves the batch dimensions, # resulting in samples with shape # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`. # Note that each interleaved pair of batch dimensions contains exactly one # dim of size `1` and one of size `>= 1`. sample_ndims = ps.rank_from_shape(sample_shape) x_with_interleaved_batch = tf.transpose( x_with_doubled_batch, perm=ps.concat([ ps.range(sample_ndims), sample_ndims + ps.reshape( ps.stack([ ps.range(batch_rank), ps.range(batch_rank) + batch_rank ], axis=-1), [-1]), sample_ndims + 2 * batch_rank + ps.range(ps.rank_from_shape(event_shape)) ], axis=0)) # Final reshape to remove the spurious `1` dimensions. return tf.reshape( x_with_interleaved_batch, ps.concat([sample_shape, batch_shape, event_shape], axis=0))
def body_fn(vecs, i): # Slice out the vector w.r.t. which we're orthogonalizing the rest. vecs_ndims = ps.rank(vecs) select_axis = (ps.range(vecs_ndims) == vecs_ndims - 1) start = ps.where(select_axis, i, ps.zeros([vecs_ndims], i.dtype)) size = ps.where(select_axis, 1, ps.shape(vecs)) u = tf.math.l2_normalize(tf.slice(vecs, start, size), axis=-2) # TODO(b/171730305): XLA can't handle this line... # u = tf.math.l2_normalize(vecs[..., i, tf.newaxis], axis=-2) # Find weights by dotting the d x 1 against the d x n. weights = tf.einsum('...dm,...dn->...n', u, vecs) # Project out vector `u` from the trailing vectors. masked_weights = tf.where(tf.range(n) > i, weights, 0.)[..., tf.newaxis, :] vecs = vecs - tf.math.multiply_no_nan(u, masked_weights) tensorshape_util.set_shape(vecs, vectors.shape) return vecs, i + 1
def _sanitize_slices(slices, intended_shape, deficient_shape): """Restricts slices to avoid overflowing size-1 (broadcast) dimensions. Args: slices: iterable of slices received by `__getitem__`. intended_shape: int `Tensor` shape for which the slices were intended. deficient_shape: int `Tensor` shape to which the slices will be applied. Must have the same rank as `intended_shape`. Returns: sanitized_slices: Python `list` of """ sanitized_slices = [] idx = 0 for slc in slices: if slc is Ellipsis: # Switch over to negative indexing. if idx < 0: raise ValueError( 'Found multiple `...` in slices {}'.format(slices)) num_remaining_non_newaxis_slices = sum([ s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:] ]) idx = -num_remaining_non_newaxis_slices elif slc is tf.newaxis: pass else: is_broadcast = intended_shape[idx] > deficient_shape[idx] if isinstance(slc, slice): # Slices are denoted by start:stop:step. start, stop, step = slc.start, slc.stop, slc.step if start is not None: start = ps.where(is_broadcast, 0, start) if stop is not None: stop = ps.where(is_broadcast, 1, stop) if step is not None: step = ps.where(is_broadcast, 1, step) slc = slice(start, stop, step) else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] slc = ps.where(is_broadcast, 0, slc) idx += 1 sanitized_slices.append(slc) return sanitized_slices
def _calculate_batch_shape(self): """Computes fully defined batch shape for the new distribution.""" all_batch_shapes = [d.batch_shape.as_list() if tensorshape_util.is_fully_defined(d.batch_shape) else d.batch_shape_tensor() for d in self.distributions] original_shape = ps.stack(all_batch_shapes, axis=0) index_mask = ps.cast( ps.one_hot(self._axis, ps.shape(original_shape)[1]), dtype=tf.bool) new_concat_dim = ps.cast( ps.reduce_sum(original_shape, axis=0)[self._axis], dtype=tf.int32) return ps.where(index_mask, new_concat_dim, ps.reduce_max(original_shape, axis=0))
def reduce_fn(operands, inits, axis=None, keepdims=False): """Applies `reducer` to the given operands along the given axes. Args: operands: tuple of tensors, all having the same shape. inits: tuple of scalar tensors, with dtypes aligned to those of operands. axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of `int`. `None` is taken to mean "reduce all axes". keepdims: When `True`, we do not squeeze away the reduced dims, instead returning values with singleton dims in those axes. Returns: reduced: A tuple of the reduced operands. """ # Static shape consistency checks. args_shape = operands[0].shape for arg in operands[1:]: args_shape = tensorshape_util.merge_with(args_shape, arg.shape) ndims = tensorshape_util.rank(args_shape) if ndims is None: raise ValueError( 'Rank of at least one of `operands` must be known statically.') # Ensure the 'axis' arg is a tuple of non-negative ints. axis = np.arange(ndims) if axis is None else np.array(axis) if axis.ndim > 1: raise ValueError( '`axis` must be `None`, an `int`, or a sequence of ' '`int`, but got {}'.format(axis)) axis = np.reshape(axis, [-1]) axis = np.where(axis < 0, axis + ndims, axis) axis = tuple(int(ax) for ax in axis) axis_nhot = ps.reduce_sum(ps.one_hot(axis, depth=ndims, on_value=True, off_value=False, dtype=tf.bool), axis=0) in_shape = args_shape if not tensorshape_util.is_fully_defined(in_shape): in_shape = tf.shape(operands[0]) unsqueezed_shape = ps.where(axis_nhot, 1, in_shape) result = _variadic_reduce_custom_grad(operands, inits, axis, reducer, unsqueezed_shape) if keepdims: result = tf.nest.map_structure( lambda t: tf.reshape(t, unsqueezed_shape), result) return result
def _canonicalize_steps_to_trace(step_indices_to_trace, num_timesteps): """Canonicalizes `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`, etc.""" step_indices_to_trace = tf.convert_to_tensor( step_indices_to_trace, dtype_hint=tf.int32) # Warning: breaks gradients. traced_steps_have_rank_zero = ps.equal( ps.rank_from_shape(ps.shape(step_indices_to_trace)), 0) # Canonicalize negative step indices as positive. step_indices_to_trace = ps.where(step_indices_to_trace < 0, num_timesteps + step_indices_to_trace, step_indices_to_trace) # Canonicalize scalars as length-one vectors. return (ps.reshape(step_indices_to_trace, [ps.size(step_indices_to_trace)]), traced_steps_have_rank_zero)
def _compute_observation_log_weights(step, particles, observations, observation_fn, num_transitions_per_observation=1): """Computes particle importance weights from an observation step. Args: step: int `Tensor` current step. particles: Nested structure of `Tensor`s, each of shape `concat([[num_particles, b1, ..., bN], event_shape])`, where `b1, ..., bN` are optional batch dimensions and `event_shape` may differ across `Tensor`s. observations: Nested structure of `Tensor`s, each of shape `concat([[num_observations, b1, ..., bN], event_shape])` where `b1, ..., bN` are optional batch dimensions and `event_shape` may differ across `Tensor`s. observation_fn: callable with signature `observation_dist = observation_fn(step, particles)`, producing a batch of distributions over the `observation` at the given `step`, one for each particle. num_transitions_per_observation: optional int `Tensor` number of times to apply the transition model between successive observation steps. Default value: `1`. Returns: log_weights: `Tensor` of shape `concat([num_particles, b1, ..., bN])`. """ with tf.name_scope('compute_observation_log_weights'): step_has_observation = ( # The second of these conditions subsumes the first, but both are # useful because the first can often be evaluated statically. ps.equal(num_transitions_per_observation, 1) | ps.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) log_weights = observation_fn(step, particles).log_prob(observation) return ps.where(step_has_observation, log_weights, tf.zeros_like(log_weights))
def _sample_n(self, n, seed=None): batch_shape = self.batch_shape_tensor() batch_rank = ps.rank_from_shape(batch_shape) n_batch = ps.reduce_prod(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape) underlying_n_batch = ps.reduce_prod(underlying_batch_shape) # Left pad underlying shape with any necessary ones. underlying_bcast_shp = ps.concat([ ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)], dtype=underlying_batch_shape.dtype), underlying_batch_shape ], axis=0) # Determine how many underlying samples to produce. n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch) samps = self.distribution.sample([n, n_bcast_samples], seed=seed) is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp) event_shape = self.event_shape_tensor() event_rank = ps.rank_from_shape(event_shape) shp = ps.concat([[n], ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp, event_shape], axis=0) # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp. samps = tf.reshape(samps, shp) # Interleave broadcast and underlying axis indices for transpose. interleaved_batch_axes = ps.reshape( ps.stack([ps.range(batch_rank), ps.range(batch_rank) + batch_rank], axis=-1), [-1]) + 1 event_axes = ps.range(event_rank) + (1 + 2 * batch_rank) perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0) samps = tf.transpose(samps, perm=perm) # Finally, reshape to the fully-broadcast batch shape. return tf.reshape(samps, ps.concat([[n], batch_shape, event_shape], axis=0))
def _calculate_new_shape(self): # Try to get the old shape statically if available. original_shape = self._distribution.batch_shape if not tensorshape_util.is_fully_defined(original_shape): original_shape = self._distribution.batch_shape_tensor() # This is not a check for falseness, it's a check for exactly that shape. if original_shape == (): # pylint: disable=g-explicit-bool-comparison # Force the size to be an integer, not a float, when the shape contains no # dtype information. original_size = 1 else: original_size = ps.reduce_prod(original_shape) original_size = ps.cast(original_size, tf.int32) # Compute the new shape, filling in the `-1` dimension if present. new_shape = self._batch_shape_unexpanded implicit_dim_mask = ps.equal(new_shape, -1) size_implicit_dim = (original_size // ps.maximum(1, -ps.reduce_prod(new_shape))) expanded_new_shape = ps.where( # Assumes exactly one `-1`. implicit_dim_mask, size_implicit_dim, new_shape) # Return the original size on the side because one caller would otherwise # have to recompute it. return expanded_new_shape, original_size
def _get_conditional_posterior(self, sampler_state): """Builds the joint posterior for a sparsity pattern (eqn (7) from [1]).""" indices = ps.where(sampler_state.nonzeros)[:, 0] conditional_posterior_precision_chol = tf.linalg.cholesky( tf.gather(tf.gather(sampler_state.weights_posterior_precision, indices), indices, axis=1)) conditional_weights_mean = tf.linalg.cholesky_solve( conditional_posterior_precision_chol, tf.gather(sampler_state.x_transpose_y, indices)[..., tf.newaxis])[..., 0] @joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched def posterior_jd(): observation_noise_variance = yield InverseGammaWithSampleUpperBound( concentration=( self.observation_noise_variance_posterior_concentration), scale=sampler_state.observation_noise_variance_posterior_scale, upper_bound=self.observation_noise_variance_upper_bound, name='observation_noise_variance') yield MVNPrecisionFactorHardZeros( loc=conditional_weights_mean, # Note that the posterior precision varies inversely with the # noise variance: in worlds with high noise we're also # more uncertain about the values of the weights. # TODO(colcarroll): Tests pass even without a square root on the # observation_noise_variance. Should add a test that would fail. precision_factor=tf.linalg.LinearOperatorLowerTriangular( conditional_posterior_precision_chol / tf.sqrt(observation_noise_variance[..., tf.newaxis, tf.newaxis])), nonzeros=sampler_state.nonzeros, name='weights') return posterior_jd
def _initialize_sampler_state(self, targets, nonzeros, observation_noise_variance): """Precompute quantities needed to sample with given targets. This method computes a sampler state (including factorized precision matrices) from scratch for a given sparsity pattern. This requires time proportional to `num_features**3`. If a sampler state is already available for an off-by-one sparsity pattern, the `_flip_feature` method (which takes time proportional to `num_features**2`) is generally more efficient. Args: targets: float Tensor regression outputs of shape `[num_outputs]`. nonzeros: boolean Tensor vectors of shape `[num_features]`. observation_noise_variance: float Tensor of to scale the posterior precision. Returns: sampler_state: instance of `DynamicSpikeSlabSamplerState` collecting Tensor quantities relevant to the sampler. See `DynamicSpikeSlabSamplerState` for details. """ with tf.name_scope('initialize_sampler_state'): targets = tf.convert_to_tensor(targets, dtype=self.dtype) nonzeros = tf.convert_to_tensor(nonzeros, dtype=tf.bool) indices = ps.where(nonzeros)[:, 0] x_transpose_y = tf.linalg.matvec(self.design_matrix, targets, adjoint_a=True) weights_posterior_precision = self.x_transpose_x + self.weights_prior_precision * observation_noise_variance y_transpose_y = tf.reduce_sum(targets**2, axis=-1) conditional_prior_precision_chol = tf.linalg.cholesky( tf.gather(tf.gather(self.weights_prior_precision, indices), indices, axis=1)) conditional_posterior_precision_chol = tf.linalg.cholesky( tf.gather(tf.gather(weights_posterior_precision, indices), indices, axis=1)) sub_x_transpose_y = tf.gather(x_transpose_y, indices) conditional_weights_mean = tf.linalg.cholesky_solve( conditional_posterior_precision_chol, sub_x_transpose_y[..., tf.newaxis])[..., 0] return self._compute_log_prob( x_transpose_y=x_transpose_y, y_transpose_y=y_transpose_y, nonzeros=nonzeros, conditional_prior_precision_chol= conditional_prior_precision_chol, conditional_posterior_precision_chol= conditional_posterior_precision_chol, weights_posterior_precision=weights_posterior_precision, observation_noise_variance_posterior_scale=( self.observation_noise_variance_prior_scale + # ss / 2 ( y_transpose_y - tf.reduce_sum( # beta_gamma' V_gamma^{-1} beta_gamma conditional_weights_mean * sub_x_transpose_y, axis=-1)) / 2))
def one_step(self, state, kernel_results, seed=None): """Takes one Sequential Monte Carlo inference step. Args: state: instance of `tfp.experimental.mcmc.WeightedParticles` representing the current particles with (log) weights. The `log_weights` must be a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The `particles` may be any structure of `Tensor`s, each of which must have shape `concat([log_weights.shape, event_shape])` for some `event_shape`, which may vary across components. kernel_results: instance of `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results from a previous step. seed: Optional seed for reproducible sampling. Returns: state: instance of `tfp.experimental.mcmc.WeightedParticles` representing new particles with (log) weights. kernel_results: instance of `tfp.experimental.mcmc.SequentialMonteCarloResults`. """ with tf.name_scope(self.name): with tf.name_scope('one_step'): seed = samplers.sanitize_seed(seed) proposal_seed, resample_seed = samplers.split_seed(seed) state = WeightedParticles(*state) # Canonicalize. num_particles = ps.size0(state.log_weights) # Propose new particles and update weights for this step, unless it's # the initial step, in which case, use the user-provided initial # particles and weights. proposed_state = self.propose_and_update_log_weights_fn( # Propose state[t] from state[t - 1]. ps.maximum(0, kernel_results.steps - 1), state, seed=proposal_seed) is_initial_step = ps.equal(kernel_results.steps, 0) # TODO(davmre): this `where` assumes the state size didn't change. state = tf.nest.map_structure( lambda a, b: tf.where(is_initial_step, a, b), state, proposed_state) normalized_log_weights = tf.nn.log_softmax(state.log_weights, axis=0) # Every entry of `log_weights` differs from `normalized_log_weights` # by the same normalizing constant. We extract that constant by # examining an arbitrary entry. incremental_log_marginal_likelihood = ( state.log_weights[0] - normalized_log_weights[0]) do_resample = self.resample_criterion_fn(state) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to # use the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a # dealbreaker. resampled_particles, resample_indices = weighted_resampling.resample( state.particles, state.log_weights, self.resample_fn, seed=resample_seed) uniform_weights = tf.fill( ps.shape(state.log_weights), value=-tf.math.log( tf.cast(num_particles, state.log_weights.dtype))) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: ps.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (state.particles, _dummy_indices_like(resample_indices), normalized_log_weights)) return ( WeightedParticles(particles=resampled_particles, log_weights=log_weights), SequentialMonteCarloResults( steps=kernel_results.steps + 1, parent_indices=resample_indices, incremental_log_marginal_likelihood=( incremental_log_marginal_likelihood), accumulated_log_marginal_likelihood=( kernel_results.accumulated_log_marginal_likelihood + incremental_log_marginal_likelihood), seed=seed))
def reduce_fn(operands, inits, axis=None, keepdims=False): """Applies `reducer` to the given operands along the given axes. Args: operands: tuple of tensors, all having the same shape. inits: tuple of scalar tensors, with dtypes aligned to those of operands. axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of `int`. `None` is taken to mean "reduce all axes". keepdims: When `True`, we do not squeeze away the reduced dims, instead returning values with singleton dims in those axes. Returns: reduced: A tuple of the reduced operands. """ # Static shape consistency checks. args_shape = operands[0].shape for arg in operands[1:]: args_shape = tensorshape_util.merge_with(args_shape, arg.shape) ndims = tensorshape_util.rank(args_shape) if ndims is None: raise ValueError( 'Rank of at least one of `operands` must be known statically.') # Ensure the 'axis' arg is a tuple of non-negative ints. axis = np.arange(ndims) if axis is None else np.array(axis) if axis.ndim > 1: raise ValueError( '`axis` must be `None`, an `int`, or a sequence of ' '`int`, but got {}'.format(axis)) axis = np.reshape(axis, [-1]) axis = np.where(axis < 0, axis + ndims, axis) axis = tuple(int(ax) for ax in axis) if JAX_MODE: from jax import lax # pylint: disable=g-import-not-at-top result = lax.reduce(operands, init_values=inits, dimensions=axis, computation=reducer) elif (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): result = _variadic_reduce(operands, init=inits, axis=axis, reducer=reducer) else: result = _xla_reduce(operands, inits, axis) if keepdims: axis_nhot = ps.reduce_sum(ps.one_hot(axis, depth=ndims, on_value=True, off_value=False, dtype=tf.bool), axis=0) in_shape = args_shape if not tensorshape_util.is_fully_defined(in_shape): in_shape = tf.shape(operands[0]) final_shape = ps.where(axis_nhot, 1, in_shape) result = tf.nest.map_structure( lambda t: tf.reshape(t, final_shape), result) return result
def covariance(x, y=None, sample_axis=0, event_axis=-1, keepdims=False, name=None): """Sample covariance between observations indexed by `event_axis`. Given `N` samples of scalar random variables `X` and `Y`, covariance may be estimated as ```none Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)} Xbar := N^{-1} sum_{n=1}^N X_n Ybar := N^{-1} sum_{n=1}^N Y_n ``` For vector-variate random variables `X = (X1, ..., Xd)`, `Y = (Y1, ..., Yd)`, one is often interested in the covariance matrix, `C_{ij} := Cov[Xi, Yj]`. ```python x = tf.random.normal(shape=(100, 2, 3)) y = tf.random.normal(shape=(100, 2, 3)) # cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j]. cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None) # cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n] cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1) ``` Notice we divide by `N`, which does not create `NaN` when `N = 1`, but is slightly biased. Args: x: A numeric `Tensor` holding samples. y: Optional `Tensor` with same `dtype` and `shape` as `x`. Default value: `None` (`y` is effectively set to `x`). sample_axis: Scalar or vector `Tensor` designating axis holding samples, or `None` (meaning all axis hold samples). Default value: `0` (leftmost dimension). event_axis: Scalar or vector `Tensor`, or `None` (scalar events). Axis indexing random events, whose covariance we are interested in. If a vector, entries must form a contiguous block of dims. `sample_axis` and `event_axis` should not intersect. Default value: `-1` (rightmost axis holds events). keepdims: Boolean. Whether to keep the sample axis as singletons. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., `'covariance'`). Returns: cov: A `Tensor` of same `dtype` as the `x`, and rank equal to `rank(x) - len(sample_axis) + 2 * len(event_axis)`. Raises: AssertionError: If `x` and `y` are found to have different shape. ValueError: If `sample_axis` and `event_axis` are found to overlap. ValueError: If `event_axis` is found to not be contiguous. """ with tf.name_scope(name or 'covariance'): x = tf.convert_to_tensor(x, name='x') # Covariance *only* uses the centered versions of x (and y). x = x - tf.reduce_mean(x, axis=sample_axis, keepdims=True) if y is None: y = x else: y = tf.convert_to_tensor(y, name='y', dtype=x.dtype) # If x and y have different shape, sample_axis and event_axis will likely # be wrong for one of them! tensorshape_util.assert_is_compatible_with(x.shape, y.shape) y = y - tf.reduce_mean(y, axis=sample_axis, keepdims=True) if event_axis is None: return tf.reduce_mean(x * tf.math.conj(y), axis=sample_axis, keepdims=keepdims) if sample_axis is None: raise ValueError( 'sample_axis was None, which means all axis hold events, and this ' 'overlaps with event_axis ({})'.format(event_axis)) event_axis = _make_positive_axis(event_axis, ps.rank(x)) sample_axis = _make_positive_axis(sample_axis, ps.rank(x)) # If we get lucky and axis is statically defined, we can do some checks. if _is_list_like(event_axis) and _is_list_like(sample_axis): event_axis = tuple(map(int, event_axis)) sample_axis = tuple(map(int, sample_axis)) if set(event_axis).intersection(sample_axis): raise ValueError( 'sample_axis ({}) and event_axis ({}) overlapped'.format( sample_axis, event_axis)) if (np.diff(np.array(sorted(event_axis))) > 1).any(): raise ValueError( 'event_axis must be contiguous. Found: {}'.format( event_axis)) batch_axis = list( sorted( set(range(tensorshape_util.rank( x.shape))).difference(sample_axis + event_axis))) else: batch_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), ps.concat((sample_axis, event_axis), 0)) event_axis = ps.cast(event_axis, dtype=tf.int32) sample_axis = ps.cast(sample_axis, dtype=tf.int32) batch_axis = ps.cast(batch_axis, dtype=tf.int32) # Permute x/y until shape = B + E + S perm_for_xy = ps.concat((batch_axis, event_axis, sample_axis), 0) x_permed = tf.transpose(a=x, perm=perm_for_xy) y_permed = tf.transpose(a=y, perm=perm_for_xy) batch_ndims = ps.size(batch_axis) batch_shape = ps.shape(x_permed)[:batch_ndims] event_ndims = ps.size(event_axis) event_shape = ps.shape(x_permed)[batch_ndims:batch_ndims + event_ndims] sample_shape = ps.shape(x_permed)[batch_ndims + event_ndims:] sample_ndims = ps.size(sample_shape) n_samples = ps.reduce_prod(sample_shape) n_events = ps.reduce_prod(event_shape) # Flatten sample_axis into one long dim. x_permed_flat = tf.reshape( x_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0)) y_permed_flat = tf.reshape( y_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0)) # Do the same for event_axis. x_permed_flat = tf.reshape( x_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0)) y_permed_flat = tf.reshape( y_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0)) # After matmul, cov.shape = batch_shape + [n_events, n_events] cov = tf.matmul(x_permed_flat, y_permed_flat, adjoint_b=True) / ps.cast(n_samples, x.dtype) # Insert some singletons to make # cov.shape = batch_shape + event_shape**2 + [1,...,1] # This is just like x_permed.shape, except the sample_axis is all 1's, and # the [n_events] became event_shape**2. cov = tf.reshape( cov, ps.concat( ( batch_shape, # event_shape**2 used here because it is the same length as # event_shape, and has the same number of elements as one # batch of covariance. event_shape**2, ps.ones([sample_ndims], tf.int32)), 0)) # Permuting by the argsort inverts the permutation, making # cov.shape have ones in the position where there were samples, and # [n_events * n_events] in the event position. cov = tf.transpose(a=cov, perm=ps.invert_permutation(perm_for_xy)) # Now expand event_shape**2 into event_shape + event_shape. # We here use (for the first time) the fact that we require event_axis to be # contiguous. e_start = event_axis[0] e_len = 1 + event_axis[-1] - event_axis[0] cov = tf.reshape( cov, ps.concat((ps.shape(cov)[:e_start], event_shape, event_shape, ps.shape(cov)[e_start + e_len:]), 0)) # tf.squeeze requires python ints for axis, not Tensor. This is enough to # require our axis args to be constants. if not keepdims: squeeze_axis = ps.where(sample_axis < e_start, sample_axis, sample_axis + e_len) cov = _squeeze(cov, axis=squeeze_axis) return cov
def _filter_one_step(step, observation, previous_particles, log_weights, transition_fn, observation_fn, proposal_fn, resample_criterion_fn, has_observation=True, seed=None): """Advances the particle filter by a single time step.""" with tf.name_scope('filter_one_step'): seed = SeedStream(seed, 'filter_one_step') num_particles = prefer_static.shape(log_weights)[0] proposed_particles, proposal_log_weights = _propose_with_log_weights( step=step - 1, particles=previous_particles, transition_fn=transition_fn, proposal_fn=proposal_fn, seed=seed) log_weights = tf.nn.log_softmax(proposal_log_weights + log_weights, axis=-1) # If this step has an observation, compute its weights and marginal # likelihood (and otherwise, leave weights unchanged). observation_log_weights = prefer_static.cond( has_observation, lambda: prefer_static.broadcast_to( # pylint: disable=g-long-lambda _compute_observation_log_weights(step, proposed_particles, observation, observation_fn), prefer_static.shape(log_weights)), lambda: tf.zeros_like(log_weights)) unnormalized_log_weights = log_weights + observation_log_weights step_log_marginal_likelihood = tf.math.reduce_logsumexp( unnormalized_log_weights, axis=0) log_weights = (unnormalized_log_weights - step_log_marginal_likelihood) # Adaptive resampling: resample particles iff the specified criterion. do_resample = resample_criterion_fn(unnormalized_log_weights) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to use # the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a dealbreaker. resampled_particles, resample_indices = _resample(proposed_particles, log_weights, resample_independent, seed=seed) uniform_weights = (prefer_static.zeros_like(log_weights) - prefer_static.log(num_particles)) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: prefer_static.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (proposed_particles, _dummy_indices_like(resample_indices), log_weights)) return ParticleFilterStepResults( particles=resampled_particles, log_weights=log_weights, parent_indices=resample_indices, step_log_marginal_likelihood=step_log_marginal_likelihood)
def __init__(self, loc, precision_factor, nonzeros, **kwargs): self._indices = ps.where(nonzeros) self._size = ps.dimension_size(nonzeros, -1) super().__init__(loc=loc, precision_factor=precision_factor, **kwargs)