def _kernel_fns(key, input_shape, network, out_logits): init_fn, f, _ = _build_network(input_shape, network, out_logits) _, params = init_fn(key, (-1, ) + input_shape) implicit_kernel_fn = jit(empirical.empirical_implicit_ntk_fn(f)) direct_kernel_fn = jit(empirical.empirical_direct_ntk_fn(f)) return (partial(implicit_kernel_fn, params=params), partial(direct_kernel_fn, params=params))
def _kernel_fns(key, input_shape, network, out_logits, diagonal_axes, trace_axes): init_fn, f, _ = _build_network(input_shape, network, out_logits) _, params = init_fn(key, (-1, ) + input_shape) implicit_kernel_fn = empirical.empirical_implicit_ntk_fn( f, trace_axes, diagonal_axes) direct_kernel_fn = empirical.empirical_direct_ntk_fn( f, trace_axes, diagonal_axes) nngp_kernel_fn = empirical.empirical_nngp_fn(f, trace_axes, diagonal_axes) implicit_kernel_fn = jit(implicit_kernel_fn) direct_kernel_fn = jit(direct_kernel_fn) nngp_kernel_fn = jit(nngp_kernel_fn) return (partial(implicit_kernel_fn, params=params), partial(direct_kernel_fn, params=params), partial(nngp_kernel_fn, params=params))