def get_marginal_distribution(self, index_points=None):
    """Compute the marginal of this GP over function values at `index_points`.

    Args:
      index_points: `float` `Tensor` representing finite (batch of) vector(s) of
        points in the index set over which the GP is defined. Shape has the form
        `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature
        dimensions and must equal `kernel.feature_ndims` and `e` is the number
        (size) of index points in each batch. Ultimately this distribution
        corresponds to a `e`-dimensional multivariate normal. The batch shape
        must be broadcastable with `kernel.batch_shape` and any batch dims
        yielded by `mean_fn`.

    Returns:
      marginal: a `Normal` or `MultivariateNormalLinearOperator` distribution,
        according to whether `index_points` consists of one or many index
        points, respectively.
    """
    with self._name_and_control_scope('get_marginal_distribution'):
      # TODO(cgs): consider caching the result here, keyed on `index_points`.
      index_points = self._get_index_points(index_points)
      covariance = self._compute_covariance(index_points)
      loc = self._mean_fn(index_points)
      # If we're sure the number of index points is 1, we can just construct a
      # scalar Normal. This has computational benefits and supports things like
      # CDF that aren't otherwise straightforward to provide.
      if self._is_univariate_marginal(index_points):
        scale = tf.sqrt(covariance)
        # `loc` has a trailing 1 in the shape; squeeze it.
        loc = tf.squeeze(loc, axis=-1)
        return normal.Normal(
            loc=loc,
            scale=scale,
            validate_args=self._validate_args,
            allow_nan_stats=self._allow_nan_stats,
            name='marginal_distribution')
      else:
        return self._marginal_fn(
            loc=loc,
            covariance=covariance,
            validate_args=self._validate_args,
            allow_nan_stats=self._allow_nan_stats,
            name='marginal_distribution')
Example #2
0
    def testNormalSurvivalFunction(self):
        batch_size = 50
        mu = self._rng.randn(batch_size)
        sigma = self._rng.rand(batch_size) + 1.0
        x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)

        normal = normal_lib.Normal(loc=mu, scale=sigma)

        sf = normal.survival_function(x)
        self.assertAllEqual(self.evaluate(normal.batch_shape_tensor()),
                            sf.shape)
        self.assertAllEqual(self.evaluate(normal.batch_shape_tensor()),
                            self.evaluate(sf).shape)
        self.assertAllEqual(normal.batch_shape, sf.shape)
        self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
        if not stats:
            return
        expected_sf = stats.norm(mu, sigma).sf(x)
        self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
Example #3
0
    def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype):
        g = tf.Graph()
        with g.as_default():
            mu = tf.Variable(dtype(0.0))
            sigma = tf.Variable(dtype(1.0))
            dist = normal_lib.Normal(loc=mu, scale=sigma)
            p = tf.Variable(
                np.array([
                    0.,
                    np.exp(-32.),
                    np.exp(-2.), 1. - np.exp(-2.), 1. - np.exp(-32.), 1.
                ]).astype(dtype))

            value = dist.quantile(p)
            grads = tf.gradients(value, [mu, p])
            with self.cached_session(graph=g):
                tf.global_variables_initializer().run()
                self.assertAllFinite(grads[0])
                self.assertAllFinite(grads[1])
Example #4
0
 def _apply_variational_kernel(self, inputs):
     if (not isinstance(self.kernel_posterior, independent_lib.Independent)
             or not isinstance(self.kernel_posterior.distribution,
                               normal_lib.Normal)):
         raise TypeError('`DenseLocalReparameterization` requires '
                         '`kernel_posterior_fn` produce an instance of '
                         '`tfd.Independent(tfd.Normal)` '
                         '(saw: \"{}\").'.format(
                             self.kernel_posterior.name))
     self.kernel_posterior_affine = normal_lib.Normal(
         loc=tf.matmul(inputs, self.kernel_posterior.distribution.loc),
         scale=tf.sqrt(
             tf.matmul(tf.square(inputs),
                       tf.square(
                           self.kernel_posterior.distribution.scale))))
     self.kernel_posterior_affine_tensor = (self.kernel_posterior_tensor_fn(
         self.kernel_posterior_affine))
     self.kernel_posterior_tensor = None
     return self.kernel_posterior_affine_tensor
Example #5
0
    def __init__(self,
                 loc,
                 scale,
                 hinge_softness=1.,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='InverseSoftplusNormal'):
        """Construct an inverse-softplus-normal distribution.

        The InverseSoftplusNormal distribution models positive-valued
        random variables whose inverse-softplus is normally distributed
        with mean `loc` and standard deviation `scale`. It is
        constructed as the softplus transformation of a Normal
        distribution.

        Arguments:
            loc: Floating-point `Tensor`; the means of the underlying
                Normal distribution(s).
            scale: Floating-point `Tensor`; the stddevs of the
                underlying Normal distribution(s).
            hinge_softness: Floating-point `Tensor`; Governs the hinge
                of the softplus operation.
            validate_args: Python `bool`, default `False`. Whether to
                validate input with asserts. If `validate_args` is
                `False`, and the inputs are invalid, correct behavior
                is not guaranteed.
            allow_nan_stats: Python `bool`, default `True`. If `False`,
                raise an exception if a statistic (e.g. mean, mode,
                etc...) is undefined for any batch member. If `True`,
                batch members with valid parameters leading to
                undefined statistics will return NaN for this statistic.
            name: The name to give Ops created by the initializer.

        """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            super(InvSoftplusNormal, self).__init__(
                distribution=normal.Normal(loc=loc, scale=scale),
                bijector=Softplus(hinge_softness=hinge_softness),
                validate_args=validate_args,
                parameters=parameters,
                name=name)
Example #6
0
  def cdf_func(concentration):
    """A helper function that is passed to _compute_value_and_grad."""
    # z is an "almost Normally distributed" random variable.
    z = ((np.sqrt(2. / np.pi) / tf.math.bessel_i0e(concentration)) *
         tf.sin(.5 * x))

    # This is the correction described in [1] which reduces the error
    # of the Normal approximation.
    z2 = z ** 2
    z3 = z2 * z
    z4 = z2 ** 2
    c = 24. * concentration
    c1 = 56.

    xi = z - z3 / ((c - 2. * z2 - 16.) / 3. -
                   (z4 + (7. / 4.) * z2 + 167. / 2.) / (c - c1 - z2 + 3.)) ** 2

    distrib = normal.Normal(tf.cast(0., dtype), tf.cast(1., dtype))

    return distrib.cdf(xi)
