示例#1
0
def main(unused_argv):
    # Build data and .
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            permute_train=True)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(
        FLAGS.learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    steps_per_epoch = 50000 // FLAGS.batch_size

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1

    # Print out summary data comparing the linear / nonlinear model.
    x, y = x_train[:10000], y_train[:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', y_test, f(params, x_test),
                       f_lin(params_lin, x_test), loss)
示例#2
0
    def testLinearization(self, shape):
        key, params, x0 = self._get_init_data(shape)

        f_lin = nt.linearize(EmpiricalTest.f, x0)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    key, split = random.split(key)
                    x = random.normal(split, (shape[-1], 1))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))
示例#3
0
    print(
        f"Total number of parameters={sum([np.prod(x.shape) for x in params_flat])}"
    )
    sys.exit()

# Load pre-saved model #
if args.load_path is not None:
    params = load_jax_params(params, args.load_path)

# (Optionally) Taylorize the network #
tb_flag = 'f'
params_0 = copy_jax_array(params)
if args.load_path_param0 is not None:
    params_0 = load_jax_params(params_0, args.load_path_param0)
if args.linearize:
    f = nt.linearize(f, params_0)
    tb_flag = 'f_lin'
elif args.expand_order >= 2:
    f = nt.taylor_expand(f, params_0, args.expand_order)
    if args.expand_order == 2:
        tb_flag = 'f_quad'
    elif args.expand_order == 3:
        tb_flag = 'f_cubic'
    else:
        tb_flag = f'f_order_{args.expand_order}'

# Define optimizer, loss, and accuracy
# Learning rate decay schedule
if args.lr_decay and args.decay_epoch is not None:

    def lr_schedule(epoch):
