Example #1
0
    def __init__(self,
                 dataset,
                 num_classes,
                 meta_batch_size,
                 meta_step_size,
                 inner_step_size, lr_decay_itr,
                 epoch,
                 num_inner_updates, load_task_mode, protocol, arch,
                 tot_num_tasks, num_support, num_query, no_random_way,
                 tensorboard_data_prefix, train=True, adv_arch="conv4", need_val=False):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.meta_batch_size = meta_batch_size  # task number per batch
        self.meta_step_size = meta_step_size
        self.inner_step_size = inner_step_size
        self.lr_decay_itr = lr_decay_itr
        self.epoch = epoch
        self.num_inner_updates = num_inner_updates
        self.test_finetune_updates = num_inner_updates
        # Make the nets
        if arch == "conv3":
            # network = FourConvs(IN_CHANNELS[self.dataset_name], IMAGE_SIZE[self.dataset_name], num_classes)
            network = Conv3(IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset], num_classes)
        elif arch == "resnet10":
            network = MetaNetwork(resnet10(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False),
                                  IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset])
        elif arch == "resnet18":
            network = MetaNetwork(resnet18(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False),
                                  IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset])

        self.network = network
        self.network.cuda()
        if train:
            trn_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, num_query,
                                          dataset, is_train=True, load_mode=load_task_mode,
                                          protocol=protocol,
                                          no_random_way=no_random_way, adv_arch=adv_arch, fetch_attack_name=False)
            # task number per mini-batch is controlled by DataLoader
            self.train_loader = DataLoader(trn_dataset, batch_size=meta_batch_size, shuffle=True, num_workers=4, pin_memory=True)
            self.tensorboard = TensorBoardWriter("{0}/pytorch_MAML_tensorboard".format(PY_ROOT),
                                                 tensorboard_data_prefix)
            os.makedirs("{0}/pytorch_MAML_tensorboard".format(PY_ROOT), exist_ok=True)
        if need_val:
            val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, 15,
                                          dataset, is_train=False, load_mode=load_task_mode,
                                          protocol=protocol,
                                          no_random_way=True, adv_arch=adv_arch, fetch_attack_name=False)
            self.val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=4, pin_memory=True) # 固定100个task,分别测每个task的准确率
        self.fast_net = InnerLoop(self.network, self.num_inner_updates,
                                  self.inner_step_size, self.meta_batch_size)  # 并行执行每个task
        self.fast_net.cuda()
        self.opt = Adam(self.network.parameters(), lr=meta_step_size)
Example #2
0
def evaluate_speed(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar")
    result = defaultdict(dict)
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset!="CIFAR-10":
            continue
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        epoch = int(ma.group(4))
        lr = float(ma.group(5))
        batch_size = int(ma.group(6))
        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        tot_num_tasks = 20000
        num_classes = 2

        num_query = 15
        old_num_update = args.num_updates

        all_shots = [0, 1,5]
        num_updates = args.num_updates
        for shot in all_shots:
            report_shot = shot
            if shot == 0:
                num_updates = 0
                shot = 1
            else:
                num_updates = args.num_updates
            meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query,
                                                dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD,
                                                protocol=split_protocol, no_random_way=True, adv_arch="conv3") #FIXME adv arch还没做其他architecture的代码
            data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
            img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                                   CLASS_NUM[dataset])
            image_transform = ImageTransformCV2(dataset, [1, 2])
            layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2
            model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2)
            checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(model_path, checkpoint['epoch']))
            evaluate_result = speed_test(model, data_loader, lr, num_updates, update_BN=False)
            result[dataset][report_shot] = evaluate_result
        break

    with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/speed_test.json".format(PY_ROOT), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
