Exemplo n.º 1
0
    def testMaxLearningRate(self, train_shape, network, out_logits,
                            fn_and_kernel):

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        if len(train_shape) == 2:
            train_shape = (train_shape[0] * 5, train_shape[1] * 10)
        else:
            train_shape = (16, 8, 8, 3)
        x_train = np.asarray(normal(train_shape, seed=split))

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        y_train = np.asarray(
            stateless_uniform(shape=(train_shape[0], out_logits),
                              seed=split,
                              minval=0,
                              maxval=1) < 0.5, np.float32)
        # Regress to an MSE loss.
        loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2)
        grad_loss = jit(grad(loss))

        def get_loss(opt_state):
            return loss(get_params(opt_state), x_train)

        steps = 20

        for lr_factor in [0.5, 3.]:
            params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                           out_logits)
            g_dd = ntk(x_train, None, 'ntk')

            step_size = predict.max_learning_rate(
                g_dd, y_train_size=y_train.size) * lr_factor
            opt_init, opt_update, get_params = optimizers.sgd(step_size)
            opt_state = opt_init(params)

            init_loss = get_loss(opt_state)

            for i in range(steps):
                params = get_params(opt_state)
                opt_state = opt_update(i, grad_loss(params, x_train),
                                       opt_state)

            trained_loss = get_loss(opt_state)
            loss_ratio = trained_loss / (init_loss + 1e-12)
            if lr_factor == 3.:
                if not math.isnan(loss_ratio):
                    self.assertGreater(loss_ratio, 10.)
            else:
                self.assertLess(loss_ratio, 0.1)
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, apply_fn, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(),
                                       stax.Dense(10, 1., 0.05))

    key = stateless_uniform(shape=[2],
                            seed=[0, 0],
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
    _, params = init_fn(key, (1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train)

    # Get initial values of the network in function space.
    fx_train = apply_fn(params, x_train)
    fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                       loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
Exemplo n.º 3
0
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        training_steps = 1000
        learning_rate = 0.1
        ensemble_size = 1024

        init_fn, apply_fn, kernel_fn = stax.serial(
            stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key, x_test, x_train, y_train = self._get_inputs(
            out_logits, test_shape, train_shape)
        predict_fn_mse_ens = predict.gradient_descent_mse_ensemble(
            kernel_fn,
            x_train,
            y_train,
            learning_rate=learning_rate,
            diag_reg=0.)

        train = (x_train, y_train)
        ensemble_key = tf_random_split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)
        rtol = 0.08

        for x in [None, 'x_test']:
            with self.subTest(x=x):
                x = x if x is None else x_test
                x_fin = x_train if x is None else x_test
                ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin)

                mean_emp = np.mean(ensemble_fx, axis=0)
                mean_subtracted = ensemble_fx - mean_emp
                cov_emp = np.einsum(
                    'ijk,ilk->jl',
                    mean_subtracted,
                    mean_subtracted,
                    optimize=True) / (mean_subtracted.shape[0] *
                                      mean_subtracted.shape[-1])

                ntk = predict_fn_mse_ens(training_steps,
                                         x,
                                         'ntk',
                                         compute_cov=True)
                self._assertAllClose(mean_emp, ntk.mean, rtol)
                self._assertAllClose(cov_emp, ntk.covariance, rtol)
Exemplo n.º 4
0
    def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits,
                            fn_and_kernel, momentum, learning_rate, t, loss):
        key, x_test, x_train, y_train = self._get_inputs(
            out_logits, test_shape, train_shape)

        params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                       out_logits)

        g_dd = ntk(x_train, None, 'ntk')
        g_td = ntk(x_test, x_train, 'ntk')

        # Regress to an MSE loss.
        loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
        grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train)))

        trace_axes = () if g_dd.ndim == 4 else (-1, )
        if loss == 'mse_analytic':
            if momentum is not None:
                raise absltest.SkipTest(momentum)
            predictor = predict.gradient_descent_mse(
                g_dd,
                y_train,
                learning_rate=learning_rate,
                trace_axes=trace_axes)
        elif loss == 'mse':
            predictor = predict.gradient_descent(loss_fn,
                                                 g_dd,
                                                 y_train,
                                                 learning_rate=learning_rate,
                                                 momentum=momentum,
                                                 trace_axes=trace_axes)
        else:
            raise NotImplementedError(loss)

        predictor = jit(predictor)

        fx_train_0 = f(params, x_train)
        fx_test_0 = f(params, x_test)

        self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum)
        self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum)
        if loss == 'mse_analytic':
            self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td,
                                y_train)

        if momentum is None:
            opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        else:
            opt_init, opt_update, get_params = optimizers.momentum(
                learning_rate, momentum)

        opt_state = opt_init(params)
        for i in range(t):
            params = get_params(opt_state)
            opt_state = opt_update(i, grad_loss(params, x_train), opt_state)

        params = get_params(opt_state)

        fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test)
        fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td)

        self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL)
        self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
