Exemplo n.º 1
0
def main(unused_argv):

    train_size = FLAGS.train_size
    x_train, y_train, x_test, y_test = pickle.load(
        open("data_" + str(train_size) + ".p", "rb"))
    print("Got data")
    sys.stdout.flush()

    # Build the network
    init_fn, apply_fn, _ = stax.serial(
        stax.Dense(2048, 1., 0.05),
        # stax.Erf(),
        stax.Relu(),
        stax.Dense(1, 1., 0.05))

    # initialize the network first time, to compute NTK
    randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                           high=np.iinfo(np.int32).max,
                                           size=2)[0]
    key = random.PRNGKey(randnnn)
    _, params = init_fn(key, (-1, 784))

    # Create an MSE predictor to solve the NTK equation in function space.
    # we assume that the NTK is approximately the same for any sample of parameters (true in the limit of infinite width)

    print("Making NTK")
    sys.stdout.flush()
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=1)
    g_dd = ntk(x_train, None, params)
    pickle.dump(g_dd, open("ntk_train_" + str(FLAGS.train_size) + ".p", "wb"))
    g_td = ntk(x_test, x_train, params)
    pickle.dump(g_td,
                open("ntk_train_test_" + str(FLAGS.train_size) + ".p", "wb"))
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
Exemplo n.º 2
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
      datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size)

    # Build the infinite network.
    _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.05), stax.Relu(),
                                  stax.Dense(1, 2., 0.05))

    # Optionally, compute the kernel in batches, in parallel.
    kernel_fn = nt.batch(kernel_fn,
                         device_count=0,
                         batch_size=FLAGS.batch_size)

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=1e-3)
    fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test)
    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    duration = time.time() - start
    print('Kernel construction and inference done in %s seconds.' % duration)

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
Exemplo n.º 3
0
def train(kernel_fn, x_train, y_train, x_test, y_test):
  batched_kernel_fn = nt.batch(kernel_fn, 25)
  
  K_test_train = batched_kernel_fn(x_test, x_train).ntk
  
  K_train_train = batched_kernel_fn(x_train, x_train).ntk
  # NNGP = batched_kernel_fn(x_train, x_train).nngp
  # print(NNGP)
  #print(K_train_train)
  # print(K_train_train)
  y_test_pred = K_test_train @ np.linalg.inv(K_train_train) @ y_train
  #print(y_test_pred)
  loss_d = np.mean((y_test_pred - y_test)**2)
  y_test_class = np.where(y_test_pred > 0, 1., -1.)
  acc_d = np.mean(y_test_class == y_test)

  # y_train_pred = K_train_train @ np.linalg.inv(K_train_train) @ y_train
  # loss_t = np.mean((y_train_pred - y_train)**2)
  # y_train_class = np.where(y_train_pred > 0, 1., -1.)
  # acc_t = np.mean(y_train_class == y_train)

  # x_id = np.eye(D).reshape(D, img_size[0], img_size[1], img_size[2])
  # K_id_train = batched_kernel_fn(x_id, x_train).ntk
  # operator = K_id_train @ np.linalg.inv(K_train_train) @ y_train
  # norm = np.linalg.norm(operator)

  # batched_kernel_fn = nt.batch(kernel_fn, 32)

  # B_matrix = batched_kernel_fn(x_id, x_id).ntk
  # w, _ = numpy.linalg.eig(B_matrix)
  # condition_no = numpy.max(w)/numpy.min(w)

  return loss_d, acc_d