def evaluate_cross_arch(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_ALL_TEST_ALL@model_conv3@data_conv3@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar")
    result = defaultdict(dict)
    update_BN = False
    for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        adv_arch = ma.group(4)
        if adv_arch != args.cross_arch_source:
            continue

        epoch = int(ma.group(5))
        lr = float(ma.group(6))
        batch_size = int(ma.group(7))
        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        tot_num_tasks = 20000
        num_classes = 2
        num_query = 15
        old_num_update = args.num_updates
        for shot in [0,1,5]:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query,
                                                dataset, is_train=False, load_mode=args.load_mode,
                                                protocol=split_protocol, no_random_way=True, adv_arch=args.cross_arch_target, fetch_attack_name=False)
            data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
            img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                                   CLASS_NUM[dataset])
            image_transform = ImageTransformCV2(dataset, [1, 2])
            layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100","SVHN"] else 2
            model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number, num_classes=2)
            checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(model_path, checkpoint['epoch']))
            evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=update_BN)  # FIXME update_BN=False会很高
            if num_update == 0:
                shot = 0
            result["{}@{}-->{}".format(dataset, args.cross_arch_source, args.cross_arch_target)][shot] = evaluate_result
    with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/cross_arch_{}--{}_using_{}_result_updateBN_{}.json".format(PY_ROOT, args.cross_arch_source,
                                                                            args.cross_arch_target, args.protocol, update_BN), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
Example #4
0
def evaluate_finetune(model_path_list, lr, protocol):
    # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做
    extract_pattern_detail = re.compile(".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar")
    tot_num_tasks = 20000
    way =  2
    query = 15
    result = defaultdict(dict)
    assert protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II, "protocol {} is not TRAIN_I_TEST_II!".format(protocol)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset != "CIFAR-10":
            continue
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue
        adv_arch = ma.group(4)
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"
        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        arch = ma.group(3)
        model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        model = model.cuda()
        import torch
        checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(model_path, checkpoint['epoch']))
        shot = 1
        meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query,
                                            dataset, is_train=False,
                                            load_mode=LOAD_TASK_MODE.LOAD,
                                            protocol=protocol, no_random_way=True, adv_arch=adv_arch)
        data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
        for num_update in range(0,51):
            evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=False)
            if num_update == 0:
                shot = 0
            else:
                shot = 1
            evaluate_result["shot"] = shot
            result["{}_{}".format(dataset,balance)][num_update] = evaluate_result
    return result
Example #5
0
def main():
    args = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    # DL_IMAGE_CLASSIFIER_CIFAR-10@conv4@epoch_40@lr_0.0001@batch_500.pth.tar
    extract_pattern_detail = re.compile(
        ".*?DL_IMAGE_CLASSIFIER_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth\.tar"
    )
    # test_CIFAR-10_tot_num_tasks_20000_metabatch_10_way_5_shot_5_query_15.txt
    extract_dataset_pattern = re.compile(
        ".*?tot_num_tasks_(\d+)_metabatch_(\d+)_way_(\d+)_shot_(\d+)_query_(\d+).*"
    )
    result = {}
    for model_path in glob.glob(
            "{}/train_pytorch_model/DL_IMAGE_CLASSIFIER*".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        arch = ma.group(2)
        epoch = int(ma.group(3))
        lr = float(ma.group(4))
        batch = int(ma.group(5))
        if arch == "conv4":
            model = FourConvs(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                              CLASS_NUM[dataset])
        elif arch == "resnet10":
            model = resnet10(CLASS_NUM[dataset], IN_CHANNELS[dataset])
        elif arch == "resnet18":
            model = resnet18(CLASS_NUM[dataset], IN_CHANNELS[dataset])
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint["state_dict"])
        model.cuda()
        print("loading {}".format(model_path))
        detector = DetectionEvaluator(model, dataset)
        for pkl_task_file_name in glob.glob("{}/task/{}/test_{}*.pkl".format(
                PY_ROOT, args.split_data_protocol, dataset)):
            preprocessor = get_preprocessor(
                input_size=IMAGE_SIZE[dataset],
                input_channels=IN_CHANNELS[dataset])
            if dataset == "CIFAR-10":
                train_dataset = CIFAR10(IMAGE_DATA_ROOT[dataset],
                                        train=True,
                                        transform=preprocessor)
            elif dataset == "MNIST":
                train_dataset = MNIST(IMAGE_DATA_ROOT[dataset],
                                      train=True,
                                      transform=preprocessor,
                                      download=True)
            elif dataset == "F-MNIST":
                train_dataset = FashionMNIST(IMAGE_DATA_ROOT[dataset],
                                             train=True,
                                             transform=preprocessor,
                                             download=True)
            elif dataset == "SVHN":
                train_dataset = SVHN(IMAGE_DATA_ROOT[dataset],
                                     train=True,
                                     transform=preprocessor)
            ma_d = extract_dataset_pattern.match(pkl_task_file_name)
            tot_num_tasks = int(ma_d.group(1))
            num_classes = int(ma_d.group(3))
            num_support = int(ma_d.group(4))
            num_query = int(ma_d.group(5))
            meta_dataset = MetaTaskDataset(
                tot_num_tasks,
                num_classes,
                num_support,
                num_query,
                dataset,
                is_train=False,
                load_mode=LOAD_TASK_MODE.LOAD,
                pkl_task_dump_path=pkl_task_file_name,
                protocol=args.split_data_protocol)
            val_loader = DataLoader(meta_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=False)

            # train_imgs = get_train_data(train_dataset)
            train_imgs = []
            accuracy = detector.evaluate_detections(train_imgs, val_loader)
            key1 = os.path.basename(model_path)
            key1 = key1[:key1.rindex(".")]
            key = os.path.basename(pkl_task_file_name)
            key = key[:key.rindex(".")]
            result["{}|{}".format(key1, key)] = accuracy
    with open(args.output_path, "w") as file_obj:
        file_obj.write(json.dumps(result))
