Beispiel #1
0
class Solver(object):
    def __init__(self, args):
        self.args = args
        # Basic
        self.cuda = (args.cuda and torch.cuda.is_available())
        self.epoch = args.epoch
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.y_dim = args.y_dim  # MNIST and CIFAR10 have class 10
        self.target = args.target  # if you want to give pertubation to specific class then use it
        self.dataset = args.dataset
        self.data_loader = return_data(args)
        self.global_epoch = 0
        self.global_iter = 0
        self.print_ = not args.silent
        self.env_name = args.env_name  # experiment name
        self.visdom = args.visdom  # I have installed it but don't use it
        self.ckpt_dir = Path(args.ckpt_dir)
        self.save_ckpt_dir = Path('./checkpoints/' + args.env_name)
        print(self.save_ckpt_dir)
        if not self.ckpt_dir.exists():
            self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        if not self.save_ckpt_dir.exists():
            self.save_ckpt_dir.mkdir(parents=True, exist_ok=True)
        self.output_dir = Path(args.output_dir).joinpath(args.env_name)
        if not self.output_dir.exists():
            self.output_dir.mkdir(parents=True, exist_ok=True)

        # Visualization Tools
        self.visualization_init(args)

        # Histories
        self.history = dict()
        self.history['acc'] = 0.
        self.history['epoch'] = 0
        self.history['iter'] = 0

        # Models & Optimizers
        self.model_init(args)
        self.load_ckpt = args.load_ckpt
        if args.load_ckpt_flag == True and self.load_ckpt != '':
            self.load_checkpoint(self.load_ckpt)

        # Adversarial Perturbation Generator
        #criterion = cuda(torch.nn.CrossEntropyLoss(), self.cuda)
        criterion = F.cross_entropy
        self.attack_mode = args.attack_mode
        if self.attack_mode == 'FGSM':
            self.attack = Attack(self.net, criterion=criterion)
        elif self.attack_mode == 'ILLC':
            self.attack = Attack(self.net, criterion=criterion)

    def visualization_init(self, args):
        # Visdom
        if self.visdom:
            from utils.visdom_utils import VisFunc
            self.port = args.visdom_port
            self.vf = VisFunc(enval=self.env_name, port=self.port)

    def model_init(self, args):
        # Network
        if args.dataset == 'MNIST':
            print("MNIST")
            self.net = cuda(ToyNet_MNIST(y_dim=self.y_dim), self.cuda)
        elif args.dataset == 'CIFAR10':
            print("Dataset used CIFAR10")
            if args.network_choice == 'ToyNet':
                self.net = cuda(ToyNet_CIFAR10(y_dim=self.y_dim), self.cuda)
            elif args.network_choice == 'ResNet18':
                self.net = cuda(ResNet18(), self.cuda)
            elif args.network_choice == 'ResNet34':
                self.net = cuda(ResNet34(), self.cuda)
            elif args.network_choice == 'ResNet50':
                self.net = cuda(ResNet50(), self.cuda)
        self.net.weight_init(_type='kaiming')
        # setup optimizer
        self.optim = optim.Adam([{
            'params': self.net.parameters(),
            'lr': self.lr
        }],
                                betas=(0.5, 0.999))

    def train(self):
        self.set_mode('train')
        acc_train_plt = [0]
        loss_plt = []
        acc_test_plt = [0]
        for e in range(self.epoch):
            self.global_epoch += 1
            local_iter = 0
            correct = 0.
            cost = 0.
            total = 0.
            total_acc = 0.
            total_loss = 0.
            for batch_idx, (images,
                            labels) in enumerate(self.data_loader['train']):
                self.global_iter += 1
                local_iter += 1
                #print("image size is ", np.shape(images))

                x = Variable(cuda(images, self.cuda))
                y = Variable(cuda(labels, self.cuda))

                logit = self.net(x)
                prediction = logit.max(1)[1]

                correct = torch.eq(prediction, y).float().mean().data.item()
                cost = F.cross_entropy(logit, y)
                total_acc += correct
                total_loss += cost.data.item()

                self.optim.zero_grad()
                cost.backward()
                self.optim.step()

                if batch_idx % 100 == 0:
                    if self.print_:
                        print()
                        print(self.env_name)
                        print('[{:03d}:{:03d}]'.format(self.global_epoch,
                                                       batch_idx))
                        print('acc:{:.3f} loss:{:.3f}'.format(
                            correct, cost.data.item()))
            total_acc = total_acc / local_iter
            total_loss = total_loss / local_iter
            acc_train_plt.append(total_acc)
            loss_plt.append(total_loss)
            acc_test_plt.append(self.test())
        print(" [*] Training Finished!")
        self.plot_result(acc_train_plt, acc_test_plt, loss_plt,
                         self.history['acc'])

    def test(self):
        self.set_mode('eval')
        correct = 0.
        cost = 0.
        total = 0.
        data_loader = self.data_loader['test']
        for batch_idx, (images, labels) in enumerate(data_loader):
            x = Variable(cuda(images, self.cuda))
            y = Variable(cuda(labels, self.cuda))

            logit = self.net(x)
            prediction = logit.max(1)[1]

            correct += torch.eq(prediction, y).float().sum().data.item()
            cost += F.cross_entropy(logit, y, size_average=False).data.item()
            total += x.size(0)
        accuracy = correct / total
        cost /= total

        if self.history['acc'] < accuracy:
            self.history['acc'] = accuracy
            self.history['epoch'] = self.global_epoch
            self.history['iter'] = self.global_iter
            self.save_checkpoint('best_acc.tar')

        if self.print_:
            print()
            print('[{:03d}]\nTEST RESULT'.format(self.global_epoch))
            print('ACC:{:.4f}'.format(self.history['acc']))
            print('*TOP* ACC:{:.4f} at e:{:03d}'.format(
                self.history['acc'],
                self.global_epoch,
            ))
            print()

        self.set_mode('train')
        return accuracy

    def generate(self, target, epsilon, alpha, iteration):
        self.set_mode('eval')
        x_true, y_true = self.sample_data(
        )  # take sample which size is batch_size
        if isinstance(target, int) and (target in range(self.y_dim)):
            y_target = torch.LongTensor(y_true.size()).fill_(target)
        else:
            y_target = None

        # generate pertubation images, inside of self.FGSM, there are fgsm and i-fgsm method
        # please implement last one 'iterative least likely method'
        if self.attack_mode == 'FGSM':
            x_adv, changed, values = self.FGSM(x_true, y_true, y_target,
                                               epsilon, alpha, iteration)
        elif self.attack_mode == 'ILLC':
            x_adv, changed, values = self.ILLC(x_true, y_true, y_target,
                                               epsilon, alpha, iteration)
        accuracy, cost, accuracy_adv, cost_adv = values

        # save the result image, you can find in outputs/experiment_name
        save_image(x_true,
                   self.output_dir.joinpath(
                       'legitimate(t:{},e:{},i:{}).jpg'.format(
                           target, epsilon, iteration)),
                   nrow=10,
                   padding=2,
                   pad_value=0.5)
        save_image(x_adv,
                   self.output_dir.joinpath(
                       'perturbed(t:{},e:{},i:{}).jpg'.format(
                           target, epsilon, iteration)),
                   nrow=10,
                   padding=2,
                   pad_value=0.5)
        save_image(changed,
                   self.output_dir.joinpath(
                       'changed(t:{},e:{},i:{}).jpg'.format(
                           target, epsilon, iteration)),
                   nrow=10,
                   padding=3,
                   pad_value=0.5)

        if self.visdom:
            self.vf.imshow_multi(x_true.cpu(), title='legitimate', factor=1.5)
            self.vf.imshow_multi(x_adv.cpu(),
                                 title='perturbed(e:{},i:{})'.format(
                                     epsilon, iteration),
                                 factor=1.5)
            self.vf.imshow_multi(changed.cpu(),
                                 title='changed(white)'.format(epsilon),
                                 factor=1.5)

        print('[BEFORE] accuracy : {:.2f} cost : {:.3f}'.format(
            accuracy, cost))
        print('[AFTER] accuracy : {:.2f} cost : {:.3f}'.format(
            accuracy_adv, cost_adv))

        self.set_mode('train')

    def ad_train(self, target, alpha, iteration, lamb):
        self.set_mode('train')
        acc_train_plt = [0]
        acc_test_plt = [0]
        loss_plt = []
        for e in range(self.epoch):
            self.global_epoch += 1
            local_iter = 0
            correct = 0.
            cost = 0.
            total_acc = 0.
            total_loss = 0.
            total = 0.
            for batch_idx, (images,
                            labels) in enumerate(self.data_loader['train']):
                self.global_iter += 1
                local_iter += 1
                self.set_mode('eval')

                num_adv_image = self.batch_size // 2

                x_true = Variable(cuda(images[:num_adv_image], self.cuda))
                y_true = Variable(cuda(labels[:num_adv_image], self.cuda))

                x = Variable(cuda(images, self.cuda))
                y = Variable(cuda(labels, self.cuda))

                if isinstance(target, int) and (target in range(self.y_dim)):
                    y_target = torch.LongTensor(y_true.size()).fill_(target)
                else:
                    y_target = None

                epsilon = abs(np.random.normal(0, 8 / 255))
                if epsilon > 16 / 255:
                    epsilon = 0

                if self.attack_mode == 'FGSM':
                    x[:num_adv_image], _, _ = self.FGSM(
                        x_true, y_true, y_target, epsilon, alpha, iteration)
                elif self.attack_mode == 'ILLC':
                    x[:num_adv_image], _, _ = self.ILLC(
                        x_true, y_true, y_target, epsilon, alpha, iteration)

                self.set_mode('train')
                logit = self.net(x)
                prediction = logit.max(1)[1]

                correct = torch.eq(prediction, y).float().mean().data.item()
                cost = (F.cross_entropy(logit[num_adv_image:], y[num_adv_image:]) \
                        + lamb*F.cross_entropy(logit[:num_adv_image], y[:num_adv_image]))*num_adv_image \
                        /(self.batch_size -(1-lamb)*num_adv_image)

                total_acc += correct
                total_loss += cost.data.item()
                self.optim.zero_grad()
                cost.backward()
                self.optim.step()

                if batch_idx % 100 == 0:
                    if self.print_:
                        print()
                        print(self.env_name)
                        print('[{:03d}:{:03d}]'.format(self.global_epoch,
                                                       batch_idx))
                        print('acc:{:.3f} loss:{:.3f}'.format(
                            correct, cost.data.item()))

            total_acc = total_acc / local_iter
            total_loss = total_loss / local_iter
            acc_train_plt.append(total_acc)
            loss_plt.append(total_loss)
            acc_test_plt.append(self.test())
            self.test()
        print(" [*] Training Finished!")
        self.plot_result(acc_train_plt, acc_test_plt, loss_plt,
                         self.history['acc'])

    def ad_test(self, target, epsilon, alpha, iteration):
        self.set_mode('eval')
        correct = 0.
        cost = 0.
        total = 0.
        data_loader = self.data_loader['test']
        for batch_idx, (images, labels) in enumerate(data_loader):
            x_true = Variable(cuda(images, self.cuda))
            y_true = Variable(cuda(labels, self.cuda))

            if isinstance(target, int) and (target in range(self.y_dim)):
                y_target = torch.LongTensor(y_true.size()).fill_(target)
            else:
                y_target = None

            if self.attack_mode == 'FGSM':
                x, _, _ = self.FGSM(x_true, y_true, y_target, epsilon, alpha,
                                    iteration)
            elif self.attack_mode == 'ILLC':
                x, _, _ = self.ILLC(x_true, y_true, y_target, epsilon, alpha,
                                    iteration)

            logit = self.net(x)
            prediction = logit.max(1)[1]

            correct += torch.eq(prediction, y_true).float().sum().data.item()
            cost += F.cross_entropy(logit, y_true,
                                    size_average=False).data.item()

            total += x.size(0)
        accuracy = correct / total
        cost /= total
        print('ACC:{:.4f}'.format(accuracy))
        self.set_mode('train')

    #sample data which size is batch size
    def sample_data(self):
        data_loader = self.data_loader['test']
        for batch_idx, (images, labels) in enumerate(data_loader):
            x_true = Variable(cuda(images, self.cuda))
            y_true = Variable(cuda(labels, self.cuda))
            break
        return x_true, y_true

    def ILLC(self,
             x,
             y_true,
             y_target=None,
             eps=0.03,
             alpha=2 / 255,
             iteration=1):
        self.set_mode('eval')
        x = Variable(cuda(x, self.cuda), requires_grad=True)
        y_true = Variable(cuda(y_true, self.cuda), requires_grad=False)

        if y_target is not None:
            targeted = True
            y_target = Variable(cuda(y_target, self.cuda), requires_grad=False)
        else:
            targeted = False

        # original image classification
        h = self.net(x)
        prediction = h.max(1)[1]
        accuracy = torch.eq(prediction, y_true).float().mean()

        cost = F.cross_entropy(h, y_true)

        # adversarial image classification
        if targeted:
            x_adv, h_adv, h = self.attack.IterativeLeastlikely(
                x, y_target, True, eps, alpha)
        else:
            x_adv, h_adv, h = self.attack.IterativeLeastlikely(
                x, y_true, False, eps, alpha)

        prediction_adv = h_adv.max(1)[1]
        accuracy_adv = torch.eq(prediction_adv, y_true).float().mean()
        cost_adv = F.cross_entropy(h_adv, y_true)

        # make indication of perturbed images that changed predictions of the classifier
        # it draw green and red boxes
        if targeted:
            changed = torch.eq(y_target, prediction_adv)
        else:
            changed = torch.eq(prediction, prediction_adv)
            changed = torch.eq(changed, 0)

        if self.dataset == 'MNIST':
            changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 28, 28)
        elif self.dataset == 'CIFAR10':
            changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 32, 32)

        #fill the grid with color
        changed[:, 0, :, :] = where(changed[:, 0, :, :] == 1, 252, 91)
        changed[:, 1, :, :] = where(changed[:, 1, :, :] == 1, 39, 252)
        changed[:, 2, :, :] = where(changed[:, 2, :, :] == 1, 25, 25)
        changed = self.scale(changed / 255)

        #fil the inner part of grid with image
        if self.dataset == 'MNIST':
            changed[:, :, 3:-2, 3:-2] = x_adv.repeat(1, 3, 1, 1)[:, :, 3:-2,
                                                                 3:-2]
        elif self.dataset == 'CIFAR10':
            changed[:, :, 3:-2, 3:-2] = x_adv[:, :, 3:-2, 3:-2]

        self.set_mode('train')

        return x_adv.data, changed.data,\
                (accuracy.data.item(), cost.data.item(), accuracy_adv.data.item(), cost_adv.data.item())

    # Key point
    def FGSM(self,
             x,
             y_true,
             y_target=None,
             eps=0.03,
             alpha=2 / 255,
             iteration=1):
        self.set_mode('eval')
        x = Variable(cuda(x, self.cuda), requires_grad=True)
        y_true = Variable(cuda(y_true, self.cuda), requires_grad=False)

        if y_target is not None:
            targeted = True
            y_target = Variable(cuda(y_target, self.cuda), requires_grad=False)
        else:
            targeted = False

        # original image classification
        h = self.net(x)
        prediction = h.max(1)[1]
        accuracy = torch.eq(prediction, y_true).float().mean()

        cost = F.cross_entropy(h, y_true)

        # adversarial image classification
        if targeted:
            x_adv, h_adv, h = self.attack.i_fgsm(x, y_target, True, eps, alpha)
        else:
            x_adv, h_adv, h = self.attack.i_fgsm(x, y_true, False, eps, alpha)

        prediction_adv = h_adv.max(1)[1]
        accuracy_adv = torch.eq(prediction_adv, y_true).float().mean()
        cost_adv = F.cross_entropy(h_adv, y_true)

        # make indication of perturbed images that changed predictions of the classifier
        # it draw green and red boxes
        if targeted:
            changed = torch.eq(y_target, prediction_adv)
        else:
            changed = torch.eq(prediction, prediction_adv)
            changed = torch.eq(changed, 0)

        if self.dataset == 'MNIST':
            changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 28, 28)
        elif self.dataset == 'CIFAR10':
            changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 32, 32)

        #fill the grid with color
        changed[:, 0, :, :] = where(changed[:, 0, :, :] == 1, 252, 91)
        changed[:, 1, :, :] = where(changed[:, 1, :, :] == 1, 39, 252)
        changed[:, 2, :, :] = where(changed[:, 2, :, :] == 1, 25, 25)
        changed = self.scale(changed / 255)

        #fil the inner part of grid with image
        if self.dataset == 'MNIST':
            changed[:, :, 3:-2, 3:-2] = x_adv.repeat(1, 3, 1, 1)[:, :, 3:-2,
                                                                 3:-2]
        elif self.dataset == 'CIFAR10':
            changed[:, :, 3:-2, 3:-2] = x_adv[:, :, 3:-2, 3:-2]

        self.set_mode('train')

        return x_adv.data, changed.data,\
                (accuracy.data.item(), cost.data.item(), accuracy_adv.data.item(), cost_adv.data.item())

    def save_checkpoint(self, filename='ckpt.tar'):
        model_states = {
            'net': self.net.state_dict(),
        }
        optim_states = {
            'optim': self.optim.state_dict(),
        }
        states = {
            'iter': self.global_iter,
            'epoch': self.global_epoch,
            'history': self.history,
            'args': self.args,
            'model_states': model_states,
            'optim_states': optim_states,
        }

        file_path = self.save_ckpt_dir / filename
        print(file_path)
        torch.save(states, file_path.open('wb+'))
        print("=> saved checkpoint '{}' (iter {})".format(
            file_path, self.global_iter))

    def load_checkpoint(self, filename='best_acc.tar'):
        file_path = self.ckpt_dir / filename
        if file_path.is_file():
            print("=> loading checkpoint '{}'".format(file_path))
            checkpoint = torch.load(file_path.open('rb'))
            self.global_epoch = checkpoint['epoch']
            self.global_iter = checkpoint['iter']
            self.history = checkpoint['history']
            self.net.load_state_dict(checkpoint['model_states']['net'])
            self.optim.load_state_dict(checkpoint['optim_states']['optim'])
            print("=> loaded checkpoint '{} (iter {})'".format(
                file_path, self.global_iter))

        else:
            print("=> no checkpoint found at '{}'".format(file_path))

    # change the model mode
    def set_mode(self, mode='train'):
        if mode == 'train':
            self.net.train()
        elif mode == 'eval':
            self.net.eval()
        else:
            raise ('mode error. It should be either train or eval')

    # change 0~1 to -1~1 zero centered
    def scale(self, image):
        return image.mul(2).add(-1)

    def convert_torch2numpy(self, torch_img):
        np_img = np.transpose(torch_img.data.cpu().numpy(), (0, 2, 3, 1))
        # PIL_image = transforms.ToPILImage()(transforms.ToTensor()(np_img),interpolation="bicubic")
        return np_img

    def plot_img(self, np_img, idx, title):
        plt.figure()
        plt.title(title)
        plt.imshow(np_img[idx], interpolation='nearest')

    def plot_result(self,
                    acc_train_plt,
                    acc_test_plt,
                    loss_plt,
                    best_acc,
                    title='train_graph'):
        epoch = range(0, self.epoch + 1)
        fig, ax1 = plt.subplots()
        ax1.plot(epoch, acc_train_plt, label='train_acc')
        ax1.plot(epoch, acc_test_plt, label='test_acc')
        ax1.set_xlabel('epoch')
        ax1.set_ylabel('accuracy')
        ax1.tick_params(axis='y')
        plt.legend(loc='upper left')
        color = 'tab:red'
        ax2 = ax1.twinx()
        ax2.plot(epoch[1:],
                 loss_plt,
                 linestyle="--",
                 label='train_loss',
                 color=color)
        ax2.set_ylabel('loss', color=color)
        ax2.tick_params(axis='y', labelcolor=color)
        plt.title("{}".format(self.env_name))
        plt.savefig('{}/{}/best_acc_{}.png'.format(self.args.output_dir,
                                                   self.env_name, best_acc),
                    dpi=350)