Exemplo n.º 4
0
def main():
    train_size = 1000
    test_size = 1000
    batch_size = 0

    init_fn, apply_fn, kernel_fn = WideResnet(block_size=4,
                                              k=1,
                                              num_classes=10)
    x_train, y_train, x_test, y_test = get_dataset(
        'cifar10', train_size, test_size, do_flatten_and_normalize=False)
    kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=batch_size)

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    fx_test_nngp, fx_test_ntk = nt.predict.gradient_descent_mse_ensemble(
        kernel_fn,
        x_train,
        y_train,
        x_test,
        get=('nngp', 'ntk'),
        diag_reg=1e-3)
    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()
    duration = time.time() - start
    print('Kernel construction and inference done in %s seconds.' % duration)

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    print_summary('NTK test', y_test, fx_test_ntk, None, loss)
def define(args):
    hidden_layers = []
    for _ in range(args.num_hidden_layers):
        hidden_layers.append(stax.Dense(args.hidden_neurons, W_std=args.W_std, b_std=args.b_std))
        hidden_layers.append(stax.Relu())
    init_fn, apply_fn, kernel_fn = stax.serial(
        *hidden_layers,
        stax.Dense(args.output_dim, W_std=args.W_std, b_std=args.b_std)
    )
    apply_fn = jit(apply_fn)
    batched_kernel_fn = nt.batch(kernel_fn, batch_size=args.batch_size, device_count=-1)
    return init_fn, apply_fn, kernel_fn, batched_kernel_fn
Exemplo n.º 6
0
def main(unused_argv):
  # Build data pipelines.
  print('Loading data.')
  x_train, y_train, x_test, y_test = \
      datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

  # Build the network
  init_fn, apply_fn, _ = stax.serial(
      stax.Dense(512, 1., 0.05),
      stax.Erf(),
      stax.Dense(10, 1., 0.05))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Create and initialize an optimizer.
  opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
  state = opt_init(params)

  # Create an mse loss function and a gradient function.
  loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
  grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

  # Create an MSE predictor to solve the NTK equation in function space.
  ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0),
                 batch_size=4, device_count=0)
  g_dd = ntk(x_train, None, params)
  g_td = ntk(x_test, x_train, params)
  predictor = nt.predict.gradient_descent_mse(g_dd, y_train)

  # Get initial values of the network in function space.
  fx_train = apply_fn(params, x_train)
  fx_test = apply_fn(params, x_test)

  # Train the network.
  train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
  print('Training for {} steps'.format(train_steps))

  for i in range(train_steps):
    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x_train, y_train), state)

  # Get predictions from analytic computation.
  print('Computing analytic prediction.')
  fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td)

  # Print out summary data comparing the linear / nonlinear model.
  util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                     loss)
  util.print_summary('test', y_test, apply_fn(params, x_test), fx_test,
                     loss)
Exemplo n.º 7
0
def maker(beta=0, W_std=0.8, b_std=0, diag_reg=1e-4):
    init_fn, apply_fn, ker_fn = get_network(net_type=net_type,
                                            w_std=W_std,
                                            b_std=b_std)
    ker_fn = nt.batch(ker_fn,
                      batch_size=batch_size,
                      device_count=num_of_gpus,
                      store_on_device=True)
    predict_fn = gradient_descent_mse_vib(beta, train_images, train_labels,
                                          diag_reg, ker_fn)
    predict_fn = partial(predict_fn, get='ntk', compute_cov=False)
    init_pred_train = predict_fn(t=0, x_test=train_images)
    init_pred_test = predict_fn(t=0, x_test=test_images)
    return partial(calc_metrics,
                   predict_fn=predict_fn,
                   init_preds=[init_pred_train, init_pred_test],
                   beta=beta)
Exemplo n.º 8
0
def infinite_resnet(train_embedding, test_embedding, data_set):
    _, _, kernel_fn = wide_resnet(block_size=4, k=1, num_classes=2)
    kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=0)
    fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
                                                        train_embedding,
                                                        data_set['Y_train'],
                                                        test_embedding,
                                                        get=('nngp', 'ntk'),
                                                        diag_reg=1e-3)
    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', data_set['Y_test'], fx_test_nngp, None,
                       loss)
    util.print_summary('NTK test', data_set['Y_test'], fx_test_ntk, None, loss)