Example #6
0
def main():
    args = parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    transform = get_preprocessor(
        IN_CHANNELS[args.ds_name],
        IMAGE_SIZE[args.ds_name])  # NeuralFP的MNIST和F-MNIST实验需要重做,因为发现单通道bug
    kwargs = {'pin_memory': True}
    if not args.evaluate:  # 训练模式
        if args.ds_name == "MNIST":
            trn_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name],
                                         train=True,
                                         download=False,
                                         transform=transform)
            val_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name],
                                         train=False,
                                         download=False,
                                         transform=transform)
        elif args.ds_name == "F-MNIST":
            trn_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name],
                                                train=True,
                                                download=False,
                                                transform=transform)
            val_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name],
                                                train=False,
                                                download=False,
                                                transform=transform)
        elif args.ds_name == "CIFAR-10":
            trn_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name],
                                           train=True,
                                           download=False,
                                           transform=transform)
            val_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name],
                                           train=False,
                                           download=False,
                                           transform=transform)
        elif args.ds_name == "SVHN":
            trn_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name],
                               train=True,
                               transform=transform)
            val_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name],
                               train=False,
                               transform=transform)
        elif args.ds_name == "ImageNet":
            trn_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] +
                                              "/new2",
                                              train=True,
                                              transform=transform)
            val_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] +
                                              "/new2",
                                              train=False,
                                              transform=transform)

        train_loader = torch.utils.data.DataLoader(trn_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   **kwargs)
        test_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.test_batch_size,
            shuffle=False,
            num_workers=0,
            **kwargs)

        if args.arch == "conv3":
            network = Conv3(IN_CHANNELS[args.ds_name],
                            IMAGE_SIZE[args.ds_name], CLASS_NUM[args.ds_name])
        elif args.arch == "resnet10":
            network = resnet10(in_channels=IN_CHANNELS[args.ds_name],
                               num_classes=CLASS_NUM[args.ds_name])
        elif args.arch == "resnet18":
            network = resnet18(in_channels=IN_CHANNELS[args.ds_name],
                               num_classes=CLASS_NUM[args.ds_name])
        network.cuda()
        model_path = os.path.join(
            PY_ROOT, "train_pytorch_model/NF_Det",
            "NF_Det@{}@{}@epoch_{}@lr_{}@eps_{}@num_dx_{}@num_class_{}.pth.tar"
            .format(args.ds_name, args.arch, args.epochs, args.lr, args.eps,
                    args.num_dx, CLASS_NUM[args.ds_name]))
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        detector = NeuralFingerprintDetector(
            args.ds_name,
            network,
            args.num_dx,
            CLASS_NUM[args.ds_name],
            eps=args.eps,
            out_fp_dxdy_dir=args.output_dx_dy_dir)

        optimizer = optim.SGD(network.parameters(),
                              lr=args.lr,
                              weight_decay=1e-6,
                              momentum=0.9)
        resume_epoch = 0
        print("{}".format(model_path))
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path,
                                    lambda storage, location: storage)
            optimizer.load_state_dict(checkpoint["optimizer"])
            resume_epoch = checkpoint["epoch"]
            network.load_state_dict(checkpoint["state_dict"])

        for epoch in range(resume_epoch, args.epochs + 1):
            if (epoch == 1):
                detector.test(epoch,
                              test_loader,
                              test_length=0.1 * len(val_dataset))
            detector.train(epoch, optimizer, train_loader)

            print("Epoch{}, Saving model in {}".format(epoch, model_path))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': network.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, model_path)
    else:  # 测试模式

        if args.study_subject == "speed_test":
            evaluate_speed(args)
            return

        extract_pattern = re.compile(
            ".*NF_Det@(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@eps_(.*?)@num_dx_(\d+)@num_class_(\d+).pth.tar"
        )
        results = defaultdict(dict)
        for model_path in glob.glob(
                "{}/train_pytorch_model/NF_Det/NF_Det@*".format(PY_ROOT)):
            ma = extract_pattern.match(model_path)
            ds_name = ma.group(1)
            # if ds_name == "ImageNet": # FIXME
            #     continue
            # if ds_name != "CIFAR-10": # FIXME
            #     continue

            arch = ma.group(2)
            epoch = int(ma.group(3))
            num_dx = int(ma.group(6))
            eps = float(ma.group(5))
            if arch == "conv3":
                network = Conv3(IN_CHANNELS[ds_name], IMAGE_SIZE[ds_name],
                                CLASS_NUM[ds_name])
            elif arch == "resnet10":
                network = resnet10(in_channels=IN_CHANNELS[ds_name],
                                   num_classes=CLASS_NUM[ds_name])
            elif arch == "resnet18":
                network = resnet18(in_channels=IN_CHANNELS[ds_name],
                                   num_classes=CLASS_NUM[ds_name])
            reject_thresholds = [0. + 0.001 * i for i in range(2050)]
            network.load_state_dict(
                torch.load(model_path,
                           lambda storage, location: storage)["state_dict"])
            network.cuda()
            print("load {} over".format(model_path))
            detector = NeuralFingerprintDetector(
                ds_name,
                network,
                num_dx,
                CLASS_NUM[ds_name],
                eps=eps,
                out_fp_dxdy_dir=args.output_dx_dy_dir)
            # 不存在cross arch的概念
            if args.study_subject == "shots":
                all_shots = [0, 1, 5]
                # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699}
                old_updates = args.num_updates
                # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699}
                for shot in all_shots:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_updates
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        ds_name,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=args.adv_arch,
                        fetch_attack_name=True)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    # if args.profile:
                    #     cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof")
                    #     s = pstats.Stats("Profile.prof")
                    #     s.strip_dirs().sort_stats("time").print_stats()
                    # else:
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds,
                        args.num_updates, args.lr)
                    results[ds_name][report_shot] = {
                        "F1": F1,
                        "best_tau": tau,
                        "eps": eps,
                        "num_dx": num_dx,
                        "num_updates": args.num_updates,
                        "attack_stats": attacker_stats
                    }
                    print("shot {} done".format(shot))

            elif args.study_subject == "cross_domain":
                source_dataset, target_dataset = args.cross_domain_source, args.cross_domain_target
                if ds_name != source_dataset:
                    continue
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                old_num_update = args.num_updates
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                for shot in [0, 1, 5]:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_num_update
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        target_dataset,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=args.adv_arch,
                        fetch_attack_name=False)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, target_dataset, reject_thresholds,
                        args.num_updates, args.lr)
                    results["{}--{}@data_adv_arch_{}".format(
                        source_dataset, target_dataset,
                        args.adv_arch)][report_shot] = {
                            "F1": F1,
                            "best_tau": tau,
                            "eps": eps,
                            "num_dx": num_dx,
                            "num_updates": args.num_updates,
                            "attack_stats": attacker_stats
                        }
            elif args.study_subject == "cross_arch":
                target_arch = args.cross_arch_target
                old_num_update = args.num_updates
                for shot in [0, 1, 5]:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_num_update
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        ds_name,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=target_arch,
                        fetch_attack_name=False)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds,
                        args.num_updates, args.lr)
                    results["{}_target_arch_{}".format(
                        ds_name, target_arch)][report_shot] = {
                            "F1": F1,
                            "best_tau": tau,
                            "eps": eps,
                            "num_dx": num_dx,
                            "num_updates": args.num_updates,
                            "attack_stats": attacker_stats
                        }

            elif args.study_subject == "finetune_eval":
                shot = 1
                query_count = 15
                old_updates = args.num_updates
                num_way = 2
                num_query = 15
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                if ds_name != args.ds_name:
                    continue
                val_dataset = MetaTaskDataset(20000,
                                              num_way,
                                              shot,
                                              num_query,
                                              ds_name,
                                              is_train=False,
                                              load_mode=args.load_task_mode,
                                              protocol=args.protocol,
                                              no_random_way=True,
                                              adv_arch=args.adv_arch)
                adv_val_loader = torch.utils.data.DataLoader(val_dataset,
                                                             batch_size=100,
                                                             shuffle=False,
                                                             **kwargs)
                args.num_updates = 50
                for num_update in range(0, 51):
                    # if args.profile:
                    #     cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof")
                    #     s = pstats.Stats("Profile.prof")
                    #     s.strip_dirs().sort_stats("time").print_stats()
                    # else:
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds, num_update,
                        args.lr)
                    results[ds_name][num_update] = {
                        "F1": F1,
                        "best_tau": tau,
                        "eps": eps,
                        "num_dx": num_dx,
                        "num_updates": num_update,
                        "attack_stats": attacker_stats
                    }
                    print("finetune {} done".format(shot))

        if not args.profile:
            if args.study_subject == "cross_domain":
                filename = "{}/train_pytorch_model/NF_Det/cross_domain_{}--{}@adv_arch_{}.json".format(
                    PY_ROOT, args.cross_domain_source,
                    args.cross_domain_target, args.adv_arch)
            elif args.study_subject == "cross_arch":
                filename = "{}/train_pytorch_model/NF_Det/cross_arch_target_{}.json".format(
                    PY_ROOT, args.cross_arch_target)
            else:
                filename = "{}/train_pytorch_model/NF_Det/{}@data_{}@protocol_{}@lr_{}@finetune_{}.json".format(
                    PY_ROOT, args.study_subject, args.adv_arch, args.protocol,
                    args.lr, args.num_updates)
            with open(filename, "w") as file_obj:
                file_obj.write(json.dumps(results))
                file_obj.flush()
