def reconstruction_error(self, batch_images, output_images):
        """
        Reconstruction loss.

        :param batch_images: target images
        :type batch_images: torch.autograd.Variable
        :param output_images: predicted images
        :type output_images: torch.autograd.Variable
        :return: error
        :rtype: torch.autograd.Variable
        """

        return torch.mean(torch.mul(batch_images - output_images, batch_images - output_images))
示例#2
0
    def train(self):
        """
        Train with fair data augmentation.
        """

        self.model.train()
        assert self.model.training is True

        split = self.args.batch_size // 2
        num_batches = int(
            math.ceil(self.train_images.shape[0] / self.args.batch_size))
        permutation = numpy.random.permutation(self.train_images.shape[0])

        for b in range(num_batches):
            self.scheduler.update(self.epoch, float(b) / num_batches)

            perm = numpy.take(permutation,
                              range(b * self.args.batch_size,
                                    (b + 1) * self.args.batch_size),
                              mode='wrap')
            batch_images = common.torch.as_variable(self.train_images[perm],
                                                    self.args.use_gpu)
            batch_classes = common.torch.as_variable(self.train_codes[perm],
                                                     self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            loss = error = gradient = 0

            if self.args.full_variant:
                for t in range(self.args.max_iterations):
                    if self.args.strong_variant:
                        min_bound = numpy.repeat(self.min_bound.reshape(1, -1),
                                                 self.args.batch_size,
                                                 axis=0)
                        max_bound = numpy.repeat(self.max_bound.reshape(1, -1),
                                                 self.args.batch_size,
                                                 axis=0)
                        random = numpy.random.uniform(
                            min_bound, max_bound,
                            (self.args.batch_size, self.args.N_theta))
                        batch_perturbed_theta = common.torch.as_variable(
                            random.astype(numpy.float32), self.args.use_gpu)

                        self.decoder.set_image(batch_images)
                        batch_perturbed_images = self.decoder(
                            batch_perturbed_theta)
                    else:
                        random = common.numpy.uniform_ball(
                            self.args.batch_size,
                            self.args.N_theta,
                            epsilon=self.args.epsilon,
                            ord=self.norm)
                        batch_perturbed_theta = common.torch.as_variable(
                            random.astype(numpy.float32), self.args.use_gpu)
                        batch_perturbed_theta = torch.min(
                            common.torch.as_variable(self.max_bound,
                                                     self.args.use_gpu),
                            batch_perturbed_theta)
                        batch_perturbed_theta = torch.max(
                            common.torch.as_variable(self.min_bound,
                                                     self.args.use_gpu),
                            batch_perturbed_theta)

                        self.decoder.set_image(batch_images)
                        batch_perturbed_images = self.decoder(
                            batch_perturbed_theta)

                    output_classes = self.model(batch_perturbed_images)

                    self.scheduler.optimizer.zero_grad()
                    l = self.loss(batch_classes, output_classes)
                    l.backward()
                    self.scheduler.optimizer.step()
                    loss += l.item()

                    g = torch.mean(
                        torch.abs(list(self.model.parameters())[0].grad))
                    gradient += g.item()

                    e = self.error(batch_classes, output_classes)
                    error += e.item()

                batch_perturbations = batch_perturbed_images - batch_images
                gradient /= self.args.max_iterations
                loss /= self.args.max_iterations
                error /= self.args.max_iterations
                perturbation_loss = loss
                perturbation_error = error
            else:
                output_classes = self.model(batch_images[:split])

                self.scheduler.optimizer.zero_grad()
                l = self.loss(batch_classes[:split], output_classes)
                l.backward()
                self.scheduler.optimizer.step()
                loss = l.item()

                gradient = torch.mean(
                    torch.abs(list(self.model.parameters())[0].grad))
                gradient = gradient.item()

                e = self.error(batch_classes[:split], output_classes)
                error = e.item()

                perturbation_loss = perturbation_error = 0
                for t in range(self.args.max_iterations):
                    if self.args.strong_variant:
                        min_bound = numpy.repeat(self.min_bound.reshape(1, -1),
                                                 split,
                                                 axis=0)
                        max_bound = numpy.repeat(self.max_bound.reshape(1, -1),
                                                 split,
                                                 axis=0)
                        random = numpy.random.uniform(
                            min_bound, max_bound, (split, self.args.N_theta))

                        batch_perturbed_theta = common.torch.as_variable(
                            random.astype(numpy.float32), self.args.use_gpu)

                        self.decoder.set_image(batch_images[split:])
                        batch_perturbed_images = self.decoder(
                            batch_perturbed_theta)
                    else:
                        random = common.numpy.uniform_ball(
                            split,
                            self.args.N_theta,
                            epsilon=self.args.epsilon,
                            ord=self.norm)
                        batch_perturbed_theta = common.torch.as_variable(
                            random.astype(numpy.float32), self.args.use_gpu)
                        batch_perturbed_theta = torch.min(
                            common.torch.as_variable(self.max_bound,
                                                     self.args.use_gpu),
                            batch_perturbed_theta)
                        batch_perturbed_theta = torch.max(
                            common.torch.as_variable(self.min_bound,
                                                     self.args.use_gpu),
                            batch_perturbed_theta)

                        self.decoder.set_image(batch_images[split:])
                        batch_perturbed_images = self.decoder(
                            batch_perturbed_theta)

                    output_classes = self.model(batch_perturbed_images)

                    self.scheduler.optimizer.zero_grad()
                    l = self.loss(batch_classes[split:], output_classes)
                    l.backward()
                    self.scheduler.optimizer.step()
                    perturbation_loss += l.item()

                    g = torch.mean(
                        torch.abs(list(self.model.parameters())[0].grad))
                    gradient += g.item()

                    e = self.error(batch_classes[split:], output_classes)
                    perturbation_error += e.item()

                batch_perturbations = batch_perturbed_images - batch_images[
                    split:]
                gradient /= self.args.max_iterations + 1
                perturbation_loss /= self.args.max_iterations
                perturbation_error /= self.args.max_iterations

            iteration = self.epoch * num_batches + b + 1
            self.train_statistics = numpy.vstack((
                self.train_statistics,
                numpy.array([[
                    iteration,  # iterations
                    iteration * (1 + self.args.max_iterations) *
                    self.args.batch_size,  # samples seen
                    min(num_batches, iteration) * self.args.batch_size +
                    iteration * self.args.max_iterations *
                    self.args.batch_size,  # unique samples seen
                    loss,
                    error,
                    perturbation_loss,
                    perturbation_error,
                    gradient
                ]])))

            if b % self.args.skip == self.args.skip // 2:
                log('[Training] %d | %d: %g (%g) %g (%g) [%g]' % (
                    self.epoch,
                    b,
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 3]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 4]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 5]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 6]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, -1]),
                ))

        self.debug('clean.%d.png' % self.epoch,
                   batch_images.permute(0, 2, 3, 1))
        self.debug('perturbed.%d.png' % self.epoch,
                   batch_perturbed_images.permute(0, 2, 3, 1))
        self.debug('perturbation.%d.png' % self.epoch,
                   batch_perturbations.permute(0, 2, 3, 1),
                   cmap='seismic')
    def train(self):
        """
        Train for one epoch.
        """

        self.model.train()
        log('[Training] %d set classifier to train' % self.epoch)
        assert self.model.training is True

        num_batches = int(
            math.ceil(self.train_images.shape[0] / self.args.batch_size))
        permutation = numpy.random.permutation(self.train_images.shape[0])

        for b in range(num_batches):
            self.scheduler.update(self.epoch, float(b) / num_batches)

            perm = numpy.take(permutation,
                              range(b * self.args.batch_size,
                                    (b + 1) * self.args.batch_size),
                              mode='wrap')
            assert perm.shape[0] == self.args.batch_size

            batch_images = common.torch.as_variable(self.train_images[perm],
                                                    self.args.use_gpu)
            batch_true_classes = common.torch.as_variable(
                self.train_codes[perm], self.args.use_gpu)
            batch_training_classes = common.torch.as_variable(
                self.train_codes[perm], self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            output_classes = self.model(batch_images)

            self.scheduler.optimizer.zero_grad()
            loss = self.loss(batch_training_classes, output_classes)
            loss.backward()
            self.scheduler.optimizer.step()
            loss = loss.item()

            gradient = torch.mean(
                torch.abs(list(self.model.parameters())[0].grad))
            gradient = gradient.item()

            error = self.error(batch_true_classes, output_classes)
            error = error.item()

            iteration = self.epoch * num_batches + b + 1
            self.train_statistics = numpy.vstack(
                (self.train_statistics,
                 numpy.array([
                     iteration, iteration * self.args.batch_size,
                     min(num_batches, iteration) * self.args.batch_size, loss,
                     error, gradient
                 ])))

            if b % self.args.skip == self.args.skip // 2:
                log('[Training] %d | %d: %g (%g) [%g]' % (
                    self.epoch,
                    b,
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 3]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 4]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, -1]),
                ))

        # Only debug last iterations for efficiency!
        self.debug('clean.png', batch_images.permute(0, 2, 3, 1))