Beispiel #2
0
class Solver(object):
    def __init__(self, args, model, dataloarder):
        self.args = args

        # Basic
        # self.cuda = (args.cuda and torch.cuda.is_available())
        # setting device
        if args.cuda and torch.cuda.is_available():
            """
            if argument is given and cuda is available
            """
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.epoch = args.epoch
        self.batch_size = args.batch_size
        self.eps = args.eps
        self.lr = args.lr
        self.y_dim = args.y_dim
        self.target = args.target
        self.dataset = args.dataset
        self.data_loader = dataloarder  # dict
        self.global_epoch = 0
        self.global_iter = 0
        self.print_ = not args.silent
        self.net = model  #need the model to be initialized here
        self.env_name = args.env_name
        self.tensorboard = args.tensorboard
        self.visdom = args.visdom

        self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.env_name)
        if not self.ckpt_dir.exists():
            self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        self.output_dir = Path(args.output_dir).joinpath(args.env_name)
        if not self.output_dir.exists():
            self.output_dir.mkdir(parents=True, exist_ok=True)

        # Visualization Tools
        self.visualization_init(args)

        # Histories
        self.history = dict()
        self.history['acc'] = 0.
        self.history['epoch'] = 0
        self.history['iter'] = 0

        # Models & Optimizers
        # self.model_init(args)
        self.load_ckpt = args.load_ckpt
        if self.load_ckpt != '':
            self.load_checkpoint(self.load_ckpt)

        # Adversarial Perturbation Generator
        #criterion = cuda(torch.nn.CrossEntropyLoss(), self.cuda)
        criterion = F.cross_entropy
        self.attack = Attack(self.net, criterion=criterion)

    def visualization_init(self, args):
        # Visdom
        if self.visdom:
            from utils.visdom_utils import VisFunc
            self.port = args.visdom_port
            self.vf = VisFunc(enval=self.env_name, port=self.port)

        # TensorboardX
        if self.tensorboard:
            from tensorboardX import SummaryWriter
            self.summary_dir = Path(args.summary_dir).joinpath(args.env_name)
            if not self.summary_dir.exists():
                self.summary_dir.mkdir(parents=True, exist_ok=True)

            self.tf = SummaryWriter(log_dir=str(self.summary_dir))
            self.tf.add_text(tag='argument',
                             text_string=str(args),
                             global_step=self.global_epoch)

    # def model_init(self, args):
    #     # Network
    #     self.net = cuda(ToyNet(y_dim=self.y_dim), self.cuda)
    #     self.net.weight_init(_type='kaiming')

    #     # Optimizers
    #     self.optim = optim.Adam([{'params':self.net.parameters(), 'lr':self.lr}],
    #                             betas=(0.5, 0.999))

    def train(self):
        self.set_mode('train')
        for e in range(self.epoch):
            self.global_epoch += 1

            correct = 0.
            cost = 0.
            total = 0.
            for batch_idx, (images,
                            labels) in enumerate(self.data_loader['train']):
                self.global_iter += 1

                x = Variable(images).to(self.device)
                y = Variable(labels).to(self.device)

                logit = self.net(x)
                prediction = logit.max(1)[1]

                correct = torch.eq(prediction, y).float().mean().data[0]
                cost = F.cross_entropy(logit, y)

                self.optim.zero_grad()
                cost.backward()
                self.optim.step()

                if batch_idx % 100 == 0:
                    if self.print_:
                        print()
                        print(self.env_name)
                        print('[{:03d}:{:03d}]'.format(self.global_epoch,
                                                       batch_idx))
                        print('acc:{:.3f} loss:{:.3f}'.format(
                            correct, cost.data[0]))

                    if self.tensorboard:
                        self.tf.add_scalars(main_tag='performance/acc',
                                            tag_scalar_dict={'train': correct},
                                            global_step=self.global_iter)
                        self.tf.add_scalars(
                            main_tag='performance/error',
                            tag_scalar_dict={'train': 1 - correct},
                            global_step=self.global_iter)
                        self.tf.add_scalars(
                            main_tag='performance/cost',
                            tag_scalar_dict={'train': cost.data[0]},
                            global_step=self.global_iter)

            self.test()

        if self.tensorboard:
            self.tf.add_scalars(main_tag='performance/best/acc',
                                tag_scalar_dict={'test': self.history['acc']},
                                global_step=self.history['iter'])
        print(" [*] Training Finished!")

    def test(self):
        self.set_mode('eval')

        correct = 0.
        cost = 0.
        total = 0.

        data_loader = self.data_loader['test']
        for batch_idx, (images, labels) in enumerate(data_loader):
            x = Variable(images).to(self.device)
            y = Variable(labels).to(self.device)

            logit = self.net(x)
            prediction = logit.max(1)[1]

            correct += torch.eq(prediction, y).float().sum().data[0]
            cost += F.cross_entropy(logit, y, size_average=False).data[0]
            total += x.size(0)

        accuracy = correct / total
        cost /= total

        if self.print_:
            print()
            print('[{:03d}]\nTEST RESULT'.format(self.global_epoch))
            print('ACC:{:.4f}'.format(accuracy))
            print('*TOP* ACC:{:.4f} at e:{:03d}'.format(
                accuracy,
                self.global_epoch,
            ))
            print()

            if self.tensorboard:
                self.tf.add_scalars(main_tag='performance/acc',
                                    tag_scalar_dict={'test': accuracy},
                                    global_step=self.global_iter)

                self.tf.add_scalars(main_tag='performance/error',
                                    tag_scalar_dict={'test': (1 - accuracy)},
                                    global_step=self.global_iter)

                self.tf.add_scalars(main_tag='performance/cost',
                                    tag_scalar_dict={'test': cost},
                                    global_step=self.global_iter)

        if self.history['acc'] < accuracy:
            self.history['acc'] = accuracy
            self.history['epoch'] = self.global_epoch
            self.history['iter'] = self.global_iter
            self.save_checkpoint('best_acc.tar')

        self.set_mode('train')

    def generate(self,
                 num_sample=100,
                 target=-1,
                 epsilon=0.03,
                 alpha=2 / 255,
                 iteration=1):
        self.set_mode('eval')

        x_true, y_true = self.sample_data(num_sample)
        if isinstance(target, int) and (target in range(self.y_dim)):
            y_target = torch.LongTensor(y_true.size()).fill_(target)
        else:
            y_target = None

        x_adv, changed, values = self.FGSM(x_true, y_true, y_target, epsilon,
                                           alpha, iteration)
        accuracy, cost, accuracy_adv, cost_adv = values

        save_image(x_true,
                   self.output_dir.joinpath(
                       'legitimate(t:{},e:{},i:{}).jpg'.format(
                           target, epsilon, iteration)),
                   nrow=10,
                   padding=2,
                   pad_value=0.5)
        save_image(x_adv,
                   self.output_dir.joinpath(
                       'perturbed(t:{},e:{},i:{}).jpg'.format(
                           target, epsilon, iteration)),
                   nrow=10,
                   padding=2,
                   pad_value=0.5)
        save_image(changed,
                   self.output_dir.joinpath(
                       'changed(t:{},e:{},i:{}).jpg'.format(
                           target, epsilon, iteration)),
                   nrow=10,
                   padding=3,
                   pad_value=0.5)

        if self.visdom:
            self.vf.imshow_multi(x_true.cpu(), title='legitimate', factor=1.5)
            self.vf.imshow_multi(x_adv.cpu(),
                                 title='perturbed(e:{},i:{})'.format(
                                     epsilon, iteration),
                                 factor=1.5)
            self.vf.imshow_multi(changed.cpu(),
                                 title='changed(white)'.format(epsilon),
                                 factor=1.5)

        print('[BEFORE] accuracy : {:.2f} cost : {:.3f}'.format(
            accuracy, cost))
        print('[AFTER] accuracy : {:.2f} cost : {:.3f}'.format(
            accuracy_adv, cost_adv))

        self.set_mode('train')

    def sample_data(self, num_sample=100):

        total = len(self.data_loader['test'].dataset)
        # seed = torch.FloatTensor(num_sample).uniform_(1, total).long()#if dataset is in tensor format
        #otherwise indexing is not supported for ndarray[torch seed]
        seed = np.random.random_integers(1, total, size=num_sample)
        # print(seed)
        x = torch.from_numpy(self.data_loader['test'].dataset.test_data[seed])
        x = x.type(torch.cuda.FloatTensor)
        x = Variable(x, requires_grad=True).to(self.device)
        x = self.scale(x.float().unsqueeze(1).div(255))
        print(type(self.data_loader['test'].dataset.test_data),
              self.data_loader['test'].dataset.test_data[0].shape)
        y = self.data_loader['test'].dataset.test_data[seed]
        #y = Variable(torch.from_numpy(self.data_loader['test'].dataset.test_labels[seed]),requires_grad = False).to(self.device)

        return x, y

    def FGSM(self,
             x,
             y_true,
             y_target=None,
             eps=0.03,
             alpha=2 / 255,
             iteration=1):
        self.set_mode('eval')
        if type(x) == np.ndarray:
            x = torch.from_numpy(x)
        if type(y_true) == np.ndarray:
            y_true = torch.from_numpy(y_true)
        x = Variable(x, requires_grad=True).to(self.device)
        y_true = Variable(y_true, requires_grad=False).to(self.device)

        if y_target is not None:
            targeted = True
            y_target = Variable(y_target, requires_grad=False).to(self.device)
        else:
            targeted = False

        h = self.net(x)
        prediction = h.max(1)[1]
        accuracy = torch.eq(prediction, y_true).float().mean()
        cost = F.cross_entropy(h, y_true)

        if iteration == 1:
            if targeted:
                x_adv, h_adv, h = self.attack.fgsm(x, y_target, True, eps)
            else:
                x_adv, h_adv, h = self.attack.fgsm(x, y_true, False, eps)
        else:
            if targeted:
                x_adv, h_adv, h = self.attack.i_fgsm(x, y_target, True, eps,
                                                     alpha, iteration)
            else:
                x_adv, h_adv, h = self.attack.i_fgsm(x, y_true, False, eps,
                                                     alpha, iteration)

        prediction_adv = h_adv.max(1)[1]
        accuracy_adv = torch.eq(prediction_adv, y_true).float().mean()
        cost_adv = F.cross_entropy(h_adv, y_true)

        # make indication of perturbed images that changed predictions of the classifier
        if targeted:
            changed = torch.eq(y_target, prediction_adv)
        else:
            changed = torch.eq(prediction, prediction_adv)
            changed = torch.eq(changed, 0)
        changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 28, 28)

        changed[:, 0, :, :] = where(changed[:, 0, :, :] == 1, 252, 91)
        changed[:, 1, :, :] = where(changed[:, 1, :, :] == 1, 39, 252)
        changed[:, 2, :, :] = where(changed[:, 2, :, :] == 1, 25, 25)
        changed = self.scale(changed / 255)
        changed[:, :, 3:-2, 3:-2] = x_adv.repeat(1, 3, 1, 1)[:, :, 3:-2, 3:-2]

        self.set_mode('train')

        return x_adv.data, changed.data,\
                (accuracy.data[0], cost.data[0], accuracy_adv.data[0], cost_adv.data[0])

    def save_checkpoint(self, filename='ckpt.tar'):
        model_states = {
            'net': self.net.state_dict(),
        }
        optim_states = {
            'optim': self.optim.state_dict(),
        }
        states = {
            'iter': self.global_iter,
            'epoch': self.global_epoch,
            'history': self.history,
            'args': self.args,
            'model_states': model_states,
            'optim_states': optim_states,
        }

        file_path = self.ckpt_dir / filename
        torch.save(states, file_path.open('wb+'))
        print("=> saved checkpoint '{}' (iter {})".format(
            file_path, self.global_iter))

    def load_checkpoint(self, filename='best_acc.tar'):
        file_path = self.ckpt_dir / filename
        if file_path.is_file():
            print("=> loading checkpoint '{}'".format(file_path))
            checkpoint = torch.load(file_path.open('rb'))
            self.global_epoch = checkpoint['epoch']
            self.global_iter = checkpoint['iter']
            self.history = checkpoint['history']

            self.net.load_state_dict(checkpoint['model_states']['net'])
            self.optim.load_state_dict(checkpoint['optim_states']['optim'])

            print("=> loaded checkpoint '{} (iter {})'".format(
                file_path, self.global_iter))

        else:
            print("=> no checkpoint found at '{}'".format(file_path))

    def set_mode(self, mode='train'):
        if mode == 'train':
            self.net.train()
        elif mode == 'eval':
            self.net.eval()
        else:
            raise ('mode error. It should be either train or eval')

    def scale(self, image):
        return image.mul(2).add(-1)

    def unscale(self, image):
        return image.add(1).mul(0.5)

    def summary_flush(self, silent=True):
        rm_dir(self.summary_dir, silent)

    def checkpoint_flush(self, silent=True):
        rm_dir(self.ckpt_dir, silent)
