def _get_dependency(get, compute_var): _, get = canonicalize_get(get) for g in get: if g not in ['nngp', 'ntk']: raise NotImplementedError( 'Can only get either "nngp" or "ntk" predictions, got %s.' % g) get_dependency = () if 'nngp' in get or ('ntk' in get and compute_var): get_dependency += ('nngp', ) if 'ntk' in get: get_dependency += ('ntk', ) return get_dependency
def _get_dependency(get: Get, compute_cov: bool) -> Tuple[str, ...]: """Figure out dependency for get.""" _, get = utils.canonicalize_get(get) for g in get: if g not in ['nngp', 'ntk']: raise NotImplementedError( 'Can only get either "nngp" or "ntk" predictions, got %s.' % g) get_dependency = () if 'nngp' in get or ('ntk' in get and compute_cov): get_dependency += ('nngp', ) if 'ntk' in get: get_dependency += ('ntk', ) return get_dependency
def predict_inf(get: Get): _, get = utils.canonicalize_get(get) k_dd = get_k_train_train(get) return gp_inference(k_dd, y_train, diag_reg, diag_reg_absolute_scale, trace_axes)
def gradient_descent_mse_gp(kernel_fn, x_train, y_train, x_test, get, diag_reg=0.0, compute_cov=False): """Predicts the gaussian embedding induced by gradient descent with mse loss. This is equivalent to an infinite ensemble of networks after marginalizing out the initialization. Args: kernel_fn: A kernel function that computes NNGP and NTK. x_train: A `np.ndarray`, representing the training data. y_train: A `np.ndarray`, representing the labels of the training data. x_test: A `np.ndarray`, representing the test data. get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple. diag_reg: A float, representing the strength of the regularization. compute_cov: A boolean. If `True` computing both `mean` and `variance` and only `mean` otherwise. Returns: A function that predicts the gaussian parameters at t: predict(t) -> Gaussian(mean, variance). If compute_cov is False, only returns the mean. """ if get is None: get = ('nngp', 'ntk') if isinstance(get, str): # NOTE(schsam): This seems like an ugly solution that involves an extra # indirection. It might be nice to clean it up. return lambda t: gradient_descent_mse_gp(kernel_fn, x_train, y_train, x_test, diag_reg=diag_reg, get=(get, ), compute_cov=compute_cov)(t)[0] _, get = canonicalize_get(get) normalization = y_train.size op_fn = _make_inv_expm1_fn(normalization) eigenspace = {} kdd, ktd, ktt = _get_matrices(kernel_fn, x_train, x_test, get, compute_cov) gp_inference_mat = (_gp_inference_mat_jit_cpu if _is_on_cpu(kdd) else _gp_inference_mat_jit) @_jit_cpu(kdd) def predict(t=None): """`t=None` is equivalent to infinite time and calls `gp_inference`.""" if t is None: return gp_inference_mat(kdd, ktd, ktt, y_train, get, diag_reg) if not eigenspace: for g in get: k = kdd.nngp if g == 'nngp' else kdd.ntk k_dd_plus_reg = _add_diagonal_regularizer(k, diag_reg) eigenspace[g] = _eigh(k_dd_plus_reg) out = {} if 'nngp' in get: evals, evecs = eigenspace['nngp'] op_evals = -op_fn(evals, t) pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.nngp, y_train) if compute_cov: op_evals_x2 = -op_fn(evals, 2 * t) pred_cov = ktt - np.einsum('mj,ji,i,ki,lk->ml', ktd.nngp, evecs, op_evals_x2, evecs, ktd.nngp, optimize=True) out['nngp'] = Gaussian(pred_mean, pred_cov) if compute_cov else pred_mean if 'ntk' in get: evals, evecs = eigenspace['ntk'] op_evals = -op_fn(evals, t) pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.ntk, y_train) if compute_cov: # inline the covariance calculation with einsum. term_1 = np.einsum('mi,i,ki,lk->ml', evecs, op_evals, evecs, ktd.ntk, optimize=True) pred_cov = np.einsum('ji,jk,kl->il', term_1, kdd.nngp, term_1, optimize=True) term_2 = np.einsum('mj,ji,i,ki,lk->ml', ktd.ntk, evecs, op_evals, evecs, ktd.nngp, optimize=True) term_2 += np.transpose(term_2) pred_cov += (-term_2 + ktt) out['ntk'] = Gaussian(pred_mean, pred_cov) if compute_cov else pred_mean returntype = named_tuple_factory('Gaussians', get) return returntype(*tuple(out[g] for g in get)) return predict
def gradient_descent_mse_gp(kernel_fn, x_train, y_train, x_test, get=('nngp', 'ntk'), diag_reg=0.0, compute_var=False): """Predicts the gaussian embedding induced by gradient descent with mse loss. This is equivalent to an infinite ensemble of networks after marginalizing out the initialization. Args: kernel_fn: A kernel function that computes NNGP and NTK. x_train: A `np.ndarray`, representing the training data. y_train: A `np.ndarray`, representing the labels of the training data. x_test: A `np.ndarray`, representing the test data. diag_reg: A float, representing the strength of the regularization. get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple. compute_var: A boolean. If `True` computing both `mean` and `variance` and only `mean` otherwise. Returns: A function that predicts the gaussian parameters at t: prediction(t) -> Gaussian(mean, variance). If compute_var is False, only returns the mean. """ if isinstance(get, str): # NOTE(schsam): This seems like an ugly solution that involves an extra # indirection. It might be nice to clean it up. return lambda t: gradient_descent_mse_gp(kernel_fn, x_train, y_train, x_test, diag_reg=diag_reg, get=(get, ), compute_var=compute_var)(t)[0] _, get = canonicalize_get(get) get_dependency = _get_dependency(compute_var, get) kdd = kernel_fn(x_train, None, get_dependency) ktd = kernel_fn(x_test, x_train, get_dependency) if compute_var: ktt = kernel_fn(x_test, None, 'nngp') normalization = y_train.size op_fn = _make_inv_expm1_fn(normalization) eigenspace = {} for g in get: k = kdd.nngp if g == 'nngp' else kdd.ntk k_dd_plus_reg = _add_diagonal_regularizer(k, diag_reg) eigenspace[g] = _eigh(k_dd_plus_reg) def prediction(t): out = {} if 'nngp' in get: evals, evecs = eigenspace['nngp'] op_evals = -op_fn(evals, 2 * t) pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.nngp, y_train) if compute_var: pred_var = ktt - np.einsum('mj,ji,i,ki,lk->ml', ktd.nngp, evecs, op_evals, evecs, ktd.nngp, optimize=True) out['nngp'] = Gaussian(pred_mean, pred_var) if compute_var else pred_mean if 'ntk' in get: evals, evecs = eigenspace['ntk'] op_evals = -op_fn(evals, t) pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.ntk, y_train) if compute_var: # inline the covariance calculation with einsum. pred_var = np.einsum('mj,ji,i,ki,lk->ml', kdd.nngp, evecs, op_evals, evecs, ktd.ntk, optimize=True) pred_var -= 2. * np.transpose(ktd.nngp) pred_var = np.einsum('mj,ji,i,ki,kl->ml', ktd.ntk, evecs, op_evals, evecs, pred_var, optimize=True) pred_var = pred_var + ktt out['ntk'] = Gaussian(pred_mean, pred_var) if compute_var else pred_mean returntype = named_tuple_factory('GPGradientDescent', get) return returntype(*tuple(out[g] for g in get)) return prediction