示例#4
0
def main(unused_argv):
    # print(f'Available GPU memory: {util.get_gpu_memory()}')
    # Load and normalize data
    print('Loading data...')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            n_train=60000,
                                                            n_test=10000,
                                                            permute_train=True)
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Reformat MNIST data to 28x28x1 pictures
    x_train = np.asarray(x_train.reshape(-1, 28, 28, 1))
    x_test = np.asarray(x_test.reshape(-1, 28, 28, 1))
    print('Data loaded and reshaped')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Set random seed
    key = random.PRNGKey(0)

    # # Add random translation to images
    # x_train = util.add_translation(x_train, FLAGS.max_pixel)
    # x_test = util.add_translation(x_test, FLAGS.max_pixel)
    # print(f'Random translation by up to {FLAGS.max_pixel} pixels added')

    # # Add random translations with padding
    # x_train = util.add_padded_translation(x_train, 10)
    # x_test = util.add_padded_translation(x_test, 10)
    # print(f'Random translations with additional padding up to 10 pixels added')

    # Build the LeNet network with NTK parameterization
    init_fn, f, kernel_fn = util.build_le_net(FLAGS.network_width)
    print(f'Network of width x{FLAGS.network_width} built.')

    # # Construct the kernel function
    # kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size_kernel)
    # print('Kernel constructed')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Compute random initial parameters
    _, params = init_fn(key, (-1, 28, 28, 1))
    params_lin = params

    print('Initial parameters constructed')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # # Save initial parameters
    # with open('init_params.npy', 'wb') as file:
    #     np.save(file, params)

    # Linearize the network about its initial parameters.
    # Use jit for faster GPU computation (only feasible for width < 25)
    f_lin = nt.linearize(f, params)
    if FLAGS.network_width <= 10:
        f_jit = jit(f)
        f_lin_jit = jit(f_lin)
    else:
        f_jit = f
        f_lin_jit = f_lin

    # Create a callable function for dynamic learning rates
    # Starts with learning_rate, divided by 10 after learning_decline epochs.
    dynamic_learning_rate = lambda iteration_step: FLAGS.learning_rate / 10**(
        (iteration_step //
         (x_train.shape[0] // FLAGS.batch_size)) // FLAGS.learning_decline)

    # Create and initialize an optimizer for both f and f_lin.
    # Use momentum with coefficient 0.9 and jit
    opt_init, opt_apply, get_params = optimizers.momentum(
        dynamic_learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    # Compute the initial states
    state = opt_init(params)
    state_lin = opt_init(params)

    # Define the accuracy function
    accuracy = lambda fx, y_hat: np.mean(
        np.argmax(fx, axis=1) == np.argmax(y_hat, axis=1))

    # Define mean square error loss function
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)

    # # Create a cross-entropy loss function.
    # loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print(
        f'Training with dynamic learning decline after {FLAGS.learning_decline} epochs...'
    )
    print(
        'Epoch\tTime\tAccuracy\tLin. Accuracy\tLoss\tLin. Loss\tAccuracy Train\tLin.Accuracy Train'
    )
    print(
        '----------------------------------------------------------------------------------------------------------'
    )

    # Initialize training
    epoch = 0
    steps_per_epoch = x_train.shape[0] // FLAGS.batch_size

    # Set start time (total and 100 epochs)
    start = time.time()
    start_epoch = time.time()

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        # Update the parameters
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        # Print information after each 100 epochs
        if (i + 1) % (steps_per_epoch * 100) == 0:
            time_point = time.time() - start_epoch

            # Update epoch
            epoch += 100

            # Accuracy in batches
            f_x = util.output_in_batches(x_train, params, f_jit,
                                         FLAGS.batch_count_accuracy)
            f_x_test = util.output_in_batches(x_test, params, f_jit,
                                              FLAGS.batch_count_accuracy)
            f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit,
                                             FLAGS.batch_count_accuracy)
            f_x_lin_test = util.output_in_batches(x_test, params_lin,
                                                  f_lin_jit,
                                                  FLAGS.batch_count_accuracy)
            # time_point = time.time() - start_epoch

            # Print information about past 100 epochs
            print(
                '{}\t{:.3f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}'
                .format(epoch, time_point,
                        accuracy(f_x, y_train) * 100,
                        accuracy(f_x_lin, y_train) * 100, loss(f_x, y_train),
                        loss(f_x_lin, y_train),
                        accuracy(f_x_test, y_test) * 100,
                        accuracy(f_x_lin_test, y_test) * 100))

            # # Save params if epoch is multiple of learning decline or multiple of fixed value
            # if epoch % FLAGS.learning_decline == 0:
            #     filename = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}.npy'
            #     with open(filename, 'wb') as file:
            #         np.save(file, params)
            #     filename_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}_lin.npy'
            #     with open(filename_lin, 'wb') as file_lin:
            #         np.save(file_lin, params_lin)

            # Reset timer
            start_epoch = time.time()

    duration = time.time() - start
    print(
        '----------------------------------------------------------------------------------------------------------'
    )
    print(f'Training complete in {duration} seconds.')

    # # Save final params in file
    # filename_final = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}.npy '
    # with open(filename_final, 'wb') as final:
    #     np.save(final, params)
    # filename_final_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}_lin.npy'
    # with open(filename_final_lin, 'wb') as final_lin:
    #     np.save(final_lin, params_lin)

    # Compute output in batches
    f_x = util.output_in_batches(x_train, params, f_jit,
                                 FLAGS.batch_count_accuracy)
    f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit,
                                     FLAGS.batch_count_accuracy)

    f_x_test = util.output_in_batches(x_test, params, f_jit,
                                      FLAGS.batch_count_accuracy)
    f_x_lin_test = util.output_in_batches(x_test, params_lin, f_lin_jit,
                                          FLAGS.batch_count_accuracy)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f_x, f_x_lin, loss)
    util.print_summary('test', y_test, f_x_test, f_x_lin_test, loss)
示例#5
0
def weight_space(train_embedding, test_embedding, data_set):
    init_fn, f, _ = stax.serial(
        stax.Dense(512, 1., 0.05),
        stax.Erf(),
        # 2 denotes 2 type of classes
        stax.Dense(2, 1., 0.05))

    key = random.PRNGKey(0)
    # (-1, 135),  135 denotes the feature length, here is 9 * 15 = 135
    _, params = init_fn(key, (-1, 135))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    # Use whole batch
    batch_size = 64
    train_epochs = 10
    steps_per_epoch = 100

    for i, (x, y) in enumerate(
            datasets.mini_batch(train_embedding, data_set['Y_train'],
                                batch_size, train_epochs)):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1
        if i / steps_per_epoch == train_epochs:
            break

    # Print out summary data comparing the linear / nonlinear model.
    x, y = train_embedding[:10000], data_set['Y_train'][:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', data_set['Y_test'], f(params, test_embedding),
                       f_lin(params_lin, test_embedding), loss)