Ejemplo n.º 1
0
def main(unused_argv):
  # Build data and .
  print('Loading data.')
  x_train, y_train, x_test, y_test = datasets.mnist(permute_train=True)

  # Build the network
  init_fn, f = stax.serial(
      layers.Dense(2048),
      stax.Tanh,
      layers.Dense(10))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Linearize the network about its initial parameters.
  f_lin = tangents.linearize(f, params)

  # Create and initialize an optimizer for both f and f_lin.
  opt_init, opt_apply, get_params = optimizers.momentum(FLAGS.learning_rate, 0.9)
  opt_apply = jit(opt_apply)

  state = opt_init(params)
  state_lin = opt_init(params)

  # Create a cross-entropy loss function.
  loss = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)

  # Specialize the loss function to compute gradients for both linearized and
  # full networks.
  grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
  grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

  # Train the network.
  print('Training.')
  print('Epoch\tLoss\tLinearized Loss')
  print('------------------------------------------')

  epoch = 0
  steps_per_epoch = 50000 // FLAGS.batch_size

  for i, (x, y) in enumerate(datasets.minibatch(
      x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)):

    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x, y), state)

    params_lin = get_params(state_lin)
    state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

    if i % steps_per_epoch == 0:
      print('{}\t{:.4f}\t{:.4f}'.format(
          epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y)))
      epoch += 1

  # Print out summary data comparing the linear / nonlinear model.
  x, y = x_train[:10000], y_train[:10000]
  util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
  util.print_summary(
      'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def inner_optimization_lin(params_init, x_train, y_train, n_inner_step):
    f_lin = tangents.linearize(f, params_init)
    grad_loss_lin = jit(grad(lambda p, x, y: loss(f_lin(p, x), y)))
    param_loss_lin = jit(lambda p, x, y: loss(f_lin(p, x), y))

    state = inner_opt_init(params_init)
    for i in range(n_inner_step):
        p = inner_get_params(state)
        g = grad_loss_lin(p, x_train, y_train)
        state = inner_opt_update(i, g, state)
    p = inner_get_params(state)
    l_train = param_loss_lin(p, x_train, y_train)
    return p, l_train, f_lin, param_loss_lin
Ejemplo n.º 3
0
  def testLinearization(self, shape):
    def f(w, x):
      return np.dot(w, x)

    key = random.PRNGKey(0)
    key, split = random.split(key)
    w0 = random.normal(split, shape)
    key, split = random.split(key)
    x = random.normal(split, (shape[-1],))

    f_lin = tangents.linearize(f, w0)

    for _ in range(10):
      key, split = random.split(key)
      w = random.normal(split, shape)
      self.assertAllClose(f(w, x), f_lin(w, x), True)
ntk_frequency = 50
plot_update_frequency = 100
for i, task_batch in tqdm(enumerate(
        taskbatch(task_fn=task_fn,
                  batch_size=args.task_batch_size,
                  n_task=args.n_train_task)),
                          total=args.n_train_task // args.task_batch_size):
    aux = dict()
    # ntk
    if i == 0 or (i + 1) % (args.n_train_task // args.task_batch_size //
                            ntk_frequency) == 0:
        ntk = tangents.ntk(f, batch_size=100)(outer_get_params(outer_state),
                                              task_eval['x_train'])
        aux['ntk_train_rank_eval'] = onp.linalg.matrix_rank(ntk)
        f_lin = tangents.linearize(f, outer_get_params(outer_state_lin))
        ntk_lin = tangents.ntk(f_lin, batch_size=100)(
            outer_get_params(outer_state_lin), task_eval['x_train'])
        aux['ntk_train_rank_eval_lin'] = onp.linalg.matrix_rank(ntk_lin)
        log.append([(key, aux[key]) for key in win_rank_eval_keys])

        # spectrum
        evals, evecs = onp.linalg.eigh(ntk)  # eigenvectors are columns
        for j in range(len(evals)):
            aux[f'ntk_spectrum_{j}_eval'] = evals[j]
        log.append([(key, aux[key]) for key in win_spectrum_eval_keys])

        evals = evals.clip(min=1e-10)
        ind = onp.arange(len(evals)) + 1  # +1 because we are taking log
        ind = ind[::-1]
        X = onp.stack([ind, evals], axis=1)
                      n_hidden_unit=args.n_hidden_unit,
                      bias_coef=args.bias_coef,
                      activation=args.activation,
                      norm=args.norm)

    # initialize network
    key = random.PRNGKey(run)
    _, params = net_init(key, (-1, 1))

    # data
    task = sinusoid_task(n_support=args.n_support)
    x_train, y_train, x_test, y_test = task['x_train'], task['y_train'], task[
        'x_test'], task['y_test']

    # linearized network
    f_lin = tangents.linearize(f, params)

    # Create an MSE predictor to solve the NTK equation in function space.
    theta = tangents.ntk(f, batch_size=32)
    g_dd = theta(params, x_train)
    g_td = theta(params, x_test, x_train)
    predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td)
    import ipdb
    ipdb.set_trace()

    # Get initial values of the network in function space.
    fx_train_ana_init = f(params, x_train)
    fx_test_ana_init = f(params, x_test)

    # optimizer for f and f_lin
    if args.inner_opt_alg == 'sgd':