class Solver(object): def __init__(self, args): self.args = args # Basic self.cuda = (args.cuda and torch.cuda.is_available()) self.epoch = args.epoch self.batch_size = args.batch_size self.lr = args.lr self.y_dim = args.y_dim # MNIST and CIFAR10 have class 10 self.target = args.target # if you want to give pertubation to specific class then use it self.dataset = args.dataset self.data_loader = return_data(args) self.global_epoch = 0 self.global_iter = 0 self.print_ = not args.silent self.env_name = args.env_name # experiment name self.visdom = args.visdom # I have installed it but don't use it self.ckpt_dir = Path(args.ckpt_dir) self.save_ckpt_dir = Path('./checkpoints/' + args.env_name) print(self.save_ckpt_dir) if not self.ckpt_dir.exists(): self.ckpt_dir.mkdir(parents=True, exist_ok=True) if not self.save_ckpt_dir.exists(): self.save_ckpt_dir.mkdir(parents=True, exist_ok=True) self.output_dir = Path(args.output_dir).joinpath(args.env_name) if not self.output_dir.exists(): self.output_dir.mkdir(parents=True, exist_ok=True) # Visualization Tools self.visualization_init(args) # Histories self.history = dict() self.history['acc'] = 0. self.history['epoch'] = 0 self.history['iter'] = 0 # Models & Optimizers self.model_init(args) self.load_ckpt = args.load_ckpt if args.load_ckpt_flag == True and self.load_ckpt != '': self.load_checkpoint(self.load_ckpt) # Adversarial Perturbation Generator #criterion = cuda(torch.nn.CrossEntropyLoss(), self.cuda) criterion = F.cross_entropy self.attack_mode = args.attack_mode if self.attack_mode == 'FGSM': self.attack = Attack(self.net, criterion=criterion) elif self.attack_mode == 'ILLC': self.attack = Attack(self.net, criterion=criterion) def visualization_init(self, args): # Visdom if self.visdom: from utils.visdom_utils import VisFunc self.port = args.visdom_port self.vf = VisFunc(enval=self.env_name, port=self.port) def model_init(self, args): # Network if args.dataset == 'MNIST': print("MNIST") self.net = cuda(ToyNet_MNIST(y_dim=self.y_dim), self.cuda) elif args.dataset == 'CIFAR10': print("Dataset used CIFAR10") if args.network_choice == 'ToyNet': self.net = cuda(ToyNet_CIFAR10(y_dim=self.y_dim), self.cuda) elif args.network_choice == 'ResNet18': self.net = cuda(ResNet18(), self.cuda) elif args.network_choice == 'ResNet34': self.net = cuda(ResNet34(), self.cuda) elif args.network_choice == 'ResNet50': self.net = cuda(ResNet50(), self.cuda) self.net.weight_init(_type='kaiming') # setup optimizer self.optim = optim.Adam([{ 'params': self.net.parameters(), 'lr': self.lr }], betas=(0.5, 0.999)) def train(self): self.set_mode('train') acc_train_plt = [0] loss_plt = [] acc_test_plt = [0] for e in range(self.epoch): self.global_epoch += 1 local_iter = 0 correct = 0. cost = 0. total = 0. total_acc = 0. total_loss = 0. for batch_idx, (images, labels) in enumerate(self.data_loader['train']): self.global_iter += 1 local_iter += 1 #print("image size is ", np.shape(images)) x = Variable(cuda(images, self.cuda)) y = Variable(cuda(labels, self.cuda)) logit = self.net(x) prediction = logit.max(1)[1] correct = torch.eq(prediction, y).float().mean().data.item() cost = F.cross_entropy(logit, y) total_acc += correct total_loss += cost.data.item() self.optim.zero_grad() cost.backward() self.optim.step() if batch_idx % 100 == 0: if self.print_: print() print(self.env_name) print('[{:03d}:{:03d}]'.format(self.global_epoch, batch_idx)) print('acc:{:.3f} loss:{:.3f}'.format( correct, cost.data.item())) total_acc = total_acc / local_iter total_loss = total_loss / local_iter acc_train_plt.append(total_acc) loss_plt.append(total_loss) acc_test_plt.append(self.test()) print(" [*] Training Finished!") self.plot_result(acc_train_plt, acc_test_plt, loss_plt, self.history['acc']) def test(self): self.set_mode('eval') correct = 0. cost = 0. total = 0. data_loader = self.data_loader['test'] for batch_idx, (images, labels) in enumerate(data_loader): x = Variable(cuda(images, self.cuda)) y = Variable(cuda(labels, self.cuda)) logit = self.net(x) prediction = logit.max(1)[1] correct += torch.eq(prediction, y).float().sum().data.item() cost += F.cross_entropy(logit, y, size_average=False).data.item() total += x.size(0) accuracy = correct / total cost /= total if self.history['acc'] < accuracy: self.history['acc'] = accuracy self.history['epoch'] = self.global_epoch self.history['iter'] = self.global_iter self.save_checkpoint('best_acc.tar') if self.print_: print() print('[{:03d}]\nTEST RESULT'.format(self.global_epoch)) print('ACC:{:.4f}'.format(self.history['acc'])) print('*TOP* ACC:{:.4f} at e:{:03d}'.format( self.history['acc'], self.global_epoch, )) print() self.set_mode('train') return accuracy def generate(self, target, epsilon, alpha, iteration): self.set_mode('eval') x_true, y_true = self.sample_data( ) # take sample which size is batch_size if isinstance(target, int) and (target in range(self.y_dim)): y_target = torch.LongTensor(y_true.size()).fill_(target) else: y_target = None # generate pertubation images, inside of self.FGSM, there are fgsm and i-fgsm method # please implement last one 'iterative least likely method' if self.attack_mode == 'FGSM': x_adv, changed, values = self.FGSM(x_true, y_true, y_target, epsilon, alpha, iteration) elif self.attack_mode == 'ILLC': x_adv, changed, values = self.ILLC(x_true, y_true, y_target, epsilon, alpha, iteration) accuracy, cost, accuracy_adv, cost_adv = values # save the result image, you can find in outputs/experiment_name save_image(x_true, self.output_dir.joinpath( 'legitimate(t:{},e:{},i:{}).jpg'.format( target, epsilon, iteration)), nrow=10, padding=2, pad_value=0.5) save_image(x_adv, self.output_dir.joinpath( 'perturbed(t:{},e:{},i:{}).jpg'.format( target, epsilon, iteration)), nrow=10, padding=2, pad_value=0.5) save_image(changed, self.output_dir.joinpath( 'changed(t:{},e:{},i:{}).jpg'.format( target, epsilon, iteration)), nrow=10, padding=3, pad_value=0.5) if self.visdom: self.vf.imshow_multi(x_true.cpu(), title='legitimate', factor=1.5) self.vf.imshow_multi(x_adv.cpu(), title='perturbed(e:{},i:{})'.format( epsilon, iteration), factor=1.5) self.vf.imshow_multi(changed.cpu(), title='changed(white)'.format(epsilon), factor=1.5) print('[BEFORE] accuracy : {:.2f} cost : {:.3f}'.format( accuracy, cost)) print('[AFTER] accuracy : {:.2f} cost : {:.3f}'.format( accuracy_adv, cost_adv)) self.set_mode('train') def ad_train(self, target, alpha, iteration, lamb): self.set_mode('train') acc_train_plt = [0] acc_test_plt = [0] loss_plt = [] for e in range(self.epoch): self.global_epoch += 1 local_iter = 0 correct = 0. cost = 0. total_acc = 0. total_loss = 0. total = 0. for batch_idx, (images, labels) in enumerate(self.data_loader['train']): self.global_iter += 1 local_iter += 1 self.set_mode('eval') num_adv_image = self.batch_size // 2 x_true = Variable(cuda(images[:num_adv_image], self.cuda)) y_true = Variable(cuda(labels[:num_adv_image], self.cuda)) x = Variable(cuda(images, self.cuda)) y = Variable(cuda(labels, self.cuda)) if isinstance(target, int) and (target in range(self.y_dim)): y_target = torch.LongTensor(y_true.size()).fill_(target) else: y_target = None epsilon = abs(np.random.normal(0, 8 / 255)) if epsilon > 16 / 255: epsilon = 0 if self.attack_mode == 'FGSM': x[:num_adv_image], _, _ = self.FGSM( x_true, y_true, y_target, epsilon, alpha, iteration) elif self.attack_mode == 'ILLC': x[:num_adv_image], _, _ = self.ILLC( x_true, y_true, y_target, epsilon, alpha, iteration) self.set_mode('train') logit = self.net(x) prediction = logit.max(1)[1] correct = torch.eq(prediction, y).float().mean().data.item() cost = (F.cross_entropy(logit[num_adv_image:], y[num_adv_image:]) \ + lamb*F.cross_entropy(logit[:num_adv_image], y[:num_adv_image]))*num_adv_image \ /(self.batch_size -(1-lamb)*num_adv_image) total_acc += correct total_loss += cost.data.item() self.optim.zero_grad() cost.backward() self.optim.step() if batch_idx % 100 == 0: if self.print_: print() print(self.env_name) print('[{:03d}:{:03d}]'.format(self.global_epoch, batch_idx)) print('acc:{:.3f} loss:{:.3f}'.format( correct, cost.data.item())) total_acc = total_acc / local_iter total_loss = total_loss / local_iter acc_train_plt.append(total_acc) loss_plt.append(total_loss) acc_test_plt.append(self.test()) self.test() print(" [*] Training Finished!") self.plot_result(acc_train_plt, acc_test_plt, loss_plt, self.history['acc']) def ad_test(self, target, epsilon, alpha, iteration): self.set_mode('eval') correct = 0. cost = 0. total = 0. data_loader = self.data_loader['test'] for batch_idx, (images, labels) in enumerate(data_loader): x_true = Variable(cuda(images, self.cuda)) y_true = Variable(cuda(labels, self.cuda)) if isinstance(target, int) and (target in range(self.y_dim)): y_target = torch.LongTensor(y_true.size()).fill_(target) else: y_target = None if self.attack_mode == 'FGSM': x, _, _ = self.FGSM(x_true, y_true, y_target, epsilon, alpha, iteration) elif self.attack_mode == 'ILLC': x, _, _ = self.ILLC(x_true, y_true, y_target, epsilon, alpha, iteration) logit = self.net(x) prediction = logit.max(1)[1] correct += torch.eq(prediction, y_true).float().sum().data.item() cost += F.cross_entropy(logit, y_true, size_average=False).data.item() total += x.size(0) accuracy = correct / total cost /= total print('ACC:{:.4f}'.format(accuracy)) self.set_mode('train') #sample data which size is batch size def sample_data(self): data_loader = self.data_loader['test'] for batch_idx, (images, labels) in enumerate(data_loader): x_true = Variable(cuda(images, self.cuda)) y_true = Variable(cuda(labels, self.cuda)) break return x_true, y_true def ILLC(self, x, y_true, y_target=None, eps=0.03, alpha=2 / 255, iteration=1): self.set_mode('eval') x = Variable(cuda(x, self.cuda), requires_grad=True) y_true = Variable(cuda(y_true, self.cuda), requires_grad=False) if y_target is not None: targeted = True y_target = Variable(cuda(y_target, self.cuda), requires_grad=False) else: targeted = False # original image classification h = self.net(x) prediction = h.max(1)[1] accuracy = torch.eq(prediction, y_true).float().mean() cost = F.cross_entropy(h, y_true) # adversarial image classification if targeted: x_adv, h_adv, h = self.attack.IterativeLeastlikely( x, y_target, True, eps, alpha) else: x_adv, h_adv, h = self.attack.IterativeLeastlikely( x, y_true, False, eps, alpha) prediction_adv = h_adv.max(1)[1] accuracy_adv = torch.eq(prediction_adv, y_true).float().mean() cost_adv = F.cross_entropy(h_adv, y_true) # make indication of perturbed images that changed predictions of the classifier # it draw green and red boxes if targeted: changed = torch.eq(y_target, prediction_adv) else: changed = torch.eq(prediction, prediction_adv) changed = torch.eq(changed, 0) if self.dataset == 'MNIST': changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 28, 28) elif self.dataset == 'CIFAR10': changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 32, 32) #fill the grid with color changed[:, 0, :, :] = where(changed[:, 0, :, :] == 1, 252, 91) changed[:, 1, :, :] = where(changed[:, 1, :, :] == 1, 39, 252) changed[:, 2, :, :] = where(changed[:, 2, :, :] == 1, 25, 25) changed = self.scale(changed / 255) #fil the inner part of grid with image if self.dataset == 'MNIST': changed[:, :, 3:-2, 3:-2] = x_adv.repeat(1, 3, 1, 1)[:, :, 3:-2, 3:-2] elif self.dataset == 'CIFAR10': changed[:, :, 3:-2, 3:-2] = x_adv[:, :, 3:-2, 3:-2] self.set_mode('train') return x_adv.data, changed.data,\ (accuracy.data.item(), cost.data.item(), accuracy_adv.data.item(), cost_adv.data.item()) # Key point def FGSM(self, x, y_true, y_target=None, eps=0.03, alpha=2 / 255, iteration=1): self.set_mode('eval') x = Variable(cuda(x, self.cuda), requires_grad=True) y_true = Variable(cuda(y_true, self.cuda), requires_grad=False) if y_target is not None: targeted = True y_target = Variable(cuda(y_target, self.cuda), requires_grad=False) else: targeted = False # original image classification h = self.net(x) prediction = h.max(1)[1] accuracy = torch.eq(prediction, y_true).float().mean() cost = F.cross_entropy(h, y_true) # adversarial image classification if targeted: x_adv, h_adv, h = self.attack.i_fgsm(x, y_target, True, eps, alpha) else: x_adv, h_adv, h = self.attack.i_fgsm(x, y_true, False, eps, alpha) prediction_adv = h_adv.max(1)[1] accuracy_adv = torch.eq(prediction_adv, y_true).float().mean() cost_adv = F.cross_entropy(h_adv, y_true) # make indication of perturbed images that changed predictions of the classifier # it draw green and red boxes if targeted: changed = torch.eq(y_target, prediction_adv) else: changed = torch.eq(prediction, prediction_adv) changed = torch.eq(changed, 0) if self.dataset == 'MNIST': changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 28, 28) elif self.dataset == 'CIFAR10': changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 32, 32) #fill the grid with color changed[:, 0, :, :] = where(changed[:, 0, :, :] == 1, 252, 91) changed[:, 1, :, :] = where(changed[:, 1, :, :] == 1, 39, 252) changed[:, 2, :, :] = where(changed[:, 2, :, :] == 1, 25, 25) changed = self.scale(changed / 255) #fil the inner part of grid with image if self.dataset == 'MNIST': changed[:, :, 3:-2, 3:-2] = x_adv.repeat(1, 3, 1, 1)[:, :, 3:-2, 3:-2] elif self.dataset == 'CIFAR10': changed[:, :, 3:-2, 3:-2] = x_adv[:, :, 3:-2, 3:-2] self.set_mode('train') return x_adv.data, changed.data,\ (accuracy.data.item(), cost.data.item(), accuracy_adv.data.item(), cost_adv.data.item()) def save_checkpoint(self, filename='ckpt.tar'): model_states = { 'net': self.net.state_dict(), } optim_states = { 'optim': self.optim.state_dict(), } states = { 'iter': self.global_iter, 'epoch': self.global_epoch, 'history': self.history, 'args': self.args, 'model_states': model_states, 'optim_states': optim_states, } file_path = self.save_ckpt_dir / filename print(file_path) torch.save(states, file_path.open('wb+')) print("=> saved checkpoint '{}' (iter {})".format( file_path, self.global_iter)) def load_checkpoint(self, filename='best_acc.tar'): file_path = self.ckpt_dir / filename if file_path.is_file(): print("=> loading checkpoint '{}'".format(file_path)) checkpoint = torch.load(file_path.open('rb')) self.global_epoch = checkpoint['epoch'] self.global_iter = checkpoint['iter'] self.history = checkpoint['history'] self.net.load_state_dict(checkpoint['model_states']['net']) self.optim.load_state_dict(checkpoint['optim_states']['optim']) print("=> loaded checkpoint '{} (iter {})'".format( file_path, self.global_iter)) else: print("=> no checkpoint found at '{}'".format(file_path)) # change the model mode def set_mode(self, mode='train'): if mode == 'train': self.net.train() elif mode == 'eval': self.net.eval() else: raise ('mode error. It should be either train or eval') # change 0~1 to -1~1 zero centered def scale(self, image): return image.mul(2).add(-1) def convert_torch2numpy(self, torch_img): np_img = np.transpose(torch_img.data.cpu().numpy(), (0, 2, 3, 1)) # PIL_image = transforms.ToPILImage()(transforms.ToTensor()(np_img),interpolation="bicubic") return np_img def plot_img(self, np_img, idx, title): plt.figure() plt.title(title) plt.imshow(np_img[idx], interpolation='nearest') def plot_result(self, acc_train_plt, acc_test_plt, loss_plt, best_acc, title='train_graph'): epoch = range(0, self.epoch + 1) fig, ax1 = plt.subplots() ax1.plot(epoch, acc_train_plt, label='train_acc') ax1.plot(epoch, acc_test_plt, label='test_acc') ax1.set_xlabel('epoch') ax1.set_ylabel('accuracy') ax1.tick_params(axis='y') plt.legend(loc='upper left') color = 'tab:red' ax2 = ax1.twinx() ax2.plot(epoch[1:], loss_plt, linestyle="--", label='train_loss', color=color) ax2.set_ylabel('loss', color=color) ax2.tick_params(axis='y', labelcolor=color) plt.title("{}".format(self.env_name)) plt.savefig('{}/{}/best_acc_{}.png'.format(self.args.output_dir, self.env_name, best_acc), dpi=350)
class Solver(object): def __init__(self, args, model, dataloarder): self.args = args # Basic # self.cuda = (args.cuda and torch.cuda.is_available()) # setting device if args.cuda and torch.cuda.is_available(): """ if argument is given and cuda is available """ self.device = torch.device('cuda') else: self.device = torch.device('cpu') self.epoch = args.epoch self.batch_size = args.batch_size self.eps = args.eps self.lr = args.lr self.y_dim = args.y_dim self.target = args.target self.dataset = args.dataset self.data_loader = dataloarder # dict self.global_epoch = 0 self.global_iter = 0 self.print_ = not args.silent self.net = model #need the model to be initialized here self.env_name = args.env_name self.tensorboard = args.tensorboard self.visdom = args.visdom self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.env_name) if not self.ckpt_dir.exists(): self.ckpt_dir.mkdir(parents=True, exist_ok=True) self.output_dir = Path(args.output_dir).joinpath(args.env_name) if not self.output_dir.exists(): self.output_dir.mkdir(parents=True, exist_ok=True) # Visualization Tools self.visualization_init(args) # Histories self.history = dict() self.history['acc'] = 0. self.history['epoch'] = 0 self.history['iter'] = 0 # Models & Optimizers # self.model_init(args) self.load_ckpt = args.load_ckpt if self.load_ckpt != '': self.load_checkpoint(self.load_ckpt) # Adversarial Perturbation Generator #criterion = cuda(torch.nn.CrossEntropyLoss(), self.cuda) criterion = F.cross_entropy self.attack = Attack(self.net, criterion=criterion) def visualization_init(self, args): # Visdom if self.visdom: from utils.visdom_utils import VisFunc self.port = args.visdom_port self.vf = VisFunc(enval=self.env_name, port=self.port) # TensorboardX if self.tensorboard: from tensorboardX import SummaryWriter self.summary_dir = Path(args.summary_dir).joinpath(args.env_name) if not self.summary_dir.exists(): self.summary_dir.mkdir(parents=True, exist_ok=True) self.tf = SummaryWriter(log_dir=str(self.summary_dir)) self.tf.add_text(tag='argument', text_string=str(args), global_step=self.global_epoch) # def model_init(self, args): # # Network # self.net = cuda(ToyNet(y_dim=self.y_dim), self.cuda) # self.net.weight_init(_type='kaiming') # # Optimizers # self.optim = optim.Adam([{'params':self.net.parameters(), 'lr':self.lr}], # betas=(0.5, 0.999)) def train(self): self.set_mode('train') for e in range(self.epoch): self.global_epoch += 1 correct = 0. cost = 0. total = 0. for batch_idx, (images, labels) in enumerate(self.data_loader['train']): self.global_iter += 1 x = Variable(images).to(self.device) y = Variable(labels).to(self.device) logit = self.net(x) prediction = logit.max(1)[1] correct = torch.eq(prediction, y).float().mean().data[0] cost = F.cross_entropy(logit, y) self.optim.zero_grad() cost.backward() self.optim.step() if batch_idx % 100 == 0: if self.print_: print() print(self.env_name) print('[{:03d}:{:03d}]'.format(self.global_epoch, batch_idx)) print('acc:{:.3f} loss:{:.3f}'.format( correct, cost.data[0])) if self.tensorboard: self.tf.add_scalars(main_tag='performance/acc', tag_scalar_dict={'train': correct}, global_step=self.global_iter) self.tf.add_scalars( main_tag='performance/error', tag_scalar_dict={'train': 1 - correct}, global_step=self.global_iter) self.tf.add_scalars( main_tag='performance/cost', tag_scalar_dict={'train': cost.data[0]}, global_step=self.global_iter) self.test() if self.tensorboard: self.tf.add_scalars(main_tag='performance/best/acc', tag_scalar_dict={'test': self.history['acc']}, global_step=self.history['iter']) print(" [*] Training Finished!") def test(self): self.set_mode('eval') correct = 0. cost = 0. total = 0. data_loader = self.data_loader['test'] for batch_idx, (images, labels) in enumerate(data_loader): x = Variable(images).to(self.device) y = Variable(labels).to(self.device) logit = self.net(x) prediction = logit.max(1)[1] correct += torch.eq(prediction, y).float().sum().data[0] cost += F.cross_entropy(logit, y, size_average=False).data[0] total += x.size(0) accuracy = correct / total cost /= total if self.print_: print() print('[{:03d}]\nTEST RESULT'.format(self.global_epoch)) print('ACC:{:.4f}'.format(accuracy)) print('*TOP* ACC:{:.4f} at e:{:03d}'.format( accuracy, self.global_epoch, )) print() if self.tensorboard: self.tf.add_scalars(main_tag='performance/acc', tag_scalar_dict={'test': accuracy}, global_step=self.global_iter) self.tf.add_scalars(main_tag='performance/error', tag_scalar_dict={'test': (1 - accuracy)}, global_step=self.global_iter) self.tf.add_scalars(main_tag='performance/cost', tag_scalar_dict={'test': cost}, global_step=self.global_iter) if self.history['acc'] < accuracy: self.history['acc'] = accuracy self.history['epoch'] = self.global_epoch self.history['iter'] = self.global_iter self.save_checkpoint('best_acc.tar') self.set_mode('train') def generate(self, num_sample=100, target=-1, epsilon=0.03, alpha=2 / 255, iteration=1): self.set_mode('eval') x_true, y_true = self.sample_data(num_sample) if isinstance(target, int) and (target in range(self.y_dim)): y_target = torch.LongTensor(y_true.size()).fill_(target) else: y_target = None x_adv, changed, values = self.FGSM(x_true, y_true, y_target, epsilon, alpha, iteration) accuracy, cost, accuracy_adv, cost_adv = values save_image(x_true, self.output_dir.joinpath( 'legitimate(t:{},e:{},i:{}).jpg'.format( target, epsilon, iteration)), nrow=10, padding=2, pad_value=0.5) save_image(x_adv, self.output_dir.joinpath( 'perturbed(t:{},e:{},i:{}).jpg'.format( target, epsilon, iteration)), nrow=10, padding=2, pad_value=0.5) save_image(changed, self.output_dir.joinpath( 'changed(t:{},e:{},i:{}).jpg'.format( target, epsilon, iteration)), nrow=10, padding=3, pad_value=0.5) if self.visdom: self.vf.imshow_multi(x_true.cpu(), title='legitimate', factor=1.5) self.vf.imshow_multi(x_adv.cpu(), title='perturbed(e:{},i:{})'.format( epsilon, iteration), factor=1.5) self.vf.imshow_multi(changed.cpu(), title='changed(white)'.format(epsilon), factor=1.5) print('[BEFORE] accuracy : {:.2f} cost : {:.3f}'.format( accuracy, cost)) print('[AFTER] accuracy : {:.2f} cost : {:.3f}'.format( accuracy_adv, cost_adv)) self.set_mode('train') def sample_data(self, num_sample=100): total = len(self.data_loader['test'].dataset) # seed = torch.FloatTensor(num_sample).uniform_(1, total).long()#if dataset is in tensor format #otherwise indexing is not supported for ndarray[torch seed] seed = np.random.random_integers(1, total, size=num_sample) # print(seed) x = torch.from_numpy(self.data_loader['test'].dataset.test_data[seed]) x = x.type(torch.cuda.FloatTensor) x = Variable(x, requires_grad=True).to(self.device) x = self.scale(x.float().unsqueeze(1).div(255)) print(type(self.data_loader['test'].dataset.test_data), self.data_loader['test'].dataset.test_data[0].shape) y = self.data_loader['test'].dataset.test_data[seed] #y = Variable(torch.from_numpy(self.data_loader['test'].dataset.test_labels[seed]),requires_grad = False).to(self.device) return x, y def FGSM(self, x, y_true, y_target=None, eps=0.03, alpha=2 / 255, iteration=1): self.set_mode('eval') if type(x) == np.ndarray: x = torch.from_numpy(x) if type(y_true) == np.ndarray: y_true = torch.from_numpy(y_true) x = Variable(x, requires_grad=True).to(self.device) y_true = Variable(y_true, requires_grad=False).to(self.device) if y_target is not None: targeted = True y_target = Variable(y_target, requires_grad=False).to(self.device) else: targeted = False h = self.net(x) prediction = h.max(1)[1] accuracy = torch.eq(prediction, y_true).float().mean() cost = F.cross_entropy(h, y_true) if iteration == 1: if targeted: x_adv, h_adv, h = self.attack.fgsm(x, y_target, True, eps) else: x_adv, h_adv, h = self.attack.fgsm(x, y_true, False, eps) else: if targeted: x_adv, h_adv, h = self.attack.i_fgsm(x, y_target, True, eps, alpha, iteration) else: x_adv, h_adv, h = self.attack.i_fgsm(x, y_true, False, eps, alpha, iteration) prediction_adv = h_adv.max(1)[1] accuracy_adv = torch.eq(prediction_adv, y_true).float().mean() cost_adv = F.cross_entropy(h_adv, y_true) # make indication of perturbed images that changed predictions of the classifier if targeted: changed = torch.eq(y_target, prediction_adv) else: changed = torch.eq(prediction, prediction_adv) changed = torch.eq(changed, 0) changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 28, 28) changed[:, 0, :, :] = where(changed[:, 0, :, :] == 1, 252, 91) changed[:, 1, :, :] = where(changed[:, 1, :, :] == 1, 39, 252) changed[:, 2, :, :] = where(changed[:, 2, :, :] == 1, 25, 25) changed = self.scale(changed / 255) changed[:, :, 3:-2, 3:-2] = x_adv.repeat(1, 3, 1, 1)[:, :, 3:-2, 3:-2] self.set_mode('train') return x_adv.data, changed.data,\ (accuracy.data[0], cost.data[0], accuracy_adv.data[0], cost_adv.data[0]) def save_checkpoint(self, filename='ckpt.tar'): model_states = { 'net': self.net.state_dict(), } optim_states = { 'optim': self.optim.state_dict(), } states = { 'iter': self.global_iter, 'epoch': self.global_epoch, 'history': self.history, 'args': self.args, 'model_states': model_states, 'optim_states': optim_states, } file_path = self.ckpt_dir / filename torch.save(states, file_path.open('wb+')) print("=> saved checkpoint '{}' (iter {})".format( file_path, self.global_iter)) def load_checkpoint(self, filename='best_acc.tar'): file_path = self.ckpt_dir / filename if file_path.is_file(): print("=> loading checkpoint '{}'".format(file_path)) checkpoint = torch.load(file_path.open('rb')) self.global_epoch = checkpoint['epoch'] self.global_iter = checkpoint['iter'] self.history = checkpoint['history'] self.net.load_state_dict(checkpoint['model_states']['net']) self.optim.load_state_dict(checkpoint['optim_states']['optim']) print("=> loaded checkpoint '{} (iter {})'".format( file_path, self.global_iter)) else: print("=> no checkpoint found at '{}'".format(file_path)) def set_mode(self, mode='train'): if mode == 'train': self.net.train() elif mode == 'eval': self.net.eval() else: raise ('mode error. It should be either train or eval') def scale(self, image): return image.mul(2).add(-1) def unscale(self, image): return image.add(1).mul(0.5) def summary_flush(self, silent=True): rm_dir(self.summary_dir, silent) def checkpoint_flush(self, silent=True): rm_dir(self.ckpt_dir, silent)
class Solver(object): def __init__(self, args): self.args = args self.epoch = args.epoch self.batch_size = args.batch_size self.lr = args.lr self.z_dim = args.z_dim self.k_dim = args.k_dim self.beta = args.beta self.env_name = args.env_name self.ckpt_dir = os.path.join('checkpoints', args.env_name) self.global_iter = 0 self.dataset = args.dataset self.fixed_x_num = args.fixed_x_num self.output_dir = os.path.join(args.output_dir, args.env_name) self.ckpt_load = args.ckpt_load self.ckpt_save = args.ckpt_save # Toy Network init if self.dataset == 'MNIST': self.model = MODEL_MNIST(k_dim=self.k_dim, z_dim=self.z_dim).cuda() elif self.dataset == 'CIFAR10': self.model = MODEL_CIFAR10(k_dim=self.k_dim, z_dim=self.z_dim).cuda() # Visdom Sample Visualization self.vf = VisFunc(enval=self.env_name, port=55558) # Criterions self.MSE_Loss = nn.MSELoss().cuda() # Dataset init self.train_data, self.train_loader = data_loader(args) self.fixed_x = iter(self.train_loader).next()[0][:self.fixed_x_num] # Optimizer self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.5, 0.999)) # Resume training if self.ckpt_load: self.load_checkpoint() def set_mode(self, mode='train'): if mode == 'train': self.model.train() elif mode == 'eval': self.model.eval() else: raise ('mode error. It should be either train or eval') def save_checkpoint(self, state, filename='checkpoint.pth.tar'): if not os.path.exists(self.ckpt_dir): os.makedirs(self.ckpt_dir) file_path = os.path.join(self.ckpt_dir, filename) torch.save(state, file_path) print("=> saved checkpoint '{}' (iter {})".format( file_path, self.global_iter)) def load_checkpoint(self): filename = 'checkpoint.pth.tar' file_path = os.path.join(self.ckpt_dir, filename) if os.path.isfile(file_path): print("=> loading checkpoint '{}'".format(file_path)) checkpoint = torch.load(file_path) self.global_iter = checkpoint['iter'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (iter {})".format( filename, checkpoint['iter'])) else: print("=> no checkpoint found at '{}'".format(file_path)) def image_save(self, imgs, name='fixed', **kwargs): # required imgs shape : batch_size x channels x width x height if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) filename = os.path.join(self.output_dir, name + '_' + str(self.global_iter) + '.jpg') torchvision.utils.save_image(imgs, filename, **kwargs) def train(self): self.set_mode('train') for e in range(self.epoch): recon_losses = [] z_and_sg_embd_losses = [] sg_z_and_embd_losses = [] for idx, (images, labels) in enumerate(self.train_loader): self.global_iter += 1 X = Variable(images.cuda(), requires_grad=False) X_recon, Z_enc, Z_dec, Z_enc_for_embd = self.model(X) recon_loss = self.MSE_Loss(X_recon, X) z_and_sg_embd_loss = self.MSE_Loss(Z_enc, Z_dec.detach()) sg_z_and_embd_loss = self.MSE_Loss( self.model._modules['embd'].weight, Z_enc_for_embd.detach()) total_loss = recon_loss + sg_z_and_embd_loss + self.beta * z_and_sg_embd_loss self.optimizer.zero_grad() total_loss.backward(retain_graph=True) Z_enc.backward(self.model.grad_for_encoder) self.optimizer.step() recon_losses.append(recon_loss.data) z_and_sg_embd_losses.append(z_and_sg_embd_loss.data) sg_z_and_embd_losses.append(sg_z_and_embd_loss.data) # Sample Visualization self.vf.imshow_multi(X_recon.data.cpu(), title='random:{:d}'.format(e + 1)) self.image_save(X_recon.data, name='random') self.test() # AVG Losses recon_losses = torch.cat(recon_losses, 0).mean() z_and_sg_embd_losses = torch.cat(z_and_sg_embd_losses, 0).mean() sg_z_and_embd_losses = torch.cat(sg_z_and_embd_losses, 0).mean() print( '[{:02d}/{:d}] recon_loss:{:.2f} z_sg_embd:{:.2f} sg_z_embd:{:.2f}' .format(e + 1, self.epoch, recon_losses, z_and_sg_embd_losses, sg_z_and_embd_losses)) print("[*] Training Finished!") def test(self): self.set_mode('eval') X = Variable(self.fixed_x, requires_grad=False).cuda() X_recon = self.model(X)[0] X_cat = torch.cat([X, X_recon], 0) self.vf.imshow_multi(X_cat.data.cpu(), nrow=self.fixed_x_num, title='fixed_x_test:' + str(self.global_iter)) self.image_save(X_cat.data, name='fixed', nrow=self.fixed_x_num) if self.ckpt_save: self.save_checkpoint({ 'iter': self.global_iter, 'args': self.args, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), }) self.set_mode('train')