def predict(t=None): """`t=None` is equivalent to infinite time and calls `gp_inference`.""" if t is None: return gp_inference_mat(kdd, ktd, ktt, y_train, get, diag_reg) if not eigenspace: for g in get: k = kdd.nngp if g == 'nngp' else kdd.ntk k_dd_plus_reg = _add_diagonal_regularizer(k, diag_reg) eigenspace[g] = _eigh(k_dd_plus_reg) out = {} if 'nngp' in get: evals, evecs = eigenspace['nngp'] op_evals = -op_fn(evals, t) pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.nngp, y_train) if compute_var: op_evals_x2 = -op_fn(evals, 2 * t) pred_var = ktt - np.einsum('mj,ji,i,ki,lk->ml', ktd.nngp, evecs, op_evals_x2, evecs, ktd.nngp, optimize=True) out['nngp'] = Gaussian(pred_mean, pred_var) if compute_var else pred_mean if 'ntk' in get: evals, evecs = eigenspace['ntk'] op_evals = -op_fn(evals, t) pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.ntk, y_train) if compute_var: # inline the covariance calculation with einsum. pred_var = np.einsum('mj,ji,i,ki,lk->ml', kdd.nngp, evecs, op_evals, evecs, ktd.ntk, optimize=True) pred_var -= 2. * np.transpose(ktd.nngp) pred_var = np.einsum('mj,ji,i,ki,kl->ml', ktd.ntk, evecs, op_evals, evecs, pred_var, optimize=True) pred_var = pred_var + ktt out['ntk'] = Gaussian(pred_mean, pred_var) if compute_var else pred_mean returntype = named_tuple_factory('Gaussians', get) return returntype(*tuple(out[g] for g in get))
def prediction(t): out = {} if 'nngp' in get: evals, evecs = eigenspace['nngp'] op_evals = -op_fn(evals, 2 * t) pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.nngp, y_train) if compute_var: pred_var = ktt - np.einsum('mj,ji,i,ki,lk->ml', ktd.nngp, evecs, op_evals, evecs, ktd.nngp, optimize=True) out['nngp'] = Gaussian(pred_mean, pred_var) if compute_var else pred_mean if 'ntk' in get: evals, evecs = eigenspace['ntk'] op_evals = -op_fn(evals, t) pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.ntk, y_train) if compute_var: # inline the covariance calculation with einsum. pred_var = np.einsum('mj,ji,i,ki,lk->ml', kdd.nngp, evecs, op_evals, evecs, ktd.ntk, optimize=True) pred_var -= 2. * np.transpose(ktd.nngp) pred_var = np.einsum('mj,ji,i,ki,kl->ml', ktd.ntk, evecs, op_evals, evecs, pred_var, optimize=True) pred_var = pred_var + ktt out['ntk'] = Gaussian(pred_mean, pred_var) if compute_var else pred_mean returntype = named_tuple_factory('GPGradientDescent', get) return returntype(*tuple(out[g] for g in get))