Esempio n. 1
0
def max_learning_rate(kdd: np.ndarray,
                      num_outputs: int = -1,
                      eps: float = 1e-12) -> float:
    r"""Computing the maximal feasible learning rate for infinite width NNs.

  The network is assumed to be trained using SGD or full-batch GD with mean
  squared loss. The loss is assumed to have the form
  :math:`1/(2 * batch_size * num_outputs) \|f(train_x) - train_y\|^2`. The
  maximal feasible learning rate is the largest `\eta` such that the operator
  :math:`(I - \eta / (batch_size * num_outputs) * NTK)` is a contraction, which
  is :math:`2 * batch_size * num_outputs * \lambda_max(NTK)`.

  Args:
    kdd: The analytic or empirical NTK of (train_x, train_x).
    num_outputs: The number of outputs of the neural network. If `kdd` is the
      analytic ntk, `num_outputs` must be provided. Otherwise `num_outputs=-1`
      and `num_outputs` is computed via the size of `kdd`.
    eps: A float to avoid zero divisor.

  Returns:
    The maximal feasible learning rate for infinite width NNs.
  """

    if kdd.ndim not in [2, 4]:
        raise ValueError('`kdd` must be a 2d or 4d tensor.')
    if kdd.ndim == 2 and num_outputs == -1:
        raise ValueError(
            '`num_outputs` must be provided for theoretical kernel.')
    if kdd.ndim == 2:
        factor = kdd.shape[0] * num_outputs
    else:
        kdd = empirical.flatten_features(kdd)
        factor = kdd.shape[0]
    if kdd.shape[0] != kdd.shape[1]:
        raise ValueError('`kdd` must be a square matrix.')
    if _is_on_cpu(kdd):
        max_eva = osp.linalg.eigvalsh(kdd,
                                      eigvals=(kdd.shape[0] - 1,
                                               kdd.shape[0] - 1))[-1]
    else:
        max_eva = np.linalg.eigvalsh(kdd)[-1]
    lr = 2 * factor / (max_eva + eps)
    return lr
