예제 #1
0
def train_detector(gpu, arch, adv_data_arch, img_classifier_model_path,
                   dataset, args):
    print("using GPU {}".format(gpu))
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)

    img_classifier_network = build_network(dataset, arch,
                                           img_classifier_model_path)
    layer_number = 3 if dataset in [
        "ImageNet", "CIFAR-10", "CIFAR-100", "SVHN"
    ] else 2

    if args.use_cv_transform:
        image_transform = ImageTransformCV2(dataset, [1, 2])
        detector_model_path = '{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format(
            PY_ROOT, args.dataset, args.protocol, arch, adv_data_arch,
            args.epochs, args.lr, args.batch_size)
        os.makedirs(os.path.dirname(detector_model_path), exist_ok=True)
    else:
        image_transform = ImageTransformTorch(dataset, [5, 15])
        detector_model_path = '{}/train_pytorch_model/ROTATE_DET/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format(
            PY_ROOT, args.dataset, args.protocol, arch, adv_data_arch,
            args.epochs, args.lr, args.batch_size)
        os.makedirs(os.path.dirname(detector_model_path), exist_ok=True)

    detector = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],
                        image_transform, layer_number)
    detector.cuda()
    optimizer = torch.optim.SGD(detector.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    cudnn.benchmark = True
    if args.protocol == SPLIT_DATA_PROTOCOL.TRAIN_ALL_TEST_ALL:
        args.balance = True
    else:
        args.balance = False
    if dataset == "ImageNet":
        train_dataset = AdversaryRandomAccessNpyDataset(
            IMAGE_DATA_ROOT[dataset] +
            "/adversarial_images/{}".format(adv_data_arch), True,
            args.protocol, config.META_ATTACKER_PART_I,
            config.META_ATTACKER_PART_II, args.balance, dataset)
    else:
        train_dataset = AdversaryDataset(
            IMAGE_DATA_ROOT[dataset] +
            "/adversarial_images/{}".format(adv_data_arch),
            True,
            args.protocol,
            config.META_ATTACKER_PART_I,
            config.META_ATTACKER_PART_II,
            balance=args.balance)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    train_epochs(detector_model_path, train_loader, detector, optimizer, arch,
                 args, args.gpu)
예제 #2
0
def evaluate_speed(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar")
    result = defaultdict(dict)
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset!="CIFAR-10":
            continue
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        epoch = int(ma.group(4))
        lr = float(ma.group(5))
        batch_size = int(ma.group(6))
        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        tot_num_tasks = 20000
        num_classes = 2

        num_query = 15
        old_num_update = args.num_updates

        all_shots = [0, 1,5]
        num_updates = args.num_updates
        for shot in all_shots:
            report_shot = shot
            if shot == 0:
                num_updates = 0
                shot = 1
            else:
                num_updates = args.num_updates
            meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query,
                                                dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD,
                                                protocol=split_protocol, no_random_way=True, adv_arch="conv3") #FIXME adv arch还没做其他architecture的代码
            data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
            img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                                   CLASS_NUM[dataset])
            image_transform = ImageTransformCV2(dataset, [1, 2])
            layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2
            model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2)
            checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(model_path, checkpoint['epoch']))
            evaluate_result = speed_test(model, data_loader, lr, num_updates, update_BN=False)
            result[dataset][report_shot] = evaluate_result
        break

    with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/speed_test.json".format(PY_ROOT), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
예제 #3
0
def evaluate_cross_arch(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_ALL_TEST_ALL@model_conv3@data_conv3@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar")
    result = defaultdict(dict)
    update_BN = False
    for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        adv_arch = ma.group(4)
        if adv_arch != args.cross_arch_source:
            continue

        epoch = int(ma.group(5))
        lr = float(ma.group(6))
        batch_size = int(ma.group(7))
        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        tot_num_tasks = 20000
        num_classes = 2
        num_query = 15
        old_num_update = args.num_updates
        for shot in [0,1,5]:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query,
                                                dataset, is_train=False, load_mode=args.load_mode,
                                                protocol=split_protocol, no_random_way=True, adv_arch=args.cross_arch_target, fetch_attack_name=False)
            data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
            img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                                   CLASS_NUM[dataset])
            image_transform = ImageTransformCV2(dataset, [1, 2])
            layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100","SVHN"] else 2
            model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number, num_classes=2)
            checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(model_path, checkpoint['epoch']))
            evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=update_BN)  # FIXME update_BN=False会很高
            if num_update == 0:
                shot = 0
            result["{}@{}-->{}".format(dataset, args.cross_arch_source, args.cross_arch_target)][shot] = evaluate_result
    with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/cross_arch_{}--{}_using_{}_result_updateBN_{}.json".format(PY_ROOT, args.cross_arch_source,
                                                                            args.cross_arch_target, args.protocol, update_BN), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
def build_rotate_detector(dataset, arch, adv_arch, protocol):
    img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                   CLASS_NUM[dataset])
    model_path = "{}/train_pytorch_model/white_box_model/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar".format(
        PY_ROOT, dataset, protocol, arch, adv_arch)
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, location: storage)
    print("load {} to detector".format(model_path))
    image_transform = ImageTransformTorch(dataset, [5, 15])
    layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2
    detector = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],
                        image_transform, layer_number)
    detector.load_state_dict(checkpoint['state_dict'], strict=True)
    detector.cuda()
    return detector
예제 #5
0
def evaluate_whitebox_attack(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    attacks = ["FGSM", "CW_L2"]
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar")
    result = defaultdict(dict)
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    model_path = "{}/train_pytorch_model/white_box_model/IMG_ROTATE_DET@{}_{}@model_{}@data_{}@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar".format(
        PY_ROOT, args.dataset, args.protocol, "conv3", args.adv_arch)
    assert os.path.exists(model_path), "{} not exists".format(model_path)
    ma = extract_pattern_detail.match(model_path)
    dataset = ma.group(1)
    # if dataset != "MNIST":
    #     continue
    split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
    arch = ma.group(3)
    epoch = int(ma.group(4))
    lr = float(ma.group(5))
    batch_size = int(ma.group(6))
    print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
    tot_num_tasks = 20000
    num_classes = 2

    num_query = 15
    old_num_update = args.num_updates
    all_shots = [1,5]
    detector = "RotateDet"
    checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
    for attack_name in attacks:
        for shot in all_shots:
            root_folder = IMAGE_DATA_ROOT[dataset] + "/adversarial_images/white_box@data_{}@det_{}/{}/".format(args.adv_arch,
                                                                                                               detector,
                                                                                                               attack_name)
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = WhiteBoxMetaTaskDataset(tot_num_tasks, num_classes, shot, num_query,
                                                dataset, args.load_mode, detector, attack_name, root_folder)
            data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
            img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                                   CLASS_NUM[dataset])
            image_transform = ImageTransformTorch(dataset, [5, 15])
            layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2
            model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(model_path, checkpoint['epoch']))
            evaluate_result = finetune_eval_task_rotate(model, data_loader, lr, num_update, update_BN=args.eval_update_BN)
            if num_update == 0:
                shot = 0
            result["{}_{}_{}_{}".format(dataset, attack_name, detector, args.adv_arch)][shot] = evaluate_result
    if args.eval_update_BN:
        update_BN_str="UpdateBN"
    else:
        update_BN_str = "NoUpdateBN"
    with open("{}/train_pytorch_model/white_box_model/white_box_RotateDet_{}_{}_using_{}__result.json".format(PY_ROOT,args.dataset, update_BN_str, args.protocol), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()