示例#4
0
    def test_test(self):
        """
        Test on testing set.
        """

        num_batches = int(
            math.ceil(self.test_images.shape[0] / self.args.batch_size))

        for b in range(num_batches):
            b_start = b * self.args.batch_size
            b_end = min((b + 1) * self.args.batch_size,
                        self.test_images.shape[0])

            batch_images = common.torch.as_variable(
                self.test_images[b_start:b_end], self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            # Important to get the correct codes!
            output_codes, output_logvar = self.encoder(batch_images)
            output_images = self.decoder(output_codes)
            e = self.reconstruction_loss(batch_images, output_images)
            self.reconstruction_error += e.data

            self.code_mean += torch.mean(output_codes).item()
            self.code_var += torch.var(output_codes).item()

            output_images = numpy.squeeze(
                numpy.transpose(output_images.cpu().detach().numpy(),
                                (0, 2, 3, 1)))
            self.pred_images = common.numpy.concatenate(
                self.pred_images, output_images)

            output_codes = output_codes.cpu().detach().numpy()
            self.pred_codes = common.numpy.concatenate(self.pred_codes,
                                                       output_codes)

            if b % 100 == 50:
                log('[Testing] %d' % b)

        assert self.pred_images.shape[0] == self.test_images.shape[
            0], 'computed invalid number of test images'
        if self.args.reconstruction_file:
            utils.write_hdf5(self.args.reconstruction_file, self.pred_images)
            log('[Testing] wrote %s' % self.args.reconstruction_file)

        if self.args.test_theta_file:
            assert self.pred_codes.shape[0] == self.test_images.shape[
                0], 'computed invalid number of test codes'
            utils.write_hdf5(self.args.test_theta_file, self.pred_codes)
            log('[Testing] wrote %s' % self.args.test_theta_file)

        threshold = 0.9
        percentage = 0
        # values = numpy.linalg.norm(pred_codes, ord=2, axis=1)
        values = numpy.max(numpy.abs(self.pred_codes), axis=1)

        while percentage < 0.9:
            threshold += 0.1
            percentage = numpy.sum(values <= threshold) / float(
                values.shape[0])
            log('[Testing] threshold %g percentage %g' %
                (threshold, percentage))
        log('[Testing] taking threshold %g with percentage %g' %
            (threshold, percentage))

        if self.args.output_directory and utils.display():
            # fit = 10
            # plot_file = os.path.join(self.args.output_directory, 'test_codes')
            # plot.manifold(plot_file, pred_codes[::fit], None, None, 'tsne', None, title='t-SNE of Test Codes')
            # log('[Testing] wrote %s' % plot_file)

            for d in range(1, self.pred_codes.shape[1]):
                plot_file = os.path.join(self.args.output_directory,
                                         'test_codes_%s' % d)
                plot.scatter(
                    plot_file,
                    self.pred_codes[:, 0],
                    self.pred_codes[:, d], (values <= threshold).astype(int),
                    ['greater %g' % threshold,
                     'smaller %g' % threshold],
                    title='Dimensions 0 and %d of Test Codes' % d)
                log('[Testing] wrote %s' % plot_file)

        self.reconstruction_error /= num_batches
        log('[Testing] reconstruction error %g' % self.reconstruction_error)
    def test(self, epoch):
        """
        Test the model.

        :param epoch: current epoch
        :type epoch: int
        """

        self.encoder.eval()
        log('[Training] %d set encoder to eval' % epoch)
        self.decoder.eval()
        log('[Training] %d set decoder to eval' % epoch)
        self.classifier.eval()
        log('[Training] %d set classifier to eval' % epoch)

        latent_loss = 0
        reconstruction_loss = 0
        reconstruction_error = 0
        decoder_loss = 0
        discriminator_loss = 0
        mean = 0
        var = 0
        logvar = 0
        pred_images = None
        pred_codes = None

        num_batches = int(
            math.ceil(self.test_images.shape[0] / self.args.batch_size))
        assert self.encoder.training is False

        for b in range(num_batches):
            b_start = b * self.args.batch_size
            b_end = min((b + 1) * self.args.batch_size,
                        self.test_images.shape[0])
            batch_images = common.torch.as_variable(
                self.test_images[b_start:b_end], self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            output_mu, output_logvar = self.encoder(batch_images)
            output_images = self.decoder(output_mu)

            output_real_classes = self.classifier(batch_images)
            output_reconstructed_classes = self.classifier(output_images)

            # Latent loss.
            e = self.latent_loss(output_mu, output_logvar)
            latent_loss += e.item()

            # Reconstruction loss.
            e = self.reconstruction_loss(batch_images, output_images)
            reconstruction_loss += e.item()

            # Reconstruction error.
            e = self.reconstruction_error(batch_images, output_images)
            reconstruction_error += e.item()

            e = self.decoder_loss(output_reconstructed_classes)
            decoder_loss += e.item()

            # Adversarial loss.
            e = self.discriminator_loss(output_real_classes,
                                        output_reconstructed_classes)
            discriminator_loss += e.item()

            mean += torch.mean(output_mu).item()
            var += torch.var(output_mu).item()
            logvar += torch.mean(output_logvar).item()

            output_images = numpy.squeeze(
                numpy.transpose(output_images.cpu().detach().numpy(),
                                (0, 2, 3, 1)))
            pred_images = common.numpy.concatenate(pred_images, output_images)
            output_codes = output_mu.cpu().detach().numpy()
            pred_codes = common.numpy.concatenate(pred_codes, output_codes)

        utils.write_hdf5(self.args.reconstruction_file, pred_images)
        log('[Training] %d: wrote %s' % (epoch, self.args.reconstruction_file))

        if utils.display():
            png_file = self.args.reconstruction_file + '.%d.png' % epoch
            if epoch == 0:
                vis.mosaic(png_file, self.test_images[:225], 15, 5, 'gray', 0,
                           1)
            else:
                vis.mosaic(png_file, pred_images[:225], 15, 5, 'gray', 0, 1)
            log('[Training] %d: wrote %s' % (epoch, png_file))

        latent_loss /= num_batches
        reconstruction_loss /= num_batches
        reconstruction_error /= num_batches
        decoder_loss /= num_batches
        discriminator_loss /= num_batches
        mean /= num_batches
        var /= num_batches
        logvar /= num_batches
        log('[Training] %d: test %g (%g) %g (%g, %g, %g)' %
            (epoch, reconstruction_loss, reconstruction_error, latent_loss,
             mean, var, logvar))
        log('[Training] %d: test %g %g' %
            (epoch, decoder_loss, discriminator_loss))

        num_batches = int(
            math.ceil(self.train_images.shape[0] / self.args.batch_size))
        iteration = epoch * num_batches
        self.test_statistics = numpy.vstack(
            (self.test_statistics,
             numpy.array([
                 iteration, iteration * self.args.batch_size,
                 min(num_batches, iteration),
                 min(num_batches, iteration) * self.args.batch_size,
                 reconstruction_loss, reconstruction_error, latent_loss, mean,
                 var, logvar, decoder_loss, discriminator_loss
             ])))

        pred_images = None
        if self.random_codes is None:
            self.random_codes = common.numpy.truncated_normal(
                (1000, self.args.latent_space_size)).astype(numpy.float32)
        num_batches = int(
            math.ceil(self.random_codes.shape[0] / self.args.batch_size))

        for b in range(num_batches):
            b_start = b * self.args.batch_size
            b_end = min((b + 1) * self.args.batch_size,
                        self.test_images.shape[0])
            if b_start >= b_end: break

            batch_codes = common.torch.as_variable(
                self.random_codes[b_start:b_end], self.args.use_gpu)
            output_images = self.decoder(batch_codes)

            output_images = numpy.squeeze(
                numpy.transpose(output_images.cpu().detach().numpy(),
                                (0, 2, 3, 1)))
            pred_images = common.numpy.concatenate(pred_images, output_images)

        utils.write_hdf5(self.args.random_file, pred_images)
        log('[Training] %d: wrote %s' % (epoch, self.args.random_file))

        if utils.display() and epoch > 0:
            png_file = self.args.random_file + '.%d.png' % epoch
            vis.mosaic(png_file, pred_images[:225], 15, 5, 'gray', 0, 1)
            log('[Training] %d: wrote %s' % (epoch, png_file))

        interpolations = None
        perm = numpy.random.permutation(numpy.array(range(
            pred_codes.shape[0])))

        for i in range(50):
            first = pred_codes[i]
            second = pred_codes[perm[i]]
            linfit = scipy.interpolate.interp1d([0, 1],
                                                numpy.vstack([first, second]),
                                                axis=0)
            interpolations = common.numpy.concatenate(
                interpolations, linfit(numpy.linspace(0, 1, 10)))

        pred_images = None
        num_batches = int(
            math.ceil(interpolations.shape[0] / self.args.batch_size))
        interpolations = interpolations.astype(numpy.float32)

        for b in range(num_batches):
            b_start = b * self.args.batch_size
            b_end = min((b + 1) * self.args.batch_size,
                        self.test_images.shape[0])
            if b_start >= b_end: break

            batch_codes = common.torch.as_variable(
                interpolations[b_start:b_end], self.args.use_gpu)
            output_images = self.decoder(batch_codes)
            output_images = numpy.squeeze(
                numpy.transpose(output_images.cpu().detach().numpy(),
                                (0, 2, 3, 1)))
            pred_images = common.numpy.concatenate(pred_images, output_images)

            if b % 100 == 50:
                log('[Testing] %d' % b)

        utils.write_hdf5(self.args.interpolation_file, pred_images)
        log('[Testing] wrote %s' % self.args.interpolation_file)

        if utils.display() and epoch > 0:
            png_file = self.args.interpolation_file + '.%d.png' % epoch
            vis.mosaic(png_file, pred_images[:100], 10, 5, 'gray', 0, 1)
            log('[Training] %d: wrote %s' % (epoch, png_file))
    def train(self, epoch):
        """
        Train for one epoch.

        :param epoch: current epoch
        :type epoch: int
        """

        self.encoder.train()
        log('[Training] %d set encoder to train' % epoch)
        self.decoder.train()
        log('[Training] %d set decoder to train' % epoch)
        self.classifier.train()
        log('[Training] %d set classifier to train' % epoch)

        num_batches = int(
            math.ceil(self.train_images.shape[0] / self.args.batch_size))
        assert self.encoder.training is True

        permutation = numpy.random.permutation(self.train_images.shape[0])
        permutation = numpy.concatenate(
            (permutation, permutation[:self.args.batch_size]), axis=0)

        for b in range(num_batches):
            self.encoder_scheduler.update(epoch, float(b) / num_batches)
            self.decoder_scheduler.update(epoch, float(b) / num_batches)
            self.classifier_scheduler.update(epoch, float(b) / num_batches)

            perm = permutation[b * self.args.batch_size:(b + 1) *
                               self.args.batch_size]
            batch_images = common.torch.as_variable(self.train_images[perm],
                                                    self.args.use_gpu, True)
            batch_images = batch_images.permute(0, 3, 1, 2)

            output_mu, output_logvar = self.encoder(batch_images)
            output_codes = self.reparameterize(output_mu, output_logvar)
            output_images = self.decoder(output_codes)

            output_real_classes = self.classifier(batch_images)
            output_reconstructed_classes = self.classifier(output_images)

            latent_loss = self.latent_loss(output_mu, output_logvar)
            reconstruction_loss = self.reconstruction_loss(
                batch_images, output_images)
            decoder_loss = self.decoder_loss(output_reconstructed_classes)
            discriminator_loss = self.discriminator_loss(
                output_real_classes, output_reconstructed_classes)

            self.encoder_scheduler.optimizer.zero_grad()
            loss = latent_loss + self.args.beta * reconstruction_loss + self.args.gamma * decoder_loss + self.args.eta * torch.sum(
                torch.abs(output_logvar))
            loss.backward(retain_graph=True)
            self.encoder_scheduler.optimizer.step()

            self.decoder_scheduler.optimizer.zero_grad()
            loss = self.args.beta * reconstruction_loss + self.args.gamma * decoder_loss
            loss.backward(retain_graph=True)
            self.decoder_scheduler.optimizer.step()

            self.classifier_scheduler.optimizer.zero_grad()
            loss = self.args.gamma * discriminator_loss
            loss.backward()
            self.classifier_scheduler.optimizer.step()

            reconstruction_error = self.reconstruction_error(
                batch_images, output_images)
            iteration = epoch * num_batches + b + 1
            self.train_statistics = numpy.vstack(
                (self.train_statistics,
                 numpy.array([
                     iteration, iteration * self.args.batch_size,
                     min(num_batches, iteration),
                     min(num_batches, iteration) * self.args.batch_size,
                     reconstruction_loss.data, reconstruction_error.data,
                     latent_loss.data,
                     torch.mean(output_mu).item(),
                     torch.var(output_mu).item(),
                     torch.mean(output_logvar).item(),
                     decoder_loss.item(),
                     discriminator_loss.item(),
                     torch.mean(
                         torch.abs(list(
                             self.encoder.parameters())[0].grad)).item(),
                     torch.mean(
                         torch.abs(list(
                             self.decoder.parameters())[0].grad)).item(),
                     torch.mean(
                         torch.abs(list(
                             self.classifier.parameters())[0].grad)).item()
                 ])))

            skip = 10
            if b % skip == skip // 2:
                log('[Training] %d | %d: %g (%g) %g (%g, %g, %g)' % (
                    epoch,
                    b,
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 4]),
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 5]),
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 6]),
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 7]),
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 8]),
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 9]),
                ))
                log('[Training] %d | %d: %g %g (%g, %g, %g)' % (
                    epoch,
                    b,
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 10]),
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 11]),
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 12]),
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 13]),
                    numpy.mean(self.train_statistics[max(0, iteration -
                                                         skip):iteration, 14]),
                ))
