Exemplo n.º 1
0
 def eigenspace(get: str):
     k_dd = getattr(get_k_train_train((get, )), get)
     k_dd = _add_diagonal_regularizer(utils.make_2d(k_dd), diag_reg,
                                      diag_reg_absolute_scale)
     evals, evecs = np.linalg.eigh(k_dd)
     evals = np.expand_dims(evals, 0)
     return evals, evecs
Exemplo n.º 2
0
def max_learning_rate(ntk_train_train: np.ndarray,
                      y_train_size: int = None,
                      eps: float = 1e-12) -> float:
    r"""Computes the maximal feasible learning rate for infinite width NNs.

  The network is assumed to be trained using SGD or full-batch GD with mean
  squared loss. The loss is assumed to have the form
  `1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2`. The maximal
  feasible learning rate is the largest `\eta` such that the operator
  `(I - \eta / (batch_size * output_size) * NTK)` is a contraction, which is
  '2 * batch_size * output_size * lambda_max(NTK)'.

  Args:
    ntk_train_train: analytic or empirical NTK on the training data.
    y_train_size: total training set output size, i.e.
      `f(x_train).size ==  y_train.size`. If `output_size=None` it is inferred
      from `ntk_train_train.shape` assuming `trace_axes=()`.
    eps: a float to avoid zero divisor.

  Returns:
    The maximal feasible learning rate for infinite width NNs.
  """
    ntk_train_train = utils.make_2d(ntk_train_train)
    factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size

    if utils.is_on_cpu(ntk_train_train):
        max_eva = osp.linalg.eigvalsh(
            ntk_train_train,
            eigvals=(ntk_train_train.shape[0] - 1,
                     ntk_train_train.shape[0] - 1))[-1]
    else:
        max_eva = np.linalg.eigvalsh(ntk_train_train)[-1]
    lr = 2 * factor / (max_eva + eps)
    return lr
Exemplo n.º 3
0
def _get_fns_in_eigenbasis(k_train_train: np.ndarray,
                           diag_reg: float,
                           diag_reg_absolute_scale: bool,
                           fns: Iterable[Callable]) -> Iterable[Callable]:
  """Build functions of a matrix in its eigenbasis.

  Args:
    k_train_train:
      an n x n matrix
    fns:
      a sequence of functions that add on the eigenvalues (evals, dt) ->
      modified_evals.

  Returns:
    A tuple of functions that act as functions of the matrix mat
    acting on vectors: `transform(vec, dt) = fn(mat, dt) @ vec`
  """
  k_train_train = utils.make_2d(k_train_train)
  k_train_train = _add_diagonal_regularizer(k_train_train, diag_reg,
                                            diag_reg_absolute_scale)
  evals, evecs = np.linalg.eigh(k_train_train)

  def to_eigenbasis(fn):
    """Generates a transform given a function on the eigenvalues."""
    def new_fn(y_train, t):
      return np.einsum('ji,ti,ki,k...->tj...',
                       evecs, fn(evals, t), evecs, y_train,
                       optimize=True)

    return new_fn

  return (to_eigenbasis(fn) for fn in fns)
Exemplo n.º 4
0
def max_learning_rate(ntk_train_train: np.ndarray,
                      y_train_size: int = None,
                      momentum=0.,
                      eps: float = 1e-12) -> float:
    r"""Computes the maximal feasible learning rate for infinite width NNs.

  The network is assumed to be trained using mini-/full-batch GD + momentum
  with mean squared loss. The loss is assumed to have the form
  `1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2`. For vanilla SGD
  (i.e. `momentum = 0`) the maximal feasible learning rate is the largest `\eta`
  such that the operator
                `(I - \eta / (batch_size * output_size) * NTK)`
  is a contraction, which is
                `2 * batch_size * output_size * lambda_max(NTK)`.
  When `momentum > 0`, we use (see `The Dynamics of Momentum` section in
  https://distill.pub/2017/momentum/)
                `2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK)`.

  Args:
    ntk_train_train:
      analytic or empirical NTK on the training data.
    y_train_size:
      total training set output size, i.e.
      `f(x_train).size ==  y_train.size`. If `output_size=None` it is inferred
      from `ntk_train_train.shape` assuming `trace_axes=()`.
    momentum:
      The `momentum` for momentum optimizers.
    eps:
      a float to avoid zero divisor.

  Returns:
    The maximal feasible learning rate for infinite width NNs.
  """
    ntk_train_train = utils.make_2d(ntk_train_train)
    factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size

    if utils.is_on_cpu(ntk_train_train):
        max_eva = osp.linalg.eigvalsh(
            ntk_train_train,
            eigvals=(ntk_train_train.shape[0] - 1,
                     ntk_train_train.shape[0] - 1))[-1]
    else:
        max_eva = np.linalg.eigvalsh(ntk_train_train)[-1]
    lr = 2 * (1 + momentum) * factor / (max_eva + eps)
    return lr
Exemplo n.º 5
0
def _get_cho_solve(
        A: np.ndarray,
        diag_reg: float,
        diag_reg_absolute_scale: bool,
        lower: bool = False) -> Callable[[np.ndarray, Axes], np.ndarray]:
    x_non_channel_shape = A.shape[1::2]
    A = utils.make_2d(A)
    A = _add_diagonal_regularizer(A, diag_reg, diag_reg_absolute_scale)
    C = sp.linalg.cho_factor(A, lower)

    def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray:
        b_axes = utils.canonicalize_axis(b_axes, b)
        last_b_axes = range(-len(b_axes), 0)
        x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes)

        b = np.moveaxis(b, b_axes, last_b_axes)
        b = b.reshape((A.shape[1], -1))

        x = sp.linalg.cho_solve(C, b)
        x = x.reshape(x_shape)
        return x

    return cho_solve
