Ejemplo n.º 1
0
    def __init__(self, dataset, num_classes, meta_batch_size, meta_step_size,
                 inner_step_size, lr_decay_itr, epoch, num_inner_updates,
                 load_task_mode, arch, tot_num_tasks, num_support, detector,
                 attack_name, root_folder):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.meta_batch_size = meta_batch_size  # task number per batch
        self.meta_step_size = meta_step_size
        self.inner_step_size = inner_step_size
        self.lr_decay_itr = lr_decay_itr
        self.epoch = epoch
        self.num_inner_updates = num_inner_updates
        self.test_finetune_updates = num_inner_updates
        # Make the nets
        if arch == "conv3":
            # network = FourConvs(IN_CHANNELS[self.dataset_name], IMAGE_SIZE[self.dataset_name], num_classes)
            network = Conv3(IN_CHANNELS[self.dataset],
                            IMAGE_SIZE[self.dataset], num_classes)
        elif arch == "resnet10":
            network = resnet10(num_classes, pretrained=False)
        elif arch == "resnet18":
            network = resnet18(num_classes, pretrained=False)
        self.network = network
        self.network.cuda()

        val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support,
                                      15, dataset, load_task_mode, detector,
                                      attack_name, root_folder)
        self.val_loader = DataLoader(
            val_dataset,
            batch_size=100,
            shuffle=False,
            num_workers=0,
            pin_memory=True)  # 固定100个task,分别测每个task的准确率
Ejemplo n.º 2
0
def build_network(dataset, arch, model_path):
    # if dataset!="ImageNet":
    #     assert os.path.exists(model_path), "{} not exists!".format(model_path)

    if arch in models.__dict__:
        print("=> using pre-trained model '{}'".format(arch))
        img_classifier_network = models.__dict__[arch](pretrained=False)
    else:
        print("=> creating model '{}'".format(arch))
        if arch == "resnet10":
            img_classifier_network = resnet10(num_classes=CLASS_NUM[dataset],
                                              in_channels=IN_CHANNELS[dataset],
                                              pretrained=False)
        elif arch == "resnet18":
            img_classifier_network = resnet18(num_classes=CLASS_NUM[dataset],
                                              in_channels=IN_CHANNELS[dataset],
                                              pretrained=False)
        elif arch == "conv3":
            img_classifier_network = Conv3(IN_CHANNELS[dataset],
                                           IMAGE_SIZE[dataset],
                                           CLASS_NUM[dataset])
    if os.path.exists(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        img_classifier_network.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
    return img_classifier_network
Ejemplo n.º 3
0
    def __init__(self,
                 dataset,
                 num_classes,
                 meta_batch_size,
                 meta_step_size,
                 inner_step_size, lr_decay_itr,
                 epoch,
                 num_inner_updates, load_task_mode, protocol, arch,
                 tot_num_tasks, num_support, num_query, no_random_way,
                 tensorboard_data_prefix, train=True, adv_arch="conv4", need_val=False):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.meta_batch_size = meta_batch_size  # task number per batch
        self.meta_step_size = meta_step_size
        self.inner_step_size = inner_step_size
        self.lr_decay_itr = lr_decay_itr
        self.epoch = epoch
        self.num_inner_updates = num_inner_updates
        self.test_finetune_updates = num_inner_updates
        # Make the nets
        if arch == "conv3":
            # network = FourConvs(IN_CHANNELS[self.dataset_name], IMAGE_SIZE[self.dataset_name], num_classes)
            network = Conv3(IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset], num_classes)
        elif arch == "resnet10":
            network = MetaNetwork(resnet10(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False),
                                  IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset])
        elif arch == "resnet18":
            network = MetaNetwork(resnet18(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False),
                                  IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset])

        self.network = network
        self.network.cuda()
        if train:
            trn_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, num_query,
                                          dataset, is_train=True, load_mode=load_task_mode,
                                          protocol=protocol,
                                          no_random_way=no_random_way, adv_arch=adv_arch, fetch_attack_name=False)
            # task number per mini-batch is controlled by DataLoader
            self.train_loader = DataLoader(trn_dataset, batch_size=meta_batch_size, shuffle=True, num_workers=4, pin_memory=True)
            self.tensorboard = TensorBoardWriter("{0}/pytorch_MAML_tensorboard".format(PY_ROOT),
                                                 tensorboard_data_prefix)
            os.makedirs("{0}/pytorch_MAML_tensorboard".format(PY_ROOT), exist_ok=True)
        if need_val:
            val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, 15,
                                          dataset, is_train=False, load_mode=load_task_mode,
                                          protocol=protocol,
                                          no_random_way=True, adv_arch=adv_arch, fetch_attack_name=False)
            self.val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=4, pin_memory=True) # 固定100个task,分别测每个task的准确率
        self.fast_net = InnerLoop(self.network, self.num_inner_updates,
                                  self.inner_step_size, self.meta_batch_size)  # 并行执行每个task
        self.fast_net.cuda()
        self.opt = Adam(self.network.parameters(), lr=meta_step_size)