Example #7
0
 def testFiniteGradientAtDifficultPoints(self):
     for dtype in [np.float32, np.float64]:
         g = tf.Graph()
         with g.as_default():
             mu = tf.Variable(dtype(0.0))
             sigma = tf.Variable(dtype(1.0))
             dist = normal_lib.Normal(loc=mu, scale=sigma)
             x = np.array([-100., -20., -5., 0., 5., 20.,
                           100.]).astype(dtype)
             for func in [
                     dist.cdf, dist.log_cdf, dist.survival_function,
                     dist.log_survival_function, dist.log_prob, dist.prob
             ]:
                 value = func(x)
                 grads = tf.gradients(value, [mu, sigma])
                 with self.session(graph=g):
                     tf.global_variables_initializer().run()
                     self.assertAllFinite(value)
                     self.assertAllFinite(grads[0])
                     self.assertAllFinite(grads[1])
 def _fn(dtype, shape, name, trainable, add_variable_fn):
     loc_init = tf.compat.v1.constant_initializer(loc)
     scale_init = tf.compat.v1.constant_initializer(scale)
     new_loc = add_variable_fn(name=name + '_loc',
                               shape=shape,
                               initializer=loc_init,
                               regularizer=None,
                               constraint=None,
                               dtype=dtype,
                               trainable=isPosterior)
     new_scale = add_variable_fn(name=name + '_scale',
                                 shape=shape,
                                 initializer=scale_init,
                                 regularizer=None,
                                 constraint=None,
                                 dtype=dtype,
                                 trainable=isPosterior)
     dist = normal_lib.Normal(loc=new_loc, scale=new_scale)
     batch_ndims = tf.size(input=dist.batch_shape_tensor())
     return independent_lib.Independent(
         dist, reinterpreted_batch_ndims=batch_ndims)
Example #9
0
  def testNormalLogCDF(self):
    batch_size = 50
    mu = self._rng.randn(batch_size)
    sigma = self._rng.rand(batch_size) + 1.0
    x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)

    normal = normal_lib.Normal(loc=mu, scale=sigma)

    cdf = normal.log_cdf(x)
    self.assertAllEqual(
        self.evaluate(normal.batch_shape_tensor()), cdf.shape)
    self.assertAllEqual(
        self.evaluate(normal.batch_shape_tensor()),
        self.evaluate(cdf).shape)
    self.assertAllEqual(normal.batch_shape, cdf.shape)
    self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)

    if not stats:
      return
    expected_cdf = stats.norm(mu, sigma).logcdf(x)
    self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
Example #10
0
 def _eval(self, x, weights):
     kernel_dist, bias_dist = self.unpack_weights_fn(  # pylint: disable=not-callable
         self.posterior.sample_distributions(value=weights)[0])
     kernel_loc, kernel_scale = vi_lib.get_spherical_normal_loc_scale(
         kernel_dist)
     loc = tf.matmul(x, kernel_loc)
     scale = tf.sqrt(tf.matmul(tf.square(x), tf.square(kernel_scale)))
     _, sampled_bias = self.unpack_weights_fn(weights)  # pylint: disable=not-callable
     if sampled_bias is not None:
         try:
             bias_loc, bias_scale = vi_lib.get_spherical_normal_loc_scale(
                 bias_dist)
             is_bias_spherical_normal = True
         except TypeError:
             is_bias_spherical_normal = False
         if is_bias_spherical_normal:
             loc = loc + bias_loc
             scale = tf.sqrt(tf.square(scale) + tf.square(bias_scale))
         else:
             loc = loc + sampled_bias
     y = normal_lib.Normal(loc=loc, scale=scale).sample(seed=self._seed())
     return y
