def chees_criterion(previous_state, proposed_state, accept_prob, validate_args=False): """The ChEES criterion from [1]. ChEES stands for Change in the Estimator of the Expected Square. ```None ChEES = 1/4 E[(||x' - E[x]||**2 - ||x - E[x]||**2)**2], ``` where `x` is the previous chain state, `x'` is the next chain state, and `||.||` is the L2 norm. Both expectations are with respect to the chain's stationary distribution. In practice, the inner expectation is replaced by the empirical mean across chains, so computing this criterion requires that at least 2 chains are present. The outer expectation is computed by the caller (e.g. in the `GradientBasedTrajectoryLengthAdaptation` kernel). This can be thought of as the standard expected squared jump distance (ESJD) criterion, except that the jump distance is computed in the space of centered squared L2 norms. Unlike ChEES, regular ESJD is maximized by perfectly anticorrelated proposals, which can give excellent mean estimates but terrible variance estimates; maximizing ChEES should give good estimates across a wider range of types of posterior expectations. Args: previous_state: (Possibly nested) floating point `Tensor`. The previous state of the HMC chain. proposed_state: (Possibly nested) floating point `Tensor`. The proposed state of the HMC chain. accept_prob: Floating `Tensor`. Probability of acceping the proposed state. validate_args: Whether to perform non-static argument validation. Returns: chees: The value of the ChEES criterion. Raises: ValueError: If `accept_prob` indicates that there are fewer than 2 chains. #### References [1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme for Setting Trajectory Lengths in Hamiltonian Monte Carlo. In preparation. """ batch_ndims = ps.rank(accept_prob) batch_axes = ps.range(batch_ndims, dtype=tf.int32) num_chains = ps.size(accept_prob) num_chains_ = tf.get_static_value(num_chains) if num_chains_ is not None: if num_chains_ < 2: raise ValueError( 'chees_criterion requires at least 2 chains. Got: {}'.format( num_chains_)) elif validate_args: with tf.control_dependencies([ assert_util.assert_greater_equal( num_chains, 2, 'chees_criterion requires at least 2 chains.') ]): previous_state = tf.nest.map_structure(tf.identity, previous_state) def _center_previous_state(x): # The empirical mean here is a stand-in for the true mean, so we drop the # gradient that flows through this term. return x - tf.stop_gradient(tf.reduce_mean(x, axis=batch_axes)) def _center_proposed_state(x): # The empirical mean here is a stand-in for the true mean, so we drop the # gradient that flows through this term. The goal here is to get a reliable # diagnostic of the unrelying dynamics, rather than incorporating the effect # of the MetropolisHastings correction. # TODO(mhoffman): Needs more experimentation. expanded_accept_prob = mcmc_util.left_justified_expand_dims_like( accept_prob, x) # accept_prob is zero when x is NaN, but we still want to sanitize such # values. x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x)) # If all accept_prob's are zero, the x_center will have a nonsense value, # but we'll discard the resultant gradients later on, so it's fine. x_center = ( tf.reduce_sum(expanded_accept_prob * x_safe, axis=batch_axes) / (tf.reduce_sum(expanded_accept_prob, axis=batch_axes) + 1e-20)) return x - tf.stop_gradient(x_center) def _sum_event_part(x): event_axes = ps.range(batch_ndims, ps.rank(x)) return tf.reduce_sum(x, axis=event_axes) def _sum_event(x): return sum(tf.nest.flatten(tf.nest.map_structure( _sum_event_part, x, ))) def _square(x): return tf.nest.map_structure(tf.square, x) def _sub(x, y): return tf.nest.map_structure(lambda x, y: x - y, x, y) previous_state = tf.nest.map_structure(_center_previous_state, previous_state) proposed_state = tf.nest.map_structure(_center_proposed_state, proposed_state) chees = 0.25 * tf.square( _sum_event(_sub(_square(proposed_state), _square(previous_state)))) return chees
def posterior_marginals(self, observations, mask=None, name=None): """Compute marginal posterior distribution for each state. This function computes, for each time step, the marginal conditional probability that the hidden Markov model was in each possible state given the observations that were made at each time step. So if the hidden states are `z[0],...,z[num_steps - 1]` and the observations are `x[0], ..., x[num_steps - 1]`, then this function computes `P(z[i] | x[0], ..., x[num_steps - 1])` for all `i` from `0` to `num_steps - 1`. This operation is sometimes called smoothing. It uses a form of the forward-backward algorithm. Note: the behavior of this function is undefined if the `observations` argument represents impossible observations from the model. Args: observations: A tensor representing a batch of observations made on the hidden Markov model. The rightmost dimension of this tensor gives the steps in a sequence of observations from a single sample from the hidden Markov model. The size of this dimension should match the `num_steps` parameter of the hidden Markov model object. The other dimensions are the dimensions of the batch and these are broadcast with the hidden Markov model's parameters. mask: optional bool-type `tensor` with rightmost dimension matching `num_steps` indicating which observations the result of this function should be conditioned on. When the mask has value `True` the corresponding observations aren't used. if `mask` is `None` then all of the observations are used. the `mask` dimensions left of the last are broadcast with the hmm batch as well as with the observations. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Returns: posterior_marginal: A `Categorical` distribution object representing the marginal probability of the hidden Markov model being in each state at each step. The rightmost dimension of the `Categorical` distributions batch will equal the `num_steps` parameter providing one marginal distribution for each step. The other dimensions are the dimensions corresponding to the batch of observations. Raises: ValueError: if rightmost dimension of `observations` does not have size `num_steps`. """ with tf.name_scope(name or "posterior_marginals"): with tf.control_dependencies(self._runtime_assertions): observation_tensor_shape = tf.shape(observations) mask_tensor_shape = tf.shape( mask) if mask is not None else None with self._observation_mask_shape_preconditions( observation_tensor_shape, mask_tensor_shape): observation_log_probs = self._observation_log_probs( observations, mask) log_prob = self._log_init + observation_log_probs[0] log_transition = self._log_trans log_adjoint_prob = tf.zeros_like(log_prob) def _scan_multiple_steps_forwards(): def forward_step(log_previous_step, log_prob_observation): return _log_vector_matrix( log_previous_step, log_transition) + log_prob_observation forward_log_probs = tf.scan(forward_step, observation_log_probs[1:], initializer=log_prob, name="forward_log_probs") return tf.concat([[log_prob], forward_log_probs], axis=0) forward_log_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps_forwards, lambda: tf.convert_to_tensor([log_prob])) total_log_prob = tf.reduce_logsumexp(forward_log_probs[-1], axis=-1) def _scan_multiple_steps_backwards(): """Perform `scan` operation when `num_steps` > 1.""" def backward_step(log_previous_step, log_prob_observation): return _log_matrix_vector( log_transition, log_prob_observation + log_previous_step) backward_log_adjoint_probs = tf.scan( backward_step, observation_log_probs[1:], initializer=log_adjoint_prob, reverse=True, name="backward_log_adjoint_probs") return tf.concat( [backward_log_adjoint_probs, [log_adjoint_prob]], axis=0) backward_log_adjoint_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps_backwards, lambda: tf.convert_to_tensor([log_adjoint_prob])) log_likelihoods = forward_log_probs + backward_log_adjoint_probs marginal_log_probs = distribution_util.move_dimension( log_likelihoods - total_log_prob[..., tf.newaxis], 0, -2) return categorical.Categorical(logits=marginal_log_probs)
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) fh, fw = filter_shape assertions = _maybe_validate_input_shapes(ps.shape(kernel), channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): if tf.get_static_value(ps.rank(kernel)) == 2: flat_x = tf.reshape(x, shape=ps.concat([[-1], event_shape], axis=0)) flat_y = tf.nn.conv2d(x, filters=tf.reshape( kernel, shape=[fh, fw, c_in, -1]), strides=strides, padding=padding, data_format='NHWC', dilations=dilations) output_shape = ps.shape(flat_y)[-3:] return tf.reshape(flat_y, shape=ps.concat([batch_shape, output_shape], axis=0)) pad_values = [ _get_conv_padding(xdim, filter_dim=k, stride=s, dilation=d, padding=padding) for (xdim, k, s, d) in zip((xh, xw), filter_shape, strides, dilations) ] idx, shape = im2row_index( (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in), block_shape=filter_shape, slice_step=strides, dilations=dilations, dtype=dtype) if padding == 'SAME': n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) x = tf.pad(x, paddings=paddings, constant_values=0) flat_shape = ps.pad(batch_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.gather(tf.reshape(x, shape=flat_shape), indices=idx, axis=-1) im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_reconstruct'). Returns: x: The original input to `tf.linalg.lu`, i.e., `x` as in, `lu_reconstruct(*tf.linalg.lu(x))`. #### Examples ```python import numpy as np import tensorflow as tf import tensorflow_probability as tfp x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x)) tf.assert_near(x, x_reconstructed) # ==> True ``` """ with tf.name_scope(name or 'lu_reconstruct'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') assertions = _lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) shape = tf.shape(lower_upper) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(shape[:-1], dtype=lower_upper.dtype)) upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1) x = tf.matmul(lower, upper) if lower_upper.shape.ndims is None or lower_upper.shape.ndims != 2: # We either don't know the batch rank or there are >0 batch dims. batch_size = tf.reduce_prod(shape[:-2]) d = shape[-1] x = tf.reshape(x, [batch_size, d, d]) perm = tf.reshape(perm, [batch_size, d]) perm = tf.map_fn(tf.math.invert_permutation, perm) batch_indices = tf.broadcast_to( tf.range(batch_size)[:, tf.newaxis], [batch_size, d]) x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1)) x = tf.reshape(x, shape) else: x = tf.gather(x, tf.math.invert_permutation(perm)) x.set_shape(lower_upper.shape) return x
def _event_shape_tensor(self): with tf.control_dependencies(self._runtime_assertions): return tf.concat( [[self._num_steps], self.observation_distribution.event_shape_tensor()], axis=0)
def fit_with_hmc(model, observed_time_series, num_results=100, num_warmup_steps=50, num_leapfrog_steps=15, initial_state=None, initial_step_size=None, chain_batch_shape=(), num_variational_steps=150, variational_optimizer=None, variational_sample_size=5, seed=None, name=None): """Draw posterior samples using Hamiltonian Monte Carlo (HMC). Markov chain Monte Carlo (MCMC) methods are considered the gold standard of Bayesian inference; under suitable conditions and in the limit of infinitely many draws they generate samples from the true posterior distribution. HMC [1] uses gradients of the model's log-density function to propose samples, allowing it to exploit posterior geometry. However, it is computationally more expensive than variational inference and relatively sensitive to tuning. This method attempts to provide a sensible default approach for fitting StructuralTimeSeries models using HMC. It first runs variational inference as a fast posterior approximation, and initializes the HMC sampler from the variational posterior, using the posterior standard deviations to set per-variable step sizes (equivalently, a diagonal mass matrix). During the warmup phase, it adapts the step size to target an acceptance rate of 0.75, which is thought to be in the desirable range for optimal mixing [2]. Args: model: An instance of `StructuralTimeSeries` representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape `[b1, ..., bN]`. observed_time_series: `float` `Tensor` of shape `concat([sample_shape, model.batch_shape, [num_timesteps, 1]])` where `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]` dimension may (optionally) be omitted if `num_timesteps > 1`. Any `NaN`s are interpreted as missing observations; missingness may be also be explicitly specified by passing a `tfp.sts.MaskedTimeSeries` instance. num_results: Integer number of Markov chain draws. Default value: `100`. num_warmup_steps: Integer number of steps to take before starting to collect results. The warmup steps are also used to adapt the step size towards a target acceptance rate of 0.75. Default value: `50`. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. Default value: `15`. initial_state: Optional Python `list` of `Tensor`s, one for each model parameter, representing the initial state(s) of the Markov chain(s). These should have shape `concat([chain_batch_shape, param.prior.batch_shape, param.prior.event_shape])`. If `None`, the initial state is set automatically using a sample from a variational posterior. Default value: `None`. initial_step_size: Python `list` of `Tensor`s, one for each model parameter, representing the step size for the leapfrog integrator. Must broadcast with the shape of `initial_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. If `None`, the step size is set automatically using the standard deviation of a variational posterior. Default value: `None`. chain_batch_shape: Batch shape (Python `tuple`, `list`, or `int`) of chains to run in parallel. Default value: `[]` (i.e., a single chain). num_variational_steps: Python `int` number of steps to run the variational optimization to determine the initial state and step sizes. Default value: `150`. variational_optimizer: Optional `tf.train.Optimizer` instance to use in the variational optimization. If `None`, defaults to `tf.train.AdamOptimizer(0.1)`. Default value: `None`. variational_sample_size: Python `int` number of Monte Carlo samples to use in estimating the variational divergence. Larger values may stabilize the optimization, but at higher cost per step in time and memory. Default value: `1`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'fit_with_hmc'). Returns: samples: Python `list` of `Tensors` representing posterior samples of model parameters, with shapes `[concat([[num_results], chain_batch_shape, param.prior.batch_shape, param.prior.event_shape]) for param in model.parameters]`. kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the HMC sampler. #### Examples Assume we've built a structural time-series model: ```python day_of_week = tfp.sts.Seasonal( num_seasons=7, observed_time_series=observed_time_series, name='day_of_week') local_linear_trend = tfp.sts.LocalLinearTrend( observed_time_series=observed_time_series, name='local_linear_trend') model = tfp.sts.Sum(components=[day_of_week, local_linear_trend], observed_time_series=observed_time_series) ``` To draw posterior samples using HMC under default settings: ```python samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series) print("acceptance rate: {}".format( np.mean(kernel_results.inner_results.inner_results.is_accepted, axis=0))) print("posterior means: {}".format( {param.name: np.mean(param_draws, axis=0) for (param, param_draws) in zip(model.parameters, samples)})) ``` We can also run multiple chains. This may help diagnose convergence issues and allows us to exploit vectorization to draw samples more quickly, although warmup still requires the same number of sequential steps. ```python from matplotlib import pylab as plt samples, kernel_results = tfp.sts.fit_with_hmc( model, observed_time_series, chain_batch_shape=[10]) print("acceptance rate: {}".format( np.mean(kernel_results.inner_results.inner_results.is_accepted, axis=0))) # Plot the sampled traces for each parameter. If the chains have mixed, their # traces should all cover the same region of state space, frequently crossing # over each other. for (param, param_draws) in zip(model.parameters, samples): if param.prior.event_shape.ndims > 0: print("Only plotting traces for scalar parameters, skipping {}".format( param.name)) continue plt.figure(figsize=[10, 4]) plt.title(param.name) plt.plot(param_draws.numpy()) plt.ylabel(param.name) plt.xlabel("HMC step") # Combining the samples from multiple chains into a single dimension allows # us to easily pass sampled parameters to downstream forecasting methods. combined_samples = [np.reshape(param_draws, [-1] + list(param_draws.shape[2:])) for param_draws in samples] ``` For greater flexibility, you may prefer to implement your own sampler using the TensorFlow Probability primitives in `tfp.mcmc`. The following recipe constructs a basic HMC sampler, using a `TransformedTransitionKernel` to incorporate constraints on the parameter space. ```python transformed_hmc_kernel = tfp.mcmc.TransformedTransitionKernel( inner_kernel=tfp.mcmc.DualAveragingStepSizeAdaptation( inner_kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=model.joint_distribution(observed_time_series).log_prob, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=True, seed=seed), num_adaptation_steps = int(0.8 * num_warmup_steps)), bijector=[param.bijector for param in model.parameters]) # Initialize from a Uniform[-2, 2] distribution in unconstrained space. initial_state = [tfp.sts.sample_uniform_initial_state( param, return_constrained=True) for param in model.parameters] samples, kernel_results = tfp.mcmc.sample_chain( kernel=transformed_hmc_kernel, num_results=num_results, current_state=initial_state, num_burnin_steps=num_warmup_steps) ``` #### References [1]: Radford Neal. MCMC Using Hamiltonian Dynamics. _Handbook of Markov Chain Monte Carlo_, 2011. https://arxiv.org/abs/1206.1901 [2] M.J. Betancourt, Simon Byrne, and Mark Girolami. Optimizing The Integrator Step Size for Hamiltonian Monte Carlo. https://arxiv.org/abs/1411.6669 """ with tf.name_scope(name or 'fit_with_hmc') as name: seed = tfp_util.SeedStream(seed, salt='StructuralTimeSeries_fit_with_hmc') observed_time_series = sts_util.pad_batch_dimension_for_multiple_chains( observed_time_series, model, chain_batch_shape=chain_batch_shape) target_log_prob_fn = model.joint_distribution( observed_time_series).log_prob # Initialize state and step sizes from a variational posterior if not # specified. if initial_step_size is None or initial_state is None: variational_posterior = build_factored_surrogate_posterior( model, batch_shape=chain_batch_shape, seed=seed()) if variational_optimizer is None: variational_optimizer = tf1.train.AdamOptimizer( learning_rate=0.1 ) # TODO(b/137299119) Replace with TF2 optimizer. loss_curve = vi.fit_surrogate_posterior( target_log_prob_fn, variational_posterior, sample_size=variational_sample_size, num_steps=num_variational_steps, optimizer=variational_optimizer, seed=seed()) with tf.control_dependencies([loss_curve]): if initial_state is None: posterior_sample = variational_posterior.sample() initial_state = [ posterior_sample[p.name] for p in model.parameters ] # Set step sizes using the unconstrained variational distribution. if initial_step_size is None: q_dists_by_name, _ = (variational_posterior.distribution. sample_distributions()) initial_step_size = [ q_dists_by_name[p.name].stddev() for p in model.parameters ] # Run HMC to sample from the posterior on parameters. @tf.function(autograph=False) def run_hmc(): return mcmc.sample_chain( num_results=num_results, current_state=initial_state, num_burnin_steps=num_warmup_steps, kernel=mcmc.DualAveragingStepSizeAdaptation( inner_kernel=mcmc.TransformedTransitionKernel( inner_kernel=mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, step_size=initial_step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=True), bijector=[ param.bijector for param in model.parameters ]), num_adaptation_steps=int(num_warmup_steps * 0.8)), seed=seed()) samples, kernel_results = run_hmc() return samples, kernel_results
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): """Solves systems of linear eqns `A X = RHS`, given LU factorizations. Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. rhs: Matrix-shaped float `Tensor` representing targets for which to solve; `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_solve'). Returns: x: The `X` in `A @ X = RHS`. #### Examples ```python import numpy as np import tensorflow as tf import tensorflow_probability as tfp x = [[[1., 2], [3, 4]], [[7, 8], [3, 4]]] inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with tf.name_scope(name or 'lu_solve'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') rhs = tf.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) rhs = tf.identity(rhs) if rhs.shape.ndims == 2 and perm.shape.ndims == 1: # Both rhs and perm have scalar batch_shape. permuted_rhs = tf.gather(rhs, perm, axis=-2) else: # Either rhs or perm have non-scalar batch_shape or we can't determine # this information statically. rhs_shape = tf.shape(rhs) broadcast_batch_shape = tf.broadcast_dynamic_shape( rhs_shape[:-2], tf.shape(perm)[:-1]) d, m = rhs_shape[-2], rhs_shape[-1] rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]], axis=0) # Tile out rhs. broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape) broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m]) # Tile out perm and add batch indices. broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1]) broadcast_perm = tf.reshape(broadcast_perm, [-1, d]) broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape) broadcast_batch_indices = tf.broadcast_to( tf.range(broadcast_batch_size)[:, tf.newaxis], [broadcast_batch_size, d]) broadcast_perm = tf.stack( [broadcast_batch_indices, broadcast_perm], axis=-1) permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm) permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) return linear_operator_util.matrix_triangular_solve_with_broadcast( lower_upper, # Only upper is accessed. linear_operator_util.matrix_triangular_solve_with_broadcast( lower, permuted_rhs), lower=False)
def __init__(self, mean_direction, concentration, validate_args=False, allow_nan_stats=True, name='VonMisesFisher'): """Creates a new `VonMisesFisher` instance. Args: mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D]. A unit vector indicating the mode of the distribution, or the unit-normalized direction of the mean. (This is *not* in general the mean of the distribution; the mean is not generally in the support of the distribution.) NOTE: `D` is currently restricted to <= 5. concentration: Floating-point `Tensor` having batch shape [B1, ... Bn] broadcastable with `mean_direction`. The level of concentration of samples around the `mean_direction`. `concentration=0` indicates a uniform distribution over the unit hypersphere, and `concentration=+inf` indicates a `Deterministic` distribution (delta function) at `mean_direction`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: For known-bad arguments, i.e. unsupported event dimension. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([mean_direction, concentration], tf.float32) mean_direction = tf.convert_to_tensor(mean_direction, name='mean_direction', dtype=dtype) concentration = tf.convert_to_tensor(concentration, name='concentration', dtype=dtype) assertions = [ assert_util.assert_non_negative( concentration, message='`concentration` must be non-negative'), assert_util.assert_greater( tf.shape(mean_direction)[-1], 1, message='`mean_direction` may not have scalar event shape' ), assert_util.assert_near( 1., tf.linalg.norm(mean_direction, axis=-1), message='`mean_direction` must be unit-length') ] if validate_args else [] static_event_dim = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mean_direction.shape, 1)[-1]) if static_event_dim is not None and static_event_dim > 5: raise ValueError('vMF ndims > 5 is not currently supported') elif validate_args: assertions += [ assert_util.assert_less_equal( tf.shape(mean_direction)[-1], 5, message='vMF ndims > 5 is not currently supported') ] with tf.control_dependencies(assertions): self._mean_direction = tf.identity(mean_direction) self._concentration = tf.identity(concentration) dtype_util.assert_same_float_dtype( [self._mean_direction, self._concentration]) # mean_direction is always reparameterized. # concentration is only for event_dim==3, via an inversion sampler. reparameterization_type = (reparameterization.FULLY_REPARAMETERIZED if static_event_dim == 3 else reparameterization.NOT_REPARAMETERIZED) super(VonMisesFisher, self).__init__( dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization_type, parameters=parameters, graph_parents=[self._mean_direction, self._concentration], name=name)
def _sample_n(self, n, seed=None): seed = seed_stream.SeedStream(seed, salt='vom_mises_fisher') # The sampling strategy relies on the fact that vMF variates are symmetric # about the mean direction. Accordingly, if we have a sampling strategy for # the away-from-mean angle, then we can uniformly sample the remaining # dimensions on the S^{dim-2} sphere for , and rotate these samples from a # (1, 0, 0, ..., 0)-mode distribution into the target orientation. # # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a # von-Mises distributed `x` value in [-1, 1], then uniformly select what # amounts to a "up" or "down" additional degree of freedom after unit # normalizing, followed by a final rotation to the desired mean direction # from a basis of (1, 0). # # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the # unit sphere over which the distribution is uniform, in particular the # circle where x = \hat{x} intersects the unit sphere. We pick a point on # that circle, then rotate to the desired mean direction from a basis of # (1, 0, 0). event_dim = (tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor()[0]) sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) dim = tf.cast(event_dim - 1, self.dtype) if event_dim == 3: samples_dim0 = self._sample_3d(n, seed=seed) else: # Wood'94 provides a rejection algorithm to sample the x coordinate. # Wood'94 definition of b: # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim # https://stats.stackexchange.com/questions/156729 suggests: b = dim / (2 * self.concentration + tf.sqrt(4 * self.concentration**2 + dim**2)) # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE # https://github.com/nicola-decao/s-vae-tf/ x = (1 - b) / (1 + b) c = self.concentration * x + dim * tf.math.log1p(-x**2) beta = beta_lib.Beta(dim / 2, dim / 2) def cond_fn(w, should_continue): del w return tf.reduce_any(should_continue) def body_fn(w, should_continue): z = beta.sample(sample_shape=sample_batch_shape, seed=seed()) w = tf1.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w) w = tf.debugging.check_numerics(w, 'w') should_continue = tf.logical_and( should_continue, self.concentration * w + dim * tf.math.log1p(-x * w) - c < tf.math.log( tf.random.uniform(sample_batch_shape, seed=seed(), dtype=self.dtype))) return w, should_continue w = tf.zeros(sample_batch_shape, dtype=self.dtype) should_continue = tf.ones(sample_batch_shape, dtype=tf.bool) samples_dim0 = tf.while_loop(cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0] samples_dim0 = samples_dim0[..., tf.newaxis] if not self._allow_nan_stats: # Verify samples are w/in -1, 1, with useful error output tensors (top # value rather than all values). with tf.control_dependencies([ assert_util.assert_less_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(1.01), data=[tf.nn.top_k(tf.reshape(samples_dim0, [-1]))[0]]), assert_util.assert_greater_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(-1.01), data=[ -tf.nn.top_k(tf.reshape(-samples_dim0, [-1]))[0] ]) ]): samples_dim0 = tf.identity(samples_dim0) samples_otherdims_shape = tf.concat( [sample_batch_shape, [event_dim - 1]], axis=0) unit_otherdims = tf.nn.l2_normalize(tf.random.normal( samples_otherdims_shape, seed=seed(), dtype=self.dtype), axis=-1) samples = tf.concat( [ samples_dim0, # we must avoid sqrt(1 - (>1)**2) tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims ], axis=-1) samples = tf.nn.l2_normalize(samples, axis=-1) if not self._allow_nan_stats: samples = tf.debugging.check_numerics(samples, 'samples') # Runtime assert that samples are unit length. if not self._allow_nan_stats: worst, idx = tf.nn.top_k( tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1])) with tf.control_dependencies([ assert_util.assert_near( dtype_util.as_numpy_dtype(self.dtype)(0), worst, data=[ worst, idx, tf.gather(tf.reshape(samples, [-1, event_dim]), idx) ], atol=1e-4, summarize=100) ]): samples = tf.identity(samples) # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0). # Now, we move the mode to `self.mean_direction` using a rotation matrix. if not self._allow_nan_stats: # Assert that the basis vector rotates to the mean direction, as expected. basis = tf.cast( tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0), self.dtype) with tf.control_dependencies([ assert_util.assert_less( tf.linalg.norm(self._rotate(basis) - self.mean_direction, axis=-1), dtype_util.as_numpy_dtype(self.dtype)(1e-5)) ]): return self._rotate(samples) return self._rotate(samples)
def push(self, value, mask, name=None): """Pushes `value` onto the stack, advances frame of batch members in `mask`. In this impl, we update each thread's top-of-stack (regardless of `mask`) to the corresponding `value`, then advance the stack pointers of only those threads indicated by `mask`. Args: value: `Tensor` having the shape of a single batch of the variable. mask: Boolean `Tensor` of shape `[batch_size]`. Threads at `True` indices of `mask` have their stack frames advanced; the others remain. name: Optional name for this op. Returns: stack: Updated stack. Does not mutate `self`. asserted_value: A assertion-bound snapshot of the input `value`, assertions used to catch stack overflows. """ with tf.name_scope(name or 'Stack.push'): value = tf.convert_to_tensor(value=value, name='value') mask = tf.convert_to_tensor(value=mask, name='mask') # self.stack: [max_stack_depth * batch_size, ...] # self.stack_index: [batch_size] # value: [batch_size, ...] batch_size = (tf.compat.dimension_value(self.stack_index.shape[0]) or tf.shape(input=self.stack_index)[0]) max_stack_depth = (tf.compat.dimension_value(self.stack.shape[0]) or tf.shape(input=self.stack)[0]) // batch_size max_stack_depth_tensor = tf.convert_to_tensor( value=max_stack_depth) tiled_value = tf.tile( input=value[tf.newaxis, ...], multiples=tf.concat( [[max_stack_depth_tensor], tf.ones(tf.rank(value), dtype=max_stack_depth_tensor.dtype)], axis=0)) update_stack_mask = tf.one_hot( self.stack_index, depth=max_stack_depth, axis= 0, # Stack depth x batch are both in outermost dim, stack major. on_value=True, off_value=False, dtype=tf.bool) new_stack = tf1.where( tf.reshape(update_stack_mask, [-1]), tf.reshape(tiled_value, tf.shape(input=self.stack)), self.stack) new_stack.set_shape(self.stack.shape) new_stack_index = self.stack_index + tf.cast( mask, self.stack_index.dtype) new_stack_index.set_shape(self.stack_index.shape) if self._safety_checks(): with tf.control_dependencies([ tf1.assert_less( new_stack_index, tf.cast(max_stack_depth_tensor, new_stack_index.dtype)) ]): value = tf.identity(value) new_stack_index = tf.identity(new_stack_index) return type(self)(new_stack, new_stack_index), value
def _solve( self, ode_fn, initial_time, initial_state, solution_times, jacobian_fn=None, jacobian_sparsity=None, batch_ndims=None, previous_solver_internal_state=None, ): # This function is comprised of the following sequential stages: # (1) Make static assertions. # (2) Initialize variables. # (3) Make non-static assertions. # (4) Solve up to final time. # (5) Return `Results` object. # # The stages can be found in the code by searching for (n) where n=1..5. # # By static vs. non-static assertions (see stages 1 and 3), we mean # assertions that can be made before the graph is run vs. those that can # only be made at run time. The latter are constructed as a list of # tf.Assert operations by the function `assert_ops` (see below). # # If `solution_times` is specified as a `Tensor`, stage 4 consists of three # nested loops, which can be conceptually understood as follows: # ``` # current_time, current_state = initial_time, initial_state # order, step_size = 1, first_step_size # for solution_time in solution_times: # while current_time < solution_time: # while True: # next_time = current_time + step_size # next_state, error = ( # solve_nonlinear_equation_to_get_approximate_state_at_next_time( # current_time, current_state, next_time, order)) # if error < tolerance: # current_time, current_state = next_time, next_state # order, step_size = ( # maybe_update_order_and_step_size(order, step_size)) # break # else: # step_size = decrease_step_size(step_size) # ``` # The outermost loop advances the solver to the next `solution_time` (see # `advance_to_solution_time`). The middle loop advances the solver by a # small timestep (see `step`). The innermost loop determines the size of # that timestep (see `maybe_step`). # # If `solution_times` is specified as # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped # and `solution_time` in the middle loop is replaced by `final_time`. def assert_ops(): """Creates a list of assert operations.""" if not self._validate_args: return [] assert_ops = [] if previous_solver_internal_state is not None: assert_initial_state_matches_previous_solver_internal_state = ( tf1.assert_near( tf.norm( initial_state_vec - previous_solver_internal_state. backward_differences[0], np.inf), 0., message= '`previous_solver_internal_state` does not match ' '`initial_state`.')) assert_ops.append( assert_initial_state_matches_previous_solver_internal_state ) if solution_times_chosen_by_solver: assert_ops.append( util.assert_positive(final_time - initial_time, 'final_time - initial_time')) else: assert_ops += [ util.assert_increasing(solution_times, 'solution_times'), util.assert_nonnegative( solution_times[0] - initial_time, 'solution_times[0] - initial_time'), ] if max_num_steps is not None: assert_ops.append( util.assert_positive(max_num_steps, 'max_num_steps')) if max_num_newton_iters is not None: assert_ops.append( util.assert_positive(max_num_newton_iters, 'max_num_newton_iters')) assert_ops += [ util.assert_positive(rtol, 'rtol'), util.assert_positive(atol, 'atol'), util.assert_positive(first_step_size, 'first_step_size'), util.assert_positive(safety_factor, 'safety_factor'), util.assert_positive(min_step_size_factor, 'min_step_size_factor'), util.assert_positive(max_step_size_factor, 'max_step_size_factor'), tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER), [ '`max_order` must be between 1 and {}.'.format( bdf_util.MAX_ORDER) ]), util.assert_positive(newton_tol_factor, 'newton_tol_factor'), util.assert_positive(newton_step_size_factor, 'newton_step_size_factor'), ] return assert_ops def advance_to_solution_time(n, diagnostics, iterand, solver_internal_state, state_vec_array, time_array): """Takes multiple steps to advance time to `solution_times[n]`.""" def step_cond(next_time, diagnostics, iterand, *_): return (iterand.time < next_time) & (tf.equal( diagnostics.status, 0)) nth_solution_time = solution_time_array.read(n) [ _, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ] = tf.while_loop(step_cond, step, [ nth_solution_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ]) state_vec_array = state_vec_array.write( n, solver_internal_state.backward_differences[0]) time_array = time_array.write(n, nth_solution_time) return (n + 1, diagnostics, iterand, solver_internal_state, state_vec_array, time_array) def step(next_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array): """Takes a single step.""" distance_to_next_time = next_time - iterand.time overstepped = iterand.new_step_size > distance_to_next_time iterand = iterand._replace(new_step_size=tf1.where( overstepped, distance_to_next_time, iterand.new_step_size), should_update_step_size=overstepped | iterand.should_update_step_size) if not self._evaluate_jacobian_lazily: diagnostics = diagnostics._replace( num_jacobian_evaluations=diagnostics. num_jacobian_evaluations + 1) iterand = iterand._replace(jacobian_mat=jacobian_fn_mat( iterand.time, solver_internal_state.backward_differences[0]), jacobian_is_up_to_date=True) def maybe_step_cond(accepted, diagnostics, *_): return tf.logical_not(accepted) & tf.equal( diagnostics.status, 0) _, diagnostics, iterand, solver_internal_state = tf.while_loop( maybe_step_cond, maybe_step, [False, diagnostics, iterand, solver_internal_state]) if solution_times_chosen_by_solver: state_vec_array = state_vec_array.write( state_vec_array.size(), solver_internal_state.backward_differences[0]) time_array = time_array.write(time_array.size(), iterand.time) return (next_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array) def maybe_step(accepted, diagnostics, iterand, solver_internal_state): """Takes a single step only if the outcome has a low enough error.""" [ num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status ] = diagnostics [ jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper ] = iterand [backward_differences, order, step_size] = solver_internal_state if max_num_steps is not None: status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0) backward_differences = tf1.where( should_update_step_size, bdf_util.interpolate_backward_differences( backward_differences, order, new_step_size / step_size), backward_differences) step_size = tf1.where(should_update_step_size, new_step_size, step_size) should_update_factorization = should_update_step_size num_steps_same_size = tf1.where(should_update_step_size, 0, num_steps_same_size) def update_factorization(): return bdf_util.newton_qr( jacobian_mat, newton_coefficients_array.read(order), step_size) if self._evaluate_jacobian_lazily: def update_jacobian_and_factorization(): new_jacobian_mat = jacobian_fn_mat(time, backward_differences[0]) new_unitary, new_upper = update_factorization() return [ new_jacobian_mat, True, num_jacobian_evaluations + 1, new_unitary, new_upper ] def maybe_update_factorization(): new_unitary, new_upper = tf.cond( should_update_factorization, update_factorization, lambda: [unitary, upper]) return [ jacobian_mat, jacobian_is_up_to_date, num_jacobian_evaluations, new_unitary, new_upper ] [ jacobian_mat, jacobian_is_up_to_date, num_jacobian_evaluations, unitary, upper ] = tf.cond(should_update_jacobian, update_jacobian_and_factorization, maybe_update_factorization) else: unitary, upper = update_factorization() num_matrix_factorizations += 1 tol = atol + rtol * tf.abs(backward_differences[0]) newton_tol = newton_tol_factor * tf.norm(tol) [ newton_converged, next_backward_difference, next_state_vec, newton_num_iters ] = bdf_util.newton(backward_differences, max_num_newton_iters, newton_coefficients_array.read(order), ode_fn_vec, order, step_size, time, newton_tol, unitary, upper) num_steps += 1 num_ode_fn_evaluations += newton_num_iters # If Newton's method failed and the Jacobian was up to date, decrease the # step size. newton_failed = tf.logical_not(newton_converged) should_update_step_size = newton_failed & jacobian_is_up_to_date new_step_size = step_size * tf1.where(should_update_step_size, newton_step_size_factor, 1.) # If Newton's method failed and the Jacobian was NOT up to date, update # the Jacobian. should_update_jacobian = newton_failed & tf.logical_not( jacobian_is_up_to_date) error_ratio = tf1.where( newton_converged, bdf_util.error_ratio(next_backward_difference, error_coefficients_array.read(order), tol), np.nan) accepted = error_ratio < 1. converged_and_rejected = newton_converged & tf.logical_not( accepted) # If Newton's method converged but the solution was NOT accepted, decrease # the step size. new_step_size = tf1.where( converged_and_rejected, util.next_step_size(step_size, order, error_ratio, safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = should_update_step_size | converged_and_rejected # If Newton's method converged and the solution was accepted, update the # matrix of backward differences. time = tf1.where(accepted, time + step_size, time) backward_differences = tf1.where( accepted, bdf_util.update_backward_differences(backward_differences, next_backward_difference, next_state_vec, order), backward_differences) jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not( accepted) num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1, num_steps_same_size) # Order and step size are only updated if we have taken strictly more than # order + 1 steps of the same size. This is to prevent the order from # being throttled. should_update_order_and_step_size = accepted & (num_steps_same_size > order + 1) backward_differences_array = tf.TensorArray( backward_differences.dtype, size=bdf_util.MAX_ORDER + 3, clear_after_read=False, element_shape=next_backward_difference.get_shape()).unstack( backward_differences) new_order = order new_error_ratio = error_ratio for offset in [-1, +1]: proposed_order = tf.clip_by_value(order + offset, 1, max_order) proposed_error_ratio = bdf_util.error_ratio( backward_differences_array.read(proposed_order + 1), error_coefficients_array.read(proposed_order), tol) proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio new_order = tf1.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_order, new_order) new_error_ratio = tf1.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_error_ratio, new_error_ratio) order = new_order error_ratio = new_error_ratio new_step_size = tf1.where( should_update_order_and_step_size, util.next_step_size(step_size, order, error_ratio, safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = (should_update_step_size | should_update_order_and_step_size) diagnostics = _BDFDiagnostics(num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status) iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper) solver_internal_state = _BDFSolverInternalState( backward_differences, order, step_size) return accepted, diagnostics, iterand, solver_internal_state # (1) Make static assertions. # TODO(b/138304296): Support specifying Jacobian sparsity patterns. if jacobian_sparsity is not None: raise NotImplementedError( 'The BDF solver does not support specifying ' 'Jacobian sparsity patterns.') if batch_ndims is not None and batch_ndims != 0: raise NotImplementedError( 'The BDF solver does not support batching.') solution_times_chosen_by_solver = (isinstance(solution_times, base.ChosenBySolver)) with tf.name_scope(self._name): # (2) Convert to tensors. error_if_wrong_dtype = functools.partial( util.error_if_not_real_or_complex, identifier='initial_state') initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state) tf.nest.map_structure(error_if_wrong_dtype, initial_state) state_shape = tf.nest.map_structure(tf.shape, initial_state) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) if jacobian_fn is None and common_state_dtype.is_complex: raise NotImplementedError( 'The BDF solver does not support automatic ' 'Jacobian computations for complex dtypes.') # Convert everything to operate on a single, concatenated vector form. initial_state_vec = util.get_state_vec(initial_state) ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape) jacobian_fn_mat = util.get_jacobian_fn_mat( jacobian_fn, ode_fn_vec, state_shape, use_pfor=self._use_pfor_to_compute_jacobian, dtype=common_state_dtype, ) num_odes = tf.size(initial_state_vec) # Use tf.cast instead of tf.convert_to_tensor for differentiable # parameters because the tf.custom_gradient decorator converts raw floats # into tf.float32, which cannot be converted to tf.float64. initial_time = tf.cast(initial_time, real_dtype) num_solution_times = 0 if solution_times_chosen_by_solver: final_time = tf.cast(solution_times.final_time, real_dtype) else: solution_times = tf.cast(solution_times, real_dtype) num_solution_times = tf.size(solution_times) solution_time_array = tf.TensorArray( solution_times.dtype, size=num_solution_times, element_shape=[]).unstack(solution_times) util.error_if_not_vector(solution_times, 'solution_times') rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype) atol = tf.convert_to_tensor(self._atol, dtype=real_dtype) safety_factor = tf.convert_to_tensor(self._safety_factor, dtype=real_dtype) min_step_size_factor = tf.convert_to_tensor( self._min_step_size_factor, dtype=real_dtype) max_step_size_factor = tf.convert_to_tensor( self._max_step_size_factor, dtype=real_dtype) max_num_steps = self._max_num_steps if max_num_steps is not None: max_num_steps = tf.convert_to_tensor(max_num_steps, dtype=tf.int32) max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32) max_num_newton_iters = self._max_num_newton_iters if max_num_newton_iters is not None: max_num_newton_iters = tf.convert_to_tensor( max_num_newton_iters, dtype=tf.int32) newton_tol_factor = tf.convert_to_tensor(self._newton_tol_factor, dtype=real_dtype) newton_step_size_factor = tf.convert_to_tensor( self._newton_step_size_factor, dtype=real_dtype) bdf_coefficients = tf.cast( tf.concat([[0.], tf.convert_to_tensor(self._bdf_coefficients, dtype=real_dtype)], 0), common_state_dtype) util.error_if_not_vector(bdf_coefficients, 'bdf_coefficients') if self._validate_args: initial_time = tf.ensure_shape(initial_time, []) if solution_times_chosen_by_solver: final_time = tf.ensure_shape(final_time, []) safety_factor = tf.ensure_shape(safety_factor, []) min_step_size_factor = tf.ensure_shape(min_step_size_factor, []) max_step_size_factor = tf.ensure_shape(max_step_size_factor, []) if max_num_steps is not None: max_num_steps = tf.ensure_shape(max_num_steps, []) max_order = tf.ensure_shape(max_order, []) if max_num_newton_iters is not None: max_num_newton_iters = tf.ensure_shape( max_num_newton_iters, []) newton_tol_factor = tf.ensure_shape(newton_tol_factor, []) newton_step_size_factor = tf.ensure_shape( newton_step_size_factor, []) bdf_coefficients = tf.ensure_shape(bdf_coefficients, [6]) newton_coefficients = 1. / ( (1. - bdf_coefficients) * bdf_util.RECIPROCAL_SUMS) newton_coefficients_array = tf.TensorArray( newton_coefficients.dtype, size=bdf_util.MAX_ORDER + 1, clear_after_read=False, element_shape=[]).unstack(newton_coefficients) error_coefficients = bdf_coefficients * bdf_util.RECIPROCAL_SUMS + 1. / ( bdf_util.ORDERS + 1) error_coefficients_array = tf.TensorArray( error_coefficients.dtype, size=bdf_util.MAX_ORDER + 1, clear_after_read=False, element_shape=[]).unstack(error_coefficients) first_step_size = self._first_step_size if first_step_size is None: first_step_size = bdf_util.first_step_size( atol, error_coefficients_array.read(1), initial_state_vec, initial_time, ode_fn_vec, rtol, safety_factor) elif previous_solver_internal_state is not None: tf.logging.warn( '`first_step_size` is ignored since' '`previous_solver_internal_state` was specified.') first_step_size = tf.convert_to_tensor(first_step_size, dtype=real_dtype) if self._validate_args: first_step_size = tf.ensure_shape(first_step_size, []) solver_internal_state = previous_solver_internal_state if solver_internal_state is None: first_order_backward_difference = ode_fn_vec( initial_time, initial_state_vec) * tf.cast( first_step_size, common_state_dtype) backward_differences = tf.concat([ initial_state_vec[tf.newaxis, :], first_order_backward_difference[tf.newaxis, :], tf.zeros(tf.stack([bdf_util.MAX_ORDER + 1, num_odes]), dtype=common_state_dtype), ], 0) solver_internal_state = _BDFSolverInternalState( backward_differences=backward_differences, order=1, step_size=first_step_size) state_vec_array = tf.TensorArray( common_state_dtype, size=num_solution_times, dynamic_size=solution_times_chosen_by_solver, element_shape=initial_state_vec.get_shape()) time_array = tf.TensorArray( real_dtype, size=num_solution_times, dynamic_size=solution_times_chosen_by_solver, element_shape=tf.TensorShape([])) diagnostics = _BDFDiagnostics(num_jacobian_evaluations=0, num_matrix_factorizations=0, num_ode_fn_evaluations=0, status=0) iterand = _BDFIterand( jacobian_mat=tf.zeros([num_odes, num_odes], dtype=common_state_dtype), jacobian_is_up_to_date=False, new_step_size=solver_internal_state.step_size, num_steps=0, num_steps_same_size=0, should_update_jacobian=True, should_update_step_size=False, time=initial_time, unitary=tf.zeros([num_odes, num_odes], dtype=common_state_dtype), upper=tf.zeros([num_odes, num_odes], dtype=common_state_dtype)) # (3) Make non-static assertions. with tf.control_dependencies(assert_ops()): # (4) Solve up to final time. if solution_times_chosen_by_solver: def step_cond(next_time, diagnostics, iterand, *_): return (iterand.time < next_time) & (tf.equal( diagnostics.status, 0)) [ _, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ] = tf.while_loop(step_cond, step, [ final_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ]) else: def advance_to_solution_time_cond(n, diagnostics, *_): return (n < num_solution_times) & (tf.equal( diagnostics.status, 0)) [ _, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ] = tf.while_loop( advance_to_solution_time_cond, advance_to_solution_time, [ 0, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ]) # (6) Return `Results` object. states = util.get_state_from_vec(state_vec_array.stack(), state_shape) times = time_array.stack() if not solution_times_chosen_by_solver: times.set_shape(solution_times.get_shape()) tf.nest.map_structure( lambda s, ini_s: s.set_shape( solution_times.get_shape( # pylint: disable=g-long-lambda ).concatenate(ini_s.shape)), states, initial_state) return base.Results( times=times, states=states, diagnostics=diagnostics, solver_internal_state=solver_internal_state)
def lossfun(x, alpha, scale, approximate=False, epsilon=1e-6): r"""Implements the general form of the loss. This implements the rho(x, \alpha, c) function described in "A General and Adaptive Robust Loss Function", Jonathan T. Barron, https://arxiv.org/abs/1701.03077. Args: x: The residual for which the loss is being computed. x can have any shape, and alpha and scale will be broadcasted to match x's shape if necessary. Must be a tensorflow tensor or numpy array of floats. alpha: The shape parameter of the loss (\alpha in the paper), where more negative values produce a loss with more robust behavior (outliers "cost" less), and more positive values produce a loss with less robust behavior (outliers are penalized more heavily). Alpha can be any value in [-infinity, infinity], but the gradient of the loss with respect to alpha is 0 at -infinity, infinity, 0, and 2. Must be a tensorflow tensor or numpy array of floats with the same precision as `x`. Varying alpha allows for smooth interpolation between a number of discrete robust losses: alpha=-Infinity: Welsch/Leclerc Loss. alpha=-2: Geman-McClure loss. alpha=0: Cauchy/Lortentzian loss. alpha=1: Charbonnier/pseudo-Huber loss. alpha=2: L2 loss. scale: The scale parameter of the loss. When |x| < scale, the loss is an L2-like quadratic bowl, and when |x| > scale the loss function takes on a different shape according to alpha. Must be a tensorflow tensor or numpy array of single-precision floats. approximate: a bool, where if True, this function returns an approximate and faster form of the loss, as described in the appendix of the paper. This approximation holds well everywhere except as x and alpha approach zero. epsilon: A float that determines how inaccurate the "approximate" version of the loss will be. Larger values are less accurate but more numerically stable. Must be great than single-precision machine epsilon. Returns: The losses for each element of x, in the same shape as x. This is returned as a TensorFlow graph node of single precision floats. """ # `scale` and `alpha` must have the same type as `x`. float_dtype = x.dtype tf.debugging.assert_type(scale, float_dtype) tf.debugging.assert_type(alpha, float_dtype) # `scale` must be > 0. assert_ops = [tf.Assert(tf.reduce_all(tf.greater(scale, 0.)), [scale])] with tf.control_dependencies(assert_ops): # Broadcast `alpha` and `scale` to have the same shape as `x`. alpha = tf.broadcast_to(alpha, tf.shape(x)) scale = tf.broadcast_to(scale, tf.shape(x)) if approximate: # `epsilon` must be greater than single-precision machine epsilon. assert epsilon > np.finfo(np.float32).eps # Compute an approximate form of the loss which is faster, but innacurate # when x and alpha are near zero. b = tf.abs(alpha - tf.cast(2., float_dtype)) + epsilon d = tf.where( tf.greater_equal(alpha, 0.), alpha + epsilon, alpha - epsilon) loss = (b / d) * (tf.pow(tf.square(x / scale) / b + 1., 0.5 * d) - 1.) else: # Compute the exact loss. # This will be used repeatedly. squared_scaled_x = tf.square(x / scale) # The loss when alpha == 2. loss_two = 0.5 * squared_scaled_x # The loss when alpha == 0. loss_zero = util.log1p_safe(0.5 * squared_scaled_x) # The loss when alpha == -infinity. loss_neginf = -tf.math.expm1(-0.5 * squared_scaled_x) # The loss when alpha == +infinity. loss_posinf = util.expm1_safe(0.5 * squared_scaled_x) # The loss when not in one of the above special cases. machine_epsilon = tf.cast(np.finfo(np.float32).eps, float_dtype) # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by. beta_safe = tf.maximum(machine_epsilon, tf.abs(alpha - 2.)) # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by. alpha_safe = tf.where( tf.greater_equal(alpha, 0.), tf.ones_like(alpha), -tf.ones_like(alpha)) * tf.maximum(machine_epsilon, tf.abs(alpha)) loss_otherwise = (beta_safe / alpha_safe) * ( tf.pow(squared_scaled_x / beta_safe + 1., 0.5 * alpha) - 1.) # Select which of the cases of the loss to return. loss = tf.where( tf.equal(alpha, -tf.cast(float('inf'), float_dtype)), loss_neginf, tf.where( tf.equal(alpha, 0.), loss_zero, tf.where( tf.equal(alpha, 2.), loss_two, tf.where( tf.equal(alpha, tf.cast(float('inf'), float_dtype)), loss_posinf, loss_otherwise)))) return loss
def _kl_independent(a, b, name='kl_independent'): """Batched KL divergence `KL(a || b)` for Independent distributions. We can leverage the fact that ``` KL(Independent(a) || Independent(b)) = sum(KL(a || b)) ``` where the sum is over the `reinterpreted_batch_ndims`. Args: a: Instance of `Independent`. b: Instance of `Independent`. name: (optional) name to use for created ops. Default 'kl_independent'. Returns: Batchwise `KL(a || b)`. Raises: ValueError: If the event space for `a` and `b`, or their underlying distributions don't match. """ p = a.distribution q = b.distribution # The KL between any two (non)-batched distributions is a scalar. # Given that the KL between two factored distributions is the sum, i.e. # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions. if (tensorshape_util.is_fully_defined(a.event_shape) and tensorshape_util.is_fully_defined(b.event_shape)): if a.event_shape == b.event_shape: if p.event_shape == q.event_shape: num_reduce_dims = (tensorshape_util.rank(a.event_shape) - tensorshape_util.rank(p.event_shape)) reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)] return tf.reduce_sum(kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) else: raise NotImplementedError( 'KL between Independents with different ' 'event shapes not supported.') else: raise ValueError('Event shapes do not match.') else: p_event_shape_tensor = p.event_shape_tensor() q_event_shape_tensor = q.event_shape_tensor() # NOTE: We could optimize by passing the event_shape_tensor of p and q # to a.event_shape_tensor() and b.event_shape_tensor(). a_event_shape_tensor = a.event_shape_tensor() b_event_shape_tensor = b.event_shape_tensor() with tf.control_dependencies([ assert_util.assert_equal(a_event_shape_tensor, b_event_shape_tensor, message='Event shapes do not match.'), assert_util.assert_equal(p_event_shape_tensor, q_event_shape_tensor, message='Event shapes do not match.'), ]): num_reduce_dims = ( ps.rank_from_shape(a_event_shape_tensor, a.event_shape) - ps.rank_from_shape(p_event_shape_tensor, p.event_shape)) reduce_dims = ps.range(-num_reduce_dims, 0, 1) return tf.reduce_sum(kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
def _parse_train_data(self, data): """Parse data for ShapeMask training.""" classes = data['groundtruth_classes'] boxes = data['groundtruth_boxes'] masks = data['groundtruth_instance_masks'] is_crowds = data['groundtruth_is_crowd'] # Skips annotations with `is_crowd` = True. if self._skip_crowd_during_training and self._is_training: num_groundtrtuhs = tf.shape(classes)[0] with tf.control_dependencies([num_groundtrtuhs, is_crowds]): indices = tf.cond( tf.greater(tf.size(is_crowds), 0), lambda: tf.where(tf.logical_not(is_crowds))[:, 0], lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64)) classes = tf.gather(classes, indices) boxes = tf.gather(boxes, indices) masks = tf.gather(masks, indices) # Gets original image and its size. image = data['image'] image_shape = tf.shape(image)[0:2] # If not using category, makes all categories with id = 0. if not self._use_category: classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32) # Normalizes image with mean and std pixel values. image = input_utils.normalize_image(image) # Flips image randomly during training. if self._aug_rand_hflip: image, boxes, masks = input_utils.random_horizontal_flip( image, boxes, masks) # Converts boxes from normalized coordinates to pixel coordinates. boxes = box_utils.denormalize_boxes(boxes, image_shape) # Resizes and crops image. image, image_info = input_utils.resize_and_crop_image( image, self._output_size, self._output_size, aug_scale_min=self._aug_scale_min, aug_scale_max=self._aug_scale_max) image_scale = image_info[2, :] offset = image_info[3, :] # Resizes and crops boxes and masks. boxes = input_utils.resize_and_crop_boxes(boxes, image_scale, self._output_size, offset) # Filters out ground truth boxes that are all zeros. indices = input_utils.get_non_empty_box_indices(boxes) boxes = tf.gather(boxes, indices) classes = tf.gather(classes, indices) masks = tf.gather(masks, indices) # Assigns anchors. input_anchor = anchor.Anchor(self._min_level, self._max_level, self._num_scales, self._aspect_ratios, self._anchor_size, self._output_size) anchor_labeler = anchor.AnchorLabeler(input_anchor, self._match_threshold, self._unmatched_threshold) (cls_targets, box_targets, num_positives) = anchor_labeler.label_anchors( boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32)) # Sample groundtruth masks/boxes/classes for mask branch. num_masks = tf.shape(masks)[0] mask_shape = tf.shape(masks)[1:3] # Pad sampled boxes/masks/classes to a constant batch size. padded_boxes = input_utils.pad_to_fixed_size(boxes, self._num_sampled_masks) padded_classes = input_utils.pad_to_fixed_size(classes, self._num_sampled_masks) padded_masks = input_utils.pad_to_fixed_size(masks, self._num_sampled_masks) # Randomly sample groundtruth masks for mask branch training. For the image # without groundtruth masks, it will sample the dummy padded tensors. rand_indices = tf.random.shuffle( tf.range(tf.maximum(num_masks, self._num_sampled_masks))) rand_indices = tf.math.mod(rand_indices, tf.maximum(num_masks, 1)) rand_indices = rand_indices[0:self._num_sampled_masks] rand_indices = tf.reshape(rand_indices, [self._num_sampled_masks]) sampled_boxes = tf.gather(padded_boxes, rand_indices) sampled_classes = tf.gather(padded_classes, rand_indices) sampled_masks = tf.gather(padded_masks, rand_indices) # Jitter the sampled boxes to mimic the noisy detections. sampled_boxes = box_utils.jitter_boxes( sampled_boxes, noise_scale=self._box_jitter_scale) sampled_boxes = box_utils.clip_boxes(sampled_boxes, self._output_size) # Compute mask targets in feature crop. A feature crop fully contains a # sampled box. mask_outer_boxes = box_utils.compute_outer_boxes( sampled_boxes, tf.shape(image)[0:2], scale=self._outer_box_scale) mask_outer_boxes = box_utils.clip_boxes(mask_outer_boxes, self._output_size) # Compensate the offset of mask_outer_boxes to map it back to original image # scale. mask_outer_boxes_ori = mask_outer_boxes mask_outer_boxes_ori += tf.tile(tf.expand_dims(offset, axis=0), [1, 2]) mask_outer_boxes_ori /= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2]) norm_mask_outer_boxes_ori = box_utils.normalize_boxes( mask_outer_boxes_ori, mask_shape) # Set sampled_masks shape to [batch_size, height, width, 1]. sampled_masks = tf.cast(tf.expand_dims(sampled_masks, axis=-1), tf.float32) mask_targets = tf.image.crop_and_resize( sampled_masks, norm_mask_outer_boxes_ori, box_indices=tf.range(self._num_sampled_masks), crop_size=[self._mask_crop_size, self._mask_crop_size], method='bilinear', extrapolation_value=0, name='train_mask_targets') mask_targets = tf.where(tf.greater_equal(mask_targets, 0.5), tf.ones_like(mask_targets), tf.zeros_like(mask_targets)) mask_targets = tf.squeeze(mask_targets, axis=-1) if self._up_sample_factor > 1: fine_mask_targets = tf.image.crop_and_resize( sampled_masks, norm_mask_outer_boxes_ori, box_indices=tf.range(self._num_sampled_masks), crop_size=[ self._mask_crop_size * self._up_sample_factor, self._mask_crop_size * self._up_sample_factor ], method='bilinear', extrapolation_value=0, name='train_mask_targets') fine_mask_targets = tf.where( tf.greater_equal(fine_mask_targets, 0.5), tf.ones_like(fine_mask_targets), tf.zeros_like(fine_mask_targets)) fine_mask_targets = tf.squeeze(fine_mask_targets, axis=-1) else: fine_mask_targets = mask_targets # If bfloat16 is used, casts input image to tf.bfloat16. if self._use_bfloat16: image = tf.cast(image, dtype=tf.bfloat16) valid_image = tf.cast(tf.not_equal(num_masks, 0), tf.int32) if self._mask_train_class == 'all': mask_is_valid = valid_image * tf.ones_like(sampled_classes, tf.int32) else: # Get the intersection of sampled classes with training splits. mask_valid_classes = tf.cast( tf.expand_dims( class_utils.coco_split_class_ids(self._mask_train_class), 1), sampled_classes.dtype) match = tf.reduce_any( tf.equal(tf.expand_dims(sampled_classes, 0), mask_valid_classes), 0) mask_is_valid = valid_image * tf.cast(match, tf.int32) # Packs labels for model_fn outputs. labels = { 'cls_targets': cls_targets, 'box_targets': box_targets, 'anchor_boxes': input_anchor.multilevel_boxes, 'num_positives': num_positives, 'image_info': image_info, # For ShapeMask. 'mask_boxes': sampled_boxes, 'mask_outer_boxes': mask_outer_boxes, 'mask_targets': mask_targets, 'fine_mask_targets': fine_mask_targets, 'mask_classes': sampled_classes, 'mask_is_valid': mask_is_valid, } return image, labels
def interpolate1d(x, values, tangents): r"""Perform cubic hermite spline interpolation on a 1D spline. The x coordinates of the spline knots are at [0 : 1 : len(values)-1]. Queries outside of the range of the spline are computed using linear extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline for details, where "x" corresponds to `x`, "p" corresponds to `values`, and "m" corresponds to `tangents`. Args: x: A tensor of any size of single or double precision floats containing the set of values to be used for interpolation into the spline. values: A vector of single or double precision floats containing the value of each knot of the spline being interpolated into. Must be the same length as `tangents` and the same type as `x`. tangents: A vector of single or double precision floats containing the tangent (derivative) of each knot of the spline being interpolated into. Must be the same length as `values` and the same type as `x`. Returns: The result of interpolating along the spline defined by `values`, and `tangents`, using `x` as the query values. Will be the same length and type as `x`. """ # `values` and `tangents` must have the same type as `x`. tf.debugging.assert_type(values, x.dtype) tf.debugging.assert_type(tangents, x.dtype) float_dtype = x.dtype assert_ops = [ # `values` must be a vector. tf.Assert(tf.equal(tf.rank(values), 1), [tf.shape(values)]), # `tangents` must be a vector. tf.Assert(tf.equal(tf.rank(tangents), 1), [tf.shape(values)]), # `values` and `tangents` must have the same length. tf.Assert( tf.equal(tf.shape(values)[0], tf.shape(tangents)[0]), [tf.shape(values)[0], tf.shape(tangents)[0]]), ] with tf.control_dependencies(assert_ops): # Find the indices of the knots below and above each x. x_lo = tf.cast( tf.floor( tf.clip_by_value(x, 0., tf.cast(tf.shape(values)[0] - 2, float_dtype))), tf.int32) x_hi = x_lo + 1 # Compute the relative distance between each `x` and the knot below it. t = x - tf.cast(x_lo, float_dtype) # Compute the cubic hermite expansion of `t`. t_sq = tf.square(t) t_cu = t * t_sq h01 = -2. * t_cu + 3. * t_sq h00 = 1. - h01 h11 = t_cu - t_sq h10 = h11 - t_sq + t # Linearly extrapolate above and below the extents of the spline for all # values. value_before = tangents[0] * t + values[0] value_after = tangents[-1] * (t - 1.) + values[-1] # Cubically interpolate between the knots below and above each query point. neighbor_values_lo = tf.gather(values, x_lo) neighbor_values_hi = tf.gather(values, x_hi) neighbor_tangents_lo = tf.gather(tangents, x_lo) neighbor_tangents_hi = tf.gather(tangents, x_hi) value_mid = (neighbor_values_lo * h00 + neighbor_values_hi * h01 + neighbor_tangents_lo * h10 + neighbor_tangents_hi * h11) # Return the interpolated or extrapolated values for each query point, # depending on whether or not the query lies within the span of the spline. return tf.where(t < 0., value_before, tf.where(t > 1., value_after, value_mid))
def interpolate(x, x_data, y_data, left_slope=None, right_slope=None, validate_args=False, optimize_for_tpu=False, dtype=None, name=None): """Performs linear interpolation for supplied points. Given a set of knots whose x- and y- coordinates are in `x_data` and `y_data`, this function returns y-values for x-coordinates in `x` via piecewise linear interpolation. `x_data` must be non decreasing, but `y_data` don't need to be because we do not require the function approximated by these knots to be monotonic. #### Examples ```python x = [-10, -1, 1, 3, 6, 7, 8, 15, 18, 25, 30, 35] x_data = [-1, 2, 6, 8, 18, 30.0] y_data = [10, -1, -5, 7, 9, 20] result = linear_interpolation(x, x_data, y_data) # [ 10, 10, 2.66666667, -2, -5, 1, 7, 8.4, 9, 15.41666667, 20, 20] ``` Args: x: x-coordinates for which we need to get interpolation. A N-D `Tensor` of real dtype. First N-1 dimensions represent batching dimensions. x_data: x coordinates. A N-D `Tensor` of real dtype. Should be sorted in non decreasing order. First N-1 dimensions represent batching dimensions. y_data: y coordinates. A N-D `Tensor` of real dtype. Should have the compatible shape as `x_data`. First N-1 dimensions represent batching dimensions. left_slope: The slope to use for extrapolation with x-coordinate smaller than the min `x_data`. It's a 0-D or N-D `Tensor`. Default value: `None`, which maps to `0.0` meaning constant extrapolation, i.e. extrapolated value will be the leftmost `y_data`. right_slope: The slope to use for extrapolation with x-coordinate greater than the max `x_data`. It's a 0-D or N-D `Tensor`. Default value: `None` which maps to `0.0` meaning constant extrapolation, i.e. extrapolated value will be the rightmost `y_data`. validate_args: Python `bool` that indicates whether the function performs the check if the shapes of `x_data` and `y_data` are equal and that the elements in `x_data` are non decreasing. If this value is set to `False` and the elements in `x_data` are not increasing, the result of linear interpolation may be wrong. Default value: `False`. optimize_for_tpu: A Python bool. If `True`, the algorithm uses one-hot encoding to lookup indices of `x_values` in `x_data`. This significantly improves performance of the algorithm on a TPU device but may slow down performance on the CPU. Default value: `False`. dtype: Optional tf.dtype for `x`, x_data`, `y_data`, `left_slope` and `right_slope`. Default value: `None` which means that the `dtype` inferred by TensorFlow is used. name: Python str. The name prefixed to the ops created by this function. Default value: `None` which maps to 'linear_interpolation'. Returns: A N-D `Tensor` of real dtype corresponding to the x-values in `x`. """ name = name or 'linear_interpolation' with tf.name_scope(name): x = tf.convert_to_tensor(x, dtype=dtype, name='x') dtype = dtype or x.dtype x_data = tf.convert_to_tensor(x_data, dtype=dtype, name='x_data') y_data = tf.convert_to_tensor(y_data, dtype=dtype, name='y_data') batch_shape = x.shape.as_list()[:-1] if not batch_shape: x = tf.expand_dims(x, 0) x_data = tf.expand_dims(x_data, 0) y_data = tf.expand_dims(y_data, 0) if left_slope is None: left_slope = tf.constant(0.0, dtype=x.dtype, name='left_slope') else: left_slope = tf.convert_to_tensor(left_slope, dtype=dtype, name='left_slope') if right_slope is None: right_slope = tf.constant(0.0, dtype=x.dtype, name='right_slope') else: right_slope = tf.convert_to_tensor(right_slope, dtype=dtype, name='right_slope') control_deps = [] if validate_args: # Check that `x_data` elements is non-decreasing diffs = x_data[..., 1:] - x_data[..., :-1] assertion = tf.compat.v1.debugging.assert_greater_equal( diffs, tf.zeros_like(diffs), message='x_data is not sorted in non-decreasing order.') control_deps.append(assertion) # Check that the shapes of `x_data` and `y_data` are equal control_deps.append( tf.compat.v1.assert_equal(tf.shape(x_data), tf.shape(y_data))) with tf.control_dependencies(control_deps): # Get upper bound indices for `x`. upper_indices = tf.searchsorted(x_data, x, side='left', out_type=tf.int32) x_data_size = x_data.shape.as_list()[-1] at_min = tf.equal(upper_indices, 0) at_max = tf.equal(upper_indices, x_data_size) # Create tensors in order to be used by `tf.where`. # `values_min` are extrapolated values for x-coordinates less than or # equal to `x_data[..., 0]`. # `values_max` are extrapolated values for x-coordinates greater than # `x_data[..., -1]`. values_min = tf.expand_dims( y_data[..., 0], -1) + left_slope * (x - tf.broadcast_to( tf.expand_dims(x_data[..., 0], -1), shape=tf.shape(x))) values_max = tf.expand_dims( y_data[..., -1], -1) + right_slope * (x - tf.broadcast_to( tf.expand_dims(x_data[..., -1], -1), shape=tf.shape(x))) # `tf.where` evaluates all branches, need to cap indices to ensure it # won't go out of bounds. capped_lower_indices = tf.math.maximum(upper_indices - 1, 0) capped_upper_indices = tf.math.minimum(upper_indices, x_data_size - 1) # Prepare indices for `tf.gather_nd` or `tf.one_hot` # TODO(b/156720909): Extract get_slice logic into a common utilities # module for cubic and linear interpolation if optimize_for_tpu: lower_encoding = tf.one_hot(capped_lower_indices, x_data_size, dtype=dtype) upper_encoding = tf.one_hot(capped_upper_indices, x_data_size, dtype=dtype) else: index_matrix = _prepare_indices(upper_indices) lower_encoding = tf.concat( [index_matrix, tf.expand_dims(capped_lower_indices, -1)], -1) upper_encoding = tf.concat( [index_matrix, tf.expand_dims(capped_upper_indices, -1)], -1) def get_slice(x, encoding): if optimize_for_tpu: return tf.math.reduce_sum(tf.expand_dims(x, axis=-2) * encoding, axis=-1) else: return tf.gather_nd(x, encoding) x_data_lower = get_slice(x_data, lower_encoding) x_data_upper = get_slice(x_data, upper_encoding) y_data_lower = get_slice(y_data, lower_encoding) y_data_upper = get_slice(y_data, upper_encoding) # Nan in unselected branches could propagate through gradient calculation, # hence we need to clip the values to ensure no nan would occur. In this # case we need to ensure there is no division by zero. x_data_diff = x_data_upper - x_data_lower floor_x_diff = tf.where(at_min | at_max, x_data_diff + 1, x_data_diff) interpolated = y_data_lower + (x - x_data_lower) * ( y_data_upper - y_data_lower) / floor_x_diff interpolated = tf.where(at_min, values_min, interpolated) interpolated = tf.where(at_max, values_max, interpolated) if batch_shape: return interpolated else: return tf.squeeze(interpolated, 0)
def update_state(self, values, sample_weight=None): """Accumulates statistics for computing the metric. Args: values: Per-example value. sample_weight: Optional weighting of each example. Defaults to 1. Returns: Update op. """ [ values ], sample_weight = metrics_utils.ragged_assert_compatible_and_get_flat_values( # noqa: E501 [values], sample_weight) try: values = tf.cast(values, self._dtype) except (ValueError, TypeError): msg = ( "The output of a metric function can only be a single Tensor. " f"Received: {values}. ") if isinstance(values, dict): msg += ( "To return a dict of values, implement a custom Metric " "subclass.") raise RuntimeError(msg) if sample_weight is not None: sample_weight = tf.cast(sample_weight, self._dtype) # Update dimensions of weights to match with values if possible. ( values, _, sample_weight, ) = losses_utils.squeeze_or_expand_dimensions( values, sample_weight=sample_weight) try: # Broadcast weights if possible. sample_weight = tf.__internal__.ops.broadcast_weights( sample_weight, values) except ValueError: # Reduce values to same ndim as weight array ndim = backend.ndim(values) weight_ndim = backend.ndim(sample_weight) if self.reduction == metrics_utils.Reduction.SUM: values = tf.reduce_sum(values, axis=list(range(weight_ndim, ndim))) else: values = tf.reduce_mean(values, axis=list(range(weight_ndim, ndim))) values = tf.multiply(values, sample_weight) value_sum = tf.reduce_sum(values) with tf.control_dependencies([value_sum]): update_total_op = self.total.assign_add(value_sum) # Exit early if the reduction doesn't have a denominator. if self.reduction == metrics_utils.Reduction.SUM: return update_total_op # Update `count` for reductions that require a denominator. if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE: num_values = tf.cast(tf.size(values), self._dtype) elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN: if sample_weight is None: num_values = tf.cast(tf.size(values), self._dtype) else: num_values = tf.reduce_sum(sample_weight) else: raise NotImplementedError( f'Reduction "{self.reduction}" not implemented. Expected ' '"sum", "weighted_mean", or "sum_over_batch_size".') with tf.control_dependencies([update_total_op]): return self.count.assign_add(num_values)
def _inverse(self, y): with tf.control_dependencies(self._assertions(y)): return -y, y
def pinv(a, rcond=None, validate_args=False, name=None): """Compute the Moore-Penrose pseudo-inverse of a matrix. Calculate the [generalized inverse of a matrix]( https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its singular-value decomposition (SVD) and including all large singular values. The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves' [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1] This function is analogous to [`numpy.linalg.pinv`]( https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html). It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the default `rcond` is `1e-15`. Here the default is `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`. Args: a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be pseudo-inverted. rcond: `Tensor` of small singular value cutoffs. Singular values smaller (in modulus) than `rcond` * largest_singular_value (again, in modulus) are set to zero. Must broadcast against `tf.shape(a)[:-2]`. Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`. validate_args: When `True`, additional assertions might be embedded in the graph. Default value: `False` (i.e., no graph assertions are added). name: Python `str` prefixed to ops created by this function. Default value: 'pinv'. Returns: a_pinv: The pseudo-inverse of input `a`. Has same shape as `a` except rightmost two dimensions are transposed. Raises: TypeError: if input `a` does not have `float`-like `dtype`. ValueError: if input `a` has fewer than 2 dimensions. #### Examples ```python import tensorflow as tf import tensorflow_probability as tfp a = tf.constant([[1., 0.4, 0.5], [0.4, 0.2, 0.25], [0.5, 0.25, 0.35]]) tf.matmul(tfp.math.pinv(a), a) # ==> array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) a = tf.constant([[1., 0.4, 0.5, 1.], [0.4, 0.2, 0.25, 2.], [0.5, 0.25, 0.35, 3.]]) tf.matmul(tfp.math.pinv(a), a) # ==> array([[ 0.76, 0.37, 0.21, -0.02], [ 0.37, 0.43, -0.33, 0.02], [ 0.21, -0.33, 0.81, 0.01], [-0.02, 0.02, 0.01, 1. ]], dtype=float32) ``` #### References [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press, Inc., 1980, pp. 139-142. """ with tf.name_scope(name or 'pinv'): a = tf.convert_to_tensor(a, name='a') assertions = _maybe_validate_matrix(a, validate_args) if assertions: with tf.control_dependencies(assertions): a = tf.identity(a) dtype = a.dtype.as_numpy_dtype if rcond is None: def get_dim_size(dim): if tf.compat.dimension_value(a.shape[dim]) is not None: return tf.compat.dimension_value(a.shape[dim]) return tf.shape(a)[dim] num_rows = get_dim_size(-2) num_cols = get_dim_size(-1) if isinstance(num_rows, int) and isinstance(num_cols, int): max_rows_cols = float(max(num_rows, num_cols)) else: max_rows_cols = tf.cast(tf.maximum(num_rows, num_cols), dtype) rcond = 10. * max_rows_cols * np.finfo(dtype).eps rcond = tf.convert_to_tensor(rcond, dtype=dtype, name='rcond') # Calculate pseudo inverse via SVD. # Note: if a is symmetric then u == v. (We might observe additional # performance by explicitly setting `v = u` in such cases.) [ singular_values, # Sigma left_singular_vectors, # U right_singular_vectors, # V ] = tf.linalg.svd(a, full_matrices=False, compute_uv=True) # Saturate small singular values to inf. This has the effect of make # `1. / s = 0.` while not resulting in `NaN` gradients. cutoff = rcond * tf.reduce_max(singular_values, axis=-1) singular_values = tf.where(singular_values > cutoff[..., tf.newaxis], singular_values, np.array(np.inf, dtype)) # Although `a == tf.matmul(u, s * v, transpose_b=True)` we swap # `u` and `v` here so that `tf.matmul(pinv(A), A) = tf.eye()`, i.e., # a matrix inverse has 'transposed' semantics. a_pinv = tf.matmul(right_singular_vectors / singular_values[..., tf.newaxis, :], left_singular_vectors, adjoint_b=True) if a.shape.ndims is not None: a_pinv.set_shape(a.shape[:-2].concatenate( [a.shape[-1], a.shape[-2]])) return a_pinv
def _inverse_log_det_jacobian(self, y): with tf.control_dependencies(self._maybe_assert_valid_y(y)): return (self.power - 1.) * tf.math.log(y)
def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None): """Computes a matrix inverse given the matrix's LU decomposition. This op is conceptually identical to, ```python inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X)) tf.assert_near(tf.matrix_inverse(X), inv_X) # ==> True ``` Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_matrix_inverse'). Returns: inv_x: The matrix_inv, i.e., `tf.matrix_inverse(tfp.math.lu_reconstruct(lu, perm))`. #### Examples ```python import numpy as np import tensorflow as tf import tensorflow_probability as tfp x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] inv_x = tfp.math.lu_matrix_inverse(*tf.linalg.lu(x)) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with tf.name_scope(name or 'lu_matrix_inverse'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') assertions = _lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) shape = tf.shape(lower_upper) return lu_solve(lower_upper, perm, rhs=tf.eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype), validate_args=False)
def _forward_log_det_jacobian(self, x): with tf.control_dependencies(self._maybe_assert_valid_x(x)): if self.power == 0.: return x return (1. / self.power - 1.) * tf.math.log1p(x * self.power)
def __init__(self, initial_distribution, transition_distribution, observation_distribution, num_steps, validate_args=False, allow_nan_stats=True, name="HiddenMarkovModel"): """Initialize hidden Markov model. Args: initial_distribution: A `Categorical`-like instance. Determines probability of first hidden state in Markov chain. The number of categories must match the number of categories of `transition_distribution` as well as both the rightmost batch dimension of `transition_distribution` and the rightmost batch dimension of `observation_distribution`. transition_distribution: A `Categorical`-like instance. The rightmost batch dimension indexes the probability distribution of each hidden state conditioned on the previous hidden state. observation_distribution: A `tfp.distributions.Distribution`-like instance. The rightmost batch dimension indexes the distribution of each observation conditioned on the corresponding hidden state. num_steps: The number of steps taken in Markov chain. A python `int`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Raises: ValueError: if `num_steps` is not at least 1. ValueError: if `initial_distribution` does not have scalar `event_shape`. ValueError: if `transition_distribution` does not have scalar `event_shape.` ValueError: if `transition_distribution` and `observation_distribution` are fully defined but don't have matching rightmost dimension. """ parameters = dict(locals()) # pylint: disable=protected-access with tf.name_scope(name) as name: self._runtime_assertions = [] # pylint: enable=protected-access num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps") if validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.rank(num_steps), 0, message="`num_steps` must be a scalar") ] self._runtime_assertions += [ assert_util.assert_greater_equal( num_steps, 1, message="`num_steps` must be at least 1.") ] self._initial_distribution = initial_distribution self._observation_distribution = observation_distribution self._transition_distribution = transition_distribution if (initial_distribution.event_shape is not None and tensorshape_util.rank( initial_distribution.event_shape) != 0): raise ValueError( "`initial_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape(initial_distribution.event_shape_tensor())[0], 0, message="`initial_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.event_shape is not None and tensorshape_util.rank( transition_distribution.event_shape) != 0): raise ValueError( "`transition_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape( transition_distribution.event_shape_tensor())[0], 0, message="`transition_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.batch_shape is not None and tensorshape_util.rank( transition_distribution.batch_shape) == 0): raise ValueError( "`transition_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(transition_distribution.batch_shape_tensor()), 0, message="`transition_distribution` can't have scalar " "batches") ] if (observation_distribution.batch_shape is not None and tensorshape_util.rank( observation_distribution.batch_shape) == 0): raise ValueError( "`observation_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(observation_distribution.batch_shape_tensor()), 0, message="`observation_distribution` can't have scalar " "batches") ] # Infer number of hidden states and check consistency # between transitions and observations with tf.control_dependencies(self._runtime_assertions): self._num_states = ( (transition_distribution.batch_shape and transition_distribution.batch_shape[-1]) or transition_distribution.batch_shape_tensor()[-1]) observation_states = ( (observation_distribution.batch_shape and observation_distribution.batch_shape[-1]) or observation_distribution.batch_shape_tensor()[-1]) if (tf.is_tensor(self._num_states) or tf.is_tensor(observation_states)): if validate_args: self._runtime_assertions += [ assert_util.assert_equal( self._num_states, observation_states, message="`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") ] elif self._num_states != observation_states: raise ValueError("`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") self._log_init = _extract_log_probs(self._num_states, initial_distribution) self._log_trans = _extract_log_probs(self._num_states, transition_distribution) self._num_steps = num_steps self._num_states = tf.shape(self._log_init)[-1] self._underlying_event_rank = tf.size( self._observation_distribution.event_shape_tensor()) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: self.static_event_shape = tf.TensorShape([ num_steps_ ]).concatenate(self._observation_distribution.event_shape) else: self.static_event_shape = None with tf.control_dependencies(self._runtime_assertions): self.static_batch_shape = tf.broadcast_static_shape( self._initial_distribution.batch_shape, tf.broadcast_static_shape( self._transition_distribution.batch_shape[:-1], self._observation_distribution.batch_shape[:-1])) # pylint: disable=protected-access super(HiddenMarkovModel, self).__init__( dtype=self._observation_distribution.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=(self._initial_distribution._graph_parents + self._transition_distribution._graph_parents + self._observation_distribution._graph_parents), name=name) # pylint: enable=protected-access self._parameters = parameters
def _validate_arg_if_not_none(arg, assertion, validate_args): if arg is None: return arg with tf.control_dependencies([assertion(arg)] if validate_args else []): result = tf.identity(arg) return result
def _sample_n(self, n, seed=None): with tf.control_dependencies(self._runtime_assertions): seed = seed_stream.SeedStream(seed, salt="HiddenMarkovModel") num_states = self._num_states batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=seed()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._transition_distribution.batch_shape_tensor()[:-1])) def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=seed()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states return tf.reduce_sum(old_states_one_hot * new_states, axis=-1) def _scan_multiple_steps(): dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot( hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = (batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n]) inner_shape = self._observation_distribution.event_shape # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat([[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def _effective_sample_size_single_state(states, filter_beyond_lag, filter_threshold, filter_beyond_positive_pairs, cross_chain_dims, validate_args): """ESS computation for one single Tensor argument.""" with tf.name_scope('effective_sample_size_single_state'): states = tf.convert_to_tensor(states, name='states') dt = states.dtype # filter_beyond_lag == None ==> auto_corr is the full sequence. auto_cov = stats.auto_correlation( states, axis=0, max_lags=filter_beyond_lag, normalize=False) n = _axis_size(states, axis=0) if cross_chain_dims is not None: num_chains = _axis_size(states, cross_chain_dims) num_chains_ = tf.get_static_value(num_chains) assertions = [] msg = ('When `cross_chain_dims` is not `None`, there must be > 1 chain ' 'in `states`.') if num_chains_ is not None: if num_chains_ < 2: raise ValueError(msg) elif validate_args: assertions.append( assert_util.assert_greater(num_chains, 1., message=msg)) with tf.control_dependencies(assertions): # We're computing the R[k] from equation 10 of Vehtari et al. # (2019): # # R[k] := 1 - (W - 1/C * Sum_{c=1}^C s_c**2 R[k, c]) / (var^+), # # where: # C := number of chains # N := length of chains # x_hat[c] := 1 / N Sum_{n=1}^N x[n, c], chain mean. # x_hat := 1 / C Sum_{c=1}^C x_hat[c], overall mean. # W := 1/C Sum_{c=1}^C s_c**2, within-chain variance. # B := N / (C - 1) Sum_{c=1}^C (x_hat[c] - x_hat)**2, between chain # variance. # s_c**2 := 1 / (N - 1) Sum_{n=1}^N (x[n, c] - x_hat[c])**2, chain # variance # R[k, m] := auto_corr[k, m, ...], auto-correlation indexed by chain. # var^+ := (N - 1) / N * W + B / N cross_chain_dims = ps.non_negative_axis( cross_chain_dims, ps.rank(states)) # B / N between_chain_variance_div_n = _reduce_variance( tf.reduce_mean(states, axis=0), biased=False, # This makes the denominator be C - 1. axis=cross_chain_dims - 1) # W * (N - 1) / N biased_within_chain_variance = tf.reduce_mean(auto_cov[0], cross_chain_dims - 1) # var^+ approx_variance = ( biased_within_chain_variance + between_chain_variance_div_n) # 1/C * Sum_{c=1}^C s_c**2 R[k, c] mean_auto_cov = tf.reduce_mean(auto_cov, cross_chain_dims) auto_corr = 1. - (biased_within_chain_variance - mean_auto_cov) / approx_variance else: auto_corr = auto_cov / auto_cov[:1] num_chains = 1 # With R[k] := auto_corr[k, ...], # ESS = N / {1 + 2 * Sum_{k=1}^N R[k] * (N - k) / N} # = N / {-1 + 2 * Sum_{k=0}^N R[k] * (N - k) / N} (since R[0] = 1) # approx N / {-1 + 2 * Sum_{k=0}^M R[k] * (N - k) / N} # where M is the filter_beyond_lag truncation point chosen above. # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total # ndims the same as auto_corr k = tf.range(0., _axis_size(auto_corr, axis=0)) nk_factor = (n - k) / n if tensorshape_util.rank(auto_corr.shape) is not None: new_shape = [-1] + [1] * (tensorshape_util.rank(auto_corr.shape) - 1) else: new_shape = tf.concat( ([-1], tf.ones([tf.rank(auto_corr) - 1], dtype=tf.int32)), axis=0) nk_factor = tf.reshape(nk_factor, new_shape) weighted_auto_corr = nk_factor * auto_corr if filter_beyond_positive_pairs: def _sum_pairs(x): x_len = ps.shape(x)[0] # For odd sequences, we drop the final value. x = x[:x_len - x_len % 2] new_shape = ps.concat([[x_len // 2, 2], ps.shape(x)[1:]], axis=0) return tf.reduce_sum(tf.reshape(x, new_shape), 1) # Pairwise sums are all positive for auto-correlation spectra derived from # reversible MCMC chains. # E.g. imagine the pairwise sums are [0.2, 0.1, -0.1, -0.2] # Step 1: mask = [False, False, True, True] mask = _sum_pairs(auto_corr) < 0. # Step 2: mask = [0, 0, 1, 1] mask = tf.cast(mask, dt) # Step 3: mask = [0, 0, 1, 2] mask = tf.cumsum(mask, axis=0) # Step 4: mask = [1, 1, 0, 0] mask = tf.maximum(1. - mask, 0.) # N.B. this reduces the length of weighted_auto_corr by a factor of 2. # It still works fine in the formula below. weighted_auto_corr = _sum_pairs(weighted_auto_corr) * mask elif filter_threshold is not None: filter_threshold = tf.convert_to_tensor( filter_threshold, dtype=dt, name='filter_threshold') # Get a binary mask to zero out values of auto_corr below the threshold. # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, # mask[i, ...] = 0, otherwise. # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] # Building step by step, # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. # Step 1: mask = [False, False, True, False] mask = auto_corr < filter_threshold # Step 2: mask = [0, 0, 1, 0] mask = tf.cast(mask, dtype=dt) # Step 3: mask = [0, 0, 1, 1] mask = tf.cumsum(mask, axis=0) # Step 4: mask = [1, 1, 0, 0] mask = tf.maximum(1. - mask, 0.) weighted_auto_corr *= mask return num_chains * n / (-1 + 2 * tf.reduce_sum(weighted_auto_corr, axis=0))
def posterior_mode(self, observations, mask=None, name=None): """Compute maximum likelihood sequence of hidden states. When this function is provided with a sequence of observations `x[0], ..., x[num_steps - 1]`, it returns the sequence of hidden states `z[0], ..., z[num_steps - 1]`, drawn from the underlying Markov chain, that is most likely to yield those observations. It uses the [Viterbi algorithm]( https://en.wikipedia.org/wiki/Viterbi_algorithm). Note: the behavior of this function is undefined if the `observations` argument represents impossible observations from the model. Note: if there isn't a unique most likely sequence then one of the equally most likely sequences is chosen. Args: observations: A tensor representing a batch of observations made on the hidden Markov model. The rightmost dimensions of this tensor correspond to the dimensions of the observation distributions of the underlying Markov chain. The next dimension from the right indexes the steps in a sequence of observations from a single sample from the hidden Markov model. The size of this dimension should match the `num_steps` parameter of the hidden Markov model object. The other dimensions are the dimensions of the batch and these are broadcast with the hidden Markov model's parameters. mask: optional bool-type `tensor` with rightmost dimension matching `num_steps` indicating which observations the result of this function should be conditioned on. When the mask has value `True` the corresponding observations aren't used. if `mask` is `None` then all of the observations are used. the `mask` dimensions left of the last are broadcast with the hmm batch as well as with the observations. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Returns: posterior_mode: A `Tensor` representing the most likely sequence of hidden states. The rightmost dimension of this tensor will equal the `num_steps` parameter providing one hidden state for each step. The other dimensions are those of the batch. Raises: ValueError: if the `observations` tensor does not consist of sequences of `num_steps` observations. #### Examples ```python tfd = tfp.distributions # A simple weather model. # Represent a cold day with 0 and a hot day with 1. # Suppose the first day of a sequence has a 0.8 chance of being cold. initial_distribution = tfd.Categorical(probs=[0.8, 0.2]) # Suppose a cold day has a 30% chance of being followed by a hot day # and a hot day has a 20% chance of being followed by a cold day. transition_distribution = tfd.Categorical(probs=[[0.7, 0.3], [0.2, 0.8]]) # Suppose additionally that on each day the temperature is # normally distributed with mean and standard deviation 0 and 5 on # a cold day and mean and standard deviation 15 and 10 on a hot day. observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.]) # This gives the hidden Markov model: model = tfd.HiddenMarkovModel( initial_distribution=initial_distribution, transition_distribution=transition_distribution, observation_distribution=observation_distribution, num_steps=7) # Suppose we observe gradually rising temperatures over a week: temps = [-2., 0., 2., 4., 6., 8., 10.] # We can now compute the most probable sequence of hidden states: model.posterior_mode(temps) # The result is [0 0 0 0 0 1 1] telling us that the transition # from "cold" to "hot" most likely happened between the # 5th and 6th days. ``` """ with tf.name_scope(name or "posterior_mode"): observations = tf.convert_to_tensor(observations, name="observations") if mask is not None: mask = tf.convert_to_tensor(mask, name="mask", dtype_hint=tf.bool) with tf.control_dependencies(self._runtime_assertions): observation_tensor_shape = tf.shape(observations) mask_tensor_shape = tf.shape( mask) if mask is not None else None with self._observation_mask_shape_preconditions( observation_tensor_shape, mask_tensor_shape): observation_log_probs = self._observation_log_probs( observations, mask) log_prob = self._log_init + observation_log_probs[0] def _reduce_multiple_steps(): """Perform `reduce_max` operation when `num_steps` > 1.""" def forward_step(previous_step_pair, log_prob_observation): log_prob_previous = previous_step_pair[0] log_prob = ( log_prob_previous[..., tf.newaxis] + self._log_trans + log_prob_observation[..., tf.newaxis, :]) most_likely_given_successor = tf.argmax(log_prob, axis=-2) max_log_p_given_successor = tf.reduce_max( input_tensor=log_prob, axis=-2) return (max_log_p_given_successor, most_likely_given_successor) forward_log_probs, all_most_likely_given_successor = tf.scan( forward_step, observation_log_probs[1:], initializer=(log_prob, tf.zeros(tf.shape(log_prob), dtype=tf.int64)), name="forward_log_probs") most_likely_end = tf.argmax(forward_log_probs[-1], axis=-1) # We require the operation that gives C from A and B where # C[i...j] = A[i...j, B[i...j]] # and A = most_likely_given_successor # B = most_likely_successor. # tf.gather requires indices of known shape so instead we use # reduction with tf.one_hot(B) to pick out elements from B def backward_step(most_likely_successor, most_likely_given_successor): return tf.reduce_sum( input_tensor=(most_likely_given_successor * tf.one_hot(most_likely_successor, self._num_states, dtype=tf.int64)), axis=-1) backward_scan = tf.scan( backward_step, all_most_likely_given_successor, most_likely_end, reverse=True) most_likely_sequences = tf.concat( [backward_scan, [most_likely_end]], axis=0) return distribution_util.move_dimension( most_likely_sequences, 0, -1) return prefer_static.cond( self.num_steps > 1, _reduce_multiple_steps, lambda: tf.argmax(log_prob, axis=-1)[..., tf.newaxis])
def _potential_scale_reduction_single_state(state, independent_chain_ndims, split_chains, validate_args): """potential_scale_reduction for one single state `Tensor`.""" # casting integers to floats for floating-point division # check to see if the `state` is a numpy object for the numpy test suite if dtype_util.as_numpy_dtype(state.dtype) is np.int64: state = tf.cast(state, tf.float64) elif dtype_util.is_integer(state.dtype): state = tf.cast(state, tf.float32) with tf.name_scope('potential_scale_reduction_single_state'): # We assume exactly one leading dimension indexes e.g. correlated samples # from each Markov chain. state = tf.convert_to_tensor(state, name='state') n_samples_ = tf.compat.dimension_value(state.shape[0]) if n_samples_ is not None: # If available statically. if split_chains and n_samples_ < 4: raise ValueError( 'Must provide at least 4 samples when splitting chains. ' 'Found {}'.format(n_samples_)) if not split_chains and n_samples_ < 2: raise ValueError( 'Must provide at least 2 samples. Found {}'.format(n_samples_)) elif validate_args: if split_chains: assertions = [assert_util.assert_greater( ps.shape(state)[0], 4, message='Must provide at least 4 samples when splitting chains.')] with tf.control_dependencies(assertions): state = tf.identity(state) else: assertions = [assert_util.assert_greater( ps.shape(state)[0], 2, message='Must provide at least 2 samples.')] with tf.control_dependencies(assertions): state = tf.identity(state) # Define so it's not a magic number. # Warning! `if split_chains` logic assumes this is 1! sample_ndims = 1 if split_chains: # Split the sample dimension in half, doubling the number of # independent chains. # For odd number of samples, keep all but the last sample. state_shape = ps.shape(state) n_samples = state_shape[0] state = state[:n_samples - n_samples % 2] # Suppose state = [0, 1, 2, 3, 4, 5] # Step 1: reshape into [[0, 1, 2], [3, 4, 5]] # E.g. reshape states of shape [a, b] into [2, a//2, b]. state = tf.reshape( state, ps.concat([[2, n_samples // 2], state_shape[1:]], axis=0) ) # Step 2: Put the size `2` dimension in the right place to be treated as a # chain, changing [[0, 1, 2], [3, 4, 5]] into [[0, 3], [1, 4], [2, 5]], # reshaping [2, a//2, b] into [a//2, 2, b]. state = tf.transpose( a=state, perm=ps.concat( [[1, 0], ps.range(2, ps.rank(state))], axis=0)) # We're treating the new dim as indexing 2 chains, so increment. independent_chain_ndims += 1 sample_axis = ps.range(0, sample_ndims) chain_axis = ps.range(sample_ndims, sample_ndims + independent_chain_ndims) sample_and_chain_axis = ps.range( 0, sample_ndims + independent_chain_ndims) n = _axis_size(state, sample_axis) m = _axis_size(state, chain_axis) # In the language of Brooks and Gelman (1998), # B / n is the between chain variance, the variance of the chain means. # W is the within sequence variance, the mean of the chain variances. b_div_n = _reduce_variance( tf.reduce_mean(state, axis=sample_axis, keepdims=True), sample_and_chain_axis, biased=False) w = tf.reduce_mean( _reduce_variance(state, sample_axis, keepdims=True, biased=False), axis=sample_and_chain_axis) # sigma^2_+ is an estimate of the true variance, which would be unbiased if # each chain was drawn from the target. c.f. "law of total variance." sigma_2_plus = ((n - 1) / n) * w + b_div_n return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, kernel_shape[-1], batch_shape, event_shape) idx, shape = im2row_index((xh * sh + sum(pad_values[0]), xw * sw + sum(pad_values[1]), c_in), block_shape=filter_shape, slice_step=(1, 1), dilations=dilations, dtype=dtype, transpose=True) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) # Interleave the rows and columns of the input with rows and columns of # zeros equal to the number of strides. x_half_dilated = tf.concat([ tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)], axis=0), dtype=input_dtype), tf.reshape(x, shape=ps.concat( [batch_shape, (xh * xw, 1, c_in)], axis=0)) ], axis=-2) y = tf.reshape(x_half_dilated, shape=ps.concat( [batch_shape, (xh, 1, xw * sw, c_in)], axis=0)) x = tf.reshape(tf.concat([ tf.zeros(ps.concat( [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0), dtype=input_dtype), y ], axis=-3), shape=ps.concat( [batch_shape, (xh * sh, xw * sw, c_in)], axis=0)) truncations = -ps.minimum(ps.cast(paddings, dtype=tf.int32), 0) truncate_start, truncate_end = ps.unstack(truncations, axis=1) x_truncate = tf.slice(x, begin=truncate_start, size=ps.shape(x) - (truncate_start + truncate_end)) x_pad = tf.pad(x_truncate, paddings=ps.maximum(paddings, 0), constant_values=0) flat_shape = ps.pad(batch_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape), indices=idx, axis=-1) im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
def _sample_control_dependencies(self, x): """Helper which validates sample arg, e.g., input to `log_prob`.""" x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else tensorshape_util.rank(x.shape)) event_ndims = (tf.size(self.event_shape_tensor()) if tensorshape_util.rank(self.event_shape) is None else tensorshape_util.rank(self.event_shape)) batch_ndims = (tf.size(self.batch_shape_tensor()) if tensorshape_util.rank(self.batch_shape) is None else tensorshape_util.rank(self.batch_shape)) expected_batch_event_ndims = batch_ndims + event_ndims if (isinstance(x_ndims, int) and isinstance(expected_batch_event_ndims, int)): if x_ndims < expected_batch_event_ndims: raise NotImplementedError( 'Broadcasting is not supported; too few batch and event dims ' '(expected at least {}, saw {}).'.format( expected_batch_event_ndims, x_ndims)) ndims_assertion = [] elif self.validate_args: ndims_assertion = [ assert_util.assert_greater_equal( x_ndims, expected_batch_event_ndims, message=('Broadcasting is not supported; too few ' 'batch and event dims.'), name='assert_batch_and_event_ndims_large_enough'), ] if (tensorshape_util.is_fully_defined(self.batch_shape) and tensorshape_util.is_fully_defined(self.event_shape)): expected_batch_event_shape = np.int32( tensorshape_util.concatenate(self.batch_shape, self.event_shape)) else: expected_batch_event_shape = tf.concat([ self.batch_shape_tensor(), self.event_shape_tensor(), ], axis=0) sample_ndims = x_ndims - expected_batch_event_ndims if isinstance(sample_ndims, int): sample_ndims = max(sample_ndims, 0) if (isinstance(sample_ndims, int) and tensorshape_util.is_fully_defined(x.shape[sample_ndims:])): actual_batch_event_shape = np.int32(x.shape[sample_ndims:]) else: sample_ndims = tf.maximum(sample_ndims, 0) actual_batch_event_shape = tf.shape(x)[sample_ndims:] assertions = [] if (isinstance(expected_batch_event_shape, np.ndarray) and isinstance(actual_batch_event_shape, np.ndarray)): if any(expected_batch_event_shape != actual_batch_event_shape): raise NotImplementedError('Broadcasting is not supported; ' 'unexpected batch and event shape ' '(expected {}, saw {}).'.format( expected_batch_event_shape, actual_batch_event_shape)) assertions.extend(ndims_assertion) elif self.validate_args: with tf.control_dependencies(ndims_assertion): shape_assertion = assert_util.assert_equal( expected_batch_event_shape, actual_batch_event_shape, message=('Broadcasting is not supported; ' 'unexpected batch and event shape.'), name='assert_batch_and_event_shape_same') assertions.append(shape_assertion) return assertions