def LoadShuffleDefenseAndCIFAR10(vis=False): modelPlusList = [] #Basic variable and data setup device = torch.device("cuda") numClasses = 10 imgSize = 224 batchSize = 8 #Load the CIFAR-10 data valLoader = DMP.GetCIFAR10Validation(imgSize, batchSize) #Load ViT-L-16 config = CONFIGS["ViT-L_16"] model = VisionTransformer(config, imgSize, zero_head=True, num_classes=numClasses, vis=vis) dir = "Models/ViT-L_16,cifar10,run0_15K_checkpoint.bin" dict = torch.load(dir) model.load_state_dict(dict) model.eval() #Wrap the model in the ModelPlus class modelPlusV = ModelPlus("ViT-L_16", model, device, imgSizeH=imgSize, imgSizeW=imgSize, batchSize=batchSize) modelPlusList.append(modelPlusV) #Load the BiT-M-R101x3 dirB = "Models/BiT-M-R101x3-Run0.tar" modelB = BigTransferModels.KNOWN_MODELS["BiT-M-R101x3"]( head_size=numClasses, zero_head=False) #Get the checkpoint checkpoint = torch.load(dirB, map_location="cpu") #Remove module so that it will load properly new_state_dict = OrderedDict() for k, v in checkpoint["model"].items(): name = k[7:] # remove `module.` new_state_dict[name] = v #Load the dictionary modelB.load_state_dict(new_state_dict) modelB.eval() #Wrap the model in the ModelPlus class #Here we hard code the Big Transfer Model Plus class input size to 160x128 (what it was trained on) modelBig101Plus = ModelPlus("BiT-M-R101x3", modelB, device, imgSizeH=160, imgSizeW=128, batchSize=batchSize) modelPlusList.append(modelBig101Plus) #Now time to build the defense defense = ShuffleDefense.ShuffleDefense(modelPlusList, numClasses) return valLoader, defense
def LoadViTLAndCIFAR10(): class Arguments(): def __init__(self): self.cuda = "True" self.cnn = "E:\Projects\CPSC-597\AdversarialDetection\State\cifar10cnn.model" self.recon = "E:\Projects\CPSC-597\AdversarialDetection\State\cifar10recon.model" self.detect = "E:\Projects\CPSC-597\AdversarialDetection\State\cifar10detect.model" args = Arguments() metaCNN = MetaCNN(3, args) # metaCNN = ResNet18() # stateDict = torch.load( "E:\Projects\CPSC-597\AdversarialDetection\State\cifar10resnet.model" ) # metaCNN.load_state_dict( stateDict['model'] ) #Basic variable and data setup device = torch.device("cuda") numClasses = 10 imgSize = 32 batchSize = 128 #Load the CIFAR-10 data valLoader = DMP.GetCIFAR10Validation(imgSize, batchSize) #Load ViT-L-16 config = CONFIGS["ViT-L_16"] #model = VisionTransformer(config, imgSize, zero_head=True, num_classes=numClasses) #dir = "Models/ViT-L_16,cifar10,run0_15K_checkpoint.bin" #dict = torch.load(dir) #model.load_state_dict(dict) #model.eval() model = metaCNN #Wrap the model in the ModelPlus class modelPlus = ModelPlus("ViT-L_16", model, device, imgSizeH=imgSize, imgSizeW=imgSize, batchSize=batchSize) return valLoader, modelPlus