def mnist_tutorial(train_start=0,
                   train_end=60000,
                   test_start=0,
                   test_end=10000,
                   nb_epochs=NB_EPOCHS,
                   batch_size=BATCH_SIZE,
                   learning_rate=LEARNING_RATE,
                   clean_train=CLEAN_TRAIN,
                   testing=False,
                   backprop_through_attack=BACKPROP_THROUGH_ATTACK,
                   nb_filters=NB_FILTERS,
                   num_threads=None,
                   label_smoothing=0.1):
    """
  MNIST cleverhans tutorial
  :param train_start: index of first training set example
  :param train_end: index of last training set example
  :param test_start: index of first test set example
  :param test_end: index of last test set example
  :param nb_epochs: number of epochs to train model
  :param batch_size: size of training batches
  :param learning_rate: learning rate for training
  :param clean_train: perform normal training on clean examples only
                      before performing adversarial training.
  :param testing: if true, complete an AccuracyReport for unit tests
                  to verify that performance is adequate
  :param backprop_through_attack: If True, backprop through adversarial
                                  example construction process during
                                  adversarial training.
  :param label_smoothing: float, amount of label smoothing for cross entropy
  :return: an AccuracyReport object
  """

    # Object used to keep track of (and return) key accuracies
    report = AccuracyReport()

    # Set TF random seed to improve reproducibility
    tf.set_random_seed(1234)

    # Set logging level to see debug information
    set_log_level(logging.DEBUG)

    # Create TF session
    if num_threads:
        config_args = dict(intra_op_parallelism_threads=1)
    else:
        config_args = {}
    sess = tf.Session(config=tf.ConfigProto(**config_args))

    # Get MNIST test data
    mnist = MNIST(train_start=train_start,
                  train_end=train_end,
                  test_start=test_start,
                  test_end=test_end)
    x_train, y_train = mnist.get_set('train')
    x_test, y_test = mnist.get_set('test')

    # Use Image Parameters
    img_rows, img_cols, nchannels = x_train.shape[1:4]
    nb_classes = y_train.shape[1]

    # Define input TF placeholder
    x = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, nchannels))
    y = tf.placeholder(tf.float32, shape=(None, nb_classes))

    # Train an MNIST model
    train_params = {
        'nb_epochs': nb_epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate
    }
    eval_params = {'batch_size': batch_size}
    fgsm_params = {'eps': 0.3, 'clip_min': 0., 'clip_max': 1.}
    rng = np.random.RandomState([2017, 8, 30])

    def do_eval(preds, x_set, y_set, report_key, is_adv=None):
        acc = model_eval(sess, x, y, preds, x_set, y_set, args=eval_params)
        setattr(report, report_key, acc)
        if is_adv is None:
            report_text = None
        elif is_adv:
            report_text = 'adversarial'
        else:
            report_text = 'legitimate'
        if report_text:
            print('Test accuracy on %s examples: %0.4f' % (report_text, acc))

    if clean_train:
        model = make_basic_picklable_cnn()
        # Tag the model so that when it is saved to disk, future scripts will
        # be able to tell what data it was trained on
        model.dataset_factory = mnist.get_factory()
        preds = model.get_logits(x)
        assert len(model.get_params()) > 0
        loss = CrossEntropy(model, smoothing=label_smoothing)

        def evaluate():
            do_eval(preds, x_test, y_test, 'clean_train_clean_eval', False)

        train(sess,
              loss,
              x_train,
              y_train,
              evaluate=evaluate,
              args=train_params,
              rng=rng,
              var_list=model.get_params())

        with sess.as_default():
            save("clean_model.joblib", model)

            print("Now that the model has been saved, you can evaluate it in a"
                  " separate process using `evaluate_pickled_model.py`. "
                  "You should get exactly the same result for both clean and "
                  "adversarial accuracy as you get within this program.")

        # Calculate training error
        if testing:
            do_eval(preds, x_train, y_train, 'train_clean_train_clean_eval')

        # Initialize the Fast Gradient Sign Method (FGSM) attack object and
        # graph
        fgsm = FastGradientMethod(model, sess=sess)
        adv_x = fgsm.generate(x, **fgsm_params)
        preds_adv = model.get_logits(adv_x)

        # Evaluate the accuracy of the MNIST model on adversarial examples
        do_eval(preds_adv, x_test, y_test, 'clean_train_adv_eval', True)

        # Calculate training error
        if testing:
            do_eval(preds_adv, x_train, y_train, 'train_clean_train_adv_eval')

        print('Repeating the process, using adversarial training')

    # Create a new model and train it to be robust to FastGradientMethod
    model2 = make_basic_picklable_cnn()
    # Tag the model so that when it is saved to disk, future scripts will
    # be able to tell what data it was trained on
    model2.dataset_factory = mnist.get_factory()
    fgsm2 = FastGradientMethod(model2, sess=sess)

    def attack(x):
        return fgsm2.generate(x, **fgsm_params)

    loss2 = CrossEntropy(model2, smoothing=label_smoothing, attack=attack)
    preds2 = model2.get_logits(x)
    adv_x2 = attack(x)

    if not backprop_through_attack:
        # For the fgsm attack used in this tutorial, the attack has zero
        # gradient so enabling this flag does not change the gradient.
        # For some other attacks, enabling this flag increases the cost of
        # training, but gives the defender the ability to anticipate how
        # the atacker will change their strategy in response to updates to
        # the defender's parameters.
        adv_x2 = tf.stop_gradient(adv_x2)
    preds2_adv = model2.get_logits(adv_x2)

    def evaluate2():
        # Accuracy of adversarially trained model on legitimate test inputs
        do_eval(preds2, x_test, y_test, 'adv_train_clean_eval', False)
        # Accuracy of the adversarially trained model on adversarial examples
        do_eval(preds2_adv, x_test, y_test, 'adv_train_adv_eval', True)

    # Perform and evaluate adversarial training
    train(sess,
          loss2,
          x_train,
          y_train,
          evaluate=evaluate2,
          args=train_params,
          rng=rng,
          var_list=model2.get_params())

    with sess.as_default():
        save("adv_model.joblib", model2)
        print(
            "Now that the model has been saved, you can evaluate it in a "
            "separate process using "
            "`python evaluate_pickled_model.py adv_model.joblib`. "
            "You should get exactly the same result for both clean and "
            "adversarial accuracy as you get within this program."
            " You can also move beyond the tutorials directory and run the "
            " real `compute_accuracy.py` script (make sure cleverhans/scripts "
            "is in your PATH) to see that this FGSM-trained "
            "model is actually not very robust---it's just a model that trains "
            " quickly so the tutorial does not take a long time")

    # Calculate training errors
    if testing:
        do_eval(preds2, x_train, y_train, 'train_adv_train_clean_eval')
        do_eval(preds2_adv, x_train, y_train, 'train_adv_train_adv_eval')

    return report
