def mat_vec_factory(forward_fn, params, model_state, samples): # "forward function" that maps params to outputs def fun(W): return forward_fn({"params": W, **model_state}, samples) _, jvp_fn = jax.linearize(fun, params) return Partial(mat_vec, jvp_fn)
def _hessianopt(x, f): _, hvp = jax.linearize(jax.grad(f), x) hvp = jax.jit(hvp) vhvp = jax.vmap(hvp) vhvp = jax.jit(vhvp) basis = np.eye(np.prod(x.shape)).reshape(-1, *x.shape) return vhvp(basis).reshape(x.shape + x.shape)
def hvp(self, x, v): if self.x is None or not np.equal(x, self.x).all(): _, flin = jax.linearize(self.grad, x) self.flin = jax.jit(flin) #self.flin = flin self.x = x return self.flin(v)
def make_mixed_jvp(f, first_args, second_args, opposite=False): """Make a mixed jacobian-vector product function Args: f (callable): Binary callable with signature f(x,y) first_args (numpy.ndarray): First arguments to f second_args (numpy.ndarray): Second arguments to f opposite (bool, optional): Take Dyx if False, Dxy if True. Defaults to False. Returns: callable: Unary callable 'jvp(v)' taking a numpy.ndarray as input. """ if opposite is not True: given = second_args gradfun = jax.grad(f, 0) def frozen_grad(y): return gradfun(first_args, y) else: given = first_args gradfun = jax.grad(f, 1) def frozen_grad(x): return gradfun(x, second_args) return jax.linearize(frozen_grad, given)[1]
def test_odeint_linearize_fwrap(): def odeint_fwrap(y0, ts, fargs): return odeint(y0, ts, func=f, fargs=fargs) _, out_tangent = jvp(odeint_fwrap, (y0, ts, fargs), (y0, ts, fargs)) # when break this is why y, f_jvp = linearize(odeint_fwrap, *(y0, ts, fargs)) out_tangent_2 = f_jvp(*(y0, ts, fargs)) # print(make_jaxpr(f_jvp)((y0,t0,t1,fargs),)) check_close(out_tangent, out_tangent_2)
def _hessianopt(x, f): _, hvp = jax.linearize(jax.grad(f), x) hvp = jax.jit(hvp) n = np.prod(x.shape) idxs = np.arange(vsize, n, vsize) basis = np.eye(np.prod(x.shape)).reshape(-1, *x.shape) splitbasis = np.split(basis, idxs) vhvp = jax.vmap(hvp) vhvp = jax.jit(vhvp) return np.concatenate([vhvp(b) for b in splitbasis]).reshape(x.shape + x.shape)
def _expect_kernel(self, logpsi: Callable, params: PyTree, x: Array, mass: Optional[PyTree]): def logpsi_x(x): return logpsi(params, x) dlogpsi_x = jax.grad(logpsi_x) basis = jnp.eye(x.shape[0]) y, f_jvp = jax.linearize(dlogpsi_x, x) dp_dx2 = jnp.diag(jax.vmap(f_jvp)(basis)) dp_dx = dlogpsi_x(x)**2 return -0.5 * jnp.sum(mass * (dp_dx2 + dp_dx), axis=-1)
def test_linearize(self): @djax.djit def f(x): y = sin(x) return reduce_sum(y, axes=(0,)) x = bbarray((5,), jnp.arange(2.)) with jax.enable_checks(False): # TODO implement dxla_call abs eval rule z, f_lin = jax.linearize(f, x) z_dot = f_lin(ones_like(x)) def g(x): return jnp.sin(x).sum() expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),)) self.assertAllClose(np.array(z), expected_z, check_dtypes=False) self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]: args, in_tree = tree_flatten((args, kwargs)) def f_(*args): args, kwargs = tree_unflatten(in_tree, args) return f(*args, **kwargs) jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])( *args).jaxpr res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)] res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)} results = [] for x in res_lits: results.append((x.aval, 'from a literal')) for v in jaxpr.constvars: if v in res_vars: results.append((v.aval, 'from a constant')) assert len(jaxpr.invars) == len(args) for i, v in enumerate(jaxpr.invars): if v in res_vars: src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}' results.append((v.aval, src)) for eqn in jaxpr.eqns: src = source_info_util.summarize(eqn.source_info) for v in eqn.outvars: if v in res_vars: if eqn.primitive is name_p: results.append( (v.aval, f"named '{eqn.params['name']}' from {src}")) else: results.append((v.aval, f'from {src}')) assert len(results) == len(jaxpr.outvars) return results
def linearized(fn: Callable[..., Tensor], *primals: Tensor) -> LinFun: """Returns linear function that is tangent to `fn` at given primal point.""" val, deriv = jax.linearize(fn, *primals) return lambda *xs: val + deriv(*[x - p for x, p in zip(xs, primals)])
def jvp_unlinearized(f, primals, tangents): out, jvp = linearize(f, *primals) return out, jvp(*tangents)
def hessian_fn(x): _, hvp = jax.linearize(jax.grad(f), x) hvp = jax.jit(hvp) # seems like a substantial speedup to do this basis = jnp.eye(jnp.prod(x.shape)).reshape(-1, *x.shape) return jnp.stack([hvp(e) for e in basis]).reshape(x.shape + x.shape)
# print ("pred: ", predictions) def net_apply_reverse(inputs, net_params): return net_apply(net_params, inputs) @jit def test_loss(net_params, inputs): return np.sum(net_apply(net_params, inputs)) primals_out, vjpfun = vjp(partial(net_apply_reverse, inputs), net_params) print(primals_out) primals_out, jvpfun = linearize(partial(net_apply_reverse, inputs), net_params) # primals_out, vp = jvp(net_apply, (net_params, inputs), random.normal(rng, (1, 256))) print(primals_out) input("") for i in range(10): import time s = time.time() out = vjpfun(random.normal(rng, (1, 10))) e = time.time() print("vjp time: ", (e - s)) s = time.time() out = jvpfun(net_params) # print (out) e = time.time()
def f(y): z, g_lin = jax.linearize(lambda y: g(x, y), y) zdot = g_lin(y) return z, zdot
print(f(x, n)) # type: ignore print(f'should be\n{np.broadcast_to(np.nonzero(x)[0], (4, 2))}') ## ad @djit def f(x): y = sin(x) return reduce_sum(y, axes=(0,)) x = bbarray((5,), jnp.arange(2.)) p('basic jvp') z, z_dot = jax.jvp(f, (x,), (ones_like(x),)) print(z, z_dot) p('basic linearize') _, f_lin = jax.linearize(f, x) print(f_lin(ones_like(x))) ## vmap @djit def f(x): return nonzero(x) p('vmap of nonzero') xs = jnp.array([[0, 1, 0, 1, 0, 1], [1, 1, 1, 1, 0, 1]]) print(jax.vmap(f)(xs))
def linearize_and_solve(x, b): unchecked_zeros, f_jvp = jax.linearize(f, x) return tangent_solve(f_jvp, b)
def test_odeint_2_linearize(): def odeint2(y0, ts, fargs): return odeint(f, y0, ts, fargs, atol=1e-8, rtol=1e-8) odeint2_prim = custom_transforms(odeint2).primitive def odeint2_jvp((y0, ts, fargs), (tan_y, tan_ts, tan_fargs)): return jvp_odeint(f, (y0, ts, fargs), (tan_y, tan_ts, tan_fargs)) ad.defjvp(odeint2_prim, odeint2_jvp) _, out_tangent = jvp(odeint2, (y0, ts, fargs), (y0, ts, fargs)) # when break this is why y, f_jvp = linearize(odeint2, *(y0, ts, fargs)) out_tangent_2 = f_jvp(*(y0, ts, fargs)) # print(make_jaxpr(f_jvp)(y0,t0,t1,fargs)) check_close(out_tangent, out_tangent_2) def test_odeint_linearize_fwrap(): def odeint_fwrap(y0, ts, fargs): return odeint(y0, ts, func=f, fargs=fargs) _, out_tangent = jvp(odeint_fwrap, (y0, ts, fargs), (y0, ts, fargs)) # when break this is why y, f_jvp = linearize(odeint_fwrap, *(y0, ts, fargs)) out_tangent_2 = f_jvp(*(y0, ts, fargs))
def diffGiltHOTRG(A, Anorm, isom, RABs, RABsh, scaleN = 20, isom_corr = False): """ Similar as diffRGnew, but designed for Gilt-HOTRG-imp version """ # define the invariant tensor where the magnetitute is properly taken care of Ainv = Anorm**(-1/3) * A # read of isometries and R matrices used in Gilt w, v = isom RAl, RAr, RBl, RBr = RABs[2:] RAlh, RArh, RBlh, RBrh = RABsh[2:] # convert everything to numpy.array for consistent, since we will # fall back to ordinary tensor multiplication in the calculation here Ainv = convertAbeBack(Ainv) N1, N2, N3, N4 = Ainv.shape w = convertAbeBack(w) v = convertAbeBack(v) RAl = convertAbeBack(RAl) RAr = convertAbeBack(RAr) RBl = convertAbeBack(RBl) RBr = convertAbeBack(RBr) RAlh = convertAbeBack(RAlh) RArh = convertAbeBack(RArh) RBlh = convertAbeBack(RBlh) RBrh = convertAbeBack(RBrh) # define the RG equation def equationRG(psiA): Aorg = psiA.reshape(N1,N2,N3,N4) # Gilt before y-contraction Ap = jncon([Aorg, RAl, RAr], [[1, 2, -3, -4], [1,-1], [2,-2]]) Bp = jncon([Aorg, RBl, RBr], [[1, 2, -3, -4], [1,-1], [2,-2]]) # perform HOTRG y-contraction if not isom_corr: Ap = doHalfHOTRGknownWV(Bp, Ap, w, direction = "v") else: chiH = w.shape[2] Ap = halfHOTRG(Bp, Ap, chiH, direction = "v", verbose = False, isjax = True)[0] # Gilt before x-contraction App = jncon([Ap, RAlh, RArh], [[-1,-2,3,4], [4,-4], [3,-3]]) Bpp = jncon([Ap, RBlh, RBrh], [[-1,-2,3,4], [4,-4], [3,-3]]) # perform HOTRG x-contraction if not isom_corr: Ap = doHalfHOTRGknownWV(Bpp, App, v, direction = "h") else: chiV = v.shape[2] Ap = halfHOTRG(Bpp, App, chiV, direction = "h", verbose = False, isjax = True)[0] psiAp = Ap.reshape(N1 * N2 * N3 * N4) return psiAp # linearlized the RG equation to get response matrix dimA = N1 * N2 * N3 * N4 psiA = Ainv.reshape(dimA) psiAp, responseMat = jax.linearize(equationRG, psiA) # calculate its eigenvalues RGhyperM = LinearOperator((dimA,dimA), matvec = responseMat) dtemp = np.sort(abs(eigs(RGhyperM, k=scaleN, which='LM', return_eigenvectors=False))) dtemp = dtemp[::-1] # calculate scaling dimensions scDims = -np.log2(abs(dtemp/dtemp[0])) return scDims