예제 #1
0
def main():
    setup_train_experiment(logger, FLAGS, "%(model)s_at")

    logger.info("Loading data...")
    data = mnist_load(FLAGS.train_size, FLAGS.seed)
    X_train, y_train = data.X_train, data.y_train
    X_val, y_val = data.X_val, data.y_val
    X_test, y_test = data.X_test, data.y_test

    img_shape = [None, 1, 28, 28]
    train_images = T.tensor4('train_images')
    train_labels = T.lvector('train_labels')
    val_images = T.tensor4('valid_labels')
    val_labels = T.lvector('valid_labels')

    layer_dims = [int(dim) for dim in FLAGS.layer_dims.split("-")]
    num_classes = layer_dims[-1]
    net = create_network(FLAGS.model, img_shape, layer_dims=layer_dims)
    model = with_end_points(net)

    train_outputs = model(train_images)
    val_outputs = model(val_images, deterministic=True)

    # losses
    train_ce = categorical_crossentropy(train_outputs['prob'],
                                        train_labels).mean()
    train_at = adversarial_training(lambda x: model(x)['prob'],
                                    train_images,
                                    train_labels,
                                    epsilon=FLAGS.epsilon).mean()
    train_loss = train_ce + FLAGS.lmbd * train_at
    val_ce = categorical_crossentropy(val_outputs['prob'], val_labels).mean()
    val_deepfool_images = deepfool(
        lambda x: model(x, deterministic=True)['logits'],
        val_images,
        val_labels,
        num_classes,
        max_iter=FLAGS.deepfool_iter,
        clip_dist=FLAGS.deepfool_clip,
        over_shoot=FLAGS.deepfool_overshoot)

    # metrics
    train_acc = categorical_accuracy(train_outputs['logits'],
                                     train_labels).mean()
    train_err = 1.0 - train_acc
    val_acc = categorical_accuracy(val_outputs['logits'], val_labels).mean()
    val_err = 1.0 - val_acc
    # deepfool robustness
    reduc_ind = range(1, train_images.ndim)
    l2_deepfool = (val_deepfool_images - val_images).norm(2, axis=reduc_ind)
    l2_deepfool_norm = l2_deepfool / val_images.norm(2, axis=reduc_ind)

    train_metrics = OrderedDict([('loss', train_loss), ('nll', train_ce),
                                 ('at', train_at), ('err', train_err)])
    val_metrics = OrderedDict([('nll', val_ce), ('err', val_err)])
    summary_metrics = OrderedDict([('l2', l2_deepfool.mean()),
                                   ('l2_norm', l2_deepfool_norm.mean())])

    lr = theano.shared(floatX(FLAGS.initial_learning_rate), 'learning_rate')
    train_params = get_all_params(net, trainable=True)
    train_updates = adam(train_loss, train_params, lr)

    logger.info("Compiling theano functions...")
    train_fn = theano.function([train_images, train_labels],
                               outputs=train_metrics.values(),
                               updates=train_updates)
    val_fn = theano.function([val_images, val_labels],
                             outputs=val_metrics.values())
    summary_fn = theano.function([val_images, val_labels],
                                 outputs=summary_metrics.values() +
                                 [val_deepfool_images])

    logger.info("Starting training...")
    try:
        samples_per_class = FLAGS.summary_samples_per_class
        summary_images, summary_labels = select_balanced_subset(
            X_val, y_val, num_classes, samples_per_class)
        save_path = os.path.join(FLAGS.samples_dir, 'orig.png')
        save_images(summary_images, save_path)

        epoch = 0
        batch_index = 0
        while epoch < FLAGS.num_epochs:
            epoch += 1

            start_time = time.time()
            train_iterator = batch_iterator(X_train,
                                            y_train,
                                            FLAGS.batch_size,
                                            shuffle=True)
            epoch_outputs = np.zeros(len(train_fn.outputs))
            for batch_index, (images,
                              labels) in enumerate(train_iterator,
                                                   batch_index + 1):
                batch_outputs = train_fn(images, labels)
                epoch_outputs += batch_outputs
            epoch_outputs /= X_train.shape[0] // FLAGS.batch_size
            logger.info(
                build_result_str(
                    "Train epoch [{}, {:.2f}s]:".format(
                        epoch,
                        time.time() - start_time), train_metrics.keys(),
                    epoch_outputs))

            # update learning rate
            if epoch > FLAGS.start_learning_rate_decay:
                new_lr_value = lr.get_value(
                ) * FLAGS.learning_rate_decay_factor
                lr.set_value(floatX(new_lr_value))
                logger.debug("learning rate was changed to {:.10f}".format(
                    new_lr_value))

            # validation
            start_time = time.time()
            val_iterator = batch_iterator(X_val,
                                          y_val,
                                          FLAGS.test_batch_size,
                                          shuffle=False)
            val_epoch_outputs = np.zeros(len(val_fn.outputs))
            for images, labels in val_iterator:
                val_epoch_outputs += val_fn(images, labels)
            val_epoch_outputs /= X_val.shape[0] // FLAGS.test_batch_size
            logger.info(
                build_result_str(
                    "Test epoch [{}, {:.2f}s]:".format(
                        epoch,
                        time.time() - start_time), val_metrics.keys(),
                    val_epoch_outputs))

            if epoch % FLAGS.summary_frequency == 0:
                summary = summary_fn(summary_images, summary_labels)
                logger.info(
                    build_result_str(
                        "Epoch [{}] adversarial statistics:".format(epoch),
                        summary_metrics.keys(), summary[:-1]))
                save_path = os.path.join(FLAGS.samples_dir,
                                         'epoch-%d.png' % epoch)
                df_images = summary[-1]
                save_images(df_images, save_path)

            if epoch % FLAGS.checkpoint_frequency == 0:
                save_network(net, epoch=epoch)
    except KeyboardInterrupt:
        logger.debug("Keyboard interrupt. Stopping training...")
    finally:
        save_network(net)

    # evaluate final model on test set
    test_iterator = batch_iterator(X_test,
                                   y_test,
                                   FLAGS.test_batch_size,
                                   shuffle=False)
    test_results = np.zeros(len(val_fn.outputs))
    for images, labels in test_iterator:
        test_results += val_fn(images, labels)
    test_results /= X_test.shape[0] // FLAGS.test_batch_size
    logger.info(
        build_result_str("Final test results:", val_metrics.keys(),
                         test_results))
