def test_debug_print_transpose_rule(self): def f(x): debug_print('should never be called: {}', x) return x with capture_stdout() as output: jax.linear_transpose(f, 1.)(1.) jax.effects_barrier() # `debug_print` should be dropped by `partial_eval` because of no # output data-dependence. self.assertEqual(output(), "")
def mvt(v, f, params, samples, centered, w): (res, ) = jax.linear_transpose( lambda v_: qgt_onthefly_logic.mat_vec(v_, f, params, samples, 0.0, centered), v, )(w) return res
def mvpT(A, y): assert y.ndim == 3 input_image = types.SimpleNamespace(shape=(H, W, C), dtype=jnp.float32) mvpt = jax.linear_transpose(lambda x: mvp(A, x), input_image) out = mvpt(y)[0] assert out.ndim == 3 return out
def test_allreduce_transpose(): from mpi4jax import allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() (res,) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(_arr) assert jnp.array_equal(_arr, res)
def _transpose_one_output(linear_fun, primals): transpose_fun = jax.linear_transpose(linear_fun, primals) def transposed_fun(x): (y, ) = transpose_fun(x) return y return transposed_fun
def test_matvec_linear_transpose(): w = v (actual, ) = jax.linear_transpose( lambda v_: mat_vec(v_, f, params, samples, 0.0), v)(w) # use that S is hermitian: # S^T = (O^H O)^T = O^T O* = (O^H O)* = S* # S^T w = S* w = (S w*)* expected = tree_conj(mat_vec(tree_conj(w), f, params, samples, 0.0)) # (expected,) = jax.linear_transpose(lambda v_: reassemble_complex(S_real @ tree_toreal_flat(v_)), v)(v) assert tree_allclose(actual, expected)
def test_allreduce_transpose2(): # test transposing twice from mpi4jax import allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() _arr2 = arr.copy() def lt(y): return jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(y)[0] (res,) = jax.linear_transpose(lt, _arr)(_arr2) expected, _ = allreduce(_arr2, op=MPI.SUM) assert jnp.array_equal(expected, res)
def mat_vec(jvp_fn, v, diag_shift): # Save linearisation work # TODO move to mat_vec_factory after jax v0.2.19 vjp_fn = jax.linear_transpose(jvp_fn, v) w = jvp_fn(v) w = w * (1.0 / (w.size * mpi.n_nodes)) w = subtract_mean(w) # w/ MPI # Oᴴw = (wᴴO)ᴴ = (w* O)* since 1D arrays are not transposed # vjp_fn packages output into a length-1 tuple (res, ) = tree_conj(vjp_fn(w.conjugate())) res = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res) return tree_axpy(diag_shift, v, res) # res + diag_shift * v
def get_ntk(x1, x2, *args): args = tuple(args) args1, args2 = args[:len(args) // 2], args[len(args) // 2 :] _kwargs1 = {k: v for k, v in zip(keys, args1)} _kwargs2 = {k: v for k, v in zip(keys, args2)} f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1) f2 = f1 if utils.all_none(x2) else _get_f_params( f, x2, x_axis, fx_axis, kw_axes, **_kwargs2) def delta_vjp_jvp(delta): def delta_vjp(delta): return vjp(f2, params)[1](delta) return jvp(f1, (params,), delta_vjp(delta))[1] fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params) eye = utils.std_basis(fx1) ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye) ntk = tree_map(lambda fx12: utils.unravel_array_into_pytree(fx1, 0, fx12), ntk) ntk = _diagonal(ntk, fx1) return ntk
def reassemble_complex(x, fun=tree_toreal_flat, target=params): # target: a tree with the expected shape and types of the result (res,) = jax.linear_transpose(fun, target)(x) res = tree_conj(res) # fix the dtypes: return tree_cast(res, target)
def lt(y): return jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(y)[0]
def f(x): (res, ) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(x) return res
def f(x): (res, ) = jax.linear_transpose(lt, _arr)(x) return res
def reassemble_complex(x, target, fun=tree_toreal_flat): # target: a tree with the expected shape and types of the result (res, ) = jax.linear_transpose(fun, target)(x) res = qgt_onthefly_logic.tree_conj(res) # fix the dtypes: return qgt_onthefly_logic.tree_cast(res, target)
def mvt(v, w): (res,) = jax.linear_transpose(lambda v_: mv(v_, 0.0), v)(w) return res
def func_transpose(x): return jax.linear_transpose(func, x)(x)[0]
def fT(y): return jax.linear_transpose(f, x)(y)[0]
def _tree_reassemble_complex(x, target, fun=_tree_to_reim): (res,) = jax.linear_transpose(fun, target)(x) return nkjax.tree_conj(res)
def view_update(data, view_fun): item, view_transpose = view_fun(data), linear_transpose(view_fun, data) def update(new_item): diff, = view_transpose(tree_multimap(jnp.subtract, new_item, item)) return tree_multimap(jnp.add, data, diff) return item, update