Esempio n. 1
0
def build_network(dataset, arch, model_path):
    # if dataset!="ImageNet":
    #     assert os.path.exists(model_path), "{} not exists!".format(model_path)

    if arch in models.__dict__:
        print("=> using pre-trained model '{}'".format(arch))
        img_classifier_network = models.__dict__[arch](pretrained=False)
    else:
        print("=> creating model '{}'".format(arch))
        if arch == "resnet10":
            img_classifier_network = resnet10(num_classes=CLASS_NUM[dataset],
                                              in_channels=IN_CHANNELS[dataset],
                                              pretrained=False)
        elif arch == "resnet18":
            img_classifier_network = resnet18(num_classes=CLASS_NUM[dataset],
                                              in_channels=IN_CHANNELS[dataset],
                                              pretrained=False)
        elif arch == "conv3":
            img_classifier_network = Conv3(IN_CHANNELS[dataset],
                                           IMAGE_SIZE[dataset],
                                           CLASS_NUM[dataset])
    if os.path.exists(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        img_classifier_network.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
    return img_classifier_network
Esempio n. 2
0
    def __init__(self, **kwargs):
        super(Conditional_RNVP_with_image_prior, self).__init__()

        self.mode = kwargs.get('usage_mode')
        self.deterministic = kwargs.get('deterministic')

        self.g_latent_space_size = kwargs.get('g_latent_space_size')

        self.g_prior_n_flows = kwargs.get('g_prior_n_flows')
        self.g_prior_n_features = kwargs.get('g_prior_n_features')

        self.p_latent_space_size = kwargs.get('p_latent_space_size')
        self.p_prior_n_layers = kwargs.get('p_prior_n_layers')

        self.p_decoder_n_flows = kwargs.get('p_decoder_n_flows')
        self.p_decoder_n_features = kwargs.get('p_decoder_n_features')
        self.p_decoder_base_type = kwargs.get('p_decoder_base_type')
        self.p_decoder_base_var = kwargs.get('p_decoder_base_var')

        self.p_prior = FeatureEncoder(self.p_prior_n_layers, self.g_latent_space_size,
                                      self.p_latent_space_size, deterministic=False,
                                      mu_weight_std=0.001, mu_bias=0.0,
                                      logvar_weight_std=0.01, logvar_bias=0.0)

        self.pc_decoder = LocalCondRNVPDecoder(self.p_decoder_n_flows,
                                               self.p_decoder_n_features,
                                               self.g_latent_space_size,
                                               weight_std=0.01)

        self.image_encoder = resnet18(num_classes=self.g_latent_space_size)
        self.g_prior_n_layers = kwargs.get('g_prior_n_layers')
        self.g0_prior = FeatureEncoder(self.g_prior_n_layers, self.g_latent_space_size,
                                       self.g_latent_space_size, deterministic=False,
                                       mu_weight_std=0.0033, mu_bias=0.0,
                                       logvar_weight_std=0.033, logvar_bias=0.0)
Esempio n. 3
0
def load_pretrained_model(prefix):
    checkpoint = torch.load(add_prefix(prefix, 'model_best.pth.tar'))
    model = resnet18(is_ptrtrained=False)
    print('load pretrained resnet18 successfully.')
    model.load_state_dict(remove_prefix(checkpoint['state_dict']))
    # print('best acc=%.4f' % checkpoint['best_accuracy'])
    return model
def model_selector(model_type):
    if model_type == 'vgg':
        return vgg19(pretrained=False, num_classes=2)
    elif model_type == 'resnet':
        return resnet18(is_ptrtrained=False)
    else:
        raise ValueError('')
    def __init__(self, dataset, num_classes, meta_batch_size, meta_step_size,
                 inner_step_size, lr_decay_itr, epoch, num_inner_updates,
                 load_task_mode, arch, tot_num_tasks, num_support, detector,
                 attack_name, root_folder):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.meta_batch_size = meta_batch_size  # task number per batch
        self.meta_step_size = meta_step_size
        self.inner_step_size = inner_step_size
        self.lr_decay_itr = lr_decay_itr
        self.epoch = epoch
        self.num_inner_updates = num_inner_updates
        self.test_finetune_updates = num_inner_updates
        # Make the nets
        if arch == "conv3":
            # network = FourConvs(IN_CHANNELS[self.dataset_name], IMAGE_SIZE[self.dataset_name], num_classes)
            network = Conv3(IN_CHANNELS[self.dataset],
                            IMAGE_SIZE[self.dataset], num_classes)
        elif arch == "resnet10":
            network = resnet10(num_classes, pretrained=False)
        elif arch == "resnet18":
            network = resnet18(num_classes, pretrained=False)
        self.network = network
        self.network.cuda()

        val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support,
                                      15, dataset, load_task_mode, detector,
                                      attack_name, root_folder)
        self.val_loader = DataLoader(
            val_dataset,
            batch_size=100,
            shuffle=False,
            num_workers=0,
            pin_memory=True)  # 固定100个task,分别测每个task的准确率
Esempio n. 6
0
def model_selector(model_type):
    if model_type == 'vgg':
        model = vgg()
    elif model_type == 'resnet18':
        model = resnet18(is_ptrtrained=False)
    else:
        raise ValueError('')
    return model