Example #11
0
    def __init__(self,
                 loc=None,
                 scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="LogNormal"):
        """Construct a log-normal distribution.

    The LogNormal distribution models positive-valued random variables
    whose logarithm is normally distributed with mean `loc` and
    standard deviation `scale`. It is constructed as the exponential
    transformation of a Normal distribution.

    Args:
      loc: Floating-point `Tensor`; the means of the underlying
        Normal distribution(s).
      scale: Floating-point `Tensor`; the stddevs of the underlying
        Normal distribution(s).
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale], tf.float32)
            super(LogNormal, self).__init__(distribution=normal.Normal(
                loc=tf.convert_to_tensor(value=loc, name="loc", dtype=dtype),
                scale=tf.convert_to_tensor(value=scale,
                                           name="scale",
                                           dtype=dtype)),
                                            bijector=exp_bijector.Exp(),
                                            validate_args=validate_args,
                                            parameters=parameters,
                                            name=name)
  def params_and_state_transition_fn(step,
                                     params_and_state,
                                     perturbation_scale,
                                     **kwargs):
    """Transition function operating on a `ParamsAndState` namedtuple."""
    # Extract the state, to pass through to the observation fn.
    unconstrained_params, state = params_and_state
    if 'state_history' in kwargs:
      kwargs['state_history'] = kwargs['state_history'].state

    # Perturb each (unconstrained) parameter with normally-distributed noise.
    if not tf.nest.is_nested(perturbation_scale):
      perturbation_scale = tf.nest.map_structure(
          lambda x: tf.convert_to_tensor(perturbation_scale,  # pylint: disable=g-long-lambda
                                         name='perturbation_scale',
                                         dtype=x.dtype),
          unconstrained_params)
    perturbed_unconstrained_parameter_dists = tf.nest.map_structure(
        lambda x, p, s: independent.Independent(  # pylint: disable=g-long-lambda
            normal.Normal(loc=x, scale=p),
            reinterpreted_batch_ndims=prefer_static.rank_from_shape(s)),
        unconstrained_params,
        perturbation_scale,
        parameter_prior.event_shape_tensor())

    # For the joint transition, pass the perturbed parameters
    # into the original transition fn (after pushing them into constrained
    # space).
    return joint_distribution_named.JointDistributionNamed(
        ParametersAndState(
            unconstrained_parameters=_maybe_build_joint_distribution(
                perturbed_unconstrained_parameter_dists),
            state=lambda unconstrained_parameters: (  # pylint: disable=g-long-lambda
                parameterized_transition_fn(
                    step,
                    state,
                    parameters=parameter_constraining_bijector.forward(
                        unconstrained_parameters),
                    **kwargs))))
Example #13
0
def default_multivariate_normal_fn(dtype, shape, name, trainable,
                                   add_variable_fn):
    """Creates multivariate standard `Normal` distribution.

  Args:
    dtype: Type of parameter's event.
    shape: Python `list`-like representing the parameter's event shape.
    name: Python `str` name prepended to any created (or existing)
      `tf.Variable`s.
    trainable: Python `bool` indicating all created `tf.Variable`s should be
      added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
    add_variable_fn: `tf.get_variable`-like `callable` used to create (or
      access existing) `tf.Variable`s.

  Returns:
    Multivariate standard `Normal` distribution.
  """
    del name, trainable, add_variable_fn  # unused
    dist = normal_lib.Normal(loc=tf.zeros(shape, dtype),
                             scale=dtype.as_numpy_dtype(1))
    batch_ndims = tf.size(input=dist.batch_shape_tensor())
    return independent_lib.Independent(dist,
                                       reinterpreted_batch_ndims=batch_ndims)
Example #14
0
    def testNormalQuantile(self):
        batch_size = 52
        mu = self._rng.randn(batch_size)
        sigma = self._rng.rand(batch_size) + 1.0
        p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
        # Quantile performs piecewise rational approximation so adding some
        # special input values to make sure we hit all the pieces.
        p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))

        normal = normal_lib.Normal(loc=mu, scale=sigma)
        x = normal.quantile(p)

        self.assertAllEqual(self.evaluate(normal.batch_shape_tensor()),
                            x.shape)
        self.assertAllEqual(self.evaluate(normal.batch_shape_tensor()),
                            self.evaluate(x).shape)
        self.assertAllEqual(normal.batch_shape, x.shape)
        self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)

        if not stats:
            return
        expected_x = stats.norm(mu, sigma).ppf(p)
        self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
Example #15
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='LogitNormal'):
        """Construct a logit-normal distribution.

    The LogititNormal distribution models positive-valued random variables whose
    logit (i.e., sigmoid_inverse, i.e., `log(p) - log1p(-p)`) is normally
    distributed with mean `loc` and standard deviation `scale`. It is
    constructed as the sigmoid transformation, (i.e., `1 / (1 + exp(-x))`) of a
    Normal distribution.

    Args:
      loc: Floating-point `Tensor`; the mean of the underlying
        Normal distribution(s). Must broadcast with `scale`.
      scale: Floating-point `Tensor`; the stddev of the underlying
        Normal distribution(s). Must broadcast with `loc`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            super(LogitNormal,
                  self).__init__(distribution=normal.Normal(loc=loc,
                                                            scale=scale),
                                 bijector=sigmoid_bijector.Sigmoid(),
                                 validate_args=validate_args,
                                 parameters=parameters,
                                 name=name)
Example #16
0
    def __call__(self, x, **kwargs):
        x = tf.convert_to_tensor(x, dtype=self.dtype, name='x')
        self._posterior_value = self.posterior_value_fn(self.posterior,
                                                        seed=self._seed())  # pylint: disable=not-callable
        kernel, bias = self.unpack_weights_fn(self.posterior_value)  # pylint: disable=not-callable
        y = x

        if kernel is not None:
            kernel_dist, _ = self.unpack_weights_fn(  # pylint: disable=not-callable
                self.posterior.sample_distributions(
                    value=self.posterior_value)[0])
            kernel_loc, kernel_scale = get_spherical_normal_loc_scale(
                kernel_dist)

            # batch_size = tf.shape(x)[0]
            # sign_input_shape = ([batch_size] +
            #                     [1] * self._rank +
            #                     [self._input_channels])
            y *= tfp_random.rademacher(ps.shape(y),
                                       dtype=y.dtype,
                                       seed=self._seed())
            kernel_perturb = normal_lib.Normal(loc=0., scale=kernel_scale)
            y = self._apply_kernel_fn(  # E.g., tf.matmul.
                y, kernel_perturb.sample(seed=self._seed()))
            y *= tfp_random.rademacher(ps.shape(y),
                                       dtype=y.dtype,
                                       seed=self._seed())
            y += self._apply_kernel_fn(x, kernel_loc)

        if bias is not None:
            y = y + bias

        if self.activation_fn is not None:
            y = self.activation_fn(y)  # pylint: disable=not-callable

        return y
