コード例 #1
0
    def __init__(self, dataset, num_classes, meta_batch_size, meta_step_size,
                 inner_step_size, lr_decay_itr, epoch, num_inner_updates,
                 load_task_mode, arch, tot_num_tasks, num_support, detector,
                 attack_name, root_folder):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.meta_batch_size = meta_batch_size  # task number per batch
        self.meta_step_size = meta_step_size
        self.inner_step_size = inner_step_size
        self.lr_decay_itr = lr_decay_itr
        self.epoch = epoch
        self.num_inner_updates = num_inner_updates
        self.test_finetune_updates = num_inner_updates
        # Make the nets
        if arch == "conv3":
            # network = FourConvs(IN_CHANNELS[self.dataset_name], IMAGE_SIZE[self.dataset_name], num_classes)
            network = Conv3(IN_CHANNELS[self.dataset],
                            IMAGE_SIZE[self.dataset], num_classes)
        elif arch == "resnet10":
            network = resnet10(num_classes, pretrained=False)
        elif arch == "resnet18":
            network = resnet18(num_classes, pretrained=False)
        self.network = network
        self.network.cuda()

        val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support,
                                      15, dataset, load_task_mode, detector,
                                      attack_name, root_folder)
        self.val_loader = DataLoader(
            val_dataset,
            batch_size=100,
            shuffle=False,
            num_workers=0,
            pin_memory=True)  # 固定100个task,分别测每个task的准确率
コード例 #2
0
ファイル: train.py プロジェクト: machanic/MetaAdvDet
def build_network(dataset, arch, model_path):
    # if dataset!="ImageNet":
    #     assert os.path.exists(model_path), "{} not exists!".format(model_path)

    if arch in models.__dict__:
        print("=> using pre-trained model '{}'".format(arch))
        img_classifier_network = models.__dict__[arch](pretrained=False)
    else:
        print("=> creating model '{}'".format(arch))
        if arch == "resnet10":
            img_classifier_network = resnet10(num_classes=CLASS_NUM[dataset],
                                              in_channels=IN_CHANNELS[dataset],
                                              pretrained=False)
        elif arch == "resnet18":
            img_classifier_network = resnet18(num_classes=CLASS_NUM[dataset],
                                              in_channels=IN_CHANNELS[dataset],
                                              pretrained=False)
        elif arch == "conv3":
            img_classifier_network = Conv3(IN_CHANNELS[dataset],
                                           IMAGE_SIZE[dataset],
                                           CLASS_NUM[dataset])
    if os.path.exists(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        img_classifier_network.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
    return img_classifier_network
コード例 #3
0
def evaluate_speed(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar")
    result = defaultdict(dict)
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset!="CIFAR-10":
            continue
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        epoch = int(ma.group(4))
        lr = float(ma.group(5))
        batch_size = int(ma.group(6))
        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        tot_num_tasks = 20000
        num_classes = 2

        num_query = 15
        old_num_update = args.num_updates

        all_shots = [0, 1,5]
        num_updates = args.num_updates
        for shot in all_shots:
            report_shot = shot
            if shot == 0:
                num_updates = 0
                shot = 1
            else:
                num_updates = args.num_updates
            meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query,
                                                dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD,
                                                protocol=split_protocol, no_random_way=True, adv_arch="conv3") #FIXME adv arch还没做其他architecture的代码
            data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
            img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                                   CLASS_NUM[dataset])
            image_transform = ImageTransformCV2(dataset, [1, 2])
            layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2
            model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2)
            checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(model_path, checkpoint['epoch']))
            evaluate_result = speed_test(model, data_loader, lr, num_updates, update_BN=False)
            result[dataset][report_shot] = evaluate_result
        break

    with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/speed_test.json".format(PY_ROOT), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
コード例 #4
0
def evaluate_cross_arch(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_ALL_TEST_ALL@model_conv3@data_conv3@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar")
    result = defaultdict(dict)
    update_BN = False
    for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        adv_arch = ma.group(4)
        if adv_arch != args.cross_arch_source:
            continue

        epoch = int(ma.group(5))
        lr = float(ma.group(6))
        batch_size = int(ma.group(7))
        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        tot_num_tasks = 20000
        num_classes = 2
        num_query = 15
        old_num_update = args.num_updates
        for shot in [0,1,5]:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query,
                                                dataset, is_train=False, load_mode=args.load_mode,
                                                protocol=split_protocol, no_random_way=True, adv_arch=args.cross_arch_target, fetch_attack_name=False)
            data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
            img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                                   CLASS_NUM[dataset])
            image_transform = ImageTransformCV2(dataset, [1, 2])
            layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100","SVHN"] else 2
            model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number, num_classes=2)
            checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(model_path, checkpoint['epoch']))
            evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=update_BN)  # FIXME update_BN=False会很高
            if num_update == 0:
                shot = 0
            result["{}@{}-->{}".format(dataset, args.cross_arch_source, args.cross_arch_target)][shot] = evaluate_result
    with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/cross_arch_{}--{}_using_{}_result_updateBN_{}.json".format(PY_ROOT, args.cross_arch_source,
                                                                            args.cross_arch_target, args.protocol, update_BN), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