Esempio n. 7
0
    def __init__(self,
                 dataset,
                 num_classes,
                 meta_batch_size,
                 meta_step_size,
                 inner_step_size, lr_decay_itr,
                 epoch,
                 num_inner_updates, load_task_mode, protocol, arch,
                 tot_num_tasks, num_support, num_query, no_random_way,
                 tensorboard_data_prefix, train=True, adv_arch="conv4", need_val=False):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.meta_batch_size = meta_batch_size  # task number per batch
        self.meta_step_size = meta_step_size
        self.inner_step_size = inner_step_size
        self.lr_decay_itr = lr_decay_itr
        self.epoch = epoch
        self.num_inner_updates = num_inner_updates
        self.test_finetune_updates = num_inner_updates
        # Make the nets
        if arch == "conv3":
            # network = FourConvs(IN_CHANNELS[self.dataset_name], IMAGE_SIZE[self.dataset_name], num_classes)
            network = Conv3(IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset], num_classes)
        elif arch == "resnet10":
            network = MetaNetwork(resnet10(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False),
                                  IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset])
        elif arch == "resnet18":
            network = MetaNetwork(resnet18(num_classes, in_channels=IN_CHANNELS[self.dataset], pretrained=False),
                                  IN_CHANNELS[self.dataset], IMAGE_SIZE[self.dataset])

        self.network = network
        self.network.cuda()
        if train:
            trn_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, num_query,
                                          dataset, is_train=True, load_mode=load_task_mode,
                                          protocol=protocol,
                                          no_random_way=no_random_way, adv_arch=adv_arch, fetch_attack_name=False)
            # task number per mini-batch is controlled by DataLoader
            self.train_loader = DataLoader(trn_dataset, batch_size=meta_batch_size, shuffle=True, num_workers=4, pin_memory=True)
            self.tensorboard = TensorBoardWriter("{0}/pytorch_MAML_tensorboard".format(PY_ROOT),
                                                 tensorboard_data_prefix)
            os.makedirs("{0}/pytorch_MAML_tensorboard".format(PY_ROOT), exist_ok=True)
        if need_val:
            val_dataset = MetaTaskDataset(tot_num_tasks, num_classes, num_support, 15,
                                          dataset, is_train=False, load_mode=load_task_mode,
                                          protocol=protocol,
                                          no_random_way=True, adv_arch=adv_arch, fetch_attack_name=False)
            self.val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=4, pin_memory=True) # 固定100个task,分别测每个task的准确率
        self.fast_net = InnerLoop(self.network, self.num_inner_updates,
                                  self.inner_step_size, self.meta_batch_size)  # 并行执行每个task
        self.fast_net.cuda()
        self.opt = Adam(self.network.parameters(), lr=meta_step_size)
Esempio n. 8
0
def load_pretrained_model(pretrained_path, model_type):
    checkpoint = torch.load(add_prefix(pretrained_path, 'model_best.pth.tar'))
    if model_type == 'vgg':
        model = vgg19(pretrained=False, num_classes=2)
        print('load vgg successfully.')
    elif model_type == 'resnet':
        model = resnet18(is_ptrtrained=False)
        print('load resnet18 successfully.')
    else:
        raise ValueError('')
    model.load_state_dict(remove_prefix(checkpoint['state_dict']))
    return model
def load_pretrained_model(prefix, model_type):
    if model_type == 'resnet':
        model = resnet18(is_ptrtrained=False)
    elif model_type == 'vgg':
        model = vgg19(num_classes=2, pretrained=False)
    else:
        raise ValueError('')

    checkpoint = torch.load(add_prefix(prefix, 'model_best.pth.tar'))
    print('load pretrained model successfully.')
    model.load_state_dict(remove_prefix(checkpoint['state_dict']))
    print('best acc=%.4f' % checkpoint['best_accuracy'])
    return model
def model_builder():
    classifier = resnet18(is_ptrtrained=False)
    print('use resnet18')
    auto_encoder = UNet(3, depth=5, in_channels=3)
    auto_encoder.load_state_dict(weight_to_cpu(args.pretrain_unet))
    print('load pretrained unet!')

    model = Locator(aer=auto_encoder, classifier=classifier)
    if args.cuda:
        model = DataParallel(model).cuda()
    else:
        raise ValueError('there is no gpu')

    return model
 def __init__(self, feature_size, n_classes, device):
     # Network architecture
     super(netD, self).__init__()
     self.device = device
     self.feats = feature_size
     self.feature_extractor = resnet18()
     self.feature_extractor.fc = \
         nn.Linear(self.feature_extractor.fc.in_features, feature_size)
     self.bn = nn.BatchNorm1d(feature_size, momentum=0.01)
     self.ReLU = nn.ReLU()
     self.aux_linear = nn.Linear(feature_size, n_classes, bias=False)
     self.disc_linear = nn.Linear(feature_size, 1)
     self.n_classes = n_classes
     self.softmax = nn.Softmax(dim=1)
     self.sigmoid = nn.Sigmoid()
Esempio n. 12
0
 def __init__(self, net_type, image_size=32, args=None):
     super(StrongDisc, self).__init__()
     self.net_type = net_type
     if net_type == 'inception_v3':
         self.net = inception_v3.inception_v3(pretrained=True,
                                              image_size=image_size)
     elif net_type == 'resnet18':
         self.net = resnet.resnet18(pretrained=True)
     elif net_type == 'resnet34':
         self.net = resnet.resnet34(pretrained=True)
     elif net_type == 'resnet50':
         self.net = resnet.resnet50(pretrained=True)
     elif net_type == 'resnet101':
         self.net = resnet.resnet101(pretrained=True)
     elif net_type == 'darts':
         self.net = darts.AugmentCNNOneOutput(model_path=args.darts_model)
     else:
         assert 0
