def train(self, model_path, resume_epoch=0, need_val=False): # mtr_loss, mtr_acc, mval_loss, mval_acc = [], [], [], [] for epoch in range(resume_epoch, self.epoch): # Evaluate on test tasks # Collect a meta batch update # Save a model snapshot every now and then for i, (support_images, _, support_labels, query_images, _, query_labels, *_) in enumerate(self.train_loader): itr = epoch * len(self.train_loader) + i self.adjust_learning_rate(itr, self.meta_step_size, self.lr_decay_itr) grads = [] support_images, support_labels, query_images, query_labels = support_images.cuda(), support_labels.cuda(), query_images.cuda(), query_labels.cuda() for task_idx in range(support_images.size(0)): self.fast_net.copy_weights(self.network) # fast_net only forward one task's data g = self.fast_net.forward(support_images[task_idx],query_images[task_idx], support_labels[task_idx], query_labels[task_idx]) # (trl, tra, vall, vala) = metrics grads.append(g) # Perform the meta update # print('Meta update', itr) self.meta_update(grads, query_images, query_labels) grads.clear() if itr % 1000 == 0 and need_val: result_json = finetune_eval_task_accuracy(self.network, self.val_loader, self.inner_step_size, self.test_finetune_updates, update_BN=True) query_F1_tensor = torch.Tensor(1) query_F1_tensor.fill_(result_json["query_F1"]) self.tensorboard.record_val_query_F1(query_F1_tensor, itr) torch.save({ 'epoch': epoch + 1, 'state_dict': self.network.state_dict(), 'optimizer': self.opt.state_dict(), }, model_path)
def evaluate_cross_arch(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) # IMG_ROTATE_DET@CIFAR-10_TRAIN_ALL_TEST_ALL@model_conv3@data_conv3@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar") result = defaultdict(dict) update_BN = False for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)): ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] if split_protocol != args.protocol: continue arch = ma.group(3) adv_arch = ma.group(4) if adv_arch != args.cross_arch_source: continue epoch = int(ma.group(5)) lr = float(ma.group(6)) batch_size = int(ma.group(7)) print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates for shot in [0,1,5]: if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, is_train=False, load_mode=args.load_mode, protocol=split_protocol, no_random_way=True, adv_arch=args.cross_arch_target, fetch_attack_name=False) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformCV2(dataset, [1, 2]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100","SVHN"] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number, num_classes=2) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=update_BN) # FIXME update_BN=False会很高 if num_update == 0: shot = 0 result["{}@{}-->{}".format(dataset, args.cross_arch_source, args.cross_arch_target)][shot] = evaluate_result with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/cross_arch_{}--{}_using_{}_result_updateBN_{}.json".format(PY_ROOT, args.cross_arch_source, args.cross_arch_target, args.protocol, update_BN), "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()
def evaluate_finetune(model_path_list, lr, protocol): # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做 extract_pattern_detail = re.compile(".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar") tot_num_tasks = 20000 way = 2 query = 15 result = defaultdict(dict) assert protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II, "protocol {} is not TRAIN_I_TEST_II!".format(protocol) for model_path in model_path_list: ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) if dataset != "CIFAR-10": continue file_protocol = ma.group(2) if str(protocol) != file_protocol: continue adv_arch = ma.group(4) balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) arch = ma.group(3) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() import torch checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) shot = 1 meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=protocol, no_random_way=True, adv_arch=adv_arch) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) for num_update in range(0,51): evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=False) if num_update == 0: shot = 0 else: shot = 1 evaluate_result["shot"] = shot result["{}_{}".format(dataset,balance)][num_update] = evaluate_result return result
def evaluate_whitebox(dataset, arch, adv_arch, detector, attack_name, num_update, lr, protocol, load_mode,result): # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做 extract_pattern_detail = re.compile(".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar") tot_num_tasks = 20000 way = 2 query = 15 model_path = "{}/train_pytorch_model/white_box_model/DL_DET@{}_{}@model_{}@data_{}@epoch_40@class_2@lr_0.0001@balance_True.pth.tar".format( PY_ROOT, dataset, protocol, arch, adv_arch) assert os.path.exists(model_path), "{} is not exists".format(model_path) root_folder = IMAGE_DATA_ROOT[dataset] + "/adversarial_images/white_box@data_{}@det_{}/{}/".format(adv_arch, detector,attack_name) ma = extract_pattern_detail.match(model_path) balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) arch = ma.group(3) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) old_num_update = num_update # for shot in range(16): for shot in [0,1,5]: if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, load_mode=load_mode, detector=detector, attack_name=attack_name,root_folder=root_folder) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update,update_BN=True) if num_update == 0: shot = 0 result["{}_{}_{}_{}".format(dataset, attack_name, detector, adv_arch)][shot] = evaluate_result return result
def meta_zero_shot_evaluate(args): extract_pattern = re.compile( ".*/MAML@(.*?)_(.*?)@model_(.*?)@data.*?@epoch_(\d+)@meta_batch_size_(\d+)@way_(\d+)@shot_(\d+)@num_query_(\d+)@num_updates_(\d+)@lr_(.*?)@inner_lr_(.*?)@fixed_way_(.*?)@rotate_(.*?)\.pth.tar" ) extract_param_prefix = re.compile(".*/MAML@(.*?)\.pth.tar") report_result = defaultdict(dict) str2bool = lambda v: v.lower() in ("yes", "true", "t", "1") for model_path in glob.glob("{}/train_pytorch_model/{}/MAML@*".format( PY_ROOT, args.study_subject)): if str(args.split_protocol) not in model_path: continue ma_prefix = extract_param_prefix.match(model_path) param_prefix = ma_prefix.group(1) ma = extract_pattern.match(model_path) orig_ma = ma dataset = ma.group(1) split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] arch = ma.group(3) epoch = int(ma.group(4)) meta_batch_size = int(ma.group(5)) num_classes = int(ma.group(6)) num_support = int(ma.group(7)) num_query = int(ma.group(8)) # 用这个num_query来做 num_updates = int(ma.group(9)) meta_lr = float(ma.group(10)) inner_lr = float(ma.group(11)) fixe_way = str2bool(ma.group(12)) rotate = str2bool(ma.group(13)) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) key = "{}__{}".format(dataset, split_protocol) if args.cross_domain_source is not None: if dataset != args.cross_domain_source: continue dataset = args.cross_domain_target key = "{}-->{}__{}".format(args.cross_domain_source, dataset, split_protocol) learner = MetaLearner( dataset, num_classes, meta_batch_size, meta_lr, inner_lr, args.lr_decay_itr, epoch, 0, # zero_shot args.load_task_mode, split_protocol, arch, args.tot_num_tasks, num_support, num_query, # 这个num_query统一用15 no_random_way=True, tensorboard_data_prefix=param_prefix, train=False, adv_arch=args.adv_arch, need_val=True) learner.network.load_state_dict(checkpoint['state_dict'], strict=True) result_json = finetune_eval_task_accuracy( learner.network, learner.val_loader, learner.inner_step_size, learner.test_finetune_updates, update_BN=True) report_result[dataset][key] = result_json file_name = "{}/train_pytorch_model/{}/{}_result.json".format( PY_ROOT, args.study_subject, args.study_subject) if args.cross_domain_source: file_name = "{}/train_pytorch_model/{}/zero_shot_{}--{}_result.json".format( PY_ROOT, args.study_subject, args.cross_domain_source, args.cross_domain_target) with open(file_name, "w") as file_obj: file_obj.write(json.dumps(report_result)) file_obj.flush()
def meta_cross_arch_evaluate(args): # 1 shot 训练出来的模型只能用于1 shot的数据测试 extract_pattern = re.compile( ".*/MAML@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@meta_batch_size_(\d+)@way_(\d+)@shot_(\d+)@num_query_(\d+)@num_updates_(\d+)@lr_(.*?)@inner_lr_(.*?)@fixed_way_(.*?)@.*" ) extract_param_prefix = re.compile(".*/MAML@(.*?)\.pth.tar") report_result = defaultdict(dict) str2bool = lambda v: v.lower() in ("yes", "true", "t", "1") updateBN = True for shot in [1, 5]: for model_path in glob.glob("{}/train_pytorch_model/{}/MAML@*".format( PY_ROOT, args.study_subject)): ma_prefix = extract_param_prefix.match(model_path) param_prefix = ma_prefix.group(1) ma = extract_pattern.match(model_path) dataset = ma.group(1) split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] if split_protocol != args.split_protocol: continue arch = ma.group(3) data_arch = ma.group(4) if data_arch != args.cross_arch_source: continue epoch = int(ma.group(5)) meta_batch_size = int(ma.group(6)) num_classes = int(ma.group(7)) model_train_num_support = int(ma.group(8)) if shot != model_train_num_support: continue meta_lr = float(ma.group(11)) inner_lr = float(ma.group(12)) fixe_way = str2bool(ma.group(13)) if not fixe_way: continue print("=> loading checkpoint '{}'".format(model_path)) checkpoint = torch.load( model_path, map_location=lambda storage, location: storage) learner = MetaLearner(dataset, num_classes, meta_batch_size, meta_lr, inner_lr, args.lr_decay_itr, epoch, args.test_num_updates, args.load_task_mode, split_protocol, arch, args.tot_num_tasks, shot, 15, True, param_prefix, train=False, adv_arch=args.cross_arch_target, need_val=True) learner.network.load_state_dict(checkpoint['state_dict'], strict=True) result_json = finetune_eval_task_accuracy( learner.network, learner.val_loader, learner.inner_step_size, learner.test_finetune_updates, update_BN=updateBN) report_result[dataset + "@" + args.cross_arch_source + "--" + args.cross_arch_target][shot] = result_json with open( "{}/train_pytorch_model/{}/{}--{}@finetune_{}_result_updateBN_{}.json" .format(PY_ROOT, args.study_subject, args.cross_arch_source, args.cross_arch_target, args.test_num_updates, updateBN), "w") as file_obj: file_obj.write(json.dumps(report_result)) file_obj.flush()
def evaluate_zero_shot(model_path_list, lr, protocol, args): # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做 extract_pattern_detail = re.compile( ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar" ) tot_num_tasks = 20000 way = 2 query = 15 result = defaultdict(dict) for model_path in model_path_list: ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) file_protocol = ma.group(2) if str(protocol) != file_protocol: continue balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" key = "{}@{}__{}".format(dataset, balance, protocol) if args.cross_domain_source is not None: if dataset != args.cross_domain_source: continue dataset = args.cross_domain_target key = "{}@{}-->{}__{}".format(args.cross_domain_source, balance, dataset, protocol) print("evaluate_accuracy model :{}".format( os.path.basename(model_path))) arch = ma.group(3) adv_arch = ma.group(4) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( model_path, checkpoint['epoch'])) shot = 1 num_update = 0 meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=protocol, no_random_way=True, adv_arch=adv_arch) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=False) result[key][0] = evaluate_result return result
def evaluate_shots(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) # IMG_ROTATE_DET@ImageNet_TRAIN_I_TEST_II@model_resnet10@data_resnet10@epoch_10@lr_0.001@batch_10@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar" ) result = defaultdict(dict) # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar for model_path in glob.glob( "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*" .format(PY_ROOT)): ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) if dataset == "ImageNet": continue split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] if split_protocol != args.protocol: continue arch = ma.group(3) data_adv_arch = ma.group(4) lr = float(ma.group(6)) print("evaluate_accuracy model :{}".format( os.path.basename(model_path))) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates all_shots = [0, 1, 5] for shot in all_shots: if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=split_protocol, no_random_way=True, adv_arch=data_adv_arch) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) if arch == "resnet10": img_classifier_network = resnet10( num_classes=CLASS_NUM[dataset], in_channels=IN_CHANNELS[dataset], pretrained=False) elif arch == "resnet18": img_classifier_network = resnet18( num_classes=CLASS_NUM[dataset], in_channels=IN_CHANNELS[dataset], pretrained=False) elif arch == "conv3": img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformCV2(dataset, [1, 2]) layer_number = 3 if dataset in [ "ImageNet", "CIFAR-10", "CIFAR-100", "SVHN" ] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset], image_transform, layer_number, num_classes=2) checkpoint = torch.load( model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})".format( model_path, checkpoint['epoch'])) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=False) if num_update == 0: shot = 0 result[dataset + "@" + data_adv_arch][shot] = evaluate_result with open( "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/shots_result.json" .format(PY_ROOT), "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()
def evaluate_zero_shot(args): # 0-shot的时候请传递args.num_updates = 0 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0]) # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar extract_pattern_detail = re.compile( ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar") result = defaultdict(dict) # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*.tar".format(PY_ROOT)): ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) # if dataset != "MNIST": # continue split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] if split_protocol != args.protocol: continue arch = ma.group(3) adv_arch = ma.group(4) if adv_arch != args.adv_arch: continue epoch = int(ma.group(5)) lr = float(ma.group(6)) batch_size = int(ma.group(7)) tot_num_tasks = 20000 num_classes = 2 num_query = 15 old_num_update = args.num_updates shot = 1 key = "{}@{}".format(dataset, split_protocol) if args.cross_domain_source is not None: if dataset != args.cross_domain_source: continue dataset = args.cross_domain_target key = "{}-->{}___{}".format(args.cross_domain_source, dataset, split_protocol) print("evaluate_accuracy model :{}".format(os.path.basename(model_path))) meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=split_protocol, no_random_way=True, adv_arch="conv3") #FIXME adv arch还没做其他architecture的代码 data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], CLASS_NUM[dataset]) image_transform = ImageTransformCV2(dataset, [1, 2]) layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2 model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) model.cuda() print("=> loaded checkpoint '{}' (epoch {})" .format(model_path, checkpoint['epoch'])) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_updates=0, update_BN=False) result[key][0] = evaluate_result file_name = "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/zero_shot_result.json".format(PY_ROOT) if args.cross_domain_source: file_name = "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/zero_shot_{}--{}_result.json".format(PY_ROOT, args.cross_domain_source, args.cross_domain_target) with open(file_name, "w") as file_obj: file_obj.write(json.dumps(result)) file_obj.flush()
def meta_ablation_study_evaluate(args): extract_pattern = re.compile( ".*/MAML@(.*?)_(.*?)@model_(.*?)@data.*?@epoch_(\d+)@meta_batch_size_(\d+)@way_(\d+)@shot_(\d+)@num_query_(\d+)@num_updates_(\d+)@lr_(.*?)@inner_lr_(.*?)@fixed_way_(.*?)@rotate_(.*?)\.pth.tar" ) extract_param_prefix = re.compile(".*/MAML@(.*?)\.pth.tar") report_result = defaultdict(dict) str2bool = lambda v: v.lower() in ("yes", "true", "t", "1") for model_path in glob.glob("{}/train_pytorch_model/{}/MAML@*".format( PY_ROOT, args.study_subject)): if str(args.split_protocol) not in model_path: continue ma_prefix = extract_param_prefix.match(model_path) param_prefix = ma_prefix.group(1) ma = extract_pattern.match(model_path) orig_ma = ma dataset = ma.group(1) if dataset == "ImageNet": continue split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)] arch = ma.group(3) epoch = int(ma.group(4)) meta_batch_size = int(ma.group(5)) num_classes = int(ma.group(6)) num_support = int(ma.group(7)) num_query = int(ma.group(8)) # 用这个num_query来做 num_updates = int(ma.group(9)) meta_lr = float(ma.group(10)) inner_lr = float(ma.group(11)) fixe_way = str2bool(ma.group(12)) rotate = str2bool(ma.group(13)) extract_key = param_prefix if args.study_subject == "inner_update_ablation_study": extract_key = num_updates elif args.study_subject == "shots_ablation_study": extract_key = num_support elif args.study_subject == "cross_adv_group": extract_key = num_support elif args.study_subject == "tasks_ablation_study": extract_key = meta_batch_size elif args.study_subject == "ways_ablation_study": extract_key = num_classes elif args.study_subject == "random_vs_fix_way": extract_key = "shots_{}_fixed_way_{}".format(num_support, fixe_way) elif args.study_subject == "query_size_ablation_study": extract_key = num_query elif args.study_subject == "vs_deep_MAX": extract_key = num_support elif args.study_subject == "zero_shot": extract_key = "0-shot" print("=> loading checkpoint '{}'".format(model_path)) checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) if args.study_subject == "fine_tune_update_ablation_study": # 这个实验重做 learner = MetaLearner(dataset, num_classes, meta_batch_size, meta_lr, inner_lr, args.lr_decay_itr, epoch, num_updates, args.load_task_mode, split_protocol, arch, args.tot_num_tasks, num_support, 15, True, param_prefix, train=False, adv_arch=args.adv_arch, need_val=True) learner.network.load_state_dict(checkpoint['state_dict'], strict=True) for test_num_updates in range(1, 51): result_json = finetune_eval_task_accuracy(learner.network, learner.val_loader, inner_lr, test_num_updates, update_BN=True) report_result[dataset][test_num_updates] = result_json elif args.study_subject == "zero_shot": learner = MetaLearner( dataset, num_classes, meta_batch_size, meta_lr, inner_lr, args.lr_decay_itr, epoch, args.test_num_updates, args.load_task_mode, split_protocol, arch, args.tot_num_tasks, num_support, num_query, # 这个num_query统一用15 no_random_way=True, tensorboard_data_prefix=param_prefix, train=True, adv_arch=args.adv_arch, need_val=True) learner.network.load_state_dict(checkpoint['state_dict'], strict=True) result_json = learner.test_zero_shot_with_finetune_trainset() report_result[dataset][extract_key] = result_json else: load_mode = args.load_task_mode learner = MetaLearner( dataset, num_classes, meta_batch_size, meta_lr, inner_lr, args.lr_decay_itr, epoch, args.test_num_updates, load_mode, split_protocol, arch, args.tot_num_tasks, num_support, 15, # 这个num_query统一用15 True, param_prefix, train=False, adv_arch=args.adv_arch, need_val=True) learner.network.load_state_dict(checkpoint['state_dict'], strict=True) result_json = finetune_eval_task_accuracy(learner.network, learner.val_loader, inner_lr, args.test_num_updates, update_BN=True) report_result[dataset][extract_key] = result_json file_name = "{}/train_pytorch_model/{}/{}_result.json".format( PY_ROOT, args.study_subject, args.study_subject) if args.study_subject == "cross_domain": file_name = "{}/train_pytorch_model/{}/{}--{}_result.json".format( PY_ROOT, args.study_subject, args.cross_domain_source, args.cross_domain_target) with open(file_name, "w") as file_obj: file_obj.write(json.dumps(report_result)) file_obj.flush()
def evaluate_cross_arch(model_path_list, num_update, lr, protocol, src_arch, target_arch, updateBN): extract_pattern_detail = re.compile( ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar" ) tot_num_tasks = 20000 way = 2 query = 15 result = defaultdict(dict) for model_path in model_path_list: ma = extract_pattern_detail.match(model_path) dataset = ma.group(1) file_protocol = ma.group(2) if str(protocol) != file_protocol: continue arch = ma.group(3) adv_arch = ma.group(4) if adv_arch != src_arch: continue balance = ma.group(8) if balance == "True": balance = "balance" else: balance = "no_balance" print("evaluate_accuracy model :{}".format( os.path.basename(model_path))) model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2) model = model.cuda() checkpoint = torch.load(model_path, map_location=lambda storage, location: storage) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( model_path, checkpoint['epoch'])) old_num_update = num_update for shot in [0, 1, 5]: if shot == 0: shot = 1 num_update = 0 else: num_update = old_num_update meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query, dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD, protocol=protocol, no_random_way=True, adv_arch=target_arch, fetch_attack_name=False) data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True) evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=updateBN) if num_update == 0: shot = 0 result["{}-->{}@{}_{}".format(src_arch, target_arch, dataset, balance)][shot] = evaluate_result return result