def __init__(self, dataset, num_classes, meta_batch_size, meta_step_size, inner_step_size, lr_decay_itr, epoch, num_inner_updates, load_task_mode, protocol, arch, tot_num_tasks, num_support, num_query, no_random_way, tensorboard_data_prefix, train=True, adv_arch="conv4", need_val=False): super(self.__class__, self).__init__() self.dataset = dataset self.num_classes = num_classes self.meta_batch_size = meta_batch_size # task number per batch self.meta_step_size = meta_step_size self.inner_step_size = inner_step_size self.lr_decay_itr = lr_decay_itr self.epoch = epoch self.num_inner_updates = num_inner_updates self.test_finetune_updates = num_inner_updates # Make the nets if arch == "conv3": # network = FourConvs(IN_CHANNELS[self.dataset_name], IMAGE_SIZE[self.dataset_name], num_classes) network = Conv3(IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset], num_classes) elif arch == "resnet10": network = MetaNetwork(resnet10(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False), IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset]) elif arch == "resnet18": network = MetaNetwork(resnet18(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False), IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset]) self.network = network self.network.cuda() if train: trn_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, num_query, dataset, is_train=True, load_mode=load_task_mode, protocol=protocol, no_random_way=no_random_way, adv_arch=adv_arch, fetch_attack_name=False) # task number per mini-batch is controlled by DataLoader self.train_loader = DataLoader(trn_dataset, batch_size=meta_batch_size, shuffle=True, num_workers=4, pin_memory=True) self.tensorboard = TensorBoardWriter("{0}/pytorch_MAML_tensorboard".format(PY_ROOT), tensorboard_data_prefix) os.makedirs("{0}/pytorch_MAML_tensorboard".format(PY_ROOT), exist_ok=True) if need_val: val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, 15, dataset, is_train=False, load_mode=load_task_mode, protocol=protocol, no_random_way=True, adv_arch=adv_arch, fetch_attack_name=False) self.val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=4, pin_memory=True) # 固定100个task,分别测每个task的准确率 self.fast_net = InnerLoop(self.network, self.num_inner_updates, self.inner_step_size, self.meta_batch_size) # 并行执行每个task self.fast_net.cuda() self.opt = Adam(self.network.parameters(), lr=meta_step_size)
def evaluate_speed(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar") result = defaultdict(dict) # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)): ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) if dataset!="CIFAR-10": continue split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] if split_protocol != args.protocol: continue arch = ma.group(3) epoch = int(ma.group(4)) lr = float(ma.group(5)) batch_size = int(ma.group(6)) print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates all_shots = [0, 1,5] num_updates = args.num_updates for shot in all_shots: report_shot = shot if shot == 0: num_updates = 0 shot = 1 else: num_updates = args.num_updates meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=split_protocol, no_random_way=True, adv_arch="conv3") #FIXME adv arch还没做其他architecture的代码 data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformCV2(dataset, [1, 2]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) evaluate_result = speed_test(model, data_loader, lr, num_updates, update_BN=False) result[dataset][report_shot] = evaluate_result break with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/speed_test.json".format(PY_ROOT), "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()
def evaluate_cross_arch(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) # IMG_ROTATE_DET@CIFAR-10_TRAIN_ALL_TEST_ALL@model_conv3@data_conv3@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar") result = defaultdict(dict) update_BN = False for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)): ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] if split_protocol != args.protocol: continue arch = ma.group(3) adv_arch = ma.group(4) if adv_arch != args.cross_arch_source: continue epoch = int(ma.group(5)) lr = float(ma.group(6)) batch_size = int(ma.group(7)) print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates for shot in [0,1,5]: if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, is_train=False, load_mode=args.load_mode, protocol=split_protocol, no_random_way=True, adv_arch=args.cross_arch_target, fetch_attack_name=False) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformCV2(dataset, [1, 2]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100","SVHN"] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number, num_classes=2) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=update_BN) # FIXME update_BN=False会很高 if num_update == 0: shot = 0 result["{}@{}-->{}".format(dataset, args.cross_arch_source, args.cross_arch_target)][shot] = evaluate_result with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/cross_arch_{}--{}_using_{}_result_updateBN_{}.json".format(PY_ROOT, args.cross_arch_source, args.cross_arch_target, args.protocol, update_BN), "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()
def evaluate_finetune(model_path_list, lr, protocol): # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做 extract_pattern_detail = re.compile(".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar") tot_num_tasks = 20000 way = 2 query = 15 result = defaultdict(dict) assert protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II, "protocol {} is not TRAIN_I_TEST_II!".format(protocol) for model_path in model_path_list: ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) if dataset != "CIFAR-10": continue file_protocol = ma.group(2) if str(protocol) != file_protocol: continue adv_arch = ma.group(4) balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) arch = ma.group(3) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() import torch checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) shot = 1 meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=protocol, no_random_way=True, adv_arch=adv_arch) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) for num_update in range(0,51): evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=False) if num_update == 0: shot = 0 else: shot = 1 evaluate_result["shot"] = shot result["{}_{}".format(dataset,balance)][num_update] = evaluate_result return result
def main(): args = parser.parse_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # DL_IMAGE_CLASSIFIER_CIFAR-10@conv4@epoch_40@lr_0.0001@batch_500.pth.tar extract_pattern_detail = re.compile( ".*?DL_IMAGE_CLASSIFIER_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth\.tar" ) # test_CIFAR-10_tot_num_tasks_20000_metabatch_10_way_5_shot_5_query_15.txt extract_dataset_pattern = re.compile( ".*?tot_num_tasks_(\d+)_metabatch_(\d+)_way_(\d+)_shot_(\d+)_query_(\d+).*" ) result = {} for model_path in glob.glob( "{}/train_pytorch_model/DL_IMAGE_CLASSIFIER*".format(PY_ROOT)): ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) arch = ma.group(2) epoch = int(ma.group(3)) lr = float(ma.group(4)) batch = int(ma.group(5)) if arch == "conv4": model = FourConvs(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) elif arch == "resnet10": model = resnet10(CLASS_NUM[dataset], IN_CHANNELS[dataset]) elif arch == "resnet18": model = resnet18(CLASS_NUM[dataset], IN_CHANNELS[dataset]) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint["state_dict"]) model.cuda() print("loading {}".format(model_path)) detector = DetectionEvaluator(model, dataset) for pkl_task_file_name in glob.glob("{}/task/{}/test_{}*.pkl".format( PY_ROOT, args.split_data_protocol, dataset)): preprocessor = get_preprocessor( input_size=IMAGE_SIZE[dataset], input_channels=IN_CHANNELS[dataset]) if dataset == "CIFAR-10": train_dataset = CIFAR10(IMAGE_DATA_ROOT[dataset], train=True, transform=preprocessor) elif dataset == "MNIST": train_dataset = MNIST(IMAGE_DATA_ROOT[dataset], train=True, transform=preprocessor, download=True) elif dataset == "F-MNIST": train_dataset = FashionMNIST(IMAGE_DATA_ROOT[dataset], train=True, transform=preprocessor, download=True) elif dataset == "SVHN": train_dataset = SVHN(IMAGE_DATA_ROOT[dataset], train=True, transform=preprocessor) ma_d = extract_dataset_pattern.match(pkl_task_file_name) tot_num_tasks = int(ma_d.group(1)) num_classes = int(ma_d.group(3)) num_support = int(ma_d.group(4)) num_query = int(ma_d.group(5)) meta_dataset = MetaTaskDataset( tot_num_tasks, num_classes, num_support, num_query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, pkl_task_dump_path=pkl_task_file_name, protocol=args.split_data_protocol) val_loader = DataLoader(meta_dataset, batch_size=args.batch_size, shuffle=False) # train_imgs = get_train_data(train_dataset) train_imgs = [] accuracy = detector.evaluate_detections(train_imgs, val_loader) key1 = os.path.basename(model_path) key1 = key1[:key1.rindex(".")] key = os.path.basename(pkl_task_file_name) key = key[:key.rindex(".")] result["{}|{}".format(key1, key)] = accuracy with open(args.output_path, "w") as file_obj: file_obj.write(json.dumps(result))
def main(): args = parse_args() torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) transform = get_preprocessor( IN_CHANNELS[args.ds_name], IMAGE_SIZE[args.ds_name]) # NeuralFP的MNIST和F-MNIST实验需要重做,因为发现单通道bug kwargs = {'pin_memory': True} if not args.evaluate: # 训练模式 if args.ds_name == "MNIST": trn_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name], train=True, download=False, transform=transform) val_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name], train=False, download=False, transform=transform) elif args.ds_name == "F-MNIST": trn_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name], train=True, download=False, transform=transform) val_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name], train=False, download=False, transform=transform) elif args.ds_name == "CIFAR-10": trn_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name], train=True, download=False, transform=transform) val_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name], train=False, download=False, transform=transform) elif args.ds_name == "SVHN": trn_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name], train=True, transform=transform) val_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name], train=False, transform=transform) elif args.ds_name == "ImageNet": trn_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] + "/new2", train=True, transform=transform) val_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] + "/new2", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, **kwargs) test_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=0, **kwargs) if args.arch == "conv3": network = Conv3(IN_CHANNELS[args.ds_name], IMAGE_SIZE[args.ds_name], CLASS_NUM[args.ds_name]) elif args.arch == "resnet10": network = resnet10(in_channels=IN_CHANNELS[args.ds_name], num_classes=CLASS_NUM[args.ds_name]) elif args.arch == "resnet18": network = resnet18(in_channels=IN_CHANNELS[args.ds_name], num_classes=CLASS_NUM[args.ds_name]) network.cuda() model_path = os.path.join( PY_ROOT, "train_pytorch_model/NF_Det", "NF_Det@{}@{}@epoch_{}@lr_{}@eps_{}@num_dx_{}@num_class_{}.pth.tar" .format(args.ds_name, args.arch, args.epochs, args.lr, args.eps, args.num_dx, CLASS_NUM[args.ds_name])) os.makedirs(os.path.dirname(model_path), exist_ok=True) detector = NeuralFingerprintDetector( args.ds_name, network, args.num_dx, CLASS_NUM[args.ds_name], eps=args.eps, out_fp_dxdy_dir=args.output_dx_dy_dir) optimizer = optim.SGD(network.parameters(), lr=args.lr, weight_decay=1e-6, momentum=0.9) resume_epoch = 0 print("{}".format(model_path)) if os.path.exists(model_path): checkpoint = torch.load(model_path, lambda storage, location: storage) optimizer.load_state_dict(checkpoint["optimizer"]) resume_epoch = checkpoint["epoch"] network.load_state_dict(checkpoint["state_dict"]) for epoch in range(resume_epoch, args.epochs + 1): if (epoch == 1): detector.test(epoch, test_loader, test_length=0.1 * len(val_dataset)) detector.train(epoch, optimizer, train_loader) print("Epoch{}, Saving model in {}".format(epoch, model_path)) torch.save( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': network.state_dict(), 'optimizer': optimizer.state_dict(), }, model_path) else: # 测试模式 if args.study_subject == "speed_test": evaluate_speed(args) return extract_pattern = re.compile( ".*NF_Det@(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@eps_(.*?)@num_dx_(\d+)@num_class_(\d+).pth.tar" ) results = defaultdict(dict) for model_path in glob.glob( "{}/train_pytorch_model/NF_Det/NF_Det@*".format(PY_ROOT)): ma = extract_pattern.match(model_path) ds_name = ma.group(1) # if ds_name == "ImageNet": # FIXME # continue # if ds_name != "CIFAR-10": # FIXME # continue arch = ma.group(2) epoch = int(ma.group(3)) num_dx = int(ma.group(6)) eps = float(ma.group(5)) if arch == "conv3": network = Conv3(IN_CHANNELS[ds_name], IMAGE_SIZE[ds_name], CLASS_NUM[ds_name]) elif arch == "resnet10": network = resnet10(in_channels=IN_CHANNELS[ds_name], num_classes=CLASS_NUM[ds_name]) elif arch == "resnet18": network = resnet18(in_channels=IN_CHANNELS[ds_name], num_classes=CLASS_NUM[ds_name]) reject_thresholds = [0. + 0.001 * i for i in range(2050)] network.load_state_dict( torch.load(model_path, lambda storage, location: storage)["state_dict"]) network.cuda() print("load {} over".format(model_path)) detector = NeuralFingerprintDetector( ds_name, network, num_dx, CLASS_NUM[ds_name], eps=eps, out_fp_dxdy_dir=args.output_dx_dy_dir) # 不存在cross arch的概念 if args.study_subject == "shots": all_shots = [0, 1, 5] # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699} old_updates = args.num_updates # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699} for shot in all_shots: report_shot = shot if shot == 0: shot = 1 args.num_updates = 0 else: args.num_updates = old_updates num_way = 2 num_query = 15 val_dataset = MetaTaskDataset( 20000, num_way, shot, num_query, ds_name, is_train=False, load_mode=args.load_task_mode, protocol=args.protocol, no_random_way=True, adv_arch=args.adv_arch, fetch_attack_name=True) adv_val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=100, shuffle=False, **kwargs) # if args.profile: # cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof") # s = pstats.Stats("Profile.prof") # s.strip_dirs().sort_stats("time").print_stats() # else: F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune( adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr) results[ds_name][report_shot] = { "F1": F1, "best_tau": tau, "eps": eps, "num_dx": num_dx, "num_updates": args.num_updates, "attack_stats": attacker_stats } print("shot {} done".format(shot)) elif args.study_subject == "cross_domain": source_dataset, target_dataset = args.cross_domain_source, args.cross_domain_target if ds_name != source_dataset: continue # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699} old_num_update = args.num_updates # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699} for shot in [0, 1, 5]: report_shot = shot if shot == 0: shot = 1 args.num_updates = 0 else: args.num_updates = old_num_update num_way = 2 num_query = 15 val_dataset = MetaTaskDataset( 20000, num_way, shot, num_query, target_dataset, is_train=False, load_mode=args.load_task_mode, protocol=args.protocol, no_random_way=True, adv_arch=args.adv_arch, fetch_attack_name=False) adv_val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=100, shuffle=False, **kwargs) F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune( adv_val_loader, target_dataset, reject_thresholds, args.num_updates, args.lr) results["{}--{}@data_adv_arch_{}".format( source_dataset, target_dataset, args.adv_arch)][report_shot] = { "F1": F1, "best_tau": tau, "eps": eps, "num_dx": num_dx, "num_updates": args.num_updates, "attack_stats": attacker_stats } elif args.study_subject == "cross_arch": target_arch = args.cross_arch_target old_num_update = args.num_updates for shot in [0, 1, 5]: report_shot = shot if shot == 0: shot = 1 args.num_updates = 0 else: args.num_updates = old_num_update num_way = 2 num_query = 15 val_dataset = MetaTaskDataset( 20000, num_way, shot, num_query, ds_name, is_train=False, load_mode=args.load_task_mode, protocol=args.protocol, no_random_way=True, adv_arch=target_arch, fetch_attack_name=False) adv_val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=100, shuffle=False, **kwargs) F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune( adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr) results["{}_target_arch_{}".format( ds_name, target_arch)][report_shot] = { "F1": F1, "best_tau": tau, "eps": eps, "num_dx": num_dx, "num_updates": args.num_updates, "attack_stats": attacker_stats } elif args.study_subject == "finetune_eval": shot = 1 query_count = 15 old_updates = args.num_updates num_way = 2 num_query = 15 # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699} if ds_name != args.ds_name: continue val_dataset = MetaTaskDataset(20000, num_way, shot, num_query, ds_name, is_train=False, load_mode=args.load_task_mode, protocol=args.protocol, no_random_way=True, adv_arch=args.adv_arch) adv_val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False, **kwargs) args.num_updates = 50 for num_update in range(0, 51): # if args.profile: # cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof") # s = pstats.Stats("Profile.prof") # s.strip_dirs().sort_stats("time").print_stats() # else: F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune( adv_val_loader, ds_name, reject_thresholds, num_update, args.lr) results[ds_name][num_update] = { "F1": F1, "best_tau": tau, "eps": eps, "num_dx": num_dx, "num_updates": num_update, "attack_stats": attacker_stats } print("finetune {} done".format(shot)) if not args.profile: if args.study_subject == "cross_domain": filename = "{}/train_pytorch_model/NF_Det/cross_domain_{}--{}@adv_arch_{}.json".format( PY_ROOT, args.cross_domain_source, args.cross_domain_target, args.adv_arch) elif args.study_subject == "cross_arch": filename = "{}/train_pytorch_model/NF_Det/cross_arch_target_{}.json".format( PY_ROOT, args.cross_arch_target) else: filename = "{}/train_pytorch_model/NF_Det/{}@data_{}@protocol_{}@lr_{}@finetune_{}.json".format( PY_ROOT, args.study_subject, args.adv_arch, args.protocol, args.lr, args.num_updates) with open(filename, "w") as file_obj: file_obj.write(json.dumps(results)) file_obj.flush()
def evaluate_speed(args): extract_pattern = re.compile( ".*NF_Det@(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@eps_(.*?)@num_dx_(\d+)@num_class_(\d+).pth.tar" ) results = defaultdict(dict) for model_path in glob.glob( "{}/train_pytorch_model/NF_Det/NF_Det@*".format(PY_ROOT)): ma = extract_pattern.match(model_path) ds_name = ma.group(1) if ds_name != "CIFAR-10": continue arch = ma.group(2) epoch = int(ma.group(3)) num_dx = int(ma.group(6)) eps = float(ma.group(5)) network = Conv3(IN_CHANNELS[ds_name], IMAGE_SIZE[ds_name], CLASS_NUM[ds_name]) reject_thresholds = [0. + 0.001 * i for i in range(2050)] network.load_state_dict( torch.load(model_path, lambda storage, location: storage)["state_dict"]) network.cuda() print("load {} over".format(model_path)) detector = NeuralFingerprintDetector( ds_name, network, num_dx, CLASS_NUM[ds_name], eps=eps, out_fp_dxdy_dir=args.output_dx_dy_dir) # 不存在cross arch的概念 all_shots = [0, 1, 5] for shot in all_shots: report_shot = shot if shot == 0: num_updates = 0 shot = 1 else: num_updates = args.num_updates num_way = 2 num_query = 15 val_dataset = MetaTaskDataset(20000, num_way, shot, num_query, ds_name, is_train=False, load_mode=args.load_task_mode, protocol=args.protocol, no_random_way=True, adv_arch=args.adv_arch) adv_val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False) mean_time, var_time = detector.test_speed( adv_val_loader, ds_name, reject_thresholds, num_updates, args.lr, ) results[ds_name][report_shot] = { "mean_time": mean_time, "var_time": var_time } print("shot {} done".format(shot)) break file_name = "{}/train_pytorch_model/NF_Det/speed_test_result.json".format( PY_ROOT) with open(file_name, "w") as file_obj: file_obj.write(json.dumps(results)) file_obj.flush()
def evaluate_zero_shot(model_path_list, lr, protocol, args): # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做 extract_pattern_detail = re.compile( ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar" ) tot_num_tasks = 20000 way = 2 query = 15 result = defaultdict(dict) for model_path in model_path_list: ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) file_protocol = ma.group(2) if str(protocol) != file_protocol: continue balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" key = "{}@{}__{}".format(dataset, balance, protocol) if args.cross_domain_source is not None: if dataset != args.cross_domain_source: continue dataset = args.cross_domain_target key = "{}@{}-->{}__{}".format(args.cross_domain_source, balance, dataset, protocol) print("evaluate_accuracy model :{}".format( os.path.basename(model_path))) arch = ma.group(3) adv_arch = ma.group(4) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( model_path, checkpoint['epoch'])) shot = 1 num_update = 0 meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=protocol, no_random_way=True, adv_arch=adv_arch) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=False) result[key][0] = evaluate_result return result
def evaluate_cross_arch(model_path_list, num_update, lr, protocol, src_arch, target_arch, updateBN): extract_pattern_detail = re.compile( ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar" ) tot_num_tasks = 20000 way = 2 query = 15 result = defaultdict(dict) for model_path in model_path_list: ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) file_protocol = ma.group(2) if str(protocol) != file_protocol: continue arch = ma.group(3) adv_arch = ma.group(4) if adv_arch != src_arch: continue balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" print("evaluate_accuracy model :{}".format( os.path.basename(model_path))) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( model_path, checkpoint['epoch'])) old_num_update = num_update for shot in [0, 1, 5]: if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=protocol, no_random_way=True, adv_arch=target_arch, fetch_attack_name=False) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=updateBN) if num_update == 0: shot = 0 result["{}-->{}@{}_{}".format(src_arch, target_arch, dataset, balance)][shot] = evaluate_result return result