Esempio n. 13
0
def evaluate_shots(model_path_list, num_update, lr, protocol):
    # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做
    extract_pattern_detail = re.compile(
        ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar"
    )
    tot_num_tasks = 20000
    way = 2
    query = 15
    result = defaultdict(dict)
    assert protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II, "protocol {} is not TRAIN_I_TEST_II!".format(
        protocol)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset == "ImageNet":
            continue
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"

        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        arch = ma.group(3)
        adv_arch = ma.group(4)

        if arch == "conv3":
            model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        elif arch == "resnet10":
            model = resnet10(2,
                             in_channels=IN_CHANNELS[dataset],
                             pretrained=False)
        elif arch == "resnet18":
            model = resnet18(2,
                             in_channels=IN_CHANNELS[dataset],
                             pretrained=False)
        model = model.cuda()
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
        old_num_update = num_update
        # for shot in range(16):
        for shot in [0, 1, 5]:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(
                tot_num_tasks,
                way,
                shot,
                query,
                dataset,
                is_train=False,
                load_mode=LOAD_TASK_MODE.NO_LOAD,
                protocol=protocol,
                no_random_way=True,
                adv_arch=adv_arch)
            data_loader = DataLoader(meta_task_dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     pin_memory=True)
            evaluate_result = finetune_eval_task_accuracy(model,
                                                          data_loader,
                                                          lr,
                                                          num_update,
                                                          update_BN=False)
            if num_update == 0:
                shot = 0
            result["{}@{}@{}".format(dataset, balance,
                                     adv_arch)][shot] = evaluate_result
    return result
Esempio n. 14
0
def main():
    args = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    # DL_IMAGE_CLASSIFIER_CIFAR-10@conv4@epoch_40@lr_0.0001@batch_500.pth.tar
    extract_pattern_detail = re.compile(
        ".*?DL_IMAGE_CLASSIFIER_(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth\.tar"
    )
    # test_CIFAR-10_tot_num_tasks_20000_metabatch_10_way_5_shot_5_query_15.txt
    extract_dataset_pattern = re.compile(
        ".*?tot_num_tasks_(\d+)_metabatch_(\d+)_way_(\d+)_shot_(\d+)_query_(\d+).*"
    )
    result = {}
    for model_path in glob.glob(
            "{}/train_pytorch_model/DL_IMAGE_CLASSIFIER*".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        arch = ma.group(2)
        epoch = int(ma.group(3))
        lr = float(ma.group(4))
        batch = int(ma.group(5))
        if arch == "conv4":
            model = FourConvs(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                              CLASS_NUM[dataset])
        elif arch == "resnet10":
            model = resnet10(CLASS_NUM[dataset], IN_CHANNELS[dataset])
        elif arch == "resnet18":
            model = resnet18(CLASS_NUM[dataset], IN_CHANNELS[dataset])
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint["state_dict"])
        model.cuda()
        print("loading {}".format(model_path))
        detector = DetectionEvaluator(model, dataset)
        for pkl_task_file_name in glob.glob("{}/task/{}/test_{}*.pkl".format(
                PY_ROOT, args.split_data_protocol, dataset)):
            preprocessor = get_preprocessor(
                input_size=IMAGE_SIZE[dataset],
                input_channels=IN_CHANNELS[dataset])
            if dataset == "CIFAR-10":
                train_dataset = CIFAR10(IMAGE_DATA_ROOT[dataset],
                                        train=True,
                                        transform=preprocessor)
            elif dataset == "MNIST":
                train_dataset = MNIST(IMAGE_DATA_ROOT[dataset],
                                      train=True,
                                      transform=preprocessor,
                                      download=True)
            elif dataset == "F-MNIST":
                train_dataset = FashionMNIST(IMAGE_DATA_ROOT[dataset],
                                             train=True,
                                             transform=preprocessor,
                                             download=True)
            elif dataset == "SVHN":
                train_dataset = SVHN(IMAGE_DATA_ROOT[dataset],
                                     train=True,
                                     transform=preprocessor)
            ma_d = extract_dataset_pattern.match(pkl_task_file_name)
            tot_num_tasks = int(ma_d.group(1))
            num_classes = int(ma_d.group(3))
            num_support = int(ma_d.group(4))
            num_query = int(ma_d.group(5))
            meta_dataset = MetaTaskDataset(
                tot_num_tasks,
                num_classes,
                num_support,
                num_query,
                dataset,
                is_train=False,
                load_mode=LOAD_TASK_MODE.LOAD,
                pkl_task_dump_path=pkl_task_file_name,
                protocol=args.split_data_protocol)
            val_loader = DataLoader(meta_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=False)

            # train_imgs = get_train_data(train_dataset)
            train_imgs = []
            accuracy = detector.evaluate_detections(train_imgs, val_loader)
            key1 = os.path.basename(model_path)
            key1 = key1[:key1.rindex(".")]
            key = os.path.basename(pkl_task_file_name)
            key = key[:key.rindex(".")]
            result["{}|{}".format(key1, key)] = accuracy
    with open(args.output_path, "w") as file_obj:
        file_obj.write(json.dumps(result))
