def nllfun(self, x, alpha, scale): r"""Implements the negative log-likelihood (NLL). Specifically, we implement -log(p(x | 0, \alpha, c) of Equation 16 in the paper as nllfun(x, alpha, shape). Args: x: The residual for which the NLL is being computed. x can have any shape, and alpha and scale will be broadcasted to match x's shape if necessary. Must be a tensorflow tensor or numpy array of floats. alpha: The shape parameter of the NLL (\alpha in the paper), where more negative values cause outliers to "cost" more and inliers to "cost" less. Alpha can be any non-negative value, but the gradient of the NLL with respect to alpha has singularities at 0 and 2 so you may want to limit usage to (0, 2) during gradient descent. Must be a tensorflow tensor or numpy array of floats. Varying alpha in that range allows for smooth interpolation between a Cauchy distribution (alpha = 0) and a Normal distribution (alpha = 2) similar to a Student's T distribution. scale: The scale parameter of the loss. When |x| < scale, the NLL is like that of a (possibly unnormalized) normal distribution, and when |x| > scale the NLL takes on a different shape according to alpha. Must be a tensorflow tensor or numpy array of floats. Returns: The NLLs for each element of x, in the same shape as x. This is returned as a TensorFlow graph node of floats with the same precision as x. """ alpha = jnp.maximum(0, alpha) scale = jnp.maximum(jnp.finfo(jnp.float32).eps, scale) loss = general.lossfun(x, alpha, scale) return loss + jnp.log(scale) + self.log_base_partition_function(alpha)
def testGradientMatchesFiniteDifferences(self): # Test that the loss and its approximation both return gradients that are # close to the numerical gradient from finite differences, with forward # differencing. Returning correct gradients is JAX's job, so this is # just an aggressive sanity check. num_samples = 100000 rng = random.PRNGKey(0) # Normally distributed inputs. rng, key = random.split(rng) x = random.normal(key, shape=[num_samples]) # Uniformly distributed values in (-16, 3), quantized to the nearest # 0.1 and then shifted by 0.05 so that we avoid the special cases at # 0 and 2 where the analytical gradient wont match finite differences. rng, key = random.split(rng) alpha = jnp.round( random.uniform(key, shape=[num_samples], minval=-16, maxval=3) * 10) / 10. + 0.05 # Random log-normally distributed values in approx (1e-5, 100000): rng, key = random.split(rng) scale = random.uniform(key, shape=[num_samples], minval=0.5, maxval=1.5) loss = general.lossfun(x, alpha, scale) d_x, d_alpha, d_scale = (jax.grad( lambda x, a, s: jnp.sum(general.lossfun(x, a, s)), [0, 1, 2])(x, alpha, scale)) step_size = 1e-3 fn = self.variant(general.lossfun) n_x = (fn(x + step_size, alpha, scale) - loss) / step_size n_alpha = (fn(x, alpha + step_size, scale) - loss) / step_size n_scale = (fn(x, alpha, scale + step_size) - loss) / step_size chex.assert_tree_all_close(n_x, d_x, atol=1e-2, rtol=1e-2) chex.assert_tree_all_close(n_alpha, d_alpha, atol=1e-2, rtol=1e-2) chex.assert_tree_all_close(n_scale, d_scale, atol=1e-2, rtol=1e-2)
def testLossIsScaleInvariant(self): # Check that loss(mult * x, alpha, mult * scale) == loss(x, alpha, scale) (num_samples, loss, x, alpha, scale, _, _, _) = (self._precompute_lossfun_inputs()) # Random log-normally distributed scalings in ~(0.2, 20) rng = random.PRNGKey(0) mult = jnp.maximum(0.2, jnp.exp(random.normal(rng, shape=[num_samples]))) # Compute the scaled loss. loss_scaled = general.lossfun(mult * x, alpha, mult * scale) chex.assert_tree_all_close(loss, loss_scaled, atol=1e-4, rtol=1e-4)
def testGradientsAreFiniteWithAllInputs(self, alpha): x_half = jnp.concatenate( [jnp.exp(jnp.linspace(-80, 80, 1001)), jnp.array([jnp.inf])]) x = jnp.concatenate([-x_half[::-1], jnp.array([0.]), x_half]) scale = jnp.full_like(x, 1.) fn = self.variant(lambda x, s: general.lossfun(x, alpha, s)) loss = fn(x, scale) d_x, d_scale = jax.vmap(jax.grad(fn, [0, 1]))(x, scale) for v in [loss, d_x, d_scale]: chex.assert_tree_all_finite(v)
def draw_samples(self, rng, alpha, scale): r"""Draw samples from the robust distribution. This function implements Algorithm 1 the paper. This code is written to allow for sampling from a set of different distributions, each parametrized by its own alpha and scale values, as opposed to the more standard approach of drawing N samples from the same distribution. This is done by repeatedly performing N instances of rejection sampling for each of the N distributions until at least one proposal for each of the N distributions has been accepted. All samples assume a zero mean --- to get non-zero mean samples, just add each mean to each sample. Args: rng: A JAX pseudo random number generated, from random.PRNG(). alpha: A tensor where each element is the shape parameter of that element's distribution. Must be > 0. scale: A tensor where each element is the scale parameter of that element's distribution. Must be >=0 and the same shape as `alpha`. Returns: A tensor with the same shape as `alpha` and `scale` where each element is a sample drawn from the zero-mean distribution specified for that element by `alpha` and `scale`. """ assert jnp.all(scale > 0) assert jnp.all(alpha >= 0) assert jnp.all(jnp.array(alpha.shape) == jnp.array(scale.shape)) shape = alpha.shape samples = jnp.zeros(shape) accepted = jnp.zeros(shape, dtype=bool) # Rejection sampling. while not jnp.all(accepted): # The sqrt(2) scaling of the Cauchy distribution corrects for our # differing conventions for standardization. rng, key = random.split(rng) cauchy_sample = random.cauchy(key, shape=shape) * jnp.sqrt(2) # Compute the likelihood of each sample under its target distribution. nll = self.nllfun(cauchy_sample, alpha, 1) # Bound the NLL. We don't use the approximate loss as it may cause # unpredictable behavior in the context of sampling. nll_bound = (general.lossfun(cauchy_sample, 0, 1) + self.log_base_partition_function(alpha)) # Draw N samples from a uniform distribution, and use each uniform # sample to decide whether or not to accept each proposal sample. rng, key = random.split(rng) uniform_sample = random.uniform(key, shape=shape) accept = uniform_sample <= jnp.exp(nll_bound - nll) # If a sample is accepted, replace its element in `samples` with the # proposal sample, and set its bit in `accepted` to True. samples = jnp.where(accept, cauchy_sample, samples) accepted = accept | accepted # Because our distribution is a location-scale family, we sample from # p(x | 0, \alpha, 1) and then scale each sample by `scale`. samples *= scale return samples