コード例 #5
0
def build_DNN_detector(dataset, arch, adv_arch, protocol):
    model_path = "{}/train_pytorch_model/white_box_model/DL_DET@{}_{}@model_{}@data_{}@epoch_40@class_2@lr_0.0001@balance_True.pth.tar".format(
        PY_ROOT, dataset, protocol, arch, adv_arch)
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, location: storage)
    print("load {} to detector".format(model_path))
    network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
    network.load_state_dict(checkpoint["state_dict"], strict=True)
    network.cuda()
    return network
コード例 #6
0
    def __init__(self,
                 dataset,
                 num_classes,
                 meta_batch_size,
                 meta_step_size,
                 inner_step_size, lr_decay_itr,
                 epoch,
                 num_inner_updates, load_task_mode, protocol,
                 tot_num_tasks, num_support, num_query, no_random_way,
                 tensorboard_data_prefix, train=True, adv_arch="conv3",rotate=False,need_val=False):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.meta_batch_size = meta_batch_size  # task number per batch
        self.meta_step_size = meta_step_size
        self.inner_step_size = inner_step_size
        self.lr_decay_itr = lr_decay_itr
        self.epoch = epoch
        self.num_inner_updates = num_inner_updates
        self.test_finetune_updates = num_inner_updates
        # Make the nets

        if train:
            # 需要一种特殊的MetaTaskDataset,训练阶段support set就给两个way的class attribute
            trn_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, num_query,
                                          dataset, is_train=True, load_mode=load_task_mode,
                                          protocol=protocol,
                                          no_random_way=no_random_way, adv_arch=adv_arch,rotate=rotate)
            self.train_loader = DataLoader(trn_dataset, batch_size=meta_batch_size, shuffle=True, num_workers=0, pin_memory=True)
            self.tensorboard = TensorBoardWriter("{0}/zeroshot_tensorboard".format(PY_ROOT),
                                                 tensorboard_data_prefix)
            os.makedirs("{0}/zeroshot_tensorboard".format(PY_ROOT), exist_ok=True)
        if need_val:
            val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, 15,
                                          dataset, is_train=False, load_mode=load_task_mode,
                                          protocol=protocol,
                                          no_random_way=True, adv_arch=adv_arch,rotate=rotate)
            self.val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=0, pin_memory=True) # 固定100个task,分别测每个task的准确率

        self.hidden_feature_size = 2048
        self.attr_network = AttributeNetwork(312,1200,self.hidden_feature_size)  # output 2048
        self.relation_network = RelationNetwork(2 * self.hidden_feature_size, 1200, 2)
        self.img_feature_extract_network = Conv3(IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset], self.hidden_feature_size)
        self.attr_network.cuda()
        self.relation_network.cuda()

        self.inner_attr_network  = copy.deepcopy(self.attr_network)  # deal with each task
        self.inner_relation_network = copy.deepcopy(self.relation_network)  # deal with each task

        # 没有内部更新,只有外部更新,optimizer拥有两个网络的参数
        self.opt_attr_net = Adam(self.img_feature_extract_network.parameters() + self.relation_network.parameters() + self.attr_network.parameters(),
                                 lr=meta_step_size)
        self.sched = StepLR(self.opt_attr_net, step_size=30000,gamma=0.5)
コード例 #7
0
ファイル: meta_adv_det.py プロジェクト: machanic/MetaAdvDet
    def __init__(self,
                 dataset,
                 num_classes,
                 meta_batch_size,
                 meta_step_size,
                 inner_step_size, lr_decay_itr,
                 epoch,
                 num_inner_updates, load_task_mode, protocol, arch,
                 tot_num_tasks, num_support, num_query, no_random_way,
                 tensorboard_data_prefix, train=True, adv_arch="conv4", need_val=False):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.meta_batch_size = meta_batch_size  # task number per batch
        self.meta_step_size = meta_step_size
        self.inner_step_size = inner_step_size
        self.lr_decay_itr = lr_decay_itr
        self.epoch = epoch
        self.num_inner_updates = num_inner_updates
        self.test_finetune_updates = num_inner_updates
        # Make the nets
        if arch == "conv3":
            # network = FourConvs(IN_CHANNELS[self.dataset_name], IMAGE_SIZE[self.dataset_name], num_classes)
            network = Conv3(IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset], num_classes)
        elif arch == "resnet10":
            network = MetaNetwork(resnet10(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False),
                                  IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset])
        elif arch == "resnet18":
            network = MetaNetwork(resnet18(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False),
                                  IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset])

        self.network = network
        self.network.cuda()
        if train:
            trn_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, num_query,
                                          dataset, is_train=True, load_mode=load_task_mode,
                                          protocol=protocol,
                                          no_random_way=no_random_way, adv_arch=adv_arch, fetch_attack_name=False)
            # task number per mini-batch is controlled by DataLoader
            self.train_loader = DataLoader(trn_dataset, batch_size=meta_batch_size, shuffle=True, num_workers=4, pin_memory=True)
            self.tensorboard = TensorBoardWriter("{0}/pytorch_MAML_tensorboard".format(PY_ROOT),
                                                 tensorboard_data_prefix)
            os.makedirs("{0}/pytorch_MAML_tensorboard".format(PY_ROOT), exist_ok=True)
        if need_val:
            val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, 15,
                                          dataset, is_train=False, load_mode=load_task_mode,
                                          protocol=protocol,
                                          no_random_way=True, adv_arch=adv_arch, fetch_attack_name=False)
            self.val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=4, pin_memory=True) # 固定100个task,分别测每个task的准确率
        self.fast_net = InnerLoop(self.network, self.num_inner_updates,
                                  self.inner_step_size, self.meta_batch_size)  # 并行执行每个task
        self.fast_net.cuda()
        self.opt = Adam(self.network.parameters(), lr=meta_step_size)
