def get_samples(target, nb_class=10, sample_index=0): ''' params: target : [mnist, cifar10] nb_class : number of classes example_index : index of image by class returns: original_images (numpy array) : Original images, shape = (number of class, W, H, C) pre_images (torch array) : Preprocessing images, shape = (number of class, C, W, H) target_classes (dictionary) : keys = class index, values = class name model (pytorch model) : pretrained model ''' if target == 'mnist': image_size = (28, 28, 1) _, _, testloader = mnist_load() testset = testloader.dataset elif target == 'cifar10': image_size = (32, 32, 3) _, _, testloader = cifar10_load() testset = testloader.dataset # idx2class target_class2idx = testset.class_to_idx target_classes = dict( zip(list(target_class2idx.values()), list(target_class2idx.keys()))) # select images idx_by_class = [ np.where(np.array(testset.targets) == i)[0][sample_index] for i in range(nb_class) ] original_images = testset.data[idx_by_class] if not isinstance(original_images, np.ndarray): original_images = original_images.numpy() original_images = original_images.reshape((nb_class, ) + image_size) # select targets if isinstance(testset.targets, list): original_targets = torch.LongTensor(testset.targets)[idx_by_class] else: original_targets = testset.targets[idx_by_class] # model load weights = torch.load('../checkpoint/simple_cnn_{}.pth'.format(target)) model = SimpleCNN(target) model.load_state_dict(weights['model']) # image preprocessing pre_images = torch.zeros(original_images.shape) pre_images = np.transpose(pre_images, (0, 3, 1, 2)) for i in range(len(original_images)): pre_images[i] = testset.transform(original_images[i]) return original_images, original_targets, pre_images, target_classes, model
def __init__(self, model, target, batch_size, method, ensemble=None, sample_pct=0.1, nb_class=10, sample_index=0): ''' Args model : pretrained model target : [mnist, cifar10] method : { VBP : Vanilla Backpropagation, IB : Input x Backpropagation, IG : Integrated Gradients, GB : Guided Backpropagation, GC : Grad CAM, GB-GC : Guided GradCAM, DeconvNet : DeconvNet, } ensemble : { SG : SmoothGrad, SG-SQ : SmoothGrad Square SG-VAR : SmoothGrad VAR } deconv_model : Deconvolution model. Only used if method is set to DeconvNet. nb_class : number of class sample_index : sample image index by class ''' # data load if target == 'mnist': _, _, testloader = mnist_load() elif target == 'cifar10': _, _, testloader = cifar10_load() self.target = target self.testset = testloader.dataset self.img_size = self.testset.data.shape[ 1:] # mnist : (28,28), cifar10 : (32,32,3) self.batch_size = batch_size # sampling seed_everything() sample_size = int(len(self.testset) * sample_pct) sample_idx = np.random.choice(len(self.testset), sample_size, replace=False) self.testset.data = self.testset.data[sample_idx] self.testset.targets = np.array(self.testset.targets)[sample_idx] self.sample_pct = sample_pct self.data_size = len(self.testset) # model setting self.model = model self.model.eval() self.deconv_model = None # saliency map self.method = method self.ensemble = ensemble self.saliency_map, self.layer, self.color = self.saliency_map_choice() # sample self.nb_class = nb_class self.nb_checkpoint = 5 self.idx_by_class = [ np.where(np.array(self.testset.targets) == i)[0][sample_index] for i in range(self.nb_class) ]
def main(args, **kwargs): ################################# # Config ################################# epochs = args.epochs batch_size = args.batch_size valid_rate = args.valid_rate lr = args.lr verbose = args.verbose # checkpoint target = args.target attention = args.attention monitor = args.monitor mode = args.mode # save name model_name = 'simple_cnn_{}'.format(target) if attention in ['CAM', 'CBAM']: model_name = model_name + '_{}'.format(attention) elif attention in ['RAN', 'WARN']: model_name = '{}_{}'.format(target, attention) # save directory savedir = '../checkpoint' logdir = '../logs' # device setting cpu or cuda(gpu) device = 'cuda' if torch.cuda.is_available() else 'cpu' print('=====Setting=====') print('Training: ', args.train) print('Epochs: ', epochs) print('Batch Size: ', batch_size) print('Validation Rate: ', valid_rate) print('Learning Rate: ', lr) print('Target: ', target) print('Monitor: ', monitor) print('Model Name: ', model_name) print('Mode: ', mode) print('Attention: ', attention) print('Save Directory: ', savedir) print('Log Directory: ', logdir) print('Device: ', device) print('Verbose: ', verbose) print() print('Evaluation: ', args.eval) if args.eval != None: print('Pixel ratio: ', kwargs['ratio']) print() print('Setting Random Seed') print() seed_everything() # seed setting ################################# # Data Load ################################# print('=====Data Load=====') if target == 'mnist': trainloader, validloader, testloader = mnist_load( batch_size=batch_size, validation_rate=valid_rate, shuffle=True) elif target == 'cifar10': trainloader, validloader, testloader = cifar10_load( batch_size=batch_size, validation_rate=valid_rate, shuffle=True) ################################# # ROAR or KAR ################################# if (args.eval == 'ROAR') or (args.eval == 'KAR'): # saliency map load filename = f'../saliency_maps/[{args.target}]{args.method}' if attention in ['CBAM', 'RAN']: filename += f'_{attention}' hf = h5py.File(f'{filename}_train.hdf5', 'r') sal_maps = np.array(hf['saliencys']) # adjust image data_lst = adjust_image(kwargs['ratio'], trainloader, sal_maps, args.eval) # hdf5 close hf.close() # model name model_name = model_name + '_{0:}_{1:}{2:.1f}'.format( args.method, args.eval, kwargs['ratio']) # check exit if os.path.isfile('{}/{}_logs.txt'.format(logdir, model_name)): sys.exit() ################################# # Load model ################################# print('=====Model Load=====') if attention == 'RAN': net = RAN(target).to(device) elif attention == 'WARN': net = WideResNetAttention(target).to(device) else: net = SimpleCNN(target, attention).to(device) n_parameters = sum([np.prod(p.size()) for p in net.parameters()]) print('Total number of parameters:', n_parameters) print() # Model compile optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) criterion = nn.CrossEntropyLoss() ################################# # Train ################################# modeltrain = ModelTrain(model=net, data=trainloader, epochs=epochs, criterion=criterion, optimizer=optimizer, device=device, model_name=model_name, savedir=savedir, monitor=monitor, mode=mode, validation=validloader, verbose=verbose) ################################# # Test ################################# modeltest = ModelTest(model=net, data=testloader, loaddir=savedir, model_name=model_name, device=device) modeltrain.history['test_result'] = modeltest.results # History save as json file if not (os.path.isdir(logdir)): os.mkdir(logdir) with open(f'{logdir}/{model_name}_logs.txt', 'w') as outfile: json.dump(modeltrain.history, outfile)
def get_samples(target, nb_class=10, sample_index=0, attention=None, device='cpu'): ''' Get samples : original images, preprocessed images, target class, trained model args: - target: [mnist, cifar10] - nb_class: number of classes - example_index: index of image by class return: - original_images (numpy array): Original images, shape = (number of class, W, H, C) - pre_images (torch array): Preprocessing images, shape = (number of class, C, W, H) - target_classes (dictionary): keys = class index, values = class name - model (pytorch model): pretrained model ''' if target == 'mnist': image_size = (28, 28, 1) _, _, testloader = mnist_load() testset = testloader.dataset elif target == 'cifar10': image_size = (32, 32, 3) _, _, testloader = cifar10_load() testset = testloader.dataset # idx2class target_class2idx = testset.class_to_idx target_classes = dict( zip(list(target_class2idx.values()), list(target_class2idx.keys()))) # select images idx_by_class = [ np.where(np.array(testset.targets) == i)[0][sample_index] for i in range(nb_class) ] original_images = testset.data[idx_by_class] if not isinstance(original_images, np.ndarray): original_images = original_images.numpy() original_images = original_images.reshape((nb_class, ) + image_size) # select targets if isinstance(testset.targets, list): original_targets = torch.LongTensor(testset.targets)[idx_by_class] else: original_targets = testset.targets[idx_by_class] # model load filename = f'simple_cnn_{target}' if attention in ['CAM', 'CBAM']: filename += f'_{attention}' elif attention in ['RAN', 'WARN']: filename = f'{target}_{attention}' print('filename: ', filename) weights = torch.load(f'../checkpoint/{filename}.pth') if attention == 'RAN': model = RAN(target).to(device) elif attention == 'WARN': model = WideResNetAttention(target).to(device) else: model = SimpleCNN(target, attention).to(device) model.load_state_dict(weights['model']) # image preprocessing pre_images = torch.zeros(original_images.shape) pre_images = np.transpose(pre_images, (0, 3, 1, 2)) for i in range(len(original_images)): pre_images[i] = testset.transform(original_images[i]) return original_images, original_targets, pre_images, target_classes, model
def main(args, **kwargs): # Config epochs = args.epochs batch_size = args.batch_size valid_rate = args.valid_rate lr = args.lr verbose = args.verbose # checkpoint target = args.target monitor = args.monitor mode = args.mode # save name model_name = 'simple_cnn_{}'.format(target) # save directory savedir = '../checkpoint' logdir = '../logs' # device setting cpu or cuda(gpu) device = 'cuda' if torch.cuda.is_available() else 'cpu' print('=====Setting=====') print('Epochs: ', epochs) print('Batch Size: ', batch_size) print('Validation Rate: ', valid_rate) print('Learning Rate: ', lr) print('Target: ', target) print('Monitor: ', monitor) print('Model Name: ', model_name) print('Mode: ', mode) print('Save Directory: ', savedir) print('Log Directory: ', logdir) print('Device: ', device) print('Verbose: ', verbose) print() print('Setting Random Seed') print() seed_everything() # seed setting # Data Load print('=====Data Load=====') if target == 'mnist': trainloader, validloader, testloader = mnist_load( batch_size=batch_size, validation_rate=valid_rate, shuffle=True) elif target == 'cifar10': trainloader, validloader, testloader = cifar10_load( batch_size=batch_size, validation_rate=valid_rate, shuffle=True) # ROAR or KAR if (args.eval == 'ROAR') or (args.eval == 'KAR'): # saliency map load hf = h5py.File( f'../saliency_maps/[{args.target}]{args.method}_train.hdf5', 'r') sal_maps = np.array(hf['saliencys']) # adjust image data_lst = adjust_image(kwargs['ratio'], trainloader, sal_maps, args.eval) # hdf5 close hf.close() # model name model_name = model_name + '_{0:}_{1:}{2:.1f}'.format( args.method, args.eval, kwargs['ratio']) print('=====Model Load=====') # Load model net = SimpleCNN(target).to(device) print() # Model compile optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) criterion = nn.CrossEntropyLoss() # Train modeltrain = ModelTrain(model=net, data=trainloader, epochs=epochs, criterion=criterion, optimizer=optimizer, device=device, model_name=model_name, savedir=savedir, monitor=monitor, mode=mode, validation=validloader, verbose=verbose) # Test modeltest = ModelTest(model=net, data=testloader, loaddir=savedir, model_name=model_name, device=device) modeltrain.history['test_result'] = modeltest.results # History save as json file if not (os.path.isdir(logdir)): os.mkdir(logdir) with open(f'{logdir}/{model_name}_logs.txt', 'w') as outfile: json.dump(modeltrain.history, outfile)