def _build_sampler_loop_body(model, observed_time_series, is_missing=None): """Builds a Gibbs sampler for the given model and observed data. Args: model: A `tf.sts.StructuralTimeSeries` model instance. This must be of the form constructed by `build_model_for_gibbs_sampling`. observed_time_series: Float `Tensor` time series of shape `[..., num_timesteps]`. is_missing: Optional `bool` `Tensor` of shape `[..., num_timesteps]`. A `True` value indicates that the observation for that timestep is missing. Returns: sampler_loop_body: Python callable that performs a single cycle of Gibbs sampling. Its first argument is a `GibbsSamplerState`, and it returns a new `GibbsSamplerState`. The second argument (passed by `tf.scan`) is ignored. """ level_component = model.components[0] if not (isinstance(level_component, sts.LocalLevel) or isinstance(level_component, sts.LocalLinearTrend)): raise ValueError( 'Expected the first model component to be an instance of ' '`tfp.sts.LocalLevel` or `tfp.sts.LocalLinearTrend`; ' 'instead saw {}'.format(level_component)) model_has_slope = isinstance(level_component, sts.LocalLinearTrend) # Require that the model has exactly the parameters expected by # `GibbsSamplerState`. if model_has_slope: (observation_noise_param, level_scale_param, slope_scale_param, weights_param) = model.parameters else: (observation_noise_param, level_scale_param, weights_param) = model.parameters # If all non-`slope_scale` parameters are as expected, by process of # elimination `slope_scale` must be correct as well. if (('observation_noise' not in observation_noise_param.name) or ('level_scale' not in level_scale_param.name) or ('weights' not in weights_param.name)): raise ValueError( 'Model parameters {} do not match the expected sampler ' 'state.'.format(model.parameters)) if is_missing is not None: # Ensure series does not contain NaNs. observed_time_series = tf.where(is_missing, tf.zeros_like(observed_time_series), observed_time_series) num_observed_steps = prefer_static.shape(observed_time_series)[-1] design_matrix = _get_design_matrix(model).to_dense()[:num_observed_steps] # Untransform scale priors -> variance priors by reaching thru Sqrt bijector. level_scale_variance_prior = level_scale_param.prior.distribution if model_has_slope: slope_scale_variance_prior = slope_scale_param.prior.distribution observation_noise_variance_prior = observation_noise_param.prior.distribution def sampler_loop_body(previous_sample, _): """Runs one sampler iteration, resampling all model variables.""" (weights_seed, level_seed, observation_noise_scale_seed, level_scale_seed, loop_seed) = samplers.split_seed(previous_sample.seed, n=5, salt='sampler_loop_body') # Preserve backward-compatible seed behavior by splitting slope separately. slope_scale_seed, = samplers.split_seed(previous_sample.seed, n=1, salt='sampler_loop_body_slope') # We encourage a reasonable initialization by sampling the weights first, # so at the first step they are regressed directly against the observed # time series. If we instead sampled the level first it might 'explain away' # some observed variation that we would ultimately prefer to explain through # the regression weights, because the level can represent arbitrary # variation, while the weights are limited to representing variation in the # subspace given by the design matrix. weights = _resample_weights( design_matrix=design_matrix, target_residuals=(observed_time_series - previous_sample.level), observation_noise_scale=previous_sample.observation_noise_scale, weights_prior_scale=weights_param.prior.distribution.scale, is_missing=is_missing, seed=weights_seed) regression_residuals = observed_time_series - tf.linalg.matvec( design_matrix, weights) latents = _resample_latents( observed_residuals=regression_residuals, level_scale=previous_sample.level_scale, slope_scale=previous_sample.slope_scale if model_has_slope else None, observation_noise_scale=previous_sample.observation_noise_scale, initial_state_prior=level_component.initial_state_prior, is_missing=is_missing, seed=level_seed) level = latents[..., 0] level_residuals = level[..., 1:] - level[..., :-1] if model_has_slope: slope = latents[..., 1] level_residuals -= slope[..., :-1] slope_residuals = slope[..., 1:] - slope[..., :-1] # Estimate level scale from the empirical changes in level. level_scale = _resample_scale(prior=level_scale_variance_prior, observed_residuals=level_residuals, is_missing=None, seed=level_scale_seed) if model_has_slope: slope_scale = _resample_scale(prior=slope_scale_variance_prior, observed_residuals=slope_residuals, is_missing=None, seed=slope_scale_seed) # Estimate noise scale from the residuals. observation_noise_scale = _resample_scale( prior=observation_noise_variance_prior, observed_residuals=regression_residuals - level, is_missing=is_missing, seed=observation_noise_scale_seed) return GibbsSamplerState( observation_noise_scale=observation_noise_scale, level_scale=level_scale, slope_scale=(slope_scale if model_has_slope else previous_sample.slope_scale), weights=weights, level=level, slope=(slope if model_has_slope else previous_sample.slope), seed=loop_seed) return sampler_loop_body
def fit_with_gibbs_sampling(model, observed_time_series, num_chains=(), num_results=2000, num_warmup_steps=200, initial_state=None, seed=None): """Fits parameters for an STS model using Gibbs sampling. Args: model: A `tfp.sts.StructuralTimeSeries` model instance return by `build_model_for_gibbs_fitting`. observed_time_series: `float` `Tensor` of shape [..., T, 1]` (omitting the trailing unit dimension is also supported when `T > 1`), specifying an observed time series. May optionally be an instance of `tfp.sts.MaskedTimeSeries`, which includes a mask `Tensor` to specify timesteps with missing observations. num_chains: Optional int to indicate the number of parallel MCMC chains. Default to an empty tuple to sample a single chain. num_results: Optional int to indicate number of MCMC samples. num_warmup_steps: Optional int to indicate number of MCMC samples. initial_state: A `GibbsSamplerState` structure of the initial states of the MCMC chains. seed: Optional `Python` `int` seed controlling the sampled values. Returns: model: A `GibbsSamplerState` structure of posterior samples. """ if not hasattr(model, 'supports_gibbs_sampling'): raise ValueError('This STS model does not support Gibbs sampling. Models ' 'for Gibbs sampling must be created using the ' 'method `build_model_for_gibbs_fitting`.') if not tf.nest.is_nested(num_chains): num_chains = [num_chains] [ observed_time_series, is_missing ] = sts_util.canonicalize_observed_time_series_with_mask( observed_time_series) dtype = observed_time_series.dtype # The canonicalized time series always has trailing dimension `1`, # because although LinearGaussianSSMs support vector observations, STS models # describe scalar time series only. For our purposes it'll be cleaner to # remove this dimension. observed_time_series = observed_time_series[..., 0] batch_shape = prefer_static.concat( [num_chains, prefer_static.shape(observed_time_series)[:-1]], axis=-1) level_slope_shape = prefer_static.concat( [num_chains, prefer_static.shape(observed_time_series)], axis=-1) # Treat a LocalLevel model as the special case of LocalLinearTrend where # the slope_scale is always zero. initial_slope_scale = 0. initial_slope = 0. if isinstance(model.components[0], sts.LocalLinearTrend): initial_slope_scale = 1. * tf.ones(batch_shape, dtype=dtype) initial_slope = tf.zeros(level_slope_shape, dtype=dtype) if initial_state is None: initial_state = GibbsSamplerState( observation_noise_scale=tf.ones(batch_shape, dtype=dtype), level_scale=tf.ones(batch_shape, dtype=dtype), slope_scale=initial_slope_scale, weights=tf.zeros(prefer_static.concat([ batch_shape, _get_design_matrix(model).shape[-1:]], axis=0), dtype=dtype), level=tf.zeros(level_slope_shape, dtype=dtype), slope=initial_slope, seed=None) # Set below. if isinstance(seed, six.integer_types): tf.random.set_seed(seed) # Always use the passed-in `seed` arg, ignoring any seed in the initial state. initial_state = initial_state._replace( seed=samplers.sanitize_seed(seed, salt='initial_GibbsSamplerState')) sampler_loop_body = _build_sampler_loop_body(model, observed_time_series, is_missing) samples = tf.scan(sampler_loop_body, np.arange(num_warmup_steps + num_results), initial_state) return tf.nest.map_structure(lambda x: x[num_warmup_steps:], samples)
def _build_sampler_loop_body(model, observed_time_series, is_missing=None): """Builds a Gibbs sampler for the given model and observed data. Args: model: A `tf.sts.StructuralTimeSeries` model instance. This must be of the form constructed by `build_model_for_gibbs_sampling`. observed_time_series: Float `Tensor` time series of shape `[..., num_timesteps]`. is_missing: Optional `bool` `Tensor` of shape `[..., num_timesteps]`. A `True` value indicates that the observation for that timestep is missing. Returns: sampler_loop_body: Python callable that performs a single cycle of Gibbs sampling. Its first argument is a `GibbsSamplerState`, and it returns a new `GibbsSamplerState`. The second argument (passed by `tf.scan`) is ignored. """ level_component = model.components[0] if not (isinstance(level_component, sts.LocalLevel) or isinstance(level_component, sts.LocalLinearTrend)): raise ValueError('Expected the first model component to be an instance of ' '`tfp.sts.LocalLevel` or `tfp.sts.LocalLinearTrend`; ' 'instead saw {}'.format(level_component)) model_has_slope = isinstance(level_component, sts.LocalLinearTrend) regression_component = model.components[1] if not (isinstance(regression_component, sts.LinearRegression) or isinstance(regression_component, SpikeAndSlabSparseLinearRegression)): raise ValueError('Expected the second model component to be an instance of ' '`tfp.sts.LinearRegression` or ' '`SpikeAndSlabSparseLinearRegression`; ' 'instead saw {}'.format(regression_component)) model_has_spike_slab_regression = isinstance( regression_component, SpikeAndSlabSparseLinearRegression) if is_missing is not None: # Ensure series does not contain NaNs. observed_time_series = tf.where(is_missing, tf.zeros_like(observed_time_series), observed_time_series) num_observed_steps = prefer_static.shape(observed_time_series)[-1] design_matrix = _get_design_matrix(model).to_dense()[:num_observed_steps] if is_missing is not None: # Replace design matrix with zeros at unobserved timesteps. This ensures # they will not affect the posterior on weights. design_matrix = tf.where(is_missing[..., tf.newaxis], tf.zeros_like(design_matrix), design_matrix) # Untransform scale priors -> variance priors by reaching thru Sqrt bijector. observation_noise_param = model.parameters[0] if 'observation_noise' not in observation_noise_param.name: raise ValueError('Model parameters {} do not match the expected sampler ' 'state.'.format(model.parameters)) observation_noise_variance_prior = observation_noise_param.prior.distribution if model_has_slope: level_scale_variance_prior, slope_scale_variance_prior = [ p.prior.distribution for p in level_component.parameters] else: level_scale_variance_prior = ( level_component.parameters[0].prior.distribution) if model_has_spike_slab_regression: spike_and_slab_sampler = spike_and_slab.SpikeSlabSampler( design_matrix, weights_prior_precision=regression_component._weights_prior_precision, # pylint: disable=protected-access nonzero_prior_prob=regression_component._sparse_weights_nonzero_prob, # pylint: disable=protected-access observation_noise_variance_prior_concentration=( observation_noise_variance_prior.concentration), observation_noise_variance_prior_scale=( observation_noise_variance_prior.scale), observation_noise_variance_upper_bound=( observation_noise_variance_prior.upper_bound if hasattr(observation_noise_variance_prior, 'upper_bound') else None)) else: weights_prior_scale = ( regression_component.parameters[0].prior.scale) def sampler_loop_body(previous_sample, _): """Runs one sampler iteration, resampling all model variables.""" (weights_seed, level_seed, observation_noise_scale_seed, level_scale_seed, loop_seed) = samplers.split_seed( previous_sample.seed, n=5, salt='sampler_loop_body') # Preserve backward-compatible seed behavior by splitting slope separately. slope_scale_seed, = samplers.split_seed( previous_sample.seed, n=1, salt='sampler_loop_body_slope') # We encourage a reasonable initialization by sampling the weights first, # so at the first step they are regressed directly against the observed # time series. If we instead sampled the level first it might 'explain away' # some observed variation that we would ultimately prefer to explain through # the regression weights, because the level can represent arbitrary # variation, while the weights are limited to representing variation in the # subspace given by the design matrix. if model_has_spike_slab_regression: (observation_noise_variance, weights) = spike_and_slab_sampler.sample_noise_variance_and_weights( initial_nonzeros=tf.not_equal(previous_sample.weights, 0.), targets=observed_time_series - previous_sample.level, seed=weights_seed) observation_noise_scale = tf.sqrt(observation_noise_variance) else: weights = _resample_weights( design_matrix=design_matrix, target_residuals=observed_time_series - previous_sample.level, observation_noise_scale=previous_sample.observation_noise_scale, weights_prior_scale=weights_prior_scale, seed=weights_seed) # Noise scale will be resampled below. observation_noise_scale = previous_sample.observation_noise_scale regression_residuals = observed_time_series - tf.linalg.matvec( design_matrix, weights) latents = _resample_latents( observed_residuals=regression_residuals, level_scale=previous_sample.level_scale, slope_scale=previous_sample.slope_scale if model_has_slope else None, observation_noise_scale=observation_noise_scale, initial_state_prior=level_component.initial_state_prior, is_missing=is_missing, seed=level_seed) level = latents[..., 0] level_residuals = level[..., 1:] - level[..., :-1] if model_has_slope: slope = latents[..., 1] level_residuals -= slope[..., :-1] slope_residuals = slope[..., 1:] - slope[..., :-1] # Estimate level scale from the empirical changes in level. level_scale = _resample_scale( prior=level_scale_variance_prior, observed_residuals=level_residuals, is_missing=None, seed=level_scale_seed) if model_has_slope: slope_scale = _resample_scale( prior=slope_scale_variance_prior, observed_residuals=slope_residuals, is_missing=None, seed=slope_scale_seed) if not model_has_spike_slab_regression: # Estimate noise scale from the residuals. observation_noise_scale = _resample_scale( prior=observation_noise_variance_prior, observed_residuals=regression_residuals - level, is_missing=is_missing, seed=observation_noise_scale_seed) return GibbsSamplerState( observation_noise_scale=observation_noise_scale, level_scale=level_scale, slope_scale=(slope_scale if model_has_slope else previous_sample.slope_scale), weights=weights, level=level, slope=(slope if model_has_slope else previous_sample.slope), seed=loop_seed) return sampler_loop_body
def random_gamma_rejection(sample_shape, alpha, beta, internal_dtype=tf.float64, seed=None): """Samples from the gamma distribution. The sampling algorithm is rejection sampling [1], and pathwise gradients with respect to alpha are computed via implicit differentiation [2]. Args: sample_shape: The output sample shape. Must broadcast with both `alpha` and `beta`. alpha: Floating point tensor, the alpha params of the distribution(s). Must contain only positive values. Must broadcast with `beta`. beta: Floating point tensor, the inverse scale params of the distribution(s). Must contain only positive values. Must broadcast with `alpha`. internal_dtype: dtype to use for internal computations. seed: (optional) The random seed. Returns: Differentiable samples from the gamma distribution. #### References [1] George Marsaglia and Wai Wan Tsang. A simple method for generating Gamma variables. ACM Transactions on Mathematical Software, 2000. [2] Michael Figurnov, Shakir Mohamed, and Andriy Mnih. Implicit Reparameterization Gradients. Neural Information Processing Systems, 2018. """ generate_and_test_samples_seed, alpha_fix_seed = samplers.split_seed( seed, salt='random_gamma') output_dtype = dtype_util.common_dtype([alpha, beta], dtype_hint=tf.float32) def rejection_sample(alpha): """Gamma rejection sampler.""" # Note that alpha here already has a shape that is broadcast with beta. cast_alpha = tf.cast(alpha, internal_dtype) good_params_mask = (alpha > 0.) # When replacing NaN values, use 100. for alpha, since that leads to a # high-likelihood of the rejection sampler accepting on the first pass. safe_alpha = tf.where(good_params_mask, cast_alpha, 100.) modified_safe_alpha = tf.where(safe_alpha < 1., safe_alpha + 1., safe_alpha) one_third = tf.constant(1. / 3, dtype=internal_dtype) d = modified_safe_alpha - one_third c = one_third / tf.sqrt(d) def generate_and_test_samples(seed): """Generate and test samples.""" v_seed, u_seed = samplers.split_seed(seed) def generate_positive_v(): """Generate positive v.""" def _inner(seed): x = samplers.normal(sample_shape, dtype=internal_dtype, seed=seed) # This implicitly broadcasts alpha up to sample shape. v = 1 + c * x return (x, v), v > 0. # Note: It should be possible to remove this 'inner' call to # `batched_las_vegas_algorithm` and merge the v > 0 check into the # overall check for a good sample. This would lead to a slightly simpler # implementation; it is unclear whether it would be faster. We include # the inner loop so this implementation is more easily comparable to # Ref. [1] and other implementations. return brs.batched_las_vegas_algorithm(_inner, v_seed)[0] (x, v) = generate_positive_v() x2 = x * x v3 = v * v * v u = samplers.uniform(sample_shape, dtype=internal_dtype, seed=u_seed) # In [1], the suggestion is to first check u < 1 - 0.331 * x2 * x2, and to # run the check below only if it fails, in order to avoid the relatively # expensive logarithm calls. Our algorithm operates in batch mode: we will # have to compute or not compute the logarithms for the entire batch, and # as the batch gets larger, the odds we compute it grow. Therefore we # don't bother with the "cheap" check. good_sample_mask = (tf.math.log(u) < x2 / 2. + d * (1 - v3 + tf.math.log(v3))) return v3, good_sample_mask samples = brs.batched_las_vegas_algorithm( generate_and_test_samples, seed=generate_and_test_samples_seed)[0] samples = samples * d one = tf.constant(1., dtype=internal_dtype) alpha_lt_one_fix = tf.where( safe_alpha < 1., tf.math.pow( samplers.uniform(sample_shape, dtype=internal_dtype, seed=alpha_fix_seed), one / safe_alpha), one) samples = samples * alpha_lt_one_fix samples = tf.where(good_params_mask, samples, np.nan) output_type_samples = tf.cast(samples, output_dtype) # We use `tf.where` instead of `tf.maximum` because we need to allow for # `samples` to be `nan`, but `tf.maximum(nan, x) == x`. output_type_samples = tf.where( output_type_samples == 0, np.finfo(dtype_util.as_numpy_dtype( output_type_samples.dtype)).tiny, output_type_samples) return output_type_samples broadcast_alpha_shape = ps.broadcast_shape(ps.shape(alpha), ps.shape(beta)) broadcast_alpha = tf.broadcast_to(alpha, broadcast_alpha_shape) alpha_samples = rejection_sample(broadcast_alpha) corrected_beta = tf.where(beta > 0., beta, np.nan) return alpha_samples / corrected_beta
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name): x = tf.convert_to_tensor(x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = ps.rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = distribution_util.rotate_transpose(x, shift) if center: x_rotated = x_rotated - tf.reduce_mean( x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = ps.shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = ps.cast(x_len, np.float64) target_length = ps.pow(np.float64(2.), ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.))) pad_length = ps.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = distribution_util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype_util.is_complex(dtype): if not dtype_util.is_floating(dtype): raise TypeError( 'Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex( x_rotated_pad, dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.signal.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.signal.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not tensorshape_util.is_fully_defined(x_rotated.shape): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name='max_lags') max_lags_ = tf.get_static_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = tensorshape_util.as_list(x_rotated.shape) chopped_shape[-1] = min(x_len, max_lags + 1) tensorshape_util.set_shape(shifted_product_chopped, chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = ps.cast(x_len, dtype_util.real_dtype(dtype)) max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype)) denominator = x_len - ps.range(0., max_lags + 1.) denominator = ps.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return distribution_util.rotate_transpose(shifted_product_rotated, -shift)
def __init__(self, num_timesteps, design_matrix, drift_scale, initial_state_prior, observation_noise_scale=0., initial_step=0, validate_args=False, allow_nan_stats=True, name=None): """State space model for a dynamic linear regression. Args: num_timesteps: Scalar `int` `Tensor` number of timesteps to model with this distribution. design_matrix: float `Tensor` of shape `concat([batch_shape, [num_timesteps, num_features]])`. drift_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the latent state transitions. initial_state_prior: instance of `tfd.MultivariateNormal` representing the prior distribution on latent states. Must have event shape `[num_features]`. observation_noise_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the observation noise. Default value: `0.`. initial_step: scalar `int` `Tensor` specifying the starting timestep. Default value: `0`. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. allow_nan_stats: Python `bool`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. Default value: `True`. name: Python `str` name prefixed to ops created by this class. Default value: 'DynamicLinearRegressionStateSpaceModel'. """ with tf.name_scope( name or 'DynamicLinearRegressionStateSpaceModel') as name: dtype = dtype_util.common_dtype( [design_matrix, drift_scale, initial_state_prior]) design_matrix = tf.convert_to_tensor( value=design_matrix, name='design_matrix', dtype=dtype) design_matrix_with_time_in_first_dim = distribution_util.move_dimension( design_matrix, -2, 0) drift_scale = tf.convert_to_tensor( value=drift_scale, name='drift_scale', dtype=dtype) observation_noise_scale = tf.convert_to_tensor( value=observation_noise_scale, name='observation_noise_scale', dtype=dtype) num_features = prefer_static.shape(design_matrix)[-1] def observation_matrix_fn(t): observation_matrix = tf.linalg.LinearOperatorFullMatrix( tf.gather(design_matrix_with_time_in_first_dim, t)[..., tf.newaxis, :], name='observation_matrix') return observation_matrix self._drift_scale = drift_scale self._observation_noise_scale = observation_noise_scale super(DynamicLinearRegressionStateSpaceModel, self).__init__( num_timesteps=num_timesteps, transition_matrix=tf.linalg.LinearOperatorIdentity( num_rows=num_features, dtype=dtype, name='transition_matrix'), transition_noise=tfd.MultivariateNormalDiag( scale_diag=(drift_scale[..., tf.newaxis] * tf.ones([num_features], dtype=dtype)), name='transition_noise'), observation_matrix=observation_matrix_fn, observation_noise=tfd.MultivariateNormalDiag( scale_diag=observation_noise_scale[..., tf.newaxis], name='observation_noise'), initial_state_prior=initial_state_prior, initial_step=initial_step, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name)
def _batch_shape_tensor(self): return tf.broadcast_dynamic_shape( [] if self.amplitude is None else ps.shape(self.amplitude), [] if self.length_scale is None else ps.shape(self.length_scale))
def _distributional_transform(self, x, event_shape): """Performs distributional transform of the mixture samples. Distributional transform removes the parameters from samples of a multivariate distribution by applying conditional CDFs: (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1)) (the indexing is over the 'flattened' event dimensions). The result is a sample of product of Uniform[0, 1] distributions. We assume that the components are factorized, so the conditional CDFs become F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i), where w_i^k is the posterior mixture weight: for i > 0 w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1) and w_0^k = w_k is the mixture probability of the k-th component. Args: x: Sample of mixture distribution event_shape: The event shape of this distribution Returns: Result of the distributional transform """ if tensorshape_util.rank(x.shape) is None: # tf.math.softmax raises an error when applied to inputs of undefined # rank. raise ValueError( 'Distributional transform does not support inputs of ' 'undefined rank.') # Obtain factorized components distribution and assert that it's # a scalar distribution. if isinstance(self._components_distribution, independent.Independent): univariate_components = self._components_distribution.distribution else: univariate_components = self._components_distribution with tf.control_dependencies([ assert_util.assert_equal( univariate_components.is_scalar_event(), True, message='`univariate_components` must have scalar event') ]): event_ndims = ps.rank_from_shape(event_shape) x_padded = self._pad_sample_dims( x, event_ndims=event_ndims) # [S, B, 1, E] log_prob_x = univariate_components.log_prob( x_padded) # [S, B, k, E] cdf_x = univariate_components.cdf(x_padded) # [S, B, k, E] # log prob_k (x_1, ..., x_i-1) event_size = ps.cast(ps.reduce_prod(event_shape), dtype=tf.int32) cumsum_log_prob_x = tf.reshape( tf.math.cumsum( # [S*prod(B)*k, prod(E)] tf.reshape(log_prob_x, [-1, event_size]), exclusive=True, axis=-1), ps.shape(log_prob_x)) # [S, B, k, E] event_ndims = ps.rank_from_shape(event_shape) logits_mix_prob = self.mixture_distribution.logits_parameter() logits_mix_prob = tf.reshape( logits_mix_prob, # [k] or [B, k] ps.concat([ ps.shape(logits_mix_prob), ps.ones([event_ndims], dtype=tf.int32), ], axis=0)) # [k, [1]*e] or [B, k, [1]*e] # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1) log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x component_axis = tensorshape_util.rank(x.shape) - event_ndims posterior_weights_x = tf.math.softmax(log_posterior_weights_x, axis=component_axis) return tf.reduce_sum(posterior_weights_x * cdf_x, axis=component_axis)
def _batch_shape_tensor(self): return ps.shape(self.concentration)
def _reparameterize_sample(self, x, event_shape): """Adds reparameterization (pathwise) gradients to samples of the mixture. Implicit reparameterization gradients are dx/dphi = -(d transform(x, phi) / dx)^-1 * d transform(x, phi) / dphi, where transform(x, phi) is distributional transform that removes all parameters from samples x. We implement them by replacing x with -stop_gradient(d transform(x, phi) / dx)^-1 * transform(x, phi)] for the backward pass (gradient computation). The derivative of this quantity w.r.t. phi is then the implicit reparameterization gradient. Note that this replaces the gradients w.r.t. both the mixture distribution parameters and components distributions parameters. Limitations: 1. Fundamental: components must be fully reparameterized. 2. Distributional transform is currently only implemented for factorized components. 3. Distributional transform currently only works for known rank of the batch tensor. Args: x: Sample of mixture distribution event_shape: The event shape of this distribution Returns: Tensor with same value as x, but with reparameterization gradients """ # Remove the existing gradients of x wrt parameters of the components. x = tf.stop_gradient(x) event_size = ps.cast(ps.reduce_prod(event_shape), dtype=tf.int32) x_2d_shape = [-1, event_size] # [S*prod(B), prod(E)] # Perform distributional transform of x in [S, B, E] shape, # but have Jacobian of size [S*prod(B), prod(E), prod(E)]. def reshaped_distributional_transform(x_2d): return tf.reshape( self._distributional_transform(tf.reshape(x_2d, ps.shape(x)), event_shape), x_2d_shape) # transform_2d: [S*prod(B), prod(E)] # jacobian: [S*prod(B), prod(E), prod(E)] x_2d = tf.reshape(x, x_2d_shape) transform_2d, jacobian = value_and_batch_jacobian( reshaped_distributional_transform, x_2d) # We only provide the first derivative; the second derivative computed by # autodiff would be incorrect, so we raise an error if it is requested. transform_2d = _prevent_2nd_derivative(transform_2d) # Compute [- stop_gradient(jacobian)^-1 * transform] by solving a linear # system. The Jacobian is lower triangular because the distributional # transform for i-th event dimension does not depend on the next # dimensions. surrogate_x_2d = -tf.linalg.triangular_solve( tf.stop_gradient(jacobian), transform_2d[..., tf.newaxis], lower=True) # [S*prod(B), prod(E), 1] surrogate_x = tf.reshape(surrogate_x_2d, ps.shape(x)) # Replace gradients of x with gradients of surrogate_x, but keep the value. return x + (surrogate_x - tf.stop_gradient(surrogate_x))
def reshaped_distributional_transform(x_2d): return tf.reshape( self._distributional_transform(tf.reshape(x_2d, ps.shape(x)), event_shape), x_2d_shape)
def _sample_n(self, n, seed): components_seed, mix_seed = samplers.split_seed( seed, salt='MixtureSameFamily') try: seed_stream = SeedStream(seed, salt='MixtureSameFamily') except TypeError as e: # Can happen for Tensor seeds. seed_stream = None seed_stream_err = e try: x = self.components_distribution.sample( # [n, B, k, E] n, seed=components_seed) if seed_stream is not None: seed_stream() # Advance even if unused. except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `components_distribution` ' '{} of type `{}`. Please update to use `tf.random.stateless_*` ' 'RNGs. This fallback may be removed after 20-Aug-2020. {}') warnings.warn( msg.format(self.components_distribution.name, type(self.components_distribution), str(e))) x = self.components_distribution.sample( # [n, B, k, E] n, seed=seed_stream()) event_shape = None event_ndims = tensorshape_util.rank(self.event_shape) if event_ndims is None: event_shape = self.components_distribution.event_shape_tensor() event_ndims = ps.rank_from_shape(event_shape) event_ndims_static = tf.get_static_value(event_ndims) num_components = None if event_ndims_static is not None: num_components = tf.compat.dimension_value( x.shape[-1 - event_ndims_static]) # We could also check if num_components can be computed statically from # self.mixture_distribution's logits or probs. if num_components is None: num_components = tf.shape(x)[-1 - event_ndims] # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). npdt = dtype_util.as_numpy_dtype(x.dtype) try: mix_sample = self.mixture_distribution.sample( n, seed=mix_seed) # [n, B] or [n] except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `mixture_distribution` ' '{} of type `{}`. Please update to use `tf.random.stateless_*` ' 'RNGs. This fallback may be removed after 20-Aug-2020. ({})') warnings.warn( msg.format(self.mixture_distribution.name, type(self.mixture_distribution), str(e))) mix_sample = self.mixture_distribution.sample( n, seed=seed_stream()) # [n, B] or [n] mask = tf.one_hot( indices=mix_sample, # [n, B] or [n] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k] or [n, k] # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] . batch_ndims = ps.rank(x) - event_ndims - 1 mask_batch_ndims = ps.rank(mask) - 1 pad_ndims = batch_ndims - mask_batch_ndims mask_shape = ps.shape(mask) target_shape = ps.concat([ mask_shape[:-1], ps.ones([pad_ndims], dtype=tf.int32), mask_shape[-1:], ps.ones([event_ndims], dtype=tf.int32), ], axis=0) mask = tf.reshape(mask, shape=target_shape) if dtype_util.is_floating(x.dtype) or dtype_util.is_complex(x.dtype): masked = tf.math.multiply_no_nan(x, mask) else: masked = x * mask ret = tf.reduce_sum(masked, axis=-1 - event_ndims) # [n, B, E] if self._reparameterize: if event_shape is None: event_shape = self.components_distribution.event_shape_tensor() ret = self._reparameterize_sample(ret, event_shape=event_shape) return ret
def _sample_n(self, n, seed=None): power = tf.convert_to_tensor(self.power) shape = ps.concat([[n], ps.shape(power)], axis=0) numpy_dtype = dtype_util.as_numpy_dtype(power.dtype) seed = samplers.sanitize_seed(seed, salt='zipf') # Because `_hat_integral` is montonically decreasing, the bounds for u will # switch. # Compute the hat_integral explicitly here since we can calculate the log of # the inputs statically in float64 with numpy. maxval_u = tf.math.exp(-(power - 1.) * numpy_dtype(np.log1p(0.5)) - tf.math.log(power - 1.)) + 1. minval_u = tf.math.exp( -(power - 1.) * numpy_dtype(np.log1p(dtype_util.max(self.dtype) - 0.5)) - tf.math.log(power - 1.)) def loop_body(should_continue, k, seed): """Resample the non-accepted points.""" u_seed, next_seed = samplers.split_seed(seed) # Uniform variates must be sampled from the open-interval `(0, 1)` rather # than `[0, 1)`. To do so, we use # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny` # because it is the smallest, positive, 'normal' number. A 'normal' number # is such that the mantissa has an implicit leading 1. Normal, positive # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In # this case, a subnormal number (i.e., np.nextafter) can cause us to # sample 0. u = samplers.uniform( shape, minval=np.finfo(dtype_util.as_numpy_dtype(power.dtype)).tiny, maxval=numpy_dtype(1.), dtype=power.dtype, seed=u_seed) # We use (1 - u) * maxval_u + u * minval_u rather than the other way # around, since we want to draw samples in (minval_u, maxval_u]. u = maxval_u + (minval_u - maxval_u) * u # set_shape needed here because of b/139013403 tensorshape_util.set_shape(u, should_continue.shape) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u, power=power) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp(self._log_prob(k + 1, power=power))) return [should_continue & (~accept), k, next_seed] should_continue, samples, _ = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue ), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=power.dtype), # k seed, # seed ], maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and dtype_util.is_integer(self.dtype): samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: npdt = dtype_util.as_numpy_dtype(self.dtype) v = npdt( dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan ) samples = tf.where(should_continue, v, samples) return samples
def _transpose_around_bijector_fn(self, bijector_fn, arg, src_event_ndims, dest_event_ndims=None, fn_reduces_event=False, **kwargs): # This function moves the axes corresponding to `self.sample_shape` to the # left of the batch shape, then applies `bijector_fn`, then moves the axes # corresponding to `self.sample_shape` back to the event part of the shape. # # `src_event_ndims` and `dest_event_ndims` indicate the expected event rank # (omitting `self.sample_shape`) before and after applying `bijector_fn`. # # This function arose because forward and inverse ended up being quite # similar. It was then only a small generalization to also support {F/I}LDJ. batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = ps.rank_from_shape(self.sample_shape) arg_ndims = ps.rank(arg) # (1) Expand arg's dims. d = arg_ndims - batch_ndims - extra_sample_ndims - src_event_ndims arg = tf.reshape(arg, shape=ps.pad(ps.shape(arg), paddings=[[ps.maximum(0, -d), 0]], constant_values=1)) arg_ndims = ps.rank(arg) sample_ndims = ps.maximum(0, d) # (2) Transpose arg's dims. sample_dims = ps.range(0, sample_ndims) batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = ps.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims, arg_ndims) perm = ps.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) arg = tf.transpose(arg, perm=perm) # (3) Apply underlying bijector. result = bijector_fn(arg, **kwargs) # (4) Transpose sample_shape from the sample to the event shape. result_ndims = ps.rank(result) if fn_reduces_event: dest_event_ndims = 0 d = result_ndims - batch_ndims - extra_sample_ndims - dest_event_ndims if fn_reduces_event: # In some cases, fn may reduce event too far, i.e. ildj may return a # scalar `0.`, which won't work with the transpose we do below. result = tf.reshape(result, shape=ps.pad(ps.shape(result), paddings=[[ps.maximum(0, -d), 0]], constant_values=1)) result_ndims = ps.rank(result) sample_ndims = ps.maximum(0, d) sample_dims = ps.range(0, sample_ndims) extra_sample_dims = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) batch_dims = ps.range(sample_ndims + extra_sample_ndims, sample_ndims + extra_sample_ndims + batch_ndims) event_dims = ps.range(sample_ndims + extra_sample_ndims + batch_ndims, result_ndims) perm = ps.concat( [sample_dims, batch_dims, extra_sample_dims, event_dims], axis=0) return tf.transpose(result, perm=perm)
def _kahan_reduce_bwd(aux, grads): operands, inits, axis, unsqueezed_shape = aux del inits, axis # unused return (tf.broadcast_to(tf.reshape(grads[0], unsqueezed_shape), ps.shape(operands[0])), None)
def _get_reduction_axes(x, nd): """Enumerates the final `nd` axis indices of `x`.""" x_rank = prefer_static.rank_from_shape(prefer_static.shape(x)) return prefer_static.range(x_rank - 1, x_rank - nd - 1, -1)
def _batch_shape_tensor(self): return prefer_static.shape(self.concentration)
def convolution_batch(x, kernel, rank, strides, padding, data_format=None, dilations=None, name=None): """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" if rank != 2: raise NotImplementedError( 'Argument `rank` currently only supports `2`; ' 'saw "{}".'.format(rank)) if data_format is not None and data_format.upper() != 'NHWBC': raise ValueError( 'Argument `data_format` currently only supports "NHWBC"; ' 'saw "{}".'.format(data_format)) with tf.name_scope(name or 'conv2d_nhwbc'): # Prepare arguments. [ rank, _, # strides padding, dilations, data_format, ] = prepare_conv_args(rank, strides, padding, dilations) strides = prepare_tuple_argument(strides, rank + 2, arg_name='strides') dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel') # Step 1: Transpose and double flatten kernel. # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c'] kernel_shape = prefer_static.shape(kernel) kernel_batch_shape, kernel_event_shape = prefer_static.split( kernel_shape, num_or_size_splits=[-1, rank + 2]) kernel_batch_size = prefer_static.reduce_prod(kernel_batch_shape) kernel_ndims = prefer_static.rank(kernel) kernel_batch_ndims = kernel_ndims - rank - 2 perm = prefer_static.concat([ prefer_static.range(kernel_batch_ndims, kernel_batch_ndims + rank), prefer_static.range(0, kernel_batch_ndims), prefer_static.range(kernel_batch_ndims + rank, kernel_ndims), ], axis=0) # Eg, [1, 2, 0, 3, 4] kernel = tf.transpose(kernel, perm=perm) # F + B + [c, c'] kernel = tf.reshape(kernel, shape=prefer_static.concat([ kernel_event_shape[:rank], [ kernel_batch_size * kernel_event_shape[-2], kernel_event_shape[-1] ], ], axis=0)) # F + [bc, c'] # Step 2: Double flatten x. # x.shape = N + D + B + [c] x_shape = prefer_static.shape(x) [ x_sample_shape, x_rank_shape, x_batch_shape, x_channel_shape, ] = prefer_static.split( x_shape, num_or_size_splits=[-1, rank, kernel_batch_ndims, 1]) x = tf.reshape( x, # N + D + B + [c] shape=prefer_static.concat([ [prefer_static.reduce_prod(x_sample_shape)], x_rank_shape, [ prefer_static.reduce_prod(x_batch_shape) * prefer_static.reduce_prod(x_channel_shape) ], ], axis=0)) # [n] + D + [bc] # Step 3: Apply convolution. y = tf.nn.depthwise_conv2d(x, kernel, strides=strides, padding=padding, data_format='NHWC', dilations=dilations) # SAME: y.shape = [n, h, w, bcc'] # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc'] # Step 4: Reshape/reduce for output. y_shape = prefer_static.shape(y) y = tf.reshape(y, shape=prefer_static.concat( [ x_sample_shape, y_shape[1:-1], kernel_batch_shape, kernel_event_shape[-2:], ], axis=0)) # N + D' + B + [c, c'] y = tf.reduce_sum(y, axis=-2) # N + D' + B + [c'] return y
def __init__(self, design_matrix, drift_scale_prior=None, initial_weights_prior=None, observed_time_series=None, name=None): """Specify a dynamic linear regression. Args: design_matrix: float `Tensor` of shape `concat([batch_shape, [num_timesteps, num_features]])`. drift_scale_prior: instance of `tfd.Distribution` specifying a prior on the `drift_scale` parameter. If `None`, a heuristic default prior is constructed based on the provided `observed_time_series`. Default value: `None`. initial_weights_prior: instance of `tfd.MultivariateNormal` representing the prior distribution on the latent states (the regression weights). Must have event shape `[num_features]`. If `None`, a weakly-informative Normal(0., 10.) prior is used. Default value: `None`. observed_time_series: `float` `Tensor` of shape `batch_shape + [T, 1]` (omitting the trailing unit dimension is also supported when `T > 1`), specifying an observed time series. Any priors not explicitly set will be given default values according to the scale of the observed time series (or batch of time series). May optionally be an instance of `tfp.sts.MaskedTimeSeries`, which includes a mask `Tensor` to specify timesteps with missing observations. Default value: `None`. name: Python `str` for the name of this component. Default value: 'DynamicLinearRegression'. """ with tf.name_scope(name or 'DynamicLinearRegression') as name: dtype = dtype_util.common_dtype( [design_matrix, drift_scale_prior, initial_weights_prior]) num_features = prefer_static.shape(design_matrix)[-1] # Default to a weakly-informative Normal(0., 10.) for the initital state if initial_weights_prior is None: initial_weights_prior = tfd.MultivariateNormalDiag( scale_diag=10. * tf.ones([num_features], dtype=dtype)) # Heuristic default priors. Overriding these may dramatically # change inference performance and results. if drift_scale_prior is None: if observed_time_series is None: observed_stddev = tf.constant(1.0, dtype=dtype) else: _, observed_stddev, _ = sts_util.empirical_statistics( observed_time_series) drift_scale_prior = tfd.LogNormal( loc=tf.math.log(.05 * observed_stddev), scale=3., name='drift_scale_prior') self._initial_state_prior = initial_weights_prior self._design_matrix = design_matrix super(DynamicLinearRegression, self).__init__( parameters=[ Parameter('drift_scale', drift_scale_prior, tfb.Chain([tfb.AffineScalar(scale=observed_stddev), tfb.Softplus()])) ], latent_size=num_features, name=name)
def _batch_shape(self): tensors = [self.loc, self.scale, self.tailweight, self.skewness] return functools.reduce( ps.broadcast_shape, [ps.shape(t) for t in tensors])
def _batch_shape_tensor(self, concentration=None, rate=None): return ps.broadcast_shape( ps.shape(self.concentration if concentration is None else concentration), ps.shape(self.rate if rate is None else rate))
def find_root_chandrupatla(objective_fn, low, high, position_tolerance=1e-8, value_tolerance=0., max_iterations=50, stopping_policy_fn=tf.reduce_all, validate_args=False, name='find_root_chandrupatla'): r"""Finds root(s) of a scalar function using Chandrupatla's method. Chandrupatla's method [1, 2] is a root-finding algorithm that is guaranteed to converge if a root lies within the given bounds. It generalizes the [bisection method](https://en.wikipedia.org/wiki/Bisection_method); at each step it chooses to perform either bisection or inverse quadratic interpolation. This makes it similar in spirit to [Brent's method]( https://en.wikipedia.org/wiki/Brent%27s_method), which also considers steps that use the secant method, but Chandrupatla's method is simpler and often converges at least as quickly [3]. Args: objective_fn: Python callable for which roots are searched. It must be a callable of a single variable. `objective_fn` must return a `Tensor` with shape `batch_shape` and dtype matching `lower_bound` and `upper_bound`. low: Float `Tensor` of shape `batch_shape` representing a lower bound(s) on the value of a root(s). high: Float `Tensor` of shape `batch_shape` representing an upper bound(s) on the value of a root(s). position_tolerance: Optional `Tensor` representing the maximum absolute error in the positions of the estimated roots. Shape must broadcast with `batch_shape`. Default value: `1e-8`. value_tolerance: Optional `Tensor` representing the absolute error allowed in the value of the objective function. If the absolute value of `objective_fn` is smaller than `value_tolerance` at a given position, then that position is considered a root for the function. Shape must broadcast with `batch_shape`. Default value: `1e-8`. max_iterations: Optional `Tensor` or Python integer specifying the maximum number of steps to perform. Shape must broadcast with `batch_shape`. Default value: `50`. stopping_policy_fn: Python `callable` controlling the algorithm termination. It must be a callable accepting a `Tensor` of booleans with the same shape as `lower_bound` and `upper_bound` (denoting whether each search is finished), and returning a scalar boolean `Tensor` indicating whether the overall search should stop. Typical values are `tf.reduce_all` (which returns only when the search is finished for all points), and `tf.reduce_any` (which returns as soon as the search is finished for any point). Default value: `tf.reduce_all` (returns only when the search is finished for all points). validate_args: Python `bool` indicating whether to validate arguments. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: 'find_root_chandrupatla'. Returns: root_search_results: A Python `namedtuple` containing the following items: estimated_root: `Tensor` containing the last position explored. If the search was successful within the specified tolerance, this position is a root of the objective function. objective_at_estimated_root: `Tensor` containing the value of the objective function at `position`. If the search was successful within the specified tolerance, then this is close to 0. num_iterations: The number of iterations performed. #### References [1] Tirupathi R. Chandrupatla. A new hybrid quadratic/bisection algorithm for finding the zero of a nonlinear function without using derivatives. _Advances in Engineering Software_, 28.3:145-149, 1997. [2] Philipp OJ Scherer. Computational Physics. _Springer Berlin_, Heidelberg, 2010. Section 6.1.7.3 https://books.google.com/books?id=cC-8BAAAQBAJ&pg=PA95 [3] Jason Sachs. Ten Little Algorithms, Part 5: Quadratic Extremum Interpolation and Chandrupatla's Method (2015). https://www.embeddedrelated.com/showarticle/855.php """ ################################################ # Loop variables used by Chandrupatla's method: # # a: endpoint of an interval `[min(a, b), max(a, b)]` containing the # root. There is no guarantee as to which of `a` and `b` is larger. # b: endpoint of an interval `[min(a, b), max(a, b)]` containing the # root. There is no guarantee as to which of `a` and `b` is larger. # f_a: value of the objective at `a`. # f_b: value of the objective at `b`. # t: the next position to be evaluated as the coefficient of a convex # combination of `a` and `b` (i.e., a value in the unit interval). # num_iterations: integer number of steps taken so far. # converged: boolean indicating whether each batch element has converged. # # All variables have the same shape `batch_shape`. def _should_continue(a, b, f_a, f_b, t, num_iterations, converged): del a, b, f_a, f_b, t # Unused. all_converged = stopping_policy_fn( tf.logical_or(converged, num_iterations >= max_iterations)) return ~all_converged def _body(a, b, f_a, f_b, t, num_iterations, converged): """One step of Chandrupatla's method for root finding.""" previous_loop_vars = (a, b, f_a, f_b, t, num_iterations, converged) finalized_elements = tf.logical_or(converged, num_iterations >= max_iterations) # Evaluate the new point. x_new = (1 - t) * a + t * b f_new = objective_fn(x_new) # If we've bisected (t==0.5) and the new float value for `a` is identical to # that from the previous iteration, then we'll keep bisecting (the # logic below will set t==0.5 for the next step), and nothing further will # change. at_fixed_point = tf.equal(x_new, a) & tf.equal(t, 0.5) # Otherwise, tighten the bounds. a, b, c, f_a, f_b, f_c = _structure_broadcasting_where( tf.equal(tf.math.sign(f_new), tf.math.sign(f_a)), (x_new, b, a, f_new, f_b, f_a), (x_new, a, b, f_new, f_a, f_b)) # Check for convergence. f_best = tf.where(tf.abs(f_a) < tf.abs(f_b), f_a, f_b) interval_tolerance = position_tolerance / (tf.abs(b - c)) converged = tf.logical_or( interval_tolerance > 0.5, tf.logical_or( tf.math.abs(f_best) <= value_tolerance, at_fixed_point)) # Propose next point to evaluate. xi = (a - b) / (c - b) phi = (f_a - f_b) / (f_c - f_b) t = tf.where( # Condition for inverse quadratic interpolation. tf.logical_and(1 - tf.math.sqrt(1 - xi) < phi, tf.math.sqrt(xi) > phi), # Propose a point by inverse quadratic interpolation. (f_a / (f_b - f_a) * f_c / (f_b - f_c) + (c - a) / (b - a) * f_a / (f_c - f_a) * f_b / (f_c - f_b)), # Otherwise, just cut the interval in half (bisection). 0.5) # Constrain the proposal to the current interval (0 < t < 1). t = tf.minimum(tf.maximum(t, interval_tolerance), 1 - interval_tolerance) # Update elements that haven't converged. return _structure_broadcasting_where( finalized_elements, previous_loop_vars, (a, b, f_a, f_b, t, num_iterations + 1, converged)) with tf.name_scope(name): max_iterations = tf.convert_to_tensor(max_iterations, name='max_iterations', dtype_hint=tf.int32) a = tf.convert_to_tensor(low, name='lower_bound') b = tf.convert_to_tensor(high, name='upper_bound') f_a, f_b = objective_fn(a), objective_fn(b) batch_shape = ps.broadcast_shape(ps.shape(f_a), ps.shape(f_b)) assertions = [] if validate_args: assertions += [ assert_util.assert_none_equal( tf.math.sign(f_a), tf.math.sign(f_b), message='Bounds must be on different sides of a root.') ] with tf.control_dependencies(assertions): initial_loop_vars = [ a, b, f_a, f_b, tf.cast(0.5, dtype=f_a.dtype), tf.cast(0, dtype=max_iterations.dtype), False ] a, b, f_a, f_b, _, num_iterations, _ = tf.while_loop( _should_continue, _body, loop_vars=tf.nest.map_structure( lambda x: tf.broadcast_to(x, batch_shape), initial_loop_vars)) x_best, f_best = _structure_broadcasting_where( tf.abs(f_a) < tf.abs(f_b), (a, f_a), (b, f_b)) return RootSearchResults(estimated_root=x_best, objective_at_estimated_root=f_best, num_iterations=num_iterations)
def covariance(x, y=None, sample_axis=0, event_axis=-1, keepdims=False, name=None): """Sample covariance between observations indexed by `event_axis`. Given `N` samples of scalar random variables `X` and `Y`, covariance may be estimated as ```none Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)} Xbar := N^{-1} sum_{n=1}^N X_n Ybar := N^{-1} sum_{n=1}^N Y_n ``` For vector-variate random variables `X = (X1, ..., Xd)`, `Y = (Y1, ..., Yd)`, one is often interested in the covariance matrix, `C_{ij} := Cov[Xi, Yj]`. ```python x = tf.random.normal(shape=(100, 2, 3)) y = tf.random.normal(shape=(100, 2, 3)) # cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j]. cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None) # cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n] cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1) ``` Notice we divide by `N`, which does not create `NaN` when `N = 1`, but is slightly biased. Args: x: A numeric `Tensor` holding samples. y: Optional `Tensor` with same `dtype` and `shape` as `x`. Default value: `None` (`y` is effectively set to `x`). sample_axis: Scalar or vector `Tensor` designating axis holding samples, or `None` (meaning all axis hold samples). Default value: `0` (leftmost dimension). event_axis: Scalar or vector `Tensor`, or `None` (scalar events). Axis indexing random events, whose covariance we are interested in. If a vector, entries must form a contiguous block of dims. `sample_axis` and `event_axis` should not intersect. Default value: `-1` (rightmost axis holds events). keepdims: Boolean. Whether to keep the sample axis as singletons. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., `'covariance'`). Returns: cov: A `Tensor` of same `dtype` as the `x`, and rank equal to `rank(x) - len(sample_axis) + 2 * len(event_axis)`. Raises: AssertionError: If `x` and `y` are found to have different shape. ValueError: If `sample_axis` and `event_axis` are found to overlap. ValueError: If `event_axis` is found to not be contiguous. """ with tf.name_scope(name or 'covariance'): x = tf.convert_to_tensor(x, name='x') # Covariance *only* uses the centered versions of x (and y). x = x - tf.reduce_mean(x, axis=sample_axis, keepdims=True) if y is None: y = x else: y = tf.convert_to_tensor(y, name='y', dtype=x.dtype) # If x and y have different shape, sample_axis and event_axis will likely # be wrong for one of them! tensorshape_util.assert_is_compatible_with(x.shape, y.shape) y = y - tf.reduce_mean(y, axis=sample_axis, keepdims=True) if event_axis is None: return tf.reduce_mean(x * tf.math.conj(y), axis=sample_axis, keepdims=keepdims) if sample_axis is None: raise ValueError( 'sample_axis was None, which means all axis hold events, and this ' 'overlaps with event_axis ({})'.format(event_axis)) event_axis = _make_positive_axis(event_axis, ps.rank(x)) sample_axis = _make_positive_axis(sample_axis, ps.rank(x)) # If we get lucky and axis is statically defined, we can do some checks. if _is_list_like(event_axis) and _is_list_like(sample_axis): event_axis = tuple(map(int, event_axis)) sample_axis = tuple(map(int, sample_axis)) if set(event_axis).intersection(sample_axis): raise ValueError( 'sample_axis ({}) and event_axis ({}) overlapped'.format( sample_axis, event_axis)) if (np.diff(np.array(sorted(event_axis))) > 1).any(): raise ValueError( 'event_axis must be contiguous. Found: {}'.format( event_axis)) batch_axis = list( sorted( set(range(tensorshape_util.rank( x.shape))).difference(sample_axis + event_axis))) else: batch_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), ps.concat((sample_axis, event_axis), 0)) event_axis = ps.cast(event_axis, dtype=tf.int32) sample_axis = ps.cast(sample_axis, dtype=tf.int32) batch_axis = ps.cast(batch_axis, dtype=tf.int32) # Permute x/y until shape = B + E + S perm_for_xy = ps.concat((batch_axis, event_axis, sample_axis), 0) x_permed = tf.transpose(a=x, perm=perm_for_xy) y_permed = tf.transpose(a=y, perm=perm_for_xy) batch_ndims = ps.size(batch_axis) batch_shape = ps.shape(x_permed)[:batch_ndims] event_ndims = ps.size(event_axis) event_shape = ps.shape(x_permed)[batch_ndims:batch_ndims + event_ndims] sample_shape = ps.shape(x_permed)[batch_ndims + event_ndims:] sample_ndims = ps.size(sample_shape) n_samples = ps.reduce_prod(sample_shape) n_events = ps.reduce_prod(event_shape) # Flatten sample_axis into one long dim. x_permed_flat = tf.reshape( x_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0)) y_permed_flat = tf.reshape( y_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0)) # Do the same for event_axis. x_permed_flat = tf.reshape( x_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0)) y_permed_flat = tf.reshape( y_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0)) # After matmul, cov.shape = batch_shape + [n_events, n_events] cov = tf.matmul(x_permed_flat, y_permed_flat, adjoint_b=True) / ps.cast(n_samples, x.dtype) # Insert some singletons to make # cov.shape = batch_shape + event_shape**2 + [1,...,1] # This is just like x_permed.shape, except the sample_axis is all 1's, and # the [n_events] became event_shape**2. cov = tf.reshape( cov, ps.concat( ( batch_shape, # event_shape**2 used here because it is the same length as # event_shape, and has the same number of elements as one # batch of covariance. event_shape**2, ps.ones([sample_ndims], tf.int32)), 0)) # Permuting by the argsort inverts the permutation, making # cov.shape have ones in the position where there were samples, and # [n_events * n_events] in the event position. cov = tf.transpose(a=cov, perm=ps.invert_permutation(perm_for_xy)) # Now expand event_shape**2 into event_shape + event_shape. # We here use (for the first time) the fact that we require event_axis to be # contiguous. e_start = event_axis[0] e_len = 1 + event_axis[-1] - event_axis[0] cov = tf.reshape( cov, ps.concat((ps.shape(cov)[:e_start], event_shape, event_shape, ps.shape(cov)[e_start + e_len:]), 0)) # tf.squeeze requires python ints for axis, not Tensor. This is enough to # require our axis args to be constants. if not keepdims: squeeze_axis = ps.where(sample_axis < e_start, sample_axis, sample_axis + e_len) cov = _squeeze(cov, axis=squeeze_axis) return cov
def _batch_shape_tensor(self): x = self._probs if self._logits is None else self._logits return ps.shape(x)
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'one_step')): variance_parts = previous_kernel_results.running_variance diags = [ variance_part.variance() for variance_part in variance_parts ] # Set the momentum. batch_ndims = ps.rank( unnest.get_innermost(previous_kernel_results, 'target_log_prob')) state_parts = tf.nest.flatten(current_state) new_momentum_distribution = _make_momentum_distribution( diags, state_parts, batch_ndims) inner_results = self.momentum_distribution_setter_fn( previous_kernel_results.inner_results, new_momentum_distribution) # Step the inner kernel. inner_kwargs = {} if seed is None else dict(seed=seed) new_state, new_inner_results = self.inner_kernel.one_step( current_state, inner_results, **inner_kwargs) new_state_parts = tf.nest.flatten(new_state) new_variance_parts = [] for variance_part, diag, state_part in zip(variance_parts, diags, new_state_parts): # Compute new variance for each variance part, accounting for partial # batching of the variance calculation across chains (ie, some, all, or # none of the chains may share the estimated mass matrix). # # For example, say # # state_part has shape [2, 3, 4] + [5, 6] (batch + event) # variance_part has shape [4] + [5, 6] # log_prob has shape [2, 3, 4] # # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass # matrices, each being shared across a [2, 3]-batch of chains. Note this # division is inferred from the shapes of the state part, the log_prob, # and the user-provided initial running variances. # # Until RunningVariance supports rank > 1 chunking, we need to flatten # the states that go into updating the variance estimates. In the above # example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and # fed to `RunningVariance.update(state_part, axis=0)`, recording # 6 new observations in the running variance calculation. # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and # the resulting momentum distribution will have batch shape of # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part. state_rank = ps.rank(state_part) variance_rank = ps.rank(diag) num_reduce_dims = state_rank - variance_rank state_part_shape = ps.shape(state_part) # This reshape adds a 1 when reduce_dims==0, and collapses all the lead # dimensions to a single one otherwise. reshaped_state = ps.reshape( state_part, ps.concat( [[ps.reduce_prod(state_part_shape[:num_reduce_dims])], state_part_shape[num_reduce_dims:]], axis=0)) # The `axis=0` here removes the leading dimension we got from the # reshape above, so the new_variance_parts have the correct shape again. new_variance_parts.append( variance_part.update(reshaped_state, axis=0)) new_kernel_results = previous_kernel_results._replace( inner_results=new_inner_results, running_variance=new_variance_parts) return new_state, new_kernel_results
def _sample_n(self, n, seed=None): probs = self._probs_parameter_no_checks() new_shape = ps.concat([[n], ps.shape(probs)], 0) uniform = samplers.uniform(new_shape, seed=seed, dtype=probs.dtype) sample = tf.less(uniform, probs) return tf.cast(sample, self.dtype)
def one_step_predictive(model, posterior_samples, num_forecast_steps=0, original_mean=0., original_scale=1., thin_every=10): """Constructs a one-step-ahead predictive distribution at every timestep. Unlike the generic `tfp.sts.one_step_predictive`, this method uses the latent levels from Gibbs sampling to efficiently construct a predictive distribution that mixes over posterior samples. The predictive distribution may also include additional forecast steps. This method returns the predictive distributions for each timestep given previous timesteps and sampled model parameters, `p(observed_time_series[t] | observed_time_series[:t], weights, observation_noise_scale)`. Note that the posterior values of the weights and noise scale will in general be informed by observations from all timesteps *including the step being predicted*, so this is not a strictly kosher probabilistic quantity, but in general we assume that it's close, i.e., that the step being predicted had very small individual impact on the overall parameter posterior. Args: model: A `tfd.sts.StructuralTimeSeries` model instance. This must be of the form constructed by `build_model_for_gibbs_sampling`. posterior_samples: A `GibbsSamplerState` instance in which each element is a `Tensor` with initial dimension of size `num_samples`. num_forecast_steps: Python `int` number of additional forecast steps to append. Default value: `0`. original_mean: Optional scalar float `Tensor`, added to the predictive distribution to undo the effect of input normalization. Default value: `0.` original_scale: Optional scalar float `Tensor`, used to rescale the predictive distribution to undo the effect of input normalization. Default value: `1.` thin_every: Optional Python `int` factor by which to thin the posterior samples, to reduce complexity of the predictive distribution. For example, if `thin_every=10`, every `10`th sample will be used. Default value: `10`. Returns: predictive_dist: A `tfd.MixtureSameFamily` instance of event shape `[num_timesteps + num_forecast_steps]` representing the predictive distribution of each timestep given previous timesteps. """ dtype = dtype_util.common_dtype([ posterior_samples.level_scale, posterior_samples.observation_noise_scale, posterior_samples.level, original_mean, original_scale], dtype_hint=tf.float32) num_observed_steps = prefer_static.shape(posterior_samples.level)[-1] original_mean = tf.convert_to_tensor(original_mean, dtype=dtype) original_scale = tf.convert_to_tensor(original_scale, dtype=dtype) thinned_samples = tf.nest.map_structure(lambda x: x[::thin_every], posterior_samples) if prefer_static.rank_from_shape( # If no slope was inferred, treat as zero. prefer_static.shape(thinned_samples.slope)) <= 1: thinned_samples = thinned_samples._replace( slope=tf.zeros_like(thinned_samples.level), slope_scale=tf.zeros_like(thinned_samples.level_scale)) num_steps_from_last_observation = tf.concat([ tf.ones([num_observed_steps], dtype=dtype), tf.range(1, num_forecast_steps + 1, dtype=dtype)], axis=0) # The local linear trend model expects that the level at step t + 1 is equal # to the level at step t, plus the slope at time t - 1, # plus transition noise of scale 'level_scale' (which we account for below). if num_forecast_steps > 0: num_batch_dims = prefer_static.rank_from_shape( prefer_static.shape(thinned_samples.level)) - 2 # All else equal, the current level will remain stationary. forecast_level = tf.tile(thinned_samples.level[..., -1:], tf.concat([tf.ones([num_batch_dims + 1], dtype=tf.int32), [num_forecast_steps]], axis=0)) # If the model includes slope, the level will steadily increase. forecast_level += (thinned_samples.slope[..., -1:] * tf.range(1., num_forecast_steps + 1., dtype=forecast_level.dtype)) level_pred = tf.concat([thinned_samples.level[..., :1], # t == 0 (thinned_samples.level[..., :-1] + thinned_samples.slope[..., :-1]) # 1 <= t < T ] + ( [forecast_level] if num_forecast_steps > 0 else []), axis=-1) design_matrix = _get_design_matrix( model).to_dense()[:num_observed_steps + num_forecast_steps] regression_effect = tf.linalg.matvec(design_matrix, thinned_samples.weights) y_mean = ((level_pred + regression_effect) * original_scale[..., tf.newaxis] + original_mean[..., tf.newaxis]) # To derive a forecast variance, including slope uncertainty, let # `r[:k]` be iid Gaussian RVs with variance `level_scale**2` and `s[:k]` be # iid Gaussian RVs with variance `slope_scale**2`. Then the forecast level at # step `T + k` can be written as # (level[T] + # Last known level. # r[0] + ... + r[k] + # Sum of random walk terms on level. # slope[T] * k # Contribution from last known slope. # (k - 1) * s[0] + # Contributions from random walk terms on slope. # (k - 2) * s[1] + # ... + # 1 * s[k - 1]) # which has variance of # (level_scale**2 * k + # slope_scale**2 * ( (k - 1)**2 + # (k - 2)**2 + # ... + 1 )) # Here the `slope_scale` coefficient is the `k - 1`th square pyramidal # number [1], which is given by # (k - 1) * k * (2 * k - 1) / 6. # # [1] https://en.wikipedia.org/wiki/Square_pyramidal_number variance_from_level = (thinned_samples.level_scale[..., tf.newaxis]**2 * num_steps_from_last_observation) variance_from_slope = thinned_samples.slope_scale[..., tf.newaxis]**2 * ( (num_steps_from_last_observation - 1) * num_steps_from_last_observation * (2 * num_steps_from_last_observation - 1)) / 6. y_scale = (original_scale * tf.sqrt( thinned_samples.observation_noise_scale[..., tf.newaxis]**2 + variance_from_level + variance_from_slope)) num_posterior_draws = prefer_static.shape(y_mean)[0] return tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical( logits=tf.zeros([num_posterior_draws], dtype=y_mean.dtype)), components_distribution=tfd.Normal( loc=dist_util.move_dimension(y_mean, 0, -1), scale=dist_util.move_dimension(y_scale, 0, -1)))
def _batch_shape_tensor(self, loc=None, concentration=None): return prefer_static.broadcast_shape( prefer_static.shape(self.loc if loc is None else loc), prefer_static.shape(self.concentration if concentration is None else concentration))
def vectorized_fn(*args): """Vectorized version of `fn` that accepts arguments of any rank.""" with tf.name_scope(name or 'make_rank_polymorphic'): assertions = [] # If we got a single value for core_ndims, tile it across all args. core_ndims_structure = ( core_ndims if tf.nest.is_nested(core_ndims) else tf.nest.map_structure(lambda _: core_ndims, args)) # Build flat lists of all argument parts and their corresponding core # ndims. flat_core_ndims = tf.nest.flatten(core_ndims_structure) flat_args = nest.flatten_up_to( core_ndims_structure, args, check_types=False) # Filter to only the `Tensor`-valued args (taken to be those with `None` # values for `core_ndims`). Other args will be passed through to `fn` # unmodified. (vectorized_arg_core_ndims, vectorized_args, fn_of_vectorized_args) = _lock_in_non_vectorized_args( fn, arg_structure=core_ndims_structure, flat_core_ndims=flat_core_ndims, flat_args=flat_args) # `vectorized_map` requires all inputs to have a single, common batch # dimension `[n]`. So we broadcast all input parts to a common # batch shape, then flatten it down to a single dimension. # First, compute how many 'extra' (batch) ndims each part has. This must # be nonnegative. vectorized_arg_shapes = [ps.shape(arg) for arg in vectorized_args] batch_ndims = [ ps.rank_from_shape(arg_shape) - nd for (arg_shape, nd) in zip( vectorized_arg_shapes, vectorized_arg_core_ndims)] static_ndims = [tf.get_static_value(nd) for nd in batch_ndims] if any([nd and nd < 0 for nd in static_ndims]): raise ValueError('Cannot broadcast a Tensor having lower rank than the ' 'specified `core_ndims`! (saw input ranks {}, ' '`core_ndims` {}).'.format( tf.nest.map_structure( ps.rank_from_shape, vectorized_arg_shapes), vectorized_arg_core_ndims)) if validate_args: for nd, part, core_nd in zip( batch_ndims, vectorized_args, vectorized_arg_core_ndims): assertions.append(tf.debugging.assert_non_negative( nd, message='Cannot broadcast a Tensor having lower rank than ' 'the specified `core_ndims`! (saw {} vs minimum rank {}).'.format( part, core_nd))) # Next, split each part's shape into batch and core shapes, and # broadcast the batch shapes. with tf.control_dependencies(assertions): empty_shape = np.zeros([0], dtype=np.int32) batch_shapes, core_shapes = empty_shape, empty_shape if vectorized_arg_shapes: batch_shapes, core_shapes = zip(*[ (arg_shape[:nd], arg_shape[nd:]) for (arg_shape, nd) in zip(vectorized_arg_shapes, batch_ndims)]) broadcast_batch_shape = ( functools.reduce(ps.broadcast_shape, batch_shapes, [])) # Flatten all of the batch dimensions into one. n = tf.cast(ps.reduce_prod(broadcast_batch_shape), tf.int32) static_n = tf.get_static_value(n) if static_n == 1: # We can bypass `vectorized_map` if the batch shape is `[]`, `[1]`, # `[1, 1]`, etc., just by flattening to batch shape `[]`. result_batch_dims = 0 batched_result = fn_of_vectorized_args( tf.nest.map_structure( lambda x, nd: tf.reshape(x, ps.shape(x)[ps.rank(x) - nd:]), vectorized_args, vectorized_arg_core_ndims)) else: # Pad all input parts to the common shape, then flatten # into the single leading dimension `[n]`. # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap # when batch rank is static so that we can exploit broadcasting. broadcast_vectorized_args = [ tf.broadcast_to(part, ps.concat( [broadcast_batch_shape, core_shape], axis=0)) for (part, core_shape) in zip(vectorized_args, core_shapes)] vectorized_args_with_flattened_batch_dim = [ tf.reshape(part, ps.concat([[n], core_shape], axis=0)) for (part, core_shape) in zip( broadcast_vectorized_args, core_shapes)] result_batch_dims = 1 batched_result = tf.vectorized_map( fn_of_vectorized_args, vectorized_args_with_flattened_batch_dim) # Unflatten any `Tensor`s in the result. unflatten = lambda x: tf.reshape(x, ps.concat([ # pylint: disable=g-long-lambda broadcast_batch_shape, ps.shape(x)[result_batch_dims:]], axis=0)) result = tf.nest.map_structure( lambda x: unflatten(x) if tf.is_tensor(x) else x, batched_result, expand_composites=True) return result
def _resample_weights(design_matrix, target_residuals, observation_noise_scale, weights_prior_scale, is_missing=None, seed=None): """Samples regression weights from their conditional posterior. This assumes a conjugate normal regression model, ``` weights ~ Normal(loc=0., covariance_matrix=weights_prior_scale**2 * I) target_residuals ~ Normal(loc=matvec(design_matrix, weights), covariance_matrix=observation_noise_scale**2 * I) ``` and returns a sample from `p(weights | target_residuals, observation_noise_scale, design_matrix)`. Args: design_matrix: Float `Tensor` design matrix of shape `[..., num_timesteps, num_features]`. target_residuals: Float `Tensor` of shape `[..., num_observations]` observation_noise_scale: Scalar float `Tensor` (with optional batch shape) standard deviation of the iid observation noise. weights_prior_scale: Scalar float `Tensor` (with optional batch shape) specifying the standard deviation of the Normal prior on regression weights. is_missing: Optional `bool` `Tensor` of shape `[..., num_timesteps]`. A `True` value indicates that the observation for that timestep is missing. seed: Optional `Python` `int` seed controlling the sampled values. Returns: weights: Float `Tensor` of shape `[..., num_features]`, sampled from the conditional posterior `p(weights | target_residuals, observation_noise_scale, weights_prior_scale)`. """ if is_missing is not None: # Replace design matrix with zeros at unobserved timesteps. This ensures # they will not affect the posterior on weights. design_matrix = tf.where(is_missing[..., tf.newaxis], tf.zeros_like(design_matrix), design_matrix) design_shape = prefer_static.shape(design_matrix) num_outputs = design_shape[-2] num_features = design_shape[-1] iid_prior_scale = tf.linalg.LinearOperatorScaledIdentity( num_rows=num_features, multiplier=weights_prior_scale) iid_likelihood_scale = tf.linalg.LinearOperatorScaledIdentity( num_rows=num_outputs, multiplier=observation_noise_scale) weights_mean, weights_prec = ( normal_conjugate_posteriors.mvn_conjugate_linear_update( linear_transformation=design_matrix, observation=target_residuals, prior_scale=iid_prior_scale, likelihood_scale=iid_likelihood_scale)) sampled_weights = weights_prec.cholesky().solvevec(samplers.normal( shape=prefer_static.shape(weights_mean), dtype=design_matrix.dtype, seed=seed), adjoint=True) return weights_mean + sampled_weights