Exemplo n.º 9
0
def infinite_fcn(train_embedding, test_embedding, data_set, binary=True):
    _, _, kernel_fn = stax.serial(
        stax.Dense(64, 2., 0.05),
        stax.Relu(),
        stax.Dense(32, 2., 0.05),
        stax.Relu(),
        stax.Dense(4, 2., 0.05),
        stax.Relu(),
    )
    # 0 for no batching, whole batch
    kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=0)
    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    #for i in range(10):
    predict_fn = \
            nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_embedding, data_set['Y_train'],
                                                     diag_reg_absolute_scale=True, learning_rate=1, diag_reg=1e-3) #1e0 1e-3

    nngp_mean, nngp_covariance = predict_fn(x_test=test_embedding,
                                            get='nngp',
                                            compute_cov=True)

    #fx_test_nngp.block_until_ready()
    #fx_test_ntk.block_until_ready()

    duration = time.time() - start
    print('Kernel construction and inference done in %s seconds.' % duration)

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    utils.print_summary('NNGP test',
                        data_set['Y_test'],
                        nngp_mean,
                        None,
                        loss,
                        nngp_covariance,
                        binary=binary)
Exemplo n.º 10
0
def main(unused_argv):
    # Load and normalize data
    print('Loading data...')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', n_train=10, n_test=10,
                                                            permute_train=True)

    # Reformat MNIST data to 28x28x1 pictures
    x_train = np.asarray(x_train.reshape(-1, 28, 28, 1))
    x_test = np.asarray(x_test.reshape(-1, 28, 28, 1))
    print(f'Data loaded and reshaped with n_train = {x_train.shape[0]} (batch size {FLAGS.batch_size_kernel}) and '
          f'n_test = {x_test.shape[0]}.')

    # # Add random translation to images
    # x_train = util.add_translation(x_train, FLAGS.max_pixel)
    # x_test = util.add_translation(x_test, FLAGS.max_pixel)
    # print(f'Random translations by up to {FLAGS.max_pixel} pixels added')

    # # Add random translations with padding
    # x_train = util.add_padded_translation(x_train, 10)
    # x_test = util.add_padded_translation(x_test, 10)
    # print(f'Random translations with additional padding up to 10 pixels added')


    # Build the LeNet network
    init_fn, f, kernel_fn = util.build_le_net(FLAGS.network_width)
    print('Network build complete')

    # Construct the kernel function
    # Use 'store_on_device = False' for larger kernels
    kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size_kernel, store_on_device=False)

    # Set start time
    start_inf = time.time()

    # Bayesian and infinite-time gradient descent inference with infinite network
    print('Starting bayesian and infinite-time gradient descent inference with infinite network')
    predict_fn = nt.predict.gradient_descent_mse_ensemble(
        kernel_fn=kernel_fn,
        x_train=x_train,
        y_train=y_train,
        diag_reg=1e-6
    )

    duration_kernel = time.time() - start_inf
    print(f'Kernel constructed in {duration_kernel} seconds.')

    # fx_test_nngp_ub, fx_test_ntk_ub = predict_fn(x_test=x_test, get=('nngp', 'ntk'))

    fx_test_nngp, fx_test_ntk = [] * x_test.shape[0], [] * x_test.shape[0]
    print('Output vector allocated.')
    # print(f'Available GPU memory: {util.get_gpu_memory()} MiB')

    # Compute predictions in batches
    for i in range(x_test.shape[0] // FLAGS.batch_size_output):
        time_batch = time.time()
        start, end = i * FLAGS.batch_size_output, (i+1) * FLAGS.batch_size_output
        x = x_test[start:end]
        tmp_nngp, tmp_ntk = predict_fn(x_test=x, get=('nngp', 'ntk'))
        # tmp_ntk = predict_fn(x_test=x, get='ntk')
        duration_batch = time.time() - time_batch
        print(f'Batch {i+1} predicted in {duration_batch} seconds.')
        # print(f'Available GPU memory: {util.get_gpu_memory()} MiB')
        fx_test_nngp[start:end] = tmp_nngp
        fx_test_ntk[start:end] = tmp_ntk

    fx_test_nngp = np.array(fx_test_nngp)
    fx_test_ntk = np.array(fx_test_ntk)

    # fx_test_nngp.block_until_ready()
    # fx_test_ntk.block_until_ready()

    duration_inf = time.time() - start_inf

    print(f'Inference done in {duration_inf} seconds.')

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
Exemplo n.º 11
0
def main(*args, use_dummy_data: bool = False, **kwargs) -> None:
    # Mask all padding with this value.
    mask_constant = 100.

    if use_dummy_data:
        x_train, y_train, x_test, y_test = _get_dummy_data(mask_constant)
    else:
        # Build data pipelines.
        print('Loading IMDb data.')
        x_train, y_train, x_test, y_test = datasets.get_dataset(
            name='imdb_reviews',
            n_train=FLAGS.n_train,
            n_test=FLAGS.n_test,
            do_flatten_and_normalize=False,
            data_dir=FLAGS.imdb_path,
            input_key='text')

        # Embed words and pad / truncate sentences to a fixed size.
        x_train, x_test = datasets.embed_glove(
            xs=[x_train, x_test],
            glove_path=FLAGS.glove_path,
            max_sentence_length=FLAGS.max_sentence_length,
            mask_constant=mask_constant)

    # Build the infinite network.
    # Not using the finite model, hence width is set to 1 everywhere.
    _, _, kernel_fn = stax.serial(
        stax.Conv(out_chan=1,
                  filter_shape=(9, ),
                  strides=(1, ),
                  padding='VALID'), stax.Relu(),
        stax.GlobalSelfAttention(n_chan_out=1,
                                 n_chan_key=1,
                                 n_chan_val=1,
                                 pos_emb_type='SUM',
                                 W_pos_emb_std=1.,
                                 pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
                                 n_heads=1), stax.Relu(), stax.GlobalAvgPool(),
        stax.Dense(out_dim=1))

    # Optionally, compute the kernel in batches, in parallel.
    kernel_fn = nt.batch(kernel_fn,
                         device_count=-1,
                         batch_size=FLAGS.batch_size)

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    predict = nt.predict.gradient_descent_mse_ensemble(
        kernel_fn=kernel_fn,
        x_train=x_train,
        y_train=y_train,
        diag_reg=1e-6,
        mask_constant=mask_constant)

    fx_test_nngp, fx_test_ntk = predict(x_test=x_test, get=('nngp', 'ntk'))

    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    duration = time.time() - start
    print(f'Kernel construction and inference done in {duration} seconds.')

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
Exemplo n.º 12
0
def main(unused_argv):
  # Build data pipelines.
  print('Loading data.')
  x_train, y_train, x_test, y_test = \
    datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

  # Build the infinite network.
  l = 5
  w_std = 1.5
  b_std = 2

  net0 = stax.Dense(1, w_std, b_std)
  nets = [net0]

  k_layer = []
  K = net0[2](x_train, None)
  k_layer.append(K.nngp)

  for l in range(1, l+1):
    net_l = stax.serial(stax.Relu(), stax.Dense(1, w_std, b_std))
    K = net_l[2](K)
    k_layer.append(K.nngp)
    nets += [stax.serial(nets[-1], net_l)]

  kernel_fn = nets[-1][2]

  # Optionally, compute the kernel in batches, in parallel.
  kernel_fn = nt.batch(kernel_fn,
                       device_count=0,
                       batch_size=FLAGS.batch_size)

  start = time.time()
  # Bayesian and infinite-time gradient descent inference with infinite network.
  fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
                                                      x_train,
                                                      y_train,
                                                      x_test,
                                                      get=('nngp', 'ntk'),
                                                      diag_reg=1e-3)


  fx_test_nngp.block_until_ready()
  fx_test_ntk.block_until_ready()

  duration = time.time() - start
  print('Kernel construction and inference done in %s seconds.' % duration)

  # Print out accuracy and loss for infinite network predictions.
  loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
  util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)


  grid = []
  count = 1
  k_plot = []
  for i in k_layer:
    grid.append(count)
    count += 1
    k_plot.append(np.log(i[5,5]))
   
  # plt.plot(grid, k_plot)
  # plt.xlabel('layer ; w_var = 10, b_var = 2, accuracy = 93%')
  # plt.ylabel('Log (K[5][5]) ')

  w, v = LA.eig(k_layer[-1])
  w = np.sort(w)
  #print(w)
  #plt.scatter(w, np.zeros(len(w)))
  index = []
  for i in range(1,len(w)+1):
    index.append(i)

  w.sort()
  plt.scatter(index,np.log(w)[::-1]/np.log(10))
  #plt.plot(index,mp)
  plt.ylabel("log10[eigen val]")
  plt.show()

  sio.savemat('mnist_l10_wvar=0_85_b_var=0_1.mat', {
        'kernel': k_layer[-1]
    }) 