コード例 #8
0
def build_meta_adv_detector(dataset, arch, adv_arch, shot, protocol):
    # extract_pattern = re.compile(
    #     ".*/MAML@(.*?)_(.*?)@model_(.*?)@data.*?@epoch_(\d+)@meta_batch_size_(\d+)@way_(\d+)@shot_(\d+)@num_query_(\d+)@num_updates_(\d+)@lr_(.*?)@inner_lr_(.*?)@fixed_way_(.*?)@rotate_(.*?)\.pth.tar")
    # str2bool = lambda v: v.lower() in ("yes", "true", "t", "1")
    model_path = "{}/train_pytorch_model/white_box_model/MAML@{}_{}@model_{}@data_{}@epoch_4@meta_batch_size_30@way_2@shot_{}@num_query_35@num_updates_12@lr_0.0001@inner_lr_0.001@fixed_way_True@rotate_False.pth.tar".format(
        PY_ROOT, dataset, protocol, arch, adv_arch, shot)
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, location: storage)
    print("load {} to detector".format(model_path))
    network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
    network.load_state_dict(checkpoint['state_dict'], strict=True)
    network.cuda()
    return network
コード例 #9
0
def build_rotate_detector(dataset, arch, adv_arch, protocol):
    img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                   CLASS_NUM[dataset])
    model_path = "{}/train_pytorch_model/white_box_model/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar".format(
        PY_ROOT, dataset, protocol, arch, adv_arch)
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, location: storage)
    print("load {} to detector".format(model_path))
    image_transform = ImageTransformTorch(dataset, [5, 15])
    layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2
    detector = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],
                        image_transform, layer_number)
    detector.load_state_dict(checkpoint['state_dict'], strict=True)
    detector.cuda()
    return detector
コード例 #10
0
def evaluate_finetune(model_path_list, lr, protocol):
    # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做
    extract_pattern_detail = re.compile(".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar")
    tot_num_tasks = 20000
    way =  2
    query = 15
    result = defaultdict(dict)
    assert protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II, "protocol {} is not TRAIN_I_TEST_II!".format(protocol)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset != "CIFAR-10":
            continue
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue
        adv_arch = ma.group(4)
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"
        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        arch = ma.group(3)
        model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        model = model.cuda()
        import torch
        checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(model_path, checkpoint['epoch']))
        shot = 1
        meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query,
                                            dataset, is_train=False,
                                            load_mode=LOAD_TASK_MODE.LOAD,
                                            protocol=protocol, no_random_way=True, adv_arch=adv_arch)
        data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
        for num_update in range(0,51):
            evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=False)
            if num_update == 0:
                shot = 0
            else:
                shot = 1
            evaluate_result["shot"] = shot
            result["{}_{}".format(dataset,balance)][num_update] = evaluate_result
    return result
コード例 #11
0
def evaluate_whitebox(dataset, arch, adv_arch, detector, attack_name, num_update, lr, protocol, load_mode,result):
    # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做
    extract_pattern_detail = re.compile(".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar")
    tot_num_tasks = 20000
    way =  2
    query = 15

    model_path = "{}/train_pytorch_model/white_box_model/DL_DET@{}_{}@model_{}@data_{}@epoch_40@class_2@lr_0.0001@balance_True.pth.tar".format(
        PY_ROOT, dataset, protocol, arch, adv_arch)
    assert os.path.exists(model_path), "{} is not exists".format(model_path)
    root_folder = IMAGE_DATA_ROOT[dataset] + "/adversarial_images/white_box@data_{}@det_{}/{}/".format(adv_arch, detector,attack_name)
    ma = extract_pattern_detail.match(model_path)
    balance = ma.group(8)
    if balance == "True":
        balance = "balance"
    else:
        balance = "no_balance"
    print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
    arch = ma.group(3)
    model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
    model = model.cuda()
    checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})"
          .format(model_path, checkpoint['epoch']))

    old_num_update = num_update
    # for shot in range(16):
    for shot in [0,1,5]:
        if shot == 0:
            shot = 1
            num_update = 0
        else:
            num_update = old_num_update

        meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query,
                                            dataset,
                                            load_mode=load_mode, detector=detector, attack_name=attack_name,root_folder=root_folder)
        data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
        evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update,update_BN=True)
        if num_update == 0:
            shot = 0
        result["{}_{}_{}_{}".format(dataset, attack_name, detector, adv_arch)][shot] = evaluate_result

    return result
コード例 #12
0
def build_neural_fingerprint_detector(dataset, arch, eps=0.1, num_dx=5):
    output_dx_dy_dir = "{}/NF_dx_dy".format(PY_ROOT)
    model_path = "{}/train_pytorch_model/white_box_model/NF_Det@{}@{}*.pth.tar".format(
        PY_ROOT, dataset, arch)
    model_path = glob.glob(model_path)[0]
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, location: storage)
    print("load {} to detector".format(model_path))
    network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                    CLASS_NUM[dataset])
    network.load_state_dict(checkpoint["state_dict"])
    network.cuda()
    detector = NeuralFingerprintDetector(dataset,
                                         network,
                                         num_dx,
                                         CLASS_NUM[dataset],
                                         eps=eps,
                                         out_fp_dxdy_dir=output_dx_dy_dir)
    return detector
