Beispiel #1
0
def mat_vec_factory(forward_fn, params, model_state, samples):
    # "forward function" that maps params to outputs
    def fun(W):
        return forward_fn({"params": W, **model_state}, samples)

    _, jvp_fn = jax.linearize(fun, params)
    return Partial(mat_vec, jvp_fn)
Beispiel #2
0
 def _hessianopt(x, f):
     _, hvp = jax.linearize(jax.grad(f), x)
     hvp = jax.jit(hvp)
     vhvp = jax.vmap(hvp)
     vhvp = jax.jit(vhvp)
     basis = np.eye(np.prod(x.shape)).reshape(-1, *x.shape)
     return vhvp(basis).reshape(x.shape + x.shape)
Beispiel #3
0
 def hvp(self, x, v):
     if self.x is None or not np.equal(x, self.x).all():
         _, flin = jax.linearize(self.grad, x)
         self.flin = jax.jit(flin)
         #self.flin = flin
         self.x = x
     return self.flin(v)
Beispiel #4
0
def make_mixed_jvp(f, first_args, second_args, opposite=False):
    """Make a mixed jacobian-vector product function
    Args:
        f (callable): Binary callable with signature f(x,y)
        first_args (numpy.ndarray): First arguments to f
        second_args (numpy.ndarray): Second arguments to f
        opposite (bool, optional): Take Dyx if False, Dxy if True. Defaults to
            False.
    Returns:
        callable: Unary callable 'jvp(v)' taking a numpy.ndarray as input.
    """
    if opposite is not True:
        given = second_args
        gradfun = jax.grad(f, 0)

        def frozen_grad(y):
            return gradfun(first_args, y)
    else:
        given = first_args
        gradfun = jax.grad(f, 1)

        def frozen_grad(x):
            return gradfun(x, second_args)

    return jax.linearize(frozen_grad, given)[1]
Beispiel #5
0
def test_odeint_linearize_fwrap():
    def odeint_fwrap(y0, ts, fargs):
        return odeint(y0, ts, func=f, fargs=fargs)

    _, out_tangent = jvp(odeint_fwrap, (y0, ts, fargs),
                         (y0, ts, fargs))  # when break this is why
    y, f_jvp = linearize(odeint_fwrap, *(y0, ts, fargs))
    out_tangent_2 = f_jvp(*(y0, ts, fargs))

    # print(make_jaxpr(f_jvp)((y0,t0,t1,fargs),))
    check_close(out_tangent, out_tangent_2)
Beispiel #6
0
 def _hessianopt(x, f):
     _, hvp = jax.linearize(jax.grad(f), x)
     hvp = jax.jit(hvp)
     n = np.prod(x.shape)
     idxs = np.arange(vsize, n, vsize)
     basis = np.eye(np.prod(x.shape)).reshape(-1, *x.shape)
     splitbasis = np.split(basis, idxs)
     vhvp = jax.vmap(hvp)
     vhvp = jax.jit(vhvp)
     return np.concatenate([vhvp(b)
                            for b in splitbasis]).reshape(x.shape + x.shape)
Beispiel #7
0
    def _expect_kernel(self, logpsi: Callable, params: PyTree, x: Array,
                       mass: Optional[PyTree]):
        def logpsi_x(x):
            return logpsi(params, x)

        dlogpsi_x = jax.grad(logpsi_x)

        basis = jnp.eye(x.shape[0])

        y, f_jvp = jax.linearize(dlogpsi_x, x)
        dp_dx2 = jnp.diag(jax.vmap(f_jvp)(basis))

        dp_dx = dlogpsi_x(x)**2

        return -0.5 * jnp.sum(mass * (dp_dx2 + dp_dx), axis=-1)
Beispiel #8
0
  def test_linearize(self):
    @djax.djit
    def f(x):
      y = sin(x)
      return reduce_sum(y, axes=(0,))
    x = bbarray((5,), jnp.arange(2.))
    with jax.enable_checks(False):  # TODO implement dxla_call abs eval rule
      z, f_lin = jax.linearize(f, x)
    z_dot = f_lin(ones_like(x))

    def g(x):
      return jnp.sin(x).sum()
    expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),))

    self.assertAllClose(np.array(z), expected_z, check_dtypes=False)
    self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