Ejemplo n.º 4
0
def evaluate_shots(model_path_list, num_update, lr, protocol):
    # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做
    extract_pattern_detail = re.compile(
        ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar"
    )
    tot_num_tasks = 20000
    way = 2
    query = 15
    result = defaultdict(dict)
    assert protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II, "protocol {} is not TRAIN_I_TEST_II!".format(
        protocol)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset == "ImageNet":
            continue
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"

        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        arch = ma.group(3)
        adv_arch = ma.group(4)

        if arch == "conv3":
            model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        elif arch == "resnet10":
            model = resnet10(2,
                             in_channels=IN_CHANNELS[dataset],
                             pretrained=False)
        elif arch == "resnet18":
            model = resnet18(2,
                             in_channels=IN_CHANNELS[dataset],
                             pretrained=False)
        model = model.cuda()
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
        old_num_update = num_update
        # for shot in range(16):
        for shot in [0, 1, 5]:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(
                tot_num_tasks,
                way,
                shot,
                query,
                dataset,
                is_train=False,
                load_mode=LOAD_TASK_MODE.NO_LOAD,
                protocol=protocol,
                no_random_way=True,
                adv_arch=adv_arch)
            data_loader = DataLoader(meta_task_dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     pin_memory=True)
            evaluate_result = finetune_eval_task_accuracy(model,
                                                          data_loader,
                                                          lr,
                                                          num_update,
                                                          update_BN=False)
            if num_update == 0:
                shot = 0
            result["{}@{}@{}".format(dataset, balance,
                                     adv_arch)][shot] = evaluate_result
    return result
