示例#1
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
      datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size)

    # Build the infinite network.
    _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.05), stax.Relu(),
                                  stax.Dense(1, 2., 0.05))

    # Optionally, compute the kernel in batches, in parallel.
    kernel_fn = nt.batch(kernel_fn,
                         device_count=0,
                         batch_size=FLAGS.batch_size)

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=1e-3)
    fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test)
    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    duration = time.time() - start
    print('Kernel construction and inference done in %s seconds.' % duration)

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
示例#2
0
def main(unused_argv):
    # Build data and .
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            permute_train=True)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(
        FLAGS.learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    steps_per_epoch = 50000 // FLAGS.batch_size

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1

    # Print out summary data comparing the linear / nonlinear model.
    x, y = x_train[:10000], y_train[:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', y_test, f(params, x_test),
                       f_lin(params_lin, x_test), loss)
def main(unused_argv):
  # Build data pipelines.
  print('Loading data.')
  x_train, y_train, x_test, y_test = \
      datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

  # Build the network
  init_fn, apply_fn, _ = stax.serial(
      stax.Dense(512, 1., 0.05),
      stax.Erf(),
      stax.Dense(10, 1., 0.05))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Create and initialize an optimizer.
  opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
  state = opt_init(params)

  # Create an mse loss function and a gradient function.
  loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
  grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

  # Create an MSE predictor to solve the NTK equation in function space.
  ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0),
                 batch_size=4, device_count=0)
  g_dd = ntk(x_train, None, params)
  g_td = ntk(x_test, x_train, params)
  predictor = nt.predict.gradient_descent_mse(g_dd, y_train)

  # Get initial values of the network in function space.
  fx_train = apply_fn(params, x_train)
  fx_test = apply_fn(params, x_test)

  # Train the network.
  train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
  print('Training for {} steps'.format(train_steps))

  for i in range(train_steps):
    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x_train, y_train), state)

  # Get predictions from analytic computation.
  print('Computing analytic prediction.')
  fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td)

  # Print out summary data comparing the linear / nonlinear model.
  util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                     loss)
  util.print_summary('test', y_test, apply_fn(params, x_test), fx_test,
                     loss)
示例#4
0
def infinite_resnet(train_embedding, test_embedding, data_set):
    _, _, kernel_fn = wide_resnet(block_size=4, k=1, num_classes=2)
    kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=0)
    fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
                                                        train_embedding,
                                                        data_set['Y_train'],
                                                        test_embedding,
                                                        get=('nngp', 'ntk'),
                                                        diag_reg=1e-3)
    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', data_set['Y_test'], fx_test_nngp, None,
                       loss)
    util.print_summary('NTK test', data_set['Y_test'], fx_test_ntk, None, loss)
def main(l, w_std, b_std, x_train, y_train, x_test, y_test):

    # Build the infinite network.
    net0 = stax.Dense(1, w_std, b_std)
    nets = [net0]

    k_layer = []
    K = net0[2](x_train, None)
    k_layer.append(K.nngp)

    for l in range(1, l + 1):
        net_l = stax.serial(stax.Relu(), stax.Dense(1, w_std, b_std))
        K = net_l[2](K)
        k_layer.append(K.nngp)
        nets += [stax.serial(nets[-1], net_l)]

    kernel_fn = nets[-1][2]

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
                                                        x_train,
                                                        y_train,
                                                        x_test,
                                                        get=('nngp', 'ntk'),
                                                        diag_reg=0)

    fx_test_nngp.block_until_ready()

    #finding training accuracy
    fx_test_nngp_train, fx_test_ntk_train = nt.predict.gp_inference(
        kernel_fn, x_train, y_train, x_train, get=('nngp', 'ntk'), diag_reg=0)

    fx_test_nngp_train.block_until_ready()

    duration = time.time() - start
    print('Kernel construction and inference done in %s seconds.' % duration)

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    n_accuracy, n_loss_x = util.print_summary('NNGP test', y_test,
                                              fx_test_nngp, None, loss)
    n_accuracy_x, n_loss = util.print_summary('NNGP test', y_train,
                                              fx_test_nngp_train, None, loss)
    return (n_accuracy, n_loss, k_layer)