Exemplo n.º 13
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, apply_fn, _ = stax.serial(
      stax.Dense(2048, 1., 0.05),
      # stax.Erf(),
      stax.Relu(),
      stax.Dense(2048, 1., 0.05),
      # stax.Erf(),
      stax.Relu(),
      stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # params

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)
    # state


    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
    # g_dd.shape

    m = FLAGS.train_size
    print(m)
    n = m*10
    m_test = FLAGS.test_size
    n_test = m_test*10
    # g_td.shape
    # predictor
    # g_dd
    # type(g_dd)
    # g_dd.shape
    theta = g_dd.transpose((0,2,1,3)).reshape(n,n)
    theta_test = ntk(x_test, None, params).transpose((0,2,1,3)).reshape(n_test,n_test)
    theta_tilde = g_td.transpose((0,2,1,3)).reshape(n_test,n)
    #NNGP
    K = nt.empirical_nngp_fn(apply_fn)(x_train,None,params)
    K = np.kron(theta,np.eye(10))
    K_test = nt.empirical_nngp_fn(apply_fn)(x_test,None,params)
    K_test = np.kron(theta_test,np.eye(10))
    K_tilde = nt.empirical_nngp_fn(apply_fn)(x_test,x_train,params)
    K_tilde = np.kron(theta_tilde,np.eye(10))

    decay_matrix = np.eye(n)-scipy.linalg.expm(-t*theta)
    Sigma = K + np.matmul(decay_matrix, np.matmul(K, np.matmul(np.linalg.inv(theta), np.matmul(decay_matrix, theta))) - 2*K)

    # K.shape
    theta
    # alpha = np.matmul(np.linalg.inv(K),np.matmul(theta,np.linalg.inv(theta)))
    # y_train
    # alpha = np.matmul(np.linalg.inv(K), y_train.reshape(1280))
    # Sigma = K + np.matmul()
    # K = theta
    sigma_noise = 1.0
    Y = y_train.reshape(n)
    alpha = np.matmul(np.linalg.inv(np.eye(n)*(sigma_noise**2)+K),Y)
    # cov = np.linalg.inv(np.linalg.inv(K)+np.eye(n)/(sigma_noise**2))
    # covi = np.linalg.inv(cov)
    # covi = np.linalg.inv(K)+np.eye(n)/(sigma_noise**2)
    # print(covi)
    # np.linalg.det(K)
    eigs = np.linalg.eigh(K)[0]
    logdetcoviK = np.sum(np.log((eigs+sigma_noise**2) /sigma_noise**2))
    # coviK = np.matmul(covi,K)
    # coviK = np.eye(n) + K/(sigma_noise**2)
    # coviK
    # covi
    # np.linalg.det()
    # KL = 0.5*np.log(np.linalg.det(coviK)) + 0.5*np.trace(np.linalg.inv(coviK)) + 0.5*np.matmul(alpha.T,np.matmul(K,alpha)) - n/2
    KL = 0.5*logdetcoviK + 0.5*np.trace(np.linalg.inv(coviK)) + 0.5*np.matmul(alpha.T,np.matmul(K,alpha)) - n/2
    print(KL)

    delta = 2**-10
    bound = (KL+2*np.log(m)+1-np.log(delta))/m
    bound = 1-np.exp(-bound)
    bound
    print("bound", bound)

    import numpy
    bigK = numpy.zeros((n+n_test,n+n_test))
    bigK
    bigK[0:n,0:n] = K
    bigK[0:n,n:] = theta_tilde.T
    bigK[n:,0:n] = theta_tilde
    bigK[n:,n:] = theta_test
    init_ntk_f = numpy.random.multivariate_normal(np.zeros(n+n_test),bigK)
    fx_train = init_ntk_f[:n].reshape(m,10)
    fx_test = init_ntk_f[n:].reshape(m_test,10)

    # Get initial values of the network in function space.
    # fx_train = apply_fn(params, x_train)
    # fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
