Esempio n. 1
0
    def __init__(self, args=None):
        """
        Initialize.

        :param args: optional arguments if not to use sys.argv
        :type args: [str]
        """

        self.args = None
        """ Arguments of program. """

        parser = self.get_parser()
        if args is not None:
            self.args = parser.parse_args(args)
        else:
            self.args = parser.parse_args()

        self.test_codes = None
        """ (numpy.ndarray) Test classes. """

        self.test_theta = None
        """ (numpy.ndarray) Transformations for testing. """

        self.N_class = None
        """ (int) Number of classes. """

        self.attack_class = None
        """ (attacks.UntargetedAttack) Attack to use (as class). """

        self.objective_class = None
        """ (attacks.UntargetedObjective) Objective to use (as class). """

        self.model = None
        """ (encoder.Encoder) Model to train. """

        self.perturbations = None
        """ (numpy.ndarray) Perturbations per test image. """

        self.success = None
        """ (numpy.ndarray) Success per test image. """

        self.min_bound = None
        """ (numpy.ndarray) Minimum bound for codes. """

        self.max_bound = None
        """ (numpy.ndarray) Maximum bound for codes. """

        self.accuracy = None
        """ (numpy.ndarray) Accuracy. """

        if self.args.log_file:
            utils.makedir(os.path.dirname(self.args.log_file))
            Log.get_instance().attach(open(self.args.log_file, 'w'))

        log('-- ' + self.__class__.__name__)
        for key in vars(self.args):
            log('[Attack] %s=%s' % (key, str(getattr(self.args, key))))
Esempio n. 2
0
    def test(self):
        """
        Test classifier to identify valid samples to attack.
        """

        num_batches = int(
            math.ceil(self.perturbations.shape[0] / self.args.batch_size))

        for b in range(num_batches):
            b_start = b * self.args.batch_size
            b_end = min((b + 1) * self.args.batch_size,
                        self.perturbations.shape[0])
            batch_fonts = self.test_fonts[b_start:b_end]
            batch_classes = self.test_classes[b_start:b_end]
            batch_code = numpy.concatenate(
                (common.numpy.one_hot(batch_fonts, self.N_font),
                 common.numpy.one_hot(batch_classes, self.N_class)),
                axis=1).astype(numpy.float32)

            batch_inputs = common.torch.as_variable(
                self.perturbations[b_start:b_end], self.args.use_gpu)
            batch_code = common.torch.as_variable(batch_code,
                                                  self.args.use_gpu)

            # This basically allows to only optimize over theta, keeping the font/class code fixed.
            self.model.set_code(batch_code)
            output_images = self.model(batch_inputs)

            output_images = numpy.squeeze(
                numpy.transpose(output_images.cpu().detach().numpy(),
                                (0, 2, 3, 1)))
            self.perturbation_images = common.numpy.concatenate(
                self.perturbation_images, output_images)

            if b % 100 == 0:
                log('[Testing] computing perturbation images %d' % b)

        utils.makedir(os.path.dirname(self.args.perturbation_images_file))
        if len(self.perturbation_images.shape) > 3:
            self.perturbation_images = self.perturbation_images.reshape(
                self.N_samples, self.N_attempts,
                self.perturbation_images.shape[1],
                self.perturbation_images.shape[2],
                self.perturbation_images.shape[3])
        else:
            self.perturbation_images = self.perturbation_images.reshape(
                self.N_samples, self.N_attempts,
                self.perturbation_images.shape[1],
                self.perturbation_images.shape[2])
        self.perturbation_images = numpy.swapaxes(self.perturbation_images, 0,
                                                  1)
        utils.write_hdf5(self.args.perturbation_images_file,
                         self.perturbation_images)
        log('[Testing] wrote %s' % self.args.perturbation_images_file)
