def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]], params: PyTree, **apply_fn_kwargs) -> np.ndarray: """Computes a single sample of the empirical NTK (implicit differentiation). Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`) or same (if `x1 == x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) f1 = _flatten(_get_f_params(f, x1, **kwargs1)) f2 = (f1 if utils.all_none(x2) else _flatten( _get_f_params(f, x2, **kwargs2))) def delta_vjp_jvp(delta): def delta_vjp(delta): return vjp(f2, params)[1](delta) return jvp(f1, (params, ), delta_vjp(delta))[1] # Since we are taking the Jacobian of a linear function (which does not # depend on its coefficients), it is more efficient to substitute fx_dummy # for the outputs of the network. fx_dummy has the same shape as the output # of the network on a single piece of input data. fx2_struct = eval_shape(f2, params) @utils.nt_tree_fn() def dummy_output(fx_struct): return np.ones(fx_struct.shape, fx_struct.dtype) fx_dummy = dummy_output(fx2_struct) ntk = jacobian(delta_vjp_jvp)(fx_dummy) if utils.is_list_or_tuple(fx_dummy): fx_treedef = tree_structure( eval_shape(_get_f_params(f, x1, **kwargs1), params)) ntk = [ntk[i][i] for i in range(len(fx_dummy))] ntk = tree_unflatten(fx_treedef, ntk) return _trace_and_diagonal(ntk, trace_axes, diagonal_axes)
def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]], params: PyTree, **apply_fn_kwargs) -> np.ndarray: """Computes a single sample of the empirical NTK (jacobian outer product). Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) fx1 = eval_shape(f, params, x1, **kwargs1) x_axis, fx_axis, kw_axes = _canonicalize_axes(vmap_axes, x1, fx1, **kwargs1) keys = apply_fn_kwargs.keys() args1, args2 = (kwargs1[k] for k in keys), (kwargs2[k] for k in keys) def j_fn(x, *args): _kwargs = {k: v for k, v in zip(keys, args)} fx = _get_f_params(f, x, x_axis, fx_axis, kw_axes, **_kwargs) jx = jacobian(fx)(params) return jx if x_axis is not None or kw_axes: in_axes = [x_axis] + [ kw_axes[k] if k in kw_axes else None for k in keys ] j_fn = vmap(j_fn, in_axes=in_axes, out_axes=fx_axis) j1 = j_fn(x1, *args1) j2 = j_fn(x2, *args2) if not utils.all_none(x2) else j1 ntk = sum_and_contract(fx1, j1, j2) return ntk
def nngp_fn(x1: np.ndarray, x2: Optional[np.ndarray], params: PyTree, **apply_fn_kwargs) -> np.ndarray: """Computes a single sample of the empirical NNGP. Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NNGP. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ def output(x, **kwargs): out = f(params, x, **kwargs) masked_output = utils.get_masked_array(out) return utils.nt_tree_fn()(lambda x: x.masked_value)(masked_output) kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) out1 = output(x1, **kwargs1) if utils.all_none(x2): out2 = out1 else: out2 = output(x2, **kwargs2) @utils.nt_tree_fn() def contract(out1, out2): dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes) return dot / utils.size_at(out1, trace_axes) return contract(out1, out2)
def parallel_fn_x1(x1, x2=None, *args, **kwargs): x2_is_none = utils.all_none(x2) if x2_is_none: # TODO(schsam): Only compute the upper triangular part of the kernel. x2 = x1 def get_batch_size(x): if utils.is_list_or_tuple(x): return get_batch_size(x[0]) return x.shape[0] n1 = get_batch_size(x1) n2 = n1 if x2_is_none else get_batch_size(x2) _check_dropout(n1, n2, kwargs) n1_per_device, _device_count = _get_n_per_device(n1, n2) _kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count) @utils.nt_tree_fn() def batch_data(x): input_shape = x.shape[1:] return np.reshape(x, ( _device_count, n1_per_device, ) + input_shape) for k, v in kwargs.items(): if _is_np_ndarray(v): assert isinstance(v, tuple) and len(v) == 2 v0 = np.reshape(v[0], ( _device_count, n1_per_device, ) + v[0].shape[1:]) kwargs[k] = (v0, v[1]) x1 = batch_data(x1) kernel = _kernel_fn(x1, x2, *args, **kwargs) return _flatten_kernel(kernel, x2_is_none, True)
def get_ntk(x1, x2, *args): args1, args2 = args[:len(args) // 2], args[len(args) // 2:] _kwargs1 = {k: v for k, v in zip(keys, args1)} _kwargs2 = {k: v for k, v in zip(keys, args2)} f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1) f2 = f1 if utils.all_none(x2) else _get_f_params( f, x2, x_axis, fx_axis, kw_axes, **_kwargs2) def delta_vjp_jvp(delta): def delta_vjp(delta): return vjp(f2, params)[1](delta) return jvp(f1, (params, ), delta_vjp(delta))[1] fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params) eye = _std_basis(fx1) ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye) ntk = tree_map( lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk) ntk = _diagonal(ntk, fx1) return ntk
def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]], params: PyTree, **apply_fn_kwargs) -> np.ndarray: """Computes a single sample of the empirical NTK (implicit differentiation). Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`) or same (if `x1 == x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) fx1 = eval_shape(f, params, x1, **kwargs1) x_axis, fx_axis, kw_axes = _canonicalize_axes(vmap_axes, x1, fx1, **kwargs1) keys = apply_fn_kwargs.keys() args1 = (kwargs1[k] for k in keys) args2 = (kwargs1[k] if k in kw_axes and kwargs2[k] is None else kwargs2[k] for k in keys) def get_ntk(x1, x2, *args): args1, args2 = args[:len(args) // 2], args[len(args) // 2:] _kwargs1 = {k: v for k, v in zip(keys, args1)} _kwargs2 = {k: v for k, v in zip(keys, args2)} f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1) f2 = f1 if utils.all_none(x2) else _get_f_params( f, x2, x_axis, fx_axis, kw_axes, **_kwargs2) def delta_vjp_jvp(delta): def delta_vjp(delta): return vjp(f2, params)[1](delta) return jvp(f1, (params, ), delta_vjp(delta))[1] fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params) eye = _std_basis(fx1) ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye) ntk = tree_map( lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk) ntk = _diagonal(ntk, fx1) return ntk if x_axis is not None or kw_axes: x2 = x1 if utils.all_none(x2) else x2 kw_in_axes = [kw_axes[k] if k in kw_axes else None for k in keys] in_axes1 = [x_axis, None] + kw_in_axes + [None] * len(kw_in_axes) in_axes2 = [None, x_axis] + [None] * len(kw_in_axes) + kw_in_axes get_ntk = vmap(vmap(get_ntk, in_axes1, fx_axis), in_axes2, _add(fx_axis, _ndim(fx1))) return _trace_and_diagonal(get_ntk(x1, x2, *args1, *args2), trace_axes, diagonal_axes)
def serial_fn_x1(x1: NTTree[np.ndarray], x2: Optional[NTTree[Optional[np.ndarray]]] = None, *args, **kwargs) -> NTTree[Kernel]: x2_is_none = utils.all_none(x2) if x2_is_none: # TODO(schsam): Only compute the upper triangular part of the kernel. x2 = x1 @utils.nt_tree_fn(reduce=lambda x: x[0]) def get_n1_n2(x1, x2): n1, n2 = x1.shape[0], x2.shape[0] return n1, n2 n1, n2 = get_n1_n2(x1, x2) (n1_batches, n1_batch_size, n2_batches, n2_batch_size) = \ _get_n_batches_and_batch_sizes(n1, n2, batch_size, device_count) @utils.nt_tree_fn(nargs=1) def batch_input(x, batch_count, batch_size): input_shape = x.shape[1:] return np.reshape(x, ( batch_count, batch_size, ) + input_shape) x1s = batch_input(x1, n1_batches, n1_batch_size) x2s = batch_input(x2, n2_batches, n2_batch_size) kwargs_np1 = {} kwargs_np2 = {} kwargs_other = {} for k, v in kwargs.items(): if _is_np_ndarray(v): if k == 'rng': key1, key2 = random.split(v) v1 = random.split(key1, n1_batches) v2 = random.split(key2, n2_batches) else: assert isinstance(v, tuple) and len(v) == 2 v1 = np.reshape(v[0], ( n1_batches, n1_batch_size, ) + v[0].shape[1:]) v2 = np.reshape(v[1], ( n2_batches, n2_batch_size, ) + v[1].shape[1:]) kwargs_np1[k] = v1 kwargs_np2[k] = v2 else: kwargs_other[k] = v def row_fn(_, x1): return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1] def col_fn(x1, x2): x1, kwargs1 = x1 x2, kwargs2 = x2 kwargs_merge = { **kwargs_other, **dict((k, (kwargs1[k], kwargs2[k])) for k in kwargs1) } return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge) _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1)) return flatten(kernel, x2_is_none)