def main(x_train, y_train, x_test, y_test, kernel_fn):

    fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
                                                        x_train,
                                                        y_train,
                                                        x_test,
                                                        get=('nngp', 'ntk'),
                                                        diag_reg=1e-3)
    fx_test_nngp.block_until_ready()

    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    n_accuracy, n_loss = util.print_summary('NNGP test', y_test, fx_test_nngp,
                                            None, loss)
    return (n_accuracy)
示例#7
0
def main(*args, use_dummy_data: bool = False, **kwargs) -> None:
    # Mask all padding with this value.
    mask_constant = 100.

    if use_dummy_data:
        x_train, y_train, x_test, y_test = _get_dummy_data(mask_constant)
    else:
        # Build data pipelines.
        print('Loading IMDb data.')
        x_train, y_train, x_test, y_test = datasets.get_dataset(
            name='imdb_reviews',
            n_train=FLAGS.n_train,
            n_test=FLAGS.n_test,
            do_flatten_and_normalize=False,
            data_dir=FLAGS.imdb_path,
            input_key='text')

        # Embed words and pad / truncate sentences to a fixed size.
        x_train, x_test = datasets.embed_glove(
            xs=[x_train, x_test],
            glove_path=FLAGS.glove_path,
            max_sentence_length=FLAGS.max_sentence_length,
            mask_constant=mask_constant)

    # Build the infinite network.
    # Not using the finite model, hence width is set to 1 everywhere.
    _, _, kernel_fn = stax.serial(
        stax.Conv(out_chan=1,
                  filter_shape=(9, ),
                  strides=(1, ),
                  padding='VALID'), stax.Relu(),
        stax.GlobalSelfAttention(n_chan_out=1,
                                 n_chan_key=1,
                                 n_chan_val=1,
                                 pos_emb_type='SUM',
                                 W_pos_emb_std=1.,
                                 pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
                                 n_heads=1), stax.Relu(), stax.GlobalAvgPool(),
        stax.Dense(out_dim=1))

    # Optionally, compute the kernel in batches, in parallel.
    kernel_fn = nt.batch(kernel_fn,
                         device_count=-1,
                         batch_size=FLAGS.batch_size)

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    predict = nt.predict.gradient_descent_mse_ensemble(
        kernel_fn=kernel_fn,
        x_train=x_train,
        y_train=y_train,
        diag_reg=1e-6,
        mask_constant=mask_constant)

    fx_test_nngp, fx_test_ntk = predict(x_test=x_test, get=('nngp', 'ntk'))

    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    duration = time.time() - start
    print(f'Kernel construction and inference done in {duration} seconds.')

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
示例#8
0
def main(unused_argv):
  # Build data pipelines.
  print('Loading data.')
  x_train, y_train, x_test, y_test = \
    datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

  # Build the infinite network.
  l = 5
  w_std = 1.5
  b_std = 2

  net0 = stax.Dense(1, w_std, b_std)
  nets = [net0]

  k_layer = []
  K = net0[2](x_train, None)
  k_layer.append(K.nngp)

  for l in range(1, l+1):
    net_l = stax.serial(stax.Relu(), stax.Dense(1, w_std, b_std))
    K = net_l[2](K)
    k_layer.append(K.nngp)
    nets += [stax.serial(nets[-1], net_l)]

  kernel_fn = nets[-1][2]

  # Optionally, compute the kernel in batches, in parallel.
  kernel_fn = nt.batch(kernel_fn,
                       device_count=0,
                       batch_size=FLAGS.batch_size)

  start = time.time()
  # Bayesian and infinite-time gradient descent inference with infinite network.
  fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
                                                      x_train,
                                                      y_train,
                                                      x_test,
                                                      get=('nngp', 'ntk'),
                                                      diag_reg=1e-3)


  fx_test_nngp.block_until_ready()
  fx_test_ntk.block_until_ready()

  duration = time.time() - start
  print('Kernel construction and inference done in %s seconds.' % duration)

  # Print out accuracy and loss for infinite network predictions.
  loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
  util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)


  grid = []
  count = 1
  k_plot = []
  for i in k_layer:
    grid.append(count)
    count += 1
    k_plot.append(np.log(i[5,5]))
   
  # plt.plot(grid, k_plot)
  # plt.xlabel('layer ; w_var = 10, b_var = 2, accuracy = 93%')
  # plt.ylabel('Log (K[5][5]) ')

  w, v = LA.eig(k_layer[-1])
  w = np.sort(w)
  #print(w)
  #plt.scatter(w, np.zeros(len(w)))
  index = []
  for i in range(1,len(w)+1):
    index.append(i)

  w.sort()
  plt.scatter(index,np.log(w)[::-1]/np.log(10))
  #plt.plot(index,mp)
  plt.ylabel("log10[eigen val]")
  plt.show()

  sio.savemat('mnist_l10_wvar=0_85_b_var=0_1.mat', {
        'kernel': k_layer[-1]
    }) 
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, apply_fn, _ = stax.serial(
      stax.Dense(2048, 1., 0.05),
      # stax.Erf(),
      stax.Relu(),
      stax.Dense(2048, 1., 0.05),
      # stax.Erf(),
      stax.Relu(),
      stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # params

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)
    # state


    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
    # g_dd.shape

    m = FLAGS.train_size
    print(m)
    n = m*10
    m_test = FLAGS.test_size
    n_test = m_test*10
    # g_td.shape
    # predictor
    # g_dd
    # type(g_dd)
    # g_dd.shape
    theta = g_dd.transpose((0,2,1,3)).reshape(n,n)
    theta_test = ntk(x_test, None, params).transpose((0,2,1,3)).reshape(n_test,n_test)
    theta_tilde = g_td.transpose((0,2,1,3)).reshape(n_test,n)
    #NNGP
    K = nt.empirical_nngp_fn(apply_fn)(x_train,None,params)
    K = np.kron(theta,np.eye(10))
    K_test = nt.empirical_nngp_fn(apply_fn)(x_test,None,params)
    K_test = np.kron(theta_test,np.eye(10))
    K_tilde = nt.empirical_nngp_fn(apply_fn)(x_test,x_train,params)
    K_tilde = np.kron(theta_tilde,np.eye(10))

    decay_matrix = np.eye(n)-scipy.linalg.expm(-t*theta)
    Sigma = K + np.matmul(decay_matrix, np.matmul(K, np.matmul(np.linalg.inv(theta), np.matmul(decay_matrix, theta))) - 2*K)

    # K.shape
    theta
    # alpha = np.matmul(np.linalg.inv(K),np.matmul(theta,np.linalg.inv(theta)))
    # y_train
    # alpha = np.matmul(np.linalg.inv(K), y_train.reshape(1280))
    # Sigma = K + np.matmul()
    # K = theta
    sigma_noise = 1.0
    Y = y_train.reshape(n)
    alpha = np.matmul(np.linalg.inv(np.eye(n)*(sigma_noise**2)+K),Y)
    # cov = np.linalg.inv(np.linalg.inv(K)+np.eye(n)/(sigma_noise**2))
    # covi = np.linalg.inv(cov)
    # covi = np.linalg.inv(K)+np.eye(n)/(sigma_noise**2)
    # print(covi)
    # np.linalg.det(K)
    eigs = np.linalg.eigh(K)[0]
    logdetcoviK = np.sum(np.log((eigs+sigma_noise**2) /sigma_noise**2))
    # coviK = np.matmul(covi,K)
    # coviK = np.eye(n) + K/(sigma_noise**2)
    # coviK
    # covi
    # np.linalg.det()
    # KL = 0.5*np.log(np.linalg.det(coviK)) + 0.5*np.trace(np.linalg.inv(coviK)) + 0.5*np.matmul(alpha.T,np.matmul(K,alpha)) - n/2
    KL = 0.5*logdetcoviK + 0.5*np.trace(np.linalg.inv(coviK)) + 0.5*np.matmul(alpha.T,np.matmul(K,alpha)) - n/2
    print(KL)

    delta = 2**-10
    bound = (KL+2*np.log(m)+1-np.log(delta))/m
    bound = 1-np.exp(-bound)
    bound
    print("bound", bound)

    import numpy
    bigK = numpy.zeros((n+n_test,n+n_test))
    bigK
    bigK[0:n,0:n] = K
    bigK[0:n,n:] = theta_tilde.T
    bigK[n:,0:n] = theta_tilde
    bigK[n:,n:] = theta_test
    init_ntk_f = numpy.random.multivariate_normal(np.zeros(n+n_test),bigK)
    fx_train = init_ntk_f[:n].reshape(m,10)
    fx_test = init_ntk_f[n:].reshape(m_test,10)

    # Get initial values of the network in function space.
    # fx_train = apply_fn(params, x_train)
    # fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
