コード例 #1
0
def squared_error_loss(particles, fnet, ftrue):
    """
    MC estimate of the true loss (without using the div(f) trick.
    Up to rescaling + constant, this is equal to the squared
    error E[(f - f_true)**2].
    args:
        particles: array shaped (n, d)
        fnet: callable, computes witness fn
        ftrue: callable, computes grad(log p) - grad(log q)
    """
    return jnp.mean(
        vmap(lambda x: jnp.inner(fnet(x), fnet(x)) / 2 - jnp.inner(
            fnet(x), ftrue(x)))(particles))
コード例 #2
0
  def problem_fun(key):
    """Builds a quadratic loss problem."""
    pkey, ekey, qkey, vkey = jax.random.split(key, 4)

    # Sample eigenvalues.
    log_eigenvalues = jax.random.uniform(ekey,
                                         shape=(n,),
                                         minval=lambda_min,
                                         maxval=lambda_max)
    eigenvalues = 10**log_eigenvalues

    # Build orthonormal basis.
    basis = jax.nn.initializers.orthogonal()(qkey, shape=(n, n))

    # Define hessian.
    hess = jnp.dot(jnp.dot(basis, jnp.diag(eigenvalues), precision=precision),
                   basis.T,
                   precision=precision)

    # Random vector for the linear term in the loss.
    v = jax.random.normal(vkey, shape=(n,))

    # Compute an offset such that the global minimum has a loss of zero.
    xstar = jnp.linalg.solve(hess, -v)
    offset = -0.5 * quadform(hess, xstar, precision=precision) - jnp.inner(v, xstar, precision=precision)  # pylint: disable=line-too-long

    def loss_fun(x, _):
      return 0.5 * quadform(hess, x, precision=precision) + jnp.inner(v, x, precision=precision) + offset  # pylint: disable=line-too-long

    x_init = jax.random.normal(pkey, shape=(n,))
    return x_init, loss_fun
コード例 #3
0
 def update(state, inp):
     w, i = state
     d = inp
     x = jnp.array([A * np.cos(w0 * i * T + phi), A * np.sin(w0 * i * T + phi)])
     y = jnp.inner(w, x)
     e = d - y
     w += 2 * mu * e * x
     i += 1
     state = (w, i)
     return state, e
コード例 #4
0
ファイル: models.py プロジェクト: sagar87/jaxvi
    def log_joint(self, theta: jnp.DeviceArray) -> jnp.DeviceArray:
        betas = theta[:2]
        sigma = theta[2]

        beta_prior = norm.logpdf(betas, 0, 10).sum()
        sigma_prior = gamma.logpdf(sigma, a=1, scale=2).sum()
        yhat = jnp.inner(self.x, betas)
        likelihood = norm.logpdf(self.y, yhat, sigma).sum()

        return beta_prior + sigma_prior + likelihood
コード例 #5
0
def h(x, y, kernel, logp):
    k = kernel

    def h2(x_, y_):
        return np.inner(grad(logp)(y_), grad(k, argnums=0)(x_, y_))

    def d_xk(x_, y_):
        return grad(k, argnums=0)(x_, y_)

    out = np.inner(grad(logp)(x), grad(logp)(y)) * k(x, y) +\
        h2(x, y) + h2(y, x) +\
        np.trace(jacfwd(d_xk, argnums=1)(x, y))
    return out
コード例 #6
0
def step_anf(params, inputs):
    w, mu = params
    d, x = inputs

    y = jnp.inner(w, x)

    e = d - y

    outputs = (e,)

    w += 2 * mu * e * x

    params = (w, mu)

    return params, outputs
コード例 #7
0
def quadform(hess, x, precision):
  """Computes a quadratic form (x^T @ H @ x)."""
  u = jnp.dot(hess, x, precision=precision)  # u = Hx
  return jnp.inner(x, u, precision=precision)
コード例 #8
0
def inner(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.inner(x1, x2))
コード例 #9
0
def normsq(x):
    return np.inner(x, x)
コード例 #10
0
 def h(x, dlogp_x):
     div_f = np.trace(jacfwd(f)(x))
     return np.inner(f(x), dlogp_x) + div_f
コード例 #11
0
 def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
   args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
   onp_fun = lambda lhs, rhs: onp.inner(lhs, rhs)
   lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs)
   self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
   self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
コード例 #12
0
 def h(x, dlogp_x, z):
     zdf = grad(lambda _x: np.vdot(z, f(_x)))
     div_f = np.vdot(zdf(x), z)
     return np.inner(f(x), dlogp_x) + div_f