Beispiel #3
0
class Solver(object):
    def __init__(self, args):
        self.args = args
        self.epoch = args.epoch
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.z_dim = args.z_dim
        self.k_dim = args.k_dim
        self.beta = args.beta
        self.env_name = args.env_name
        self.ckpt_dir = os.path.join('checkpoints', args.env_name)
        self.global_iter = 0
        self.dataset = args.dataset
        self.fixed_x_num = args.fixed_x_num
        self.output_dir = os.path.join(args.output_dir, args.env_name)
        self.ckpt_load = args.ckpt_load
        self.ckpt_save = args.ckpt_save

        # Toy Network init
        if self.dataset == 'MNIST':
            self.model = MODEL_MNIST(k_dim=self.k_dim, z_dim=self.z_dim).cuda()
        elif self.dataset == 'CIFAR10':
            self.model = MODEL_CIFAR10(k_dim=self.k_dim,
                                       z_dim=self.z_dim).cuda()

        # Visdom Sample Visualization
        self.vf = VisFunc(enval=self.env_name, port=55558)

        # Criterions
        self.MSE_Loss = nn.MSELoss().cuda()

        # Dataset init
        self.train_data, self.train_loader = data_loader(args)
        self.fixed_x = iter(self.train_loader).next()[0][:self.fixed_x_num]

        # Optimizer
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.lr,
                                    betas=(0.5, 0.999))

        # Resume training
        if self.ckpt_load: self.load_checkpoint()

    def set_mode(self, mode='train'):
        if mode == 'train':
            self.model.train()
        elif mode == 'eval':
            self.model.eval()
        else:
            raise ('mode error. It should be either train or eval')

    def save_checkpoint(self, state, filename='checkpoint.pth.tar'):
        if not os.path.exists(self.ckpt_dir): os.makedirs(self.ckpt_dir)
        file_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, file_path)
        print("=> saved checkpoint '{}' (iter {})".format(
            file_path, self.global_iter))

    def load_checkpoint(self):
        filename = 'checkpoint.pth.tar'
        file_path = os.path.join(self.ckpt_dir, filename)
        if os.path.isfile(file_path):
            print("=> loading checkpoint '{}'".format(file_path))
            checkpoint = torch.load(file_path)
            self.global_iter = checkpoint['iter']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (iter {})".format(
                filename, checkpoint['iter']))
        else:
            print("=> no checkpoint found at '{}'".format(file_path))

    def image_save(self, imgs, name='fixed', **kwargs):
        # required imgs shape : batch_size x channels x width x height
        if not os.path.exists(self.output_dir): os.makedirs(self.output_dir)
        filename = os.path.join(self.output_dir,
                                name + '_' + str(self.global_iter) + '.jpg')
        torchvision.utils.save_image(imgs, filename, **kwargs)

    def train(self):
        self.set_mode('train')
        for e in range(self.epoch):
            recon_losses = []
            z_and_sg_embd_losses = []
            sg_z_and_embd_losses = []
            for idx, (images, labels) in enumerate(self.train_loader):
                self.global_iter += 1

                X = Variable(images.cuda(), requires_grad=False)
                X_recon, Z_enc, Z_dec, Z_enc_for_embd = self.model(X)

                recon_loss = self.MSE_Loss(X_recon, X)
                z_and_sg_embd_loss = self.MSE_Loss(Z_enc, Z_dec.detach())
                sg_z_and_embd_loss = self.MSE_Loss(
                    self.model._modules['embd'].weight,
                    Z_enc_for_embd.detach())

                total_loss = recon_loss + sg_z_and_embd_loss + self.beta * z_and_sg_embd_loss

                self.optimizer.zero_grad()
                total_loss.backward(retain_graph=True)
                Z_enc.backward(self.model.grad_for_encoder)
                self.optimizer.step()

                recon_losses.append(recon_loss.data)
                z_and_sg_embd_losses.append(z_and_sg_embd_loss.data)
                sg_z_and_embd_losses.append(sg_z_and_embd_loss.data)

            # Sample Visualization
            self.vf.imshow_multi(X_recon.data.cpu(),
                                 title='random:{:d}'.format(e + 1))
            self.image_save(X_recon.data, name='random')
            self.test()

            # AVG Losses
            recon_losses = torch.cat(recon_losses, 0).mean()
            z_and_sg_embd_losses = torch.cat(z_and_sg_embd_losses, 0).mean()
            sg_z_and_embd_losses = torch.cat(sg_z_and_embd_losses, 0).mean()
            print(
                '[{:02d}/{:d}] recon_loss:{:.2f} z_sg_embd:{:.2f} sg_z_embd:{:.2f}'
                .format(e + 1, self.epoch, recon_losses, z_and_sg_embd_losses,
                        sg_z_and_embd_losses))

        print("[*] Training Finished!")

    def test(self):
        self.set_mode('eval')

        X = Variable(self.fixed_x, requires_grad=False).cuda()
        X_recon = self.model(X)[0]
        X_cat = torch.cat([X, X_recon], 0)
        self.vf.imshow_multi(X_cat.data.cpu(),
                             nrow=self.fixed_x_num,
                             title='fixed_x_test:' + str(self.global_iter))
        self.image_save(X_cat.data, name='fixed', nrow=self.fixed_x_num)
        if self.ckpt_save:
            self.save_checkpoint({
                'iter': self.global_iter,
                'args': self.args,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
            })

        self.set_mode('train')