def _index_and_contract(ntk: np.ndarray, trace_axes: Axes, diagonal_axes: Axes) -> np.ndarray: if ntk.ndim % 2 == 1: raise ValueError('Expected an even-dimensional kernel. Please file a bug at' 'https://github.com/google/neural-tangents/issues/new') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_marg = len(diagonal_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) shrink = 0 for c in reversed(trace_axes): ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink) shrink += 1 for i, d in enumerate(diagonal_axes): ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg) res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes) ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes) return ntk / contract_size
def testAxes(self, diagonal_axes, trace_axes): key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, (4, 5, 6, 3)) data_other = random.normal(other_split, (2, 5, 6, 3)) _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self) _trace_axes = utils.canonicalize_axis(trace_axes, data_self) if any(d == c for d in _diagonal_axes for c in _trace_axes): raise absltest.SkipTest( 'diagonal axes must be different from channel axes.') get_kernel = KERNELS['empirical_logits_3'] kwargs = dict( key=key, input_shape=(5, 6, 3), network=CONV, diagonal_axes=diagonal_axes, trace_axes=trace_axes ) implicit, direct, nngp = get_kernel(**kwargs) implicit_batched, direct_batched, _ = get_kernel(**kwargs, vmap_axes=0) n_marg = len(_diagonal_axes) n_chan = len(_trace_axes) g_nngp = nngp(data_self, None) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim) g_direct = direct(data_self, None) self.assertEqual(g_nngp.shape, g_direct.shape) g_direct_batched = direct_batched(data_self, None) g = implicit(data_self, None) g_batched = implicit_batched(data_self, None) self.assertAllClose(g_direct, g) self.assertAllClose(g_direct, g_direct_batched) self.assertAllClose(g_direct, g_batched) if 0 not in _trace_axes and 0 not in _diagonal_axes: g_nngp = nngp(data_other, data_self) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim) g_direct = direct(data_other, data_self) self.assertEqual(g_nngp.shape, g_direct.shape) g_direct_batched = direct_batched(data_other, data_self) g = implicit(data_other, data_self) g_batched = implicit_batched(data_other, data_self) self.assertAllClose(g_direct, g) self.assertAllClose(g_direct, g_direct_batched) self.assertAllClose(g_direct, g_batched)
def sum_and_contract(j1, j2, output_ndim): _diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) _trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) def contract(x, y): param_axes = list(range(x.ndim))[output_ndim:] contract_axes = _trace_axes + param_axes return utils.dot_general(x, y, contract_axes, _diagonal_axes) return tree_reduce(operator.add, tree_multimap(contract, j1, j2))
def sum_and_contract(fx, j1, j2): ndim = fx.ndim size = utils.size_at(fx, trace_axes) _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim) _trace_axes = utils.canonicalize_axis(trace_axes, ndim) def contract(x, y): param_axes = list(range(x.ndim))[ndim:] contract_axes = _trace_axes + param_axes return utils.dot_general(x, y, contract_axes, _diagonal_axes) / size return tree_reduce(operator.add, tree_multimap(contract, j1, j2))
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes, diagonal_axes: Axes) -> np.ndarray: """Extract traces and diagonals along respective pairs of axes from the `ntk`. Args: ntk: input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`. trace_axes: axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along and remove the respective pairs of axes from the `ntk`. diagonal_axes: axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the diagonal along the respective pairs of axes from the `ntk` (and hence reduce the resulting `ntk` axes count by 2). Returns: An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes replaced with a single `Y` axis). """ if ntk.ndim % 2 == 1: raise ValueError( 'Expected an even-dimensional kernel. Please file a bug at' 'https://github.com/google/neural-tangents/issues/new') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_diag, n_trace = len(diagonal_axes), len(trace_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) for i, c in enumerate(reversed(trace_axes)): ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i) for i, d in enumerate(diagonal_axes): axis1 = d - i axis2 = output_ndim + d - 2 * i - n_trace for c in trace_axes: if c < d: axis1 -= 1 axis2 -= 1 ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag) res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes) ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes) return ntk / contract_size
def testAxes(self, diagonal_axes, trace_axes): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=3) key = splits[0] self_split = splits[1] other_split = splits[2] data_self = np.asarray(normal((4, 5, 6, 3), seed=self_split)) data_other = np.asarray(normal((2, 5, 6, 3), seed=other_split)) _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self) _trace_axes = utils.canonicalize_axis(trace_axes, data_self) if any(d == c for d in _diagonal_axes for c in _trace_axes): raise absltest.SkipTest( 'diagonal axes must be different from channel axes.') implicit, direct, nngp = KERNELS['empirical_logits_3']( key, (5, 6, 3), CONV, diagonal_axes=diagonal_axes, trace_axes=trace_axes) n_marg = len(_diagonal_axes) n_chan = len(_trace_axes) g = implicit(data_self, None) g_direct = direct(data_self, None) g_nngp = nngp(data_self, None) self.assertAllClose(g, g_direct) self.assertEqual(g_nngp.shape, g.shape) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim) if 0 not in _trace_axes and 0 not in _diagonal_axes: g = implicit(data_other, data_self) g_direct = direct(data_other, data_self) g_nngp = nngp(data_other, data_self) self.assertAllClose(g, g_direct) self.assertEqual(g_nngp.shape, g.shape) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)
def _get_fx_test_shape(y_train: np.ndarray, k_test_train: np.ndarray, y_axes: Axes) -> Tuple[int, ...]: if k_test_train is None: return y_train.shape shape = list(k_test_train.shape[::2]) y_axes = utils.canonicalize_axis(y_axes, y_train) for i, c in enumerate(y_train.shape): if i in y_axes: shape.insert(i, c) return tuple(shape)
def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray: b_axes = utils.canonicalize_axis(b_axes, b) last_b_axes = range(-len(b_axes), 0) x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes) b = np.moveaxis(b, b_axes, last_b_axes) b = b.reshape((A.shape[1], -1)) x = sp.linalg.cho_solve(C, b) x = x.reshape(x_shape) return x
def gradient_descent_mse_ensemble(kernel_fn: KernelFn, x_train: np.ndarray, y_train: np.ndarray, learning_rate: float = 1., diag_reg: float = 0.0, diag_reg_absolute_scale: bool = False, trace_axes: Axes = (-1, ), **kernel_fn_kwargs): r"""Predicts the gaussian embedding induced by gradient descent on MSE loss. This is equivalent to an infinite ensemble of infinite-width networks after marginalizing out the initialization, if `kernel_fn` is the kernel function of the infinite-width network. Note that `kernel_fn` can in principle also be an empirical / Monte Carlo finite-width kernel function, but in this case the returned output will not have a simple interpretation (unless these functions are used to approximate the infinite-width kernel). Note that first invocation of the returned `predict_fn` will be slow and allocate a lot of memory for its whole lifetime, as the kernel computation, and either eigendecomposition (`t` is a scalar or an array) or Cholesky factorization (`t=None`) of `kernel_fn(x_train, None, get)` is performed and cached for future invocations (or both, if the function is called on both finite and infinite (`t=None`) times). Args: kernel_fn: A kernel function that computes NNGP and/or NTK. Must have a signature `kernel_fn(x1, x2, get, **kernel_fn_kwargs)` and return a `Kernel` object or a `namedtuple` with `nngp` and/or `ntk` attributes. Therefore, it can be an `AnalyticKernelFn`, but also a `MonteCarloKernelFn`, or an `EmpiricalKernelFn` (but only `nt.empirical_kernel_fn` and not `nt.empirical_ntk_fn` or `ntk.empirical_nngp_fn`, since the latter two do not accept a `get` argument). Note that for meaningful outputs, the kernel function must represent or at least approximate the infinite-width kernel. x_train: training inputs. y_train: training targets. learning_rate: learning rate, step size. diag_reg: a scalar representing the strength of the diagonal regularization for `kernel_fn(x_train, None, get)`, i.e. computing `kernel_fn(x_train, None, get) + diag_reg * I` during Cholesky factorization or eigendecomposition. diag_reg_absolute_scale: `True` for `diag_reg` to represent regularization in absolute units, `False` to be `diag_reg * np.mean(np.trace(kernel_fn(x_train, None, get)))`. trace_axes: `f(x_train)` axes such that `kernel_fn(x_train, None, get)`, `kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)`] lack these pairs of dimensions and are to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, `trace_axes=()` will yield most accurate result. **kernel_fn_kwargs: optional keyword arguments passed to `kernel_fn`. Returns: A function with signature `predict_fn(t, x_test, get, compute_cov)` returning either mean or mean and covariance of the infinite ensemble of infinite-width networks outputs on `x_test` at time[s] `t`, in the `get` regime (`"nngp"`, `"ntk"`, or `("nngp", "ntk")`). """ expm1 = _make_expm1_fn(y_train.size) inv_expm1 = _make_inv_expm1_fn(y_train.size) trace_axes = utils.canonicalize_axis(trace_axes, y_train) trace_axes = tuple(-y_train.ndim + a for a in trace_axes) n_trace_axes = len(trace_axes) last_t_axes = range(-n_trace_axes, 0) trace_shape = tuple(y_train.shape[a] for a in trace_axes) y_train_flat = np.moveaxis(y_train, trace_axes, last_t_axes).reshape((-1, ) + trace_shape) k_dd_cache = {} def get_k_train_train(get: Sequence[str]) -> _Kernel: if len(get) == 1: get = get[0] if get not in k_dd_cache: k_dd_cache[get] = kernel_fn(x_train, None, get, **kernel_fn_kwargs) elif len(get) == 2: if not any(g in k_dd_cache for g in get): k_dd_cache.update( kernel_fn(x_train, None, get, **kernel_fn_kwargs)._asdict()) else: for g in get: if g not in k_dd_cache: k_dd_cache[g] = kernel_fn(x_train, None, g, **kernel_fn_kwargs) else: raise ValueError(get) return _Kernel(**k_dd_cache) @lru_cache(2) def eigenspace(get: str): k_dd = getattr(get_k_train_train((get, )), get) k_dd = _add_diagonal_regularizer(utils.make_2d(k_dd), diag_reg, diag_reg_absolute_scale) evals, evecs = np.linalg.eigh(k_dd) evals = np.expand_dims(evals, 0) return evals, evecs @lru_cache(4) def predict_inf(get: Get): _, get = utils.canonicalize_get(get) k_dd = get_k_train_train(get) return gp_inference(k_dd, y_train, diag_reg, diag_reg_absolute_scale, trace_axes) def get_matrices(get: Get, x_test: Optional[np.ndarray], compute_cov: bool): get = _get_dependency(get, compute_cov) k_dd = get_k_train_train(get) if x_test is None: k_td = None nngp_tt = compute_cov or None else: k_td = kernel_fn(x_test, x_train, get, **kernel_fn_kwargs) if compute_cov: nngp_tt = kernel_fn(x_test, None, 'nngp', **kernel_fn_kwargs) else: nngp_tt = None return k_dd, k_td, nngp_tt @utils.get_namedtuple('Gaussians') def predict_fn(t: ArrayOrScalar = None, x_test: np.ndarray = None, get: Get = None, compute_cov: bool = False) -> Dict[str, Gaussian]: """Return output mean and covariance on the test set at time[s] `t`. Args: t: a scalar of array of scalars of any shape. `t=None` is treated as infinity and returns the same result as `t=np.inf`, but is computed using linear solve for test predictions instead of eigendecomposition, saving time and precision. x_test: test inputs. `None` means to return non-regularized (`diag_reg=0`) predictions on the train-set inputs. For regularized predictions, pass `x_test=x_train`. get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple. `get=None` is equivalent to `get=("nngp", "ntk")`. compute_cov: if `True` computing both `mean` and `variance` and only `mean` otherwise. Returns: `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if `compute_cov == True` with potentially additional leading time dimensions. """ if get is None: get = ('nngp', 'ntk') # train-train, test-train, test-test. k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov) # Infinite time. if t is None: return predict_inf(get)(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) # Finite time. t = np.array(t) * learning_rate t_shape = t.shape t = t.reshape((-1, 1)) def reshape_mean(mean): k = _get_first(k_dd if k_td is None else k_td) mean = mean.reshape(t_shape + k.shape[::2] + trace_shape) mean = np.moveaxis(mean, last_t_axes, trace_axes) return mean def reshape_cov(cov): k = _get_first(k_dd if k_td is None else k_td) cov_shape_t = t_shape + k.shape[::2] * 2 return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape)) out = {} for g in get: evals, evecs = eigenspace(g) # Training set. if k_td is None: mean = np.einsum('ji,ti,ki,k...->tj...', evecs, -expm1(evals, t), evecs, y_train_flat, optimize=True) # Test set. else: neg_inv_expm1 = -inv_expm1(evals, t) ktd_g = utils.make_2d(getattr(k_td, g)) mean = np.einsum('lj,ji,ti,ki,k...->tl...', ktd_g, evecs, neg_inv_expm1, evecs, y_train_flat, optimize=True) mean = reshape_mean(mean) if nngp_tt is not None: nngp_dd = utils.make_2d(k_dd.nngp) # Training set. if k_td is None: if g == 'nngp': cov = np.einsum('ji,ti,ki->tjk', evecs, (np.maximum(evals, 0.) * np.exp(-2 * np.maximum(evals, 0.) * t / y_train.size)), evecs, optimize=True) elif g == 'ntk': exp = np.einsum('mi,ti,ki->tmk', evecs, np.exp(-np.maximum(evals, 0.) * t / y_train.size), evecs, optimize=True) cov = np.einsum('tmk,kl,tnl->tmn', exp, nngp_dd, exp, optimize=True) else: raise ValueError(g) # Test set. else: _nngp_tt = np.expand_dims(utils.make_2d(nngp_tt), 0) if g == 'nngp': cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml', ktd_g, evecs, -inv_expm1(evals, 2 * t), evecs, ktd_g, optimize=True) elif g == 'ntk': term_1 = np.einsum('mi,ti,ki,lk->tml', evecs, neg_inv_expm1, evecs, ktd_g, optimize=True) term_2 = np.einsum( 'mj,ji,ti,ki,lk->tml', ktd_g, evecs, neg_inv_expm1, evecs, utils.make_2d(k_td.nngp), # pytype:disable=attribute-error optimize=True) term_2 += np.moveaxis(term_2, 1, 2) cov = np.einsum('tji,jk,tkl->til', term_1, nngp_dd, term_1, optimize=True) cov += -term_2 + _nngp_tt else: raise ValueError(g) out[g] = Gaussian(mean, reshape_cov(cov)) else: out[g] = mean return out return predict_fn
def gp_inference(k_train_train, y_train: np.ndarray, diag_reg: float = 0., diag_reg_absolute_scale: bool = False, trace_axes: Axes = (-1, )): r"""Compute the mean and variance of the `posterior` of NNGP and NTK. Note that first invocation of the returned `predict_fn` will be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization of `k_train_train.nngp` or `k_train_train.ntk` (or both) is performed and cached for future invocations. Args: k_train_train: train-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c) `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels for arguments provided to the returned `predict_fn` function. For example, if you request to compute posterior test [only] NTK covariance in future `predict_fn` invocations, `k_train_train` must contain both `ntk` and `nngp` kernels. y_train: train targets. diag_reg: a scalar representing the strength of the diagonal regularization for `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during Cholesky factorization. diag_reg_absolute_scale: `True` for `diag_reg` to represent regularization in absolute units, `False` to be `diag_reg * np.mean(np.trace(k_train_train))`. trace_axes: `f(x_train)` axes such that `k_train_train`, `k_test_train`[, and `nngp_test_test`] lack these pairs of dimensions and are to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, `trace_axes=()` will yield most accurate result. Returns: A function of signature `predict_fn(get, k_test_train, nngp_test_test)` computing posterior Gaussian distribution (mean or mean and covariance) on a given test set. """ even, odd, first, last = _get_axes(_get_first(k_train_train)) trace_axes = utils.canonicalize_axis(trace_axes, y_train) @lru_cache(2) def solve(g: str): k_dd = _get_attr(k_train_train, g) return _get_cho_solve(k_dd, diag_reg, diag_reg_absolute_scale) @lru_cache(2) def k_inv_y(g: str): return solve(g)(y_train, trace_axes) @utils.get_namedtuple('Gaussians') def predict_fn( get: Get, k_test_train=None, nngp_test_test: np.ndarray = None ) -> Dict[str, Union[np.ndarray, Gaussian]]: """`test`-set posterior given respective covariance matrices. Args: get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are returned. k_test_train: test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c) `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels for arguments provided to the returned `predict_fn` function. For example, if you request to compute posterior test [only] NTK covariance, `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`, returns predictions on the training set. Note that train-set outputs are always `N(y_train, 0)` and mostly returned for API consistency. nngp_test_test: A test-test NNGP array. Provide if you want to compute test-test posterior covariance. `nngp_test_tes=None`, means to not compute it. If `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you want to get non-regularized (`diag_reg=0`) train-train posterior covariance. Note that non-regularized train-set outputs will always be the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API consistency. For regularized train-set posterior outputs according to a positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally, `nngp_test_test=nngp_train_train`. Returns: Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP posterior on the `test` set. """ if get is None: get = ('nngp', 'ntk') out = {} for g in get: k_dd = _get_attr(k_train_train, g) k_td = None if k_test_train is None else _get_attr(k_test_train, g) if k_td is None: # Train set predictions. y = y_train.astype(k_dd.dtype) else: # Test set predictions. y = np.tensordot(k_td, k_inv_y(g), (odd, first)) y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes) if nngp_test_test is not None: if k_td is None: out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype)) else: if (g == 'ntk' and (not hasattr(k_train_train, 'nngp') or not hasattr(k_test_train, 'nngp'))): raise ValueError( 'If `"ntk" in get`, and `nngp_test_test is not None`, ' 'and `k_test_train is not None`, i.e. you request the ' 'NTK posterior covariance on the test set, you need ' 'both NTK and NNGP train-train and test-train matrices ' 'contained in `k_test_train` and `k_train_train`. ' 'Hence they must be `namedtuple`s with `nngp` and ' '`ntk` attributes.') k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'), even) if g == 'nngp': cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) cov = nngp_test_test - utils.zip_axes(cov) out[g] = Gaussian(y, cov) elif g == 'ntk': term_1 = solve(g)(k_td, even) cov = np.tensordot(_get_attr(k_train_train, 'nngp'), term_1, (odd, first)) cov = np.tensordot(term_1, cov, (first, first)) term_2 = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) term_2 += np.moveaxis(term_2, first, last) cov = utils.zip_axes(cov - term_2) + nngp_test_test out[g] = Gaussian(y, cov) else: raise ValueError(g) else: out[g] = y return out return predict_fn
def gradient_descent_mse( k_train_train: np.ndarray, y_train: np.ndarray, learning_rate: float = 1., diag_reg: float = 0., diag_reg_absolute_scale: bool = False, trace_axes: Axes = (-1, ) ) -> Callable[ [ArrayOrScalar, ArrayOrScalar, ArrayOrScalar, Optional[np.ndarray]], Union[ np.ndarray, Tuple[np.ndarray, np.ndarray]]]: r"""Predicts the outcome of function space gradient descent training on MSE. Solves in closed form for the continuous-time version of gradient descent. Uses the closed-form solution for gradient descent on an MSE loss in function space detailed in [*,**] given a Neural Tangent or Neural Network Gaussian Process Kernel over the dataset. Given NNGP or NTK, this function will return a function that predicts the time evolution for function space points at arbitrary time[s] (training step[s]) `t`. Note that these time[s] (step[s]) are continuous and are interpreted in units of the `learning_rate` so `absolute_time = learning_rate * t`, and the scales of `learning_rate` and `t` are interchangeable. Note that first invocation of the returned `predict_fn` will be slow and allocate a lot of memory for its whole lifetime, as either eigendecomposition (`t` is a scalar or an array) or Cholesky factorization (`t=None`) of `k_train_train` is performed and cached for future invocations (or both, if the function is called on both finite and infinite (`t=None`) times). [*] https://arxiv.org/abs/1806.07572 [**] https://arxiv.org/abs/1902.06720 Example: >>> from neural_tangents import empirical_ntk_fn >>> from neural_tangents import predict >>> >>> t = 1e-7 >>> kernel_fn = empirical_ntk_fn(f) >>> k_train_train = kernel_fn(x_train, None, params) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> >>> predict_fn = predict.gradient_descent_mse(k_train_train, y_train) >>> >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train) Args: k_train_train: kernel on the training data. Must have the shape of `zip(y_train.shape, y_train.shape)` with `trace_axes` absent. y_train: targets for the training data. learning_rate: learning rate, step size. diag_reg: a scalar representing the strength of the diagonal regularization for `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during Cholesky factorization or eigendecomposition. diag_reg_absolute_scale: `True` for `diag_reg` to represent regularization in absolute units, `False` to be `diag_reg * np.mean(np.trace(k_train_train))`. trace_axes: `f(x_train)` axes such that `k_train_train` lacks these pairs of dimensions and is to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, `trace_axes=()` will yield most accurate result. Returns: A function of signature `predict_fn(t, fx_train_0, fx_test_0, k_test_train)` that returns output train [and test] set[s] predictions at time[s] `t`. """ _, odd, first, _ = _get_axes(k_train_train) trace_axes = utils.canonicalize_axis(trace_axes, y_train) trace_axes = tuple(-y_train.ndim + a for a in trace_axes) n_t_axes, n_non_t_axes = len(trace_axes), y_train.ndim - len(trace_axes) last_t_axes = tuple(range(-n_t_axes, 0)) non_t_axes = tuple(range(-y_train.ndim, -n_t_axes)) @lru_cache(1) def get_predict_fn_inf(): with jax.core.eval_context(): solve = _get_cho_solve(k_train_train, diag_reg, diag_reg_absolute_scale) def predict_fn_inf(fx_train_0, fx_test_0, k_test_train): fx_train_t = y_train.astype(k_train_train.dtype) if fx_test_0 is None: return fx_train_t rhs = y_train if fx_train_0 is None else y_train - fx_train_0 dfx_test = np.tensordot(k_test_train, solve(rhs, trace_axes), (odd, first)) dfx_test = np.moveaxis(dfx_test, last_t_axes, trace_axes) fx_test_t = fx_test_0 + dfx_test if fx_train_0 is None: return fx_test_t return fx_train_t, fx_test_t return predict_fn_inf @lru_cache(1) def get_predict_fn_finite(): with jax.core.eval_context(): expm1_fn, inv_expm1_fn = _get_fns_in_eigenbasis( k_train_train, diag_reg, diag_reg_absolute_scale, (_make_expm1_fn(y_train.size), _make_inv_expm1_fn( y_train.size))) rhs_shape = tuple(y_train.shape[a] for a in trace_axes) def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train): t = np.array(t) * learning_rate t_shape, t_ndim = t.shape, t.ndim first_t_axes = tuple(range(t_ndim)) t = t.reshape((-1, 1)) rhs = -y_train if fx_train_0 is None else fx_train_0 - y_train rhs = np.moveaxis(rhs, trace_axes, last_t_axes).reshape((-1, ) + rhs_shape) shape = t_shape + k_train_train.shape[1::2] + rhs_shape if fx_train_0 is not None: dfx_train = expm1_fn(rhs, t).reshape(shape) dfx_train = np.moveaxis(dfx_train, last_t_axes, trace_axes) fx_train_t = np.expand_dims(fx_train_0, first_t_axes) + dfx_train if fx_test_0 is not None: dfx_test = inv_expm1_fn(rhs, t).reshape(shape) dfx_test = np.tensordot(k_test_train, dfx_test, (odd, non_t_axes)) dfx_test = np.moveaxis( dfx_test, tuple(range(n_non_t_axes, n_non_t_axes + t_ndim)) + last_t_axes, tuple(range(t_ndim)) + trace_axes) fx_test_t = np.expand_dims(fx_test_0, first_t_axes) + dfx_test if fx_train_0 is not None and fx_test_0 is not None: return fx_train_t, fx_test_t if fx_test_0 is None: return fx_train_t return fx_test_t return predict_fn_finite def predict_fn( t: ArrayOrScalar = None, fx_train_0: ArrayOrScalar = 0., fx_test_0: ArrayOrScalar = None, k_test_train: np.ndarray = None ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Return output predictions on train [and test] set[s] at time[s] `t`. Args: t: a scalar of array of scalars of any shape. `t=None` is treated as infinity and returns the same result as `t=np.inf`, but is computed using identity or linear solve for train and test predictions respectively instead of eigendecomposition, saving time and precision. Equivalent of training steps (but can be fractional). fx_train_0: output of the network at `t == 0` on the training set. `fx_train_0=None` means to not compute predictions on the training set. fx_test_0: output of the network at `t == 0` on the test set. `fx_test_0=None` means to not compute predictions on the test set. k_test_train: kernel relating test data with training data. Must have the shape of `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass `k_test_train=None` if you only need non-regularized (`diag_reg=0`) predictions on the training set. For regularized train-set predictions, pass `k_test_train=k_train_train`. Returns: `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with potentially additional leading time dimensions matching `t.shape`. Raises: ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`. """ _check_inputs(fx_train_0, fx_test_0, k_test_train) # Infinite time if t is None: return get_predict_fn_inf()(fx_train_0, fx_test_0, k_test_train) # Finite time return get_predict_fn_finite()(t, fx_train_0, fx_test_0, k_test_train) return predict_fn
def gradient_descent( loss: Callable[[np.ndarray, np.ndarray], float], k_train_train: np.ndarray, y_train: np.ndarray, learning_rate: float = 1., momentum: float = None, trace_axes: Axes = (-1, ) ) -> Callable[[ ArrayOrScalar, Union[ArrayOrScalar, ODEState], ArrayOrScalar, Optional[np.ndarray] ], Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]]: r"""Predicts the outcome of function space training using gradient descent. Uses an ODE solver. If `momentum != None`, solves a continuous-time version of gradient descent with momentum (note: this case uses standard momentum as opposed to Nesterov momentum). Solves the function space ODE for [momentum] gradient descent with a given `loss` (detailed in [*]) given a Neural Tangent Kernel[s] over the dataset[s] at arbitrary time[s] (step[s]) `t`. Note that for gradient descent `absolute_time = learning_rate * t` and the scales of the learning rate and query step[s] `t` are interchangeable. However, the momentum gradient descent ODE is solved in the units of `learning_rate**0.5`, and therefore `absolute_time = learning_rate**0.5 * t`, hence the `learning_rate` and training time[s] (step[s]) `t` scales are not interchangeable. [*] https://arxiv.org/abs/1902.06720 Example: >>> from neural_tangents import empirical_ntk_fn >>> from neural_tangents import predict >>> >>> t = 1e-7 >>> learning_rate = 1e-2 >>> momentum = 0.9 >>> >>> kernel_fn = empirical_ntk_fn(f) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> >>> from jax.experimental import stax >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat) >>> predict_fn = predict.gradient_descent(cross_entropy, k_train_train, >>> y_train, learning_rate, momentum) >>> >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train) Args: loss: a loss function whose signature is `loss(f(x_train), y_train)`. Note: the loss function should treat the batch and output dimensions symmetrically. k_train_train: kernel on the training data. Must have the shape of `zip(y_train.shape, y_train.shape)` with `trace_axes` absent. y_train: targets for the training data. learning_rate: learning rate, step size. momentum: momentum scalar. trace_axes: `f(x_train)` axes such that `k_train_train` lacks these pairs of dimensions and is to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, `trace_axes=()` will yield most accurate result. Returns: A function that returns output train [and test] set[s] predictions at time[s] `t`. """ _, odd, _, _ = _get_axes(k_train_train) trace_axes = utils.canonicalize_axis(trace_axes, y_train) non_t_axes = tuple(a for a in range(y_train.ndim) if a not in trace_axes) last_t_axes = range(-len(trace_axes), 0) dtype = k_train_train.dtype grad_loss = grad(lambda fx: loss(fx, y_train)) if momentum is not None: learning_rate **= 0.5 momentum = (momentum - 1.0) / learning_rate def get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape): if isinstance(fx_train_or_state_0, ODEState): fx_train_0 = fx_train_or_state_0.fx_train fx_test_0 = fx_train_or_state_0.fx_test qx_train_0 = fx_train_or_state_0.qx_train qx_test_0 = fx_train_or_state_0.qx_test else: fx_train_0 = fx_train_or_state_0 qx_train_0 = qx_test_0 = None if fx_train_0 is None: fx_train_0 = np.zeros_like(y_train, dtype) else: fx_train_0 = np.broadcast_to(fx_train_0, y_train.shape) if fx_test_0 is not None: fx_test_0 = np.broadcast_to(fx_test_0, fx_test_shape) if momentum is None: if qx_train_0 is not None or qx_test_0 is not None: raise ValueError('Got passed momentum state variables, while ' '`momentum is None`.') else: qx_train_0 = (np.zeros_like(y_train, dtype) if qx_train_0 is None else np.broadcast_to(qx_train_0, y_train.shape)) qx_test_0 = (None if fx_test_0 is None else (np.zeros(fx_test_shape, dtype) if qx_test_0 is None else np.broadcast_to(qx_test_0, fx_test_shape))) return ODEState(fx_train_0, fx_test_0, qx_train_0, qx_test_0) # pytype: disable=wrong-arg-count def get_dstate_dt(k_test_train): def dstate_dt(state_t: ODEState, unused_t) -> ODEState: fx_train_t, fx_test_t, qx_train_t, qx_test_t = (state_t.fx_train, state_t.fx_test, state_t.qx_train, state_t.qx_test) dy_df_t = grad_loss(fx_train_t) fx_train_t = -np.moveaxis( np.tensordot(k_train_train, dy_df_t, (odd, non_t_axes)), last_t_axes, trace_axes) if fx_test_t is not None: fx_test_t = -np.moveaxis( np.tensordot(k_test_train, dy_df_t, (odd, non_t_axes)), last_t_axes, trace_axes) if momentum is None: return ODEState(fx_train_t, fx_test_t) # pytype: disable=wrong-arg-count fx_train_t += momentum * qx_train_t if qx_test_t is not None: fx_test_t += momentum * qx_test_t return ODEState(qx_train_t, qx_test_t, fx_train_t, fx_test_t) # pytype: disable=wrong-arg-count return dstate_dt def predict_fn( t: ArrayOrScalar = None, fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0., fx_test_0: ArrayOrScalar = None, k_test_train: np.ndarray = None ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]: """Return output predictions on train [and test] set[s] at time[s] `t`. Args: t: a scalar or array of scalars of any shape in strictly increasing order. `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of training steps (but can be fractional). fx_train_or_state_0: either (a) output of the network at `t == 0` on the training set or (b) complete ODE state (`predict.ODEState`). Pass an ODE state if you want to operate on the full ODE state instead of output variables only (useful for inspecting auxiliary variables or resuming an optimizer with auxiliary variables from a specific state. Note that only `momentum != None` optimizer currently has auxiliary variables. To initialize an ODE state from scratch, call `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an ODE state is returned. `fx_train_0=None` means to not compute predictions on the training set. fx_test_0: output of the network at `t == 0` on the test set. `fx_test_0=None` means to not compute predictions on the test set. k_test_train: kernel relating test data with training data. Must have the shape of `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass `k_test_train=None` if you only need predictions on the training set. Returns: `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with potentially additional leading time dimensions matching `t.shape`. Alternatively can return an `ODEState` at time[s] `t`. Raises: ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`. """ _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train) t = np.array(t if t is not None else np.inf, dtype) * learning_rate t_shape = t.shape t = t.reshape((-1, )) # ODE solver requires `t[0]` to be the time where `fx_train_0` [and # `fx_test_0`] are evaluated, but also a strictly increasing sequence of # timesteps, so we always temporarily append an [almost] `0` at the start. t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype), np.zeros((1, ), t.dtype)) t = np.concatenate([t0, t]) # Solve the ODE. fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes) state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape) state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t) # Remove the added `t0`. trim = lambda x: x[1:].reshape(t_shape + x.shape[1:]) trim_tree = lambda tree: tree_map(trim, tree) state_t = trim_tree(state_t) # `ODEState` -> `ODEState` if isinstance(fx_train_or_state_0, ODEState): return state_t # `np.ndarray` -> `np.ndarray` fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test if fx_train_or_state_0 is not None and fx_test_0 is None: return fx_train_t if fx_test_0 is not None and fx_train_or_state_0 is None: return fx_test_t return fx_train_t, fx_test_t return predict_fn