Exemplo n.º 14
0
     stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\
     stax.Relu(),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\
     stax.Relu(),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.Flatten(),\
     stax.Dense(10, W_std, b_std))
else:
    raise Exception('Invalid Input Error')

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2, ))
kernel_fn = nt.batch(kernel_fn, batch_size=20)

X1 = X[row_id * m:(row_id + 1) * m, :, :, :]
assert X1.shape[0] == m and X1.shape[1] == 32 and X1.shape[
    2] == 32 and X1.shape[3] == 3

# Training kernel
K = onp.zeros((m, n), dtype=onp.float32)
col_count = onp.int(n / m)
for col_id in range(row_id, col_count):
    t1 = time.time()
    X2 = X[col_id * m:(col_id + 1) * m, :, :, :]
    assert X2.shape[0] == m and X2.shape[1] == 32 and X2.shape[
        2] == 32 and X2.shape[3] == 3
    temp = kernel_fn(X1, X2, 'ntk')
    K[:, col_id * m:(col_id + 1) * m] = temp.astype(onp.float32)
Exemplo n.º 15
0
# params

# Create and initialize an optimizer.
opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
state = opt_init(params)
# state
#%%


# Create an mse loss function and a gradient function.
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

# Create an MSE predictor to solve the NTK equation in function space.
ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
g_dd = ntk(x_train, None, params)
g_td = ntk(x_test, x_train, params)
predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
#%%