Exemplo n.º 6
0
    def predict_fn(t: ArrayOrScalar = None,
                   x_test: np.ndarray = None,
                   get: Get = None,
                   compute_cov: bool = False) -> Dict[str, Gaussian]:
        """Return output mean and covariance on the test set at time[s] `t`.

    Args:
      t:
        a scalar of array of scalars of any shape. `t=None` is treated as
        infinity and returns the same result as `t=np.inf`, but is computed
        using linear solve for test predictions instead of eigendecomposition,
        saving time and precision.
      x_test:
        test inputs. `None` means to return non-regularized (`diag_reg=0`)
        predictions on the train-set inputs. For regularized predictions, pass
        `x_test=x_train`.
      get:
        string, the mode of the Gaussian process, either "nngp" or "ntk", or a
        tuple. `get=None` is equivalent to `get=("nngp", "ntk")`.
      compute_cov:
        if `True` computing both `mean` and `variance` and only `mean`
        otherwise.

    Returns:
      `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if
      `compute_cov == True` with potentially additional leading time dimensions.
    """
        if get is None:
            get = ('nngp', 'ntk')

        # train-train, test-train, test-test.
        k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov)

        # Infinite time.
        if t is None:
            return predict_inf(get)(get=get,
                                    k_test_train=k_td,
                                    nngp_test_test=nngp_tt)

        # Finite time.
        t = np.array(t) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, 1))

        def reshape_mean(mean):
            k = _get_first(k_dd if k_td is None else k_td)
            mean = mean.reshape(t_shape + k.shape[::2] + trace_shape)
            mean = np.moveaxis(mean, last_t_axes, trace_axes)
            return mean

        def reshape_cov(cov):
            k = _get_first(k_dd if k_td is None else k_td)
            cov_shape_t = t_shape + k.shape[::2] * 2
            return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape))

        out = {}

        for g in get:
            evals, evecs = eigenspace(g)

            # Training set.
            if k_td is None:
                mean = np.einsum('ji,ti,ki,k...->tj...',
                                 evecs,
                                 -expm1(evals, t),
                                 evecs,
                                 y_train_flat,
                                 optimize=True)

            # Test set.
            else:
                neg_inv_expm1 = -inv_expm1(evals, t)
                ktd_g = utils.make_2d(getattr(k_td, g))
                mean = np.einsum('lj,ji,ti,ki,k...->tl...',
                                 ktd_g,
                                 evecs,
                                 neg_inv_expm1,
                                 evecs,
                                 y_train_flat,
                                 optimize=True)

            mean = reshape_mean(mean)

            if nngp_tt is not None:
                nngp_dd = utils.make_2d(k_dd.nngp)

                # Training set.
                if k_td is None:
                    if g == 'nngp':
                        cov = np.einsum('ji,ti,ki->tjk',
                                        evecs,
                                        (np.maximum(evals, 0.) *
                                         np.exp(-2 * np.maximum(evals, 0.) *
                                                t / y_train.size)),
                                        evecs,
                                        optimize=True)

                    elif g == 'ntk':
                        exp = np.einsum('mi,ti,ki->tmk',
                                        evecs,
                                        np.exp(-np.maximum(evals, 0.) * t /
                                               y_train.size),
                                        evecs,
                                        optimize=True)
                        cov = np.einsum('tmk,kl,tnl->tmn',
                                        exp,
                                        nngp_dd,
                                        exp,
                                        optimize=True)

                    else:
                        raise ValueError(g)

                # Test set.
                else:
                    _nngp_tt = np.expand_dims(utils.make_2d(nngp_tt), 0)

                    if g == 'nngp':
                        cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml',
                                                   ktd_g,
                                                   evecs,
                                                   -inv_expm1(evals, 2 * t),
                                                   evecs,
                                                   ktd_g,
                                                   optimize=True)

                    elif g == 'ntk':
                        term_1 = np.einsum('mi,ti,ki,lk->tml',
                                           evecs,
                                           neg_inv_expm1,
                                           evecs,
                                           ktd_g,
                                           optimize=True)
                        term_2 = np.einsum(
                            'mj,ji,ti,ki,lk->tml',
                            ktd_g,
                            evecs,
                            neg_inv_expm1,
                            evecs,
                            utils.make_2d(k_td.nngp),  # pytype:disable=attribute-error
                            optimize=True)
                        term_2 += np.moveaxis(term_2, 1, 2)
                        cov = np.einsum('tji,jk,tkl->til',
                                        term_1,
                                        nngp_dd,
                                        term_1,
                                        optimize=True)
                        cov += -term_2 + _nngp_tt

                    else:
                        raise ValueError(g)

                out[g] = Gaussian(mean, reshape_cov(cov))

            else:
                out[g] = mean

        return out
Exemplo n.º 7
0
 def eigenspace(get: str):
   k_dd = getattr(get_k_train_train((get,)), get)
   k_dd = _add_diagonal_regularizer(utils.make_2d(k_dd), diag_reg,
                                    diag_reg_absolute_scale)
   return np.linalg.eigh(k_dd)