Example #7
0
def evaluate_speed(args):
    extract_pattern = re.compile(
        ".*NF_Det@(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@eps_(.*?)@num_dx_(\d+)@num_class_(\d+).pth.tar"
    )
    results = defaultdict(dict)
    for model_path in glob.glob(
            "{}/train_pytorch_model/NF_Det/NF_Det@*".format(PY_ROOT)):
        ma = extract_pattern.match(model_path)
        ds_name = ma.group(1)
        if ds_name != "CIFAR-10":
            continue
        arch = ma.group(2)
        epoch = int(ma.group(3))
        num_dx = int(ma.group(6))
        eps = float(ma.group(5))
        network = Conv3(IN_CHANNELS[ds_name], IMAGE_SIZE[ds_name],
                        CLASS_NUM[ds_name])

        reject_thresholds = [0. + 0.001 * i for i in range(2050)]
        network.load_state_dict(
            torch.load(model_path,
                       lambda storage, location: storage)["state_dict"])
        network.cuda()
        print("load {} over".format(model_path))
        detector = NeuralFingerprintDetector(
            ds_name,
            network,
            num_dx,
            CLASS_NUM[ds_name],
            eps=eps,
            out_fp_dxdy_dir=args.output_dx_dy_dir)
        # 不存在cross arch的概念
        all_shots = [0, 1, 5]
        for shot in all_shots:
            report_shot = shot
            if shot == 0:
                num_updates = 0
                shot = 1
            else:
                num_updates = args.num_updates

            num_way = 2
            num_query = 15
            val_dataset = MetaTaskDataset(20000,
                                          num_way,
                                          shot,
                                          num_query,
                                          ds_name,
                                          is_train=False,
                                          load_mode=args.load_task_mode,
                                          protocol=args.protocol,
                                          no_random_way=True,
                                          adv_arch=args.adv_arch)
            adv_val_loader = torch.utils.data.DataLoader(val_dataset,
                                                         batch_size=100,
                                                         shuffle=False)
            mean_time, var_time = detector.test_speed(
                adv_val_loader,
                ds_name,
                reject_thresholds,
                num_updates,
                args.lr,
            )
            results[ds_name][report_shot] = {
                "mean_time": mean_time,
                "var_time": var_time
            }
            print("shot {} done".format(shot))
        break

    file_name = "{}/train_pytorch_model/NF_Det/speed_test_result.json".format(
        PY_ROOT)
    with open(file_name, "w") as file_obj:
        file_obj.write(json.dumps(results))
        file_obj.flush()