m = FLAGS.train_size
n = m*10
m_test = FLAGS.test_size
n_test = m_test*10
# g_td.shape
# predictor
# g_dd
# type(g_dd)
# g_dd.shape
theta = g_dd.transpose((0,2,1,3)).reshape(n,n)
Exemplo n.º 16
0
from train import validate
import datasets

from matplotlib import pyplot as plt

# Model definition
init_fn, apply_fn, kernel_fn = nt.stax.serial(
    nt.stax.Dense(512, 1.0, 0.05),
    nt.stax.Erf(),
    nt.stax.Dense(512, 1.0, 0.05),
    nt.stax.Erf(),
    nt.stax.Dense(10, 1.0, 0.05),
)

apply_fn = jax.jit(apply_fn)
kernel_fn = nt.batch(kernel_fn, 64)


def kernel_fit(x_tr, y_tr, lam=1e-3):
    g_dd = kernel_fn(x_tr, None, "ntk")
    predictor = nt.predict.gradient_descent_mse(g_dd, y_tr - 0.1, diag_reg=lam)

    def model(x_te):
        g_td = kernel_fn(x_te, x_tr, "ntk")
        return predictor(None, None, -1, g_td)

    return model


n_train, n_test = 2048, 128
# Generating dataset
Exemplo n.º 17
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
      datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # x_train
    import numpy
    # numpy.argmax(y_train,1)%2
    # y_train_tmp = numpy.zeros((y_train.shape[0],2))
    # y_train_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_train,1)%2] = 1
    # y_train = y_train_tmp
    # y_test_tmp = numpy.zeros((y_test.shape[0],2))
    # y_test_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_test,1)%2] = 1
    # y_test = y_test_tmp

    y_train_tmp = numpy.argmax(y_train, 1) % 2
    y_train = np.expand_dims(y_train_tmp, 1)
    y_test_tmp = numpy.argmax(y_test, 1) % 2
    y_test = np.expand_dims(y_test_tmp, 1)
    # print(y_train)
    # Build the network
    # init_fn, apply_fn, _ = stax.serial(
    #   stax.Dense(2048, 1., 0.05),
    #   # stax.Erf(),
    #   stax.Relu(),
    #   stax.Dense(2048, 1., 0.05),
    #   # stax.Erf(),
    #   stax.Relu(),
    #   stax.Dense(10, 1., 0.05))
    init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                       stax.Dense(1, 1., 0.05))

    # key = random.PRNGKey(0)
    randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                           high=np.iinfo(np.int32).max,
                                           size=2)[0]
    key = random.PRNGKey(randnnn)
    _, params = init_fn(key, (-1, 784))

    # params

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)
    # state

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
    # g_dd.shape

    # Get initial values of the network in function space.
    fx_train = apply_fn(params, x_train)
    fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                       loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)