コード例 #1
0
def LoadShuffleDefenseAndCIFAR10(vis=False):
    modelPlusList = []
    #Basic variable and data setup
    device = torch.device("cuda")
    numClasses = 10
    imgSize = 224
    batchSize = 8
    #Load the CIFAR-10 data
    valLoader = DMP.GetCIFAR10Validation(imgSize, batchSize)
    #Load ViT-L-16
    config = CONFIGS["ViT-L_16"]
    model = VisionTransformer(config,
                              imgSize,
                              zero_head=True,
                              num_classes=numClasses,
                              vis=vis)
    dir = "Models/ViT-L_16,cifar10,run0_15K_checkpoint.bin"
    dict = torch.load(dir)
    model.load_state_dict(dict)
    model.eval()
    #Wrap the model in the ModelPlus class
    modelPlusV = ModelPlus("ViT-L_16",
                           model,
                           device,
                           imgSizeH=imgSize,
                           imgSizeW=imgSize,
                           batchSize=batchSize)
    modelPlusList.append(modelPlusV)
    #Load the BiT-M-R101x3
    dirB = "Models/BiT-M-R101x3-Run0.tar"
    modelB = BigTransferModels.KNOWN_MODELS["BiT-M-R101x3"](
        head_size=numClasses, zero_head=False)
    #Get the checkpoint
    checkpoint = torch.load(dirB, map_location="cpu")
    #Remove module so that it will load properly
    new_state_dict = OrderedDict()
    for k, v in checkpoint["model"].items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    #Load the dictionary
    modelB.load_state_dict(new_state_dict)
    modelB.eval()
    #Wrap the model in the ModelPlus class
    #Here we hard code the Big Transfer Model Plus class input size to 160x128 (what it was trained on)
    modelBig101Plus = ModelPlus("BiT-M-R101x3",
                                modelB,
                                device,
                                imgSizeH=160,
                                imgSizeW=128,
                                batchSize=batchSize)
    modelPlusList.append(modelBig101Plus)
    #Now time to build the defense
    defense = ShuffleDefense.ShuffleDefense(modelPlusList, numClasses)
    return valLoader, defense
コード例 #2
0
def LoadViTLAndCIFAR10():
    class Arguments():
        def __init__(self):
            self.cuda = "True"
            self.cnn = "E:\Projects\CPSC-597\AdversarialDetection\State\cifar10cnn.model"
            self.recon = "E:\Projects\CPSC-597\AdversarialDetection\State\cifar10recon.model"
            self.detect = "E:\Projects\CPSC-597\AdversarialDetection\State\cifar10detect.model"

    args = Arguments()
    metaCNN = MetaCNN(3, args)
    # metaCNN = ResNet18()
    # stateDict = torch.load( "E:\Projects\CPSC-597\AdversarialDetection\State\cifar10resnet.model" )
    # metaCNN.load_state_dict( stateDict['model'] )

    #Basic variable and data setup
    device = torch.device("cuda")
    numClasses = 10
    imgSize = 32
    batchSize = 128
    #Load the CIFAR-10 data
    valLoader = DMP.GetCIFAR10Validation(imgSize, batchSize)
    #Load ViT-L-16
    config = CONFIGS["ViT-L_16"]
    #model = VisionTransformer(config, imgSize, zero_head=True, num_classes=numClasses)
    #dir = "Models/ViT-L_16,cifar10,run0_15K_checkpoint.bin"
    #dict = torch.load(dir)
    #model.load_state_dict(dict)
    #model.eval()
    model = metaCNN

    #Wrap the model in the ModelPlus class
    modelPlus = ModelPlus("ViT-L_16",
                          model,
                          device,
                          imgSizeH=imgSize,
                          imgSizeW=imgSize,
                          batchSize=batchSize)
    return valLoader, modelPlus