示例#10
0
bigK
bigK[0:n,0:n] = K
bigK[0:n,n:] = theta_tilde.T
bigK[:n,0:n] = theta_tilde
bigK[:n,n:] = theta_test
init_ntk_f = numpy.random.multivariate_normal(np.zeros(n+n_test),bigK)
fx_train = init_ntk_f[:n].reshape(m,10)
fx_test = init_ntk_f[n:].reshape(m_test,10)

# Get initial values of the network in function space.
# fx_train = apply_fn(params, x_train)
# fx_test = apply_fn(params, x_test)

# Train the network.
train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
print('Training for {} steps'.format(train_steps))

for i in range(train_steps*10):
    # print(i)
    params = get_params(state)
state = opt_apply(i, grad_loss(params, x_train, y_train), state)

# Get predictions from analytic computation.
print('Computing analytic prediction.')
# fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)
fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

# Print out summary data comparing the linear / nonlinear model.
util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss)
util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
示例#11
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
      datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # x_train
    import numpy
    # numpy.argmax(y_train,1)%2
    # y_train_tmp = numpy.zeros((y_train.shape[0],2))
    # y_train_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_train,1)%2] = 1
    # y_train = y_train_tmp
    # y_test_tmp = numpy.zeros((y_test.shape[0],2))
    # y_test_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_test,1)%2] = 1
    # y_test = y_test_tmp

    y_train_tmp = numpy.argmax(y_train, 1) % 2
    y_train = np.expand_dims(y_train_tmp, 1)
    y_test_tmp = numpy.argmax(y_test, 1) % 2
    y_test = np.expand_dims(y_test_tmp, 1)
    # print(y_train)
    # Build the network
    # init_fn, apply_fn, _ = stax.serial(
    #   stax.Dense(2048, 1., 0.05),
    #   # stax.Erf(),
    #   stax.Relu(),
    #   stax.Dense(2048, 1., 0.05),
    #   # stax.Erf(),
    #   stax.Relu(),
    #   stax.Dense(10, 1., 0.05))
    init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                       stax.Dense(1, 1., 0.05))

    # key = random.PRNGKey(0)
    randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                           high=np.iinfo(np.int32).max,
                                           size=2)[0]
    key = random.PRNGKey(randnnn)
    _, params = init_fn(key, (-1, 784))

    # params

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)
    # state

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
    # g_dd.shape

    # Get initial values of the network in function space.
    fx_train = apply_fn(params, x_train)
    fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                       loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
示例#12
0
def weight_space(train_embedding, test_embedding, data_set):
    init_fn, f, _ = stax.serial(
        stax.Dense(512, 1., 0.05),
        stax.Erf(),
        # 2 denotes 2 type of classes
        stax.Dense(2, 1., 0.05))

    key = random.PRNGKey(0)
    # (-1, 135),  135 denotes the feature length, here is 9 * 15 = 135
    _, params = init_fn(key, (-1, 135))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    # Use whole batch
    batch_size = 64
    train_epochs = 10
    steps_per_epoch = 100

    for i, (x, y) in enumerate(
            datasets.mini_batch(train_embedding, data_set['Y_train'],
                                batch_size, train_epochs)):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1
        if i / steps_per_epoch == train_epochs:
            break

    # Print out summary data comparing the linear / nonlinear model.
    x, y = train_embedding[:10000], data_set['Y_train'][:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', data_set['Y_test'], f(params, test_embedding),
                       f_lin(params_lin, test_embedding), loss)