예제 #2
0
def main():
    setup_experiment()

    data = mnist_load()
    X_test = data.X_test
    y_test = data.y_test
    if FLAGS.sort_labels:
        ys_indices = np.argsort(y_test)
        X_test = X_test[ys_indices]
        y_test = y_test[ys_indices]

    img_shape = [None, 1, 28, 28]
    test_images = T.tensor4('test_images')
    test_labels = T.lvector('test_labels')

    # loaded discriminator number of classes and dims
    layer_dims = [int(dim) for dim in FLAGS.layer_dims.split("-")]
    num_classes = layer_dims[-1]

    # create and load discriminator
    net = create_network(FLAGS.model, img_shape, layer_dims=layer_dims)
    load_network(net, epoch=FLAGS.load_epoch)
    model = with_end_points(net)

    test_outputs = model(test_images, deterministic=True)
    # deepfool images
    test_df_images = deepfool(lambda x: model(x, deterministic=True)['logits'],
                              test_images,
                              test_labels,
                              num_classes,
                              max_iter=FLAGS.deepfool_iter,
                              clip_dist=FLAGS.deepfool_clip,
                              over_shoot=FLAGS.deepfool_overshoot)
    test_df_images_all = deepfool(
        lambda x: model(x, deterministic=True)['logits'],
        test_images,
        num_classes=num_classes,
        max_iter=FLAGS.deepfool_iter,
        clip_dist=FLAGS.deepfool_clip,
        over_shoot=FLAGS.deepfool_overshoot)
    test_df_outputs = model(test_df_images, deterministic=True)
    # fast gradient sign images
    test_fgsm_images = test_images + fast_gradient_perturbation(
        test_images, test_outputs['logits'], test_labels, FLAGS.fgsm_epsilon)
    test_at_outputs = model(test_fgsm_images, deterministic=True)

    # test metrics
    test_acc = categorical_accuracy(test_outputs['logits'], test_labels).mean()
    test_err = 1 - test_acc
    test_fgsm_acc = categorical_accuracy(test_at_outputs['logits'],
                                         test_labels).mean()
    test_fgsm_err = 1 - test_fgsm_acc
    test_df_acc = categorical_accuracy(test_df_outputs['logits'],
                                       test_labels).mean()
    test_df_err = 1 - test_df_acc

    # adversarial noise statistics
    reduc_ind = range(1, test_images.ndim)
    test_l2_df = T.sqrt(
        T.sum((test_df_images - test_images)**2, axis=reduc_ind))
    test_l2_df_norm = test_l2_df / T.sqrt(T.sum(test_images**2,
                                                axis=reduc_ind))
    test_l2_df_skip = test_l2_df.sum() / T.sum(test_l2_df > 0)
    test_l2_df_skip_norm = test_l2_df_norm.sum() / T.sum(test_l2_df_norm > 0)
    test_l2_df_all = T.sqrt(
        T.sum((test_df_images_all - test_images)**2, axis=reduc_ind))
    test_l2_df_all_norm = test_l2_df_all / T.sqrt(
        T.sum(test_images**2, axis=reduc_ind))

    test_metrics = OrderedDict([('err', test_err), ('err_fgsm', test_fgsm_err),
                                ('err_df', test_df_err),
                                ('l2_df', test_l2_df.mean()),
                                ('l2_df_norm', test_l2_df_norm.mean()),
                                ('l2_df_skip', test_l2_df_skip),
                                ('l2_df_skip_norm', test_l2_df_skip_norm),
                                ('l2_df_all', test_l2_df_all.mean()),
                                ('l2_df_all_norm', test_l2_df_all_norm.mean())
                                ])
    logger.info("Compiling theano functions...")
    test_fn = theano.function([test_images, test_labels],
                              outputs=test_metrics.values())
    generate_fn = theano.function([test_images, test_labels],
                                  [test_df_images, test_df_images_all],
                                  on_unused_input='ignore')

    logger.info("Generate samples...")
    samples_per_class = 10
    summary_images, summary_labels = select_balanced_subset(
        X_test, y_test, num_classes, samples_per_class)
    save_path = os.path.join(FLAGS.samples_dir, 'orig.png')
    save_images(summary_images, save_path)
    df_images, df_images_all = generate_fn(summary_images, summary_labels)
    save_path = os.path.join(FLAGS.samples_dir, 'deepfool.png')
    save_images(df_images, save_path)
    save_path = os.path.join(FLAGS.samples_dir, 'deepfool_all.png')
    save_images(df_images_all, save_path)

    logger.info("Starting...")
    test_iterator = batch_iterator(X_test,
                                   y_test,
                                   FLAGS.batch_size,
                                   shuffle=False)
    test_results = np.zeros(len(test_fn.outputs))
    start_time = time.time()
    for batch_index, (images, labels) in enumerate(test_iterator, 1):
        batch_results = test_fn(images, labels)
        test_results += batch_results
        if batch_index % FLAGS.summary_frequency == 0:
            df_images, df_images_all = generate_fn(images, labels)
            save_path = os.path.join(FLAGS.samples_dir,
                                     'b%d-df.png' % batch_index)
            save_images(df_images, save_path)
            save_path = os.path.join(FLAGS.samples_dir,
                                     'b%d-df_all.png' % batch_index)
            save_images(df_images_all, save_path)
            logger.info(
                build_result_str(
                    "Batch [{}] adversarial statistics:".format(batch_index),
                    test_metrics.keys(), batch_results))
    test_results /= batch_index
    logger.info(
        build_result_str(
            "Test results [{:.2f}s]:".format(time.time() - start_time),
            test_metrics.keys(), test_results))