def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.mnist(permute_train=True) # Build the network init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum( FLAGS.learning_rate, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td) # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
return np.mean(predicted_class == target_class) init_random_params, predict = stax.serial( Dense(1024), Relu, Dense(1024), Relu, Dense(10), LogSoftmax) if __name__ == "__main__": rng = random.PRNGKey(0) step_size = 0.001 num_epochs = 10 batch_size = 128 momentum_mass = 0.9 train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass)
def main(_): if FLAGS.microbatches: raise NotImplementedError( 'Microbatches < batch size not currently supported' ) train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) key = random.PRNGKey(FLAGS.seed) def data_stream(): rng = npr.RandomState(FLAGS.seed) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) _, init_params = init_random_params(key, (-1, 28, 28, 1)) opt_state = opt_init(init_params) itercount = itertools.count() steps_per_epoch = 60000 // FLAGS.batch_size print('\nStarting training...') for epoch in range(1, FLAGS.epochs + 1): start_time = time.time() # pylint: disable=no-value-for-parameter for _ in range(num_batches): if FLAGS.dpsgd: opt_state = \ private_update( key, next(itercount), opt_state, shape_as_image(*next(batches), dummy_dim=True)) else: opt_state = update( key, next(itercount), opt_state, shape_as_image(*next(batches))) # pylint: enable=no-value-for-parameter epoch_time = time.time() - start_time print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time)) # evaluate test accuracy params = get_params(opt_state) test_acc = accuracy(params, shape_as_image(test_images, test_labels)) test_loss = loss(params, shape_as_image(test_images, test_labels)) print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format( test_loss, 100 * test_acc)) # determine privacy loss so far if FLAGS.dpsgd: delta = 1e-5 num_examples = 60000 eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta) print( 'For delta={:.0e}, the current epsilon is: {:.2f}'.format(delta, eps)) else: print('Trained with vanilla non-private SGD optimizer')
Dense(512), Relu, Dense(28 * 28), ) if __name__ == "__main__": step_size = 0.001 num_epochs = 100 batch_size = 32 nrow, ncol = 10, 10 # sampled image grid size rng = random.PRNGKey(0) test_rng = random.PRNGKey(1) # fixed prng key for evaluation imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png") train_images, _, test_images, _ = datasets.mnist(permute_train=True) num_complete_batches, leftover = divmod(train_images.shape[0], batch_size) num_batches = num_complete_batches + bool(leftover) _, init_encoder_params = encoder_init((batch_size, 28 * 28)) _, init_decoder_params = decoder_init((batch_size, 10)) init_params = init_encoder_params, init_decoder_params opt_init, opt_update = minmax.momentum(step_size, mass=0.9) def binarize_batch(rng, i, images): i = i % num_batches batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size) return random.bernoulli(rng, batch) @jit
#%% FLAGS = {} FLAGS["learning_rate"] = 1.0 FLAGS["train_size"] = 128 FLAGS["test_size"] = 128 FLAGS["train_time"] = 10000.0 class Struct: def __init__(self, **entries): self.__dict__.update(entries) FLAGS=Struct(**FLAGS) print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # params # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # x_train import numpy # numpy.argmax(y_train,1)%2 # y_train_tmp = numpy.zeros((y_train.shape[0],2)) # y_train_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_train,1)%2] = 1 # y_train = y_train_tmp # y_test_tmp = numpy.zeros((y_test.shape[0],2)) # y_test_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_test,1)%2] = 1 # y_test = y_test_tmp y_train_tmp = numpy.argmax(y_train, 1) % 2 y_train = np.expand_dims(y_train_tmp, 1) y_test_tmp = numpy.argmax(y_test, 1) % 2 y_test = np.expand_dims(y_test_tmp, 1) # print(y_train) # Build the network # init_fn, apply_fn, _ = stax.serial( # stax.Dense(2048, 1., 0.05), # # stax.Erf(), # stax.Relu(), # stax.Dense(2048, 1., 0.05), # # stax.Erf(), # stax.Relu(), # stax.Dense(10, 1., 0.05)) init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(1, 1., 0.05)) # key = random.PRNGKey(0) randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) # params # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # state # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td) # g_dd.shape # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)