Beispiel #9
0
def saved_residuals(f, *args,
                    **kwargs) -> List[Tuple[core.AbstractValue, str]]:
    args, in_tree = tree_flatten((args, kwargs))

    def f_(*args):
        args, kwargs = tree_unflatten(in_tree, args)
        return f(*args, **kwargs)

    jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])(
        *args).jaxpr
    res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
    res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

    results = []

    for x in res_lits:
        results.append((x.aval, 'from a literal'))

    for v in jaxpr.constvars:
        if v in res_vars:
            results.append((v.aval, 'from a constant'))

    assert len(jaxpr.invars) == len(args)
    for i, v in enumerate(jaxpr.invars):
        if v in res_vars:
            src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}'
            results.append((v.aval, src))

    for eqn in jaxpr.eqns:
        src = source_info_util.summarize(eqn.source_info)
        for v in eqn.outvars:
            if v in res_vars:
                if eqn.primitive is name_p:
                    results.append(
                        (v.aval, f"named '{eqn.params['name']}' from {src}"))
                else:
                    results.append((v.aval, f'from {src}'))

    assert len(results) == len(jaxpr.outvars)
    return results
def linearized(fn: Callable[..., Tensor], *primals: Tensor) -> LinFun:
    """Returns linear function that is tangent to `fn` at given primal point."""
    val, deriv = jax.linearize(fn, *primals)
    return lambda *xs: val + deriv(*[x - p for x, p in zip(xs, primals)])
Beispiel #11
0
def jvp_unlinearized(f, primals, tangents):
  out, jvp = linearize(f, *primals)
  return out, jvp(*tangents)
Beispiel #12
0
 def hessian_fn(x):
     _, hvp = jax.linearize(jax.grad(f), x)
     hvp = jax.jit(hvp)  # seems like a substantial speedup to do this
     basis = jnp.eye(jnp.prod(x.shape)).reshape(-1, *x.shape)
     return jnp.stack([hvp(e) for e in basis]).reshape(x.shape + x.shape)
# print ("pred: ", predictions)


def net_apply_reverse(inputs, net_params):
    return net_apply(net_params, inputs)


@jit
def test_loss(net_params, inputs):
    return np.sum(net_apply(net_params, inputs))


primals_out, vjpfun = vjp(partial(net_apply_reverse, inputs), net_params)
print(primals_out)

primals_out, jvpfun = linearize(partial(net_apply_reverse, inputs), net_params)
# primals_out, vp = jvp(net_apply, (net_params, inputs), random.normal(rng, (1, 256)))
print(primals_out)
input("")

for i in range(10):
    import time
    s = time.time()
    out = vjpfun(random.normal(rng, (1, 10)))
    e = time.time()
    print("vjp time: ", (e - s))

    s = time.time()
    out = jvpfun(net_params)
    # print (out)
    e = time.time()
Beispiel #14
0
 def f(y):
     z, g_lin = jax.linearize(lambda y: g(x, y), y)
     zdot = g_lin(y)
     return z, zdot
Beispiel #15
0
  print(f(x, n))  # type: ignore
  print(f'should be\n{np.broadcast_to(np.nonzero(x)[0], (4, 2))}')


  ## ad

  @djit
  def f(x):
    y = sin(x)
    return reduce_sum(y, axes=(0,))
  x = bbarray((5,), jnp.arange(2.))
  p('basic jvp')
  z, z_dot = jax.jvp(f, (x,), (ones_like(x),))
  print(z, z_dot)


  p('basic linearize')
  _, f_lin = jax.linearize(f, x)
  print(f_lin(ones_like(x)))


  ## vmap

  @djit
  def f(x):
    return nonzero(x)
  p('vmap of nonzero')
  xs = jnp.array([[0, 1, 0, 1, 0, 1],
                  [1, 1, 1, 1, 0, 1]])
  print(jax.vmap(f)(xs))
Beispiel #16
0
 def linearize_and_solve(x, b):
     unchecked_zeros, f_jvp = jax.linearize(f, x)
     return tangent_solve(f_jvp, b)
Beispiel #17
0

