def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
    raise ValueError('Number of microbatches should divide evenly batch_size')

  # Load training and test data.
  train_data, train_labels, test_data, test_labels = load_mnist()

  # Define a sequential Keras model
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(16, 8,
                             strides=2,
                             padding='same',
                             activation='relu',
                             input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPool2D(2, 1),
      tf.keras.layers.Conv2D(32, 4,
                             strides=2,
                             padding='valid',
                             activation='relu'),
      tf.keras.layers.MaxPool2D(2, 1),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(32, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  if FLAGS.dpsgd:
    dp_average_query = GaussianAverageQuery(
        FLAGS.l2_norm_clip,
        FLAGS.l2_norm_clip * FLAGS.noise_multiplier,
        FLAGS.microbatches)
    optimizer = DPGradientDescentOptimizer(
        dp_average_query,
        FLAGS.microbatches,
        learning_rate=FLAGS.learning_rate,
        unroll_microbatches=True)
    # Compute vector of per-example loss rather than its mean over a minibatch.
    loss = tf.keras.losses.CategoricalCrossentropy(
        from_logits=True, reduction=tf.losses.Reduction.NONE)
  else:
    optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
    loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

  # Compile model with Keras
  model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

  # Train model with Keras
  model.fit(train_data, train_labels,
            epochs=FLAGS.epochs,
            validation_data=(test_data, test_labels),
            batch_size=FLAGS.batch_size)

  # Compute the privacy budget expended.
  if FLAGS.dpsgd:
    eps = compute_epsilon(FLAGS.epochs * 60000 // FLAGS.batch_size)
    print('For delta=1e-5, the current epsilon is: %.2f' % eps)
  else:
    print('Trained with vanilla non-private SGD optimizer')
Esempio n. 2
0
def main(_):
  # Fetch the mnist data
  train, test = tf.keras.datasets.mnist.load_data()
  train_images, train_labels = train
  test_images, test_labels = test

  # Create a dataset object and batch for the training data
  dataset = tf.data.Dataset.from_tensor_slices(
      (tf.cast(train_images[..., tf.newaxis]/255, tf.float32),
       tf.cast(train_labels, tf.int64)))
  dataset = dataset.shuffle(1000).batch(FLAGS.batch_size)

  # Create a dataset object and batch for the test data
  eval_dataset = tf.data.Dataset.from_tensor_slices(
      (tf.cast(test_images[..., tf.newaxis]/255, tf.float32),
       tf.cast(test_labels, tf.int64)))
  eval_dataset = eval_dataset.batch(10000)

  # Define the model using tf.keras.layers
  mnist_model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(16, 8,
                             strides=2,
                             padding='same',
                             activation='relu'),
      tf.keras.layers.MaxPool2D(2, 1),
      tf.keras.layers.Conv2D(32, 4, strides=2, activation='relu'),
      tf.keras.layers.MaxPool2D(2, 1),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(32, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  # Instantiate the optimizer
  if FLAGS.dpsgd:
    dp_average_query = GaussianAverageQuery(
        FLAGS.l2_norm_clip,
        FLAGS.l2_norm_clip * FLAGS.noise_multiplier,
        FLAGS.microbatches)
    opt = DPGradientDescentOptimizer(
        dp_average_query,
        FLAGS.microbatches,
        learning_rate=FLAGS.learning_rate)
  else:
    opt = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)


  # Training loop.
  steps_per_epoch = 60000 // FLAGS.batch_size

  from datetime import datetime
  start_times = list()
  end_times = list()

  for epoch in range(FLAGS.epochs):

    start_times.append(datetime.now())

    # Train the model for one epoch.
    for (_, (images, labels)) in enumerate(dataset.take(-1)):
      with tf.GradientTape(persistent=True) as gradient_tape:
        # This dummy call is needed to obtain the var list.
        logits = mnist_model(images, training=True)
        var_list = mnist_model.trainable_variables

        # In Eager mode, the optimizer takes a function that returns the loss.
        def loss_fn():
          logits = mnist_model(images, training=True)  # pylint: disable=undefined-loop-variable,cell-var-from-loop
          loss = tf.losses.sparse_softmax_cross_entropy(
              labels, logits, reduction=tf.losses.Reduction.NONE)  # pylint: disable=undefined-loop-variable,cell-var-from-loop
          # If training without privacy, the loss is a scalar not a vector.
          if not FLAGS.dpsgd:
            loss = tf.reduce_mean(loss)
          return loss

        if FLAGS.dpsgd:
          grads_and_vars = opt.compute_gradients(loss_fn, var_list,
                                                 gradient_tape=gradient_tape)
        else:
          grads_and_vars = opt.compute_gradients(loss_fn, var_list)

      global_step = tf.train.get_or_create_global_step()
      opt.apply_gradients(grads_and_vars, global_step=global_step)

    end_times.append(datetime.now())
    print('Trained epoch %s in %s seconds' % (FLAGS.epochs, (end_times[-1] - start_times[-1]).total_seconds()))
    
    # Evaluate the model and print results
    for (_, (images, labels)) in enumerate(eval_dataset.take(-1)):
      logits = mnist_model(images, training=False)
      correct_preds = tf.equal(tf.argmax(logits, axis=1), labels)
    test_accuracy = np.mean(correct_preds.numpy())
    print('Test accuracy after epoch %d is: %.3f' % (epoch, test_accuracy))

    # Compute the privacy budget expended so far.
    if FLAGS.dpsgd:
      eps = compute_epsilon(epoch * steps_per_epoch)
      print('For delta=1e-5, the current epsilon is: %.2f' % eps)
    else:
      print('Trained with vanilla non-private SGD optimizer')

  print('timing starts:')
  print(start_times)
  print('timing ends:')
  print(end_times)