コード例 #13
0
def main():
    args = parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    transform = get_preprocessor(
        IN_CHANNELS[args.ds_name],
        IMAGE_SIZE[args.ds_name])  # NeuralFP的MNIST和F-MNIST实验需要重做,因为发现单通道bug
    kwargs = {'pin_memory': True}
    if not args.evaluate:  # 训练模式
        if args.ds_name == "MNIST":
            trn_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name],
                                         train=True,
                                         download=False,
                                         transform=transform)
            val_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name],
                                         train=False,
                                         download=False,
                                         transform=transform)
        elif args.ds_name == "F-MNIST":
            trn_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name],
                                                train=True,
                                                download=False,
                                                transform=transform)
            val_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name],
                                                train=False,
                                                download=False,
                                                transform=transform)
        elif args.ds_name == "CIFAR-10":
            trn_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name],
                                           train=True,
                                           download=False,
                                           transform=transform)
            val_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name],
                                           train=False,
                                           download=False,
                                           transform=transform)
        elif args.ds_name == "SVHN":
            trn_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name],
                               train=True,
                               transform=transform)
            val_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name],
                               train=False,
                               transform=transform)
        elif args.ds_name == "ImageNet":
            trn_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] +
                                              "/new2",
                                              train=True,
                                              transform=transform)
            val_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] +
                                              "/new2",
                                              train=False,
                                              transform=transform)

        train_loader = torch.utils.data.DataLoader(trn_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   **kwargs)
        test_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.test_batch_size,
            shuffle=False,
            num_workers=0,
            **kwargs)

        if args.arch == "conv3":
            network = Conv3(IN_CHANNELS[args.ds_name],
                            IMAGE_SIZE[args.ds_name], CLASS_NUM[args.ds_name])
        elif args.arch == "resnet10":
            network = resnet10(in_channels=IN_CHANNELS[args.ds_name],
                               num_classes=CLASS_NUM[args.ds_name])
        elif args.arch == "resnet18":
            network = resnet18(in_channels=IN_CHANNELS[args.ds_name],
                               num_classes=CLASS_NUM[args.ds_name])
        network.cuda()
        model_path = os.path.join(
            PY_ROOT, "train_pytorch_model/NF_Det",
            "NF_Det@{}@{}@epoch_{}@lr_{}@eps_{}@num_dx_{}@num_class_{}.pth.tar"
            .format(args.ds_name, args.arch, args.epochs, args.lr, args.eps,
                    args.num_dx, CLASS_NUM[args.ds_name]))
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        detector = NeuralFingerprintDetector(
            args.ds_name,
            network,
            args.num_dx,
            CLASS_NUM[args.ds_name],
            eps=args.eps,
            out_fp_dxdy_dir=args.output_dx_dy_dir)

        optimizer = optim.SGD(network.parameters(),
                              lr=args.lr,
                              weight_decay=1e-6,
                              momentum=0.9)
        resume_epoch = 0
        print("{}".format(model_path))
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path,
                                    lambda storage, location: storage)
            optimizer.load_state_dict(checkpoint["optimizer"])
            resume_epoch = checkpoint["epoch"]
            network.load_state_dict(checkpoint["state_dict"])

        for epoch in range(resume_epoch, args.epochs + 1):
            if (epoch == 1):
                detector.test(epoch,
                              test_loader,
                              test_length=0.1 * len(val_dataset))
            detector.train(epoch, optimizer, train_loader)

            print("Epoch{}, Saving model in {}".format(epoch, model_path))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': network.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, model_path)
    else:  # 测试模式

        if args.study_subject == "speed_test":
            evaluate_speed(args)
            return

        extract_pattern = re.compile(
            ".*NF_Det@(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@eps_(.*?)@num_dx_(\d+)@num_class_(\d+).pth.tar"
        )
        results = defaultdict(dict)
        for model_path in glob.glob(
                "{}/train_pytorch_model/NF_Det/NF_Det@*".format(PY_ROOT)):
            ma = extract_pattern.match(model_path)
            ds_name = ma.group(1)
            # if ds_name == "ImageNet": # FIXME
            #     continue
            # if ds_name != "CIFAR-10": # FIXME
            #     continue

            arch = ma.group(2)
            epoch = int(ma.group(3))
            num_dx = int(ma.group(6))
            eps = float(ma.group(5))
            if arch == "conv3":
                network = Conv3(IN_CHANNELS[ds_name], IMAGE_SIZE[ds_name],
                                CLASS_NUM[ds_name])
            elif arch == "resnet10":
                network = resnet10(in_channels=IN_CHANNELS[ds_name],
                                   num_classes=CLASS_NUM[ds_name])
            elif arch == "resnet18":
                network = resnet18(in_channels=IN_CHANNELS[ds_name],
                                   num_classes=CLASS_NUM[ds_name])
            reject_thresholds = [0. + 0.001 * i for i in range(2050)]
            network.load_state_dict(
                torch.load(model_path,
                           lambda storage, location: storage)["state_dict"])
            network.cuda()
            print("load {} over".format(model_path))
            detector = NeuralFingerprintDetector(
                ds_name,
                network,
                num_dx,
                CLASS_NUM[ds_name],
                eps=eps,
                out_fp_dxdy_dir=args.output_dx_dy_dir)
            # 不存在cross arch的概念
            if args.study_subject == "shots":
                all_shots = [0, 1, 5]
                # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699}
                old_updates = args.num_updates
                # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699}
                for shot in all_shots:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_updates
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        ds_name,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=args.adv_arch,
                        fetch_attack_name=True)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    # if args.profile:
                    #     cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof")
                    #     s = pstats.Stats("Profile.prof")
                    #     s.strip_dirs().sort_stats("time").print_stats()
                    # else:
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds,
                        args.num_updates, args.lr)
                    results[ds_name][report_shot] = {
                        "F1": F1,
                        "best_tau": tau,
                        "eps": eps,
                        "num_dx": num_dx,
                        "num_updates": args.num_updates,
                        "attack_stats": attacker_stats
                    }
                    print("shot {} done".format(shot))

            elif args.study_subject == "cross_domain":
                source_dataset, target_dataset = args.cross_domain_source, args.cross_domain_target
                if ds_name != source_dataset:
                    continue
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                old_num_update = args.num_updates
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                for shot in [0, 1, 5]:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_num_update
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        target_dataset,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=args.adv_arch,
                        fetch_attack_name=False)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, target_dataset, reject_thresholds,
                        args.num_updates, args.lr)
                    results["{}--{}@data_adv_arch_{}".format(
                        source_dataset, target_dataset,
                        args.adv_arch)][report_shot] = {
                            "F1": F1,
                            "best_tau": tau,
                            "eps": eps,
                            "num_dx": num_dx,
                            "num_updates": args.num_updates,
                            "attack_stats": attacker_stats
                        }
            elif args.study_subject == "cross_arch":
                target_arch = args.cross_arch_target
                old_num_update = args.num_updates
                for shot in [0, 1, 5]:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_num_update
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        ds_name,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=target_arch,
                        fetch_attack_name=False)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds,
                        args.num_updates, args.lr)
                    results["{}_target_arch_{}".format(
                        ds_name, target_arch)][report_shot] = {
                            "F1": F1,
                            "best_tau": tau,
                            "eps": eps,
                            "num_dx": num_dx,
                            "num_updates": args.num_updates,
                            "attack_stats": attacker_stats
                        }

            elif args.study_subject == "finetune_eval":
                shot = 1
                query_count = 15
                old_updates = args.num_updates
                num_way = 2
                num_query = 15
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                if ds_name != args.ds_name:
                    continue
                val_dataset = MetaTaskDataset(20000,
                                              num_way,
                                              shot,
                                              num_query,
                                              ds_name,
                                              is_train=False,
                                              load_mode=args.load_task_mode,
                                              protocol=args.protocol,
                                              no_random_way=True,
                                              adv_arch=args.adv_arch)
                adv_val_loader = torch.utils.data.DataLoader(val_dataset,
                                                             batch_size=100,
                                                             shuffle=False,
                                                             **kwargs)
                args.num_updates = 50
                for num_update in range(0, 51):
                    # if args.profile:
                    #     cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof")
                    #     s = pstats.Stats("Profile.prof")
                    #     s.strip_dirs().sort_stats("time").print_stats()
                    # else:
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds, num_update,
                        args.lr)
                    results[ds_name][num_update] = {
                        "F1": F1,
                        "best_tau": tau,
                        "eps": eps,
                        "num_dx": num_dx,
                        "num_updates": num_update,
                        "attack_stats": attacker_stats
                    }
                    print("finetune {} done".format(shot))

        if not args.profile:
            if args.study_subject == "cross_domain":
                filename = "{}/train_pytorch_model/NF_Det/cross_domain_{}--{}@adv_arch_{}.json".format(
                    PY_ROOT, args.cross_domain_source,
                    args.cross_domain_target, args.adv_arch)
            elif args.study_subject == "cross_arch":
                filename = "{}/train_pytorch_model/NF_Det/cross_arch_target_{}.json".format(
                    PY_ROOT, args.cross_arch_target)
            else:
                filename = "{}/train_pytorch_model/NF_Det/{}@data_{}@protocol_{}@lr_{}@finetune_{}.json".format(
                    PY_ROOT, args.study_subject, args.adv_arch, args.protocol,
                    args.lr, args.num_updates)
            with open(filename, "w") as file_obj:
                file_obj.write(json.dumps(results))
                file_obj.flush()