def main():
    args = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    preprocessor = get_preprocessor(input_channels=IN_CHANNELS[args.dataset])

    if args.dataset == "CIFAR-10":
        train_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.dataset],
                                         train=True,
                                         transform=preprocessor)
        val_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.dataset],
                                       train=False,
                                       transform=preprocessor)
    elif args.dataset == "MNIST":
        train_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.dataset],
                                       train=True,
                                       transform=preprocessor,
                                       download=True)
        val_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.dataset],
                                     train=False,
                                     transform=preprocessor,
                                     download=True)
    elif args.dataset == "F-MNIST":
        train_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.dataset],
                                              train=True,
                                              transform=preprocessor,
                                              download=True)
        val_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.dataset],
                                            train=False,
                                            transform=preprocessor,
                                            download=True)
    elif args.dataset == "SVHN":
        train_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset],
                             train=True,
                             transform=preprocessor)
        val_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset],
                           train=False,
                           transform=preprocessor)

    # load image classifier model
    img_classifier_model_path = "{}/train_pytorch_model/DL_IMAGE_CLASSIFIER_{}@{}@epoch_40@lr_0.0001@batch_500.pth.tar".format(
        PY_ROOT, args.dataset, args.adv_arch)
    if args.adv_arch == "resnet10":
        img_classifier_network = resnet10(
            num_classes=CLASS_NUM[args.dataset],
            in_channels=IN_CHANNELS[args.dataset])
    elif args.adv_arch == "resnet18":
        img_classifier_network = resnet18(
            num_classes=CLASS_NUM[args.dataset],
            in_channels=IN_CHANNELS[args.dataset])
    elif args.adv_arch == "conv3":
        img_classifier_network = Conv3(IN_CHANNELS[args.dataset],
                                       IMAGE_SIZE[args.dataset],
                                       CLASS_NUM[args.dataset])

    print("=> loading checkpoint '{}'".format(img_classifier_model_path))
    checkpoint = torch.load(img_classifier_model_path,
                            map_location=lambda storage, loc: storage)
    img_classifier_network.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {}) for img classifier".format(
        img_classifier_model_path, checkpoint['epoch']))
    img_classifier_network.eval()
    img_classifier_network = img_classifier_network.cuda()

    # load detector model
    if args.detector == "MetaAdvDet":
        # 如果攻击finetune后的model,则需要动态生成噪音,每次support上finetune之后,迅速进行攻击生成新的对抗样本
        detector_net = build_meta_adv_detector(
            args.dataset, args.det_arch, args.adv_arch, args.shot, args.
            protocol)  # 需要在support上fine-tune后进行检测,到底攻击finetune后的还是finetune前的
    elif args.detector == "DNN":
        detector_net = build_DNN_detector(args.dataset, args.det_arch,
                                          args.adv_arch, args.protocol)
    elif args.detector == "RotateDet":
        detector_net = build_rotate_detector(args.dataset, args.det_arch,
                                             args.adv_arch, args.protocol)
    elif args.detector == "NeuralFP":
        detector_net = build_neural_fingerprint_detector(
            args.dataset, args.det_arch)

    if args.detector == "NeuralFP":
        if args.attack == "CW_L2":
            attack = CarliniWagnerL2Fingerprint(img_classifier_network,
                                                targeted=True,
                                                confidence=0.3,
                                                search_steps=30,
                                                max_steps=args.atk_max_iter,
                                                optimizer_lr=0.01,
                                                neural_fp=detector_net)
        elif args.attack == "FGSM":
            attack = IterativeFastGradientSignTargetedFingerprint(
                img_classifier_network,
                alpha=0.01,
                max_iters=args.atk_max_iter,
                neural_fp=detector_net)
    else:
        detector_net.eval()
        combined_model = CombinedModel(img_classifier_network, detector_net)
        combined_model.cuda()
        combined_model.eval()
        if args.attack == "CW_L2":
            attack = CarliniWagnerL2(combined_model,
                                     True,
                                     confidence=0.3,
                                     search_steps=30,
                                     max_steps=args.atk_max_iter,
                                     optimizer_lr=0.01)
        elif args.attack == "FGSM":
            attack = IterativeFastGradientSignTargeted(
                combined_model, alpha=0.01, max_iters=args.atk_max_iter)

    generate(
        attack,
        args.attack,
        args.dataset,
        args.detector,
        val_dataset,
        output_dir="{}/adversarial_images/white_box@data_{}@det_{}".format(
            IMAGE_DATA_ROOT[args.dataset], args.adv_arch, args.detector),
        args=args)
Esempio n. 16
0
    def __init__(self,
                 layers=50,
                 bins=(1, 2, 3, 6),
                 dropout=0.1,
                 classes=2,
                 zoom_factor=8,
                 use_ppm=True,
                 criterion=nn.CrossEntropyLoss(ignore_index=255),
                 BatchNorm=nn.BatchNorm2d,
                 pretrained=True):
        super(PSPNet, self).__init__()
        assert layers in [18, 50, 101, 152]
        assert 2048 % len(bins) == 0
        assert classes > 1
        assert zoom_factor in [1, 2, 4, 8]
        self.zoom_factor = zoom_factor
        self.use_ppm = use_ppm
        self.criterion = criterion
        models.BatchNorm = BatchNorm

        if layers == 50:
            resnet = models.resnet50(pretrained=pretrained)
        elif layers == 101:
            resnet = models.resnet101(pretrained=pretrained)
        elif layers == 18:
            resnet = models.resnet18(pretrained=pretrained)
        else:
            resnet = models.resnet152(pretrained=pretrained)

        if layers == 18:
            self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
                                        resnet.maxpool)
            self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

            for n, m in self.layer3.named_modules():
                if 'conv1' in n:
                    # print('find conv1', m)
                    m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                elif 'downsample.0' in n:
                    # print('find downsample.0', m)
                    m.stride = (1, 1)
            for n, m in self.layer4.named_modules():
                if 'conv1' in n:
                    m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
            fea_dim = 512

        else:
            self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
                                        resnet.conv2, resnet.bn2, resnet.relu,
                                        resnet.conv3, resnet.bn3, resnet.relu,
                                        resnet.maxpool)
            self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

            for n, m in self.layer3.named_modules():
                if 'conv2' in n:
                    # print('find conv2',m)
                    m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                elif 'downsample.0' in n:
                    # print('find downsample.0',m)
                    m.stride = (1, 1)
            for n, m in self.layer4.named_modules():
                if 'conv2' in n:
                    m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
            fea_dim = 2048

        # print('======*********=============')

        if use_ppm:
            self.ppm = PPM(fea_dim, int(fea_dim / len(bins)), bins, BatchNorm)
            fea_dim *= 2
        self.cls = nn.Sequential(
            nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
            BatchNorm(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout),
            # nn.Conv2d(512, classes, kernel_size=1)
            nn.Conv2d(512, classes, kernel_size=1)
            if classes > 2 else nn.Conv2d(512, 1, kernel_size=1))
        if self.training:
            self.aux = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False)
                if layers != 18 else nn.Conv2d(
                    256, 256, kernel_size=3, padding=1, bias=False),
                BatchNorm(256),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=dropout),
                # nn.Conv2d(256, classes, kernel_size=1)
                nn.Conv2d(256, classes, kernel_size=1)
                if classes > 2 else nn.Conv2d(256, 1, kernel_size=1))