def test_odeint_2_linearize():
    def odeint2(y0, ts, fargs):
        return odeint(f, y0, ts, fargs, atol=1e-8, rtol=1e-8)

    odeint2_prim = custom_transforms(odeint2).primitive

    def odeint2_jvp((y0, ts, fargs), (tan_y, tan_ts, tan_fargs)):
        return jvp_odeint(f, (y0, ts, fargs), (tan_y, tan_ts, tan_fargs))

    ad.defjvp(odeint2_prim, odeint2_jvp)

    _, out_tangent = jvp(odeint2, (y0, ts, fargs),
                         (y0, ts, fargs))  # when break this is why
    y, f_jvp = linearize(odeint2, *(y0, ts, fargs))
    out_tangent_2 = f_jvp(*(y0, ts, fargs))

    # print(make_jaxpr(f_jvp)(y0,t0,t1,fargs))
    check_close(out_tangent, out_tangent_2)


def test_odeint_linearize_fwrap():
    def odeint_fwrap(y0, ts, fargs):
        return odeint(y0, ts, func=f, fargs=fargs)

    _, out_tangent = jvp(odeint_fwrap, (y0, ts, fargs),
                         (y0, ts, fargs))  # when break this is why
    y, f_jvp = linearize(odeint_fwrap, *(y0, ts, fargs))
    out_tangent_2 = f_jvp(*(y0, ts, fargs))
Beispiel #18
0
def diffGiltHOTRG(A, Anorm, isom, RABs, RABsh, scaleN = 20,
               isom_corr = False):
    """
    Similar as diffRGnew, but designed for Gilt-HOTRG-imp version
    """
    # define the invariant tensor where the magnetitute is properly taken care of
    Ainv = Anorm**(-1/3) * A
    # read of isometries and R matrices used in Gilt
    w, v = isom
    RAl, RAr, RBl, RBr = RABs[2:]
    RAlh, RArh, RBlh, RBrh = RABsh[2:]
    # convert everything to numpy.array for consistent, since we will
    # fall back to ordinary tensor multiplication in the calculation here
    Ainv = convertAbeBack(Ainv)
    N1, N2, N3, N4 = Ainv.shape
    w = convertAbeBack(w)
    v = convertAbeBack(v)
    RAl = convertAbeBack(RAl)
    RAr = convertAbeBack(RAr)
    RBl = convertAbeBack(RBl)
    RBr = convertAbeBack(RBr)

    RAlh = convertAbeBack(RAlh)
    RArh = convertAbeBack(RArh)
    RBlh = convertAbeBack(RBlh)
    RBrh = convertAbeBack(RBrh)

    # define the RG equation
    def equationRG(psiA):
        Aorg = psiA.reshape(N1,N2,N3,N4)
        # Gilt before y-contraction
        Ap = jncon([Aorg, RAl, RAr], [[1, 2, -3, -4], [1,-1], [2,-2]])
        Bp = jncon([Aorg, RBl, RBr], [[1, 2, -3, -4], [1,-1], [2,-2]])
        # perform HOTRG y-contraction
        if not isom_corr:
            Ap = doHalfHOTRGknownWV(Bp, Ap, w, direction = "v")
        else:
            chiH = w.shape[2]
            Ap = halfHOTRG(Bp, Ap, chiH, direction = "v", verbose = False,
                           isjax = True)[0]
        # Gilt before x-contraction
        App = jncon([Ap, RAlh, RArh], [[-1,-2,3,4], [4,-4], [3,-3]])
        Bpp = jncon([Ap, RBlh, RBrh], [[-1,-2,3,4], [4,-4], [3,-3]])
        # perform HOTRG x-contraction
        if not isom_corr:
            Ap = doHalfHOTRGknownWV(Bpp, App, v, direction = "h")
        else:
            chiV = v.shape[2]
            Ap = halfHOTRG(Bpp, App, chiV, direction = "h", verbose = False,
                           isjax = True)[0]
        psiAp = Ap.reshape(N1 * N2 * N3 * N4)
        return psiAp
    # linearlized the RG equation to get response matrix
    dimA = N1 * N2 * N3 * N4
    psiA = Ainv.reshape(dimA)
    psiAp, responseMat = jax.linearize(equationRG, psiA)
    # calculate its eigenvalues
    RGhyperM = LinearOperator((dimA,dimA), matvec = responseMat)
    dtemp = np.sort(abs(eigs(RGhyperM, k=scaleN,
                    which='LM', return_eigenvectors=False)))
    dtemp = dtemp[::-1]
    # calculate scaling dimensions
    scDims = -np.log2(abs(dtemp/dtemp[0]))
    return scDims