Esempio n. 3
0
    def plot(self):
        """
        Plot.
        """

        if self.args.output_directory:
            utils.makedir(self.args.output_directory)
            for i in range(self.data.shape[0]):
                plot_file = paths.image_file('%s/%d' %
                                             (self.args.output_directory, i))
                plot.line(plot_file, self.data[i, :, 0], self.data[i, :, 1])
Esempio n. 4
0
    def __init__(self, args=None):
        """
        Initialize.

        :param args: optional arguments if not to use sys.argv
        :type args: [str]
        """

        self.args = None
        """ Arguments of program. """

        parser = self.get_parser()
        if args is not None:
            self.args = parser.parse_args(args)
        else:
            self.args = parser.parse_args()

        self.test_fonts = None
        """ (numpy.ndarray) Font classes. """

        self.test_classes = None
        """ (numpy.ndarray) Character classes. """

        self.N_attempts = None
        """ (int) Number of attempts. """

        self.N_samples = None
        """ (int) Number of samples. """

        self.N_font = None
        """ (int) Number of fonts. """

        self.N_class = None
        """ (int) Number of classes. """

        self.model = None
        """ (encoder.Encoder) Model to train. """

        self.perturbations = None
        """ (numpy.ndarray) Perturbations per test image. """

        self.perturbation_images = None
        """ (numpy.ndarray) Perturbation images. """

        if self.args.log_file:
            utils.makedir(os.path.dirname(self.args.log_file))
            Log.get_instance().attach(open(self.args.log_file, 'w'))

        log('-- ' + self.__class__.__name__)
        for key in vars(self.args):
            log('[Testing] %s=%s' % (key, str(getattr(self.args, key))))
    def __init__(self, args=None):
        """
        Initialize.

        :param args: optional arguments if not to use sys.argv
        :type args: [str]
        """

        self.args = None
        """ Arguments of program. """

        parser = self.get_parser()
        if args is not None:
            self.args = parser.parse_args(args)
        else:
            self.args = parser.parse_args()

        self.test_codes = None
        """ (numpy.ndarray) Codes for testing. """

        self.perturbation_codes = None
        """ (numpy.ndarray) Perturbation codes for testing. """

        self.model = None
        """ (encoder.Encoder) Model to train. """

        self.perturbations = None
        """ (numpy.ndarray) Perturbations per test image. """

        self.original_accuracy = None
        """ (numpy.ndarray) Success of classifier. """

        self.transfer_accuracy = None
        """ (numpy.ndarray) Success of classifier. """

        self.original_success = None
        """ (numpy.ndarray) Success per test image. """

        self.transfer_success = None
        """ (numpy.ndarray) Success per test image. """

        if self.args.log_file:
            utils.makedir(os.path.dirname(self.args.log_file))
            Log.get_instance().attach(open(self.args.log_file, 'w'))

        log('-- ' + self.__class__.__name__)
        for key in vars(self.args):
            log('[Testing] %s=%s' % (key, str(getattr(self.args, key))))
    def test(self):
        """
        Test classifier to identify valid samples to attack.
        """

        num_batches = int(
            math.ceil(self.perturbations.shape[0] / self.args.batch_size))

        for b in range(num_batches):
            b_start = b * self.args.batch_size
            b_end = min((b + 1) * self.args.batch_size,
                        self.perturbations.shape[0])
            batch_images = common.torch.as_variable(
                self.test_images[b_start:b_end], self.args.use_gpu)
            batch_inputs = common.torch.as_variable(
                self.perturbations[b_start:b_end], self.args.use_gpu)

            self.model.set_image(batch_images)
            output_images = self.model(batch_inputs)

            output_images = numpy.squeeze(
                numpy.transpose(output_images.cpu().detach().numpy(),
                                (0, 2, 3, 1)))
            self.perturbation_images = common.numpy.concatenate(
                self.perturbation_images, output_images)

            if b % 100 == 0:
                log('[Testing] computing perturbation images %d' % b)

        utils.makedir(os.path.dirname(self.args.perturbation_images_file))
        if len(self.perturbation_images.shape) > 3:
            self.perturbation_images = self.perturbation_images.reshape(
                self.N_samples, self.N_attempts,
                self.perturbation_images.shape[1],
                self.perturbation_images.shape[2],
                self.perturbation_images.shape[3])
        else:
            self.perturbation_images = self.perturbation_images.reshape(
                self.N_samples, self.N_attempts,
                self.perturbation_images.shape[1],
                self.perturbation_images.shape[2])
        self.perturbation_images = numpy.swapaxes(self.perturbation_images, 0,
                                                  1)
        utils.write_hdf5(self.args.perturbation_images_file,
                         self.perturbation_images)
        log('[Testing] wrote %s' % self.args.perturbation_images_file)