Esempio n. 2
0
def gradient_descent_mse(g_dd, y_train, g_td=None, diag_reg=0.):
    """Predicts the outcome of function space gradient descent training on MSE.

  Analytically solves for the continuous-time version of gradient descent.

  Uses the analytic solution for gradient descent on an MSE loss in function
  space detailed in [*] given a Neural Tangent Kernel over the dataset. Given
  NTKs, this function will return a function that predicts the time evolution
  for function space points at arbitrary times. Note that times are continuous
  and are measured in units of the learning rate so t = learning_rate * steps.

  [*] https://arxiv.org/abs/1806.07572

  Example:
    ```python
    >>> from neural_tangents import predict
    >>>
    >>> train_time = 1e-7
    >>> kernel_fn = empirical(f)
    >>> g_td = kernel_fn(x_test, x_train, params)
    >>>
    >>> predict_fn = predict.gradient_descent_mse(g_dd, y_train, g_td)
    >>>
    >>> fx_train_initial = f(params, x_train)
    >>> fx_test_initial = f(params, x_test)
    >>>
    >>> fx_train_final, fx_test_final = predict_fn(
    >>>          fx_train_initial, fx_test_initial, train_time)
    ```

  Args:
    g_dd: A kernel on the training data. The kernel should be an `np.ndarray` of
      shape [n_train * output_dim, n_train * output_dim] or [n_train, n_train].
      In the latter case, the kernel is assumed to be block diagonal over the
      logits.
    y_train: A `np.ndarray` of shape [n_train, output_dim] of targets for the
      training data.
    g_td: A Kernel relating training data with test data. The kernel should be
      an `np.ndarray` of shape [n_test * output_dim, n_train * output_dim] or
      [n_test, n_train]. Note; g_td should have been created in the convention
      kernel_fn(x_train, x_test, params).
    diag_reg: A float, representing the strength of the regularization.

  Returns:
    A function that predicts outputs after t = learning_rate * steps of
    training.

    If g_td is None:
      The function returned is predict(fx, t). Here fx is an `np.ndarray` of
      network outputs and has shape [n_train, output_dim], t is a floating point
      time. predict(fx, t) returns an `np.ndarray` of predictions of shape
      [n_train, output_dim].

    If g_td is not None:
      If a test set Kernel is specified then it returns a function,
      predict(fx_train, fx_test, t). Here fx_train and fx_test are ndarays of
      network outputs and have shape [n_train, output_dim] and
      [n_test, output_dim] respectively and t is a floating point time.
      predict(fx_train, fx_test, t) returns a tuple of predictions of shape
      [n_train, output_dim] and [n_test, output_dim] for train and test points
      respectively.
  """

    g_dd = empirical.flatten_features(g_dd)

    normalization = y_train.size
    output_dimension = y_train.shape[-1]
    expm1_fn, inv_expm1_fn = (_make_expm1_fn(normalization),
                              _make_inv_expm1_fn(normalization))

    def fl(fx):
        """Flatten outputs."""
        return np.reshape(fx, (-1, ))

    def ufl(fx):
        """Unflatten outputs."""
        return np.reshape(fx, (-1, output_dimension))

    # Check to see whether the kernel has a logit dimension.
    if y_train.size > g_dd.shape[-1]:
        out_dim, ragged = divmod(y_train.size, g_dd.shape[-1])
        if ragged or out_dim != y_train.shape[-1]:
            raise ValueError()
        fl = lambda x: x
        ufl = lambda x: x

    g_dd_plus_reg = _add_diagonal_regularizer(g_dd, diag_reg)
    expm1_dot_vec, inv_expm1_dot_vec = _eigen_fns(g_dd_plus_reg,
                                                  (expm1_fn, inv_expm1_fn))

    if g_td is None:

        def train_predict(dt, fx=0.0):
            gx_train = fl(fx - y_train)
            dgx = expm1_dot_vec(gx_train, dt)
            return ufl(dgx) + fx

        return train_predict

    g_td = empirical.flatten_features(g_td)

    def predict_using_kernel(dt, fx_train=0., fx_test=0.):
        gx_train = fl(fx_train - y_train)
        dgx = expm1_dot_vec(gx_train, dt)
        # Note: consider use a linalg solve instead of the eigeninverse
        # dfx = sp.linalg.solve(g_dd, dgx, sym_pos=True)
        dfx = inv_expm1_dot_vec(gx_train, dt)
        dfx = np.dot(g_td, dfx)
        return ufl(dgx) + fx_train, fx_test + ufl(dfx)

    return predict_using_kernel
