Esempio n. 1
0
 def __init__(self,
              cuda=True,
              black_box_fn=None,
              mask_resolution=32,
              num_classes=1000,
              default_iterations=200):
     if black_box_fn is None:
         self.black_box_fn = get_black_box_fn(
             cuda=cuda)  # defaults to ResNet-50 on ImageNet
     self.default_iterations = default_iterations
     self.mask_resolution = mask_resolution
     self.num_classes = num_classes
     self.saliency_loss_calc = SaliencyLoss(self.black_box_fn,
                                            area_loss_coef=11,
                                            smoothness_loss_coef=0.5,
                                            preserver_loss_coef=0.2)
     self.cuda = cuda
Esempio n. 2
0
from sal.utils.pytorch_trainer import *
from sal.saliency_model import SaliencyModel, SaliencyLoss, get_black_box_fn,  Distribution_Controller
from sal.visual_masks import *
from sal.datasets import imagenet_dataset
from sal.utils.resnet_encoder import resnet50encoder
from torchvision.models.resnet import resnet50
from  torch.optim import lr_scheduler
import pycat
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
import copy
import time
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
dts = imagenet_dataset
black_box_fn = get_black_box_fn(model_zoo_model=resnet50)
val_dts = dts.get_val_dataset()
allow_selector = True

' ---------------------------------------------------Testing with Mask Estimator -------------------------------------------'
batch_size=16
val_datas = dts.get_loader(val_dts, batch_size=batch_size,  Shuffle=True)
saliency = SaliencyModel(resnet50encoder(pretrained=True, require_out_grad=False), 5, 64, 3, 64, fix_encoder=True, use_simple_activation=False, allow_selector=allow_selector, num_classes=1000)
load_path='./yoursaliencymodel'
saliency.to_saliency_chans=Distribution_Controller()
saliency.minimialistic_restore(os.path.join(os.path.dirname(__file__), (load_path)))
saliency.train(False)
saliency_p = saliency.to(device)

for it_step, batch in enumerate(val_datas):
    images, _, paths = batch
Esempio n. 3
0
        
        x = x.transpose(1, 3)
        x = x.reshape(-1, 9216)
        x = self.drop3(self.act3(self.fc3(x)))
        
        #y = self.act4(self.fc4(x))
        y = self.fc4(x)
        
        return y

model_pytorch = CNNClassifier(padding=2)
_ = model_pytorch.load_state_dict(torch.load("../saved_models/mnist_binarized_cnn_10_digits_pytorch.pth"))

model_func = lambda pretrained=True: model_pytorch

black_box_fn = get_black_box_fn(model_zoo_model=model_func, image_domain=None)
# ----------------

#Get datasets

train_dts = dts.get_train_dataset()
val_dts = dts.get_val_dataset()

print("len(train_dts) = " + str(len(train_dts)))
print("len(val_dts) = " + str(len(val_dts)))

#Evaluate validation prediction accuracy
val_dataloader = DataLoader(val_dts, batch_size=32, num_workers=1, pin_memory=True)

black_box_model = model_func(pretrained=True)