Ejemplo n.º 5
0
def main():
    args = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    # DL_IMAGE_CLASSIFIER_CIFAR-10@conv4@epoch_40@lr_0.0001@batch_500.pth.tar
    extract_pattern_detail = re.compile(
        ".*?DL_IMAGE_CLASSIFIER_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth\.tar"
    )
    # test_CIFAR-10_tot_num_tasks_20000_metabatch_10_way_5_shot_5_query_15.txt
    extract_dataset_pattern = re.compile(
        ".*?tot_num_tasks_(\d+)_metabatch_(\d+)_way_(\d+)_shot_(\d+)_query_(\d+).*"
    )
    result = {}
    for model_path in glob.glob(
            "{}/train_pytorch_model/DL_IMAGE_CLASSIFIER*".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        arch = ma.group(2)
        epoch = int(ma.group(3))
        lr = float(ma.group(4))
        batch = int(ma.group(5))
        if arch == "conv4":
            model = FourConvs(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                              CLASS_NUM[dataset])
        elif arch == "resnet10":
            model = resnet10(CLASS_NUM[dataset], IN_CHANNELS[dataset])
        elif arch == "resnet18":
            model = resnet18(CLASS_NUM[dataset], IN_CHANNELS[dataset])
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint["state_dict"])
        model.cuda()
        print("loading {}".format(model_path))
        detector = DetectionEvaluator(model, dataset)
        for pkl_task_file_name in glob.glob("{}/task/{}/test_{}*.pkl".format(
                PY_ROOT, args.split_data_protocol, dataset)):
            preprocessor = get_preprocessor(
                input_size=IMAGE_SIZE[dataset],
                input_channels=IN_CHANNELS[dataset])
            if dataset == "CIFAR-10":
                train_dataset = CIFAR10(IMAGE_DATA_ROOT[dataset],
                                        train=True,
                                        transform=preprocessor)
            elif dataset == "MNIST":
                train_dataset = MNIST(IMAGE_DATA_ROOT[dataset],
                                      train=True,
                                      transform=preprocessor,
                                      download=True)
            elif dataset == "F-MNIST":
                train_dataset = FashionMNIST(IMAGE_DATA_ROOT[dataset],
                                             train=True,
                                             transform=preprocessor,
                                             download=True)
            elif dataset == "SVHN":
                train_dataset = SVHN(IMAGE_DATA_ROOT[dataset],
                                     train=True,
                                     transform=preprocessor)
            ma_d = extract_dataset_pattern.match(pkl_task_file_name)
            tot_num_tasks = int(ma_d.group(1))
            num_classes = int(ma_d.group(3))
            num_support = int(ma_d.group(4))
            num_query = int(ma_d.group(5))
            meta_dataset = MetaTaskDataset(
                tot_num_tasks,
                num_classes,
                num_support,
                num_query,
                dataset,
                is_train=False,
                load_mode=LOAD_TASK_MODE.LOAD,
                pkl_task_dump_path=pkl_task_file_name,
                protocol=args.split_data_protocol)
            val_loader = DataLoader(meta_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=False)

            # train_imgs = get_train_data(train_dataset)
            train_imgs = []
            accuracy = detector.evaluate_detections(train_imgs, val_loader)
            key1 = os.path.basename(model_path)
            key1 = key1[:key1.rindex(".")]
            key = os.path.basename(pkl_task_file_name)
            key = key[:key.rindex(".")]
            result["{}|{}".format(key1, key)] = accuracy
    with open(args.output_path, "w") as file_obj:
        file_obj.write(json.dumps(result))
