def test_std_basis(self): basis = api._std_basis(np.zeros(3)) assert getattr(basis, "shape", None) == (3, 3) assert onp.allclose(basis, onp.eye(3)) basis = api._std_basis(np.zeros((3, 3))) assert getattr(basis, "shape", None) == (9, 3, 3) assert onp.allclose(basis, onp.eye(9).reshape(9, 3, 3)) basis = api._std_basis([0., (np.zeros(3), np.zeros((3, 4)))]) assert isinstance(basis, list) and len(basis) == 2 assert getattr(basis[0], "shape", None) == (16, ) assert isinstance(basis[1], tuple) and len(basis[1]) == 2 assert getattr(basis[1][0], "shape", None) == (16, 3) assert getattr(basis[1][1], "shape", None) == (16, 3, 4)
def get_ntk(x1, x2, *args): args1, args2 = args[:len(args) // 2], args[len(args) // 2:] _kwargs1 = {k: v for k, v in zip(keys, args1)} _kwargs2 = {k: v for k, v in zip(keys, args2)} f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1) f2 = f1 if utils.all_none(x2) else _get_f_params( f, x2, x_axis, fx_axis, kw_axes, **_kwargs2) def delta_vjp_jvp(delta): def delta_vjp(delta): return vjp(f2, params)[1](delta) return jvp(f1, (params, ), delta_vjp(delta))[1] fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params) eye = _std_basis(fx1) ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye) ntk = tree_map( lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk) ntk = _diagonal(ntk, fx1) return ntk