def max_learning_rate(kdd: np.ndarray, num_outputs: int = -1, eps: float = 1e-12) -> float: r"""Computing the maximal feasible learning rate for infinite width NNs. The network is assumed to be trained using SGD or full-batch GD with mean squared loss. The loss is assumed to have the form :math:`1/(2 * batch_size * num_outputs) \|f(train_x) - train_y\|^2`. The maximal feasible learning rate is the largest `\eta` such that the operator :math:`(I - \eta / (batch_size * num_outputs) * NTK)` is a contraction, which is :math:`2 * batch_size * num_outputs * \lambda_max(NTK)`. Args: kdd: The analytic or empirical NTK of (train_x, train_x). num_outputs: The number of outputs of the neural network. If `kdd` is the analytic ntk, `num_outputs` must be provided. Otherwise `num_outputs=-1` and `num_outputs` is computed via the size of `kdd`. eps: A float to avoid zero divisor. Returns: The maximal feasible learning rate for infinite width NNs. """ if kdd.ndim not in [2, 4]: raise ValueError('`kdd` must be a 2d or 4d tensor.') if kdd.ndim == 2 and num_outputs == -1: raise ValueError( '`num_outputs` must be provided for theoretical kernel.') if kdd.ndim == 2: factor = kdd.shape[0] * num_outputs else: kdd = empirical.flatten_features(kdd) factor = kdd.shape[0] if kdd.shape[0] != kdd.shape[1]: raise ValueError('`kdd` must be a square matrix.') if _is_on_cpu(kdd): max_eva = osp.linalg.eigvalsh(kdd, eigvals=(kdd.shape[0] - 1, kdd.shape[0] - 1))[-1] else: max_eva = np.linalg.eigvalsh(kdd)[-1] lr = 2 * factor / (max_eva + eps) return lr
def gradient_descent_mse(g_dd, y_train, g_td=None, diag_reg=0.): """Predicts the outcome of function space gradient descent training on MSE. Analytically solves for the continuous-time version of gradient descent. Uses the analytic solution for gradient descent on an MSE loss in function space detailed in [*] given a Neural Tangent Kernel over the dataset. Given NTKs, this function will return a function that predicts the time evolution for function space points at arbitrary times. Note that times are continuous and are measured in units of the learning rate so t = learning_rate * steps. [*] https://arxiv.org/abs/1806.07572 Example: ```python >>> from neural_tangents import predict >>> >>> train_time = 1e-7 >>> kernel_fn = empirical(f) >>> g_td = kernel_fn(x_test, x_train, params) >>> >>> predict_fn = predict.gradient_descent_mse(g_dd, y_train, g_td) >>> >>> fx_train_initial = f(params, x_train) >>> fx_test_initial = f(params, x_test) >>> >>> fx_train_final, fx_test_final = predict_fn( >>> fx_train_initial, fx_test_initial, train_time) ``` Args: g_dd: A kernel on the training data. The kernel should be an `np.ndarray` of shape [n_train * output_dim, n_train * output_dim] or [n_train, n_train]. In the latter case, the kernel is assumed to be block diagonal over the logits. y_train: A `np.ndarray` of shape [n_train, output_dim] of targets for the training data. g_td: A Kernel relating training data with test data. The kernel should be an `np.ndarray` of shape [n_test * output_dim, n_train * output_dim] or [n_test, n_train]. Note; g_td should have been created in the convention kernel_fn(x_train, x_test, params). diag_reg: A float, representing the strength of the regularization. Returns: A function that predicts outputs after t = learning_rate * steps of training. If g_td is None: The function returned is predict(fx, t). Here fx is an `np.ndarray` of network outputs and has shape [n_train, output_dim], t is a floating point time. predict(fx, t) returns an `np.ndarray` of predictions of shape [n_train, output_dim]. If g_td is not None: If a test set Kernel is specified then it returns a function, predict(fx_train, fx_test, t). Here fx_train and fx_test are ndarays of network outputs and have shape [n_train, output_dim] and [n_test, output_dim] respectively and t is a floating point time. predict(fx_train, fx_test, t) returns a tuple of predictions of shape [n_train, output_dim] and [n_test, output_dim] for train and test points respectively. """ g_dd = empirical.flatten_features(g_dd) normalization = y_train.size output_dimension = y_train.shape[-1] expm1_fn, inv_expm1_fn = (_make_expm1_fn(normalization), _make_inv_expm1_fn(normalization)) def fl(fx): """Flatten outputs.""" return np.reshape(fx, (-1, )) def ufl(fx): """Unflatten outputs.""" return np.reshape(fx, (-1, output_dimension)) # Check to see whether the kernel has a logit dimension. if y_train.size > g_dd.shape[-1]: out_dim, ragged = divmod(y_train.size, g_dd.shape[-1]) if ragged or out_dim != y_train.shape[-1]: raise ValueError() fl = lambda x: x ufl = lambda x: x g_dd_plus_reg = _add_diagonal_regularizer(g_dd, diag_reg) expm1_dot_vec, inv_expm1_dot_vec = _eigen_fns(g_dd_plus_reg, (expm1_fn, inv_expm1_fn)) if g_td is None: def train_predict(dt, fx=0.0): gx_train = fl(fx - y_train) dgx = expm1_dot_vec(gx_train, dt) return ufl(dgx) + fx return train_predict g_td = empirical.flatten_features(g_td) def predict_using_kernel(dt, fx_train=0., fx_test=0.): gx_train = fl(fx_train - y_train) dgx = expm1_dot_vec(gx_train, dt) # Note: consider use a linalg solve instead of the eigeninverse # dfx = sp.linalg.solve(g_dd, dgx, sym_pos=True) dfx = inv_expm1_dot_vec(gx_train, dt) dfx = np.dot(g_td, dfx) return ufl(dgx) + fx_train, fx_test + ufl(dfx) return predict_using_kernel
def momentum(g_dd, y_train, loss, learning_rate, g_td=None, momentum=0.9): r"""Predicts the outcome of function space training using momentum descent. Solves a continuous-time version of standard momentum instead of Nesterov momentum using an ODE solver. Solves the function space ODE for momentum with a given loss (detailed in [*]) given a Neural Tangent Kernel over the dataset. This function returns a triplet of functions that initialize state variables, predicts the time evolution for function space points at arbitrary times and retrieves the function-space outputs from the state. Note that times are continuous and are measured in units of the learning rate so that t = \sqrt(learning_rate) * steps. This function uses the scipy ode solver with the 'dopri5' algorithm. [*] https://arxiv.org/abs/1902.06720 Example: ```python >>> train_time = 1e-7 >>> learning_rate = 1e-2 >>> >>> kernel_fn = empirical(f) >>> g_td = kernel_fn(x_test, x_train, params) >>> >>> from jax.experimental import stax >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat) >>> init_fn, predict_fn, get_fn = predict.momentum( >>> g_dd, y_train, cross_entropy, learning_rate, g_td) >>> >>> fx_train_initial = f(params, x_train) >>> fx_test_initial = f(params, x_test) >>> >>> lin_state = init_fn(fx_train_initial, fx_test_initial) >>> lin_state = predict_fn(lin_state, train_time) >>> fx_train_final, fx_test_final = get_fn(lin_state) ```python Args: g_dd: Kernel on the training data. The kernel should be an `np.ndarray` of shape [n_train * output_dim, n_train * output_dim]. y_train: A `np.ndarray` of shape [n_train, output_dim] of labels for the training data. loss: A loss function whose signature is loss(fx, y_hat) where fx an `np.ndarray` of function space outputs of the network and y_hat are labels. Note: the loss function should treat the batch and output dimensions symmetrically. learning_rate: A float specifying the learning rate. g_td: Kernel relating training data with test data. Should be an `np.ndarray` of shape [n_test * output_dim, n_train * output_dim]. Note: g_td should have been created in the convention g_td = kernel_fn(x_test, x_train, params). momentum: float specifying the momentum. Returns: Functions to predicts outputs after t = \sqrt(learning_rate) * steps of training. Generically three functions are returned, an init_fn that creates auxiliary velocity variables needed for optimization and packs them into a state variable, a predict_fn that computes the time-evolution of the state for some dt, and a get_fn that extracts the predictions from the state. If g_td is None: init_fn(fx_train): Takes a single `np.ndarray` of shape [n_train, output_dim] and returns a tuple containing the output_dim as an int and an `np.ndarray` of shape [2 * n_train * output_dim]. predict_fn(state, dt): Takes a state described above and a floating point time. Returns a new state with the same type and shape. get_fn(state): Takes a state and returns an `np.ndarray` of shape [n_train, output_dim]. If g_td is not None: init_fn(fx_train, fx_test): Takes two `np.ndarray`s of shape [n_train, output_dim] and [n_test, output_dim] respectively. Returns a tuple with an int giving 2 * n_train * output_dim, an int containing the output_dim, and an `np.ndarray` of shape [2 * (n_train + n_test) * output_dim]. predict_fn(state, dt): Takes a state described above and a floating point time. Returns a new state with the same type and shape. get_fn(state): Takes a state and returns two `np.ndarray` of shape [n_train, output_dim] and [n_test, output_dim] respectively. """ output_dimension = y_train.shape[-1] g_dd = empirical.flatten_features(g_dd) momentum = (momentum - 1.0) / np.sqrt(learning_rate) def fl(fx): """Flatten outputs.""" return np.reshape(fx, (-1, )) def ufl(fx): """Unflatten outputs.""" return np.reshape(fx, (-1, output_dimension)) # These functions are used inside the integrator only if the kernel is # diagonal over the logits. ifl = lambda x: x iufl = lambda x: x # Check to see whether the kernel has a logit dimension. if y_train.size > g_dd.shape[-1]: out_dim, ragged = divmod(y_train.size, g_dd.shape[-1]) if ragged or out_dim != y_train.shape[-1]: raise ValueError() ifl = fl iufl = ufl y_train = np.reshape(y_train, (-1)) grad_loss = grad(functools.partial(loss, y_hat=y_train)) if g_td is None: def dr_dt(unused_t, r): fx, qx = np.split(r, 2) dfx = qx dqx = momentum * qx - ifl(np.dot(g_dd, iufl(grad_loss(fx)))) return np.concatenate((dfx, dqx), axis=0) def init_fn(fx_train=0.): fx_train = fl(fx_train) qx_train = np.zeros_like(fx_train) return np.concatenate((fx_train, qx_train), axis=0) def predict_fn(state, dt): state = state solver = ode(dr_dt).set_integrator('dopri5') solver.set_initial_value(state, 0) solver.integrate(dt) return solver.y def get_fn(state): return ufl(np.split(state, 2)[0]) else: g_td = empirical.flatten_features(g_td) def dr_dt(unused_t, r, train_size): train, test = r[:train_size], r[train_size:] fx_train, qx_train = np.split(train, 2) _, qx_test = np.split(test, 2) dfx_train = qx_train dqx_train = \ momentum * qx_train - ifl(np.dot(g_dd, iufl(grad_loss(fx_train)))) dfx_test = qx_test dqx_test = \ momentum * qx_test - ifl(np.dot(g_td, iufl(grad_loss(fx_train)))) return np.concatenate((dfx_train, dqx_train, dfx_test, dqx_test), axis=0) def init_fn(fx_train=0., fx_test=0.): train_size = fx_train.shape[0] fx_train, fx_test = fl(fx_train), fl(fx_test) qx_train = np.zeros_like(fx_train) qx_test = np.zeros_like(fx_test) return (2 * train_size * output_dimension, np.concatenate((fx_train, qx_train, fx_test, qx_test), axis=0)) def predict_fn(state, dt): train_size, state = state solver = ode(dr_dt).set_integrator('dopri5') solver.set_initial_value(state, 0).set_f_params(train_size) solver.integrate(dt) return train_size, solver.y def get_fn(state): train_size, state = state train, test = state[:train_size], state[train_size:] return ufl(np.split(train, 2)[0]), ufl(np.split(test, 2)[0]) return init_fn, predict_fn, get_fn
def gradient_descent(g_dd, y_train, loss, g_td=None): """Predicts the outcome of function space gradient descent training on `loss`. Solves for continuous-time gradient descent using an ODE solver. Solves the function space ODE for continuous gradient descent with a given loss (detailed in [*]) given a Neural Tangent Kernel over the dataset. This function returns a function that predicts the time evolution for function space points at arbitrary times. Note that times are continuous and are measured in units of the learning rate so that t = learning_rate * steps. This function uses the scipy ode solver with the 'dopri5' algorithm. [*] https://arxiv.org/abs/1902.06720 Example: ```python >>> from jax.experimental import stax >>> from neural_tangents import predict >>> >>> train_time = 1e-7 >>> kernel_fn = empirical(f) >>> g_td = kernel_fn(x_test, x_train, params) >>> >>> from jax.experimental import stax >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat) >>> predict_fn = predict.gradient_descent( >>> g_dd, y_train, cross_entropy, g_td) >>> >>> fx_train_initial = f(params, x_train) >>> fx_test_initial = f(params, x_test) >>> >>> fx_train_final, fx_test_final = predict_fn( >>> fx_train_initial, fx_test_initial, train_time) ``` Args: g_dd: A Kernel on the training data. The kernel should be an `np.ndarray` of shape [n_train * output_dim, n_train * output_dim] or [n_train, n_train]. In the latter case it is assumed that the kernel is block diagonal over the logits. y_train: A `np.ndarray` of shape [n_train, output_dim] of labels for the training data. loss: A loss function whose signature is loss(fx, y_hat) where fx is an `np.ndarray` of function space output_dim of the network and y_hat are targets. Note: the loss function should treat the batch and output dimensions symmetrically. g_td: A Kernel relating training data with test data. The kernel should be an `np.ndarray` of shape [n_test * output_dim, n_train * output_dim] or [n_test, n_train]. Note: g_td should have been created in the convention kernel_fn(x_test, x_train, params). Returns: A function that predicts outputs after t = learning_rate * steps of training. If g_td is None: The function returned is predict(fx, t). Here fx is an `np.ndarray` of network outputs and has shape [n_train, output_dim], t is a floating point time. predict(fx, t) returns an `np.ndarray` of predictions of shape [n_train, output_dim]. If g_td is not None: If a test set Kernel is specified then it returns a function, predict(fx_train, fx_test, t). Here fx_train and fx_test are ndarays of network outputs and have shape [n_train, output_dim] and [n_test, output_dim] respectively and t is a floating point time. predict(fx_train, fx_test, t) returns a tuple of predictions of shape [n_train, output_dim] and [n_test, output_dim] for train and test points respectively. """ output_dimension = y_train.shape[-1] g_dd = empirical.flatten_features(g_dd) def fl(fx): """Flatten outputs.""" return np.reshape(fx, (-1, )) def ufl(fx): """Unflatten outputs.""" return np.reshape(fx, (-1, output_dimension)) # These functions are used inside the integrator only if the kernel is # diagonal over the logits. ifl = lambda x: x iufl = lambda x: x # Check to see whether the kernel has a logit dimension. if y_train.size > g_dd.shape[-1]: out_dim, ragged = divmod(y_train.size, g_dd.shape[-1]) if ragged or out_dim != y_train.shape[-1]: raise ValueError() ifl = fl iufl = ufl y_train = np.reshape(y_train, (-1)) grad_loss = grad(functools.partial(loss, y_hat=y_train)) if g_td is None: dfx_dt = lambda unused_t, fx: -ifl(np.dot(g_dd, iufl(grad_loss(fx)))) def predict(dt, fx=0.): r = ode(dfx_dt).set_integrator('dopri5') r.set_initial_value(fl(fx), 0) r.integrate(dt) return ufl(r.y) else: g_td = empirical.flatten_features(g_td) def dfx_dt(unused_t, fx, train_size): fx_train = fx[:train_size] dfx_train = -ifl(np.dot(g_dd, iufl(grad_loss(fx_train)))) dfx_test = -ifl(np.dot(g_td, iufl(grad_loss(fx_train)))) return np.concatenate((dfx_train, dfx_test), axis=0) def predict(dt, fx_train=0., fx_test=0.): r = ode(dfx_dt).set_integrator('dopri5') fx = fl(np.concatenate((fx_train, fx_test), axis=0)) train_size, output_dim = fx_train.shape r.set_initial_value(fx, 0).set_f_params(train_size * output_dim) r.integrate(dt) fx = ufl(r.y) return fx[:train_size], fx[train_size:] return predict
def analytic_mse(g_dd, y_train, g_td=None): """Predicts the outcome of function space training with an MSE loss. Uses the analytic solution for gradient descent on an MSE loss in function space detailed in [*] given a Neural Tangent Kernel over the dataset. Given NTKs, this function will return a function that predicts the time evolution for function space points at arbitrary times. Note that times are continuous and are measured in units of the learning rate so t = learning_rate * steps. [*] https://arxiv.org/abs/1806.07572 Example: >>> train_time = 1e-7 >>> ker_fun = empirical(f) >>> g_td = ker_fun(x_test, x_train, params) >>> >>> predict_fn = analytic_mse_predictor(g_dd, y_train, g_td) >>> >>> fx_train_initial = f(params, x_train) >>> fx_test_initial = f(params, x_test) >>> >>> fx_train_final, fx_test_final = predict_fn( >>> fx_train_initial, fx_test_initial, train_time) Args: g_dd: A kernel on the training data. The kernel should be an `np.ndarray` of shape [n_train * output_dim, n_train * output_dim] or [n_train, n_train]. In the latter case, the kernel is assumed to be block diagonal over the logits. y_train: A `np.ndarray` of shape [n_train, output_dim] of targets for the training data. g_td: A Kernel relating training data with test data. The kernel should be an `np.ndarray` of shape [n_test * output_dim, n_train * output_dim] or [n_test, n_train]. Note: g_td should have been created in the convention ker_fun(x_train, x_test, params). Returns: A function that predicts outputs after t = learning_rate * steps of training. If g_td is None: The function returned is predict(fx, t). Here fx is an `np.ndarray` of network outputs and has shape [n_train, output_dim], t is a floating point time. predict(fx, t) returns an `np.ndarray` of predictions of shape [n_train, output_dim]. If g_td is not None: If a test set Kernel is specified then it returns a function, predict(fx_train, fx_test, t). Here fx_train and fx_test are ndarays of network outputs and have shape [n_train, output_dim] and [n_test, output_dim] respectively and t is a floating point time. predict(fx_train, fx_test, t) returns a tuple of predictions of shape [n_train, output_dim] and [n_test, output_dim] for train and test points respectively. """ g_dd = _canonicalize_kernel_to_ntk(g_dd) g_td = _canonicalize_kernel_to_ntk(g_td) g_dd = empirical.flatten_features(g_dd) # TODO(schsam): Eventually, we may want to handle non-symmetric kernels for # e.g. masking. Additionally, once JAX supports eigh on TPU, we probably want # to switch to JAX's eigh. if xla_bridge.get_backend().platform == 'tpu': eigh = np.onp.linalg.eigh else: eigh = np.linalg.eigh evals, evecs = eigh(g_dd) ievecs = np.transpose(evecs) normalization = y_train.size output_dimension = y_train.shape[-1] def fl(fx): """Flatten outputs.""" return np.reshape(fx, (-1, )) def ufl(fx): """Unflatten outputs.""" return np.reshape(fx, (-1, output_dimension)) # Check to see whether the kernel has a logit dimension. if y_train.size > g_dd.shape[-1]: out_dim, ragged = divmod(y_train.size, g_dd.shape[-1]) if ragged or out_dim != y_train.shape[-1]: raise ValueError() fl = lambda x: x ufl = lambda x: x def predict(gx, dt): gx_ = np.diag(np.exp(-evals * dt / normalization)) gx_ = np.dot(evecs, gx_) gx_ = np.dot(gx_, ievecs) gx_ = np.dot(gx_, gx) return gx_ if g_td is None: return lambda dt, fx=0.: ufl(predict(fl(fx - y_train), dt)) + y_train g_td = empirical.flatten_features(g_td) mevals = np.diag(1.0 / evals) inverse = np.dot(np.dot(evecs, mevals), ievecs) def predict_using_kernel(dt, fx_train=0., fx_test=0.): gx_train = fl(fx_train - y_train) dgx = predict(gx_train, dt) - gx_train dfx = np.dot(inverse, dgx) dfx = np.dot(g_td, dfx) return ufl(dgx) + fx_train, fx_test + ufl(dfx) return predict_using_kernel