Beispiel #1
0
def _get_dependency(get, compute_var):
    _, get = canonicalize_get(get)
    for g in get:
        if g not in ['nngp', 'ntk']:
            raise NotImplementedError(
                'Can only get either "nngp" or "ntk" predictions, got %s.' % g)
    get_dependency = ()
    if 'nngp' in get or ('ntk' in get and compute_var):
        get_dependency += ('nngp', )
    if 'ntk' in get:
        get_dependency += ('ntk', )
    return get_dependency
Beispiel #2
0
def _get_dependency(get: Get, compute_cov: bool) -> Tuple[str, ...]:
    """Figure out dependency for get."""
    _, get = utils.canonicalize_get(get)
    for g in get:
        if g not in ['nngp', 'ntk']:
            raise NotImplementedError(
                'Can only get either "nngp" or "ntk" predictions, got %s.' % g)
    get_dependency = ()
    if 'nngp' in get or ('ntk' in get and compute_cov):
        get_dependency += ('nngp', )
    if 'ntk' in get:
        get_dependency += ('ntk', )
    return get_dependency
Beispiel #3
0
 def predict_inf(get: Get):
     _, get = utils.canonicalize_get(get)
     k_dd = get_k_train_train(get)
     return gp_inference(k_dd, y_train, diag_reg, diag_reg_absolute_scale,
                         trace_axes)
Beispiel #4
0
def gradient_descent_mse_gp(kernel_fn,
                            x_train,
                            y_train,
                            x_test,
                            get,
                            diag_reg=0.0,
                            compute_cov=False):
    """Predicts the gaussian embedding induced by gradient descent with mse loss.

  This is equivalent to an infinite ensemble of networks after marginalizing
  out the initialization.

  Args:
    kernel_fn: A kernel function that computes NNGP and NTK.
    x_train: A `np.ndarray`, representing the training data.
    y_train: A `np.ndarray`, representing the labels of the training data.
    x_test: A `np.ndarray`, representing the test data.
    get: string, the mode of the Gaussian process, either "nngp" or "ntk", or
      a tuple.
    diag_reg: A float, representing the strength of the regularization.
    compute_cov: A boolean. If `True` computing both `mean` and `variance` and
      only `mean` otherwise.

  Returns:
    A function that predicts the gaussian parameters at t:
      predict(t) -> Gaussian(mean, variance).
      If compute_cov is False, only returns the mean.
  """
    if get is None:
        get = ('nngp', 'ntk')
    if isinstance(get, str):
        # NOTE(schsam): This seems like an ugly solution that involves an extra
        # indirection. It might be nice to clean it up.
        return lambda t: gradient_descent_mse_gp(kernel_fn,
                                                 x_train,
                                                 y_train,
                                                 x_test,
                                                 diag_reg=diag_reg,
                                                 get=(get, ),
                                                 compute_cov=compute_cov)(t)[0]

    _, get = canonicalize_get(get)

    normalization = y_train.size
    op_fn = _make_inv_expm1_fn(normalization)

    eigenspace = {}

    kdd, ktd, ktt = _get_matrices(kernel_fn, x_train, x_test, get, compute_cov)
    gp_inference_mat = (_gp_inference_mat_jit_cpu
                        if _is_on_cpu(kdd) else _gp_inference_mat_jit)

    @_jit_cpu(kdd)
    def predict(t=None):
        """`t=None` is equivalent to infinite time and calls `gp_inference`."""
        if t is None:
            return gp_inference_mat(kdd, ktd, ktt, y_train, get, diag_reg)

        if not eigenspace:
            for g in get:
                k = kdd.nngp if g == 'nngp' else kdd.ntk
                k_dd_plus_reg = _add_diagonal_regularizer(k, diag_reg)
                eigenspace[g] = _eigh(k_dd_plus_reg)

        out = {}

        if 'nngp' in get:
            evals, evecs = eigenspace['nngp']
            op_evals = -op_fn(evals, t)
            pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.nngp,
                                                y_train)
            if compute_cov:
                op_evals_x2 = -op_fn(evals, 2 * t)
                pred_cov = ktt - np.einsum('mj,ji,i,ki,lk->ml',
                                           ktd.nngp,
                                           evecs,
                                           op_evals_x2,
                                           evecs,
                                           ktd.nngp,
                                           optimize=True)

            out['nngp'] = Gaussian(pred_mean,
                                   pred_cov) if compute_cov else pred_mean

        if 'ntk' in get:
            evals, evecs = eigenspace['ntk']
            op_evals = -op_fn(evals, t)
            pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.ntk,
                                                y_train)
            if compute_cov:
                # inline the covariance calculation with einsum.
                term_1 = np.einsum('mi,i,ki,lk->ml',
                                   evecs,
                                   op_evals,
                                   evecs,
                                   ktd.ntk,
                                   optimize=True)
                pred_cov = np.einsum('ji,jk,kl->il',
                                     term_1,
                                     kdd.nngp,
                                     term_1,
                                     optimize=True)
                term_2 = np.einsum('mj,ji,i,ki,lk->ml',
                                   ktd.ntk,
                                   evecs,
                                   op_evals,
                                   evecs,
                                   ktd.nngp,
                                   optimize=True)
                term_2 += np.transpose(term_2)
                pred_cov += (-term_2 + ktt)

            out['ntk'] = Gaussian(pred_mean,
                                  pred_cov) if compute_cov else pred_mean

        returntype = named_tuple_factory('Gaussians', get)
        return returntype(*tuple(out[g] for g in get))

    return predict