示例#7
0
    def train(self):
        """
        Train adversarially.
        """

        num_batches = int(
            math.ceil(self.train_images.shape[0] / self.args.batch_size))
        permutation = numpy.random.permutation(self.train_images.shape[0])
        perturbation_permutation = numpy.random.permutation(
            self.train_images.shape[0])
        if self.args.safe:
            perturbation_permutation = perturbation_permutation[
                self.train_valid == 1]
        else:
            perturbation_permuation = permutation

        for b in range(num_batches):
            self.scheduler.update(self.epoch, float(b) / num_batches)

            self.model.eval()
            assert self.model.training is False
            objective = self.objective_class()
            split = self.args.batch_size // 2

            if self.args.full_variant:
                perm = numpy.concatenate(
                    (numpy.take(permutation,
                                range(b * self.args.batch_size,
                                      b * self.args.batch_size + split),
                                mode='wrap'),
                     numpy.take(perturbation_permutation,
                                range(b * self.args.batch_size + split,
                                      (b + 1) * self.args.batch_size),
                                mode='wrap')),
                    axis=0)
                batch_images = common.torch.as_variable(
                    self.train_images[perm], self.args.use_gpu)
                batch_classes = common.torch.as_variable(
                    self.train_codes[perm], self.args.use_gpu)
                batch_theta = common.torch.as_variable(self.train_theta[perm],
                                                       self.args.use_gpu)
                batch_images = batch_images.permute(0, 3, 1, 2)

                attack = self.setup_attack(self.model, batch_images[:split],
                                           batch_classes[:split])
                success, perturbations, _, _, _ = attack.run(
                    objective, self.args.verbose)
                batch_perturbations1 = common.torch.as_variable(
                    perturbations.astype(numpy.float32), self.args.use_gpu)
                batch_perturbed_images1 = batch_images[:split] + batch_perturbations1

                if isinstance(self.decoder, models.SelectiveDecoder):
                    self.decoder.set_code(batch_classes[split:])
                attack = self.setup_decoder_attack(self.decoder_classifier,
                                                   batch_theta[split:],
                                                   batch_classes[split:])
                attack.set_bound(torch.from_numpy(self.min_bound),
                                 torch.from_numpy(self.max_bound))
                decoder_success, decoder_perturbations, probabilities, norm, _ = attack.run(
                    objective, self.args.verbose)

                batch_perturbed_theta = batch_theta[
                    split:] + common.torch.as_variable(decoder_perturbations,
                                                       self.args.use_gpu)
                batch_perturbed_images2 = self.decoder(batch_perturbed_theta)
                batch_perturbations2 = batch_perturbed_images2 - batch_images[
                    split:]

                batch_input_images = torch.cat(
                    (batch_perturbed_images1, batch_perturbed_images2), dim=0)

                self.model.train()
                assert self.model.training is True

                output_classes = self.model(batch_input_images)

                self.scheduler.optimizer.zero_grad()
                perturbation_loss = self.loss(batch_classes[:split],
                                              output_classes[:split])
                decoder_perturbation_loss = self.loss(batch_classes[split:],
                                                      output_classes[split:])
                loss = (perturbation_loss + decoder_perturbation_loss) / 2
                loss.backward()
                self.scheduler.optimizer.step()
                loss = loss.item()
                perturbation_loss = perturbation_loss.item()
                decoder_perturbation_loss = decoder_perturbation_loss.item()

                gradient = torch.mean(
                    torch.abs(list(self.model.parameters())[0].grad))
                gradient = gradient.item()

                perturbation_error = self.error(batch_classes[:split],
                                                output_classes[:split])
                perturbation_error = perturbation_error.item()

                decoder_perturbation_error = self.error(
                    batch_classes[split:], output_classes[split:])
                decoder_perturbation_error = decoder_perturbation_error.item()

                error = (perturbation_error + decoder_perturbation_error) / 2
            else:
                perm = numpy.concatenate((
                    numpy.take(
                        perturbation_permutation,
                        range(b * self.args.batch_size + split + split // 2,
                              (b + 1) * self.args.batch_size),
                        mode='wrap'),
                    numpy.take(
                        permutation,
                        range(b * self.args.batch_size,
                              b * self.args.batch_size + split + split // 2),
                        mode='wrap'),
                ),
                                         axis=0)
                batch_images = common.torch.as_variable(
                    self.train_images[perm], self.args.use_gpu)
                batch_classes = common.torch.as_variable(
                    self.train_codes[perm], self.args.use_gpu)
                batch_theta = common.torch.as_variable(self.train_theta[perm],
                                                       self.args.use_gpu)
                batch_images = batch_images.permute(0, 3, 1, 2)

                attack = self.setup_attack(self.model,
                                           batch_images[split // 2:split],
                                           batch_classes[split // 2:split])
                success, perturbations, _, _, _ = attack.run(
                    objective, self.args.verbose)
                batch_perturbations1 = common.torch.as_variable(
                    perturbations.astype(numpy.float32), self.args.use_gpu)
                batch_perturbed_images1 = batch_images[
                    split // 2:split] + batch_perturbations1

                if isinstance(self.decoder, models.SelectiveDecoder):
                    self.decoder.set_code(batch_classes[:split // 2])
                attack = self.setup_decoder_attack(self.decoder_classifier,
                                                   batch_theta[:split // 2],
                                                   batch_classes[:split // 2])
                attack.set_bound(torch.from_numpy(self.min_bound),
                                 torch.from_numpy(self.max_bound))
                decoder_success, decoder_perturbations, probabilities, norm, _ = attack.run(
                    objective, self.args.verbose)

                batch_perturbed_theta = batch_theta[:split //
                                                    2] + common.torch.as_variable(
                                                        decoder_perturbations,
                                                        self.args.use_gpu)
                batch_perturbed_images2 = self.decoder(batch_perturbed_theta)
                batch_perturbations2 = batch_perturbed_images2 - batch_images[:split
                                                                              //
                                                                              2]

                batch_input_images = torch.cat(
                    (batch_perturbed_images2, batch_perturbed_images1,
                     batch_images[split:]),
                    dim=0)

                self.model.train()
                assert self.model.training is True

                output_classes = self.model(batch_input_images)

                self.scheduler.optimizer.zero_grad()
                loss = self.loss(batch_classes[split:], output_classes[split:])
                perturbation_loss = self.loss(batch_classes[split // 2:split],
                                              output_classes[split // 2:split])
                decoder_perturbation_loss = self.loss(
                    batch_classes[:split // 2], output_classes[:split // 2])
                l = (loss + perturbation_loss + decoder_perturbation_loss) / 3
                l.backward()
                self.scheduler.optimizer.step()
                loss = loss.item()
                perturbation_loss = perturbation_loss.item()
                decoder_perturbation_loss = decoder_perturbation_loss.item()

                gradient = torch.mean(
                    torch.abs(list(self.model.parameters())[0].grad))
                gradient = gradient.item()

                error = self.error(batch_classes[split:],
                                   output_classes[split:])
                error = error.item()

                perturbation_error = self.error(
                    batch_classes[split // 2:split],
                    output_classes[split // 2:split])
                perturbation_error = perturbation_error.item()

                decoder_perturbation_error = self.error(
                    batch_classes[:split // 2], output_classes[:split // 2])
                decoder_perturbation_error = decoder_perturbation_error.item()

            iterations = numpy.mean(
                success[success >= 0]) if numpy.sum(success >= 0) > 0 else -1
            norm = numpy.mean(
                numpy.linalg.norm(perturbations.reshape(
                    perturbations.shape[0], -1),
                                  axis=1,
                                  ord=self.norm))
            success = numpy.sum(success >= 0) / self.args.batch_size

            decoder_iterations = numpy.mean(
                decoder_success[decoder_success >= 0]) if numpy.sum(
                    decoder_success >= 0) > 0 else -1
            decoder_norm = numpy.mean(
                numpy.linalg.norm(decoder_perturbations, axis=1,
                                  ord=self.norm))
            decoder_success = numpy.sum(
                decoder_success >= 0) / self.args.batch_size

            iteration = self.epoch * num_batches + b + 1
            self.train_statistics = numpy.vstack((
                self.train_statistics,
                numpy.array([[
                    iteration,  # iterations
                    iteration * (1 + self.args.max_iterations) *
                    self.args.batch_size,  # samples seen
                    min(num_batches, iteration) * self.args.batch_size +
                    iteration * self.args.max_iterations *
                    self.args.batch_size,  # unique samples seen
                    loss,
                    error,
                    perturbation_loss,
                    perturbation_error,
                    decoder_perturbation_loss,
                    decoder_perturbation_error,
                    success,
                    iterations,
                    norm,
                    decoder_success,
                    decoder_iterations,
                    decoder_norm,
                    gradient
                ]])))

            if b % self.args.skip == self.args.skip // 2:
                log('[Training] %d | %d: %g (%g) %g (%g) %g (%g) [%g]' % (
                    self.epoch,
                    b,
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 3]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 4]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 5]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 6]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 7]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 8]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, -1]),
                ))
                log('[Training] %d | %d: %g (%g, %g) %g (%g, %g)' % (
                    self.epoch,
                    b,
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 9]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 10]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 11]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 12]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 13]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 14]),
                ))

        self.debug('clean.%d.png' % self.epoch,
                   batch_images.permute(0, 2, 3, 1))
        self.debug('perturbed.%d.png' % self.epoch,
                   batch_perturbed_images1.permute(0, 2, 3, 1))
        self.debug('perturbed2.%d.png' % self.epoch,
                   batch_perturbed_images2.permute(0, 2, 3, 1))
        self.debug('perturbation.%d.png' % self.epoch,
                   batch_perturbations1.permute(0, 2, 3, 1),
                   cmap='seismic')
        self.debug('perturbation2.%d.png' % self.epoch,
                   batch_perturbations2.permute(0, 2, 3, 1),
                   cmap='seismic')
    def train(self, epoch):
        """
        Train for one epoch.

        :param epoch: current epoch
        :type epoch: int
        """

        assert self.encoder is not None and self.decoder is not None
        assert self.scheduler is not None

        self.auto_encoder.train()
        log('[Training] %d set auto encoder to train' % epoch)
        self.encoder.train()
        log('[Training] %d set encoder to train' % epoch)
        self.decoder.train()
        log('[Training] %d set decoder to train' % epoch)

        num_batches = int(math.ceil(self.train_images.shape[0]/self.args.batch_size))
        assert self.encoder.training is True

        permutation = numpy.random.permutation(self.train_images.shape[0])
        permutation = numpy.concatenate((permutation, permutation[:self.args.batch_size]), axis=0)

        for b in range(num_batches):
            self.scheduler.update(epoch, float(b)/num_batches)

            perm = permutation[b * self.args.batch_size: (b + 1) * self.args.batch_size]
            batch_images = common.torch.as_variable(self.train_images[perm], self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            output_images, output_mu, output_logvar = self.auto_encoder(batch_images)
            reconstruction_loss = self.reconstruction_loss(batch_images, output_images)

            self.scheduler.optimizer.zero_grad()
            latent_loss = self.latent_loss(output_mu, output_logvar)
            loss = self.args.beta*reconstruction_loss + latent_loss
            loss.backward()
            self.scheduler.optimizer.step()
            reconstruction_loss = reconstruction_loss.item()
            latent_loss = latent_loss.item()

            reconstruction_error = self.reconstruction_error(batch_images, output_images)
            reconstruction_error = reconstruction_error.item()

            iteration = epoch*num_batches + b + 1
            self.train_statistics = numpy.vstack((self.train_statistics, numpy.array([
                iteration,
                iteration * self.args.batch_size,
                min(num_batches, iteration),
                min(num_batches, iteration) * self.args.batch_size,
                reconstruction_loss,
                reconstruction_error,
                latent_loss,
                torch.mean(output_mu).item(),
                torch.var(output_mu).item(),
                torch.mean(output_logvar).item(),
            ])))

            skip = 10
            if b%skip == skip//2:
                log('[Training] %d | %d: %g (%g) %g %g %g %g' % (
                    epoch,
                    b,
                    numpy.mean(self.train_statistics[max(0, iteration-skip):iteration, 4]),
                    numpy.mean(self.train_statistics[max(0, iteration-skip):iteration, 5]),
                    numpy.mean(self.train_statistics[max(0, iteration-skip):iteration, 6]),
                    numpy.mean(self.train_statistics[max(0, iteration-skip):iteration, 7]),
                    numpy.mean(self.train_statistics[max(0, iteration-skip):iteration, 8]),
                    numpy.mean(self.train_statistics[max(0, iteration-skip):iteration, 9]),
                ))
示例#9
0
    def train(self):
        """
        Train with fair data augmentation.
        """

        self.model.train()
        log('[Training] %d set classifier to train' % self.epoch)
        assert self.model.training is True

        split = self.args.batch_size // 2
        num_batches = int(
            math.ceil(self.train_images.shape[0] / self.args.batch_size))
        permutation = numpy.random.permutation(self.train_images.shape[0])

        for b in range(num_batches):
            self.scheduler.update(self.epoch, float(b) / num_batches)

            perm = numpy.take(permutation,
                              range(b * self.args.batch_size,
                                    (b + 1) * self.args.batch_size),
                              mode='wrap')
            batch_images = common.torch.as_variable(self.train_images[perm],
                                                    self.args.use_gpu)
            batch_classes = common.torch.as_variable(self.train_codes[perm],
                                                     self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            if self.args.full_variant:
                loss = error = gradient = 0
                for t in range(self.args.max_iterations):
                    size = batch_images.size()
                    batch_perturbations = common.numpy.uniform_ball(
                        size[0],
                        numpy.prod(size[1:]),
                        epsilon=self.args.epsilon,
                        ord=self.norm)
                    batch_perturbations = common.torch.as_variable(
                        batch_perturbations.reshape(size).astype(
                            numpy.float32), self.args.use_gpu)
                    batch_perturbations = torch.min(
                        torch.ones_like(batch_images) - batch_images,
                        batch_perturbations)
                    batch_perturbations = torch.max(
                        torch.zeros_like(batch_images) - batch_images,
                        batch_perturbations)

                    batch_perturbed_images = batch_images + batch_perturbations
                    output_perturbed_classes = self.model(
                        batch_perturbed_images)

                    self.scheduler.optimizer.zero_grad()
                    l = self.loss(batch_classes, output_perturbed_classes)
                    l.backward()
                    self.scheduler.optimizer.step()
                    loss += l.item()

                    g = torch.mean(
                        torch.abs(list(self.model.parameters())[0].grad))
                    gradient += g.item()

                    e = self.error(batch_classes, output_perturbed_classes)
                    error += e.item()

                gradient /= self.args.max_iterations
                loss /= self.args.max_iterations
                error /= self.args.max_iterations
                perturbation_loss = loss
                perturbation_error = error

            elif self.args.strong_variant:
                raise NotImplementedError('strong_variant not implemented yet')
            else:
                output_classes = self.model(batch_images[:split])

                self.scheduler.optimizer.zero_grad()
                l = self.loss(batch_classes[:split], output_classes)
                l.backward()
                self.scheduler.optimizer.step()
                loss = l.item()

                gradient = torch.mean(
                    torch.abs(list(self.model.parameters())[0].grad))
                gradient = gradient.item()

                e = self.error(batch_classes[:split], output_classes)
                error = e.item()

                perturbation_loss = perturbation_error = 0
                for t in range(self.args.max_iterations):
                    size = batch_images.size()
                    batch_perturbations = common.numpy.uniform_ball(
                        split,
                        numpy.prod(size[1:]),
                        epsilon=self.args.epsilon,
                        ord=self.norm)
                    batch_perturbations = common.torch.as_variable(
                        batch_perturbations.reshape(
                            split, size[1], size[2],
                            size[3]).astype(numpy.float32), self.args.use_gpu)
                    batch_perturbations = torch.min(
                        torch.ones_like(batch_images[split:]) -
                        batch_images[split:], batch_perturbations)
                    batch_perturbations = torch.max(
                        torch.zeros_like(batch_images[split:]) -
                        batch_images[split:], batch_perturbations)

                    batch_perturbed_images = batch_images[
                        split:] + batch_perturbations
                    output_perturbed_classes = self.model(
                        batch_perturbed_images)

                    self.scheduler.optimizer.zero_grad()
                    l = self.loss(batch_classes[split:],
                                  output_perturbed_classes)
                    l.backward()
                    self.scheduler.optimizer.step()
                    perturbation_loss += l.item()

                    g = torch.mean(
                        torch.abs(list(self.model.parameters())[0].grad))
                    gradient += g.item()

                    e = self.error(batch_classes[split:],
                                   output_perturbed_classes)
                    perturbation_error += e.item()

                gradient /= self.args.max_iterations
                perturbation_loss /= self.args.max_iterations
                perturbation_error /= self.args.max_iterations

            iteration = self.epoch * num_batches + b + 1
            self.train_statistics = numpy.vstack((
                self.train_statistics,
                numpy.array([[
                    iteration,  # iterations
                    iteration * (1 + self.args.max_iterations) *
                    self.args.batch_size,  # samples seen
                    min(num_batches, iteration) * self.args.batch_size +
                    iteration * self.args.max_iterations *
                    self.args.batch_size,  # unique samples seen
                    loss,  # clean loss
                    error,  # clean error (1-accuracy)
                    perturbation_loss,  # perturbation loss
                    perturbation_error,  # perturbation error (1-accuracy)
                    gradient
                ]])))

            if b % self.args.skip == self.args.skip // 2:
                log('[Training] %d | %d: %g (%g) %g (%g) [%g]' % (
                    self.epoch,
                    b,
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 3]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 4]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 5]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 6]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, -1]),
                ))

        self.debug('clean.%d.png' % self.epoch,
                   batch_images.permute(0, 2, 3, 1))
        self.debug('perturbed.%d.png' % self.epoch,
                   batch_perturbed_images.permute(0, 2, 3, 1))
        self.debug('perturbation.%d.png' % self.epoch,
                   batch_perturbations.permute(0, 2, 3, 1),
                   cmap='seismic')
    def train(self):
        """
        Train adversarially.
        """

        split = self.args.batch_size // 2
        num_batches = int(
            math.ceil(self.train_images.shape[0] / self.args.batch_size))
        permutation = numpy.random.permutation(self.train_images.shape[0])

        for b in range(num_batches):
            self.scheduler.update(self.epoch, float(b) / num_batches)

            perm = numpy.take(permutation,
                              range(b * self.args.batch_size,
                                    (b + 1) * self.args.batch_size),
                              mode='wrap')
            batch_images = common.torch.as_variable(self.train_images[perm],
                                                    self.args.use_gpu)
            batch_theta = common.torch.as_variable(self.train_theta[perm],
                                                   self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            batch_fonts = self.train_codes[perm, 1]
            batch_classes = self.train_codes[perm, self.args.label_index]
            batch_code = numpy.concatenate(
                (common.numpy.one_hot(batch_fonts, self.N_font),
                 common.numpy.one_hot(batch_classes, self.N_class)),
                axis=1).astype(numpy.float32)
            batch_code = common.torch.as_variable(batch_code,
                                                  self.args.use_gpu)
            batch_classes = common.torch.as_variable(batch_classes,
                                                     self.args.use_gpu)

            self.model.eval()
            assert self.model.training is False

            if self.args.full_variant:
                objective = self.objective_class()
                self.decoder.set_code(batch_code)
                attack = self.setup_attack(self.decoder_classifier,
                                           batch_theta, batch_classes)
                attack.set_bound(torch.from_numpy(self.min_bound),
                                 torch.from_numpy(self.max_bound))
                success, perturbations, probabilities, norm, _ = attack.run(
                    objective, self.args.verbose)

                batch_perturbed_theta = batch_theta + common.torch.as_variable(
                    perturbations, self.args.use_gpu)
                batch_perturbed_images = self.decoder(batch_perturbed_theta)
                batch_perturbations = batch_perturbed_images - batch_images

                self.model.train()
                assert self.model.training is True

                output_classes = self.model(batch_perturbed_images)

                self.scheduler.optimizer.zero_grad()
                loss = self.loss(batch_classes, output_classes)
                loss.backward()
                self.scheduler.optimizer.step()
                loss = perturbation_loss = loss.item()

                gradient = torch.mean(
                    torch.abs(list(self.model.parameters())[0].grad))
                gradient = gradient.item()

                error = self.error(batch_classes, output_classes)
                error = perturbation_error = error.item()
            else:
                objective = self.objective_class()
                self.decoder.set_code(batch_code[split:])
                attack = self.setup_attack(self.decoder_classifier,
                                           batch_theta[split:],
                                           batch_classes[split:])
                attack.set_bound(torch.from_numpy(self.min_bound),
                                 torch.from_numpy(self.max_bound))
                success, perturbations, probabilities, norm, _ = attack.run(
                    objective, self.args.verbose)

                batch_perturbed_theta = batch_theta[
                    split:] + common.torch.as_variable(perturbations,
                                                       self.args.use_gpu)
                batch_perturbed_images = self.decoder(batch_perturbed_theta)
                batch_perturbations = batch_perturbed_images - batch_images[
                    split:]

                self.model.train()
                assert self.model.training is True

                batch_input_images = torch.cat(
                    (batch_images[:split], batch_perturbed_images), dim=0)
                output_classes = self.model(batch_input_images)

                self.scheduler.optimizer.zero_grad()
                loss = self.loss(batch_classes[:split], output_classes[:split])
                perturbation_loss = self.loss(batch_classes[split:],
                                              output_classes[split:])
                l = (loss + perturbation_loss) / 2
                l.backward()
                self.scheduler.optimizer.step()
                loss = loss.item()
                perturbation_loss = perturbation_loss.item()

                gradient = torch.mean(
                    torch.abs(list(self.model.parameters())[0].grad))
                gradient = gradient.item()

                error = self.error(batch_classes[:split],
                                   output_classes[:split])
                error = error.item()

                perturbation_error = self.error(batch_classes[split:],
                                                output_classes[split:])
                perturbation_error = perturbation_error.item()

            iterations = numpy.mean(
                success[success >= 0]) if numpy.sum(success >= 0) > 0 else -1
            norm = numpy.mean(
                numpy.linalg.norm(perturbations.reshape(
                    perturbations.shape[0], -1),
                                  axis=1,
                                  ord=self.norm))
            success = numpy.sum(success >= 0) / (self.args.batch_size // 2)

            iteration = self.epoch * num_batches + b + 1
            self.train_statistics = numpy.vstack((
                self.train_statistics,
                numpy.array([[
                    iteration,  # iterations
                    iteration * (1 + self.args.max_iterations) *
                    self.args.batch_size,  # samples seen
                    min(num_batches, iteration) * self.args.batch_size +
                    iteration * self.args.max_iterations *
                    self.args.batch_size,  # unique samples seen
                    loss,
                    error,
                    perturbation_loss,
                    perturbation_error,
                    success,
                    iterations,
                    norm,
                    gradient
                ]])))

            if b % self.args.skip == self.args.skip // 2:
                log('[Training] %d | %d: %g (%g) %g (%g) [%g]' % (
                    self.epoch,
                    b,
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 3]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 4]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 5]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 6]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, -1]),
                ))
                log('[Training] %d | %d: %g (%g, %g)' % (
                    self.epoch,
                    b,
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 7]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 8]),
                    numpy.mean(self.train_statistics[
                        max(0, iteration - self.args.skip):iteration, 9]),
                ))

        self.debug('clean.%d.png' % self.epoch,
                   batch_images.permute(0, 2, 3, 1))
        self.debug('perturbed.%d.png' % self.epoch,
                   batch_perturbed_images.permute(0, 2, 3, 1))
        self.debug('perturbation.%d.png' % self.epoch,
                   batch_perturbations.permute(0, 2, 3, 1),
                   cmap='seismic')
    def compute_appr(self):
        """
        Compute approximate.
        """

        assert self.test_codes is not None
        num_batches = int(math.ceil(self.perturbations.shape[0] / self.args.batch_size))

        for b in range(num_batches):
            b_start = b * self.args.batch_size
            b_end = min((b + 1) * self.args.batch_size, self.perturbations.shape[0])

            batch_classes = common.torch.as_variable(self.test_codes[b_start: b_end], self.args.use_gpu)
            batch_theta = common.torch.as_variable(self.test_theta[b_start: b_end].astype(numpy.float32), self.args.use_gpu, True)
            batch_perturbation = common.torch.as_variable(self.perturbations[b_start: b_end].astype(numpy.float32), self.args.use_gpu)

            if isinstance(self.model, models.SelectiveDecoder):
                self.model.set_code(batch_classes)
            batch_theta = torch.nn.Parameter(batch_theta)
            optimizer = torch.optim.Adam([batch_theta], lr=0.1)

            log('[Detection] %d: start' % b)
            for t in range(100):
                optimizer.zero_grad()
                output_perturbation = self.model.forward(batch_theta)
                error = torch.mean(torch.mul(output_perturbation - batch_perturbation, output_perturbation - batch_perturbation))
                error.backward()
                optimizer.step()

                log('[Detection] %d: %d = %g' % (b, t, error.item()))

            output_perturbation = numpy.squeeze(output_perturbation.cpu().detach().numpy())
            self.projected_perturbations = common.numpy.concatenate(self.projected_perturbations, output_perturbation)

            batch_theta = common.torch.as_variable(self.test_theta[b_start: b_end].astype(numpy.float32), self.args.use_gpu, True)
            batch_images = common.torch.as_variable(self.test_images[b_start: b_end].astype(numpy.float32), self.args.use_gpu)

            batch_theta = torch.nn.Parameter(batch_theta)
            optimizer = torch.optim.Adam([batch_theta], lr=0.5)

            log('[Detection] %d: start' % b)
            for t in range(100):
                optimizer.zero_grad()
                output_images = self.model.forward(batch_theta)
                error = torch.mean(torch.mul(output_images - batch_images, output_images - batch_images))
                error.backward()
                optimizer.step()

                log('[Detection] %d: %d = %g' % (b, t, error.item()))

            output_images = numpy.squeeze(output_images.cpu().detach().numpy())
            self.projected_test_images = common.numpy.concatenate(self.projected_test_images, output_images)

        projected_perturbations = self.projected_perturbations.reshape((self.projected_perturbations.shape[0], -1))
        projected_test_images = self.projected_test_images.reshape((self.projected_test_images.shape[0], -1))

        perturbations = self.perturbations.reshape((self.perturbations.shape[0], -1))
        test_images = self.test_images.reshape((self.test_images.shape[0], -1))

        success = numpy.logical_and(self.success >= 0, self.accuracy)
        log('[Detection] %d valid attacked samples' % numpy.sum(success))

        self.distances['true'] = numpy.linalg.norm(perturbations - projected_perturbations, ord=2, axis=1)
        self.angles['true'] = numpy.rad2deg(common.numpy.angles(perturbations.T, projected_perturbations.T))

        self.distances['true'] = self.distances['true'][success]
        self.angles['true'] = self.angles['true'][success]

        self.distances['test'] = numpy.linalg.norm(test_images - projected_test_images, ord=2, axis=1)
        self.angles['test'] = numpy.rad2deg(common.numpy.angles(test_images.T, projected_test_images.T))

        self.distances['test'] = self.distances['test'][success]
        self.angles['test'] = self.angles['test'][success]
    def compute_true(self):
        """
        Compute true.
        """

        assert self.test_codes is not None
        num_batches = int(math.ceil(self.perturbations.shape[0] / self.args.batch_size))

        params = {
            'lr': 0.09,
            'lr_decay': 0.95,
            'lr_min': 0.0000001,
            'weight_decay': 0,
        }

        for b in range(num_batches):
            b_start = b * self.args.batch_size
            b_end = min((b + 1) * self.args.batch_size, self.perturbations.shape[0])

            batch_fonts = self.test_codes[b_start: b_end, 1]
            batch_classes = self.test_codes[b_start: b_end, 2]
            batch_code = numpy.concatenate((common.numpy.one_hot(batch_fonts, self.N_font), common.numpy.one_hot(batch_classes, self.N_class)), axis=1).astype( numpy.float32)
            batch_code = common.torch.as_variable(batch_code, self.args.use_gpu)

            batch_images = common.torch.as_variable(self.test_images[b_start: b_end], self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            batch_theta = common.torch.as_variable(self.test_theta[b_start: b_end].astype(numpy.float32), self.args.use_gpu, True)
            batch_perturbation = common.torch.as_variable(self.perturbations[b_start: b_end].astype(numpy.float32), self.args.use_gpu)

            self.model.set_code(batch_code)

            #output_images = self.model.forward(batch_theta)
            #test_error = torch.mean(torch.mul(output_images - batch_images, output_images - batch_images))
            #print(test_error.item())
            #vis.mosaic('true.png', batch_images.cpu().detach().numpy()[:, 0, :, :])
            #vis.mosaic('output.png', output_images.cpu().detach().numpy()[:, 0, :, :])
            # print(batch_images.cpu().detach().numpy()[0])
            # print(output_images.cpu().detach().numpy()[0, 0])

            #_batch_images = batch_images.cpu().detach().numpy()
            #_output_images = output_images.cpu().detach().numpy()[:, 0, :, :]
            #test_error = numpy.max(numpy.abs(_batch_images.reshape(_batch_images.shape[0], -1) - _output_images.reshape(_output_images.shape[0], -1)), axis=1)
            #print(test_error)
            #test_error = numpy.mean(numpy.multiply(_batch_images - _output_images, _batch_images - _output_images), axis=1)
            #print(test_error)

            batch_theta = torch.nn.Parameter(batch_theta)
            scheduler = ADAMScheduler([batch_theta], **params)

            log('[Detection] %d: start' % b)
            for t in range(100):
                scheduler.update(t//10, float(t)/10)
                scheduler.optimizer.zero_grad()
                output_perturbation = self.model.forward(batch_theta)
                error = torch.mean(torch.mul(output_perturbation - batch_perturbation, output_perturbation - batch_perturbation))
                test_error = torch.mean(torch.mul(output_perturbation - batch_images, output_perturbation - batch_images))
                #error.backward()
                #scheduler.optimizer.step()

                log('[Detection] %d: %d = %g, %g' % (b, t, error.item(), test_error.item()))

                output_perturbation = numpy.squeeze(numpy.transpose(output_perturbation.cpu().detach().numpy(), (0, 2, 3, 1)))
            self.projected_perturbations = common.numpy.concatenate(self.projected_perturbations, output_perturbation)

        projected_perturbations = self.projected_perturbations.reshape((self.projected_perturbations.shape[0], -1))
        perturbations = self.perturbations.reshape((self.perturbations.shape[0], -1))

        success = numpy.logical_and(self.success >= 0, self.accuracy)
        log('[Detection] %d valid attacked samples' % numpy.sum(success))

        self.distances['true'] = numpy.linalg.norm(perturbations - projected_perturbations, ord=2, axis=1)
        self.angles['true'] = numpy.rad2deg(common.numpy.angles(perturbations.T, projected_perturbations.T))

        self.distances['true'] = self.distances['true'][success]
        self.angles['true'] = self.angles['true'][success]

        self.distances['test'] = numpy.zeros((numpy.sum(success)))
        self.angles['test'] = numpy.zeros((numpy.sum(success)))