def quadrature_scheme_softmaxnormal_quantiles(normal_loc,
                                              normal_scale,
                                              quadrature_size,
                                              validate_args=False,
                                              name=None):
    """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex.

  A `SoftmaxNormal` random variable `Y` may be generated via

  ```
  Y = SoftmaxCentered(X),
  X = Normal(normal_loc, normal_scale)
  ```

  Args:
    normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
      The location parameter of the Normal used to construct the SoftmaxNormal.
    normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
      The scale parameter of the Normal used to construct the SoftmaxNormal.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      convex combination of affine parameters for `K` components.
      `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex.
    probs:  Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      associated with each grid point.
  """
    with tf.name_scope(name or "softmax_normal_grid_and_probs"):
        normal_loc = tf.convert_to_tensor(value=normal_loc, name="normal_loc")
        dt = dtype_util.base_dtype(normal_loc.dtype)
        normal_scale = tf.convert_to_tensor(value=normal_scale,
                                            dtype=dt,
                                            name="normal_scale")

        normal_scale = maybe_check_quadrature_param(normal_scale,
                                                    "normal_scale",
                                                    validate_args)

        dist = normal.Normal(loc=normal_loc, scale=normal_scale)

        def _get_batch_ndims():
            """Helper to get rank(dist.batch_shape), statically if possible."""
            ndims = tensorshape_util.rank(dist.batch_shape)
            if ndims is None:
                ndims = tf.shape(input=dist.batch_shape_tensor())[0]
            return ndims

        batch_ndims = _get_batch_ndims()

        def _get_final_shape(qs):
            """Helper to build `TensorShape`."""
            bs = tensorshape_util.with_rank_at_least(dist.batch_shape, 1)
            num_components = tf.compat.dimension_value(bs[-1])
            if num_components is not None:
                num_components += 1
            tail = tf.TensorShape([num_components, qs])
            return bs[:-1].concatenate(tail)

        def _compute_quantiles():
            """Helper to build quantiles."""
            # Omit {0, 1} since they might lead to Inf/NaN.
            zero = tf.zeros([], dtype=dist.dtype)
            edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
            # Expand edges so its broadcast across batch dims.
            edges = tf.reshape(
                edges,
                shape=tf.concat(
                    [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
            quantiles = dist.quantile(edges)
            quantiles = softmax_centered_bijector.SoftmaxCentered().forward(
                quantiles)
            # Cyclically permute left by one.
            perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
            quantiles = tf.transpose(a=quantiles, perm=perm)
            tensorshape_util.set_shape(quantiles,
                                       _get_final_shape(quadrature_size + 1))
            return quantiles

        quantiles = _compute_quantiles()

        # Compute grid as quantile midpoints.
        grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2.
        # Set shape hints.
        tensorshape_util.set_shape(grid, _get_final_shape(quadrature_size))

        # By construction probs is constant, i.e., `1 / quadrature_size`. This is
        # important, because non-constant probs leads to non-reparameterizable
        # samples.
        probs = tf.fill(dims=[quadrature_size],
                        value=1. / tf.cast(quadrature_size, dist.dtype))

        return grid, probs
Example #18
0
 def testSampleLikeArgsGetDistDType(self):
   dist = normal_lib.Normal(0., 1.)
   self.assertEqual(tf.float32, dist.dtype)
   for method in ("log_prob", "prob", "log_cdf", "cdf",
                  "log_survival_function", "survival_function", "quantile"):
     self.assertEqual(tf.float32, getattr(dist, method)(1).dtype)
 def _batched_isotropic_normal_like(state_part):
     event_ndims = ps.rank(state_part) - batch_rank
     return independent.Independent(
         normal.Normal(ps.zeros_like(state_part, tf.float32), 1.),
         reinterpreted_batch_ndims=event_ndims)
Example #20
0
def _get_cdf_pdf(c):
    dtype = dtype_util.as_numpy_dtype(c.dtype)
    d = normal_lib.Normal(dtype(0), 1)
    return d.cdf, d.prob
Example #21
0
    def __init__(self,
                 loc,
                 scale,
                 num_probit_terms_approx=2,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='LogitNormal'):
        """Construct a logit-normal distribution.

    The LogitNormal distribution models random variables between 0 and 1 whose
    logit (i.e., sigmoid_inverse, i.e., `log(p) - log1p(-p)`) is normally
    distributed with mean `loc` and standard deviation `scale`. It is
    constructed as the sigmoid transformation, (i.e., `1 / (1 + exp(-x))`) of a
    Normal distribution.

    Args:
      loc: Floating-point `Tensor`; the mean of the underlying
        Normal distribution(s). Must broadcast with `scale`.
      scale: Floating-point `Tensor`; the stddev of the underlying
        Normal distribution(s). Must broadcast with `loc`.
      num_probit_terms_approx: The `k` used in the approximation,
        `sigmoid(x) approx= sum_i^k p[k,i] Normal(0, c[k, i]).cdf(x)`
        where `sum_i^k p[k,i]=1` and `p[k,i],c[k,i] > 0`
        [(Monahan and Stefanski, 1989)][1] and used in `mean_*_approx` functions
        [(Owen, 1980)][2]. Must be a python scalar integer between `1` and `8`
        (inclusive). Using `num_probit_terms_approx=2` should result in
        `mean_approx` error not exceeding `10**-4`.
        Default value: `2`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    #### References

    [1]: Monahan, John H., and Leonard A. Stefanski. Normal scale mixture
         approximations to the logistic distribution with applications. North
         Carolina State University. Dept. of Statistics, 1989.
         http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.154.5032
    [2]: Owen, Donald Bruce. "A table of normal integrals: A table."
         Communications in Statistics-Simulation and Computation 9.4 (1980):
         389-419.
         https://www.tandfonline.com/doi/abs/10.1080/03610918008812164
    """
        parameters = dict(locals())
        num_probit_terms_approx = int(num_probit_terms_approx)
        if num_probit_terms_approx < 1 or num_probit_terms_approx > 8:
            raise ValueError(
                'Argument `num_probit_terms_approx` must be an integer between '
                '`1` and `8` (inclusive).')
        self._num_probit_terms_approx = num_probit_terms_approx
        with tf.name_scope(name) as name:
            super(LogitNormal,
                  self).__init__(distribution=normal_lib.Normal(loc=loc,
                                                                scale=scale),
                                 bijector=sigmoid_bijector.Sigmoid(),
                                 validate_args=validate_args,
                                 parameters=parameters,
                                 name=name)
Example #22
0
    def __init__(self,
                 loc=None,
                 scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MultivariateNormalLinearOperator"):
        """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
        parameters = dict(locals())
        if scale is None:
            raise ValueError("Missing required `scale` parameter.")
        if not dtype_util.is_floating(scale.dtype):
            raise TypeError(
                "`scale` parameter must have floating-point dtype.")

        with tf.name_scope(name) as name:
            # Since expand_dims doesn't preserve constant-ness, we obtain the
            # non-dynamic value if possible.
            loc = loc if loc is None else tf.convert_to_tensor(
                loc, name="loc", dtype=scale.dtype)
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale)

        super(MultivariateNormalLinearOperator, self).__init__(
            distribution=normal.Normal(loc=tf.zeros([], dtype=scale.dtype),
                                       scale=tf.ones([], dtype=scale.dtype)),
            bijector=affine_linear_operator_bijector.AffineLinearOperator(
                shift=loc, scale=scale, validate_args=validate_args),
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
            name=name)
        self._parameters = parameters
Example #23
0
    def __init__(self,
                 skewness,
                 tailweight,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct Johnson's SU distributions.

    The distributions have shape parameteres `tailweight` and `skewness`,
    mean `loc`, and scale `scale`.

    The parameters `tailweight`, `skewness`, `loc`, and `scale` must be shaped
    in a way that supports broadcasting
    (e.g. `skewness + tailweight + loc + scale` is a valid operation).

    Args:
      skewness: Floating-point `Tensor`. Skewness of the distribution(s).
      tailweight: Floating-point `Tensor`. Tail weight of the
        distribution(s). `tailweight` must contain only positive values.
      loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
      scale: Floating-point `Tensor`. The scaling factor(s) for the
        distribution(s). Note that `scale` is not technically the standard
        deviation of this distribution but has semantics more similar to
        standard deviation than variance.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if any of skewness, tailweight, loc and scale are different
        dtypes.
    """
        parameters = dict(locals())
        with tf.name_scope(name or 'JohnsonSU') as name:
            dtype = dtype_util.common_dtype([skewness, tailweight, loc, scale],
                                            tf.float32)
            self._skewness = tensor_util.convert_nonref_to_tensor(
                skewness, name='skewness', dtype=dtype)
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name='tailweight', dtype=dtype)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name='loc',
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name='scale',
                                                               dtype=dtype)

            norm_shift = invert_bijector.Invert(
                shift_bijector.Shift(shift=self._skewness,
                                     validate_args=validate_args))

            norm_scale = invert_bijector.Invert(
                scale_bijector.Scale(scale=self._tailweight,
                                     validate_args=validate_args))

            sinh = sinh_bijector.Sinh(validate_args=validate_args)

            scale = scale_bijector.Scale(scale=self._scale,
                                         validate_args=validate_args)

            shift = shift_bijector.Shift(shift=self._loc,
                                         validate_args=validate_args)

            bijector = shift(scale(sinh(norm_scale(norm_shift))))

            batch_rank = ps.reduce_max([
                distribution_util.prefer_static_rank(x)
                for x in (self._skewness, self._tailweight, self._loc,
                          self._scale)
            ])

            super(JohnsonSU, self).__init__(
                # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden
                # `batch_shape` and `batch_shape_tensor` when
                # TransformedDistribution's bijector can modify its `batch_shape`.
                distribution=normal.Normal(loc=tf.zeros(ps.ones(
                    batch_rank, tf.int32),
                                                        dtype=dtype),
                                           scale=tf.ones([], dtype=dtype),
                                           validate_args=validate_args,
                                           allow_nan_stats=allow_nan_stats),
                bijector=bijector,
                validate_args=validate_args,
                parameters=parameters,
                name=name)
Example #24
0
    def __init__(self,
                 loc,
                 scale,
                 skewness=None,
                 tailweight=None,
                 distribution=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="SinhArcsinh"):
        """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Default is `tfd.Normal(0., 1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())

        with tf.compat.v2.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                            tf.float32)
            loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
            scale = tf.convert_to_tensor(value=scale,
                                         name="scale",
                                         dtype=dtype)
            tailweight = 1. if tailweight is None else tailweight
            has_default_skewness = skewness is None
            skewness = 0. if skewness is None else skewness
            tailweight = tf.convert_to_tensor(value=tailweight,
                                              name="tailweight",
                                              dtype=dtype)
            skewness = tf.convert_to_tensor(value=skewness,
                                            name="skewness",
                                            dtype=dtype)

            batch_shape = distribution_util.get_broadcast_shape(
                loc, scale, tailweight, skewness)

            # Recall, with Z a random variable,
            #   Y := loc + C * F(Z),
            #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
            #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
            #   C := 2 * scale / F_0(2)
            if distribution is None:
                distribution = normal.Normal(loc=tf.zeros([], dtype=dtype),
                                             scale=tf.ones([], dtype=dtype),
                                             allow_nan_stats=allow_nan_stats)
            else:
                asserts = distribution_util.maybe_check_scalar_distribution(
                    distribution, dtype, validate_args)
                if asserts:
                    loc = distribution_util.with_dependencies(asserts, loc)

            # Make the SAS bijector, 'F'.
            f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness,
                                                  tailweight=tailweight)
            if has_default_skewness:
                f_noskew = f
            else:
                f_noskew = sinh_arcsinh_bijector.SinhArcsinh(
                    skewness=skewness.dtype.as_numpy_dtype(0.),
                    tailweight=tailweight)

            # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
            c = 2 * scale / f_noskew.forward(
                tf.convert_to_tensor(value=2, dtype=dtype))
            affine = affine_scalar_bijector.AffineScalar(
                shift=loc, scale=c, validate_args=validate_args)

            bijector = chain_bijector.Chain([affine, f])

            super(SinhArcsinh, self).__init__(distribution=distribution,
                                              bijector=bijector,
                                              batch_shape=batch_shape,
                                              validate_args=validate_args,
                                              name=name)
        self._parameters = parameters
        self._loc = loc
        self._scale = scale
        self._tailweight = tailweight
        self._skewness = skewness
def extended_kalman_filter_one_step(
    state, observation, transition_fn, observation_fn,
    transition_jacobian_fn, observation_jacobian_fn, name=None):
  """A single step of the EKF.

  Args:
    state: A `Tensor` of shape
      `concat([[num_timesteps, b1, ..., bN], [state_size]])` with scalar
      `event_size` and optional batch dimensions `b1, ..., bN`.
    observation: A `Tensor` of shape
      `concat([[num_timesteps, b1, ..., bN], [event_size]])` with scalar
      `event_size` and optional batch dimensions `b1, ..., bN`.
    transition_fn: a Python `callable` that accepts (batched) vectors of length
      `state_size`, and returns a `tfd.Distribution` instance, typically a
      `MultivariateNormal`, representing the state transition and covariance.
    observation_fn: a Python `callable` that accepts a (batched) vector of
      length `state_size` and returns a `tfd.Distribution` instance, typically
      a `MultivariateNormal` representing the observation model and covariance.
    transition_jacobian_fn: a Python `callable` that accepts a (batched) vector
      of length `state_size` and returns a (batched) matrix of shape
      `[state_size, state_size]`, representing the Jacobian of `transition_fn`.
    observation_jacobian_fn: a Python `callable` that accepts a (batched) vector
      of length `state_size` and returns a (batched) matrix of size
      `[state_size, event_size]`, representing the Jacobian of `observation_fn`.
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'extended_kalman_filter_one_step'`).
  Returns:
    updated_state: `KalmanFilterState` object containing the updated state
      estimate.
  """
  with tf.name_scope(name or 'extended_kalman_filter_one_step') as name:

    # If observations are scalar, we can avoid some matrix ops.
    observation_size_is_static_and_scalar = (observation.shape[-1] == 1)

    current_state = state.filtered_mean
    current_covariance = state.filtered_cov
    current_jacobian = transition_jacobian_fn(current_state)
    state_prior = transition_fn(current_state)

    predicted_cov = (tf.matmul(
        current_jacobian,
        tf.matmul(current_covariance, current_jacobian, transpose_b=True)) +
                     state_prior.covariance())
    predicted_mean = state_prior.mean()

    observation_dist = observation_fn(predicted_mean)
    observation_mean = observation_dist.mean()
    observation_cov = observation_dist.covariance()

    predicted_jacobian = observation_jacobian_fn(predicted_mean)
    tmp_obs_cov = tf.matmul(predicted_jacobian, predicted_cov)
    residual_covariance = tf.matmul(
        predicted_jacobian, tmp_obs_cov, transpose_b=True) + observation_cov

    if observation_size_is_static_and_scalar:
      gain_transpose = tmp_obs_cov / residual_covariance
    else:
      chol_residual_cov = tf.linalg.cholesky(residual_covariance)
      gain_transpose = tf.linalg.cholesky_solve(chol_residual_cov, tmp_obs_cov)

    filtered_mean = predicted_mean + tf.matmul(
        gain_transpose,
        (observation - observation_mean)[..., tf.newaxis],
        transpose_a=True)[..., 0]

    tmp_term = -tf.matmul(predicted_jacobian, gain_transpose, transpose_a=True)
    tmp_term = tf.linalg.set_diag(tmp_term, tf.linalg.diag_part(tmp_term) + 1.)
    filtered_cov = (
        tf.matmul(
            tmp_term, tf.matmul(predicted_cov, tmp_term), transpose_a=True) +
        tf.matmul(gain_transpose,
                  tf.matmul(observation_cov, gain_transpose), transpose_a=True))

    if observation_size_is_static_and_scalar:
      # A plain Normal would have event shape `[]`; wrapping with Independent
      # ensures `event_shape=[1]` as required.
      predictive_dist = independent.Independent(
          normal.Normal(loc=observation_mean,
                        scale=tf.sqrt(residual_covariance[..., 0])),
          reinterpreted_batch_ndims=1)

    else:
      predictive_dist = mvn_tril.MultivariateNormalTriL(
          loc=observation_mean,
          scale_tril=chol_residual_cov)

    log_marginal_likelihood = predictive_dist.log_prob(observation)

    return linear_gaussian_ssm.KalmanFilterState(
        filtered_mean=filtered_mean,
        filtered_cov=filtered_cov,
        predicted_mean=predicted_mean,
        predicted_cov=predicted_cov,
        observation_mean=observation_mean,
        observation_cov=observation_cov,
        log_marginal_likelihood=log_marginal_likelihood,
        timestep=state.timestep + 1)
Example #26
0
  def __init__(self,
               loc,
               scale,
               skewness=None,
               tailweight=None,
               distribution=None,
               validate_args=False,
               allow_nan_stats=True,
               name='SinhArcsinh'):
    """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Must have a batch shape to which the shapes of `loc`, `scale`,
        `skewness`, and `tailweight` all broadcast. Default is
        `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted
        shape of the parameters. Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())

    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                      tf.float32)
      self._loc = tensor_util.convert_nonref_to_tensor(
          loc, name='loc', dtype=dtype)
      self._scale = tensor_util.convert_nonref_to_tensor(
          scale, name='scale', dtype=dtype)
      tailweight = 1. if tailweight is None else tailweight
      has_default_skewness = skewness is None
      skewness = 0. if has_default_skewness else skewness
      self._tailweight = tensor_util.convert_nonref_to_tensor(
          tailweight, name='tailweight', dtype=dtype)
      self._skewness = tensor_util.convert_nonref_to_tensor(
          skewness, name='skewness', dtype=dtype)

      # Recall, with Z a random variable,
      #   Y := loc + scale * F(Z),
      #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C
      #   C := 2 / F_0(2)
      #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
      if distribution is None:
        batch_shape = functools.reduce(
            ps.broadcast_shape,
            [ps.shape(x)
             for x in (self._skewness, self._tailweight,
                       self._loc, self._scale)])

        distribution = normal.Normal(
            loc=tf.zeros(batch_shape, dtype=dtype),
            scale=tf.ones([], dtype=dtype),
            allow_nan_stats=allow_nan_stats,
            validate_args=validate_args)

      # Make the SAS bijector, 'F'.
      f = sinh_arcsinh_bijector.SinhArcsinh(
          skewness=self._skewness, tailweight=self._tailweight,
          validate_args=validate_args)

      # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
      affine = shift_bijector.Shift(shift=self._loc)(
          scale_bijector.Scale(scale=self._scale))
      bijector = chain_bijector.Chain([affine, f])

      super(SinhArcsinh, self).__init__(
          distribution=distribution,
          bijector=bijector,
          validate_args=validate_args,
          name=name)
      self._parameters = parameters
