def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, f = stax.serial(layers.Dense(4096), stax.Tanh, layers.Dense(10))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    theta = tangents.ntk(f, batch_size=32)
    g_dd = theta(params, x_train)
    import ipdb
    ipdb.set_trace()
    g_td = theta(params, x_test, x_train)
    predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td)

    # Get initial values of the network in function space.
    fx_train = f(params, x_train)
    fx_test = f(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(fx_train, fx_test, FLAGS.train_time)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
  def testNTKMSEPrediction(self, shape, out_logits):

    key = random.PRNGKey(0)

    key, split = random.split(key)
    data_train = random.normal(split, shape)

    key, split = random.split(key)
    data_labels = np.array(
        random.bernoulli(split, shape=(shape[0], out_logits)), np.float32)

    key, split = random.split(key)
    data_test = random.normal(split, shape)

    key, w_split, b_split = random.split(key, 3)
    params = (random.normal(w_split, (shape[-1], out_logits)),
              random.normal(b_split, (out_logits,)))

    def f(params, x):
      w, b = params
      return np.dot(x, w) / shape[-1] + b

    # Regress to an MSE loss.
    loss = lambda params, x: \
        0.5 * np.mean((f(params, x) - data_labels) ** 2)

    theta = tangents.ntk(f)
    g_dd = theta(params, data_train)
    g_td = theta(params, data_test, data_train)

    predictor = tangents.analytic_mse_predictor(g_dd, data_labels, g_td)

    step_size = 1.0
    train_time = 100.0
    steps = int(train_time / step_size)

    opt_init, opt_update, get_params = opt.sgd(step_size)
    opt_state = opt_init(params)

    fx_initial_train = f(params, data_train)
    fx_initial_test = f(params, data_test)

    fx_pred_train, fx_pred_test = predictor(
        fx_initial_train, fx_initial_test, 0.0)

    # NOTE(schsam): I think at the moment stax always generates 32-bit results
    # since the weights are explicitly cast to float32.
    self.assertAllClose(fx_initial_train, fx_pred_train, False)
    self.assertAllClose(fx_initial_test, fx_pred_test, False)

    for i in range(steps):
      params = get_params(opt_state)
      opt_state = opt_update(i, grad(loss)(params, data_train), opt_state)

    params = get_params(opt_state)
    fx_train = f(params, data_train)
    fx_test = f(params, data_test)

    fx_pred_train, fx_pred_test = predictor(
        fx_initial_train, fx_initial_test, train_time)

    fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train) ** 2))
    fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test) ** 2))

    fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
    fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

    self.assertAllClose(
        fx_error_train, np.zeros_like(fx_error_train), False, 0.1, 0.1)
    self.assertAllClose(
        fx_error_test, np.zeros_like(fx_error_test), False, 0.1, 0.1)
    key = random.PRNGKey(run)
    _, params = net_init(key, (-1, 1))

    # data
    task = sinusoid_task(n_support=args.n_support)
    x_train, y_train, x_test, y_test = task['x_train'], task['y_train'], task[
        'x_test'], task['y_test']

    # linearized network
    f_lin = tangents.linearize(f, params)

    # Create an MSE predictor to solve the NTK equation in function space.
    theta = tangents.ntk(f, batch_size=32)
    g_dd = theta(params, x_train)
    g_td = theta(params, x_test, x_train)
    predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td)
    import ipdb
    ipdb.set_trace()

    # Get initial values of the network in function space.
    fx_train_ana_init = f(params, x_train)
    fx_test_ana_init = f(params, x_test)

    # optimizer for f and f_lin
    if args.inner_opt_alg == 'sgd':
        optimizer = partial(optimizers.sgd, step_size=args.inner_step_size)
    elif args.inner_opt_alg == 'momentum':
        optimizer = partial(optimizers.momentum,
                            step_size=args.inner_step_size,
                            mass=0.9)
    elif args.inner_opt_alg == 'adam':