def test_multiclass_hinge_loss(): from lasagne.objectives import multiclass_hinge_loss from lasagne.nonlinearities import rectify p = theano.tensor.matrix('p') t = theano.tensor.ivector('t') c = multiclass_hinge_loss(p, t) # numeric version floatX = theano.config.floatX predictions = np.random.rand(10, 20).astype(floatX) targets = np.random.random_integers(0, 19, (10,)).astype("int8") one_hot = np.zeros((10, 20)) one_hot[np.arange(10), targets] = 1 correct = predictions[one_hot > 0] rest = predictions[one_hot < 1].reshape((10, 19)) rest = np.max(rest, axis=1) hinge = rectify(1 + rest - correct) # compare assert np.allclose(hinge, c.eval({p: predictions, t: targets}))
def test_multiclass_hinge_loss_invalid(): from lasagne.objectives import multiclass_hinge_loss with pytest.raises(TypeError) as exc: multiclass_hinge_loss(theano.tensor.vector(), theano.tensor.matrix()) assert 'rank mismatch' in exc.value.args[0]
def main(): setup_train_experiment(logger, FLAGS, "%(model)s_margin") logger.info("Loading data...") data = mnist_load(FLAGS.train_size, FLAGS.seed) X_train, y_train = data.X_train, data.y_train X_val, y_val = data.X_val, data.y_val X_test, y_test = data.X_test, data.y_test img_shape = [None, 1, 28, 28] train_images = T.tensor4('train_images') train_labels = T.lvector('train_labels') val_images = T.tensor4('valid_labels') val_labels = T.lvector('valid_labels') layer_dims = [int(dim) for dim in FLAGS.layer_dims.split("-")] num_classes = layer_dims[-1] net = create_network(FLAGS.model, img_shape, layer_dims=layer_dims) model = with_end_points(net) train_outputs = model(train_images) val_outputs = model(val_images, deterministic=True) # losses train_ma = multiclass_hinge_loss(train_outputs['logits'], train_labels).mean() if FLAGS.margin_order == 2: train_ma_sens = margin_sensitivity(train_images, train_outputs['logits'], train_labels, num_classes).mean() else: train_ma_sens = margin_sensitivity(train_images, train_outputs['logits'], train_labels, num_classes, ord=np.inf).mean() train_loss = train_ma + FLAGS.lmbd * train_ma_sens val_ma = multiclass_hinge_loss(val_outputs['logits'], val_labels).mean() val_deepfool_images = deepfool( lambda x: model(x, deterministic=True)['logits'], val_images, val_labels, num_classes, max_iter=FLAGS.deepfool_iter, clip_dist=FLAGS.deepfool_clip, over_shoot=FLAGS.deepfool_overshoot) # metrics train_acc = categorical_accuracy(train_outputs['logits'], train_labels).mean() train_err = 1.0 - train_acc val_acc = categorical_accuracy(val_outputs['logits'], val_labels).mean() val_err = 1.0 - val_acc # deepfool robustness reduc_ind = range(1, train_images.ndim) l2_deepfool = (val_deepfool_images - val_images).norm(2, axis=reduc_ind) l2_deepfool_norm = l2_deepfool / val_images.norm(2, axis=reduc_ind) train_metrics = OrderedDict([('loss', train_loss), ('margin', train_ma), ('sensitivity', train_ma_sens), ('err', train_err)]) val_metrics = OrderedDict([('margin', val_ma), ('err', val_err)]) summary_metrics = OrderedDict([('l2', l2_deepfool.mean()), ('l2_norm', l2_deepfool_norm.mean())]) lr = theano.shared(floatX(FLAGS.initial_learning_rate), 'learning_rate') train_params = get_all_params(net, trainable=True) train_updates = adam(train_loss, train_params, lr) logger.info("Compiling theano functions...") train_fn = theano.function([train_images, train_labels], outputs=train_metrics.values(), updates=train_updates) val_fn = theano.function([val_images, val_labels], outputs=val_metrics.values()) summary_fn = theano.function([val_images, val_labels], outputs=summary_metrics.values() + [val_deepfool_images]) logger.info("Starting training...") try: samples_per_class = FLAGS.summary_samples_per_class summary_images, summary_labels = select_balanced_subset( X_val, y_val, num_classes, samples_per_class) save_path = os.path.join(FLAGS.samples_dir, 'orig.png') save_images(summary_images, save_path) epoch = 0 batch_index = 0 while epoch < FLAGS.num_epochs: epoch += 1 start_time = time.time() train_iterator = batch_iterator(X_train, y_train, FLAGS.batch_size, shuffle=True) epoch_outputs = np.zeros(len(train_fn.outputs)) for batch_index, (images, labels) in enumerate(train_iterator, batch_index + 1): batch_outputs = train_fn(images, labels) epoch_outputs += batch_outputs epoch_outputs /= X_train.shape[0] // FLAGS.batch_size logger.info( build_result_str( "Train epoch [{}, {:.2f}s]:".format( epoch, time.time() - start_time), train_metrics.keys(), epoch_outputs)) # update learning rate if epoch > FLAGS.start_learning_rate_decay: new_lr_value = lr.get_value( ) * FLAGS.learning_rate_decay_factor lr.set_value(floatX(new_lr_value)) logger.debug("learning rate was changed to {:.10f}".format( new_lr_value)) # validation start_time = time.time() val_iterator = batch_iterator(X_val, y_val, FLAGS.test_batch_size, shuffle=False) val_epoch_outputs = np.zeros(len(val_fn.outputs)) for images, labels in val_iterator: val_epoch_outputs += val_fn(images, labels) val_epoch_outputs /= X_val.shape[0] // FLAGS.test_batch_size logger.info( build_result_str( "Test epoch [{}, {:.2f}s]:".format( epoch, time.time() - start_time), val_metrics.keys(), val_epoch_outputs)) if epoch % FLAGS.summary_frequency == 0: summary = summary_fn(summary_images, summary_labels) logger.info( build_result_str( "Epoch [{}] adversarial statistics:".format(epoch), summary_metrics.keys(), summary[:-1])) save_path = os.path.join(FLAGS.samples_dir, 'epoch-%d.png' % epoch) df_images = summary[-1] save_images(df_images, save_path) if epoch % FLAGS.checkpoint_frequency == 0: save_network(net, epoch=epoch) except KeyboardInterrupt: logger.debug("Keyboard interrupt. Stopping training...") finally: save_network(net) # evaluate final model on test set test_iterator = batch_iterator(X_test, y_test, FLAGS.test_batch_size, shuffle=False) test_results = np.zeros(len(val_fn.outputs)) for images, labels in test_iterator: test_results += val_fn(images, labels) test_results /= X_test.shape[0] // FLAGS.test_batch_size logger.info( build_result_str("Final test results:", val_metrics.keys(), test_results))