예제 #2
0
def do_train(train_start=TRAIN_START,
             train_end=60000,
             test_start=0,
             test_end=10000,
             nb_epochs=NB_EPOCHS,
             batch_size=BATCH_SIZE,
             learning_rate=LEARNING_RATE,
             backprop_through_attack=False,
             nb_filters=NB_FILTERS,
             num_threads=None,
             use_ema=USE_EMA,
             ema_decay=EMA_DECAY):
    print('Parameters')
    print('-' * 79)
    for x, y in sorted(locals().items()):
        print('%-32s %s' % (x, y))
    print('-' * 79)

    if os.path.exists(FLAGS.save_path):
        print("Model " + FLAGS.save_path +
              " already exists. Refusing to overwrite.")
        quit()

    # Set TF random seed to improve reproducibility
    tf.set_random_seed(1234)

    # Create TF session
    if num_threads:
        config_args = dict(intra_op_parallelism_threads=1)
    else:
        config_args = {}
    sess = tf.Session(config=tf.ConfigProto(**config_args))

    dataset = MNIST(train_start=train_start,
                    train_end=train_end,
                    test_start=test_start,
                    test_end=test_end,
                    center=True)

    # Use Image Parameters
    img_rows, img_cols, nchannels = dataset.x_train.shape[1:4]
    nb_classes = dataset.NB_CLASSES

    # Define input TF placeholder
    x = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, nchannels))
    y = tf.placeholder(tf.float32, shape=(None, nb_classes))

    train_params = {
        'nb_epochs': nb_epochs,
        'learning_rate': learning_rate,
        'batch_size': batch_size,
    }
    eval_params = {'batch_size': batch_size}
    rng = np.random.RandomState([2017, 8, 30])
    sess = tf.Session()

    def do_eval(x_set, y_set, is_adv=None):
        acc = accuracy(sess, model, x_set, y_set)
        if is_adv is None:
            report_text = None
        elif is_adv:
            report_text = 'adversarial'
        else:
            report_text = 'clean'
        if report_text:
            print('Accuracy on %s examples: %0.4f' % (report_text, acc))
        return acc

    model = Model(filters=nb_filters)
    model.dataset_factory = dataset.get_factory()

    pgd = ProjectedGradientDescent(model=model, sess=sess)

    center = dataset.kwargs['center']
    value_range = 1. + center
    base_eps = 8. / 255.

    attack_params = {
        'eps': base_eps * value_range,
        'clip_min': -float(center),
        'clip_max': float(center),
        'eps_iter': (2. / 255.) * value_range,
        'nb_iter': 40.
    }

    loss = CrossEntropy(
        model,
        attack=pgd,
        adv_coeff=1.,
        attack_params=attack_params,
    )

    print_test_period = 10
    print_train_period = 50

    def evaluate():
        global epoch
        global last_test_print
        global last_train_print
        global best_result
        global best_epoch
        with sess.as_default():
            print("Saving to ", FLAGS.save_path)
            save(FLAGS.save_path, model)
        if epoch % print_test_period == 0 or time.time(
        ) - last_test_print > 300:
            t1 = time.time()
            result = do_eval(dataset.x_test, dataset.y_test, False)
            t2 = time.time()
            if result >= best_result:
                if result > best_result:
                    best_epoch = epoch
                else:
                    # Keep track of ties
                    assert result == best_result
                    if not isinstance(best_epoch, list):
                        if best_epoch == -1:
                            best_epoch = []
                        else:
                            best_epoch = [best_epoch]
                    best_epoch.append(epoch)
                best_result = result
            print("Best so far: ", best_result)
            print("Best epoch: ", best_epoch)
            last_test_print = t2
            print("Test eval time: ", t2 - t1)
        if (epoch % print_train_period == 0
                or time.time() - last_train_print > 3000):
            t1 = time.time()
            print("Training set: ")
            do_eval(dataset.x_train, dataset.y_train, False)
            t2 = time.time()
            print("Train eval time: ", t2 - t1)
            last_train_print = t2
        epoch += 1

    optimizer = None

    ema_decay = globals()[ema_decay]
    assert callable(ema_decay)

    train(sess,
          loss,
          dataset.x_train,
          dataset.y_train,
          evaluate=evaluate,
          optimizer=optimizer,
          args=train_params,
          rng=rng,
          var_list=model.get_params(),
          use_ema=use_ema,
          ema_decay=ema_decay)
    # Make sure we always evaluate on the last epoch, so pickling bugs are more
    # obvious
    if (epoch - 1) % print_test_period != 0:
        do_eval(dataset.x_test, dataset.y_test, False)
    if (epoch - 1) % print_train_period != 0:
        print("Training set: ")
        do_eval(dataset.x_train, dataset.y_train, False)

    with sess.as_default():
        save(FLAGS.save_path, model)