Beispiel #5
0
def gradient_descent_mse_gp(kernel_fn,
                            x_train,
                            y_train,
                            x_test,
                            get=('nngp', 'ntk'),
                            diag_reg=0.0,
                            compute_var=False):
    """Predicts the gaussian embedding induced by gradient descent with mse loss.

  This is equivalent to an infinite ensemble of networks after marginalizing
  out the initialization.

  Args:
    kernel_fn: A kernel function that computes NNGP and NTK.
    x_train: A `np.ndarray`, representing the training data.
    y_train: A `np.ndarray`, representing the labels of the training data.
    x_test: A `np.ndarray`, representing the test data.
    diag_reg: A float, representing the strength of the regularization.
    get: string, the mode of the Gaussian process, either "nngp" or "ntk", or
      a tuple.
    compute_var: A boolean. If `True` computing both `mean` and `variance` and
      only `mean` otherwise.

  Returns:
    A function that predicts the gaussian parameters at t:
      prediction(t) -> Gaussian(mean, variance).
      If compute_var is False, only returns the mean.
  """
    if isinstance(get, str):
        # NOTE(schsam): This seems like an ugly solution that involves an extra
        # indirection. It might be nice to clean it up.
        return lambda t: gradient_descent_mse_gp(kernel_fn,
                                                 x_train,
                                                 y_train,
                                                 x_test,
                                                 diag_reg=diag_reg,
                                                 get=(get, ),
                                                 compute_var=compute_var)(t)[0]

    _, get = canonicalize_get(get)
    get_dependency = _get_dependency(compute_var, get)

    kdd = kernel_fn(x_train, None, get_dependency)
    ktd = kernel_fn(x_test, x_train, get_dependency)
    if compute_var:
        ktt = kernel_fn(x_test, None, 'nngp')

    normalization = y_train.size
    op_fn = _make_inv_expm1_fn(normalization)

    eigenspace = {}
    for g in get:
        k = kdd.nngp if g == 'nngp' else kdd.ntk
        k_dd_plus_reg = _add_diagonal_regularizer(k, diag_reg)
        eigenspace[g] = _eigh(k_dd_plus_reg)

    def prediction(t):
        out = {}

        if 'nngp' in get:
            evals, evecs = eigenspace['nngp']
            op_evals = -op_fn(evals, 2 * t)
            pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.nngp,
                                                y_train)
            if compute_var:
                pred_var = ktt - np.einsum('mj,ji,i,ki,lk->ml',
                                           ktd.nngp,
                                           evecs,
                                           op_evals,
                                           evecs,
                                           ktd.nngp,
                                           optimize=True)

            out['nngp'] = Gaussian(pred_mean,
                                   pred_var) if compute_var else pred_mean

        if 'ntk' in get:
            evals, evecs = eigenspace['ntk']
            op_evals = -op_fn(evals, t)
            pred_mean = _mean_prediction_einsum(evecs, op_evals, ktd.ntk,
                                                y_train)
            if compute_var:
                # inline the covariance calculation with einsum.
                pred_var = np.einsum('mj,ji,i,ki,lk->ml',
                                     kdd.nngp,
                                     evecs,
                                     op_evals,
                                     evecs,
                                     ktd.ntk,
                                     optimize=True)
                pred_var -= 2. * np.transpose(ktd.nngp)
                pred_var = np.einsum('mj,ji,i,ki,kl->ml',
                                     ktd.ntk,
                                     evecs,
                                     op_evals,
                                     evecs,
                                     pred_var,
                                     optimize=True)
                pred_var = pred_var + ktt

            out['ntk'] = Gaussian(pred_mean,
                                  pred_var) if compute_var else pred_mean

        returntype = named_tuple_factory('GPGradientDescent', get)
        return returntype(*tuple(out[g] for g in get))

    return prediction