def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum( FLAGS.learning_rate, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def testLinearization(self, shape): key, params, x0 = self._get_init_data(shape) f_lin = nt.linearize(EmpiricalTest.f, x0) for _ in range(TAYLOR_RANDOM_SAMPLES): for do_alter in [True, False]: for do_shift_x in [True, False]: key, split = random.split(key) x = random.normal(split, (shape[-1], 1)) self.assertAllClose( EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_lin(x, params, do_alter, do_shift_x=do_shift_x))
print( f"Total number of parameters={sum([np.prod(x.shape) for x in params_flat])}" ) sys.exit() # Load pre-saved model # if args.load_path is not None: params = load_jax_params(params, args.load_path) # (Optionally) Taylorize the network # tb_flag = 'f' params_0 = copy_jax_array(params) if args.load_path_param0 is not None: params_0 = load_jax_params(params_0, args.load_path_param0) if args.linearize: f = nt.linearize(f, params_0) tb_flag = 'f_lin' elif args.expand_order >= 2: f = nt.taylor_expand(f, params_0, args.expand_order) if args.expand_order == 2: tb_flag = 'f_quad' elif args.expand_order == 3: tb_flag = 'f_cubic' else: tb_flag = f'f_order_{args.expand_order}' # Define optimizer, loss, and accuracy # Learning rate decay schedule if args.lr_decay and args.decay_epoch is not None: def lr_schedule(epoch):
def main(unused_argv): # print(f'Available GPU memory: {util.get_gpu_memory()}') # Load and normalize data print('Loading data...') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', n_train=60000, n_test=10000, permute_train=True) # print(f'Available GPU memory: {util.get_gpu_memory()}') # Reformat MNIST data to 28x28x1 pictures x_train = np.asarray(x_train.reshape(-1, 28, 28, 1)) x_test = np.asarray(x_test.reshape(-1, 28, 28, 1)) print('Data loaded and reshaped') # print(f'Available GPU memory: {util.get_gpu_memory()}') # Set random seed key = random.PRNGKey(0) # # Add random translation to images # x_train = util.add_translation(x_train, FLAGS.max_pixel) # x_test = util.add_translation(x_test, FLAGS.max_pixel) # print(f'Random translation by up to {FLAGS.max_pixel} pixels added') # # Add random translations with padding # x_train = util.add_padded_translation(x_train, 10) # x_test = util.add_padded_translation(x_test, 10) # print(f'Random translations with additional padding up to 10 pixels added') # Build the LeNet network with NTK parameterization init_fn, f, kernel_fn = util.build_le_net(FLAGS.network_width) print(f'Network of width x{FLAGS.network_width} built.') # # Construct the kernel function # kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size_kernel) # print('Kernel constructed') # print(f'Available GPU memory: {util.get_gpu_memory()}') # Compute random initial parameters _, params = init_fn(key, (-1, 28, 28, 1)) params_lin = params print('Initial parameters constructed') # print(f'Available GPU memory: {util.get_gpu_memory()}') # # Save initial parameters # with open('init_params.npy', 'wb') as file: # np.save(file, params) # Linearize the network about its initial parameters. # Use jit for faster GPU computation (only feasible for width < 25) f_lin = nt.linearize(f, params) if FLAGS.network_width <= 10: f_jit = jit(f) f_lin_jit = jit(f_lin) else: f_jit = f f_lin_jit = f_lin # Create a callable function for dynamic learning rates # Starts with learning_rate, divided by 10 after learning_decline epochs. dynamic_learning_rate = lambda iteration_step: FLAGS.learning_rate / 10**( (iteration_step // (x_train.shape[0] // FLAGS.batch_size)) // FLAGS.learning_decline) # Create and initialize an optimizer for both f and f_lin. # Use momentum with coefficient 0.9 and jit opt_init, opt_apply, get_params = optimizers.momentum( dynamic_learning_rate, 0.9) opt_apply = jit(opt_apply) # Compute the initial states state = opt_init(params) state_lin = opt_init(params) # Define the accuracy function accuracy = lambda fx, y_hat: np.mean( np.argmax(fx, axis=1) == np.argmax(y_hat, axis=1)) # Define mean square error loss function loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) # # Create a cross-entropy loss function. # loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print( f'Training with dynamic learning decline after {FLAGS.learning_decline} epochs...' ) print( 'Epoch\tTime\tAccuracy\tLin. Accuracy\tLoss\tLin. Loss\tAccuracy Train\tLin.Accuracy Train' ) print( '----------------------------------------------------------------------------------------------------------' ) # Initialize training epoch = 0 steps_per_epoch = x_train.shape[0] // FLAGS.batch_size # Set start time (total and 100 epochs) start = time.time() start_epoch = time.time() for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): # Update the parameters params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) # Print information after each 100 epochs if (i + 1) % (steps_per_epoch * 100) == 0: time_point = time.time() - start_epoch # Update epoch epoch += 100 # Accuracy in batches f_x = util.output_in_batches(x_train, params, f_jit, FLAGS.batch_count_accuracy) f_x_test = util.output_in_batches(x_test, params, f_jit, FLAGS.batch_count_accuracy) f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit, FLAGS.batch_count_accuracy) f_x_lin_test = util.output_in_batches(x_test, params_lin, f_lin_jit, FLAGS.batch_count_accuracy) # time_point = time.time() - start_epoch # Print information about past 100 epochs print( '{}\t{:.3f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}' .format(epoch, time_point, accuracy(f_x, y_train) * 100, accuracy(f_x_lin, y_train) * 100, loss(f_x, y_train), loss(f_x_lin, y_train), accuracy(f_x_test, y_test) * 100, accuracy(f_x_lin_test, y_test) * 100)) # # Save params if epoch is multiple of learning decline or multiple of fixed value # if epoch % FLAGS.learning_decline == 0: # filename = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}.npy' # with open(filename, 'wb') as file: # np.save(file, params) # filename_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}_lin.npy' # with open(filename_lin, 'wb') as file_lin: # np.save(file_lin, params_lin) # Reset timer start_epoch = time.time() duration = time.time() - start print( '----------------------------------------------------------------------------------------------------------' ) print(f'Training complete in {duration} seconds.') # # Save final params in file # filename_final = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}.npy ' # with open(filename_final, 'wb') as final: # np.save(final, params) # filename_final_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}_lin.npy' # with open(filename_final_lin, 'wb') as final_lin: # np.save(final_lin, params_lin) # Compute output in batches f_x = util.output_in_batches(x_train, params, f_jit, FLAGS.batch_count_accuracy) f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit, FLAGS.batch_count_accuracy) f_x_test = util.output_in_batches(x_test, params, f_jit, FLAGS.batch_count_accuracy) f_x_lin_test = util.output_in_batches(x_test, params_lin, f_lin_jit, FLAGS.batch_count_accuracy) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, f_x, f_x_lin, loss) util.print_summary('test', y_test, f_x_test, f_x_lin_test, loss)
def weight_space(train_embedding, test_embedding, data_set): init_fn, f, _ = stax.serial( stax.Dense(512, 1., 0.05), stax.Erf(), # 2 denotes 2 type of classes stax.Dense(2, 1., 0.05)) key = random.PRNGKey(0) # (-1, 135), 135 denotes the feature length, here is 9 * 15 = 135 _, params = init_fn(key, (-1, 135)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 # Use whole batch batch_size = 64 train_epochs = 10 steps_per_epoch = 100 for i, (x, y) in enumerate( datasets.mini_batch(train_embedding, data_set['Y_train'], batch_size, train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 if i / steps_per_epoch == train_epochs: break # Print out summary data comparing the linear / nonlinear model. x, y = train_embedding[:10000], data_set['Y_train'][:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', data_set['Y_test'], f(params, test_embedding), f_lin(params_lin, test_embedding), loss)