Example #8
0
def evaluate_zero_shot(model_path_list, lr, protocol, args):
    # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做
    extract_pattern_detail = re.compile(
        ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar"
    )
    tot_num_tasks = 20000
    way = 2
    query = 15
    result = defaultdict(dict)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"

        key = "{}@{}__{}".format(dataset, balance, protocol)
        if args.cross_domain_source is not None:
            if dataset != args.cross_domain_source:
                continue
            dataset = args.cross_domain_target
            key = "{}@{}-->{}__{}".format(args.cross_domain_source, balance,
                                          dataset, protocol)

        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        arch = ma.group(3)
        adv_arch = ma.group(4)
        model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        model = model.cuda()

        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
        shot = 1
        num_update = 0
        meta_task_dataset = MetaTaskDataset(tot_num_tasks,
                                            way,
                                            shot,
                                            query,
                                            dataset,
                                            is_train=False,
                                            load_mode=LOAD_TASK_MODE.LOAD,
                                            protocol=protocol,
                                            no_random_way=True,
                                            adv_arch=adv_arch)
        data_loader = DataLoader(meta_task_dataset,
                                 batch_size=100,
                                 shuffle=False,
                                 pin_memory=True)
        evaluate_result = finetune_eval_task_accuracy(model,
                                                      data_loader,
                                                      lr,
                                                      num_update,
                                                      update_BN=False)
        result[key][0] = evaluate_result
    return result
def evaluate_cross_arch(model_path_list, num_update, lr, protocol, src_arch,
                        target_arch, updateBN):
    extract_pattern_detail = re.compile(
        ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar"
    )
    tot_num_tasks = 20000
    way = 2
    query = 15
    result = defaultdict(dict)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue

        arch = ma.group(3)
        adv_arch = ma.group(4)
        if adv_arch != src_arch:
            continue
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"
        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        model = model.cuda()
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
        old_num_update = num_update
        for shot in [0, 1, 5]:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(tot_num_tasks,
                                                way,
                                                shot,
                                                query,
                                                dataset,
                                                is_train=False,
                                                load_mode=LOAD_TASK_MODE.LOAD,
                                                protocol=protocol,
                                                no_random_way=True,
                                                adv_arch=target_arch,
                                                fetch_attack_name=False)
            data_loader = DataLoader(meta_task_dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     pin_memory=True)
            evaluate_result = finetune_eval_task_accuracy(model,
                                                          data_loader,
                                                          lr,
                                                          num_update,
                                                          update_BN=updateBN)
            if num_update == 0:
                shot = 0
            result["{}-->{}@{}_{}".format(src_arch, target_arch, dataset,
                                          balance)][shot] = evaluate_result
    return result