def quick_grad_check(fun, arg0, extra_args=(), kwargs={}, verbose=True,
                     eps=EPS, rtol=RTOL, atol=ATOL, rs=None):
    """Checks the gradient of a function (w.r.t. to its first arg) in a random direction"""

    if verbose:
        print("Checking gradient of {0} at {1}".format(fun, arg0))

    if rs is None:
        rs = np.random.RandomState()

    random_dir = rs.standard_normal(np.shape(arg0))
    random_dir = random_dir / np.sqrt(np.sum(random_dir * random_dir))
    unary_fun = lambda x : fun(arg0 + x * random_dir, *extra_args, **kwargs)
    numeric_grad = unary_nd(unary_fun, 0.0, eps=eps)

    analytic_grad = np.sum(grad(fun)(arg0, *extra_args, **kwargs) * random_dir)

    assert np.allclose(numeric_grad, analytic_grad, rtol=rtol, atol=atol), \
        "Check failed! nd={0}, ad={1}".format(numeric_grad, analytic_grad)

    if verbose:
        print("Gradient projection OK (numeric grad: {0}, analytic grad: {1})".format(
            numeric_grad, analytic_grad))
        return 0.5 * (np.tril(mat) + np.triu(mat, 1).T)
    elif len(mat.shape) == 3:
        return 0.5 * (np.tril(mat) + np.swapaxes(np.triu(mat, 1), 1,2))
    else:
        raise ArithmeticError

def generalized_outer_product(mat):
    if len(mat.shape) == 1:
        return np.outer(mat, mat)
    elif len(mat.shape) == 2:
        return np.einsum('ij,ik->ijk', mat, mat)
    else:
        raise ArithmeticError

def covgrad(x, mean, cov):
    # I think once we have Cholesky we can make this nicer.
    solved = np.linalg.solve(cov, (x - mean).T).T
    return lower_half(np.linalg.inv(cov) - generalized_outer_product(solved))

logpdf.defgrad(lambda ans, x, mean, cov: unbroadcast(ans, x,    lambda g: -np.expand_dims(g, 1) * np.linalg.solve(cov, (x - mean).T).T), argnum=0)
logpdf.defgrad(lambda ans, x, mean, cov: unbroadcast(ans, mean, lambda g:  np.expand_dims(g, 1) * np.linalg.solve(cov, (x - mean).T).T), argnum=1)
logpdf.defgrad(lambda ans, x, mean, cov: unbroadcast(ans, cov,  lambda g: -np.reshape(g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov)), argnum=2)

# Same as log pdf, but multiplied by the pdf (ans).
pdf.defgrad(lambda ans, x, mean, cov: unbroadcast(ans, x,    lambda g: -g * ans * np.linalg.solve(cov, x - mean)), argnum=0)
pdf.defgrad(lambda ans, x, mean, cov: unbroadcast(ans, mean, lambda g:  g * ans * np.linalg.solve(cov, x - mean)), argnum=1)
pdf.defgrad(lambda ans, x, mean, cov: unbroadcast(ans, cov,  lambda g: -g * ans * covgrad(x, mean, cov)),          argnum=2)

entropy.defgrad_is_zero(argnums=(0,))
entropy.defgrad(lambda ans, mean, cov: unbroadcast(ans, cov, lambda g:  0.5 * g * np.linalg.inv(cov).T), argnum=1)