def squared_error_loss(particles, fnet, ftrue): """ MC estimate of the true loss (without using the div(f) trick. Up to rescaling + constant, this is equal to the squared error E[(f - f_true)**2]. args: particles: array shaped (n, d) fnet: callable, computes witness fn ftrue: callable, computes grad(log p) - grad(log q) """ return jnp.mean( vmap(lambda x: jnp.inner(fnet(x), fnet(x)) / 2 - jnp.inner( fnet(x), ftrue(x)))(particles))
def problem_fun(key): """Builds a quadratic loss problem.""" pkey, ekey, qkey, vkey = jax.random.split(key, 4) # Sample eigenvalues. log_eigenvalues = jax.random.uniform(ekey, shape=(n,), minval=lambda_min, maxval=lambda_max) eigenvalues = 10**log_eigenvalues # Build orthonormal basis. basis = jax.nn.initializers.orthogonal()(qkey, shape=(n, n)) # Define hessian. hess = jnp.dot(jnp.dot(basis, jnp.diag(eigenvalues), precision=precision), basis.T, precision=precision) # Random vector for the linear term in the loss. v = jax.random.normal(vkey, shape=(n,)) # Compute an offset such that the global minimum has a loss of zero. xstar = jnp.linalg.solve(hess, -v) offset = -0.5 * quadform(hess, xstar, precision=precision) - jnp.inner(v, xstar, precision=precision) # pylint: disable=line-too-long def loss_fun(x, _): return 0.5 * quadform(hess, x, precision=precision) + jnp.inner(v, x, precision=precision) + offset # pylint: disable=line-too-long x_init = jax.random.normal(pkey, shape=(n,)) return x_init, loss_fun
def update(state, inp): w, i = state d = inp x = jnp.array([A * np.cos(w0 * i * T + phi), A * np.sin(w0 * i * T + phi)]) y = jnp.inner(w, x) e = d - y w += 2 * mu * e * x i += 1 state = (w, i) return state, e
def log_joint(self, theta: jnp.DeviceArray) -> jnp.DeviceArray: betas = theta[:2] sigma = theta[2] beta_prior = norm.logpdf(betas, 0, 10).sum() sigma_prior = gamma.logpdf(sigma, a=1, scale=2).sum() yhat = jnp.inner(self.x, betas) likelihood = norm.logpdf(self.y, yhat, sigma).sum() return beta_prior + sigma_prior + likelihood
def h(x, y, kernel, logp): k = kernel def h2(x_, y_): return np.inner(grad(logp)(y_), grad(k, argnums=0)(x_, y_)) def d_xk(x_, y_): return grad(k, argnums=0)(x_, y_) out = np.inner(grad(logp)(x), grad(logp)(y)) * k(x, y) +\ h2(x, y) + h2(y, x) +\ np.trace(jacfwd(d_xk, argnums=1)(x, y)) return out
def step_anf(params, inputs): w, mu = params d, x = inputs y = jnp.inner(w, x) e = d - y outputs = (e,) w += 2 * mu * e * x params = (w, mu) return params, outputs
def quadform(hess, x, precision): """Computes a quadratic form (x^T @ H @ x).""" u = jnp.dot(hess, x, precision=precision) # u = Hx return jnp.inner(x, u, precision=precision)
def inner(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.inner(x1, x2))
def normsq(x): return np.inner(x, x)
def h(x, dlogp_x): div_f = np.trace(jacfwd(f)(x)) return np.inner(f(x), dlogp_x) + div_f
def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] onp_fun = lambda lhs, rhs: onp.inner(lhs, rhs) lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs) self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def h(x, dlogp_x, z): zdf = grad(lambda _x: np.vdot(z, f(_x))) div_f = np.vdot(zdf(x), z) return np.inner(f(x), dlogp_x) + div_f
def loss_fun(x, _): return 0.5 * quadform(hess, x, precision=precision) + jnp.inner(v, x, precision=precision) + offset # pylint: disable=line-too-long
def stein_operator(fun, x, logp, transposed=False, aux=False): """ Arguments: * fun: callable, transformation $\text{fun}: \mathbb R^d \to \mathbb R^d$, or $\text{fun}: \mathbb R^d \to \mathbb R$. Satisfies $\lim_{x \to \infty} \text{fun}(x) = 0$. * x: np.array of shape (d,). * p: callable, takes argument of shape (d,). Computes log(p(x)). Can be unnormalized (just using gradient.) Returns: Stein operator $\mathcal A$ evaluated at fun and x: \[ \mathcal A_p [\text{fun}](x) .\] This expression takes the form of a scalar if transposed else a dxd matrix Auxiliary data: values for G (kernel smoothed gradient) and R (kernel repulsion term) of shape((G, R)) = (2, d) """ x = np.array(x, dtype=np.float32) # if x.ndim < 1: # assume d = 1 # x = np.expand_dims(x, 0) # x now has correct shape (d,) = (1,) if x.ndim != 1: raise ValueError(f"x needs to be an np.array of shape (d,). Instead, " f"x has shape {x.shape}") fx = fun(x) if transposed: if fx.ndim == 0: # f: R^d --> R raise ValueError( f"Got passed transposed = True, but the input " f"function {fun.__name__} returns a scalar. This " "doesn't make sense: the transposed Stein operator " "acts only on vector-valued functions.") elif fx.ndim == 1: # f: R^d --> R^d drift_term = np.inner(grad(logp)(x), fx) repulsive_term = np.trace(jacfwd(fun)(x).transpose()) auxdata = np.asarray([drift_term, repulsive_term]) out = drift_term + repulsive_term else: raise ValueError(f"Output of input function {fun.__name__} needs " f"to have rank 0 or 1. Instead got output " f"of shape {fx.shape}") else: if fx.ndim == 0: # f: R^d --> R drift_term = grad(logp)(x) * fx repulsive_term = grad(fun)(x) auxdata = np.asarray([drift_term, repulsive_term]) out = drift_term + repulsive_term elif fx.ndim == 1: # f: R^d --> R^d drift_term = np.einsum("i,j->ij", grad(logp)(x), fun(x)) repulsive_term = jacfwd(fun)(x).transpose() auxdata = np.asarray([drift_term, repulsive_term]) out = drift_term + repulsive_term elif fx.ndim == 2 and fx.shape[0] == fx.shape[1]: # f: R^d --> R^{dxd} raise NotImplementedError("Not implemented for matrix-valued f.") else: raise ValueError(f"Output of input function {fun.__name__} needs " f"to be a scalar, a vector, or a square matrix. " f"Instead got output of shape {fx.shape}") if aux: return out, auxdata else: return out
def stein_op_true(x): return np.inner(optimal_witness(x), optimal_witness(x))
def h2(x_, y_): return np.inner(grad(logp)(y_), grad(k, argnums=0)(x_, y_))
def h(y): return np.inner(dlogp_xi, kx(y)) + grad(kx)(y)
def squared_error(x, y): pred = model.apply(params, x) return jnp.inner(y - pred, y - pred) / 2.0
def scalar_product_kernel(x, y): """k(x, y) = x^T y""" return np.inner(x, y)
def sample(self, rng, sample_shape=()) -> jp.ndarray: z = random.normal(rng, shape=sample_shape + self.event_shape) return self.loc + jp.inner(z, self.scale_tril)
def fun_norm(x): return np.inner(fun(x), fun(x))
def entropy(self, alpha): alpha0 = np.sum(alpha, axis=-1) lnB = _lnB(alpha) K = alpha.shape[-1] return lnB + (alpha0 - K) * digamma(alpha0) - np.inner( (alpha - 1) * digamma(alpha))
def squared_error(y, y_pred): return jnp.inner(y - y_pred, y - y_pred) / 2.0