def train_detector(gpu, arch, adv_data_arch, img_classifier_model_path, dataset, args): print("using GPU {}".format(gpu)) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) img_classifier_network = build_network(dataset, arch, img_classifier_model_path) layer_number = 3 if dataset in [ "ImageNet", "CIFAR-10", "CIFAR-100", "SVHN" ] else 2 if args.use_cv_transform: image_transform = ImageTransformCV2(dataset, [1, 2]) detector_model_path = '{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format( PY_ROOT, args.dataset, args.protocol, arch, adv_data_arch, args.epochs, args.lr, args.batch_size) os.makedirs(os.path.dirname(detector_model_path), exist_ok=True) else: image_transform = ImageTransformTorch(dataset, [5, 15]) detector_model_path = '{}/train_pytorch_model/ROTATE_DET/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format( PY_ROOT, args.dataset, args.protocol, arch, adv_data_arch, args.epochs, args.lr, args.batch_size) os.makedirs(os.path.dirname(detector_model_path), exist_ok=True) detector = Detector(dataset, img_classifier_network, CLASS_NUM[dataset], image_transform, layer_number) detector.cuda() optimizer = torch.optim.SGD(detector.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True if args.protocol == SPLIT_DATA_PROTOCOL.TRAIN_ALL_TEST_ALL: args.balance = True else: args.balance = False if dataset == "ImageNet": train_dataset = AdversaryRandomAccessNpyDataset( IMAGE_DATA_ROOT[dataset] + "/adversarial_images/{}".format(adv_data_arch), True, args.protocol, config.META_ATTACKER_PART_I, config.META_ATTACKER_PART_II, args.balance, dataset) else: train_dataset = AdversaryDataset( IMAGE_DATA_ROOT[dataset] + "/adversarial_images/{}".format(adv_data_arch), True, args.protocol, config.META_ATTACKER_PART_I, config.META_ATTACKER_PART_II, balance=args.balance) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) train_epochs(detector_model_path, train_loader, detector, optimizer, arch, args, args.gpu)
def evaluate_speed(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar") result = defaultdict(dict) # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)): ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) if dataset!="CIFAR-10": continue split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] if split_protocol != args.protocol: continue arch = ma.group(3) epoch = int(ma.group(4)) lr = float(ma.group(5)) batch_size = int(ma.group(6)) print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates all_shots = [0, 1,5] num_updates = args.num_updates for shot in all_shots: report_shot = shot if shot == 0: num_updates = 0 shot = 1 else: num_updates = args.num_updates meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=split_protocol, no_random_way=True, adv_arch="conv3") #FIXME adv arch还没做其他architecture的代码 data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformCV2(dataset, [1, 2]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) evaluate_result = speed_test(model, data_loader, lr, num_updates, update_BN=False) result[dataset][report_shot] = evaluate_result break with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/speed_test.json".format(PY_ROOT), "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()
def evaluate_cross_arch(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) # IMG_ROTATE_DET@CIFAR-10_TRAIN_ALL_TEST_ALL@model_conv3@data_conv3@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar") result = defaultdict(dict) update_BN = False for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)): ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] if split_protocol != args.protocol: continue arch = ma.group(3) adv_arch = ma.group(4) if adv_arch != args.cross_arch_source: continue epoch = int(ma.group(5)) lr = float(ma.group(6)) batch_size = int(ma.group(7)) print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates for shot in [0,1,5]: if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, is_train=False, load_mode=args.load_mode, protocol=split_protocol, no_random_way=True, adv_arch=args.cross_arch_target, fetch_attack_name=False) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformCV2(dataset, [1, 2]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100","SVHN"] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number, num_classes=2) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=update_BN) # FIXME update_BN=False会很高 if num_update == 0: shot = 0 result["{}@{}-->{}".format(dataset, args.cross_arch_source, args.cross_arch_target)][shot] = evaluate_result with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/cross_arch_{}--{}_using_{}_result_updateBN_{}.json".format(PY_ROOT, args.cross_arch_source, args.cross_arch_target, args.protocol, update_BN), "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()
def build_rotate_detector(dataset, arch, adv_arch, protocol): img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) model_path = "{}/train_pytorch_model/white_box_model/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar".format( PY_ROOT, dataset, protocol, arch, adv_arch) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) print("load {} to detector".format(model_path)) image_transform = ImageTransformTorch(dataset, [5, 15]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2 detector = Detector(dataset, img_classifier_network, CLASS_NUM[dataset], image_transform, layer_number) detector.load_state_dict(checkpoint['state_dict'], strict=True) detector.cuda() return detector
def evaluate_whitebox_attack(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) attacks = ["FGSM", "CW_L2"] # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar") result = defaultdict(dict) # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar model_path = "{}/train_pytorch_model/white_box_model/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar".format( PY_ROOT, args.dataset, args.protocol, "conv3", args.adv_arch) assert os.path.exists(model_path), "{} not exists".format(model_path) ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) # if dataset != "MNIST": # continue split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] arch = ma.group(3) epoch = int(ma.group(4)) lr = float(ma.group(5)) batch_size = int(ma.group(6)) print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates all_shots = [1,5] detector = "RotateDet" checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) for attack_name in attacks: for shot in all_shots: root_folder = IMAGE_DATA_ROOT[dataset] + "/adversarial_images/white_box@data_{}@det_{}/{}/".format(args.adv_arch, detector, attack_name) if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = WhiteBoxMetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, args.load_mode, detector, attack_name, root_folder) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformTorch(dataset, [5, 15]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) evaluate_result = finetune_eval_task_rotate(model, data_loader, lr, num_update, update_BN=args.eval_update_BN) if num_update == 0: shot = 0 result["{}_{}_{}_{}".format(dataset, attack_name, detector, args.adv_arch)][shot] = evaluate_result if args.eval_update_BN: update_BN_str="UpdateBN" else: update_BN_str = "NoUpdateBN" with open("{}/train_pytorch_model/white_box_model/white_box_RotateDet_{}_{}_using_{}__result.json".format(PY_ROOT,args.dataset, update_BN_str, args.protocol), "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()