Esempio n. 17
0
def main():
    args = parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    transform = get_preprocessor(
        IN_CHANNELS[args.ds_name],
        IMAGE_SIZE[args.ds_name])  # NeuralFP的MNIST和F-MNIST实验需要重做,因为发现单通道bug
    kwargs = {'pin_memory': True}
    if not args.evaluate:  # 训练模式
        if args.ds_name == "MNIST":
            trn_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name],
                                         train=True,
                                         download=False,
                                         transform=transform)
            val_dataset = datasets.MNIST(IMAGE_DATA_ROOT[args.ds_name],
                                         train=False,
                                         download=False,
                                         transform=transform)
        elif args.ds_name == "F-MNIST":
            trn_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name],
                                                train=True,
                                                download=False,
                                                transform=transform)
            val_dataset = datasets.FashionMNIST(IMAGE_DATA_ROOT[args.ds_name],
                                                train=False,
                                                download=False,
                                                transform=transform)
        elif args.ds_name == "CIFAR-10":
            trn_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name],
                                           train=True,
                                           download=False,
                                           transform=transform)
            val_dataset = datasets.CIFAR10(IMAGE_DATA_ROOT[args.ds_name],
                                           train=False,
                                           download=False,
                                           transform=transform)
        elif args.ds_name == "SVHN":
            trn_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name],
                               train=True,
                               transform=transform)
            val_dataset = SVHN(IMAGE_DATA_ROOT[args.ds_name],
                               train=False,
                               transform=transform)
        elif args.ds_name == "ImageNet":
            trn_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] +
                                              "/new2",
                                              train=True,
                                              transform=transform)
            val_dataset = ImageNetRealDataset(IMAGE_DATA_ROOT[args.ds_name] +
                                              "/new2",
                                              train=False,
                                              transform=transform)

        train_loader = torch.utils.data.DataLoader(trn_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   **kwargs)
        test_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.test_batch_size,
            shuffle=False,
            num_workers=0,
            **kwargs)

        if args.arch == "conv3":
            network = Conv3(IN_CHANNELS[args.ds_name],
                            IMAGE_SIZE[args.ds_name], CLASS_NUM[args.ds_name])
        elif args.arch == "resnet10":
            network = resnet10(in_channels=IN_CHANNELS[args.ds_name],
                               num_classes=CLASS_NUM[args.ds_name])
        elif args.arch == "resnet18":
            network = resnet18(in_channels=IN_CHANNELS[args.ds_name],
                               num_classes=CLASS_NUM[args.ds_name])
        network.cuda()
        model_path = os.path.join(
            PY_ROOT, "train_pytorch_model/NF_Det",
            "NF_Det@{}@{}@epoch_{}@lr_{}@eps_{}@num_dx_{}@num_class_{}.pth.tar"
            .format(args.ds_name, args.arch, args.epochs, args.lr, args.eps,
                    args.num_dx, CLASS_NUM[args.ds_name]))
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        detector = NeuralFingerprintDetector(
            args.ds_name,
            network,
            args.num_dx,
            CLASS_NUM[args.ds_name],
            eps=args.eps,
            out_fp_dxdy_dir=args.output_dx_dy_dir)

        optimizer = optim.SGD(network.parameters(),
                              lr=args.lr,
                              weight_decay=1e-6,
                              momentum=0.9)
        resume_epoch = 0
        print("{}".format(model_path))
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path,
                                    lambda storage, location: storage)
            optimizer.load_state_dict(checkpoint["optimizer"])
            resume_epoch = checkpoint["epoch"]
            network.load_state_dict(checkpoint["state_dict"])

        for epoch in range(resume_epoch, args.epochs + 1):
            if (epoch == 1):
                detector.test(epoch,
                              test_loader,
                              test_length=0.1 * len(val_dataset))
            detector.train(epoch, optimizer, train_loader)

            print("Epoch{}, Saving model in {}".format(epoch, model_path))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': network.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, model_path)
    else:  # 测试模式

        if args.study_subject == "speed_test":
            evaluate_speed(args)
            return

        extract_pattern = re.compile(
            ".*NF_Det@(.*?)@(.*?)@epoch_(\d+)@lr_(.*?)@eps_(.*?)@num_dx_(\d+)@num_class_(\d+).pth.tar"
        )
        results = defaultdict(dict)
        for model_path in glob.glob(
                "{}/train_pytorch_model/NF_Det/NF_Det@*".format(PY_ROOT)):
            ma = extract_pattern.match(model_path)
            ds_name = ma.group(1)
            # if ds_name == "ImageNet": # FIXME
            #     continue
            # if ds_name != "CIFAR-10": # FIXME
            #     continue

            arch = ma.group(2)
            epoch = int(ma.group(3))
            num_dx = int(ma.group(6))
            eps = float(ma.group(5))
            if arch == "conv3":
                network = Conv3(IN_CHANNELS[ds_name], IMAGE_SIZE[ds_name],
                                CLASS_NUM[ds_name])
            elif arch == "resnet10":
                network = resnet10(in_channels=IN_CHANNELS[ds_name],
                                   num_classes=CLASS_NUM[ds_name])
            elif arch == "resnet18":
                network = resnet18(in_channels=IN_CHANNELS[ds_name],
                                   num_classes=CLASS_NUM[ds_name])
            reject_thresholds = [0. + 0.001 * i for i in range(2050)]
            network.load_state_dict(
                torch.load(model_path,
                           lambda storage, location: storage)["state_dict"])
            network.cuda()
            print("load {} over".format(model_path))
            detector = NeuralFingerprintDetector(
                ds_name,
                network,
                num_dx,
                CLASS_NUM[ds_name],
                eps=eps,
                out_fp_dxdy_dir=args.output_dx_dy_dir)
            # 不存在cross arch的概念
            if args.study_subject == "shots":
                all_shots = [0, 1, 5]
                # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699}
                old_updates = args.num_updates
                # threhold_dict = {0:0.885896, 1:1.23128099999,5:1.33487699}
                for shot in all_shots:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_updates
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        ds_name,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=args.adv_arch,
                        fetch_attack_name=True)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    # if args.profile:
                    #     cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof")
                    #     s = pstats.Stats("Profile.prof")
                    #     s.strip_dirs().sort_stats("time").print_stats()
                    # else:
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds,
                        args.num_updates, args.lr)
                    results[ds_name][report_shot] = {
                        "F1": F1,
                        "best_tau": tau,
                        "eps": eps,
                        "num_dx": num_dx,
                        "num_updates": args.num_updates,
                        "attack_stats": attacker_stats
                    }
                    print("shot {} done".format(shot))

            elif args.study_subject == "cross_domain":
                source_dataset, target_dataset = args.cross_domain_source, args.cross_domain_target
                if ds_name != source_dataset:
                    continue
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                old_num_update = args.num_updates
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                for shot in [0, 1, 5]:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_num_update
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        target_dataset,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=args.adv_arch,
                        fetch_attack_name=False)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, target_dataset, reject_thresholds,
                        args.num_updates, args.lr)
                    results["{}--{}@data_adv_arch_{}".format(
                        source_dataset, target_dataset,
                        args.adv_arch)][report_shot] = {
                            "F1": F1,
                            "best_tau": tau,
                            "eps": eps,
                            "num_dx": num_dx,
                            "num_updates": args.num_updates,
                            "attack_stats": attacker_stats
                        }
            elif args.study_subject == "cross_arch":
                target_arch = args.cross_arch_target
                old_num_update = args.num_updates
                for shot in [0, 1, 5]:
                    report_shot = shot
                    if shot == 0:
                        shot = 1
                        args.num_updates = 0
                    else:
                        args.num_updates = old_num_update
                    num_way = 2
                    num_query = 15
                    val_dataset = MetaTaskDataset(
                        20000,
                        num_way,
                        shot,
                        num_query,
                        ds_name,
                        is_train=False,
                        load_mode=args.load_task_mode,
                        protocol=args.protocol,
                        no_random_way=True,
                        adv_arch=target_arch,
                        fetch_attack_name=False)
                    adv_val_loader = torch.utils.data.DataLoader(
                        val_dataset, batch_size=100, shuffle=False, **kwargs)
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds,
                        args.num_updates, args.lr)
                    results["{}_target_arch_{}".format(
                        ds_name, target_arch)][report_shot] = {
                            "F1": F1,
                            "best_tau": tau,
                            "eps": eps,
                            "num_dx": num_dx,
                            "num_updates": args.num_updates,
                            "attack_stats": attacker_stats
                        }

            elif args.study_subject == "finetune_eval":
                shot = 1
                query_count = 15
                old_updates = args.num_updates
                num_way = 2
                num_query = 15
                # threhold_dict = {0: 0.885896, 1: 1.23128099999, 5: 1.33487699}
                if ds_name != args.ds_name:
                    continue
                val_dataset = MetaTaskDataset(20000,
                                              num_way,
                                              shot,
                                              num_query,
                                              ds_name,
                                              is_train=False,
                                              load_mode=args.load_task_mode,
                                              protocol=args.protocol,
                                              no_random_way=True,
                                              adv_arch=args.adv_arch)
                adv_val_loader = torch.utils.data.DataLoader(val_dataset,
                                                             batch_size=100,
                                                             shuffle=False,
                                                             **kwargs)
                args.num_updates = 50
                for num_update in range(0, 51):
                    # if args.profile:
                    #     cProfile.runctx("detector.eval_with_fingerprints_finetune(adv_val_loader, ds_name, reject_thresholds, args.num_updates, args.lr)", globals(), locals(), "Profile.prof")
                    #     s = pstats.Stats("Profile.prof")
                    #     s.strip_dirs().sort_stats("time").print_stats()
                    # else:
                    F1, tau, attacker_stats = detector.eval_with_fingerprints_finetune(
                        adv_val_loader, ds_name, reject_thresholds, num_update,
                        args.lr)
                    results[ds_name][num_update] = {
                        "F1": F1,
                        "best_tau": tau,
                        "eps": eps,
                        "num_dx": num_dx,
                        "num_updates": num_update,
                        "attack_stats": attacker_stats
                    }
                    print("finetune {} done".format(shot))

        if not args.profile:
            if args.study_subject == "cross_domain":
                filename = "{}/train_pytorch_model/NF_Det/cross_domain_{}--{}@adv_arch_{}.json".format(
                    PY_ROOT, args.cross_domain_source,
                    args.cross_domain_target, args.adv_arch)
            elif args.study_subject == "cross_arch":
                filename = "{}/train_pytorch_model/NF_Det/cross_arch_target_{}.json".format(
                    PY_ROOT, args.cross_arch_target)
            else:
                filename = "{}/train_pytorch_model/NF_Det/{}@data_{}@protocol_{}@lr_{}@finetune_{}.json".format(
                    PY_ROOT, args.study_subject, args.adv_arch, args.protocol,
                    args.lr, args.num_updates)
            with open(filename, "w") as file_obj:
                file_obj.write(json.dumps(results))
                file_obj.flush()