Esempio n. 3
0
def momentum(g_dd, y_train, loss, learning_rate, g_td=None, momentum=0.9):
    r"""Predicts the outcome of function space training using momentum descent.

  Solves a continuous-time version of standard momentum instead of
  Nesterov momentum using an ODE solver.

  Solves the function space ODE for momentum with a given loss (detailed
  in [*]) given a Neural Tangent Kernel over the dataset. This function returns
  a triplet of functions that initialize state variables, predicts the time
  evolution for function space points at arbitrary times and retrieves the
  function-space outputs from the state. Note that times are continuous and are
  measured in units of the learning rate so that
  t = \sqrt(learning_rate) * steps.

  This function uses the scipy ode solver with the 'dopri5' algorithm.

  [*] https://arxiv.org/abs/1902.06720

  Example:
    ```python
    >>> train_time = 1e-7
    >>> learning_rate = 1e-2
    >>>
    >>> kernel_fn = empirical(f)
    >>> g_td = kernel_fn(x_test, x_train, params)
    >>>
    >>> from jax.experimental import stax
    >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
    >>> init_fn, predict_fn, get_fn = predict.momentum(
    >>>                   g_dd, y_train, cross_entropy, learning_rate, g_td)
    >>>
    >>> fx_train_initial = f(params, x_train)
    >>> fx_test_initial = f(params, x_test)
    >>>
    >>> lin_state = init_fn(fx_train_initial, fx_test_initial)
    >>> lin_state = predict_fn(lin_state, train_time)
    >>> fx_train_final, fx_test_final = get_fn(lin_state)
    ```python

  Args:
    g_dd: Kernel on the training data. The kernel should be an `np.ndarray` of
      shape [n_train * output_dim, n_train * output_dim].
    y_train: A `np.ndarray` of shape [n_train, output_dim] of labels for the
      training data.
    loss: A loss function whose signature is loss(fx, y_hat) where fx an
      `np.ndarray` of function space outputs of the network and y_hat are
      labels. Note: the loss function should treat the batch and output
        dimensions symmetrically.
    learning_rate:  A float specifying the learning rate.
    g_td: Kernel relating training data with test data. Should be an
      `np.ndarray` of shape [n_test * output_dim, n_train * output_dim]. Note:
        g_td should have been created in the convention g_td = kernel_fn(x_test,
        x_train, params).
    momentum: float specifying the momentum.

  Returns:
    Functions to predicts outputs after t = \sqrt(learning_rate) * steps of
    training. Generically three functions are returned, an init_fn that creates
    auxiliary velocity variables needed for optimization and packs them into
    a state variable, a predict_fn that computes the time-evolution of the state
    for some dt, and a get_fn that extracts the predictions from the state.

    If g_td is None:
      init_fn(fx_train): Takes a single `np.ndarray` of shape
        [n_train, output_dim] and returns a tuple containing the output_dim as
        an int and an `np.ndarray` of shape [2 * n_train * output_dim].

      predict_fn(state, dt): Takes a state described above and a floating point
        time. Returns a new state with the same type and shape.

      get_fn(state): Takes a state and returns an `np.ndarray` of shape
        [n_train, output_dim].

    If g_td is not None:
      init_fn(fx_train, fx_test): Takes two `np.ndarray`s of shape
        [n_train, output_dim] and [n_test, output_dim] respectively. Returns a
        tuple with an int giving 2 * n_train * output_dim, an int containing the
        output_dim, and an `np.ndarray` of shape
        [2 * (n_train + n_test) * output_dim].

      predict_fn(state, dt): Takes a state described above and a floating point
        time. Returns a new state with the same type and shape.

      get_fn(state): Takes a state and returns two `np.ndarray` of shape
        [n_train, output_dim] and [n_test, output_dim] respectively.
  """
    output_dimension = y_train.shape[-1]

    g_dd = empirical.flatten_features(g_dd)

    momentum = (momentum - 1.0) / np.sqrt(learning_rate)

    def fl(fx):
        """Flatten outputs."""
        return np.reshape(fx, (-1, ))

    def ufl(fx):
        """Unflatten outputs."""
        return np.reshape(fx, (-1, output_dimension))

    # These functions are used inside the integrator only if the kernel is
    # diagonal over the logits.
    ifl = lambda x: x
    iufl = lambda x: x

    # Check to see whether the kernel has a logit dimension.
    if y_train.size > g_dd.shape[-1]:
        out_dim, ragged = divmod(y_train.size, g_dd.shape[-1])
        if ragged or out_dim != y_train.shape[-1]:
            raise ValueError()
        ifl = fl
        iufl = ufl

    y_train = np.reshape(y_train, (-1))
    grad_loss = grad(functools.partial(loss, y_hat=y_train))

    if g_td is None:

        def dr_dt(unused_t, r):
            fx, qx = np.split(r, 2)
            dfx = qx
            dqx = momentum * qx - ifl(np.dot(g_dd, iufl(grad_loss(fx))))
            return np.concatenate((dfx, dqx), axis=0)

        def init_fn(fx_train=0.):
            fx_train = fl(fx_train)
            qx_train = np.zeros_like(fx_train)
            return np.concatenate((fx_train, qx_train), axis=0)

        def predict_fn(state, dt):
            state = state

            solver = ode(dr_dt).set_integrator('dopri5')
            solver.set_initial_value(state, 0)
            solver.integrate(dt)

            return solver.y

        def get_fn(state):
            return ufl(np.split(state, 2)[0])

    else:
        g_td = empirical.flatten_features(g_td)

        def dr_dt(unused_t, r, train_size):
            train, test = r[:train_size], r[train_size:]
            fx_train, qx_train = np.split(train, 2)
            _, qx_test = np.split(test, 2)
            dfx_train = qx_train
            dqx_train = \
                momentum * qx_train - ifl(np.dot(g_dd, iufl(grad_loss(fx_train))))
            dfx_test = qx_test
            dqx_test = \
                momentum * qx_test - ifl(np.dot(g_td, iufl(grad_loss(fx_train))))
            return np.concatenate((dfx_train, dqx_train, dfx_test, dqx_test),
                                  axis=0)

        def init_fn(fx_train=0., fx_test=0.):
            train_size = fx_train.shape[0]
            fx_train, fx_test = fl(fx_train), fl(fx_test)
            qx_train = np.zeros_like(fx_train)
            qx_test = np.zeros_like(fx_test)
            return (2 * train_size * output_dimension,
                    np.concatenate((fx_train, qx_train, fx_test, qx_test),
                                   axis=0))

        def predict_fn(state, dt):
            train_size, state = state
            solver = ode(dr_dt).set_integrator('dopri5')
            solver.set_initial_value(state, 0).set_f_params(train_size)
            solver.integrate(dt)

            return train_size, solver.y

        def get_fn(state):
            train_size, state = state
            train, test = state[:train_size], state[train_size:]
            return ufl(np.split(train, 2)[0]), ufl(np.split(test, 2)[0])

    return init_fn, predict_fn, get_fn