Ejemplo n.º 6
0
def main():
    args = parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    transform = get_preprocessor(
        IN_CHANNELS[args.ds_name],
        IMAGE_SIZE[args.ds_name])  # NeuralFP的MNIST和F-MNIST实验需要重做,因为发现单通道bug
    kwargs = {'pin_memory': True}
    if not args.evaluate:  # 训练模式
        if args.ds_name == "MNIST":
            trn_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name],
                                         train=True,
                                         download=False,
                                         transform=transform)
            val_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name],
                                         train=False,
                                         download=False,
                                         transform=transform)
        elif args.ds_name == "F-MNIST":
            trn_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name],
                                                train=True,
                                                download=False,
                                                transform=transform)
            val_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name],
                                                train=False,
                                                download=False,
                                                transform=transform)
        elif args.ds_name == "CIFAR-10":
            trn_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name],
                                           train=True,
                                           download=False,
                                           transform=transform)
            val_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name],
                                           train=False,
                                           download=False,
                                           transform=transform)
        elif args.ds_name == "SVHN":
            trn_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name],
                               train=True,
                               transform=transform)
            val_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name],
                               train=False,
                               transform=transform)
        elif args.ds_name == "ImageNet":
            trn_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] +
                                              "/new2",
                                              train=True,
                                              transform=transform)
            val_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] +
                                              "/new2",
                                              train=False,
                                              transform=transform)

        train_loader = torch.utils.data.DataLoader(trn_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   **kwargs)
        test_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.test_batch_size,
            shuffle=False,
            num_workers=0,
            **kwargs)

        if args.arch == "conv3":
            network = Conv3(IN_CHANNELS[args.ds_name],
                            IMAGE_SIZE[args.ds_name], CLASS_NUM[args.ds_name])
        elif args.arch == "resnet10":
            network = resnet10(in_channels=IN_CHANNELS[args.ds_name],
                               num_classes=CLASS_NUM[args.ds_name])
        elif args.arch == "resnet18":
            network = resnet18(in_channels=IN_CHANNELS[args.ds_name],
                               num_classes=CLASS_NUM[args.ds_name])
        network.cuda()
        model_path = os.path.join(
            PY_ROOT, "train_pytorch_model/NF_Det",
            "NF_Det@{}@{}@epoch_{}@lr_{}@eps_{}@num_dx_{}@num_class_{}.pth.tar"
            .format(args.ds_name, args.arch, args.epochs, args.lr, args.eps,
                    args.num_dx, CLASS_NUM[args.ds_name]))
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        detector = NeuralFingerprintDetector(
            args.ds_name,
            network,
            args.num_dx,
            CLASS_NUM[args.ds_name],
            eps=args.eps,
            out_fp_dxdy_dir=args.output_dx_dy_dir)

        optimizer = optim.SGD(network.parameters(),
                              lr=args.lr,
                              weight_decay=1e-6,
                              momentum=0.9)
        resume_epoch = 0
        print("{}".format(model_path))
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path,
                                    lambda storage, location: storage)
            optimizer.load_state_dict(checkpoint["optimizer"])
            resume_epoch = checkpoint["epoch"]
            network.load_state_dict(checkpoint["state_dict"])

        for epoch in range(resume_epoch, args.epochs + 1):
            if (epoch == 1):
                detector.test(epoch,
                              test_loader,
                              test_length=0.1 * len(val_dataset))
            detector.train(epoch, optimizer, train_loader)

            print("Epoch{}, Saving model in {}".format(epoch, model_path))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': network.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, model_path)
    else:  # 测试模式

        if args.study_subject == "speed_test":
            evaluate_speed(args)
            return

        extract_pattern = re.compile(
            ".*NF_Det@(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@eps_(.*?)@num_dx_(\d+)@num_class_(\d+).pth.tar"
        )
        results = defaultdict(dict)
        for model_path in glob.glob(
                "{}/train_pytorch_model/NF_Det/NF_Det@*".format(PY_ROOT)):
            ma = extract_pattern.match(model_path)
            ds_name = ma.group(1)
            # if ds_name == "ImageNet": # FIXME
            #     continue
            # if ds_name != "CIFAR-10": # FIXME
            #     continue

            arch = ma.group(2)
            epoch = int(ma.group(3))
            num_dx = int(ma.group(6))
            eps = float(ma.group(5))
            if arch == "conv3":
                network = Conv3(IN_CHANNELS[ds_name], IMAGE_SIZE[ds_name],
                                CLASS_NUM[ds_name])
            elif arch == "resnet10":
                network = resnet10(in_channels=IN_CHANNELS[ds_name],
                                   num_classes=CLASS_NUM[ds_name])
            elif arch == "resnet18":
                network = resnet18(in_channels=IN_CHANNELS[ds_name],
                                   num_classes=CLASS_NUM[ds_name])
            reject_thresholds = [0. + 0.001 * i for i in range(2050)]
            network.load_state_dict(
                torch.load(model_path,
                           lambda storage, location: storage)["state_dict"])
            network.cuda()
            print("load {} over".format(model_path))
            detector = NeuralFingerprintDetector(
                ds_name,
                network,
                num_dx,
                CLASS_NUM[ds_name],
                eps=eps,
                out_fp_dxdy_dir=args.output_dx_dy_dir)
            # 不存在cross arch的概念
            if args.study_subject == "shots":
                all_shots = [0, 1, 5]
                # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699}
                old_updates = args.num_updates
                # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699}
                for shot in all_shots:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_updates
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        ds_name,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=args.adv_arch,
                        fetch_attack_name=True)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    # if args.profile:
                    #     cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof")
                    #     s = pstats.Stats("Profile.prof")
                    #     s.strip_dirs().sort_stats("time").print_stats()
                    # else:
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds,
                        args.num_updates, args.lr)
                    results[ds_name][report_shot] = {
                        "F1": F1,
                        "best_tau": tau,
                        "eps": eps,
                        "num_dx": num_dx,
                        "num_updates": args.num_updates,
                        "attack_stats": attacker_stats
                    }
                    print("shot {} done".format(shot))

            elif args.study_subject == "cross_domain":
                source_dataset, target_dataset = args.cross_domain_source, args.cross_domain_target
                if ds_name != source_dataset:
                    continue
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                old_num_update = args.num_updates
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                for shot in [0, 1, 5]:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_num_update
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        target_dataset,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=args.adv_arch,
                        fetch_attack_name=False)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, target_dataset, reject_thresholds,
                        args.num_updates, args.lr)
                    results["{}--{}@data_adv_arch_{}".format(
                        source_dataset, target_dataset,
                        args.adv_arch)][report_shot] = {
                            "F1": F1,
                            "best_tau": tau,
                            "eps": eps,
                            "num_dx": num_dx,
                            "num_updates": args.num_updates,
                            "attack_stats": attacker_stats
                        }
            elif args.study_subject == "cross_arch":
                target_arch = args.cross_arch_target
                old_num_update = args.num_updates
                for shot in [0, 1, 5]:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_num_update
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        ds_name,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=target_arch,
                        fetch_attack_name=False)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds,
                        args.num_updates, args.lr)
                    results["{}_target_arch_{}".format(
                        ds_name, target_arch)][report_shot] = {
                            "F1": F1,
                            "best_tau": tau,
                            "eps": eps,
                            "num_dx": num_dx,
                            "num_updates": args.num_updates,
                            "attack_stats": attacker_stats
                        }

            elif args.study_subject == "finetune_eval":
                shot = 1
                query_count = 15
                old_updates = args.num_updates
                num_way = 2
                num_query = 15
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                if ds_name != args.ds_name:
                    continue
                val_dataset = MetaTaskDataset(20000,
                                              num_way,
                                              shot,
                                              num_query,
                                              ds_name,
                                              is_train=False,
                                              load_mode=args.load_task_mode,
                                              protocol=args.protocol,
                                              no_random_way=True,
                                              adv_arch=args.adv_arch)
                adv_val_loader = torch.utils.data.DataLoader(val_dataset,
                                                             batch_size=100,
                                                             shuffle=False,
                                                             **kwargs)
                args.num_updates = 50
                for num_update in range(0, 51):
                    # if args.profile:
                    #     cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof")
                    #     s = pstats.Stats("Profile.prof")
                    #     s.strip_dirs().sort_stats("time").print_stats()
                    # else:
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds, num_update,
                        args.lr)
                    results[ds_name][num_update] = {
                        "F1": F1,
                        "best_tau": tau,
                        "eps": eps,
                        "num_dx": num_dx,
                        "num_updates": num_update,
                        "attack_stats": attacker_stats
                    }
                    print("finetune {} done".format(shot))

        if not args.profile:
            if args.study_subject == "cross_domain":
                filename = "{}/train_pytorch_model/NF_Det/cross_domain_{}--{}@adv_arch_{}.json".format(
                    PY_ROOT, args.cross_domain_source,
                    args.cross_domain_target, args.adv_arch)
            elif args.study_subject == "cross_arch":
                filename = "{}/train_pytorch_model/NF_Det/cross_arch_target_{}.json".format(
                    PY_ROOT, args.cross_arch_target)
            else:
                filename = "{}/train_pytorch_model/NF_Det/{}@data_{}@protocol_{}@lr_{}@finetune_{}.json".format(
                    PY_ROOT, args.study_subject, args.adv_arch, args.protocol,
                    args.lr, args.num_updates)
            with open(filename, "w") as file_obj:
                file_obj.write(json.dumps(results))
                file_obj.flush()
