def main(unused_argv):
  layer_sizes = [784, 1024, 1024, 10] # TODO: Revise to standard arch
  param_scale = 0.1
  step_size = 0.001
  num_epochs = 10
  batch_size = 32

  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train_images = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train_images, batch_size)
  num_batches = num_complete_batches + bool(leftover)

  def data_stream():
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(num_train_images)
      for i in range(num_batches):
        batch_idx = perm[i * batch_size:(i + 1) * batch_size]
        yield train_images[batch_idx], train_labels[batch_idx]

  batches = data_stream()

  @jit
  def update(params, batch):
    grads = grad(loss)(params, batch)
    return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

  params = init_random_params(param_scale, layer_sizes)
  for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
      params = update(params, next(batches))
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, (train_images, train_labels))
    test_acc = accuracy(params, (test_images, test_labels))
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))
def main(_):

  if FLAGS.microbatches:
    raise NotImplementedError(
        'Microbatches < batch size not currently supported'
    )

  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size)
  num_batches = num_complete_batches + bool(leftover)
  key = random.PRNGKey(FLAGS.seed)

  def data_stream():
    rng = npr.RandomState(FLAGS.seed)
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
        yield train_images[batch_idx], train_labels[batch_idx]

  batches = data_stream()

  opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

  @jit
  def update(_, i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

  @jit
  def private_update(rng, i, opt_state, batch):
    params = get_params(opt_state)
    rng = random.fold_in(rng, i)  # get new key for new random numbers
    return opt_update(
        i,
        private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                     FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

  _, init_params = init_random_params(key, (-1, 28, 28, 1))
  opt_state = opt_init(init_params)
  itercount = itertools.count()

  steps_per_epoch = 60000 // FLAGS.batch_size
  print('\nStarting training...')
  for epoch in range(1, FLAGS.epochs + 1):
    start_time = time.time()
    for _ in range(num_batches):
      if FLAGS.dpsgd:
        opt_state = \
            private_update(
                key, next(itercount), opt_state,
                shape_as_image(*next(batches), dummy_dim=True))
      else:
        opt_state = update(
            key, next(itercount), opt_state, shape_as_image(*next(batches)))
    epoch_time = time.time() - start_time
    print(f'Epoch {epoch} in {epoch_time:0.2f} sec')

    # evaluate test accuracy
    params = get_params(opt_state)
    test_acc = accuracy(params, shape_as_image(test_images, test_labels))
    test_loss = loss(params, shape_as_image(test_images, test_labels))
    print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format(
        test_loss, 100 * test_acc))

    # determine privacy loss so far
    if FLAGS.dpsgd:
      delta = 1e-5
      num_examples = 60000
      eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta)
      print(
          f'For delta={delta:.0e}, the current epsilon is: {eps:.2f}')
    else:
      print('Trained with vanilla non-private SGD optimizer')