Exemplo n.º 5
0
def main(unused_argv):
    # Build data and .
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            permute_train=True)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = stateless_uniform(shape=[2],
                            seed=[0, 0],
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
    _, params = init_fn(key, (1, 784))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = momentum(FLAGS.learning_rate, 0.9)
    # opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)

    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # momentum = MomentumOptimizer(learning_rate=FLAGS.learning_rate, momentum=0.9)
    # momentum_lin = MomentumOptimizer(learning_rate=FLAGS.learning_rate, momentum=0.9)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    steps_per_epoch = 50000 // FLAGS.batch_size

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        # x = np.asarray(x)
        # y = np.asarray(y)

        # momentum.apply_gradients((grad_loss(params, x, y), params))
        # momentum.apply_gradients((grad_loss_lin(params_lin, x, y), params_lin))

        if i % steps_per_epoch == 0:
            print('{}\t{}\t{}'.format(epoch, loss(f(params, x), y),
                                      loss(f_lin(params_lin, x), y)))
            epoch += 1

    # Print out summary data comparing the linear / nonlinear model.
    x, y = x_train[:10000], y_train[:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', y_test, f(params, x_test),
                       f_lin(params_lin, x_test), loss)
Exemplo n.º 6
0
def gradient_descent(
    loss: Callable[[np.ndarray, np.ndarray], float],
    k_train_train: np.ndarray,
    y_train: np.ndarray,
    learning_rate: float = 1.,
    momentum: float = None,
    trace_axes: Axes = (-1, )
) -> Callable[[
        ArrayOrScalar, Union[ArrayOrScalar,
                             ODEState], ArrayOrScalar, Optional[np.ndarray]
], Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]]:
    r"""Predicts the outcome of function space training using gradient descent.

  Uses an ODE solver. If `momentum != None`, solves a continuous-time version of
  gradient descent with momentum (note: this case uses standard momentum as
  opposed to Nesterov momentum).

  Solves the function space ODE for [momentum] gradient descent with a given
  `loss` (detailed in [*]) given a Neural Tangent Kernel[s] over the dataset[s]
  at arbitrary time[s] (step[s]) `t`. Note that for gradient descent
  `absolute_time = learning_rate * t` and the scales of the learning rate and
  query step[s] `t` are interchangeable. However, the momentum gradient descent
  ODE is solved in the units of `learning_rate**0.5`, and therefore
  `absolute_time = learning_rate**0.5 * t`, hence the `learning_rate` and
  training time[s] (step[s]) `t` scales are not interchangeable.

  [*] https://arxiv.org/abs/1902.06720

  Example:
    >>> from neural_tangents import empirical_ntk_fn
    >>> from neural_tangents import predict
    >>>
    >>> t = 1e-7
    >>> learning_rate = 1e-2
    >>> momentum = 0.9
    >>>
    >>> kernel_fn = empirical_ntk_fn(f)
    >>> k_test_train = kernel_fn(x_test, x_train, params)
    >>>
    >>> from jax.experimental import stax
    >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
    >>> predict_fn = predict.gradient_descent(cross_entropy, k_train_train,
    >>>                                       y_train, learning_rate, momentum)
    >>>
    >>> fx_train_0 = f(params, x_train)
    >>> fx_test_0 = f(params, x_test)
    >>>
    >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0,
    >>>                                    k_test_train)

  Args:
    loss:
      a loss function whose signature is `loss(f(x_train), y_train)`. Note:
      the loss function should treat the batch and output dimensions
      symmetrically.
    k_train_train:
      kernel on the training data. Must have the shape of
      `zip(y_train.shape, y_train.shape)` with `trace_axes` absent.
    y_train:
      targets for the training data.
    learning_rate:
      learning rate, step size.
    momentum:
      momentum scalar.
    trace_axes:
      `f(x_train)` axes such that `k_train_train` lacks these pairs of
      dimensions and is to be interpreted as :math:`\Theta \otimes I`, i.e.
      block-diagonal along `trace_axes`. These can can be specified either to
      save space and compute, or to even improve approximation accuracy of the
      infinite-width or infinite-samples limit, since in in these limits the
      covariance along channel / feature / logit axes indeed converges to a
      constant-diagonal matrix. However, if you target linearized dynamics of a
      specific finite-width network, `trace_axes=()` will yield most accurate
      result.

  Returns:
    A function that returns output train [and test] set[s] predictions at
    time[s] `t`.
  """
    _, odd, _, _ = _get_axes(k_train_train)
    trace_axes = utils.canonicalize_axis(trace_axes, y_train)
    non_t_axes = tuple(a for a in range(y_train.ndim) if a not in trace_axes)
    last_t_axes = range(-len(trace_axes), 0)

    dtype = k_train_train.dtype
    grad_loss = grad(lambda fx: loss(fx, y_train))

    if momentum is not None:
        learning_rate **= 0.5
        momentum = (momentum - 1.0) / learning_rate

    def get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape):
        if isinstance(fx_train_or_state_0, ODEState):
            fx_train_0 = fx_train_or_state_0.fx_train
            fx_test_0 = fx_train_or_state_0.fx_test
            qx_train_0 = fx_train_or_state_0.qx_train
            qx_test_0 = fx_train_or_state_0.qx_test
        else:
            fx_train_0 = fx_train_or_state_0
            qx_train_0 = qx_test_0 = None

        if fx_train_0 is None:
            fx_train_0 = np.zeros_like(y_train, dtype)
        else:
            fx_train_0 = np.broadcast_to(fx_train_0, y_train.shape)

        if fx_test_0 is not None:
            fx_test_0 = np.broadcast_to(fx_test_0, fx_test_shape)

        if momentum is None:
            if qx_train_0 is not None or qx_test_0 is not None:
                raise ValueError('Got passed momentum state variables, while '
                                 '`momentum is None`.')
        else:
            qx_train_0 = (np.zeros_like(y_train, dtype) if qx_train_0 is None
                          else np.broadcast_to(qx_train_0, y_train.shape))
            qx_test_0 = (None if fx_test_0 is None else
                         (np.zeros(fx_test_shape, dtype) if qx_test_0 is None
                          else np.broadcast_to(qx_test_0, fx_test_shape)))

        return ODEState(fx_train_0, fx_test_0, qx_train_0, qx_test_0)  # pytype: disable=wrong-arg-count

    def get_dstate_dt(k_test_train):
        def dstate_dt(state_t: ODEState, unused_t) -> ODEState:
            fx_train_t, fx_test_t, qx_train_t, qx_test_t = (state_t.fx_train,
                                                            state_t.fx_test,
                                                            state_t.qx_train,
                                                            state_t.qx_test)

            dy_df_t = grad_loss(fx_train_t)

            fx_train_t = -np.moveaxis(
                np.tensordot(k_train_train, dy_df_t,
                             (odd, non_t_axes)), last_t_axes, trace_axes)
            if fx_test_t is not None:
                fx_test_t = -np.moveaxis(
                    np.tensordot(k_test_train, dy_df_t,
                                 (odd, non_t_axes)), last_t_axes, trace_axes)

            if momentum is None:
                return ODEState(fx_train_t, fx_test_t)  # pytype: disable=wrong-arg-count

            fx_train_t += momentum * qx_train_t
            if qx_test_t is not None:
                fx_test_t += momentum * qx_test_t

            return ODEState(qx_train_t, qx_test_t, fx_train_t, fx_test_t)  # pytype: disable=wrong-arg-count

        return dstate_dt

    def predict_fn(
        t: ArrayOrScalar = None,
        fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0.,
        fx_test_0: ArrayOrScalar = None,
        k_test_train: np.ndarray = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]:
        """Return output predictions on train [and test] set[s] at time[s] `t`.

    Args:
      t:
        a scalar or array of scalars of any shape in strictly increasing order.
        `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of
        training steps (but can be fractional).
      fx_train_or_state_0:
        either (a) output of the network at `t == 0` on the training set or (b)
        complete ODE state (`predict.ODEState`). Pass an ODE state if you want
        to operate on the full ODE state instead of output variables only
        (useful for inspecting auxiliary variables or resuming an optimizer with
        auxiliary variables from a specific state. Note that only
        `momentum != None` optimizer currently has auxiliary variables. To
        initialize an ODE state from scratch, call
        `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an
        ODE state is returned. `fx_train_0=None` means to not compute
        predictions on the training set.
      fx_test_0:
        output of the network at `t == 0` on the test set. `fx_test_0=None`
        means to not compute predictions on the test set.
      k_test_train:
        kernel relating test data with training data. Must have the shape of
        `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass
        `k_test_train=None` if you only need predictions on the training set.

    Returns:
      `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with
      potentially additional leading time dimensions matching `t.shape`.
      Alternatively can return an `ODEState` at time[s] `t`.

    Raises:
      ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`.
    """
        _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train)

        t = np.array(t if t is not None else np.inf, dtype) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, ))

        # ODE solver requires `t[0]` to be the time where `fx_train_0` [and
        # `fx_test_0`] are evaluated, but also a strictly increasing sequence of
        # timesteps, so we always temporarily append an [almost] `0` at the start.
        identity = lambda x: x
        t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype),
                      np.zeros((1, ), t.dtype))
        t = np.concatenate([t0, t])

        # Solve the ODE.
        fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes)
        state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape)
        state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t)

        # Remove the added `t0`.
        trim = lambda x: x[1:].reshape(t_shape + x.shape[1:])
        trim_tree = lambda tree: tree_map(trim, tree)
        state_t = trim_tree(state_t)

        # `ODEState` -> `ODEState`
        if isinstance(fx_train_or_state_0, ODEState):
            return state_t

        # `np.ndarray` -> `np.ndarray`
        fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test

        if fx_train_or_state_0 is not None and fx_test_0 is None:
            return fx_train_t
        if fx_test_0 is not None and fx_train_or_state_0 is None:
            return fx_test_t
        return fx_train_t, fx_test_t

    return predict_fn