def eigenspace(get: str): k_dd = getattr(get_k_train_train((get, )), get) k_dd = _add_diagonal_regularizer(utils.make_2d(k_dd), diag_reg, diag_reg_absolute_scale) evals, evecs = np.linalg.eigh(k_dd) evals = np.expand_dims(evals, 0) return evals, evecs
def max_learning_rate(ntk_train_train: np.ndarray, y_train_size: int = None, eps: float = 1e-12) -> float: r"""Computes the maximal feasible learning rate for infinite width NNs. The network is assumed to be trained using SGD or full-batch GD with mean squared loss. The loss is assumed to have the form `1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2`. The maximal feasible learning rate is the largest `\eta` such that the operator `(I - \eta / (batch_size * output_size) * NTK)` is a contraction, which is '2 * batch_size * output_size * lambda_max(NTK)'. Args: ntk_train_train: analytic or empirical NTK on the training data. y_train_size: total training set output size, i.e. `f(x_train).size == y_train.size`. If `output_size=None` it is inferred from `ntk_train_train.shape` assuming `trace_axes=()`. eps: a float to avoid zero divisor. Returns: The maximal feasible learning rate for infinite width NNs. """ ntk_train_train = utils.make_2d(ntk_train_train) factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size if utils.is_on_cpu(ntk_train_train): max_eva = osp.linalg.eigvalsh( ntk_train_train, eigvals=(ntk_train_train.shape[0] - 1, ntk_train_train.shape[0] - 1))[-1] else: max_eva = np.linalg.eigvalsh(ntk_train_train)[-1] lr = 2 * factor / (max_eva + eps) return lr
def _get_fns_in_eigenbasis(k_train_train: np.ndarray, diag_reg: float, diag_reg_absolute_scale: bool, fns: Iterable[Callable]) -> Iterable[Callable]: """Build functions of a matrix in its eigenbasis. Args: k_train_train: an n x n matrix fns: a sequence of functions that add on the eigenvalues (evals, dt) -> modified_evals. Returns: A tuple of functions that act as functions of the matrix mat acting on vectors: `transform(vec, dt) = fn(mat, dt) @ vec` """ k_train_train = utils.make_2d(k_train_train) k_train_train = _add_diagonal_regularizer(k_train_train, diag_reg, diag_reg_absolute_scale) evals, evecs = np.linalg.eigh(k_train_train) def to_eigenbasis(fn): """Generates a transform given a function on the eigenvalues.""" def new_fn(y_train, t): return np.einsum('ji,ti,ki,k...->tj...', evecs, fn(evals, t), evecs, y_train, optimize=True) return new_fn return (to_eigenbasis(fn) for fn in fns)
def max_learning_rate(ntk_train_train: np.ndarray, y_train_size: int = None, momentum=0., eps: float = 1e-12) -> float: r"""Computes the maximal feasible learning rate for infinite width NNs. The network is assumed to be trained using mini-/full-batch GD + momentum with mean squared loss. The loss is assumed to have the form `1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2`. For vanilla SGD (i.e. `momentum = 0`) the maximal feasible learning rate is the largest `\eta` such that the operator `(I - \eta / (batch_size * output_size) * NTK)` is a contraction, which is `2 * batch_size * output_size * lambda_max(NTK)`. When `momentum > 0`, we use (see `The Dynamics of Momentum` section in https://distill.pub/2017/momentum/) `2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK)`. Args: ntk_train_train: analytic or empirical NTK on the training data. y_train_size: total training set output size, i.e. `f(x_train).size == y_train.size`. If `output_size=None` it is inferred from `ntk_train_train.shape` assuming `trace_axes=()`. momentum: The `momentum` for momentum optimizers. eps: a float to avoid zero divisor. Returns: The maximal feasible learning rate for infinite width NNs. """ ntk_train_train = utils.make_2d(ntk_train_train) factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size if utils.is_on_cpu(ntk_train_train): max_eva = osp.linalg.eigvalsh( ntk_train_train, eigvals=(ntk_train_train.shape[0] - 1, ntk_train_train.shape[0] - 1))[-1] else: max_eva = np.linalg.eigvalsh(ntk_train_train)[-1] lr = 2 * (1 + momentum) * factor / (max_eva + eps) return lr
def _get_cho_solve( A: np.ndarray, diag_reg: float, diag_reg_absolute_scale: bool, lower: bool = False) -> Callable[[np.ndarray, Axes], np.ndarray]: x_non_channel_shape = A.shape[1::2] A = utils.make_2d(A) A = _add_diagonal_regularizer(A, diag_reg, diag_reg_absolute_scale) C = sp.linalg.cho_factor(A, lower) def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray: b_axes = utils.canonicalize_axis(b_axes, b) last_b_axes = range(-len(b_axes), 0) x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes) b = np.moveaxis(b, b_axes, last_b_axes) b = b.reshape((A.shape[1], -1)) x = sp.linalg.cho_solve(C, b) x = x.reshape(x_shape) return x return cho_solve
def predict_fn(t: ArrayOrScalar = None, x_test: np.ndarray = None, get: Get = None, compute_cov: bool = False) -> Dict[str, Gaussian]: """Return output mean and covariance on the test set at time[s] `t`. Args: t: a scalar of array of scalars of any shape. `t=None` is treated as infinity and returns the same result as `t=np.inf`, but is computed using linear solve for test predictions instead of eigendecomposition, saving time and precision. x_test: test inputs. `None` means to return non-regularized (`diag_reg=0`) predictions on the train-set inputs. For regularized predictions, pass `x_test=x_train`. get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple. `get=None` is equivalent to `get=("nngp", "ntk")`. compute_cov: if `True` computing both `mean` and `variance` and only `mean` otherwise. Returns: `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if `compute_cov == True` with potentially additional leading time dimensions. """ if get is None: get = ('nngp', 'ntk') # train-train, test-train, test-test. k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov) # Infinite time. if t is None: return predict_inf(get)(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) # Finite time. t = np.array(t) * learning_rate t_shape = t.shape t = t.reshape((-1, 1)) def reshape_mean(mean): k = _get_first(k_dd if k_td is None else k_td) mean = mean.reshape(t_shape + k.shape[::2] + trace_shape) mean = np.moveaxis(mean, last_t_axes, trace_axes) return mean def reshape_cov(cov): k = _get_first(k_dd if k_td is None else k_td) cov_shape_t = t_shape + k.shape[::2] * 2 return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape)) out = {} for g in get: evals, evecs = eigenspace(g) # Training set. if k_td is None: mean = np.einsum('ji,ti,ki,k...->tj...', evecs, -expm1(evals, t), evecs, y_train_flat, optimize=True) # Test set. else: neg_inv_expm1 = -inv_expm1(evals, t) ktd_g = utils.make_2d(getattr(k_td, g)) mean = np.einsum('lj,ji,ti,ki,k...->tl...', ktd_g, evecs, neg_inv_expm1, evecs, y_train_flat, optimize=True) mean = reshape_mean(mean) if nngp_tt is not None: nngp_dd = utils.make_2d(k_dd.nngp) # Training set. if k_td is None: if g == 'nngp': cov = np.einsum('ji,ti,ki->tjk', evecs, (np.maximum(evals, 0.) * np.exp(-2 * np.maximum(evals, 0.) * t / y_train.size)), evecs, optimize=True) elif g == 'ntk': exp = np.einsum('mi,ti,ki->tmk', evecs, np.exp(-np.maximum(evals, 0.) * t / y_train.size), evecs, optimize=True) cov = np.einsum('tmk,kl,tnl->tmn', exp, nngp_dd, exp, optimize=True) else: raise ValueError(g) # Test set. else: _nngp_tt = np.expand_dims(utils.make_2d(nngp_tt), 0) if g == 'nngp': cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml', ktd_g, evecs, -inv_expm1(evals, 2 * t), evecs, ktd_g, optimize=True) elif g == 'ntk': term_1 = np.einsum('mi,ti,ki,lk->tml', evecs, neg_inv_expm1, evecs, ktd_g, optimize=True) term_2 = np.einsum( 'mj,ji,ti,ki,lk->tml', ktd_g, evecs, neg_inv_expm1, evecs, utils.make_2d(k_td.nngp), # pytype:disable=attribute-error optimize=True) term_2 += np.moveaxis(term_2, 1, 2) cov = np.einsum('tji,jk,tkl->til', term_1, nngp_dd, term_1, optimize=True) cov += -term_2 + _nngp_tt else: raise ValueError(g) out[g] = Gaussian(mean, reshape_cov(cov)) else: out[g] = mean return out
def eigenspace(get: str): k_dd = getattr(get_k_train_train((get,)), get) k_dd = _add_diagonal_regularizer(utils.make_2d(k_dd), diag_reg, diag_reg_absolute_scale) return np.linalg.eigh(k_dd)