示例#1
0
def AdaptiveAttackShuffleDefense():
    #Corresponding tag for saving files
    #First part indicates the type of defense, second part indidcates the synthetic model and last part indicates the strenght of the attack (100%)
    saveTag = "ViT-L-16, ViT-32(ImageNet21K), p100"
    device = torch.device("cuda")
    #Attack parameters
    numAttackSamples = 1000
    epsForAttacks = 0.031
    clipMin = 0.0
    clipMax = 1.0
    #Parameters of training the synthetic model
    imgSize = 224
    batchSize = 32
    numClasses = 10
    numIterations = 4
    epochsPerIteration = 10
    epsForAug = 0.1  #when generating synthetic data, this value is eps for FGSM used to generate synthetic data
    learningRate = (3e-2) / 2  #Learning rate of the synthetic model
    #Load the training dataset, validation dataset and the defense
    valLoader, defense = LoadShuffleDefenseAndCIFAR10()
    trainLoader = DMP.GetCIFAR10Training(imgSize, batchSize)
    #Get the clean data
    xTest, yTest = DMP.DataLoaderToTensor(valLoader)
    cleanLoader = DMP.GetCorrectlyIdentifiedSamplesBalancedDefense(
        defense, numAttackSamples, valLoader, numClasses)
    #Create the synthetic model
    syntheticDir = "Models//imagenet21k_ViT-B_32.npz"
    config = CONFIGS["ViT-B_32"]
    syntheticModel = VisionTransformer(config,
                                       imgSize,
                                       zero_head=True,
                                       num_classes=numClasses)
    syntheticModel.load_from(numpy.load(syntheticDir))
    syntheticModel.to(device)
    #Do the attack
    oracle = defense
    dataLoaderForTraining = trainLoader
    optimizerName = "sgd"
    #Last line does the attack
    AttackWrappersAdaptiveBlackBox.AdaptiveAttack(
        saveTag, device, oracle, syntheticModel, numIterations,
        epochsPerIteration, epsForAug, learningRate, optimizerName,
        dataLoaderForTraining, cleanLoader, numClasses, epsForAttacks, clipMin,
        clipMax)
示例#2
0
def RaySAttackShuffleDefense():
    #Load the model and dataset
    valLoader, defense = LoadShuffleDefenseAndCIFAR10()
    #Get the clean samples
    numClasses = 10
    attackSampleNum = 1000
    cleanLoader = DMP.GetCorrectlyIdentifiedSamplesBalancedDefense(
        defense, attackSampleNum, valLoader, numClasses)
    #Set the attack parameters
    epsMax = 0.031
    queryLimit = 10000
    #The next line does the actual attack on the defense
    advLoader = AttackWrappersRayS.RaySAttack(defense, epsMax, queryLimit,
                                              cleanLoader)
    #Check the results
    robustAcc = defense.validateD(advLoader)
    cleanAcc = defense.validateD(valLoader)
    #Print the results
    print("Queries used:", queryLimit)
    print("Robust acc:", robustAcc)
    print("Clean acc:", cleanAcc)