def testNTKBatched(self, shape, out_logits):

    key = random.PRNGKey(0)
    data_self = random.normal(key, shape)
    data_other = random.normal(key, shape)

    key, w_split, b_split = random.split(key, 3)
    params = (random.normal(w_split, (shape[-1], out_logits)),
              random.normal(b_split, (out_logits,)))

    def f(params, x):
      w, b = params
      return np.dot(x, w) / shape[-1] + b

    g_fn = tangents.ntk(f)
    g_batched_fn = tangents.ntk(f, batch_size=2)

    g = g_fn(params, data_self)
    g_batched = g_batched_fn(params, data_self)
    self.assertAllClose(g, g_batched, check_dtypes=False)

    g = g_fn(params, data_other, data_self)
    g_batched = g_batched_fn(params, data_other, data_self)
    self.assertAllClose(g, g_batched, check_dtypes=False)
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, f = stax.serial(layers.Dense(4096), stax.Tanh, layers.Dense(10))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    theta = tangents.ntk(f, batch_size=32)
    g_dd = theta(params, x_train)
    import ipdb
    ipdb.set_trace()
    g_td = theta(params, x_test, x_train)
    predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td)

    # Get initial values of the network in function space.
    fx_train = f(params, x_train)
    fx_test = f(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(fx_train, fx_test, FLAGS.train_time)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
  def testNTKAgainstDirect(self, shape, out_logits):

    def sum_and_contract(j1, j2):
      def contract(x, y):
        param_count = int(np.prod(x.shape[2:]))
        x = np.reshape(x, (-1, param_count))
        y = np.reshape(y, (-1, param_count))
        return np.dot(x, np.transpose(y))

      return tree_reduce(operator.add, tree_multimap(contract, j1, j2))

    def ntk_direct(f, params, x1, x2):
      jac_fn = jacobian(f)
      j1 = jac_fn(params, x1)

      if x2 is None:
        j2 = j1
      else:
        j2 = jac_fn(params, x2)

      return sum_and_contract(j1, j2)

    key = random.PRNGKey(0)
    data_self = random.normal(key, shape)
    data_other = random.normal(key, shape)

    key, w_split, b_split = random.split(key, 3)
    params = (random.normal(w_split, (shape[-1], out_logits)),
              random.normal(b_split, (out_logits,)))

    def f(params, x):
      w, b = params
      return np.dot(x, w) / shape[-1] + b

    g_fn = tangents.ntk(f)

    g = g_fn(params, data_self)
    g_direct = ntk_direct(f, params, data_self, data_self)
    self.assertAllClose(g, g_direct, check_dtypes=False)

    g = g_fn(params, data_other, data_self)
    g_direct = ntk_direct(f, params, data_other, data_self)
    self.assertAllClose(g, g_direct, check_dtypes=False)
                      n_query=args.n_query)
else:
    raise ValueError