Esempio n. 4
0
def gradient_descent(g_dd, y_train, loss, g_td=None):
    """Predicts the outcome of function space gradient descent training on `loss`.

  Solves for continuous-time gradient descent using an ODE solver.

  Solves the function space ODE for continuous gradient descent with a given
  loss (detailed in [*]) given a Neural Tangent Kernel over the dataset. This
  function returns a function that predicts the time evolution for function
  space points at arbitrary times. Note that times are continuous and are
  measured in units of the learning rate so that t = learning_rate * steps.

  This function uses the scipy ode solver with the 'dopri5' algorithm.

  [*] https://arxiv.org/abs/1902.06720

  Example:
    ```python
    >>> from jax.experimental import stax
    >>> from neural_tangents import predict
    >>>
    >>> train_time = 1e-7
    >>> kernel_fn = empirical(f)
    >>> g_td = kernel_fn(x_test, x_train, params)
    >>>
    >>> from jax.experimental import stax
    >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
    >>> predict_fn = predict.gradient_descent(
    >>>     g_dd, y_train, cross_entropy, g_td)
    >>>
    >>> fx_train_initial = f(params, x_train)
    >>> fx_test_initial = f(params, x_test)
    >>>
    >>> fx_train_final, fx_test_final = predict_fn(
    >>>     fx_train_initial, fx_test_initial, train_time)
    ```
  Args:
    g_dd: A Kernel on the training data. The kernel should be an `np.ndarray` of
      shape [n_train * output_dim, n_train * output_dim] or [n_train, n_train].
      In the latter case it is assumed that the kernel is block diagonal over
      the logits.
    y_train: A `np.ndarray` of shape [n_train, output_dim] of labels for the
      training data.
    loss: A loss function whose signature is loss(fx, y_hat) where fx is an
      `np.ndarray` of function space output_dim of the network and y_hat are
      targets. Note: the loss function should treat the batch and output
        dimensions symmetrically.
    g_td: A Kernel relating training data with test data. The kernel should be
      an `np.ndarray` of shape [n_test * output_dim, n_train * output_dim] or
      [n_test, n_train]. Note: g_td should have been created in the convention
        kernel_fn(x_test, x_train, params).

  Returns:
    A function that predicts outputs after t = learning_rate * steps of
    training.

    If g_td is None:
      The function returned is predict(fx, t). Here fx is an `np.ndarray` of
      network outputs and has shape [n_train, output_dim], t is a floating point
      time. predict(fx, t) returns an `np.ndarray` of predictions of shape
      [n_train, output_dim].

    If g_td is not None:
      If a test set Kernel is specified then it returns a function,
      predict(fx_train, fx_test, t). Here fx_train and fx_test are ndarays of
      network outputs and have shape [n_train, output_dim] and
      [n_test, output_dim] respectively and t is a floating point time.
      predict(fx_train, fx_test, t) returns a tuple of predictions of shape
      [n_train, output_dim] and [n_test, output_dim] for train and test points
      respectively.
  """

    output_dimension = y_train.shape[-1]

    g_dd = empirical.flatten_features(g_dd)

    def fl(fx):
        """Flatten outputs."""
        return np.reshape(fx, (-1, ))

    def ufl(fx):
        """Unflatten outputs."""
        return np.reshape(fx, (-1, output_dimension))

    # These functions are used inside the integrator only if the kernel is
    # diagonal over the logits.
    ifl = lambda x: x
    iufl = lambda x: x

    # Check to see whether the kernel has a logit dimension.
    if y_train.size > g_dd.shape[-1]:
        out_dim, ragged = divmod(y_train.size, g_dd.shape[-1])
        if ragged or out_dim != y_train.shape[-1]:
            raise ValueError()
        ifl = fl
        iufl = ufl

    y_train = np.reshape(y_train, (-1))
    grad_loss = grad(functools.partial(loss, y_hat=y_train))

    if g_td is None:
        dfx_dt = lambda unused_t, fx: -ifl(np.dot(g_dd, iufl(grad_loss(fx))))

        def predict(dt, fx=0.):
            r = ode(dfx_dt).set_integrator('dopri5')
            r.set_initial_value(fl(fx), 0)
            r.integrate(dt)

            return ufl(r.y)
    else:
        g_td = empirical.flatten_features(g_td)

        def dfx_dt(unused_t, fx, train_size):
            fx_train = fx[:train_size]
            dfx_train = -ifl(np.dot(g_dd, iufl(grad_loss(fx_train))))
            dfx_test = -ifl(np.dot(g_td, iufl(grad_loss(fx_train))))
            return np.concatenate((dfx_train, dfx_test), axis=0)

        def predict(dt, fx_train=0., fx_test=0.):
            r = ode(dfx_dt).set_integrator('dopri5')

            fx = fl(np.concatenate((fx_train, fx_test), axis=0))
            train_size, output_dim = fx_train.shape
            r.set_initial_value(fx, 0).set_f_params(train_size * output_dim)
            r.integrate(dt)
            fx = ufl(r.y)

            return fx[:train_size], fx[train_size:]

    return predict
