def get_marginal_distribution(self, index_points=None): """Compute the marginal of this GP over function values at `index_points`. Args: index_points: `float` `Tensor` representing finite (batch of) vector(s) of points in the index set over which the GP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to a `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. Returns: marginal: a `Normal` or `MultivariateNormalLinearOperator` distribution, according to whether `index_points` consists of one or many index points, respectively. """ with self._name_and_control_scope('get_marginal_distribution'): # TODO(cgs): consider caching the result here, keyed on `index_points`. index_points = self._get_index_points(index_points) covariance = self._compute_covariance(index_points) loc = self._mean_fn(index_points) # If we're sure the number of index points is 1, we can just construct a # scalar Normal. This has computational benefits and supports things like # CDF that aren't otherwise straightforward to provide. if self._is_univariate_marginal(index_points): scale = tf.sqrt(covariance) # `loc` has a trailing 1 in the shape; squeeze it. loc = tf.squeeze(loc, axis=-1) return normal.Normal( loc=loc, scale=scale, validate_args=self._validate_args, allow_nan_stats=self._allow_nan_stats, name='marginal_distribution') else: return self._marginal_fn( loc=loc, covariance=covariance, validate_args=self._validate_args, allow_nan_stats=self._allow_nan_stats, name='marginal_distribution')
def testNormalSurvivalFunction(self): batch_size = 50 mu = self._rng.randn(batch_size) sigma = self._rng.rand(batch_size) + 1.0 x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) normal = normal_lib.Normal(loc=mu, scale=sigma) sf = normal.survival_function(x) self.assertAllEqual(self.evaluate(normal.batch_shape_tensor()), sf.shape) self.assertAllEqual(self.evaluate(normal.batch_shape_tensor()), self.evaluate(sf).shape) self.assertAllEqual(normal.batch_shape, sf.shape) self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) if not stats: return expected_sf = stats.norm(mu, sigma).sf(x) self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype): g = tf.Graph() with g.as_default(): mu = tf.Variable(dtype(0.0)) sigma = tf.Variable(dtype(1.0)) dist = normal_lib.Normal(loc=mu, scale=sigma) p = tf.Variable( np.array([ 0., np.exp(-32.), np.exp(-2.), 1. - np.exp(-2.), 1. - np.exp(-32.), 1. ]).astype(dtype)) value = dist.quantile(p) grads = tf.gradients(value, [mu, p]) with self.cached_session(graph=g): tf.global_variables_initializer().run() self.assertAllFinite(grads[0]) self.assertAllFinite(grads[1])
def _apply_variational_kernel(self, inputs): if (not isinstance(self.kernel_posterior, independent_lib.Independent) or not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): raise TypeError('`DenseLocalReparameterization` requires ' '`kernel_posterior_fn` produce an instance of ' '`tfd.Independent(tfd.Normal)` ' '(saw: \"{}\").'.format( self.kernel_posterior.name)) self.kernel_posterior_affine = normal_lib.Normal( loc=tf.matmul(inputs, self.kernel_posterior.distribution.loc), scale=tf.sqrt( tf.matmul(tf.square(inputs), tf.square( self.kernel_posterior.distribution.scale)))) self.kernel_posterior_affine_tensor = (self.kernel_posterior_tensor_fn( self.kernel_posterior_affine)) self.kernel_posterior_tensor = None return self.kernel_posterior_affine_tensor
def __init__(self, loc, scale, hinge_softness=1., validate_args=False, allow_nan_stats=True, name='InverseSoftplusNormal'): """Construct an inverse-softplus-normal distribution. The InverseSoftplusNormal distribution models positive-valued random variables whose inverse-softplus is normally distributed with mean `loc` and standard deviation `scale`. It is constructed as the softplus transformation of a Normal distribution. Arguments: loc: Floating-point `Tensor`; the means of the underlying Normal distribution(s). scale: Floating-point `Tensor`; the stddevs of the underlying Normal distribution(s). hinge_softness: Floating-point `Tensor`; Governs the hinge of the softplus operation. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean, mode, etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. """ parameters = dict(locals()) with tf.name_scope(name) as name: super(InvSoftplusNormal, self).__init__( distribution=normal.Normal(loc=loc, scale=scale), bijector=Softplus(hinge_softness=hinge_softness), validate_args=validate_args, parameters=parameters, name=name)
def cdf_func(concentration): """A helper function that is passed to _compute_value_and_grad.""" # z is an "almost Normally distributed" random variable. z = ((np.sqrt(2. / np.pi) / tf.math.bessel_i0e(concentration)) * tf.sin(.5 * x)) # This is the correction described in [1] which reduces the error # of the Normal approximation. z2 = z ** 2 z3 = z2 * z z4 = z2 ** 2 c = 24. * concentration c1 = 56. xi = z - z3 / ((c - 2. * z2 - 16.) / 3. - (z4 + (7. / 4.) * z2 + 167. / 2.) / (c - c1 - z2 + 3.)) ** 2 distrib = normal.Normal(tf.cast(0., dtype), tf.cast(1., dtype)) return distrib.cdf(xi)
def testFiniteGradientAtDifficultPoints(self): for dtype in [np.float32, np.float64]: g = tf.Graph() with g.as_default(): mu = tf.Variable(dtype(0.0)) sigma = tf.Variable(dtype(1.0)) dist = normal_lib.Normal(loc=mu, scale=sigma) x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype) for func in [ dist.cdf, dist.log_cdf, dist.survival_function, dist.log_survival_function, dist.log_prob, dist.prob ]: value = func(x) grads = tf.gradients(value, [mu, sigma]) with self.session(graph=g): tf.global_variables_initializer().run() self.assertAllFinite(value) self.assertAllFinite(grads[0]) self.assertAllFinite(grads[1])
def _fn(dtype, shape, name, trainable, add_variable_fn): loc_init = tf.compat.v1.constant_initializer(loc) scale_init = tf.compat.v1.constant_initializer(scale) new_loc = add_variable_fn(name=name + '_loc', shape=shape, initializer=loc_init, regularizer=None, constraint=None, dtype=dtype, trainable=isPosterior) new_scale = add_variable_fn(name=name + '_scale', shape=shape, initializer=scale_init, regularizer=None, constraint=None, dtype=dtype, trainable=isPosterior) dist = normal_lib.Normal(loc=new_loc, scale=new_scale) batch_ndims = tf.size(input=dist.batch_shape_tensor()) return independent_lib.Independent( dist, reinterpreted_batch_ndims=batch_ndims)
def testNormalLogCDF(self): batch_size = 50 mu = self._rng.randn(batch_size) sigma = self._rng.rand(batch_size) + 1.0 x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) normal = normal_lib.Normal(loc=mu, scale=sigma) cdf = normal.log_cdf(x) self.assertAllEqual( self.evaluate(normal.batch_shape_tensor()), cdf.shape) self.assertAllEqual( self.evaluate(normal.batch_shape_tensor()), self.evaluate(cdf).shape) self.assertAllEqual(normal.batch_shape, cdf.shape) self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) if not stats: return expected_cdf = stats.norm(mu, sigma).logcdf(x) self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
def _eval(self, x, weights): kernel_dist, bias_dist = self.unpack_weights_fn( # pylint: disable=not-callable self.posterior.sample_distributions(value=weights)[0]) kernel_loc, kernel_scale = vi_lib.get_spherical_normal_loc_scale( kernel_dist) loc = tf.matmul(x, kernel_loc) scale = tf.sqrt(tf.matmul(tf.square(x), tf.square(kernel_scale))) _, sampled_bias = self.unpack_weights_fn(weights) # pylint: disable=not-callable if sampled_bias is not None: try: bias_loc, bias_scale = vi_lib.get_spherical_normal_loc_scale( bias_dist) is_bias_spherical_normal = True except TypeError: is_bias_spherical_normal = False if is_bias_spherical_normal: loc = loc + bias_loc scale = tf.sqrt(tf.square(scale) + tf.square(bias_scale)) else: loc = loc + sampled_bias y = normal_lib.Normal(loc=loc, scale=scale).sample(seed=self._seed()) return y
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name="LogNormal"): """Construct a log-normal distribution. The LogNormal distribution models positive-valued random variables whose logarithm is normally distributed with mean `loc` and standard deviation `scale`. It is constructed as the exponential transformation of a Normal distribution. Args: loc: Floating-point `Tensor`; the means of the underlying Normal distribution(s). scale: Floating-point `Tensor`; the stddevs of the underlying Normal distribution(s). validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], tf.float32) super(LogNormal, self).__init__(distribution=normal.Normal( loc=tf.convert_to_tensor(value=loc, name="loc", dtype=dtype), scale=tf.convert_to_tensor(value=scale, name="scale", dtype=dtype)), bijector=exp_bijector.Exp(), validate_args=validate_args, parameters=parameters, name=name)
def params_and_state_transition_fn(step, params_and_state, perturbation_scale, **kwargs): """Transition function operating on a `ParamsAndState` namedtuple.""" # Extract the state, to pass through to the observation fn. unconstrained_params, state = params_and_state if 'state_history' in kwargs: kwargs['state_history'] = kwargs['state_history'].state # Perturb each (unconstrained) parameter with normally-distributed noise. if not tf.nest.is_nested(perturbation_scale): perturbation_scale = tf.nest.map_structure( lambda x: tf.convert_to_tensor(perturbation_scale, # pylint: disable=g-long-lambda name='perturbation_scale', dtype=x.dtype), unconstrained_params) perturbed_unconstrained_parameter_dists = tf.nest.map_structure( lambda x, p, s: independent.Independent( # pylint: disable=g-long-lambda normal.Normal(loc=x, scale=p), reinterpreted_batch_ndims=prefer_static.rank_from_shape(s)), unconstrained_params, perturbation_scale, parameter_prior.event_shape_tensor()) # For the joint transition, pass the perturbed parameters # into the original transition fn (after pushing them into constrained # space). return joint_distribution_named.JointDistributionNamed( ParametersAndState( unconstrained_parameters=_maybe_build_joint_distribution( perturbed_unconstrained_parameter_dists), state=lambda unconstrained_parameters: ( # pylint: disable=g-long-lambda parameterized_transition_fn( step, state, parameters=parameter_constraining_bijector.forward( unconstrained_parameters), **kwargs))))
def default_multivariate_normal_fn(dtype, shape, name, trainable, add_variable_fn): """Creates multivariate standard `Normal` distribution. Args: dtype: Type of parameter's event. shape: Python `list`-like representing the parameter's event shape. name: Python `str` name prepended to any created (or existing) `tf.Variable`s. trainable: Python `bool` indicating all created `tf.Variable`s should be added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. add_variable_fn: `tf.get_variable`-like `callable` used to create (or access existing) `tf.Variable`s. Returns: Multivariate standard `Normal` distribution. """ del name, trainable, add_variable_fn # unused dist = normal_lib.Normal(loc=tf.zeros(shape, dtype), scale=dtype.as_numpy_dtype(1)) batch_ndims = tf.size(input=dist.batch_shape_tensor()) return independent_lib.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
def testNormalQuantile(self): batch_size = 52 mu = self._rng.randn(batch_size) sigma = self._rng.rand(batch_size) + 1.0 p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64) # Quantile performs piecewise rational approximation so adding some # special input values to make sure we hit all the pieces. p = np.hstack((p, np.exp(-33), 1. - np.exp(-33))) normal = normal_lib.Normal(loc=mu, scale=sigma) x = normal.quantile(p) self.assertAllEqual(self.evaluate(normal.batch_shape_tensor()), x.shape) self.assertAllEqual(self.evaluate(normal.batch_shape_tensor()), self.evaluate(x).shape) self.assertAllEqual(normal.batch_shape, x.shape) self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape) if not stats: return expected_x = stats.norm(mu, sigma).ppf(p) self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name='LogitNormal'): """Construct a logit-normal distribution. The LogititNormal distribution models positive-valued random variables whose logit (i.e., sigmoid_inverse, i.e., `log(p) - log1p(-p)`) is normally distributed with mean `loc` and standard deviation `scale`. It is constructed as the sigmoid transformation, (i.e., `1 / (1 + exp(-x))`) of a Normal distribution. Args: loc: Floating-point `Tensor`; the mean of the underlying Normal distribution(s). Must broadcast with `scale`. scale: Floating-point `Tensor`; the stddev of the underlying Normal distribution(s). Must broadcast with `loc`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. """ parameters = dict(locals()) with tf.name_scope(name) as name: super(LogitNormal, self).__init__(distribution=normal.Normal(loc=loc, scale=scale), bijector=sigmoid_bijector.Sigmoid(), validate_args=validate_args, parameters=parameters, name=name)
def __call__(self, x, **kwargs): x = tf.convert_to_tensor(x, dtype=self.dtype, name='x') self._posterior_value = self.posterior_value_fn(self.posterior, seed=self._seed()) # pylint: disable=not-callable kernel, bias = self.unpack_weights_fn(self.posterior_value) # pylint: disable=not-callable y = x if kernel is not None: kernel_dist, _ = self.unpack_weights_fn( # pylint: disable=not-callable self.posterior.sample_distributions( value=self.posterior_value)[0]) kernel_loc, kernel_scale = get_spherical_normal_loc_scale( kernel_dist) # batch_size = tf.shape(x)[0] # sign_input_shape = ([batch_size] + # [1] * self._rank + # [self._input_channels]) y *= tfp_random.rademacher(ps.shape(y), dtype=y.dtype, seed=self._seed()) kernel_perturb = normal_lib.Normal(loc=0., scale=kernel_scale) y = self._apply_kernel_fn( # E.g., tf.matmul. y, kernel_perturb.sample(seed=self._seed())) y *= tfp_random.rademacher(ps.shape(y), dtype=y.dtype, seed=self._seed()) y += self._apply_kernel_fn(x, kernel_loc) if bias is not None: y = y + bias if self.activation_fn is not None: y = self.activation_fn(y) # pylint: disable=not-callable return y
def quadrature_scheme_softmaxnormal_quantiles(normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex. A `SoftmaxNormal` random variable `Y` may be generated via ``` Y = SoftmaxCentered(X), X = Normal(normal_loc, normal_scale) ``` Args: normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. The location parameter of the Normal used to construct the SoftmaxNormal. normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`. The scale parameter of the Normal used to construct the SoftmaxNormal. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the convex combination of affine parameters for `K` components. `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the associated with each grid point. """ with tf.name_scope(name or "softmax_normal_grid_and_probs"): normal_loc = tf.convert_to_tensor(value=normal_loc, name="normal_loc") dt = dtype_util.base_dtype(normal_loc.dtype) normal_scale = tf.convert_to_tensor(value=normal_scale, dtype=dt, name="normal_scale") normal_scale = maybe_check_quadrature_param(normal_scale, "normal_scale", validate_args) dist = normal.Normal(loc=normal_loc, scale=normal_scale) def _get_batch_ndims(): """Helper to get rank(dist.batch_shape), statically if possible.""" ndims = tensorshape_util.rank(dist.batch_shape) if ndims is None: ndims = tf.shape(input=dist.batch_shape_tensor())[0] return ndims batch_ndims = _get_batch_ndims() def _get_final_shape(qs): """Helper to build `TensorShape`.""" bs = tensorshape_util.with_rank_at_least(dist.batch_shape, 1) num_components = tf.compat.dimension_value(bs[-1]) if num_components is not None: num_components += 1 tail = tf.TensorShape([num_components, qs]) return bs[:-1].concatenate(tail) def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) quantiles = softmax_centered_bijector.SoftmaxCentered().forward( quantiles) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(a=quantiles, perm=perm) tensorshape_util.set_shape(quantiles, _get_final_shape(quadrature_size + 1)) return quantiles quantiles = _compute_quantiles() # Compute grid as quantile midpoints. grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. # Set shape hints. tensorshape_util.set_shape(grid, _get_final_shape(quadrature_size)) # By construction probs is constant, i.e., `1 / quadrature_size`. This is # important, because non-constant probs leads to non-reparameterizable # samples. probs = tf.fill(dims=[quadrature_size], value=1. / tf.cast(quadrature_size, dist.dtype)) return grid, probs
def testSampleLikeArgsGetDistDType(self): dist = normal_lib.Normal(0., 1.) self.assertEqual(tf.float32, dist.dtype) for method in ("log_prob", "prob", "log_cdf", "cdf", "log_survival_function", "survival_function", "quantile"): self.assertEqual(tf.float32, getattr(dist, method)(1).dtype)
def _batched_isotropic_normal_like(state_part): event_ndims = ps.rank(state_part) - batch_rank return independent.Independent( normal.Normal(ps.zeros_like(state_part, tf.float32), 1.), reinterpreted_batch_ndims=event_ndims)
def _get_cdf_pdf(c): dtype = dtype_util.as_numpy_dtype(c.dtype) d = normal_lib.Normal(dtype(0), 1) return d.cdf, d.prob
def __init__(self, loc, scale, num_probit_terms_approx=2, validate_args=False, allow_nan_stats=True, name='LogitNormal'): """Construct a logit-normal distribution. The LogitNormal distribution models random variables between 0 and 1 whose logit (i.e., sigmoid_inverse, i.e., `log(p) - log1p(-p)`) is normally distributed with mean `loc` and standard deviation `scale`. It is constructed as the sigmoid transformation, (i.e., `1 / (1 + exp(-x))`) of a Normal distribution. Args: loc: Floating-point `Tensor`; the mean of the underlying Normal distribution(s). Must broadcast with `scale`. scale: Floating-point `Tensor`; the stddev of the underlying Normal distribution(s). Must broadcast with `loc`. num_probit_terms_approx: The `k` used in the approximation, `sigmoid(x) approx= sum_i^k p[k,i] Normal(0, c[k, i]).cdf(x)` where `sum_i^k p[k,i]=1` and `p[k,i],c[k,i] > 0` [(Monahan and Stefanski, 1989)][1] and used in `mean_*_approx` functions [(Owen, 1980)][2]. Must be a python scalar integer between `1` and `8` (inclusive). Using `num_probit_terms_approx=2` should result in `mean_approx` error not exceeding `10**-4`. Default value: `2`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. #### References [1]: Monahan, John H., and Leonard A. Stefanski. Normal scale mixture approximations to the logistic distribution with applications. North Carolina State University. Dept. of Statistics, 1989. http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.154.5032 [2]: Owen, Donald Bruce. "A table of normal integrals: A table." Communications in Statistics-Simulation and Computation 9.4 (1980): 389-419. https://www.tandfonline.com/doi/abs/10.1080/03610918008812164 """ parameters = dict(locals()) num_probit_terms_approx = int(num_probit_terms_approx) if num_probit_terms_approx < 1 or num_probit_terms_approx > 8: raise ValueError( 'Argument `num_probit_terms_approx` must be an integer between ' '`1` and `8` (inclusive).') self._num_probit_terms_approx = num_probit_terms_approx with tf.name_scope(name) as name: super(LogitNormal, self).__init__(distribution=normal_lib.Normal(loc=loc, scale=scale), bijector=sigmoid_bijector.Sigmoid(), validate_args=validate_args, parameters=parameters, name=name)
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalLinearOperator"): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not dtype_util.is_floating(scale.dtype): raise TypeError( "`scale` parameter must have floating-point dtype.") with tf.name_scope(name) as name: # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = loc if loc is None else tf.convert_to_tensor( loc, name="loc", dtype=scale.dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) super(MultivariateNormalLinearOperator, self).__init__( distribution=normal.Normal(loc=tf.zeros([], dtype=scale.dtype), scale=tf.ones([], dtype=scale.dtype)), bijector=affine_linear_operator_bijector.AffineLinearOperator( shift=loc, scale=scale, validate_args=validate_args), batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, skewness, tailweight, loc, scale, validate_args=False, allow_nan_stats=True, name=None): """Construct Johnson's SU distributions. The distributions have shape parameteres `tailweight` and `skewness`, mean `loc`, and scale `scale`. The parameters `tailweight`, `skewness`, `loc`, and `scale` must be shaped in a way that supports broadcasting (e.g. `skewness + tailweight + loc + scale` is a valid operation). Args: skewness: Floating-point `Tensor`. Skewness of the distribution(s). tailweight: Floating-point `Tensor`. Tail weight of the distribution(s). `tailweight` must contain only positive values. loc: Floating-point `Tensor`. The mean(s) of the distribution(s). scale: Floating-point `Tensor`. The scaling factor(s) for the distribution(s). Note that `scale` is not technically the standard deviation of this distribution but has semantics more similar to standard deviation than variance. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value '`NaN`' to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if any of skewness, tailweight, loc and scale are different dtypes. """ parameters = dict(locals()) with tf.name_scope(name or 'JohnsonSU') as name: dtype = dtype_util.common_dtype([skewness, tailweight, loc, scale], tf.float32) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name='skewness', dtype=dtype) self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name='tailweight', dtype=dtype) self._loc = tensor_util.convert_nonref_to_tensor(loc, name='loc', dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name='scale', dtype=dtype) norm_shift = invert_bijector.Invert( shift_bijector.Shift(shift=self._skewness, validate_args=validate_args)) norm_scale = invert_bijector.Invert( scale_bijector.Scale(scale=self._tailweight, validate_args=validate_args)) sinh = sinh_bijector.Sinh(validate_args=validate_args) scale = scale_bijector.Scale(scale=self._scale, validate_args=validate_args) shift = shift_bijector.Shift(shift=self._loc, validate_args=validate_args) bijector = shift(scale(sinh(norm_scale(norm_shift)))) batch_rank = ps.reduce_max([ distribution_util.prefer_static_rank(x) for x in (self._skewness, self._tailweight, self._loc, self._scale) ]) super(JohnsonSU, self).__init__( # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden # `batch_shape` and `batch_shape_tensor` when # TransformedDistribution's bijector can modify its `batch_shape`. distribution=normal.Normal(loc=tf.zeros(ps.ones( batch_rank, tf.int32), dtype=dtype), scale=tf.ones([], dtype=dtype), validate_args=validate_args, allow_nan_stats=allow_nan_stats), bijector=bijector, validate_args=validate_args, parameters=parameters, name=name)
def __init__(self, loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name="SinhArcsinh"): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape (indexing batch dimensions). They must all have the same `dtype`. Args: loc: Floating-point `Tensor`. scale: `Tensor` of same `dtype` as `loc`. skewness: Skewness parameter. Default is `0.0` (no skew). tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. Default is `tfd.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a `SinhArcsinh` sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.compat.v2.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight], tf.float32) loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype) scale = tf.convert_to_tensor(value=scale, name="scale", dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if skewness is None else skewness tailweight = tf.convert_to_tensor(value=tailweight, name="tailweight", dtype=dtype) skewness = tf.convert_to_tensor(value=skewness, name="skewness", dtype=dtype) batch_shape = distribution_util.get_broadcast_shape( loc, scale, tailweight, skewness) # Recall, with Z a random variable, # Y := loc + C * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) # C := 2 * scale / F_0(2) if distribution is None: distribution = normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats) else: asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: loc = distribution_util.with_dependencies(asserts, loc) # Make the SAS bijector, 'F'. f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness, tailweight=tailweight) if has_default_skewness: f_noskew = f else: f_noskew = sinh_arcsinh_bijector.SinhArcsinh( skewness=skewness.dtype.as_numpy_dtype(0.), tailweight=tailweight) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) c = 2 * scale / f_noskew.forward( tf.convert_to_tensor(value=2, dtype=dtype)) affine = affine_scalar_bijector.AffineScalar( shift=loc, scale=c, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__(distribution=distribution, bijector=bijector, batch_shape=batch_shape, validate_args=validate_args, name=name) self._parameters = parameters self._loc = loc self._scale = scale self._tailweight = tailweight self._skewness = skewness
def extended_kalman_filter_one_step( state, observation, transition_fn, observation_fn, transition_jacobian_fn, observation_jacobian_fn, name=None): """A single step of the EKF. Args: state: A `Tensor` of shape `concat([[num_timesteps, b1, ..., bN], [state_size]])` with scalar `event_size` and optional batch dimensions `b1, ..., bN`. observation: A `Tensor` of shape `concat([[num_timesteps, b1, ..., bN], [event_size]])` with scalar `event_size` and optional batch dimensions `b1, ..., bN`. transition_fn: a Python `callable` that accepts (batched) vectors of length `state_size`, and returns a `tfd.Distribution` instance, typically a `MultivariateNormal`, representing the state transition and covariance. observation_fn: a Python `callable` that accepts a (batched) vector of length `state_size` and returns a `tfd.Distribution` instance, typically a `MultivariateNormal` representing the observation model and covariance. transition_jacobian_fn: a Python `callable` that accepts a (batched) vector of length `state_size` and returns a (batched) matrix of shape `[state_size, state_size]`, representing the Jacobian of `transition_fn`. observation_jacobian_fn: a Python `callable` that accepts a (batched) vector of length `state_size` and returns a (batched) matrix of size `[state_size, event_size]`, representing the Jacobian of `observation_fn`. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'extended_kalman_filter_one_step'`). Returns: updated_state: `KalmanFilterState` object containing the updated state estimate. """ with tf.name_scope(name or 'extended_kalman_filter_one_step') as name: # If observations are scalar, we can avoid some matrix ops. observation_size_is_static_and_scalar = (observation.shape[-1] == 1) current_state = state.filtered_mean current_covariance = state.filtered_cov current_jacobian = transition_jacobian_fn(current_state) state_prior = transition_fn(current_state) predicted_cov = (tf.matmul( current_jacobian, tf.matmul(current_covariance, current_jacobian, transpose_b=True)) + state_prior.covariance()) predicted_mean = state_prior.mean() observation_dist = observation_fn(predicted_mean) observation_mean = observation_dist.mean() observation_cov = observation_dist.covariance() predicted_jacobian = observation_jacobian_fn(predicted_mean) tmp_obs_cov = tf.matmul(predicted_jacobian, predicted_cov) residual_covariance = tf.matmul( predicted_jacobian, tmp_obs_cov, transpose_b=True) + observation_cov if observation_size_is_static_and_scalar: gain_transpose = tmp_obs_cov / residual_covariance else: chol_residual_cov = tf.linalg.cholesky(residual_covariance) gain_transpose = tf.linalg.cholesky_solve(chol_residual_cov, tmp_obs_cov) filtered_mean = predicted_mean + tf.matmul( gain_transpose, (observation - observation_mean)[..., tf.newaxis], transpose_a=True)[..., 0] tmp_term = -tf.matmul(predicted_jacobian, gain_transpose, transpose_a=True) tmp_term = tf.linalg.set_diag(tmp_term, tf.linalg.diag_part(tmp_term) + 1.) filtered_cov = ( tf.matmul( tmp_term, tf.matmul(predicted_cov, tmp_term), transpose_a=True) + tf.matmul(gain_transpose, tf.matmul(observation_cov, gain_transpose), transpose_a=True)) if observation_size_is_static_and_scalar: # A plain Normal would have event shape `[]`; wrapping with Independent # ensures `event_shape=[1]` as required. predictive_dist = independent.Independent( normal.Normal(loc=observation_mean, scale=tf.sqrt(residual_covariance[..., 0])), reinterpreted_batch_ndims=1) else: predictive_dist = mvn_tril.MultivariateNormalTriL( loc=observation_mean, scale_tril=chol_residual_cov) log_marginal_likelihood = predictive_dist.log_prob(observation) return linear_gaussian_ssm.KalmanFilterState( filtered_mean=filtered_mean, filtered_cov=filtered_cov, predicted_mean=predicted_mean, predicted_cov=predicted_cov, observation_mean=observation_mean, observation_cov=observation_cov, log_marginal_likelihood=log_marginal_likelihood, timestep=state.timestep + 1)
def __init__(self, loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name='SinhArcsinh'): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape (indexing batch dimensions). They must all have the same `dtype`. Args: loc: Floating-point `Tensor`. scale: `Tensor` of same `dtype` as `loc`. skewness: Skewness parameter. Default is `0.0` (no skew). tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. Must have a batch shape to which the shapes of `loc`, `scale`, `skewness`, and `tailweight` all broadcast. Default is `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted shape of the parameters. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a `SinhArcsinh` sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight], tf.float32) self._loc = tensor_util.convert_nonref_to_tensor( loc, name='loc', dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor( scale, name='scale', dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if has_default_skewness else skewness self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name='tailweight', dtype=dtype) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name='skewness', dtype=dtype) # Recall, with Z a random variable, # Y := loc + scale * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C # C := 2 / F_0(2) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) if distribution is None: batch_shape = functools.reduce( ps.broadcast_shape, [ps.shape(x) for x in (self._skewness, self._tailweight, self._loc, self._scale)]) distribution = normal.Normal( loc=tf.zeros(batch_shape, dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats, validate_args=validate_args) # Make the SAS bijector, 'F'. f = sinh_arcsinh_bijector.SinhArcsinh( skewness=self._skewness, tailweight=self._tailweight, validate_args=validate_args) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) affine = shift_bijector.Shift(shift=self._loc)( scale_bijector.Scale(scale=self._scale)) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__( distribution=distribution, bijector=bijector, validate_args=validate_args, name=name) self._parameters = parameters
def quadrature_scheme_lognormal_quantiles(loc, scale, quadrature_size, validate_args=False, name=None): """Use LogNormal quantiles to form quadrature on positive-reals. Args: loc: `float`-like (batch of) scalar `Tensor`; the location parameter of the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: grid: (Batch of) length-`quadrature_size` vectors representing the `log_rate` parameters of a `Poisson`. probs: (Batch of) length-`quadrature_size` vectors representing the weight associate with each `grid` value. """ with tf.name_scope(name, "quadrature_scheme_lognormal_quantiles", [loc, scale]): # Create a LogNormal distribution. dist = transformed_distribution.TransformedDistribution( distribution=normal.Normal(loc=loc, scale=scale), bijector=exp_bijector.Exp(), validate_args=validate_args) batch_ndims = dist.batch_shape.ndims if batch_ndims is None: batch_ndims = tf.shape(dist.batch_shape_tensor())[0] def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(quantiles, perm) return quantiles quantiles = _compute_quantiles() # Compute grid as quantile midpoints. grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. # Set shape hints. grid.set_shape(dist.batch_shape.concatenate([quadrature_size])) # By construction probs is constant, i.e., `1 / quadrature_size`. This is # important, because non-constant probs leads to non-reparameterizable # samples. probs = tf.fill(dims=[quadrature_size], value=1. / tf.cast(quadrature_size, dist.dtype)) return grid, probs
def variational_loss(self, observations, observation_index_points=None, kl_weight=1., name='variational_loss'): """Variational loss for the VGP. Given `observations` and `observation_index_points`, compute the negative variational lower bound as specified in [Hensman, 2013][1]. Args: observations: `float` `Tensor` representing collection, or batch of collections, of observations corresponding to `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which must be brodcastable with the batch and example shapes of `observation_index_points`. The batch shape `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `observation_index_points`, etc.). observation_index_points: `float` `Tensor` representing finite (batch of) vector(s) of points where observations are defined. Shape has the form `[b1, ..., bB, e1, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e1` is the number (size) of index points in each batch (we denote it `e1` to distinguish it from the numer of inducing index points, denoted `e2` below). If set to `None` uses `index_points` as the origin for observations. Default value: None. kl_weight: Amount by which to scale the KL divergence loss between prior and posterior. Default value: 1. name: Python `str` name prefixed to Ops created by this class. Default value: "GaussianProcess". Returns: loss: Scalar tensor representing the negative variational lower bound. Can be directly used in a `tf.Optimizer`. Raises: ValueError: if `mean_fn` is not `None` and is not callable. #### References [1]: Hensman, J., Lawrence, N. "Gaussian Processes for Big Data", 2013 https://arxiv.org/abs/1309.6835 """ with tf.name_scope(name or 'variational_gp_loss'): if observation_index_points is None: observation_index_points = self._index_points observation_index_points = tf.convert_to_tensor( observation_index_points, dtype=self._dtype, name='observation_index_points') observations = tf.convert_to_tensor(observations, dtype=self._dtype, name='observations') kl_weight = tf.convert_to_tensor(kl_weight, dtype=self._dtype, name='kl_weight') # The variational loss is a negative ELBO. The ELBO can be broken down # into three terms: # 1. a likelihood term # 2. a trace term arising from the covariance of the posterior predictive kzx = self.kernel.matrix(self._inducing_index_points, observation_index_points) kzx_linop = tf.linalg.LinearOperatorFullMatrix(kzx) loc = (self._mean_fn(observation_index_points) + kzx_linop.matvec(self._kzz_inv_varloc, adjoint=True)) likelihood = independent.Independent(normal.Normal( loc=loc, scale=tf.sqrt(self._observation_noise_variance + self._jitter), name='NormalLikelihood'), reinterpreted_batch_ndims=1) obs_ll = likelihood.log_prob(observations) chol_kzz_linop = tf.linalg.LinearOperatorLowerTriangular( self._chol_kzz) chol_kzz_inv_kzx = chol_kzz_linop.solve(kzx) kzz_inv_kzx = chol_kzz_linop.solve(chol_kzz_inv_kzx, adjoint=True) kxx_diag = self.kernel.apply(observation_index_points, observation_index_points, example_ndims=1) ktilde_trace_term = ( tf.reduce_sum(kxx_diag, axis=-1) - tf.reduce_sum(chol_kzz_inv_kzx**2, axis=[-2, -1])) # Tr(SB) # where S = A A.T, A = variational_inducing_observations_scale # and B = Kzz^-1 Kzx Kzx.T Kzz^-1 # # Now Tr(SB) = Tr(A A.T Kzz^-1 Kzx Kzx.T Kzz^-1) # = Tr(A.T Kzz^-1 Kzx Kzx.T Kzz^-1 A) # = sum_ij (A.T Kzz^-1 Kzx)_{ij}^2 other_trace_term = tf.reduce_sum( (self._variational_inducing_observations_posterior.scale. matmul(kzz_inv_kzx)**2), axis=[-2, -1]) trace_term = (.5 * (ktilde_trace_term + other_trace_term) / self._observation_noise_variance) kl_term = kl_weight * self.surrogate_posterior_kl_divergence_prior( ) lower_bound = (obs_ll - trace_term - kl_term) return -tf.reduce_mean(lower_bound)
def __init__(self, loc=None, scale_diag=None, scale_identity_multiplier=None, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalLinearOperator"): """Construct VectorSinhArcsinhDiag distribution on `R^k`. The arguments `scale_diag` and `scale_identity_multiplier` combine to define the diagonal `scale` referred to in this class docstring: ```none scale = diag(scale_diag + scale_identity_multiplier * ones(k)) ``` The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_diag: Non-zero, floating-point `Tensor` representing a diagonal matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, and characterizes `b`-batches of `k x k` diagonal matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. scale_identity_multiplier: Non-zero, floating-point `Tensor` representing a scale-identity-matrix added to `scale`. May have shape `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale `k x k` identity matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. skewness: Skewness parameter. floating-point `Tensor` with shape broadcastable with `event_shape`. tailweight: Tailweight parameter. floating-point `Tensor` with shape broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is `tfd.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a VectorSinhArcsinhDiag sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ parameters = dict(locals()) with tf.name_scope( name, values=[ loc, scale_diag, scale_identity_multiplier, skewness, tailweight ]) as name: dtype = dtype_util.common_dtype( [loc, scale_diag, scale_identity_multiplier, skewness, tailweight], tf.float32) loc = loc if loc is None else tf.convert_to_tensor( value=loc, name="loc", dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if skewness is None else skewness # Recall, with Z a random variable, # Y := loc + C * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) # C := 2 * scale / F_0(2) # Construct shapes and 'scale' out of the scale_* and loc kwargs. # scale_linop is only an intermediary to: # 1. get shapes from looking at loc and the two scale args. # 2. combine scale_diag with scale_identity_multiplier, which gives us # 'scale', which in turn gives us 'C'. scale_linop = distribution_util.make_diag_scale( loc=loc, scale_diag=scale_diag, scale_identity_multiplier=scale_identity_multiplier, validate_args=False, assert_positive=False, dtype=dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale_linop) # scale_linop.diag_part() is efficient since it is a diag type linop. scale_diag_part = scale_linop.diag_part() dtype = scale_diag_part.dtype if distribution is None: distribution = normal.Normal( loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats) else: asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: scale_diag_part = distribution_util.with_dependencies( asserts, scale_diag_part) # Make the SAS bijector, 'F'. skewness = tf.convert_to_tensor( value=skewness, dtype=dtype, name="skewness") tailweight = tf.convert_to_tensor( value=tailweight, dtype=dtype, name="tailweight") f = sinh_arcsinh_bijector.SinhArcsinh( skewness=skewness, tailweight=tailweight) if has_default_skewness: f_noskew = f else: f_noskew = sinh_arcsinh_bijector.SinhArcsinh( skewness=skewness.dtype.as_numpy_dtype(0.), tailweight=tailweight) # Make the Affine bijector, Z --> loc + C * Z. c = 2 * scale_diag_part / f_noskew.forward( tf.convert_to_tensor(value=2, dtype=dtype)) affine = affine_bijector.Affine( shift=loc, scale_diag=c, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(VectorSinhArcsinhDiag, self).__init__( distribution=distribution, bijector=bijector, batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters self._loc = loc self._scale = scale_linop self._tailweight = tailweight self._skewness = skewness
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, experimental_use_kahan_sum=False, name='MultivariateNormalLinearOperator'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan summation to aggregate independent underlying log_prob values. For best results, Kahan summation should also be applied when computing the log-determinant of the `LinearOperator` representing the scale matrix. Kahan summation improves against the precision of a naive float32 sum. This can be noticeable in particular for large dimensions in float32. See CPU caveat on `tfp.math.reduce_kahan_sum`. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) self._experimental_use_kahan_sum = experimental_use_kahan_sum if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError('`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = tensor_util.convert_nonref_to_tensor( loc, dtype=dtype, name='loc') batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) self._loc = loc self._scale = scale bijector = scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale, validate_args=validate_args) if loc is not None: bijector = shift_bijector.Shift( shift=loc, validate_args=validate_args)(bijector) super(MultivariateNormalLinearOperator, self).__init__( # TODO(b/137665504): Use batch-adding meta-distribution to set the batch # shape instead of tf.zeros. # We use `Sample` instead of `Independent` because `Independent` # requires concatenating `batch_shape` and `event_shape`, which loses # static `batch_shape` information when `event_shape` is not statically # known. distribution=sample.Sample( normal.Normal( loc=tf.zeros(batch_shape, dtype=dtype), scale=tf.ones([], dtype=dtype)), event_shape, experimental_use_kahan_sum=experimental_use_kahan_sum), bijector=bijector, validate_args=validate_args, name=name) self._parameters = parameters