Ejemplo n.º 7
0
def main_train_worker(gpu, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    # create model
    if args.pretrained and args.arch in models.__dict__:
        print("=> using pre-trained model '{}'".format(args.arch))
        network = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        if args.arch == "resnet10":
            network = resnet10(num_classes=CLASS_NUM[args.dataset], in_channels=IN_CHANNELS[args.dataset])
        elif args.arch == "resnet18":
            network = resnet18(num_classes=CLASS_NUM[args.dataset], in_channels=IN_CHANNELS[args.dataset])
        elif args.arch == "conv3":
            network = Conv3(IN_CHANNELS[args.dataset], IMAGE_SIZE[args.dataset],CLASS_NUM[args.dataset])

    if args.arch.startswith("resnet"):
        network.avgpool = Identity()
        network.fc = nn.Linear(512, CLASS_NUM[args.dataset])
    elif args.arch.startswith("vgg"):
        network.classifier[6] = nn.Linear(4096, CLASS_NUM[args.dataset])

    model_path = './train_pytorch_model/DL_IMAGE_CLASSIFIER_{}@{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format(
        args.dataset, args.arch, args.epochs, args.lr, args.batch_size)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    preprocessor = get_preprocessor(IN_CHANNELS[args.dataset])
    network.cuda()

    # define loss function (criterion) and optimizer
    image_classifier_loss = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.Adam(network.parameters(), args.lr,
                                weight_decay=args.weight_decay)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            network.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.dataset == "CIFAR-10":
        train_dataset = CIFAR10(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor)
        val_dataset = CIFAR10(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor)
    elif args.dataset == "MNIST":
        train_dataset = MNIST(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor, download=True)
        val_dataset = MNIST(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor, download=True)
    elif args.dataset == "F-MNIST":
        train_dataset = FashionMNIST(IMAGE_DATA_ROOT[args.dataset], train=True,transform=preprocessor, download=True)
        val_dataset = FashionMNIST(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor, download=True)
    elif args.dataset=="SVHN":
        train_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor)
        val_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    for epoch in range(0, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, network, image_classifier_loss, optimizer, epoch, args)

        # evaluate_accuracy on validation set
        acc1 = validate(val_loader, network, image_classifier_loss, args)
        print(acc1)
        # remember best acc@1 and save checkpoint

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': network.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=model_path)
Ejemplo n.º 8
0
def evaluate_shots(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@ImageNet_TRAIN_I_TEST_II@model_resnet10@data_resnet10@epoch_10@lr_0.001@batch_10@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar"
    )
    result = defaultdict(dict)
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    for model_path in glob.glob(
            "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*"
            .format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset == "ImageNet":
            continue
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        data_adv_arch = ma.group(4)

        lr = float(ma.group(6))
        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        tot_num_tasks = 20000
        num_classes = 2

        num_query = 15
        old_num_update = args.num_updates

        all_shots = [0, 1, 5]
        for shot in all_shots:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(tot_num_tasks,
                                                num_classes,
                                                shot,
                                                num_query,
                                                dataset,
                                                is_train=False,
                                                load_mode=LOAD_TASK_MODE.LOAD,
                                                protocol=split_protocol,
                                                no_random_way=True,
                                                adv_arch=data_adv_arch)
            data_loader = DataLoader(meta_task_dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     pin_memory=True)

            if arch == "resnet10":
                img_classifier_network = resnet10(
                    num_classes=CLASS_NUM[dataset],
                    in_channels=IN_CHANNELS[dataset],
                    pretrained=False)
            elif arch == "resnet18":
                img_classifier_network = resnet18(
                    num_classes=CLASS_NUM[dataset],
                    in_channels=IN_CHANNELS[dataset],
                    pretrained=False)
            elif arch == "conv3":
                img_classifier_network = Conv3(IN_CHANNELS[dataset],
                                               IMAGE_SIZE[dataset],
                                               CLASS_NUM[dataset])

            image_transform = ImageTransformCV2(dataset, [1, 2])
            layer_number = 3 if dataset in [
                "ImageNet", "CIFAR-10", "CIFAR-100", "SVHN"
            ] else 2
            model = Detector(dataset,
                             img_classifier_network,
                             CLASS_NUM[dataset],
                             image_transform,
                             layer_number,
                             num_classes=2)
            checkpoint = torch.load(
                model_path, map_location=lambda storage, location: storage)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})".format(
                model_path, checkpoint['epoch']))
            evaluate_result = finetune_eval_task_accuracy(model,
                                                          data_loader,
                                                          lr,
                                                          num_update,
                                                          update_BN=False)
            if num_update == 0:
                shot = 0
            result[dataset + "@" + data_adv_arch][shot] = evaluate_result
    with open(
            "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/shots_result.json"
            .format(PY_ROOT), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
def main():
    args = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    preprocessor = get_preprocessor(input_channels=IN_CHANNELS[args.dataset])

    if args.dataset == "CIFAR-10":
        train_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.dataset],
                                         train=True,
                                         transform=preprocessor)
        val_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.dataset],
                                       train=False,
                                       transform=preprocessor)
    elif args.dataset == "MNIST":
        train_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.dataset],
                                       train=True,
                                       transform=preprocessor,
                                       download=True)
        val_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.dataset],
                                     train=False,
                                     transform=preprocessor,
                                     download=True)
    elif args.dataset == "F-MNIST":
        train_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.dataset],
                                              train=True,
                                              transform=preprocessor,
                                              download=True)
        val_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.dataset],
                                            train=False,
                                            transform=preprocessor,
                                            download=True)
    elif args.dataset == "SVHN":
        train_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset],
                             train=True,
                             transform=preprocessor)
        val_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset],
                           train=False,
                           transform=preprocessor)

    # load image classifier model
    img_classifier_model_path = "{}/train_pytorch_model/DL_IMAGE_CLASSIFIER_{}@{}@epoch_40@lr_0.0001@batch_500.pth.tar".format(
        PY_ROOT, args.dataset, args.adv_arch)
    if args.adv_arch == "resnet10":
        img_classifier_network = resnet10(
            num_classes=CLASS_NUM[args.dataset],
            in_channels=IN_CHANNELS[args.dataset])
    elif args.adv_arch == "resnet18":
        img_classifier_network = resnet18(
            num_classes=CLASS_NUM[args.dataset],
            in_channels=IN_CHANNELS[args.dataset])
    elif args.adv_arch == "conv3":
        img_classifier_network = Conv3(IN_CHANNELS[args.dataset],
                                       IMAGE_SIZE[args.dataset],
                                       CLASS_NUM[args.dataset])

    print("=> loading checkpoint '{}'".format(img_classifier_model_path))
    checkpoint = torch.load(img_classifier_model_path,
                            map_location=lambda storage, loc: storage)
    img_classifier_network.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {}) for img classifier".format(
        img_classifier_model_path, checkpoint['epoch']))
    img_classifier_network.eval()
    img_classifier_network = img_classifier_network.cuda()

    # load detector model
    if args.detector == "MetaAdvDet":
        # 如果攻击finetune后的model,则需要动态生成噪音,每次support上finetune之后,迅速进行攻击生成新的对抗样本
        detector_net = build_meta_adv_detector(
            args.dataset, args.det_arch, args.adv_arch, args.shot, args.
            protocol)  # 需要在support上fine-tune后进行检测,到底攻击finetune后的还是finetune前的
    elif args.detector == "DNN":
        detector_net = build_DNN_detector(args.dataset, args.det_arch,
                                          args.adv_arch, args.protocol)
    elif args.detector == "RotateDet":
        detector_net = build_rotate_detector(args.dataset, args.det_arch,
                                             args.adv_arch, args.protocol)
    elif args.detector == "NeuralFP":
        detector_net = build_neural_fingerprint_detector(
            args.dataset, args.det_arch)

    if args.detector == "NeuralFP":
        if args.attack == "CW_L2":
            attack = CarliniWagnerL2Fingerprint(img_classifier_network,
                                                targeted=True,
                                                confidence=0.3,
                                                search_steps=30,
                                                max_steps=args.atk_max_iter,
                                                optimizer_lr=0.01,
                                                neural_fp=detector_net)
        elif args.attack == "FGSM":
            attack = IterativeFastGradientSignTargetedFingerprint(
                img_classifier_network,
                alpha=0.01,
                max_iters=args.atk_max_iter,
                neural_fp=detector_net)
    else:
        detector_net.eval()
        combined_model = CombinedModel(img_classifier_network, detector_net)
        combined_model.cuda()
        combined_model.eval()
        if args.attack == "CW_L2":
            attack = CarliniWagnerL2(combined_model,
                                     True,
                                     confidence=0.3,
                                     search_steps=30,
                                     max_steps=args.atk_max_iter,
                                     optimizer_lr=0.01)
        elif args.attack == "FGSM":
            attack = IterativeFastGradientSignTargeted(
                combined_model, alpha=0.01, max_iters=args.atk_max_iter)

    generate(
        attack,
        args.attack,
        args.dataset,
        args.detector,
        val_dataset,
        output_dir="{}/adversarial_images/white_box@data_{}@det_{}".format(
            IMAGE_DATA_ROOT[args.dataset], args.adv_arch, args.detector),
        args=args)