コード例 #14
0
def evaluate_whitebox_attack(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    attacks = ["FGSM", "CW_L2"]
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar")
    result = defaultdict(dict)
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    model_path = "{}/train_pytorch_model/white_box_model/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar".format(
        PY_ROOT, args.dataset, args.protocol, "conv3", args.adv_arch)
    assert os.path.exists(model_path), "{} not exists".format(model_path)
    ma = extract_pattern_detail.match(model_path)
    dataset = ma.group(1)
    # if dataset != "MNIST":
    #     continue
    split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
    arch = ma.group(3)
    epoch = int(ma.group(4))
    lr = float(ma.group(5))
    batch_size = int(ma.group(6))
    print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
    tot_num_tasks = 20000
    num_classes = 2

    num_query = 15
    old_num_update = args.num_updates
    all_shots = [1,5]
    detector = "RotateDet"
    checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
    for attack_name in attacks:
        for shot in all_shots:
            root_folder = IMAGE_DATA_ROOT[dataset] + "/adversarial_images/white_box@data_{}@det_{}/{}/".format(args.adv_arch,
                                                                                                               detector,
                                                                                                               attack_name)
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = WhiteBoxMetaTaskDataset(tot_num_tasks, num_classes, shot, num_query,
                                                dataset, args.load_mode, detector, attack_name, root_folder)
            data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
            img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                                   CLASS_NUM[dataset])
            image_transform = ImageTransformTorch(dataset, [5, 15])
            layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2
            model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(model_path, checkpoint['epoch']))
            evaluate_result = finetune_eval_task_rotate(model, data_loader, lr, num_update, update_BN=args.eval_update_BN)
            if num_update == 0:
                shot = 0
            result["{}_{}_{}_{}".format(dataset, attack_name, detector, args.adv_arch)][shot] = evaluate_result
    if args.eval_update_BN:
        update_BN_str="UpdateBN"
    else:
        update_BN_str = "NoUpdateBN"
    with open("{}/train_pytorch_model/white_box_model/white_box_RotateDet_{}_{}_using_{}__result.json".format(PY_ROOT,args.dataset, update_BN_str, args.protocol), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
コード例 #15
0
def evaluate_speed(args):
    extract_pattern = re.compile(
        ".*NF_Det@(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@eps_(.*?)@num_dx_(\d+)@num_class_(\d+).pth.tar"
    )
    results = defaultdict(dict)
    for model_path in glob.glob(
            "{}/train_pytorch_model/NF_Det/NF_Det@*".format(PY_ROOT)):
        ma = extract_pattern.match(model_path)
        ds_name = ma.group(1)
        if ds_name != "CIFAR-10":
            continue
        arch = ma.group(2)
        epoch = int(ma.group(3))
        num_dx = int(ma.group(6))
        eps = float(ma.group(5))
        network = Conv3(IN_CHANNELS[ds_name], IMAGE_SIZE[ds_name],
                        CLASS_NUM[ds_name])

        reject_thresholds = [0. + 0.001 * i for i in range(2050)]
        network.load_state_dict(
            torch.load(model_path,
                       lambda storage, location: storage)["state_dict"])
        network.cuda()
        print("load {} over".format(model_path))
        detector = NeuralFingerprintDetector(
            ds_name,
            network,
            num_dx,
            CLASS_NUM[ds_name],
            eps=eps,
            out_fp_dxdy_dir=args.output_dx_dy_dir)
        # 不存在cross arch的概念
        all_shots = [0, 1, 5]
        for shot in all_shots:
            report_shot = shot
            if shot == 0:
                num_updates = 0
                shot = 1
            else:
                num_updates = args.num_updates

            num_way = 2
            num_query = 15
            val_dataset = MetaTaskDataset(20000,
                                          num_way,
                                          shot,
                                          num_query,
                                          ds_name,
                                          is_train=False,
                                          load_mode=args.load_task_mode,
                                          protocol=args.protocol,
                                          no_random_way=True,
                                          adv_arch=args.adv_arch)
            adv_val_loader = torch.utils.data.DataLoader(val_dataset,
                                                         batch_size=100,
                                                         shuffle=False)
            mean_time, var_time = detector.test_speed(
                adv_val_loader,
                ds_name,
                reject_thresholds,
                num_updates,
                args.lr,
            )
            results[ds_name][report_shot] = {
                "mean_time": mean_time,
                "var_time": var_time
            }
            print("shot {} done".format(shot))
        break

    file_name = "{}/train_pytorch_model/NF_Det/speed_test_result.json".format(
        PY_ROOT)
    with open(file_name, "w") as file_obj:
        file_obj.write(json.dumps(results))
        file_obj.flush()
コード例 #16
0
def evaluate_zero_shot(model_path_list, lr, protocol, args):
    # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做
    extract_pattern_detail = re.compile(
        ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar"
    )
    tot_num_tasks = 20000
    way = 2
    query = 15
    result = defaultdict(dict)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"

        key = "{}@{}__{}".format(dataset, balance, protocol)
        if args.cross_domain_source is not None:
            if dataset != args.cross_domain_source:
                continue
            dataset = args.cross_domain_target
            key = "{}@{}-->{}__{}".format(args.cross_domain_source, balance,
                                          dataset, protocol)

        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        arch = ma.group(3)
        adv_arch = ma.group(4)
        model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        model = model.cuda()

        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
        shot = 1
        num_update = 0
        meta_task_dataset = MetaTaskDataset(tot_num_tasks,
                                            way,
                                            shot,
                                            query,
                                            dataset,
                                            is_train=False,
                                            load_mode=LOAD_TASK_MODE.LOAD,
                                            protocol=protocol,
                                            no_random_way=True,
                                            adv_arch=adv_arch)
        data_loader = DataLoader(meta_task_dataset,
                                 batch_size=100,
                                 shuffle=False,
                                 pin_memory=True)
        evaluate_result = finetune_eval_task_accuracy(model,
                                                      data_loader,
                                                      lr,
                                                      num_update,
                                                      update_BN=False)
        result[key][0] = evaluate_result
    return result
コード例 #17
0
def main():
    args = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    preprocessor = get_preprocessor(input_channels=IN_CHANNELS[args.dataset])

    if args.dataset == "CIFAR-10":
        train_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.dataset],
                                         train=True,
                                         transform=preprocessor)
        val_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.dataset],
                                       train=False,
                                       transform=preprocessor)
    elif args.dataset == "MNIST":
        train_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.dataset],
                                       train=True,
                                       transform=preprocessor,
                                       download=True)
        val_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.dataset],
                                     train=False,
                                     transform=preprocessor,
                                     download=True)
    elif args.dataset == "F-MNIST":
        train_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.dataset],
                                              train=True,
                                              transform=preprocessor,
                                              download=True)
        val_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.dataset],
                                            train=False,
                                            transform=preprocessor,
                                            download=True)
    elif args.dataset == "SVHN":
        train_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset],
                             train=True,
                             transform=preprocessor)
        val_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset],
                           train=False,
                           transform=preprocessor)

    # load image classifier model
    img_classifier_model_path = "{}/train_pytorch_model/DL_IMAGE_CLASSIFIER_{}@{}@epoch_40@lr_0.0001@batch_500.pth.tar".format(
        PY_ROOT, args.dataset, args.adv_arch)
    if args.adv_arch == "resnet10":
        img_classifier_network = resnet10(
            num_classes=CLASS_NUM[args.dataset],
            in_channels=IN_CHANNELS[args.dataset])
    elif args.adv_arch == "resnet18":
        img_classifier_network = resnet18(
            num_classes=CLASS_NUM[args.dataset],
            in_channels=IN_CHANNELS[args.dataset])
    elif args.adv_arch == "conv3":
        img_classifier_network = Conv3(IN_CHANNELS[args.dataset],
                                       IMAGE_SIZE[args.dataset],
                                       CLASS_NUM[args.dataset])

    print("=> loading checkpoint '{}'".format(img_classifier_model_path))
    checkpoint = torch.load(img_classifier_model_path,
                            map_location=lambda storage, loc: storage)
    img_classifier_network.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {}) for img classifier".format(
        img_classifier_model_path, checkpoint['epoch']))
    img_classifier_network.eval()
    img_classifier_network = img_classifier_network.cuda()

    # load detector model
    if args.detector == "MetaAdvDet":
        # 如果攻击finetune后的model,则需要动态生成噪音,每次support上finetune之后,迅速进行攻击生成新的对抗样本
        detector_net = build_meta_adv_detector(
            args.dataset, args.det_arch, args.adv_arch, args.shot, args.
            protocol)  # 需要在support上fine-tune后进行检测,到底攻击finetune后的还是finetune前的
    elif args.detector == "DNN":
        detector_net = build_DNN_detector(args.dataset, args.det_arch,
                                          args.adv_arch, args.protocol)
    elif args.detector == "RotateDet":
        detector_net = build_rotate_detector(args.dataset, args.det_arch,
                                             args.adv_arch, args.protocol)
    elif args.detector == "NeuralFP":
        detector_net = build_neural_fingerprint_detector(
            args.dataset, args.det_arch)

    if args.detector == "NeuralFP":
        if args.attack == "CW_L2":
            attack = CarliniWagnerL2Fingerprint(img_classifier_network,
                                                targeted=True,
                                                confidence=0.3,
                                                search_steps=30,
                                                max_steps=args.atk_max_iter,
                                                optimizer_lr=0.01,
                                                neural_fp=detector_net)
        elif args.attack == "FGSM":
            attack = IterativeFastGradientSignTargetedFingerprint(
                img_classifier_network,
                alpha=0.01,
                max_iters=args.atk_max_iter,
                neural_fp=detector_net)
    else:
        detector_net.eval()
        combined_model = CombinedModel(img_classifier_network, detector_net)
        combined_model.cuda()
        combined_model.eval()
        if args.attack == "CW_L2":
            attack = CarliniWagnerL2(combined_model,
                                     True,
                                     confidence=0.3,
                                     search_steps=30,
                                     max_steps=args.atk_max_iter,
                                     optimizer_lr=0.01)
        elif args.attack == "FGSM":
            attack = IterativeFastGradientSignTargeted(
                combined_model, alpha=0.01, max_iters=args.atk_max_iter)

    generate(
        attack,
        args.attack,
        args.dataset,
        args.detector,
        val_dataset,
        output_dir="{}/adversarial_images/white_box@data_{}@det_{}".format(
            IMAGE_DATA_ROOT[args.dataset], args.adv_arch, args.detector),
        args=args)