Example #27
0
def quadrature_scheme_lognormal_quantiles(loc,
                                          scale,
                                          quadrature_size,
                                          validate_args=False,
                                          name=None):
    """Use LogNormal quantiles to form quadrature on positive-reals.

  Args:
    loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
      the LogNormal prior.
    scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
      the LogNormal prior.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: (Batch of) length-`quadrature_size` vectors representing the
      `log_rate` parameters of a `Poisson`.
    probs: (Batch of) length-`quadrature_size` vectors representing the
      weight associate with each `grid` value.
  """
    with tf.name_scope(name, "quadrature_scheme_lognormal_quantiles",
                       [loc, scale]):
        # Create a LogNormal distribution.
        dist = transformed_distribution.TransformedDistribution(
            distribution=normal.Normal(loc=loc, scale=scale),
            bijector=exp_bijector.Exp(),
            validate_args=validate_args)
        batch_ndims = dist.batch_shape.ndims
        if batch_ndims is None:
            batch_ndims = tf.shape(dist.batch_shape_tensor())[0]

        def _compute_quantiles():
            """Helper to build quantiles."""
            # Omit {0, 1} since they might lead to Inf/NaN.
            zero = tf.zeros([], dtype=dist.dtype)
            edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
            # Expand edges so its broadcast across batch dims.
            edges = tf.reshape(
                edges,
                shape=tf.concat(
                    [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
            quantiles = dist.quantile(edges)
            # Cyclically permute left by one.
            perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
            quantiles = tf.transpose(quantiles, perm)
            return quantiles

        quantiles = _compute_quantiles()

        # Compute grid as quantile midpoints.
        grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2.
        # Set shape hints.
        grid.set_shape(dist.batch_shape.concatenate([quadrature_size]))

        # By construction probs is constant, i.e., `1 / quadrature_size`. This is
        # important, because non-constant probs leads to non-reparameterizable
        # samples.
        probs = tf.fill(dims=[quadrature_size],
                        value=1. / tf.cast(quadrature_size, dist.dtype))

        return grid, probs
    def variational_loss(self,
                         observations,
                         observation_index_points=None,
                         kl_weight=1.,
                         name='variational_loss'):
        """Variational loss for the VGP.

    Given `observations` and `observation_index_points`, compute the
    negative variational lower bound as specified in [Hensman, 2013][1].

    Args:
      observations: `float` `Tensor` representing collection, or batch of
        collections, of observations corresponding to
        `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which
        must be brodcastable with the batch and example shapes of
        `observation_index_points`. The batch shape `[b1, ..., bB]` must be
        broadcastable with the shapes of all other batched parameters
        (`kernel.batch_shape`, `observation_index_points`, etc.).
      observation_index_points: `float` `Tensor` representing finite (batch of)
        vector(s) of points where observations are defined. Shape has the
        form `[b1, ..., bB, e1, f1, ..., fF]` where `F` is the number of feature
        dimensions and must equal `kernel.feature_ndims` and `e1` is the number
        (size) of index points in each batch (we denote it `e1` to distinguish
        it from the numer of inducing index points, denoted `e2` below). If
        set to `None` uses `index_points` as the origin for observations.
        Default value: None.
      kl_weight: Amount by which to scale the KL divergence loss between prior
        and posterior.
        Default value: 1.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "GaussianProcess".
    Returns:
      loss: Scalar tensor representing the negative variational lower bound.
        Can be directly used in a `tf.Optimizer`.
    Raises:
      ValueError: if `mean_fn` is not `None` and is not callable.

    #### References

    [1]: Hensman, J., Lawrence, N. "Gaussian Processes for Big Data", 2013
         https://arxiv.org/abs/1309.6835
    """

        with tf.name_scope(name or 'variational_gp_loss'):
            if observation_index_points is None:
                observation_index_points = self._index_points
            observation_index_points = tf.convert_to_tensor(
                observation_index_points,
                dtype=self._dtype,
                name='observation_index_points')
            observations = tf.convert_to_tensor(observations,
                                                dtype=self._dtype,
                                                name='observations')
            kl_weight = tf.convert_to_tensor(kl_weight,
                                             dtype=self._dtype,
                                             name='kl_weight')

            # The variational loss is a negative ELBO. The ELBO can be broken down
            # into three terms:
            #  1. a likelihood term
            #  2. a trace term arising from the covariance of the posterior predictive

            kzx = self.kernel.matrix(self._inducing_index_points,
                                     observation_index_points)

            kzx_linop = tf.linalg.LinearOperatorFullMatrix(kzx)
            loc = (self._mean_fn(observation_index_points) +
                   kzx_linop.matvec(self._kzz_inv_varloc, adjoint=True))

            likelihood = independent.Independent(normal.Normal(
                loc=loc,
                scale=tf.sqrt(self._observation_noise_variance + self._jitter),
                name='NormalLikelihood'),
                                                 reinterpreted_batch_ndims=1)
            obs_ll = likelihood.log_prob(observations)

            chol_kzz_linop = tf.linalg.LinearOperatorLowerTriangular(
                self._chol_kzz)
            chol_kzz_inv_kzx = chol_kzz_linop.solve(kzx)
            kzz_inv_kzx = chol_kzz_linop.solve(chol_kzz_inv_kzx, adjoint=True)

            kxx_diag = self.kernel.apply(observation_index_points,
                                         observation_index_points,
                                         example_ndims=1)
            ktilde_trace_term = (
                tf.reduce_sum(kxx_diag, axis=-1) -
                tf.reduce_sum(chol_kzz_inv_kzx**2, axis=[-2, -1]))

            # Tr(SB)
            # where S = A A.T, A = variational_inducing_observations_scale
            # and B = Kzz^-1 Kzx Kzx.T Kzz^-1
            #
            # Now Tr(SB) = Tr(A A.T Kzz^-1 Kzx Kzx.T Kzz^-1)
            #            = Tr(A.T Kzz^-1 Kzx Kzx.T Kzz^-1 A)
            #            = sum_ij (A.T Kzz^-1 Kzx)_{ij}^2
            other_trace_term = tf.reduce_sum(
                (self._variational_inducing_observations_posterior.scale.
                 matmul(kzz_inv_kzx)**2),
                axis=[-2, -1])

            trace_term = (.5 * (ktilde_trace_term + other_trace_term) /
                          self._observation_noise_variance)

            kl_term = kl_weight * self.surrogate_posterior_kl_divergence_prior(
            )

            lower_bound = (obs_ll - trace_term - kl_term)

            return -tf.reduce_mean(lower_bound)
Example #29
0
  def __init__(self,
               loc=None,
               scale_diag=None,
               scale_identity_multiplier=None,
               skewness=None,
               tailweight=None,
               distribution=None,
               validate_args=False,
               allow_nan_stats=True,
               name="MultivariateNormalLinearOperator"):
    """Construct VectorSinhArcsinhDiag distribution on `R^k`.

    The arguments `scale_diag` and `scale_identity_multiplier` combine to
    define the diagonal `scale` referred to in this class docstring:

    ```none
    scale = diag(scale_diag + scale_identity_multiplier * ones(k))
    ```

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
        matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
        and characterizes `b`-batches of `k x k` diagonal matrices added to
        `scale`. When both `scale_identity_multiplier` and `scale_diag` are
        `None` then `scale` is the `Identity`.
      scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
        a scale-identity-matrix added to `scale`. May have shape
        `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale
        `k x k` identity matrices added to `scale`. When both
        `scale_identity_multiplier` and `scale_diag` are `None` then `scale`
        is the `Identity`.
      skewness:  Skewness parameter.  floating-point `Tensor` with shape
        broadcastable with `event_shape`.
      tailweight:  Tailweight parameter.  floating-point `Tensor` with shape
        broadcastable with `event_shape`.
      distribution: `tf.Distribution`-like instance. Distribution from which `k`
        iid samples are used as input to transformation `F`.  Default is
        `tfd.Normal(loc=0., scale=1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a VectorSinhArcsinhDiag sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if at most `scale_identity_multiplier` is specified.
    """
    parameters = dict(locals())

    with tf.name_scope(
        name,
        values=[
            loc, scale_diag, scale_identity_multiplier, skewness, tailweight
        ]) as name:
      dtype = dtype_util.common_dtype(
          [loc, scale_diag, scale_identity_multiplier, skewness, tailweight],
          tf.float32)
      loc = loc if loc is None else tf.convert_to_tensor(
          value=loc, name="loc", dtype=dtype)
      tailweight = 1. if tailweight is None else tailweight
      has_default_skewness = skewness is None
      skewness = 0. if skewness is None else skewness

      # Recall, with Z a random variable,
      #   Y := loc + C * F(Z),
      #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
      #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
      #   C := 2 * scale / F_0(2)

      # Construct shapes and 'scale' out of the scale_* and loc kwargs.
      # scale_linop is only an intermediary to:
      #  1. get shapes from looking at loc and the two scale args.
      #  2. combine scale_diag with scale_identity_multiplier, which gives us
      #     'scale', which in turn gives us 'C'.
      scale_linop = distribution_util.make_diag_scale(
          loc=loc,
          scale_diag=scale_diag,
          scale_identity_multiplier=scale_identity_multiplier,
          validate_args=False,
          assert_positive=False,
          dtype=dtype)
      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
          loc, scale_linop)
      # scale_linop.diag_part() is efficient since it is a diag type linop.
      scale_diag_part = scale_linop.diag_part()
      dtype = scale_diag_part.dtype

      if distribution is None:
        distribution = normal.Normal(
            loc=tf.zeros([], dtype=dtype),
            scale=tf.ones([], dtype=dtype),
            allow_nan_stats=allow_nan_stats)
      else:
        asserts = distribution_util.maybe_check_scalar_distribution(
            distribution, dtype, validate_args)
        if asserts:
          scale_diag_part = distribution_util.with_dependencies(
              asserts, scale_diag_part)

      # Make the SAS bijector, 'F'.
      skewness = tf.convert_to_tensor(
          value=skewness, dtype=dtype, name="skewness")
      tailweight = tf.convert_to_tensor(
          value=tailweight, dtype=dtype, name="tailweight")
      f = sinh_arcsinh_bijector.SinhArcsinh(
          skewness=skewness, tailweight=tailweight)
      if has_default_skewness:
        f_noskew = f
      else:
        f_noskew = sinh_arcsinh_bijector.SinhArcsinh(
            skewness=skewness.dtype.as_numpy_dtype(0.),
            tailweight=tailweight)

      # Make the Affine bijector, Z --> loc + C * Z.
      c = 2 * scale_diag_part / f_noskew.forward(
          tf.convert_to_tensor(value=2, dtype=dtype))
      affine = affine_bijector.Affine(
          shift=loc, scale_diag=c, validate_args=validate_args)

      bijector = chain_bijector.Chain([affine, f])

      super(VectorSinhArcsinhDiag, self).__init__(
          distribution=distribution,
          bijector=bijector,
          batch_shape=batch_shape,
          event_shape=event_shape,
          validate_args=validate_args,
          name=name)
    self._parameters = parameters
    self._loc = loc
    self._scale = scale_linop
    self._tailweight = tailweight
    self._skewness = skewness
Example #30
0
  def __init__(self,
               loc=None,
               scale=None,
               validate_args=False,
               allow_nan_stats=True,
               experimental_use_kahan_sum=False,
               name='MultivariateNormalLinearOperator'):
    """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan
        summation to aggregate independent underlying log_prob values. For best
        results, Kahan summation should also be applied when computing the
        log-determinant of the `LinearOperator` representing the scale matrix.
        Kahan summation improves against the precision of a naive float32 sum.
        This can be noticeable in particular for large dimensions in float32.
        See CPU caveat on `tfp.math.reduce_kahan_sum`.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
    parameters = dict(locals())
    self._experimental_use_kahan_sum = experimental_use_kahan_sum
    if scale is None:
      raise ValueError('Missing required `scale` parameter.')
    if not dtype_util.is_floating(scale.dtype):
      raise TypeError('`scale` parameter must have floating-point dtype.')

    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32)
      # Since expand_dims doesn't preserve constant-ness, we obtain the
      # non-dynamic value if possible.
      loc = tensor_util.convert_nonref_to_tensor(
          loc, dtype=dtype, name='loc')
      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
          loc, scale)
    self._loc = loc
    self._scale = scale

    bijector = scale_matvec_linear_operator.ScaleMatvecLinearOperator(
        scale, validate_args=validate_args)
    if loc is not None:
      bijector = shift_bijector.Shift(
          shift=loc, validate_args=validate_args)(bijector)
    super(MultivariateNormalLinearOperator, self).__init__(
        # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
        # shape instead of tf.zeros.
        # We use `Sample` instead of `Independent` because `Independent`
        # requires concatenating `batch_shape` and `event_shape`, which loses
        # static `batch_shape` information when `event_shape` is not statically
        # known.
        distribution=sample.Sample(
            normal.Normal(
                loc=tf.zeros(batch_shape, dtype=dtype),
                scale=tf.ones([], dtype=dtype)),
            event_shape,
            experimental_use_kahan_sum=experimental_use_kahan_sum),
        bijector=bijector,
        validate_args=validate_args,
        name=name)
    self._parameters = parameters