ntk_frequency = 50
plot_update_frequency = 100
for i, task_batch in tqdm(enumerate(
        taskbatch(task_fn=task_fn,
                  batch_size=args.task_batch_size,
                  n_task=args.n_train_task)),
                          total=args.n_train_task // args.task_batch_size):
    aux = dict()
    # ntk
    if i == 0 or (i + 1) % (args.n_train_task // args.task_batch_size //
                            ntk_frequency) == 0:
        ntk = tangents.ntk(f, batch_size=100)(outer_get_params(outer_state),
                                              task_eval['x_train'])
        aux['ntk_train_rank_eval'] = onp.linalg.matrix_rank(ntk)
        f_lin = tangents.linearize(f, outer_get_params(outer_state_lin))
        ntk_lin = tangents.ntk(f_lin, batch_size=100)(
            outer_get_params(outer_state_lin), task_eval['x_train'])
        aux['ntk_train_rank_eval_lin'] = onp.linalg.matrix_rank(ntk_lin)
        log.append([(key, aux[key]) for key in win_rank_eval_keys])

        # spectrum
        evals, evecs = onp.linalg.eigh(ntk)  # eigenvectors are columns
        for j in range(len(evals)):
            aux[f'ntk_spectrum_{j}_eval'] = evals[j]
        log.append([(key, aux[key]) for key in win_spectrum_eval_keys])

        evals = evals.clip(min=1e-10)
        ind = onp.arange(len(evals)) + 1  # +1 because we are taking log
  def testNTKMomentumPrediction(self, shape, out_logits):

    key = random.PRNGKey(1)

    key, split = random.split(key)
    data_train = random.normal(split, shape)

    key, split = random.split(key)
    label_ids = random.randint(split, (shape[0],), 0, out_logits)
    data_labels = np.eye(out_logits)[label_ids]

    key, split = random.split(key)
    data_test = random.normal(split, shape)

    key, w_split, b_split = random.split(key, 3)
    params = (random.normal(w_split, (shape[-1], out_logits)),
              random.normal(b_split, (out_logits,)))

    def f(params, x):
      w, b = params
      return np.dot(x, w) / shape[-1] + b

    loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat) ** 2)
    grad_loss = grad(lambda params, x: loss(f(params, x), data_labels))

    theta = tangents.ntk(f)
    g_dd = theta(params, data_train)
    g_td = theta(params, data_test, data_train)

    step_size = 1.0
    train_time = 100.0
    steps = int(train_time / np.sqrt(step_size))

    init_fn, predict_fn, get_fn = tangents.momentum_predictor(
        g_dd, data_labels, loss, step_size, g_td)

    opt_init, opt_update, get_params = momentum(step_size, 0.9)
    opt_state = opt_init(params)

    fx_initial_train = f(params, data_train)
    fx_initial_test = f(params, data_test)

    lin_state = init_fn(fx_initial_train, fx_initial_test)

    for i in range(steps):
      params = get_params(opt_state)
      opt_state = opt_update(i, grad_loss(params, data_train), opt_state)

    params = get_params(opt_state)
    fx_train = f(params, data_train)
    fx_test = f(params, data_test)

    lin_state = predict_fn(lin_state, train_time)

    fx_pred_train, fx_pred_test = get_fn(lin_state)

    # Put errors in units of RMS distance of the function values during
    # optimization.
    fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train) ** 2))
    fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test) ** 2))

    fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
    fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

    self.assertAllClose(
        fx_error_train, np.zeros_like(fx_error_train), False, 0.1, 0.1)
    self.assertAllClose(
        fx_error_test, np.zeros_like(fx_error_test), False, 0.1, 0.1)
  def testNTKGDPrediction(self, shape, out_logits):

    key = random.PRNGKey(1)

    key, split = random.split(key)
    data_train = random.normal(split, shape)

    key, split = random.split(key)
    label_ids = random.randint(split, (shape[0],), 0, out_logits)
    data_labels = np.eye(out_logits)[label_ids]

    key, split = random.split(key)
    data_test = random.normal(split, shape)

    key, w_split, b_split = random.split(key, 3)
    params = (random.normal(w_split, (shape[-1], out_logits)),
              random.normal(b_split, (out_logits,)))

    def f(params, x):
      w, b = params
      return np.dot(x, w) / shape[-1] + b

    loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat) ** 2)
    grad_loss = grad(lambda params, x: loss(f(params, x), data_labels))

    theta = tangents.ntk(f)
    g_dd = theta(params, data_train)
    g_td = theta(params, data_test, data_train)

    predictor = tangents.gradient_descent_predictor(
        g_dd, data_labels, loss, g_td)

    step_size = 1.0
    train_time = 100.0
    steps = int(train_time / step_size)

    opt_init, opt_update, get_params = opt.sgd(step_size)
    opt_state = opt_init(params)

    fx_initial_train = f(params, data_train)
    fx_initial_test = f(params, data_test)

    fx_pred_train, fx_pred_test = predictor(
        fx_initial_train, fx_initial_test, 0.0)

    # NOTE(schsam): I think at the moment stax always generates 32-bit results
    # since the weights are explicitly cast to float32.
    self.assertAllClose(fx_initial_train, fx_pred_train, False)
    self.assertAllClose(fx_initial_test, fx_pred_test, False)

    for i in range(steps):
      params = get_params(opt_state)
      opt_state = opt_update(i, grad_loss(params, data_train), opt_state)

    params = get_params(opt_state)
    fx_train = f(params, data_train)
    fx_test = f(params, data_test)

    fx_pred_train, fx_pred_test = predictor(
        fx_initial_train, fx_initial_test, train_time)

    # Put errors in units of RMS distance of the function values during
    # optimization.
    fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train) ** 2))
    fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test) ** 2))

    fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
    fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

    self.assertAllClose(
        fx_error_train, np.zeros_like(fx_error_train), False, 0.1, 0.1)
    self.assertAllClose(
        fx_error_test, np.zeros_like(fx_error_test), False, 0.1, 0.1)
  def testNTKMSEPrediction(self, shape, out_logits):

    key = random.PRNGKey(0)

    key, split = random.split(key)
    data_train = random.normal(split, shape)

    key, split = random.split(key)
    data_labels = np.array(
        random.bernoulli(split, shape=(shape[0], out_logits)), np.float32)

    key, split = random.split(key)
    data_test = random.normal(split, shape)

    key, w_split, b_split = random.split(key, 3)
    params = (random.normal(w_split, (shape[-1], out_logits)),
              random.normal(b_split, (out_logits,)))

    def f(params, x):
      w, b = params
      return np.dot(x, w) / shape[-1] + b

    # Regress to an MSE loss.
    loss = lambda params, x: \
        0.5 * np.mean((f(params, x) - data_labels) ** 2)

    theta = tangents.ntk(f)
    g_dd = theta(params, data_train)
    g_td = theta(params, data_test, data_train)

    predictor = tangents.analytic_mse_predictor(g_dd, data_labels, g_td)

    step_size = 1.0
    train_time = 100.0
    steps = int(train_time / step_size)

    opt_init, opt_update, get_params = opt.sgd(step_size)
    opt_state = opt_init(params)

    fx_initial_train = f(params, data_train)
    fx_initial_test = f(params, data_test)

    fx_pred_train, fx_pred_test = predictor(
        fx_initial_train, fx_initial_test, 0.0)

    # NOTE(schsam): I think at the moment stax always generates 32-bit results
    # since the weights are explicitly cast to float32.
    self.assertAllClose(fx_initial_train, fx_pred_train, False)
    self.assertAllClose(fx_initial_test, fx_pred_test, False)

    for i in range(steps):
      params = get_params(opt_state)
      opt_state = opt_update(i, grad(loss)(params, data_train), opt_state)

    params = get_params(opt_state)
    fx_train = f(params, data_train)
    fx_test = f(params, data_test)

    fx_pred_train, fx_pred_test = predictor(
        fx_initial_train, fx_initial_test, train_time)

    fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train) ** 2))
    fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test) ** 2))

    fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
    fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

    self.assertAllClose(
        fx_error_train, np.zeros_like(fx_error_train), False, 0.1, 0.1)
    self.assertAllClose(
        fx_error_test, np.zeros_like(fx_error_test), False, 0.1, 0.1)
                      norm=args.norm)

    # initialize network
    key = random.PRNGKey(run)
    _, params = net_init(key, (-1, 1))

    # data
    task = sinusoid_task(n_support=args.n_support)
    x_train, y_train, x_test, y_test = task['x_train'], task['y_train'], task[
        'x_test'], task['y_test']

    # linearized network
    f_lin = tangents.linearize(f, params)

    # Create an MSE predictor to solve the NTK equation in function space.
    theta = tangents.ntk(f, batch_size=32)
    g_dd = theta(params, x_train)
    g_td = theta(params, x_test, x_train)
    predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td)
    import ipdb
    ipdb.set_trace()

    # Get initial values of the network in function space.
    fx_train_ana_init = f(params, x_train)
    fx_test_ana_init = f(params, x_test)

    # optimizer for f and f_lin
    if args.inner_opt_alg == 'sgd':
        optimizer = partial(optimizers.sgd, step_size=args.inner_step_size)
    elif args.inner_opt_alg == 'momentum':
        optimizer = partial(optimizers.momentum,