Esempio n. 18
0
def evaluate_shots(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@ImageNet_TRAIN_I_TEST_II@model_resnet10@data_resnet10@epoch_10@lr_0.001@batch_10@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar"
    )
    result = defaultdict(dict)
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    for model_path in glob.glob(
            "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*"
            .format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset == "ImageNet":
            continue
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        data_adv_arch = ma.group(4)

        lr = float(ma.group(6))
        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        tot_num_tasks = 20000
        num_classes = 2

        num_query = 15
        old_num_update = args.num_updates

        all_shots = [0, 1, 5]
        for shot in all_shots:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(tot_num_tasks,
                                                num_classes,
                                                shot,
                                                num_query,
                                                dataset,
                                                is_train=False,
                                                load_mode=LOAD_TASK_MODE.LOAD,
                                                protocol=split_protocol,
                                                no_random_way=True,
                                                adv_arch=data_adv_arch)
            data_loader = DataLoader(meta_task_dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     pin_memory=True)

            if arch == "resnet10":
                img_classifier_network = resnet10(
                    num_classes=CLASS_NUM[dataset],
                    in_channels=IN_CHANNELS[dataset],
                    pretrained=False)
            elif arch == "resnet18":
                img_classifier_network = resnet18(
                    num_classes=CLASS_NUM[dataset],
                    in_channels=IN_CHANNELS[dataset],
                    pretrained=False)
            elif arch == "conv3":
                img_classifier_network = Conv3(IN_CHANNELS[dataset],
                                               IMAGE_SIZE[dataset],
                                               CLASS_NUM[dataset])

            image_transform = ImageTransformCV2(dataset, [1, 2])
            layer_number = 3 if dataset in [
                "ImageNet", "CIFAR-10", "CIFAR-100", "SVHN"
            ] else 2
            model = Detector(dataset,
                             img_classifier_network,
                             CLASS_NUM[dataset],
                             image_transform,
                             layer_number,
                             num_classes=2)
            checkpoint = torch.load(
                model_path, map_location=lambda storage, location: storage)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})".format(
                model_path, checkpoint['epoch']))
            evaluate_result = finetune_eval_task_accuracy(model,
                                                          data_loader,
                                                          lr,
                                                          num_update,
                                                          update_BN=False)
            if num_update == 0:
                shot = 0
            result[dataset + "@" + data_adv_arch][shot] = evaluate_result
    with open(
            "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/shots_result.json"
            .format(PY_ROOT), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
Esempio n. 19
0
def define_model(model_type,
                 pretrained_path='',
                 neighbour_slice=args.neighbour_slice,
                 input_type=args.input_type,
                 output_type=args.output_type):
    if input_type == 'diff_img':
        input_channel = neighbour_slice - 1
    else:
        input_channel = neighbour_slice

    if model_type == 'prevost':
        model_ft = generators.PrevostNet()
    elif model_type == 'resnext50':
        model_ft = resnext.resnet50(sample_size=2,
                                    sample_duration=16,
                                    cardinality=32)
        model_ft.conv1 = nn.Conv3d(in_channels=1,
                                   out_channels=64,
                                   kernel_size=(3, 7, 7),
                                   stride=(1, 2, 2),
                                   padding=(1, 3, 3),
                                   bias=False)
    elif model_type == 'resnext101':
        model_ft = resnext.resnet101(sample_size=2,
                                     sample_duration=16,
                                     cardinality=32)
        model_ft.conv1 = nn.Conv3d(in_channels=1,
                                   out_channels=64,
                                   kernel_size=(3, 7, 7),
                                   stride=(1, 2, 2),
                                   padding=(1, 3, 3),
                                   bias=False)
        # model_ft.conv1 = nn.Conv3d(neighbour_slice, 64, kernel_size=7, stride=(1, 2, 2),
        #                            padding=(3, 3, 3), bias=False)
    elif model_type == 'resnet152':
        model_ft = resnet.resnet152(pretrained=True)
        model_ft.conv1 = nn.Conv2d(input_channel,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
    elif model_type == 'resnet101':
        model_ft = resnet.resnet101(pretrained=True)
        model_ft.conv1 = nn.Conv2d(input_channel,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
    elif model_type == 'resnet50':
        model_ft = resnet.resnet50(pretrained=True)
        model_ft.conv1 = nn.Conv2d(input_channel,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
    elif model_type == 'resnet34':
        model_ft = resnet.resnet34(pretrained=False)
        model_ft.conv1 = nn.Conv2d(input_channel,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
    elif model_type == 'resnet18':
        model_ft = resnet.resnet18(pretrained=True)
        model_ft.conv1 = nn.Conv2d(input_channel,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
    elif model_type == 'mynet':
        model_ft = mynet.resnet50(sample_size=2,
                                  sample_duration=16,
                                  cardinality=32)
        model_ft.conv1 = nn.Conv3d(in_channels=1,
                                   out_channels=64,
                                   kernel_size=(3, 7, 7),
                                   stride=(1, 2, 2),
                                   padding=(0, 3, 3),
                                   bias=False)
    elif model_type == 'mynet2':
        model_ft = generators.My3DNet()
    elif model_type == 'p3d':
        model_ft = p3d.P3D63()
        model_ft.conv1_custom = nn.Conv3d(1,
                                          64,
                                          kernel_size=(1, 7, 7),
                                          stride=(1, 2, 2),
                                          padding=(0, 3, 3),
                                          bias=False)
    elif model_type == 'densenet121':
        model_ft = densenet.densenet121()
    else:
        print('network type of <{}> is not supported, use original instead'.
              format(network_type))
        model_ft = generators.PrevostNet()

    num_ftrs = model_ft.fc.in_features

    if model_type == 'mynet':
        num_ftrs = 384
    elif model_type == 'prevost':
        num_ftrs = 576

    if output_type == 'average_dof' or output_type == 'sum_dof':
        # model_ft.fc = nn.Linear(128, 6)
        model_ft.fc = nn.Linear(num_ftrs, 6)
    else:
        # model_ft.fc = nn.Linear(128, (neighbour_slice - 1) * 6)
        model_ft.fc = nn.Linear(num_ftrs, (neighbour_slice - 1) * 6)

    # if args.training_mode == 'finetune':
    #     model_path = path.join(results_dir, args.model_filename)
    #     if path.isfile(model_path):
    #         print('Loading model from <{}>...'.format(model_path))
    #         model_ft.load_state_dict(torch.load(model_path))
    #         print('Done')
    #     else:
    #         print('<{}> not exists! Training from scratch...'.format(model_path))

    if pretrained_path:
        if path.isfile(pretrained_path):
            print('Loading model from <{}>...'.format(pretrained_path))
            model_ft.load_state_dict(
                torch.load(pretrained_path, map_location='cuda:0'))
            # model_ft.load_state_dict(torch.load(pretrained_path))
            print('Done')
        else:
            print('<{}> not exists! Training from scratch...'.format(
                pretrained_path))
    else:
        print('Train this model from scratch!')

    model_ft.cuda()
    model_ft = model_ft.to(device)
    print('define model device {}'.format(device))
    return model_ft
Esempio n. 20
0
def main_train_worker(args, model_path, META_ATTACKER_PART_I=None,META_ATTACKER_PART_II=None,gpu="0"):
    if META_ATTACKER_PART_I  is None:
        META_ATTACKER_PART_I = config.META_ATTACKER_PART_I
    if META_ATTACKER_PART_II is None:
        META_ATTACKER_PART_II = config.META_ATTACKER_PART_II
    print("Use GPU: {} for training".format(gpu))
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    print("will save to {}".format(model_path))
    global best_acc1
    if args.arch == "conv3":
        model = Conv3(IN_CHANNELS[args.dataset],IMAGE_SIZE[args.dataset], 2)
    elif args.arch == "resnet10":
        model = resnet10(2, in_channels=IN_CHANNELS[args.dataset], pretrained=False)
    elif args.arch == "resnet18":
        model = resnet18(2, in_channels=IN_CHANNELS[args.dataset], pretrained=False)
    model = model.cuda()
    if args.dataset == "ImageNet":
        train_dataset = AdversaryRandomAccessNpyDataset(IMAGE_DATA_ROOT[args.dataset] + "/adversarial_images/{}".format(args.adv_arch),
                                                        True, args.protocol, META_ATTACKER_PART_I,META_ATTACKER_PART_II,
                                                        args.balance, args.dataset)
    else:
        train_dataset = AdversaryDataset(IMAGE_DATA_ROOT[args.dataset] + "/adversarial_images/{}".format(args.adv_arch),
                                     True, args.protocol, META_ATTACKER_PART_I,META_ATTACKER_PART_II, args.balance)

    # val_dataset = MetaTaskDataset(20000, 2, 1, 15,
    #                                     args.dataset, is_train=False, pkl_task_dump_path=args.test_pkl_path,
    #                                     load_mode=LOAD_TASK_MODE.LOAD,
    #                                     protocol=args.protocol, no_random_way=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if os.path.exists(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(model_path, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(model_path))

    cudnn.benchmark = True

    # Data loading code
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,
    #     batch_size=100, shuffle=False,
    #     num_workers=0, pin_memory=True)
    tensorboard = TensorBoardWriter("{0}/pytorch_DeepLearning_tensorboard".format(PY_ROOT), "DeepLearning")
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)
        # train for one epoch
        train(train_loader, None, model, criterion, optimizer, epoch,tensorboard, args)
        if args.balance:
            train_dataset.img_label_list.clear()
            train_dataset.img_label_list.extend(train_dataset.img_label_dict[1])
            train_dataset.img_label_list.extend(random.sample(train_dataset.img_label_dict[0], len(train_dataset.img_label_dict[1])))
        # evaluate_accuracy on validation set

        # acc1 = validate(val_loader, model, criterion, args)
        # remember best acc@1 and save checkpoint
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=model_path)
Esempio n. 21
0
def main_train_worker(gpu, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    # create model
    if args.pretrained and args.arch in models.__dict__:
        print("=> using pre-trained model '{}'".format(args.arch))
        network = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        if args.arch == "resnet10":
            network = resnet10(num_classes=CLASS_NUM[args.dataset], in_channels=IN_CHANNELS[args.dataset])
        elif args.arch == "resnet18":
            network = resnet18(num_classes=CLASS_NUM[args.dataset], in_channels=IN_CHANNELS[args.dataset])
        elif args.arch == "conv3":
            network = Conv3(IN_CHANNELS[args.dataset], IMAGE_SIZE[args.dataset],CLASS_NUM[args.dataset])

    if args.arch.startswith("resnet"):
        network.avgpool = Identity()
        network.fc = nn.Linear(512, CLASS_NUM[args.dataset])
    elif args.arch.startswith("vgg"):
        network.classifier[6] = nn.Linear(4096, CLASS_NUM[args.dataset])

    model_path = './train_pytorch_model/DL_IMAGE_CLASSIFIER_{}@{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format(
        args.dataset, args.arch, args.epochs, args.lr, args.batch_size)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    preprocessor = get_preprocessor(IN_CHANNELS[args.dataset])
    network.cuda()

    # define loss function (criterion) and optimizer
    image_classifier_loss = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.Adam(network.parameters(), args.lr,
                                weight_decay=args.weight_decay)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            network.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.dataset == "CIFAR-10":
        train_dataset = CIFAR10(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor)
        val_dataset = CIFAR10(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor)
    elif args.dataset == "MNIST":
        train_dataset = MNIST(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor, download=True)
        val_dataset = MNIST(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor, download=True)
    elif args.dataset == "F-MNIST":
        train_dataset = FashionMNIST(IMAGE_DATA_ROOT[args.dataset], train=True,transform=preprocessor, download=True)
        val_dataset = FashionMNIST(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor, download=True)
    elif args.dataset=="SVHN":
        train_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset], train=True, transform=preprocessor)
        val_dataset = SVHN(IMAGE_DATA_ROOT[args.dataset], train=False, transform=preprocessor)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    for epoch in range(0, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, network, image_classifier_loss, optimizer, epoch, args)

        # evaluate_accuracy on validation set
        acc1 = validate(val_loader, network, image_classifier_loss, args)
        print(acc1)
        # remember best acc@1 and save checkpoint

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': network.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=model_path)