Ejemplo n.º 10
0
def main_train_worker(args, model_path, META_ATTACKER_PART_I=None,META_ATTACKER_PART_II=None,gpu="0"):
    if META_ATTACKER_PART_I  is None:
        META_ATTACKER_PART_I = config.META_ATTACKER_PART_I
    if META_ATTACKER_PART_II is None:
        META_ATTACKER_PART_II = config.META_ATTACKER_PART_II
    print("Use GPU: {} for training".format(gpu))
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    print("will save to {}".format(model_path))
    global best_acc1
    if args.arch == "conv3":
        model = Conv3(IN_CHANNELS[args.dataset],IMAGE_SIZE[args.dataset], 2)
    elif args.arch == "resnet10":
        model = resnet10(2, in_channels=IN_CHANNELS[args.dataset], pretrained=False)
    elif args.arch == "resnet18":
        model = resnet18(2, in_channels=IN_CHANNELS[args.dataset], pretrained=False)
    model = model.cuda()
    if args.dataset == "ImageNet":
        train_dataset = AdversaryRandomAccessNpyDataset(IMAGE_DATA_ROOT[args.dataset] + "/adversarial_images/{}".format(args.adv_arch),
                                                        True, args.protocol, META_ATTACKER_PART_I,META_ATTACKER_PART_II,
                                                        args.balance, args.dataset)
    else:
        train_dataset = AdversaryDataset(IMAGE_DATA_ROOT[args.dataset] + "/adversarial_images/{}".format(args.adv_arch),
                                     True, args.protocol, META_ATTACKER_PART_I,META_ATTACKER_PART_II, args.balance)

    # val_dataset = MetaTaskDataset(20000, 2, 1, 15,
    #                                     args.dataset, is_train=False, pkl_task_dump_path=args.test_pkl_path,
    #                                     load_mode=LOAD_TASK_MODE.LOAD,
    #                                     protocol=args.protocol, no_random_way=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if os.path.exists(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(model_path, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(model_path))

    cudnn.benchmark = True

    # Data loading code
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,
    #     batch_size=100, shuffle=False,
    #     num_workers=0, pin_memory=True)
    tensorboard = TensorBoardWriter("{0}/pytorch_DeepLearning_tensorboard".format(PY_ROOT), "DeepLearning")
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)
        # train for one epoch
        train(train_loader, None, model, criterion, optimizer, epoch,tensorboard, args)
        if args.balance:
            train_dataset.img_label_list.clear()
            train_dataset.img_label_list.extend(train_dataset.img_label_dict[1])
            train_dataset.img_label_list.extend(random.sample(train_dataset.img_label_dict[0], len(train_dataset.img_label_dict[1])))
        # evaluate_accuracy on validation set

        # acc1 = validate(val_loader, model, criterion, args)
        # remember best acc@1 and save checkpoint
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=model_path)