コード例 #18
0
def main_train_worker(args, model_path, META_ATTACKER_PART_I=None,META_ATTACKER_PART_II=None,gpu="0"):
    if META_ATTACKER_PART_I  is None:
        META_ATTACKER_PART_I = config.META_ATTACKER_PART_I
    if META_ATTACKER_PART_II is None:
        META_ATTACKER_PART_II = config.META_ATTACKER_PART_II
    print("Use GPU: {} for training".format(gpu))
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    print("will save to {}".format(model_path))
    global best_acc1
    if args.arch == "conv3":
        model = Conv3(IN_CHANNELS[args.dataset],IMAGE_SIZE[args.dataset], 2)
    elif args.arch == "resnet10":
        model = resnet10(2, in_channels=IN_CHANNELS[args.dataset], pretrained=False)
    elif args.arch == "resnet18":
        model = resnet18(2, in_channels=IN_CHANNELS[args.dataset], pretrained=False)
    model = model.cuda()
    if args.dataset == "ImageNet":
        train_dataset = AdversaryRandomAccessNpyDataset(IMAGE_DATA_ROOT[args.dataset] + "/adversarial_images/{}".format(args.adv_arch),
                                                        True, args.protocol, META_ATTACKER_PART_I,META_ATTACKER_PART_II,
                                                        args.balance, args.dataset)
    else:
        train_dataset = AdversaryDataset(IMAGE_DATA_ROOT[args.dataset] + "/adversarial_images/{}".format(args.adv_arch),
                                     True, args.protocol, META_ATTACKER_PART_I,META_ATTACKER_PART_II, args.balance)

    # val_dataset = MetaTaskDataset(20000, 2, 1, 15,
    #                                     args.dataset, is_train=False, pkl_task_dump_path=args.test_pkl_path,
    #                                     load_mode=LOAD_TASK_MODE.LOAD,
    #                                     protocol=args.protocol, no_random_way=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if os.path.exists(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(model_path, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(model_path))

    cudnn.benchmark = True

    # Data loading code
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,
    #     batch_size=100, shuffle=False,
    #     num_workers=0, pin_memory=True)
    tensorboard = TensorBoardWriter("{0}/pytorch_DeepLearning_tensorboard".format(PY_ROOT), "DeepLearning")
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)
        # train for one epoch
        train(train_loader, None, model, criterion, optimizer, epoch,tensorboard, args)
        if args.balance:
            train_dataset.img_label_list.clear()
            train_dataset.img_label_list.extend(train_dataset.img_label_dict[1])
            train_dataset.img_label_list.extend(random.sample(train_dataset.img_label_dict[0], len(train_dataset.img_label_dict[1])))
        # evaluate_accuracy on validation set

        # acc1 = validate(val_loader, model, criterion, args)
        # remember best acc@1 and save checkpoint
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=model_path)
コード例 #19
0
ファイル: train.py プロジェクト: machanic/MetaAdvDet
def main_train_worker(gpu, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    # create model
    if args.pretrained and args.arch in models.__dict__:
        print("=> using pre-trained model '{}'".format(args.arch))
        network = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        if args.arch == "resnet10":
            network = resnet10(num_classes=CLASS_NUM[args.dataset], in_channels=IN_CHANNELS[args.dataset])
        elif args.arch == "resnet18":
            network = resnet18(num_classes=CLASS_NUM[args.dataset], in_channels=IN_CHANNELS[args.dataset])
        elif args.arch == "conv3":
            network = Conv3(IN_CHANNELS[args.dataset], IMAGE_SIZE[args.dataset],CLASS_NUM[args.dataset])

    if args.arch.startswith("resnet"):
        network.avgpool = Identity()
        network.fc = nn.Linear(512, CLASS_NUM[args.dataset])
    elif args.arch.startswith("vgg"):
        network.classifier[6] = nn.Linear(4096, CLASS_NUM[args.dataset])

    model_path = './train_pytorch_model/DL_IMAGE_CLASSIFIER_{}@{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format(
        args.dataset, args.arch, args.epochs, args.lr, args.batch_size)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    preprocessor = get_preprocessor(IN_CHANNELS[args.dataset])
    network.cuda()

    # define loss function (criterion) and optimizer
    image_classifier_loss = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.Adam(network.parameters(), args.lr,
                                weight_decay=args.weight_decay)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            network.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.dataset == "CIFAR-10":
        train_dataset = CIFAR10(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor)
        val_dataset = CIFAR10(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor)
    elif args.dataset == "MNIST":
        train_dataset = MNIST(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor, download=True)
        val_dataset = MNIST(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor, download=True)
    elif args.dataset == "F-MNIST":
        train_dataset = FashionMNIST(IMAGE_DATA_ROOT[args.dataset], train=True,transform=preprocessor, download=True)
        val_dataset = FashionMNIST(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor, download=True)
    elif args.dataset=="SVHN":
        train_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor)
        val_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    for epoch in range(0, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, network, image_classifier_loss, optimizer, epoch, args)

        # evaluate_accuracy on validation set
        acc1 = validate(val_loader, network, image_classifier_loss, args)
        print(acc1)
        # remember best acc@1 and save checkpoint

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': network.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=model_path)
コード例 #20
0
def evaluate_cross_arch(model_path_list, num_update, lr, protocol, src_arch,
                        target_arch, updateBN):
    extract_pattern_detail = re.compile(
        ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar"
    )
    tot_num_tasks = 20000
    way = 2
    query = 15
    result = defaultdict(dict)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue

        arch = ma.group(3)
        adv_arch = ma.group(4)
        if adv_arch != src_arch:
            continue
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"
        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        model = model.cuda()
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
        old_num_update = num_update
        for shot in [0, 1, 5]:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(tot_num_tasks,
                                                way,
                                                shot,
                                                query,
                                                dataset,
                                                is_train=False,
                                                load_mode=LOAD_TASK_MODE.LOAD,
                                                protocol=protocol,
                                                no_random_way=True,
                                                adv_arch=target_arch,
                                                fetch_attack_name=False)
            data_loader = DataLoader(meta_task_dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     pin_memory=True)
            evaluate_result = finetune_eval_task_accuracy(model,
                                                          data_loader,
                                                          lr,
                                                          num_update,
                                                          update_BN=updateBN)
            if num_update == 0:
                shot = 0
            result["{}-->{}@{}_{}".format(src_arch, target_arch, dataset,
                                          balance)][shot] = evaluate_result
    return result