예제 #1
0
    def nllfun(self, x, alpha, scale):
        r"""Implements the negative log-likelihood (NLL).

    Specifically, we implement -log(p(x | 0, \alpha, c) of Equation 16 in the
    paper as nllfun(x, alpha, shape).

    Args:
      x: The residual for which the NLL is being computed. x can have any shape,
        and alpha and scale will be broadcasted to match x's shape if necessary.
        Must be a tensorflow tensor or numpy array of floats.
      alpha: The shape parameter of the NLL (\alpha in the paper), where more
        negative values cause outliers to "cost" more and inliers to "cost"
        less. Alpha can be any non-negative value, but the gradient of the NLL
        with respect to alpha has singularities at 0 and 2 so you may want to
        limit usage to (0, 2) during gradient descent. Must be a tensorflow
        tensor or numpy array of floats. Varying alpha in that range allows for
        smooth interpolation between a Cauchy distribution (alpha = 0) and a
        Normal distribution (alpha = 2) similar to a Student's T distribution.
      scale: The scale parameter of the loss. When |x| < scale, the NLL is like
        that of a (possibly unnormalized) normal distribution, and when |x| >
        scale the NLL takes on a different shape according to alpha. Must be a
        tensorflow tensor or numpy array of floats.

    Returns:
      The NLLs for each element of x, in the same shape as x. This is returned
      as a TensorFlow graph node of floats with the same precision as x.
    """
        alpha = jnp.maximum(0, alpha)
        scale = jnp.maximum(jnp.finfo(jnp.float32).eps, scale)
        loss = general.lossfun(x, alpha, scale)
        return loss + jnp.log(scale) + self.log_base_partition_function(alpha)
예제 #2
0
    def testGradientMatchesFiniteDifferences(self):
        # Test that the loss and its approximation both return gradients that are
        # close to the numerical gradient from finite differences, with forward
        # differencing. Returning correct gradients is JAX's job, so this is
        # just an aggressive sanity check.
        num_samples = 100000
        rng = random.PRNGKey(0)

        # Normally distributed inputs.
        rng, key = random.split(rng)
        x = random.normal(key, shape=[num_samples])

        # Uniformly distributed values in (-16, 3), quantized to the nearest
        # 0.1 and then shifted by 0.05 so that we avoid the special cases at
        # 0 and 2 where the analytical gradient wont match finite differences.
        rng, key = random.split(rng)
        alpha = jnp.round(
            random.uniform(key, shape=[num_samples], minval=-16, maxval=3) *
            10) / 10. + 0.05

        # Random log-normally distributed values in approx (1e-5, 100000):
        rng, key = random.split(rng)
        scale = random.uniform(key,
                               shape=[num_samples],
                               minval=0.5,
                               maxval=1.5)

        loss = general.lossfun(x, alpha, scale)
        d_x, d_alpha, d_scale = (jax.grad(
            lambda x, a, s: jnp.sum(general.lossfun(x, a, s)),
            [0, 1, 2])(x, alpha, scale))

        step_size = 1e-3
        fn = self.variant(general.lossfun)
        n_x = (fn(x + step_size, alpha, scale) - loss) / step_size
        n_alpha = (fn(x, alpha + step_size, scale) - loss) / step_size
        n_scale = (fn(x, alpha, scale + step_size) - loss) / step_size

        chex.assert_tree_all_close(n_x, d_x, atol=1e-2, rtol=1e-2)
        chex.assert_tree_all_close(n_alpha, d_alpha, atol=1e-2, rtol=1e-2)
        chex.assert_tree_all_close(n_scale, d_scale, atol=1e-2, rtol=1e-2)
예제 #3
0
    def testLossIsScaleInvariant(self):
        # Check that loss(mult * x, alpha, mult * scale) == loss(x, alpha, scale)
        (num_samples, loss, x, alpha, scale, _, _,
         _) = (self._precompute_lossfun_inputs())
        # Random log-normally distributed scalings in ~(0.2, 20)

        rng = random.PRNGKey(0)
        mult = jnp.maximum(0.2, jnp.exp(random.normal(rng,
                                                      shape=[num_samples])))

        # Compute the scaled loss.
        loss_scaled = general.lossfun(mult * x, alpha, mult * scale)
        chex.assert_tree_all_close(loss, loss_scaled, atol=1e-4, rtol=1e-4)
예제 #4
0
    def testGradientsAreFiniteWithAllInputs(self, alpha):
        x_half = jnp.concatenate(
            [jnp.exp(jnp.linspace(-80, 80, 1001)),
             jnp.array([jnp.inf])])
        x = jnp.concatenate([-x_half[::-1], jnp.array([0.]), x_half])
        scale = jnp.full_like(x, 1.)

        fn = self.variant(lambda x, s: general.lossfun(x, alpha, s))
        loss = fn(x, scale)
        d_x, d_scale = jax.vmap(jax.grad(fn, [0, 1]))(x, scale)

        for v in [loss, d_x, d_scale]:
            chex.assert_tree_all_finite(v)
예제 #5
0
    def draw_samples(self, rng, alpha, scale):
        r"""Draw samples from the robust distribution.

    This function implements Algorithm 1 the paper. This code is written to
    allow for sampling from a set of different distributions, each parametrized
    by its own alpha and scale values, as opposed to the more standard approach
    of drawing N samples from the same distribution. This is done by repeatedly
    performing N instances of rejection sampling for each of the N distributions
    until at least one proposal for each of the N distributions has been
    accepted. All samples assume a zero mean --- to get non-zero mean samples,
    just add each mean to each sample.

    Args:
      rng: A JAX pseudo random number generated, from random.PRNG().
      alpha: A tensor where each element is the shape parameter of that
        element's distribution. Must be > 0.
      scale: A tensor where each element is the scale parameter of that
        element's distribution. Must be >=0 and the same shape as `alpha`.

    Returns:
      A tensor with the same shape as `alpha` and `scale` where each element is
      a sample drawn from the zero-mean distribution specified for that element
      by `alpha` and `scale`.
    """
        assert jnp.all(scale > 0)
        assert jnp.all(alpha >= 0)
        assert jnp.all(jnp.array(alpha.shape) == jnp.array(scale.shape))
        shape = alpha.shape

        samples = jnp.zeros(shape)
        accepted = jnp.zeros(shape, dtype=bool)

        # Rejection sampling.
        while not jnp.all(accepted):

            # The sqrt(2) scaling of the Cauchy distribution corrects for our
            # differing conventions for standardization.
            rng, key = random.split(rng)
            cauchy_sample = random.cauchy(key, shape=shape) * jnp.sqrt(2)

            # Compute the likelihood of each sample under its target distribution.
            nll = self.nllfun(cauchy_sample, alpha, 1)

            # Bound the NLL. We don't use the approximate loss as it may cause
            # unpredictable behavior in the context of sampling.
            nll_bound = (general.lossfun(cauchy_sample, 0, 1) +
                         self.log_base_partition_function(alpha))

            # Draw N samples from a uniform distribution, and use each uniform
            # sample to decide whether or not to accept each proposal sample.
            rng, key = random.split(rng)
            uniform_sample = random.uniform(key, shape=shape)
            accept = uniform_sample <= jnp.exp(nll_bound - nll)

            # If a sample is accepted, replace its element in `samples` with the
            # proposal sample, and set its bit in `accepted` to True.
            samples = jnp.where(accept, cauchy_sample, samples)
            accepted = accept | accepted

        # Because our distribution is a location-scale family, we sample from
        # p(x | 0, \alpha, 1) and then scale each sample by `scale`.
        samples *= scale

        return samples