def __init__(self, dataset, num_classes, meta_batch_size, meta_step_size, inner_step_size, lr_decay_itr, epoch, num_inner_updates, load_task_mode, arch, tot_num_tasks, num_support, detector, attack_name, root_folder): super(self.__class__, self).__init__() self.dataset = dataset self.num_classes = num_classes self.meta_batch_size = meta_batch_size # task number per batch self.meta_step_size = meta_step_size self.inner_step_size = inner_step_size self.lr_decay_itr = lr_decay_itr self.epoch = epoch self.num_inner_updates = num_inner_updates self.test_finetune_updates = num_inner_updates # Make the nets if arch == "conv3": # network = FourConvs(IN_CHANNELS[self.dataset_name], IMAGE_SIZE[self.dataset_name], num_classes) network = Conv3(IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset], num_classes) elif arch == "resnet10": network = resnet10(num_classes, pretrained=False) elif arch == "resnet18": network = resnet18(num_classes, pretrained=False) self.network = network self.network.cuda() val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, 15, dataset, load_task_mode, detector, attack_name, root_folder) self.val_loader = DataLoader( val_dataset, batch_size=100, shuffle=False, num_workers=0, pin_memory=True) # 固定100个task,分别测每个task的准确率
def build_network(dataset, arch, model_path): # if dataset!="ImageNet": # assert os.path.exists(model_path), "{} not exists!".format(model_path) if arch in models.__dict__: print("=> using pre-trained model '{}'".format(arch)) img_classifier_network = models.__dict__[arch](pretrained=False) else: print("=> creating model '{}'".format(arch)) if arch == "resnet10": img_classifier_network = resnet10(num_classes=CLASS_NUM[dataset], in_channels=IN_CHANNELS[dataset], pretrained=False) elif arch == "resnet18": img_classifier_network = resnet18(num_classes=CLASS_NUM[dataset], in_channels=IN_CHANNELS[dataset], pretrained=False) elif arch == "conv3": img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) if os.path.exists(model_path): print("=> loading checkpoint '{}'".format(model_path)) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) img_classifier_network.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( model_path, checkpoint['epoch'])) return img_classifier_network
def evaluate_speed(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar") result = defaultdict(dict) # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)): ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) if dataset!="CIFAR-10": continue split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] if split_protocol != args.protocol: continue arch = ma.group(3) epoch = int(ma.group(4)) lr = float(ma.group(5)) batch_size = int(ma.group(6)) print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates all_shots = [0, 1,5] num_updates = args.num_updates for shot in all_shots: report_shot = shot if shot == 0: num_updates = 0 shot = 1 else: num_updates = args.num_updates meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=split_protocol, no_random_way=True, adv_arch="conv3") #FIXME adv arch还没做其他architecture的代码 data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformCV2(dataset, [1, 2]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) evaluate_result = speed_test(model, data_loader, lr, num_updates, update_BN=False) result[dataset][report_shot] = evaluate_result break with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/speed_test.json".format(PY_ROOT), "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()
def evaluate_cross_arch(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) # IMG_ROTATE_DET@CIFAR-10_TRAIN_ALL_TEST_ALL@model_conv3@data_conv3@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar") result = defaultdict(dict) update_BN = False for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)): ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] if split_protocol != args.protocol: continue arch = ma.group(3) adv_arch = ma.group(4) if adv_arch != args.cross_arch_source: continue epoch = int(ma.group(5)) lr = float(ma.group(6)) batch_size = int(ma.group(7)) print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates for shot in [0,1,5]: if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, is_train=False, load_mode=args.load_mode, protocol=split_protocol, no_random_way=True, adv_arch=args.cross_arch_target, fetch_attack_name=False) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformCV2(dataset, [1, 2]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100","SVHN"] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number, num_classes=2) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=update_BN) # FIXME update_BN=False会很高 if num_update == 0: shot = 0 result["{}@{}-->{}".format(dataset, args.cross_arch_source, args.cross_arch_target)][shot] = evaluate_result with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/cross_arch_{}--{}_using_{}_result_updateBN_{}.json".format(PY_ROOT, args.cross_arch_source, args.cross_arch_target, args.protocol, update_BN), "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()
def build_DNN_detector(dataset, arch, adv_arch, protocol): model_path = "{}/train_pytorch_model/white_box_model/DL_DET@{}_{}@model_{}@data_{}@epoch_40@class_2@lr_0.0001@balance_True.pth.tar".format( PY_ROOT, dataset, protocol, arch, adv_arch) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) print("load {} to detector".format(model_path)) network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) network.load_state_dict(checkpoint["state_dict"], strict=True) network.cuda() return network
def __init__(self, dataset, num_classes, meta_batch_size, meta_step_size, inner_step_size, lr_decay_itr, epoch, num_inner_updates, load_task_mode, protocol, tot_num_tasks, num_support, num_query, no_random_way, tensorboard_data_prefix, train=True, adv_arch="conv3",rotate=False,need_val=False): super(self.__class__, self).__init__() self.dataset = dataset self.num_classes = num_classes self.meta_batch_size = meta_batch_size # task number per batch self.meta_step_size = meta_step_size self.inner_step_size = inner_step_size self.lr_decay_itr = lr_decay_itr self.epoch = epoch self.num_inner_updates = num_inner_updates self.test_finetune_updates = num_inner_updates # Make the nets if train: # 需要一种特殊的MetaTaskDataset,训练阶段support set就给两个way的class attribute trn_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, num_query, dataset, is_train=True, load_mode=load_task_mode, protocol=protocol, no_random_way=no_random_way, adv_arch=adv_arch,rotate=rotate) self.train_loader = DataLoader(trn_dataset, batch_size=meta_batch_size, shuffle=True, num_workers=0, pin_memory=True) self.tensorboard = TensorBoardWriter("{0}/zeroshot_tensorboard".format(PY_ROOT), tensorboard_data_prefix) os.makedirs("{0}/zeroshot_tensorboard".format(PY_ROOT), exist_ok=True) if need_val: val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, 15, dataset, is_train=False, load_mode=load_task_mode, protocol=protocol, no_random_way=True, adv_arch=adv_arch,rotate=rotate) self.val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=0, pin_memory=True) # 固定100个task,分别测每个task的准确率 self.hidden_feature_size = 2048 self.attr_network = AttributeNetwork(312,1200,self.hidden_feature_size) # output 2048 self.relation_network = RelationNetwork(2 * self.hidden_feature_size, 1200, 2) self.img_feature_extract_network = Conv3(IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset], self.hidden_feature_size) self.attr_network.cuda() self.relation_network.cuda() self.inner_attr_network = copy.deepcopy(self.attr_network) # deal with each task self.inner_relation_network = copy.deepcopy(self.relation_network) # deal with each task # 没有内部更新,只有外部更新,optimizer拥有两个网络的参数 self.opt_attr_net = Adam(self.img_feature_extract_network.parameters() + self.relation_network.parameters() + self.attr_network.parameters(), lr=meta_step_size) self.sched = StepLR(self.opt_attr_net, step_size=30000,gamma=0.5)
def __init__(self, dataset, num_classes, meta_batch_size, meta_step_size, inner_step_size, lr_decay_itr, epoch, num_inner_updates, load_task_mode, protocol, arch, tot_num_tasks, num_support, num_query, no_random_way, tensorboard_data_prefix, train=True, adv_arch="conv4", need_val=False): super(self.__class__, self).__init__() self.dataset = dataset self.num_classes = num_classes self.meta_batch_size = meta_batch_size # task number per batch self.meta_step_size = meta_step_size self.inner_step_size = inner_step_size self.lr_decay_itr = lr_decay_itr self.epoch = epoch self.num_inner_updates = num_inner_updates self.test_finetune_updates = num_inner_updates # Make the nets if arch == "conv3": # network = FourConvs(IN_CHANNELS[self.dataset_name], IMAGE_SIZE[self.dataset_name], num_classes) network = Conv3(IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset], num_classes) elif arch == "resnet10": network = MetaNetwork(resnet10(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False), IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset]) elif arch == "resnet18": network = MetaNetwork(resnet18(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False), IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset]) self.network = network self.network.cuda() if train: trn_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, num_query, dataset, is_train=True, load_mode=load_task_mode, protocol=protocol, no_random_way=no_random_way, adv_arch=adv_arch, fetch_attack_name=False) # task number per mini-batch is controlled by DataLoader self.train_loader = DataLoader(trn_dataset, batch_size=meta_batch_size, shuffle=True, num_workers=4, pin_memory=True) self.tensorboard = TensorBoardWriter("{0}/pytorch_MAML_tensorboard".format(PY_ROOT), tensorboard_data_prefix) os.makedirs("{0}/pytorch_MAML_tensorboard".format(PY_ROOT), exist_ok=True) if need_val: val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, 15, dataset, is_train=False, load_mode=load_task_mode, protocol=protocol, no_random_way=True, adv_arch=adv_arch, fetch_attack_name=False) self.val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=4, pin_memory=True) # 固定100个task,分别测每个task的准确率 self.fast_net = InnerLoop(self.network, self.num_inner_updates, self.inner_step_size, self.meta_batch_size) # 并行执行每个task self.fast_net.cuda() self.opt = Adam(self.network.parameters(), lr=meta_step_size)
def build_meta_adv_detector(dataset, arch, adv_arch, shot, protocol): # extract_pattern = re.compile( # ".*/MAML@(.*?)_(.*?)@model_(.*?)@data.*?@epoch_(\d+)@meta_batch_size_(\d+)@way_(\d+)@shot_(\d+)@num_query_(\d+)@num_updates_(\d+)@lr_(.*?)@inner_lr_(.*?)@fixed_way_(.*?)@rotate_(.*?)\.pth.tar") # str2bool = lambda v: v.lower() in ("yes", "true", "t", "1") model_path = "{}/train_pytorch_model/white_box_model/MAML@{}_{}@model_{}@data_{}@epoch_4@meta_batch_size_30@way_2@shot_{}@num_query_35@num_updates_12@lr_0.0001@inner_lr_0.001@fixed_way_True@rotate_False.pth.tar".format( PY_ROOT, dataset, protocol, arch, adv_arch, shot) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) print("load {} to detector".format(model_path)) network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) network.load_state_dict(checkpoint['state_dict'], strict=True) network.cuda() return network
def build_rotate_detector(dataset, arch, adv_arch, protocol): img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) model_path = "{}/train_pytorch_model/white_box_model/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar".format( PY_ROOT, dataset, protocol, arch, adv_arch) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) print("load {} to detector".format(model_path)) image_transform = ImageTransformTorch(dataset, [5, 15]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2 detector = Detector(dataset, img_classifier_network, CLASS_NUM[dataset], image_transform, layer_number) detector.load_state_dict(checkpoint['state_dict'], strict=True) detector.cuda() return detector
def evaluate_finetune(model_path_list, lr, protocol): # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做 extract_pattern_detail = re.compile(".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar") tot_num_tasks = 20000 way = 2 query = 15 result = defaultdict(dict) assert protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II, "protocol {} is not TRAIN_I_TEST_II!".format(protocol) for model_path in model_path_list: ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) if dataset != "CIFAR-10": continue file_protocol = ma.group(2) if str(protocol) != file_protocol: continue adv_arch = ma.group(4) balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) arch = ma.group(3) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() import torch checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) shot = 1 meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=protocol, no_random_way=True, adv_arch=adv_arch) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) for num_update in range(0,51): evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=False) if num_update == 0: shot = 0 else: shot = 1 evaluate_result["shot"] = shot result["{}_{}".format(dataset,balance)][num_update] = evaluate_result return result
def evaluate_whitebox(dataset, arch, adv_arch, detector, attack_name, num_update, lr, protocol, load_mode,result): # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做 extract_pattern_detail = re.compile(".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar") tot_num_tasks = 20000 way = 2 query = 15 model_path = "{}/train_pytorch_model/white_box_model/DL_DET@{}_{}@model_{}@data_{}@epoch_40@class_2@lr_0.0001@balance_True.pth.tar".format( PY_ROOT, dataset, protocol, arch, adv_arch) assert os.path.exists(model_path), "{} is not exists".format(model_path) root_folder = IMAGE_DATA_ROOT[dataset] + "/adversarial_images/white_box@data_{}@det_{}/{}/".format(adv_arch, detector,attack_name) ma = extract_pattern_detail.match(model_path) balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) arch = ma.group(3) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) old_num_update = num_update # for shot in range(16): for shot in [0,1,5]: if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, load_mode=load_mode, detector=detector, attack_name=attack_name,root_folder=root_folder) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update,update_BN=True) if num_update == 0: shot = 0 result["{}_{}_{}_{}".format(dataset, attack_name, detector, adv_arch)][shot] = evaluate_result return result
def build_neural_fingerprint_detector(dataset, arch, eps=0.1, num_dx=5): output_dx_dy_dir = "{}/NF_dx_dy".format(PY_ROOT) model_path = "{}/train_pytorch_model/white_box_model/NF_Det@{}@{}*.pth.tar".format( PY_ROOT, dataset, arch) model_path = glob.glob(model_path)[0] checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) print("load {} to detector".format(model_path)) network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) network.load_state_dict(checkpoint["state_dict"]) network.cuda() detector = NeuralFingerprintDetector(dataset, network, num_dx, CLASS_NUM[dataset], eps=eps, out_fp_dxdy_dir=output_dx_dy_dir) return detector
def main(): args = parse_args() torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) transform = get_preprocessor( IN_CHANNELS[args.ds_name], IMAGE_SIZE[args.ds_name]) # NeuralFP的MNIST和F-MNIST实验需要重做,因为发现单通道bug kwargs = {'pin_memory': True} if not args.evaluate: # 训练模式 if args.ds_name == "MNIST": trn_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name], train=True, download=False, transform=transform) val_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name], train=False, download=False, transform=transform) elif args.ds_name == "F-MNIST": trn_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name], train=True, download=False, transform=transform) val_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name], train=False, download=False, transform=transform) elif args.ds_name == "CIFAR-10": trn_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name], train=True, download=False, transform=transform) val_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name], train=False, download=False, transform=transform) elif args.ds_name == "SVHN": trn_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name], train=True, transform=transform) val_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name], train=False, transform=transform) elif args.ds_name == "ImageNet": trn_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] + "/new2", train=True, transform=transform) val_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] + "/new2", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, **kwargs) test_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=0, **kwargs) if args.arch == "conv3": network = Conv3(IN_CHANNELS[args.ds_name], IMAGE_SIZE[args.ds_name], CLASS_NUM[args.ds_name]) elif args.arch == "resnet10": network = resnet10(in_channels=IN_CHANNELS[args.ds_name], num_classes=CLASS_NUM[args.ds_name]) elif args.arch == "resnet18": network = resnet18(in_channels=IN_CHANNELS[args.ds_name], num_classes=CLASS_NUM[args.ds_name]) network.cuda() model_path = os.path.join( PY_ROOT, "train_pytorch_model/NF_Det", "NF_Det@{}@{}@epoch_{}@lr_{}@eps_{}@num_dx_{}@num_class_{}.pth.tar" .format(args.ds_name, args.arch, args.epochs, args.lr, args.eps, args.num_dx, CLASS_NUM[args.ds_name])) os.makedirs(os.path.dirname(model_path), exist_ok=True) detector = NeuralFingerprintDetector( args.ds_name, network, args.num_dx, CLASS_NUM[args.ds_name], eps=args.eps, out_fp_dxdy_dir=args.output_dx_dy_dir) optimizer = optim.SGD(network.parameters(), lr=args.lr, weight_decay=1e-6, momentum=0.9) resume_epoch = 0 print("{}".format(model_path)) if os.path.exists(model_path): checkpoint = torch.load(model_path, lambda storage, location: storage) optimizer.load_state_dict(checkpoint["optimizer"]) resume_epoch = checkpoint["epoch"] network.load_state_dict(checkpoint["state_dict"]) for epoch in range(resume_epoch, args.epochs + 1): if (epoch == 1): detector.test(epoch, test_loader, test_length=0.1 * len(val_dataset)) detector.train(epoch, optimizer, train_loader) print("Epoch{}, Saving model in {}".format(epoch, model_path)) torch.save( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': network.state_dict(), 'optimizer': optimizer.state_dict(), }, model_path) else: # 测试模式 if args.study_subject == "speed_test": evaluate_speed(args) return extract_pattern = re.compile( ".*NF_Det@(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@eps_(.*?)@num_dx_(\d+)@num_class_(\d+).pth.tar" ) results = defaultdict(dict) for model_path in glob.glob( "{}/train_pytorch_model/NF_Det/NF_Det@*".format(PY_ROOT)): ma = extract_pattern.match(model_path) ds_name = ma.group(1) # if ds_name == "ImageNet": # FIXME # continue # if ds_name != "CIFAR-10": # FIXME # continue arch = ma.group(2) epoch = int(ma.group(3)) num_dx = int(ma.group(6)) eps = float(ma.group(5)) if arch == "conv3": network = Conv3(IN_CHANNELS[ds_name], IMAGE_SIZE[ds_name], CLASS_NUM[ds_name]) elif arch == "resnet10": network = resnet10(in_channels=IN_CHANNELS[ds_name], num_classes=CLASS_NUM[ds_name]) elif arch == "resnet18": network = resnet18(in_channels=IN_CHANNELS[ds_name], num_classes=CLASS_NUM[ds_name]) reject_thresholds = [0. + 0.001 * i for i in range(2050)] network.load_state_dict( torch.load(model_path, lambda storage, location: storage)["state_dict"]) network.cuda() print("load {} over".format(model_path)) detector = NeuralFingerprintDetector( ds_name, network, num_dx, CLASS_NUM[ds_name], eps=eps, out_fp_dxdy_dir=args.output_dx_dy_dir) # 不存在cross arch的概念 if args.study_subject == "shots": all_shots = [0, 1, 5] # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699} old_updates = args.num_updates # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699} for shot in all_shots: report_shot = shot if shot == 0: shot = 1 args.num_updates = 0 else: args.num_updates = old_updates num_way = 2 num_query = 15 val_dataset = MetaTaskDataset( 20000, num_way, shot, num_query, ds_name, is_train=False, load_mode=args.load_task_mode, protocol=args.protocol, no_random_way=True, adv_arch=args.adv_arch, fetch_attack_name=True) adv_val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=100, shuffle=False, **kwargs) # if args.profile: # cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof") # s = pstats.Stats("Profile.prof") # s.strip_dirs().sort_stats("time").print_stats() # else: F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune( adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr) results[ds_name][report_shot] = { "F1": F1, "best_tau": tau, "eps": eps, "num_dx": num_dx, "num_updates": args.num_updates, "attack_stats": attacker_stats } print("shot {} done".format(shot)) elif args.study_subject == "cross_domain": source_dataset, target_dataset = args.cross_domain_source, args.cross_domain_target if ds_name != source_dataset: continue # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699} old_num_update = args.num_updates # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699} for shot in [0, 1, 5]: report_shot = shot if shot == 0: shot = 1 args.num_updates = 0 else: args.num_updates = old_num_update num_way = 2 num_query = 15 val_dataset = MetaTaskDataset( 20000, num_way, shot, num_query, target_dataset, is_train=False, load_mode=args.load_task_mode, protocol=args.protocol, no_random_way=True, adv_arch=args.adv_arch, fetch_attack_name=False) adv_val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=100, shuffle=False, **kwargs) F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune( adv_val_loader, target_dataset, reject_thresholds, args.num_updates, args.lr) results["{}--{}@data_adv_arch_{}".format( source_dataset, target_dataset, args.adv_arch)][report_shot] = { "F1": F1, "best_tau": tau, "eps": eps, "num_dx": num_dx, "num_updates": args.num_updates, "attack_stats": attacker_stats } elif args.study_subject == "cross_arch": target_arch = args.cross_arch_target old_num_update = args.num_updates for shot in [0, 1, 5]: report_shot = shot if shot == 0: shot = 1 args.num_updates = 0 else: args.num_updates = old_num_update num_way = 2 num_query = 15 val_dataset = MetaTaskDataset( 20000, num_way, shot, num_query, ds_name, is_train=False, load_mode=args.load_task_mode, protocol=args.protocol, no_random_way=True, adv_arch=target_arch, fetch_attack_name=False) adv_val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=100, shuffle=False, **kwargs) F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune( adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr) results["{}_target_arch_{}".format( ds_name, target_arch)][report_shot] = { "F1": F1, "best_tau": tau, "eps": eps, "num_dx": num_dx, "num_updates": args.num_updates, "attack_stats": attacker_stats } elif args.study_subject == "finetune_eval": shot = 1 query_count = 15 old_updates = args.num_updates num_way = 2 num_query = 15 # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699} if ds_name != args.ds_name: continue val_dataset = MetaTaskDataset(20000, num_way, shot, num_query, ds_name, is_train=False, load_mode=args.load_task_mode, protocol=args.protocol, no_random_way=True, adv_arch=args.adv_arch) adv_val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False, **kwargs) args.num_updates = 50 for num_update in range(0, 51): # if args.profile: # cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof") # s = pstats.Stats("Profile.prof") # s.strip_dirs().sort_stats("time").print_stats() # else: F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune( adv_val_loader, ds_name, reject_thresholds, num_update, args.lr) results[ds_name][num_update] = { "F1": F1, "best_tau": tau, "eps": eps, "num_dx": num_dx, "num_updates": num_update, "attack_stats": attacker_stats } print("finetune {} done".format(shot)) if not args.profile: if args.study_subject == "cross_domain": filename = "{}/train_pytorch_model/NF_Det/cross_domain_{}--{}@adv_arch_{}.json".format( PY_ROOT, args.cross_domain_source, args.cross_domain_target, args.adv_arch) elif args.study_subject == "cross_arch": filename = "{}/train_pytorch_model/NF_Det/cross_arch_target_{}.json".format( PY_ROOT, args.cross_arch_target) else: filename = "{}/train_pytorch_model/NF_Det/{}@data_{}@protocol_{}@lr_{}@finetune_{}.json".format( PY_ROOT, args.study_subject, args.adv_arch, args.protocol, args.lr, args.num_updates) with open(filename, "w") as file_obj: file_obj.write(json.dumps(results)) file_obj.flush()
def evaluate_whitebox_attack(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) attacks = ["FGSM", "CW_L2"] # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar") result = defaultdict(dict) # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar model_path = "{}/train_pytorch_model/white_box_model/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar".format( PY_ROOT, args.dataset, args.protocol, "conv3", args.adv_arch) assert os.path.exists(model_path), "{} not exists".format(model_path) ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) # if dataset != "MNIST": # continue split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] arch = ma.group(3) epoch = int(ma.group(4)) lr = float(ma.group(5)) batch_size = int(ma.group(6)) print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates all_shots = [1,5] detector = "RotateDet" checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) for attack_name in attacks: for shot in all_shots: root_folder = IMAGE_DATA_ROOT[dataset] + "/adversarial_images/white_box@data_{}@det_{}/{}/".format(args.adv_arch, detector, attack_name) if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = WhiteBoxMetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, args.load_mode, detector, attack_name, root_folder) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformTorch(dataset, [5, 15]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) evaluate_result = finetune_eval_task_rotate(model, data_loader, lr, num_update, update_BN=args.eval_update_BN) if num_update == 0: shot = 0 result["{}_{}_{}_{}".format(dataset, attack_name, detector, args.adv_arch)][shot] = evaluate_result if args.eval_update_BN: update_BN_str="UpdateBN" else: update_BN_str = "NoUpdateBN" with open("{}/train_pytorch_model/white_box_model/white_box_RotateDet_{}_{}_using_{}__result.json".format(PY_ROOT,args.dataset, update_BN_str, args.protocol), "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()
def evaluate_speed(args): extract_pattern = re.compile( ".*NF_Det@(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@eps_(.*?)@num_dx_(\d+)@num_class_(\d+).pth.tar" ) results = defaultdict(dict) for model_path in glob.glob( "{}/train_pytorch_model/NF_Det/NF_Det@*".format(PY_ROOT)): ma = extract_pattern.match(model_path) ds_name = ma.group(1) if ds_name != "CIFAR-10": continue arch = ma.group(2) epoch = int(ma.group(3)) num_dx = int(ma.group(6)) eps = float(ma.group(5)) network = Conv3(IN_CHANNELS[ds_name], IMAGE_SIZE[ds_name], CLASS_NUM[ds_name]) reject_thresholds = [0. + 0.001 * i for i in range(2050)] network.load_state_dict( torch.load(model_path, lambda storage, location: storage)["state_dict"]) network.cuda() print("load {} over".format(model_path)) detector = NeuralFingerprintDetector( ds_name, network, num_dx, CLASS_NUM[ds_name], eps=eps, out_fp_dxdy_dir=args.output_dx_dy_dir) # 不存在cross arch的概念 all_shots = [0, 1, 5] for shot in all_shots: report_shot = shot if shot == 0: num_updates = 0 shot = 1 else: num_updates = args.num_updates num_way = 2 num_query = 15 val_dataset = MetaTaskDataset(20000, num_way, shot, num_query, ds_name, is_train=False, load_mode=args.load_task_mode, protocol=args.protocol, no_random_way=True, adv_arch=args.adv_arch) adv_val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False) mean_time, var_time = detector.test_speed( adv_val_loader, ds_name, reject_thresholds, num_updates, args.lr, ) results[ds_name][report_shot] = { "mean_time": mean_time, "var_time": var_time } print("shot {} done".format(shot)) break file_name = "{}/train_pytorch_model/NF_Det/speed_test_result.json".format( PY_ROOT) with open(file_name, "w") as file_obj: file_obj.write(json.dumps(results)) file_obj.flush()
def evaluate_zero_shot(model_path_list, lr, protocol, args): # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做 extract_pattern_detail = re.compile( ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar" ) tot_num_tasks = 20000 way = 2 query = 15 result = defaultdict(dict) for model_path in model_path_list: ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) file_protocol = ma.group(2) if str(protocol) != file_protocol: continue balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" key = "{}@{}__{}".format(dataset, balance, protocol) if args.cross_domain_source is not None: if dataset != args.cross_domain_source: continue dataset = args.cross_domain_target key = "{}@{}-->{}__{}".format(args.cross_domain_source, balance, dataset, protocol) print("evaluate_accuracy model :{}".format( os.path.basename(model_path))) arch = ma.group(3) adv_arch = ma.group(4) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( model_path, checkpoint['epoch'])) shot = 1 num_update = 0 meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=protocol, no_random_way=True, adv_arch=adv_arch) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=False) result[key][0] = evaluate_result return result
def main(): args = parser.parse_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) preprocessor = get_preprocessor(input_channels=IN_CHANNELS[args.dataset]) if args.dataset == "CIFAR-10": train_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor) val_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor) elif args.dataset == "MNIST": train_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor, download=True) val_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor, download=True) elif args.dataset == "F-MNIST": train_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor, download=True) val_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor, download=True) elif args.dataset == "SVHN": train_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor) val_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor) # load image classifier model img_classifier_model_path = "{}/train_pytorch_model/DL_IMAGE_CLASSIFIER_{}@{}@epoch_40@lr_0.0001@batch_500.pth.tar".format( PY_ROOT, args.dataset, args.adv_arch) if args.adv_arch == "resnet10": img_classifier_network = resnet10( num_classes=CLASS_NUM[args.dataset], in_channels=IN_CHANNELS[args.dataset]) elif args.adv_arch == "resnet18": img_classifier_network = resnet18( num_classes=CLASS_NUM[args.dataset], in_channels=IN_CHANNELS[args.dataset]) elif args.adv_arch == "conv3": img_classifier_network = Conv3(IN_CHANNELS[args.dataset], IMAGE_SIZE[args.dataset], CLASS_NUM[args.dataset]) print("=> loading checkpoint '{}'".format(img_classifier_model_path)) checkpoint = torch.load(img_classifier_model_path, map_location=lambda storage, loc: storage) img_classifier_network.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {}) for img classifier".format( img_classifier_model_path, checkpoint['epoch'])) img_classifier_network.eval() img_classifier_network = img_classifier_network.cuda() # load detector model if args.detector == "MetaAdvDet": # 如果攻击finetune后的model,则需要动态生成噪音,每次support上finetune之后,迅速进行攻击生成新的对抗样本 detector_net = build_meta_adv_detector( args.dataset, args.det_arch, args.adv_arch, args.shot, args. protocol) # 需要在support上fine-tune后进行检测,到底攻击finetune后的还是finetune前的 elif args.detector == "DNN": detector_net = build_DNN_detector(args.dataset, args.det_arch, args.adv_arch, args.protocol) elif args.detector == "RotateDet": detector_net = build_rotate_detector(args.dataset, args.det_arch, args.adv_arch, args.protocol) elif args.detector == "NeuralFP": detector_net = build_neural_fingerprint_detector( args.dataset, args.det_arch) if args.detector == "NeuralFP": if args.attack == "CW_L2": attack = CarliniWagnerL2Fingerprint(img_classifier_network, targeted=True, confidence=0.3, search_steps=30, max_steps=args.atk_max_iter, optimizer_lr=0.01, neural_fp=detector_net) elif args.attack == "FGSM": attack = IterativeFastGradientSignTargetedFingerprint( img_classifier_network, alpha=0.01, max_iters=args.atk_max_iter, neural_fp=detector_net) else: detector_net.eval() combined_model = CombinedModel(img_classifier_network, detector_net) combined_model.cuda() combined_model.eval() if args.attack == "CW_L2": attack = CarliniWagnerL2(combined_model, True, confidence=0.3, search_steps=30, max_steps=args.atk_max_iter, optimizer_lr=0.01) elif args.attack == "FGSM": attack = IterativeFastGradientSignTargeted( combined_model, alpha=0.01, max_iters=args.atk_max_iter) generate( attack, args.attack, args.dataset, args.detector, val_dataset, output_dir="{}/adversarial_images/white_box@data_{}@det_{}".format( IMAGE_DATA_ROOT[args.dataset], args.adv_arch, args.detector), args=args)
def main_train_worker(args, model_path, META_ATTACKER_PART_I=None,META_ATTACKER_PART_II=None,gpu="0"): if META_ATTACKER_PART_I is None: META_ATTACKER_PART_I = config.META_ATTACKER_PART_I if META_ATTACKER_PART_II is None: META_ATTACKER_PART_II = config.META_ATTACKER_PART_II print("Use GPU: {} for training".format(gpu)) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = gpu print("will save to {}".format(model_path)) global best_acc1 if args.arch == "conv3": model = Conv3(IN_CHANNELS[args.dataset],IMAGE_SIZE[args.dataset], 2) elif args.arch == "resnet10": model = resnet10(2, in_channels=IN_CHANNELS[args.dataset], pretrained=False) elif args.arch == "resnet18": model = resnet18(2, in_channels=IN_CHANNELS[args.dataset], pretrained=False) model = model.cuda() if args.dataset == "ImageNet": train_dataset = AdversaryRandomAccessNpyDataset(IMAGE_DATA_ROOT[args.dataset] + "/adversarial_images/{}".format(args.adv_arch), True, args.protocol, META_ATTACKER_PART_I,META_ATTACKER_PART_II, args.balance, args.dataset) else: train_dataset = AdversaryDataset(IMAGE_DATA_ROOT[args.dataset] + "/adversarial_images/{}".format(args.adv_arch), True, args.protocol, META_ATTACKER_PART_I,META_ATTACKER_PART_II, args.balance) # val_dataset = MetaTaskDataset(20000, 2, 1, 15, # args.dataset, is_train=False, pkl_task_dump_path=args.test_pkl_path, # load_mode=LOAD_TASK_MODE.LOAD, # protocol=args.protocol, no_random_way=True) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if os.path.exists(model_path): print("=> loading checkpoint '{}'".format(model_path)) checkpoint = torch.load(model_path) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(model_path)) cudnn.benchmark = True # Data loading code train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) # val_loader = torch.utils.data.DataLoader( # val_dataset, # batch_size=100, shuffle=False, # num_workers=0, pin_memory=True) tensorboard = TensorBoardWriter("{0}/pytorch_DeepLearning_tensorboard".format(PY_ROOT), "DeepLearning") for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args) # train for one epoch train(train_loader, None, model, criterion, optimizer, epoch,tensorboard, args) if args.balance: train_dataset.img_label_list.clear() train_dataset.img_label_list.extend(train_dataset.img_label_dict[1]) train_dataset.img_label_list.extend(random.sample(train_dataset.img_label_dict[0], len(train_dataset.img_label_dict[1]))) # evaluate_accuracy on validation set # acc1 = validate(val_loader, model, criterion, args) # remember best acc@1 and save checkpoint save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, filename=model_path)
def main_train_worker(gpu, args): args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # create model if args.pretrained and args.arch in models.__dict__: print("=> using pre-trained model '{}'".format(args.arch)) network = models.__dict__[args.arch](pretrained=True) else: print("=> creating model '{}'".format(args.arch)) if args.arch == "resnet10": network = resnet10(num_classes=CLASS_NUM[args.dataset], in_channels=IN_CHANNELS[args.dataset]) elif args.arch == "resnet18": network = resnet18(num_classes=CLASS_NUM[args.dataset], in_channels=IN_CHANNELS[args.dataset]) elif args.arch == "conv3": network = Conv3(IN_CHANNELS[args.dataset], IMAGE_SIZE[args.dataset],CLASS_NUM[args.dataset]) if args.arch.startswith("resnet"): network.avgpool = Identity() network.fc = nn.Linear(512, CLASS_NUM[args.dataset]) elif args.arch.startswith("vgg"): network.classifier[6] = nn.Linear(4096, CLASS_NUM[args.dataset]) model_path = './train_pytorch_model/DL_IMAGE_CLASSIFIER_{}@{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format( args.dataset, args.arch, args.epochs, args.lr, args.batch_size) os.makedirs(os.path.dirname(model_path), exist_ok=True) preprocessor = get_preprocessor(IN_CHANNELS[args.dataset]) network.cuda() # define loss function (criterion) and optimizer image_classifier_loss = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.Adam(network.parameters(), args.lr, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] network.load_state_dict(checkpoint['state_dict']) # optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True if args.dataset == "CIFAR-10": train_dataset = CIFAR10(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor) val_dataset = CIFAR10(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor) elif args.dataset == "MNIST": train_dataset = MNIST(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor, download=True) val_dataset = MNIST(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor, download=True) elif args.dataset == "F-MNIST": train_dataset = FashionMNIST(IMAGE_DATA_ROOT[args.dataset], train=True,transform=preprocessor, download=True) val_dataset = FashionMNIST(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor, download=True) elif args.dataset=="SVHN": train_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor) val_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) for epoch in range(0, args.epochs): adjust_learning_rate(optimizer, epoch, args) # train for one epoch train(train_loader, network, image_classifier_loss, optimizer, epoch, args) # evaluate_accuracy on validation set acc1 = validate(val_loader, network, image_classifier_loss, args) print(acc1) # remember best acc@1 and save checkpoint save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': network.state_dict(), 'optimizer': optimizer.state_dict(), }, filename=model_path)
def evaluate_cross_arch(model_path_list, num_update, lr, protocol, src_arch, target_arch, updateBN): extract_pattern_detail = re.compile( ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar" ) tot_num_tasks = 20000 way = 2 query = 15 result = defaultdict(dict) for model_path in model_path_list: ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) file_protocol = ma.group(2) if str(protocol) != file_protocol: continue arch = ma.group(3) adv_arch = ma.group(4) if adv_arch != src_arch: continue balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" print("evaluate_accuracy model :{}".format( os.path.basename(model_path))) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( model_path, checkpoint['epoch'])) old_num_update = num_update for shot in [0, 1, 5]: if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=protocol, no_random_way=True, adv_arch=target_arch, fetch_attack_name=False) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=updateBN) if num_update == 0: shot = 0 result["{}-->{}@{}_{}".format(src_arch, target_arch, dataset, balance)][shot] = evaluate_result return result