Esempio n. 5
0
def analytic_mse(g_dd, y_train, g_td=None):
    """Predicts the outcome of function space training with an MSE loss.

  Uses the analytic solution for gradient descent on an MSE loss in function
  space detailed in [*] given a Neural Tangent Kernel over the dataset. Given
  NTKs, this function will return a function that predicts the time evolution
  for function space points at arbitrary times. Note that times are continuous
  and are measured in units of the learning rate so t = learning_rate * steps.

  [*] https://arxiv.org/abs/1806.07572

  Example:
    >>> train_time = 1e-7
    >>> ker_fun = empirical(f)
    >>> g_td = ker_fun(x_test, x_train, params)
    >>>
    >>> predict_fn = analytic_mse_predictor(g_dd, y_train, g_td)
    >>>
    >>> fx_train_initial = f(params, x_train)
    >>> fx_test_initial = f(params, x_test)
    >>>
    >>> fx_train_final, fx_test_final = predict_fn(
    >>>          fx_train_initial, fx_test_initial, train_time)

  Args:
    g_dd: A kernel on the training data. The kernel should be an `np.ndarray` of
      shape [n_train * output_dim, n_train * output_dim] or [n_train, n_train].
      In the latter case, the kernel is assumed to be block diagonal over the
      logits.
    y_train: A `np.ndarray` of shape [n_train, output_dim] of targets for the
      training data.
    g_td: A Kernel relating training data with test data. The kernel should be
      an `np.ndarray` of shape [n_test * output_dim, n_train * output_dim] or
      [n_test, n_train].
      Note: g_td should have been created in the convention ker_fun(x_train,
        x_test, params).

  Returns:
    A function that predicts outputs after t = learning_rate * steps of
    training.

    If g_td is None:
      The function returned is predict(fx, t). Here fx is an `np.ndarray` of
      network outputs and has shape [n_train, output_dim], t is a floating point
      time. predict(fx, t) returns an `np.ndarray` of predictions of shape
      [n_train, output_dim].

    If g_td is not None:
      If a test set Kernel is specified then it returns a function,
      predict(fx_train, fx_test, t). Here fx_train and fx_test are ndarays of
      network outputs and have shape [n_train, output_dim] and
      [n_test, output_dim] respectively and t is a floating point time.
      predict(fx_train, fx_test, t) returns a tuple of predictions of shape
      [n_train, output_dim] and [n_test, output_dim] for train and test points
      respectively.
  """

    g_dd = _canonicalize_kernel_to_ntk(g_dd)
    g_td = _canonicalize_kernel_to_ntk(g_td)

    g_dd = empirical.flatten_features(g_dd)

    # TODO(schsam): Eventually, we may want to handle non-symmetric kernels for
    # e.g. masking. Additionally, once JAX supports eigh on TPU, we probably want
    # to switch to JAX's eigh.
    if xla_bridge.get_backend().platform == 'tpu':
        eigh = np.onp.linalg.eigh
    else:
        eigh = np.linalg.eigh

    evals, evecs = eigh(g_dd)
    ievecs = np.transpose(evecs)

    normalization = y_train.size
    output_dimension = y_train.shape[-1]

    def fl(fx):
        """Flatten outputs."""
        return np.reshape(fx, (-1, ))

    def ufl(fx):
        """Unflatten outputs."""
        return np.reshape(fx, (-1, output_dimension))

    # Check to see whether the kernel has a logit dimension.
    if y_train.size > g_dd.shape[-1]:
        out_dim, ragged = divmod(y_train.size, g_dd.shape[-1])
        if ragged or out_dim != y_train.shape[-1]:
            raise ValueError()
        fl = lambda x: x
        ufl = lambda x: x

    def predict(gx, dt):
        gx_ = np.diag(np.exp(-evals * dt / normalization))
        gx_ = np.dot(evecs, gx_)
        gx_ = np.dot(gx_, ievecs)
        gx_ = np.dot(gx_, gx)
        return gx_

    if g_td is None:
        return lambda dt, fx=0.: ufl(predict(fl(fx - y_train), dt)) + y_train

    g_td = empirical.flatten_features(g_td)
    mevals = np.diag(1.0 / evals)
    inverse = np.dot(np.dot(evecs, mevals), ievecs)

    def predict_using_kernel(dt, fx_train=0., fx_test=0.):
        gx_train = fl(fx_train - y_train)
        dgx = predict(gx_train, dt) - gx_train
        dfx = np.dot(inverse, dgx)
        dfx = np.dot(g_td, dfx)
        return ufl(dgx) + fx_train, fx_test + ufl(dfx)

    return predict_using_kernel