def main(): setup_train_experiment(logger, FLAGS, "%(model)s_at") logger.info("Loading data...") data = mnist_load(FLAGS.train_size, FLAGS.seed) X_train, y_train = data.X_train, data.y_train X_val, y_val = data.X_val, data.y_val X_test, y_test = data.X_test, data.y_test img_shape = [None, 1, 28, 28] train_images = T.tensor4('train_images') train_labels = T.lvector('train_labels') val_images = T.tensor4('valid_labels') val_labels = T.lvector('valid_labels') layer_dims = [int(dim) for dim in FLAGS.layer_dims.split("-")] num_classes = layer_dims[-1] net = create_network(FLAGS.model, img_shape, layer_dims=layer_dims) model = with_end_points(net) train_outputs = model(train_images) val_outputs = model(val_images, deterministic=True) # losses train_ce = categorical_crossentropy(train_outputs['prob'], train_labels).mean() train_at = adversarial_training(lambda x: model(x)['prob'], train_images, train_labels, epsilon=FLAGS.epsilon).mean() train_loss = train_ce + FLAGS.lmbd * train_at val_ce = categorical_crossentropy(val_outputs['prob'], val_labels).mean() val_deepfool_images = deepfool( lambda x: model(x, deterministic=True)['logits'], val_images, val_labels, num_classes, max_iter=FLAGS.deepfool_iter, clip_dist=FLAGS.deepfool_clip, over_shoot=FLAGS.deepfool_overshoot) # metrics train_acc = categorical_accuracy(train_outputs['logits'], train_labels).mean() train_err = 1.0 - train_acc val_acc = categorical_accuracy(val_outputs['logits'], val_labels).mean() val_err = 1.0 - val_acc # deepfool robustness reduc_ind = range(1, train_images.ndim) l2_deepfool = (val_deepfool_images - val_images).norm(2, axis=reduc_ind) l2_deepfool_norm = l2_deepfool / val_images.norm(2, axis=reduc_ind) train_metrics = OrderedDict([('loss', train_loss), ('nll', train_ce), ('at', train_at), ('err', train_err)]) val_metrics = OrderedDict([('nll', val_ce), ('err', val_err)]) summary_metrics = OrderedDict([('l2', l2_deepfool.mean()), ('l2_norm', l2_deepfool_norm.mean())]) lr = theano.shared(floatX(FLAGS.initial_learning_rate), 'learning_rate') train_params = get_all_params(net, trainable=True) train_updates = adam(train_loss, train_params, lr) logger.info("Compiling theano functions...") train_fn = theano.function([train_images, train_labels], outputs=train_metrics.values(), updates=train_updates) val_fn = theano.function([val_images, val_labels], outputs=val_metrics.values()) summary_fn = theano.function([val_images, val_labels], outputs=summary_metrics.values() + [val_deepfool_images]) logger.info("Starting training...") try: samples_per_class = FLAGS.summary_samples_per_class summary_images, summary_labels = select_balanced_subset( X_val, y_val, num_classes, samples_per_class) save_path = os.path.join(FLAGS.samples_dir, 'orig.png') save_images(summary_images, save_path) epoch = 0 batch_index = 0 while epoch < FLAGS.num_epochs: epoch += 1 start_time = time.time() train_iterator = batch_iterator(X_train, y_train, FLAGS.batch_size, shuffle=True) epoch_outputs = np.zeros(len(train_fn.outputs)) for batch_index, (images, labels) in enumerate(train_iterator, batch_index + 1): batch_outputs = train_fn(images, labels) epoch_outputs += batch_outputs epoch_outputs /= X_train.shape[0] // FLAGS.batch_size logger.info( build_result_str( "Train epoch [{}, {:.2f}s]:".format( epoch, time.time() - start_time), train_metrics.keys(), epoch_outputs)) # update learning rate if epoch > FLAGS.start_learning_rate_decay: new_lr_value = lr.get_value( ) * FLAGS.learning_rate_decay_factor lr.set_value(floatX(new_lr_value)) logger.debug("learning rate was changed to {:.10f}".format( new_lr_value)) # validation start_time = time.time() val_iterator = batch_iterator(X_val, y_val, FLAGS.test_batch_size, shuffle=False) val_epoch_outputs = np.zeros(len(val_fn.outputs)) for images, labels in val_iterator: val_epoch_outputs += val_fn(images, labels) val_epoch_outputs /= X_val.shape[0] // FLAGS.test_batch_size logger.info( build_result_str( "Test epoch [{}, {:.2f}s]:".format( epoch, time.time() - start_time), val_metrics.keys(), val_epoch_outputs)) if epoch % FLAGS.summary_frequency == 0: summary = summary_fn(summary_images, summary_labels) logger.info( build_result_str( "Epoch [{}] adversarial statistics:".format(epoch), summary_metrics.keys(), summary[:-1])) save_path = os.path.join(FLAGS.samples_dir, 'epoch-%d.png' % epoch) df_images = summary[-1] save_images(df_images, save_path) if epoch % FLAGS.checkpoint_frequency == 0: save_network(net, epoch=epoch) except KeyboardInterrupt: logger.debug("Keyboard interrupt. Stopping training...") finally: save_network(net) # evaluate final model on test set test_iterator = batch_iterator(X_test, y_test, FLAGS.test_batch_size, shuffle=False) test_results = np.zeros(len(val_fn.outputs)) for images, labels in test_iterator: test_results += val_fn(images, labels) test_results /= X_test.shape[0] // FLAGS.test_batch_size logger.info( build_result_str("Final test results:", val_metrics.keys(), test_results))
def main(): setup_experiment() data = mnist_load() X_test = data.X_test y_test = data.y_test if FLAGS.sort_labels: ys_indices = np.argsort(y_test) X_test = X_test[ys_indices] y_test = y_test[ys_indices] img_shape = [None, 1, 28, 28] test_images = T.tensor4('test_images') test_labels = T.lvector('test_labels') # loaded discriminator number of classes and dims layer_dims = [int(dim) for dim in FLAGS.layer_dims.split("-")] num_classes = layer_dims[-1] # create and load discriminator net = create_network(FLAGS.model, img_shape, layer_dims=layer_dims) load_network(net, epoch=FLAGS.load_epoch) model = with_end_points(net) test_outputs = model(test_images, deterministic=True) # deepfool images test_df_images = deepfool(lambda x: model(x, deterministic=True)['logits'], test_images, test_labels, num_classes, max_iter=FLAGS.deepfool_iter, clip_dist=FLAGS.deepfool_clip, over_shoot=FLAGS.deepfool_overshoot) test_df_images_all = deepfool( lambda x: model(x, deterministic=True)['logits'], test_images, num_classes=num_classes, max_iter=FLAGS.deepfool_iter, clip_dist=FLAGS.deepfool_clip, over_shoot=FLAGS.deepfool_overshoot) test_df_outputs = model(test_df_images, deterministic=True) # fast gradient sign images test_fgsm_images = test_images + fast_gradient_perturbation( test_images, test_outputs['logits'], test_labels, FLAGS.fgsm_epsilon) test_at_outputs = model(test_fgsm_images, deterministic=True) # test metrics test_acc = categorical_accuracy(test_outputs['logits'], test_labels).mean() test_err = 1 - test_acc test_fgsm_acc = categorical_accuracy(test_at_outputs['logits'], test_labels).mean() test_fgsm_err = 1 - test_fgsm_acc test_df_acc = categorical_accuracy(test_df_outputs['logits'], test_labels).mean() test_df_err = 1 - test_df_acc # adversarial noise statistics reduc_ind = range(1, test_images.ndim) test_l2_df = T.sqrt( T.sum((test_df_images - test_images)**2, axis=reduc_ind)) test_l2_df_norm = test_l2_df / T.sqrt(T.sum(test_images**2, axis=reduc_ind)) test_l2_df_skip = test_l2_df.sum() / T.sum(test_l2_df > 0) test_l2_df_skip_norm = test_l2_df_norm.sum() / T.sum(test_l2_df_norm > 0) test_l2_df_all = T.sqrt( T.sum((test_df_images_all - test_images)**2, axis=reduc_ind)) test_l2_df_all_norm = test_l2_df_all / T.sqrt( T.sum(test_images**2, axis=reduc_ind)) test_metrics = OrderedDict([('err', test_err), ('err_fgsm', test_fgsm_err), ('err_df', test_df_err), ('l2_df', test_l2_df.mean()), ('l2_df_norm', test_l2_df_norm.mean()), ('l2_df_skip', test_l2_df_skip), ('l2_df_skip_norm', test_l2_df_skip_norm), ('l2_df_all', test_l2_df_all.mean()), ('l2_df_all_norm', test_l2_df_all_norm.mean()) ]) logger.info("Compiling theano functions...") test_fn = theano.function([test_images, test_labels], outputs=test_metrics.values()) generate_fn = theano.function([test_images, test_labels], [test_df_images, test_df_images_all], on_unused_input='ignore') logger.info("Generate samples...") samples_per_class = 10 summary_images, summary_labels = select_balanced_subset( X_test, y_test, num_classes, samples_per_class) save_path = os.path.join(FLAGS.samples_dir, 'orig.png') save_images(summary_images, save_path) df_images, df_images_all = generate_fn(summary_images, summary_labels) save_path = os.path.join(FLAGS.samples_dir, 'deepfool.png') save_images(df_images, save_path) save_path = os.path.join(FLAGS.samples_dir, 'deepfool_all.png') save_images(df_images_all, save_path) logger.info("Starting...") test_iterator = batch_iterator(X_test, y_test, FLAGS.batch_size, shuffle=False) test_results = np.zeros(len(test_fn.outputs)) start_time = time.time() for batch_index, (images, labels) in enumerate(test_iterator, 1): batch_results = test_fn(images, labels) test_results += batch_results if batch_index % FLAGS.summary_frequency == 0: df_images, df_images_all = generate_fn(images, labels) save_path = os.path.join(FLAGS.samples_dir, 'b%d-df.png' % batch_index) save_images(df_images, save_path) save_path = os.path.join(FLAGS.samples_dir, 'b%d-df_all.png' % batch_index) save_images(df_images_all, save_path) logger.info( build_result_str( "Batch [{}] adversarial statistics:".format(batch_index), test_metrics.keys(), batch_results)) test_results /= batch_index logger.info( build_result_str( "Test results [{:.2f}s]:".format(time.time() - start_time), test_metrics.keys(), test_results))