コード例 #13
0
 def loss_fun(x, _):
   return 0.5 * quadform(hess, x, precision=precision) + jnp.inner(v, x, precision=precision) + offset  # pylint: disable=line-too-long
コード例 #14
0
def stein_operator(fun, x, logp, transposed=False, aux=False):
    """
    Arguments:
    * fun: callable, transformation $\text{fun}: \mathbb R^d \to \mathbb R^d$,
    or $\text{fun}: \mathbb R^d \to \mathbb R$.
    Satisfies $\lim_{x \to \infty} \text{fun}(x) = 0$.
    * x: np.array of shape (d,).
    * p: callable, takes argument of shape (d,). Computes log(p(x)). Can be
    unnormalized (just using gradient.)

    Returns:
    Stein operator $\mathcal A$ evaluated at fun and x:
    \[ \mathcal A_p [\text{fun}](x) .\]
    This expression takes the form of a scalar if transposed else a dxd matrix

    Auxiliary data: values for G (kernel smoothed gradient) and R (kernel
    repulsion term) of shape((G, R)) = (2, d)
    """
    x = np.array(x, dtype=np.float32)
    # if x.ndim < 1: # assume d = 1
    #    x = np.expand_dims(x, 0) # x now has correct shape (d,) = (1,)
    if x.ndim != 1:
        raise ValueError(f"x needs to be an np.array of shape (d,). Instead, "
                         f"x has shape {x.shape}")
    fx = fun(x)
    if transposed:
        if fx.ndim == 0:  # f: R^d --> R
            raise ValueError(
                f"Got passed transposed = True, but the input "
                f"function {fun.__name__} returns a scalar. This "
                "doesn't make sense: the transposed Stein operator "
                "acts only on vector-valued functions.")
        elif fx.ndim == 1:  # f: R^d --> R^d
            drift_term = np.inner(grad(logp)(x), fx)
            repulsive_term = np.trace(jacfwd(fun)(x).transpose())
            auxdata = np.asarray([drift_term, repulsive_term])
            out = drift_term + repulsive_term
        else:
            raise ValueError(f"Output of input function {fun.__name__} needs "
                             f"to have rank 0 or 1. Instead got output "
                             f"of shape {fx.shape}")
    else:
        if fx.ndim == 0:  # f: R^d --> R
            drift_term = grad(logp)(x) * fx
            repulsive_term = grad(fun)(x)
            auxdata = np.asarray([drift_term, repulsive_term])
            out = drift_term + repulsive_term
        elif fx.ndim == 1:  # f: R^d --> R^d
            drift_term = np.einsum("i,j->ij", grad(logp)(x), fun(x))
            repulsive_term = jacfwd(fun)(x).transpose()
            auxdata = np.asarray([drift_term, repulsive_term])
            out = drift_term + repulsive_term
        elif fx.ndim == 2 and fx.shape[0] == fx.shape[1]:  # f: R^d --> R^{dxd}
            raise NotImplementedError("Not implemented for matrix-valued f.")
        else:
            raise ValueError(f"Output of input function {fun.__name__} needs "
                             f"to be a scalar, a vector, or a square matrix. "
                             f"Instead got output of shape {fx.shape}")
    if aux:
        return out, auxdata
    else:
        return out
コード例 #15
0
 def stein_op_true(x):
     return np.inner(optimal_witness(x), optimal_witness(x))
コード例 #16
0
 def h2(x_, y_):
     return np.inner(grad(logp)(y_), grad(k, argnums=0)(x_, y_))
コード例 #17
0
 def h(y):
     return np.inner(dlogp_xi, kx(y)) + grad(kx)(y)
コード例 #18
0
 def squared_error(x, y):
     pred = model.apply(params, x)
     return jnp.inner(y - pred, y - pred) / 2.0
コード例 #19
0
def scalar_product_kernel(x, y):
    """k(x, y) = x^T y"""
    return np.inner(x, y)
コード例 #20
0
ファイル: __init__.py プロジェクト: samuela/research
 def sample(self, rng, sample_shape=()) -> jp.ndarray:
     z = random.normal(rng, shape=sample_shape + self.event_shape)
     return self.loc + jp.inner(z, self.scale_tril)
コード例 #21
0
 def fun_norm(x):
     return np.inner(fun(x), fun(x))
コード例 #22
0
 def entropy(self, alpha):
     alpha0 = np.sum(alpha, axis=-1)
     lnB = _lnB(alpha)
     K = alpha.shape[-1]
     return lnB + (alpha0 - K) * digamma(alpha0) - np.inner(
         (alpha - 1) * digamma(alpha))
コード例 #23
0
 def squared_error(y, y_pred):
     return jnp.inner(y - y_pred, y - y_pred) / 2.0