示例#1
0
def _kernel_fns(key, input_shape, network, out_logits):
    init_fn, f, _ = _build_network(input_shape, network, out_logits)
    _, params = init_fn(key, (-1, ) + input_shape)
    implicit_kernel_fn = jit(empirical.empirical_implicit_ntk_fn(f))
    direct_kernel_fn = jit(empirical.empirical_direct_ntk_fn(f))

    return (partial(implicit_kernel_fn,
                    params=params), partial(direct_kernel_fn, params=params))
def _kernel_fns(key, input_shape, network, out_logits, diagonal_axes,
                trace_axes):
    init_fn, f, _ = _build_network(input_shape, network, out_logits)
    _, params = init_fn(key, (-1, ) + input_shape)
    implicit_kernel_fn = empirical.empirical_implicit_ntk_fn(
        f, trace_axes, diagonal_axes)
    direct_kernel_fn = empirical.empirical_direct_ntk_fn(
        f, trace_axes, diagonal_axes)
    nngp_kernel_fn = empirical.empirical_nngp_fn(f, trace_axes, diagonal_axes)

    implicit_kernel_fn = jit(implicit_kernel_fn)
    direct_kernel_fn = jit(direct_kernel_fn)
    nngp_kernel_fn = jit(nngp_kernel_fn)

    return (partial(implicit_kernel_fn,
                    params=params), partial(direct_kernel_fn, params=params),
            partial(nngp_kernel_fn, params=params))