def AdaptiveAttackShuffleDefense(): #Corresponding tag for saving files #First part indicates the type of defense, second part indidcates the synthetic model and last part indicates the strenght of the attack (100%) saveTag = "ViT-L-16, ViT-32(ImageNet21K), p100" device = torch.device("cuda") #Attack parameters numAttackSamples = 1000 epsForAttacks = 0.031 clipMin = 0.0 clipMax = 1.0 #Parameters of training the synthetic model imgSize = 224 batchSize = 32 numClasses = 10 numIterations = 4 epochsPerIteration = 10 epsForAug = 0.1 #when generating synthetic data, this value is eps for FGSM used to generate synthetic data learningRate = (3e-2) / 2 #Learning rate of the synthetic model #Load the training dataset, validation dataset and the defense valLoader, defense = LoadShuffleDefenseAndCIFAR10() trainLoader = DMP.GetCIFAR10Training(imgSize, batchSize) #Get the clean data xTest, yTest = DMP.DataLoaderToTensor(valLoader) cleanLoader = DMP.GetCorrectlyIdentifiedSamplesBalancedDefense( defense, numAttackSamples, valLoader, numClasses) #Create the synthetic model syntheticDir = "Models//imagenet21k_ViT-B_32.npz" config = CONFIGS["ViT-B_32"] syntheticModel = VisionTransformer(config, imgSize, zero_head=True, num_classes=numClasses) syntheticModel.load_from(numpy.load(syntheticDir)) syntheticModel.to(device) #Do the attack oracle = defense dataLoaderForTraining = trainLoader optimizerName = "sgd" #Last line does the attack AttackWrappersAdaptiveBlackBox.AdaptiveAttack( saveTag, device, oracle, syntheticModel, numIterations, epochsPerIteration, epsForAug, learningRate, optimizerName, dataLoaderForTraining, cleanLoader, numClasses, epsForAttacks, clipMin, clipMax)
def RaySAttackShuffleDefense(): #Load the model and dataset valLoader, defense = LoadShuffleDefenseAndCIFAR10() #Get the clean samples numClasses = 10 attackSampleNum = 1000 cleanLoader = DMP.GetCorrectlyIdentifiedSamplesBalancedDefense( defense, attackSampleNum, valLoader, numClasses) #Set the attack parameters epsMax = 0.031 queryLimit = 10000 #The next line does the actual attack on the defense advLoader = AttackWrappersRayS.RaySAttack(defense, epsMax, queryLimit, cleanLoader) #Check the results robustAcc = defense.validateD(advLoader) cleanAcc = defense.validateD(valLoader) #Print the results print("Queries used:", queryLimit) print("Robust acc:", robustAcc) print("Clean acc:", cleanAcc)