Esempio n. 7
0
    def __init__(self, args=None):
        """
        Initialize.

        :param args: arguments
        :type args: list
        """

        self.args = None
        """ Arguments of program. """

        self.test_images = None
        """ (numpy.ndarray) Images to test on. """

        self.test_codes = None
        """ (numpy.ndarray) Codes for testing. """

        self.model = None
        """ (encoder.Encoder) Model to train. """

        self.loss = None
        """ (float) Will hold evalauted loss. """

        self.error = None
        """ (float) Will hold evaluated error. """

        self.accuracy = None
        """ (numpy.ndarray) Will hold success. """

        self.results = dict()
        """ (dict) Will hold evaluation results. """

        parser = self.get_parser()
        if args is not None:
            self.args = parser.parse_args(args)
        else:
            self.args = parser.parse_args()  # sys.args

        utils.makedir(os.path.dirname(self.args.log_file))
        if self.args.log_file:
            Log.get_instance().attach(open(self.args.log_file, 'w'))

        log('-- ' + self.__class__.__name__)
        for key in vars(self.args):
            log('[Testing] %s=%s' % (key, str(getattr(self.args, key))))
    def debug(self, filename, images, cmap='gray'):
        """
        Simple debugging.

        :param filename: filename in debug_directory
        :type filename: str
        :param images: images
        :type images: numpy.ndarray
        """

        if type(images) == torch.autograd.Variable or type(
                images) == torch.Tensor:
            images = images.cpu().detach().numpy()

        assert type(images) == numpy.ndarray
        assert images.shape[3] == 1 or images.shape[3] == 3

        if utils.display() and self.args.debug_directory:
            utils.makedir(self.args.debug_directory)
            vis.mosaic(os.path.join(self.args.debug_directory, filename),
                       images,
                       cmap=cmap)
    def __init__(self, args=None):
        """
        Initialize.

        :param args: optional arguments if not to use sys.argv
        :type args: [str]
        """

        self.args = None
        """ Arguments of program. """

        parser = self.get_parser()
        if args is not None:
            self.args = parser.parse_args(args)
        else:
            self.args = parser.parse_args()

        self.train_images = None
        """ (numpy.ndarray) Images to train on. """

        self.train_codes = None
        """ (numpy.ndarray) Codes for training. """

        self.val_images = None
        """ (numpy.ndarray) Images to validate on. """

        self.val_codes = None
        """ (numpy.ndarray) Codes to validate on. """

        self.val_error = None
        """ (float) Validation error. """

        self.test_images = None
        """ (numpy.ndarray) Images to test on. """

        self.test_codes = None
        """ (numpy.ndarray) Codes for testing. """

        self.model = None
        """ (encoder.Encoder) Model to train. """

        self.scheduler = None
        """ (Scheduler) Scheduler for training. """

        self.train_statistics = numpy.zeros((0, 6))
        """ (numpy.ndarray) Will hold training statistics. """

        self.test_statistics = numpy.zeros((0, 5))
        """ (numpy.ndarray) Will hold testing statistics. """

        self.epoch = 0
        """ (int) Current epoch. """

        self.N_class = None
        """ (int) Number of classes. """

        self.results = dict()
        """ (dict) Results. """

        utils.makedir(os.path.dirname(self.args.state_file))
        utils.makedir(os.path.dirname(self.args.log_file))

        if self.args.log_file:
            Log.get_instance().attach(open(self.args.log_file, 'w'))

        log('-- ' + self.__class__.__name__)
        for key in vars(self.args):
            log('[Training] %s=%s' % (key, str(getattr(self.args, key))))
    def test(self):
        """
        Test classifier to identify valid samples to attack.
        """

        self.model.eval()
        assert self.model.training is False
        assert self.perturbation_codes.shape[0] == self.perturbations.shape[0]
        assert self.test_codes.shape[0] == self.test_images.shape[0]
        assert len(self.perturbations.shape) == 4
        assert len(self.test_images.shape) == 4

        perturbations_accuracy = None
        num_batches = int(math.ceil(self.perturbations.shape[0] / self.args.batch_size))

        for b in range(num_batches):
            b_start = b * self.args.batch_size
            b_end = min((b + 1) * self.args.batch_size, self.perturbations.shape[0])
            batch_perturbations = common.torch.as_variable(self.perturbations[b_start: b_end], self.args.use_gpu)
            batch_classes = common.torch.as_variable(self.perturbation_codes[b_start: b_end], self.args.use_gpu)
            batch_perturbations = batch_perturbations.permute(0, 3, 1, 2)

            output_classes = self.model(batch_perturbations)
            values, indices = torch.max(torch.nn.functional.softmax(output_classes, dim=1), dim=1)
            errors = torch.abs(indices - batch_classes)
            perturbations_accuracy = common.numpy.concatenate(perturbations_accuracy, errors.data.cpu().numpy())

            for n in range(batch_perturbations.size(0)):
                log('[Testing] %d: original success=%d, transfer accuracy=%d' % (n, self.original_success[b_start + n], errors[n].item()))

        self.transfer_success[perturbations_accuracy == 0] = -1
        self.transfer_success = self.transfer_success.reshape((self.N_samples, self.N_attempts))
        self.transfer_success = numpy.swapaxes(self.transfer_success, 0, 1)

        utils.makedir(os.path.dirname(self.args.transfer_success_file))
        utils.write_hdf5(self.args.transfer_success_file, self.transfer_success)
        log('[Testing] wrote %s' % self.args.transfer_success_file)

        num_batches = int(math.ceil(self.test_images.shape[0] / self.args.batch_size))
        for b in range(num_batches):
            b_start = b * self.args.batch_size
            b_end = min((b + 1) * self.args.batch_size, self.test_images.shape[0])
            batch_images = common.torch.as_variable(self.test_images[b_start: b_end], self.args.use_gpu)
            batch_classes = common.torch.as_variable(self.test_codes[b_start: b_end], self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            output_classes = self.model(batch_images)
            values, indices = torch.max(torch.nn.functional.softmax(output_classes, dim=1), dim=1)
            errors = torch.abs(indices - batch_classes)

            self.transfer_accuracy = common.numpy.concatenate(self.transfer_accuracy, errors.data.cpu().numpy())

            if b % 100 == 0:
                log('[Testing] computing accuracy %d' % b)

        self.transfer_accuracy = self.transfer_accuracy == 0
        log('[Testing] original accuracy=%g' % (numpy.sum(self.original_accuracy)/float(self.original_accuracy.shape[0])))
        log('[Testing] transfer accuracy=%g' % (numpy.sum(self.transfer_accuracy)/float(self.transfer_accuracy.shape[0])))
        log('[Testing] accuracy difference=%g' % (numpy.sum(self.transfer_accuracy != self.original_accuracy)/float(self.transfer_accuracy.shape[0])))
        log('[Testing] accuracy difference on %d samples=%g' % (self.N_samples, numpy.sum(self.transfer_accuracy[:self.N_samples] != self.original_accuracy[:self.N_samples])/float(self.N_samples)))
        self.transfer_accuracy = numpy.logical_and(self.transfer_accuracy, self.original_accuracy)

        utils.makedir(os.path.dirname(self.args.transfer_accuracy_file))
        utils.write_hdf5(self.args.transfer_accuracy_file, self.transfer_accuracy)
        log('[Testing] wrote %s' % self.args.transfer_accuracy_file)
Esempio n. 11
0
    def __init__(self, args=None):
        """
        Initialize.

        :param args: optional arguments if not to use sys.argv
        :type args: [str]
        """

        self.args = None
        """ Arguments of program. """

        parser = self.get_parser()
        if args is not None:
            self.args = parser.parse_args(args)
        else:
            self.args = parser.parse_args()

        self.train_images = None
        """ (numpy.ndarray) Images to train on. """

        self.test_images = None
        """ (numpy.ndarray) Images to test on. """

        self.train_codes = None
        """ (numpy.ndarray) Labels to train on. """

        self.test_codes = None
        """ (numpy.ndarray) Labels to test on. """

        self.resolution = None
        """ (int) Resolution. """

        self.encoder = None
        """ (models.LearnedEncoder) Encoder. """

        self.decoder = None
        """ (models.LearnedDecoder) Decoder. """

        self.reconstruction_error = 0
        """ (int) Reconstruction error. """

        self.code_mean = 0
        """ (int) Reconstruction error. """

        self.code_var = 0
        """ (int) Reconstruction error. """

        self.pred_images = None
        """ (numpy.ndarray) Test images reconstructed. """

        self.pred_codes = None
        """ (numpy.ndarray) Test latent codes. """

        self.results = dict()
        """ (dict) Results. """

        utils.makedir(os.path.dirname(self.args.log_file))
        if self.args.log_file:
            Log.get_instance().attach(open(self.args.log_file, 'w'))

        log('-- ' + self.__class__.__name__)
        for key in vars(self.args):
            log('[Testing] %s=%s' % (key, str(getattr(self.args, key))))
    def __init__(self, args=None):
        """
        Initialize.

        :param args: optional arguments if not to use sys.argv
        :type args: [str]
        """

        self.args = None
        """ Arguments of program. """

        parser = self.get_parser()
        if args is not None:
            self.args = parser.parse_args(args)
        else:
            self.args = parser.parse_args()

        self.train_images = None
        """ (numpy.ndarray) Images to train on. """

        self.test_images = None
        """ (numpy.ndarray) Images to test on. """

        self.train_codes = None
        """ (numpy.ndarray) Labels to train on. """

        self.test_codes = None
        """ (numpy.ndarray) Labels to test on. """

        if self.args.log_file:
            utils.makedir(os.path.dirname(self.args.log_file))
            Log.get_instance().attach(open(self.args.log_file, 'w'))

        log('-- ' + self.__class__.__name__)
        for key in vars(self.args):
            log('[Training] %s=%s' % (key, str(getattr(self.args, key))))

        utils.makedir(os.path.dirname(self.args.encoder_file))
        utils.makedir(os.path.dirname(self.args.decoder_file))
        utils.makedir(os.path.dirname(self.args.log_file))

        self.resolution = None
        """ (int) Resolution. """

        self.encoder = None
        """ (models.LearnedVariationalEncoder) Encoder. """

        self.decoder = None
        """ (models.LearnedDecoder) Decoder. """

        self.classifier = None
        """ (models.Classifier) Classifier. """

        self.encoder_scheduler = None
        """ (scheduler.Scheduler) Encoder schduler. """

        self.decoder_scheduler = None
        """ (scheduler.Scheduler) Decoder schduler. """

        self.classifier_scheduler = None
        """ (scheduler.Scheduler) Classifier schduler. """

        self.random_codes = None
        """ (numyp.ndarray) Random codes. """

        self.train_statistics = numpy.zeros((0, 15))
        """ (numpy.ndarray) Will hold training statistics. """

        self.test_statistics = numpy.zeros((0, 12))
        """ (numpy.ndarray) Will hold testing statistics. """

        self.results = dict()
        """ (dict) Results. """

        self.logvar = -2.5
        """ (float) Log-variance hyper parameter. """
    def visualize_perturbations(self):
        """
        Visualize perturbations.
        """

        num_attempts = self.perturbations.shape[1]
        num_attempts = min(num_attempts, 6)
        utils.makedir(self.args.output_directory)

        count = 0
        for i in range(min(1000, self.perturbations.shape[0])):

            log('[Visualization] sample %d, iterations %s and correctly classified: %s'
                % (i + 1, ' '.join(list(map(
                    str, self.success[i]))), self.accuracy[i]))
            if not numpy.any(self.success[i] >= 0) or not self.accuracy[i]:
                continue
            elif count > 200:
                break

            #fig, axes = pyplot.subplots(num_attempts, 8)
            #if num_attempts == 1:
            #    axes = [axes] # dirty hack for axis indexing

            for j in range(num_attempts):
                theta = self.test_theta[i]
                theta_attack = self.perturbations[i][j]
                theta_perturbation = theta_attack - theta

                image = self.test_images[i]
                image_attack = self.perturbation_images[i][j]
                image_perturbation = image_attack - image

                max_theta_perturbation = numpy.max(
                    numpy.abs(theta_perturbation))
                theta_perturbation /= max_theta_perturbation

                max_image_perturbation = numpy.max(
                    numpy.abs(image_perturbation))
                image_perturbation /= max_image_perturbation

                image_representation = self.theta_representations[i]
                attack_representation = self.perturbation_representations[i][j]

                image_label = numpy.argmax(image_representation)
                attack_label = numpy.argmax(attack_representation)

                #vmin = min(numpy.min(theta), numpy.min(theta_attack))
                #vmax = max(numpy.max(theta), numpy.max(theta_attack))
                #axes[j][0].imshow(theta.reshape(1, -1), interpolation='nearest', vmin=vmin, vmax=vmax)
                #axes[j][1].imshow(numpy.squeeze(image), interpolation='nearest', cmap='gray', vmin=0, vmax=1)
                #axes[j][2].imshow(theta_perturbation.reshape(1, -1), interpolation='nearest', vmin=vmin, vmax=vmax)
                #axes[j][2].text(0, -1, 'x' + str(max_theta_perturbation))
                #axes[j][3].imshow(numpy.squeeze(image_perturbation), interpolation='nearest', cmap='seismic', vmin=-1, vmax=1)
                #axes[j][3].text(0, -image.shape[1]//8, 'x' + str(max_image_perturbation))
                #axes[j][4].imshow(theta_attack.reshape(1, -1), interpolation='nearest', vmin=vmin, vmax=vmax)
                #axes[j][5].imshow(numpy.squeeze(image_attack), interpolation='nearest', cmap='gray', vmin=0, vmax=1)

                #axes[j][6].imshow(image_representation.reshape(1, -1), interpolation='nearest', vmin=vmin, vmax=vmax)
                #axes[j][6].text(0, -1, 'Label:' + str(image_label))
                #axes[j][7].imshow(attack_representation.reshape(1, -1), interpolation='nearest', vmin=vmin, vmax=vmax)
                #axes[j][7].text(0, -1, 'Label:' + str(attack_label))

                image_file = os.path.join(
                    self.args.output_directory,
                    '%d_%d_image_%d.png' % (i, j, image_label))
                attack_file = os.path.join(
                    self.args.output_directory,
                    '%d_%d_attack_%d.png' % (i, j, attack_label))
                perturbation_file = os.path.join(
                    self.args.output_directory, '%d_%d_perturbation_%g.png' %
                    (i, j, max_image_perturbation))

                vis.image(image_file, image, scale=10)
                vis.image(attack_file, image_attack, scale=10)
                vis.perturbation(perturbation_file,
                                 image_perturbation,
                                 scale=10)

            #plot_file = os.path.join(self.args.output_directory, str(i) + '.png')
            #pyplot.savefig(plot_file)
            #pyplot.close(fig)
            count += 1
    def visualize_perturbations(self):
        """
        Visualize perturbations.
        """

        num_attempts = self.perturbations.shape[1]
        num_attempts = min(num_attempts, 6)
        utils.makedir(self.args.output_directory)

        count = 0
        for i in range(min(1000, self.perturbations.shape[0])):

            if not numpy.any(self.success[i]) or not self.accuracy[i]:
                continue
            elif count > 200:
                break

            #fig, axes = pyplot.subplots(num_attempts, 5)
            #if num_attempts == 1:
            #    axes = [axes] # dirty hack for axis indexing

            for j in range(num_attempts):
                image = self.test_images[i]
                attack = self.perturbations[i][j]
                perturbation = attack - image
                max_perturbation = numpy.max(numpy.abs(perturbation))
                perturbation /= max_perturbation

                image_representation = self.image_representations[i]
                attack_representation = self.perturbation_representations[i][j]

                image_label = numpy.argmax(image_representation)
                attack_label = numpy.argmax(attack_representation)

                #axes[j][0].imshow(numpy.squeeze(image), interpolation='nearest', cmap='gray', vmin=0, vmax=1)
                #axes[j][1].imshow(numpy.squeeze(perturbation), interpolation='nearest', cmap='seismic', vmin=-1, vmax=1)
                #axes[j][1].text(0, -image.shape[1]//8, 'x' + str(max_perturbation))
                #axes[j][2].imshow(numpy.squeeze(attack), interpolation='nearest', cmap='gray', vmin=0, vmax=1)

                #vmin = min(numpy.min(image_representation), numpy.min(attack_representation))
                #vmax = max(numpy.max(image_representation), numpy.max(attack_representation))
                #axes[j][3].imshow(image_representation.reshape(1, -1), interpolation='nearest', vmin=vmin, vmax=vmax)
                #axes[j][3].text(0, -1, 'Label:' + str(image_label))
                #axes[j][4].imshow(attack_representation.reshape(1, -1), interpolation='nearest', vmin=vmin, vmax=vmax)
                #axes[j][4].text(0, -1, 'Label:' + str(attack_label))

                image_file = os.path.join(
                    self.args.output_directory,
                    '%d_%d_image_%d.png' % (i, j, image_label))
                attack_file = os.path.join(
                    self.args.output_directory,
                    '%d_%d_attack_%d.png' % (i, j, attack_label))
                perturbation_file = os.path.join(
                    self.args.output_directory,
                    '%d_%d_perturbation_%g.png' % (i, j, max_perturbation))

                vis.image(image_file, image, scale=10)
                vis.image(attack_file, attack, scale=10)
                vis.perturbation(perturbation_file, perturbation, scale=10)

                if len(perturbation.shape) > 2:
                    perturbation_magnitude = numpy.linalg.norm(perturbation,
                                                               ord=2,
                                                               axis=2)
                    max_perturbation_magnitude = numpy.max(
                        numpy.abs(perturbation_magnitude))
                    perturbation_magnitude /= max_perturbation_magnitude

                    perturbation_file = os.path.join(
                        self.args.output_directory,
                        '%d_%d_perturbation_magnitude_%g.png' %
                        (i, j, max_perturbation_magnitude))
                    vis.perturbation(perturbation_file,
                                     perturbation_magnitude,
                                     scale=10)

            #plot_file = os.path.join(self.args.output_directory, str(i) + '.png')
            #pyplot.savefig(plot_file)
            #pyplot.close(fig)
            count += 1