def main(unused_argv): layer_sizes = [784, 1024, 1024, 10] # TODO: Revise to standard arch param_scale = 0.1 step_size = 0.001 num_epochs = 10 batch_size = 32 train_images, train_labels, test_images, test_labels = datasets.mnist() num_train_images = train_images.shape[0] num_complete_batches, leftover = divmod(num_train_images, batch_size) num_batches = num_complete_batches + bool(leftover) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train_images) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() @jit def update(params, batch): grads = grad(loss)(params, batch) return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] params = init_random_params(param_scale, layer_sizes) for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): params = update(params, next(batches)) epoch_time = time.time() - start_time train_acc = accuracy(params, (train_images, train_labels)) test_acc = accuracy(params, (test_images, test_labels)) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {}".format(train_acc)) print("Test set accuracy {}".format(test_acc))
def main(_): if FLAGS.microbatches: raise NotImplementedError( 'Microbatches < batch size not currently supported' ) train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) key = random.PRNGKey(FLAGS.seed) def data_stream(): rng = npr.RandomState(FLAGS.seed) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) _, init_params = init_random_params(key, (-1, 28, 28, 1)) opt_state = opt_init(init_params) itercount = itertools.count() steps_per_epoch = 60000 // FLAGS.batch_size print('\nStarting training...') for epoch in range(1, FLAGS.epochs + 1): start_time = time.time() for _ in range(num_batches): if FLAGS.dpsgd: opt_state = \ private_update( key, next(itercount), opt_state, shape_as_image(*next(batches), dummy_dim=True)) else: opt_state = update( key, next(itercount), opt_state, shape_as_image(*next(batches))) epoch_time = time.time() - start_time print(f'Epoch {epoch} in {epoch_time:0.2f} sec') # evaluate test accuracy params = get_params(opt_state) test_acc = accuracy(params, shape_as_image(test_images, test_labels)) test_loss = loss(params, shape_as_image(test_images, test_labels)) print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format( test_loss, 100 * test_acc)) # determine privacy loss so far if FLAGS.dpsgd: delta = 1e-5 num_examples = 60000 eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta) print( f'For delta={delta:.0e}, the current epsilon is: {eps:.2f}') else: print('Trained with vanilla non-private SGD optimizer')