コード例 #1
0
def resume_finetuning_from_checkpoint(args, ds, finetuned_model_path):
    """Given arguments, dataset object and a finetuned model_path, returns a model
    with loaded weights and returns the checkpoint necessary for resuming training.
    """
    print("[Resuming finetuning from a checkpoint...]")
    if (
        args.dataset in list(transfer_datasets.DS_TO_FUNC.keys())
        and not args.cifar10_cifar10
    ):
        model, _ = model_utils.make_and_restore_model(
            arch=pytorch_models[args.arch](args.pytorch_pretrained)
            if args.arch in pytorch_models.keys()
            else args.arch,
            dataset=datasets.ImageNet(""),
            add_custom_forward=args.arch in pytorch_models.keys(),
        )
        while hasattr(model, "model"):
            model = model.model
        model = fine_tunify.ft(args.arch, model, ds.num_classes, args.additional_hidden)
        model, checkpoint = model_utils.make_and_restore_model(
            arch=model,
            dataset=ds,
            resume_path=finetuned_model_path,
            add_custom_forward=args.additional_hidden > 0
            or args.arch in pytorch_models.keys(),
        )
    else:
        model, checkpoint = model_utils.make_and_restore_model(
            arch=args.arch, dataset=ds, resume_path=finetuned_model_path
        )
    return model, checkpoint
コード例 #2
0
def get_model(args, ds):
    """Given arguments and a dataset object, returns an ImageNet model (with appropriate last layer changes to 
    fit the target dataset) and a checkpoint.The checkpoint is set to None if noe resuming training.
    """
    finetuned_model_path = os.path.join(
        args.out_dir, "checkpoint.pt.latest"
    )
    if args.resume and os.path.isfile(finetuned_model_path):
        model, checkpoint = resume_finetuning_from_checkpoint(
            args, ds, finetuned_model_path
        )
    else:
        if (
            args.dataset in list(transfer_datasets.DS_TO_FUNC.keys())
            and not args.cifar10_cifar10
        ):
            model, _ = model_utils.make_and_restore_model(
                arch=pytorch_models[args.arch](args.pytorch_pretrained)
                if args.arch in pytorch_models.keys()
                else args.arch,
                dataset=datasets.ImageNet(""),
                resume_path=args.model_path,
                pytorch_pretrained=args.pytorch_pretrained,
                add_custom_forward=args.arch in pytorch_models.keys(),
            )
            checkpoint = None
        else:
            model, _ = model_utils.make_and_restore_model(
                arch=args.arch,
                dataset=ds,
                resume_path=args.model_path,
                pytorch_pretrained=args.pytorch_pretrained,
            )
            checkpoint = None

        if not args.no_replace_last_layer and not args.eval_only:
            print(
                f"[Replacing the last layer with {args.additional_hidden} "
                f"hidden layers and 1 classification layer that fits the {args.dataset} dataset.]"
            )
            while hasattr(model, "model"):
                model = model.model
            model = fine_tunify.ft(
                args.arch, model, ds.num_classes, args.additional_hidden
            )
            model, checkpoint = model_utils.make_and_restore_model(
                arch=model,
                dataset=ds,
                add_custom_forward=args.additional_hidden > 0
                or args.arch in pytorch_models.keys(),
            )
        else:
            print("[NOT replacing the last layer]")
    return model, checkpoint
コード例 #3
0
def load_model(arch, dataset=None):
    '''
    Load pretrained model with specified architecture.
    Args:
        arch (str): name of one of the pytorch pretrained models or 
                    "robust" for robust model
        dataset (dataset object): not None only for robust model
    Returns:
        model: loaded model
    '''

    if arch != 'robust':
        model = eval(arch)(pretrained=True).cuda()
        model.eval()
        pass
    else:
        model_kwargs = {
            'arch': 'resnet50',
            'dataset': dataset,
            'resume_path': f'./models/RestrictedImageNet.pt'
        }

        model, _ = model_utils.make_and_restore_model(**model_kwargs)
        model.eval()
        model = model.module.model
    return model
コード例 #4
0
def get_model(args, ds):
    # An option to resume finetuning from a checkpoint. Only for Imagenet-Imagenet transfer
    finetuned_model_path = os.path.join(args.out_dir, args.exp_name,
                                        'checkpoint.pt.latest')
    if args.resume and os.path.isfile(finetuned_model_path):
        model, checkpoint = resume_finetuning_from_checkpoint(
            args, ds, finetuned_model_path)
    else:

        if args.dataset in list(transfer_datasets.DS_TO_FUNC.keys()
                                ) and not args.cifar10_cifar10:
            model, _ = model_utils.make_and_restore_model(
                arch=pytorch_models[args.arch](args.pytorch_pretrained)
                if args.arch in pytorch_models.keys() else args.arch,
                dataset=datasets.ImageNet(''),
                resume_path=args.model_path,
                pytorch_pretrained=args.pytorch_pretrained,
                add_custom_forward=args.arch in pytorch_models.keys())
            checkpoint = None
        else:
            model, _ = model_utils.make_and_restore_model(
                arch=args.arch,
                dataset=ds,
                resume_path=args.model_path,
                pytorch_pretrained=args.pytorch_pretrained)
            checkpoint = None

        # For all other datasets, replace the last layer then finetine, unless otherwise specified using
        # the args.no_replace_last_layer flag
        if not args.no_replace_last_layer and not args.eval_only:
            print(
                f'[Replacing the last layer with {args.additional_hidden} '
                f'hidden layers and 1 classification layer that fits the {args.dataset} dataset.]'
            )
            while hasattr(model, 'model'):
                model = model.model
            model = fine_tunify.ft(args.arch, model, ds.num_classes,
                                   args.additional_hidden)
            model, checkpoint = model_utils.make_and_restore_model(
                arch=model,
                dataset=ds,
                add_custom_forward=args.additional_hidden > 0
                or args.arch in pytorch_models.keys())
        else:
            print('[NOT replacing the last layer]')

    return model, checkpoint
コード例 #5
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--orig-data', default='./data/cifar10', help='path to the original dataset')
    parser.add_argument('--enc-data', default='./data', help='path to output encrypted dataset')
    parser.add_argument('--resume-path', default='./logs/checkpoints/dir/resnet50/checkpoint.pt.best', help='path to checkpoint to resume from')
    parser.add_argument('--workers', type=int, help='data loading workers', default=8)
    parser.add_argument('--batch-size', type=int, default=128, help='batch size for data loading')
    parser.add_argument('--arch', default='resnet50', help='architecture')
    parser.add_argument('--eps', type=float, default=0.5, help='adversarial perturbation budget')
    parser.add_argument('--attack-lr', type=float, default=0.1, help='step size for PGD')
    parser.add_argument('--attack-steps', type=int, default=100, help='number of steps for adversarial attack')
    parser.add_argument('--enc-method', default='basic', choices=['basic', 'mixup', 'horiz', 'mixandcat'], help='encryption method')
    parser.add_argument('--alpha', type=float, default=0.5, help='hyperparameter in horizontal concat')
    parser.add_argument('--lambd', type=float, default=0.5, help='hyperparameter in mixup')
    parser.add_argument('--manual-seed', type=int, default=23, help='manual seed')

    opt = parser.parse_args()

    if opt.manual_seed is None:
        opt.manual_seed = random.randint(1, 10000)
    print("Random Seed: ", opt.manual_seed)
    ch.manual_seed(opt.manual_seed)

    kwargs = {
        'constraint': '2',
        'eps': opt.eps,
        'step_size': opt.attack_lr,
        'iterations': opt.attack_steps,
        'targeted': True,
        'do_tqdm': False
    }

    ds = CIFAR(opt.orig_data)
    train_loader, test_loader = ds.make_loaders(workers=opt.workers,
                                                batch_size=opt.batch_size,
                                                data_aug=False,
                                                shuffle_train=False,
                                                )

    model, _ = make_and_restore_model(arch=opt.arch, dataset=ds, resume_path=opt.resume_path)
    model.eval()


    if opt.enc_method == 'basic':
        generate_train_data(train_loader, model, kwargs, opt)
        generate_test_data(test_loader, model, kwargs, opt)

    elif opt.enc_method == 'mixup':
        generate_train_data_mixup(train_loader, model, kwargs, opt)
        generate_test_data_mixup(test_loader, model, kwargs, opt)

    elif opt.enc_method == 'horiz':
        generate_train_data_horiz(train_loader, model, kwargs, opt)
        generate_test_data_horiz(test_loader, model, kwargs, opt)

    elif opt.enc_method == 'mixandcat':
        generate_train_data_mixandcat(train_loader, model, kwargs, opt)
        generate_test_data_mixandcat(test_loader, model, kwargs, opt)
コード例 #6
0
def main():
    path = sys.argv[-1]
    metrics = sys.argv[-2]

    name = 'pred'
    if sys.argv[-3].startswith('advs'):
        name += '_advs'

    if not os.path.exists('eval/{}_{}.npy'.format(name,
                                                  path.split('/')[-1][:-3])):
        if path.endswith('_best.pt'):
            ds = CustomCIFAR(int(path.split('/')[-1].split('_')[0]),
                             '/home/zhuzby/data')
        else:
            ds = CIFAR('/home/zhuzby/data')
        model, _ = make_and_restore_model(arch='resnet50',
                                          dataset=ds,
                                          resume_path=path)
        model = model.eval()
        if sys.argv[-3].startswith('advs'):
            im_adv = np.load(sys.argv[-3])
            labs = np.load(
                os.path.join(sys.argv[-3].split('/')[0],
                             'labels_' + sys.argv[-3].split('/')[1]))
            data = TensorDataset(torch.tensor(im_adv), torch.tensor(labs))
            test_loader = DataLoader(data,
                                     batch_size=128,
                                     num_workers=8,
                                     shuffle=False)
        else:
            _, test_loader = ds.make_loaders(workers=8, batch_size=128)
        preds, labels = [], []
        for i, (im, label) in enumerate(test_loader):
            output, _ = model(im)
            label = label.cpu().numpy()
            preds = output.detach().cpu().numpy() if len(
                preds) == 0 else np.vstack(
                    (preds, output.detach().cpu().numpy()))
            labels = label if len(labels) == 0 else np.hstack((labels, label))
        np.save('eval/{}_{}.npy'.format(name, path.split('/')[-1][:-3]), preds)
        np.save('eval/label_{}.npy'.format(path.split('/')[-1][:-3]), labels)
    else:
        preds = np.load('eval/{}_{}.npy'.format(name,
                                                path.split('/')[-1][:-3]))
        labels = np.load('eval/label_{}.npy'.format(path.split('/')[-1][:-3]))

    if metrics == 'origin':
        check_normal(preds, labels)
    elif metrics == 'hamming':
        check_hamming(int(path.split('/')[-1].split('_')[0]), preds, labels,
                      int(sys.argv[-4]),
                      path.split('/')[-1][:-3])
コード例 #7
0
 def __init__(self):
     super(resnet, self).__init__()
     dataset = datasets.RestrictedImageNet('')
     model_kwargs = {
         "arch": 'resnet50',
         'dataset': dataset,
         'resume_path': './RestrictedImageNet.pt',
         'parallel': False
     }
     model, ckpt = model_utils.make_and_restore_model(**model_kwargs)
     self.model = model.model
     for param in self.parameters():
         param.requires_grad = False
コード例 #8
0
def load_robust_model():
    dataset_function = getattr(dataset_utils, 'ImageNet')
    dataset = dataset_function('')

    model_kwargs = {
        'arch': 'resnet50',
        'dataset': dataset,
        'resume_path': f'./models/robust_resnet50.pth',
        'parallel': False
    }
    model, _ = model_utils.make_and_restore_model(**model_kwargs)
    model.eval()
    return model
コード例 #9
0
def load_madrylab_imagenet(arch):
    data = "ImageNet"
    dataset_function = getattr(datasets, data)
    dataset = dataset_function(DATA_PATH_DICT[data])
    model_kwargs = {
        "arch": arch,
        "dataset": dataset,
        "resume_path": f"madrylab_models/{data}.pt",
        "state_dict_path": "model",
    }
    (model, _) = model_utils.make_and_restore_model(**model_kwargs)

    return model
コード例 #10
0
ファイル: utils.py プロジェクト: iamgroot42/auto-attack
 def get_model(self, m_type, arch='resnet50'):
     model_path = self.models.get(m_type, None)
     if not model_path:
         model_path = m_type
     else:
         model_path = self.model_prefix[arch] + self.models[m_type]
     model_kwargs = {
         'arch': arch,
         'dataset': self.dataset,
         'resume_path': model_path
     }
     model, _ = make_and_restore_model(**model_kwargs)
     model.eval()
     return model
コード例 #11
0
def resume_finetuning_from_checkpoint(args, ds, finetuned_model_path):
    print('[Resuming finetuning from a checkpoint...]')
    if args.dataset in list(
            transfer_datasets.DS_TO_FUNC.keys()) and not args.cifar10_cifar10:
        model, _ = model_utils.make_and_restore_model(
            arch=pytorch_models[args.arch](args.pytorch_pretrained)
            if args.arch in pytorch_models.keys() else args.arch,
            dataset=datasets.ImageNet(''),
            add_custom_forward=args.arch in pytorch_models.keys())
        while hasattr(model, 'model'):
            model = model.model
        model = fine_tunify.ft(args.arch, model, ds.num_classes,
                               args.additional_hidden)
        model, checkpoint = model_utils.make_and_restore_model(
            arch=model,
            dataset=ds,
            resume_path=finetuned_model_path,
            add_custom_forward=args.additional_hidden > 0
            or args.arch in pytorch_models.keys())
    else:
        model, checkpoint = model_utils.make_and_restore_model(
            arch=args.arch, dataset=ds, resume_path=finetuned_model_path)
    return model, checkpoint
コード例 #12
0
def load_restricted_imagenet_model():
    model_kwargs = {
        'arch': 'resnet50',
        'dataset': RestrictedImageNet('./data'),
        'resume_path': f'./models/RestrictedImageNet.pt'
    }

    model, _ = model_utils.make_and_restore_model(**model_kwargs)

    try:
        model = model.module.model
    except:
        model = model.model

    return model
コード例 #13
0
def get_deep_features(dataset, loader, model_root, arch, device='cuda'): 
    rng = np.random.RandomState(random_seed)

    model, _ = make_and_restore_model(arch=arch, 
             dataset=dataset,
             resume_path=model_root
        )
    model.eval()
    model = ch.nn.DataParallel(model.to(device))

    latents, labels = [], []
    for _, (X,y) in tqdm(enumerate(loader), total=len(loader)): 
        (op, reps), _ = model(X.to(device), with_latent=True)
        latents.append(reps.cpu())
        labels.append(y)
    return ch.cat(latents), ch.cat(labels)
コード例 #14
0
def load_model(model_name, model_path, dataset):
    if model_name == 'ImageNetNat.pt':
        model_kwargs = {
            'arch': 'resnet50',
            'dataset': dataset,
            'pytorch_pretrained': True,
            'parallel': False
        }
    else:
        model_kwargs = {
            'arch': 'resnet50',
            'dataset': dataset,
            'resume_path': model_path + model_name,
            'parallel': False
        }
    model, _ = model_utils.make_and_restore_model(**model_kwargs)
    model.eval()
    return model
コード例 #15
0
def obtain_model(model_type):
    if model_type != 'robust':
        checkpoint_path = 'runs/amdim_cpt.pth'
        checkpointer = Checkpointer()
        print('Loading model')
        model = checkpointer.restore_model_from_checkpoint(checkpoint_path)
        torch_device = torch.device('cuda')
        model = model.to(torch_device)
    else:
        dataset = robustness.datasets.CIFAR()
        model_kwargs = {
            'arch':
            'resnet50',
            'dataset':
            dataset,
            'resume_path':
            f'../robust_classif/robustness_applications/models/CIFAR.pt'
        }
        model, _ = model_utils.make_and_restore_model(**model_kwargs)
    model.eval()
    model = CommonModel(model, model_type)
    return model
コード例 #16
0
ファイル: generation.py プロジェクト: gatechke/xray-ai
data_iterator = enumerate(test_loader)

#arch = CombineNet()#.to(device)
arch = resnet14_1(add_softmax=False, output_channels=4, latent_dim=192)
#arch = nn.DataParallel(arch)
#arch.eval()
print("Parameters: {}".format(sum(p.numel() for p in arch.parameters())))

# Load model
model_kwargs = {
    'arch': arch,
    'dataset': dataset,
    'resume_path': args.load_model + "/best.ckpt"
}

model, _ = model_utils.make_and_restore_model(**model_kwargs)
model.eval()


def downsample(x, step=GRAIN):
    down = ch.zeros([len(x), 1, DATA_SHAPE // step, DATA_SHAPE // step])

    for i in range(0, DATA_SHAPE, step):
        for j in range(0, DATA_SHAPE, step):
            v = x[:, :, i:i + step,
                  j:j + step].mean(dim=2, keepdim=True).mean(dim=3,
                                                             keepdim=True)
            ii, jj = i // step, j // step
            down[:, :, ii:ii + 1, jj:jj + 1] = v
    return down
コード例 #17
0
            'alexnet': models.alexnet,
            'vgg16': models.vgg16,
            'vgg16_bn': models.vgg16_bn,
            'squeezenet': models.squeezenet1_0,
            'densenet': models.densenet161,
            # 'inception': models.inception_v3,
            # 'googlenet': models.googlenet,
            'shufflenet': models.shufflenet_v2_x1_0,
            'mobilenet': models.mobilenet_v2,
            'resnext50_32x4d': models.resnext50_32x4d,
            'mnasnet': models.mnasnet1_0,
        }

        model, checkpoint = model_utils.make_and_restore_model(
            arch=pytorch_models[ARCH](True),
            dataset=ds,
            resume_path=model_path,
            add_custom_forward=True)
        # model, checkpoint = model_utils.make_and_restore_model(arch='resnet50', dataset=ds, resume_path=model_path, add_custom_forward=False)
        train_loader, val_loader = ds.make_loaders(batch_size=64, workers=4)

        correct = 0
        model.eval()
        with ch.no_grad():
            for X, y in tqdm(val_loader):
                X, y = X.cuda(), y.cuda()
                out = model(X, with_image=False)
                _, pred = out.topk(1, 1)
                correct += (pred.squeeze() == y).detach().cpu().sum()
        print(
            f'The clean accuracy is {1.*correct/len(val_loader.dataset)*100.}%'
コード例 #18
0
    elif args.net == "resnet50-imagenet":
        print('Using ImageNet pretrained ResNet-50')
        net = torchvision.models.resnet50(pretrained=True)
        net = nn.Sequential(
            net, nn.Linear(in_features=1000, out_features=200, bias=True))
        net.load_state_dict(
            torch.load('../../drive/My Drive/tiny_imagenet/best_model.pth')
            ['model_state_dict'])

    elif args.net == "resnet-madry":
        print("Using ResNet-50 with Madry training")
        from robustness.model_utils import make_and_restore_model
        from robustness.datasets import CIFAR
        ds = CIFAR('../data')
        net, _ = make_and_restore_model(parallel=False,
                                        arch='resnet50',
                                        dataset=ds,
                                        resume_path='./cifar_linf_8.pt')
        net = net.model

    elif args.net == "resnet-madry-kernel-def":
        print("Using ResNet-50 with Madry training")
        from robustness.model_utils import make_and_restore_model
        from robustness.datasets import CIFAR
        ds = CIFAR('../data')
        net, _ = make_and_restore_model(parallel=False,
                                        arch='resnet50',
                                        dataset=ds,
                                        resume_path='./cifar_linf_8.pt')
        net = net.model
        state_dict = net.state_dict()
        new_state_dict = OrderedDict()
コード例 #19
0
        ts.append(
            Vingette((c, args.crop_resize_crop, args.crop_resize_crop),
                     args.vingette,
                     pt=True,
                     batch_dim=False,
                     offset=args.vingette_offset))
    test_transform = transforms.Compose(ts)

    if args.dataset in ['imagenet', 'restricted_imagenet', 'cifar']:
        ds_class = datasets.DATASETS[args.dataset]
        ds = ds_class(args.data_dir)
        ds.transform_train = training_transform
        ds.transform_test = test_transform
        m, params = model_utils.make_and_restore_model(
            arch=('resnet18' if args.dataset == 'cifar' else 'resnet50'),
            dataset=ds,
            resume_path=args.resume_from,
            parallel=False)

        if args.start_epoch is None:
            if args.resume_from is not None:
                args.start_epoch = params['epoch']
            else:
                args.start_epoch = 0
        train_loader, _ = ds.make_loaders(batch_size=args.batch_size,
                                          workers=args.nr_workers)
        _, val_loader = ds.make_loaders(batch_size=args.batch_size // 2,
                                        workers=args.nr_workers)
    elif 'mnist' in args.dataset:
        ds_class = datasets.DATASETS['cifar']  # pretend we are cifar
        train_dataset = get_dataset(args,
コード例 #20
0
def get_boosted_model(args, ds):
    is_pt_model = args.arch in constants.NAME_TO_ARCH and args.dataset == 'imagenet'
    arch = constants.NAME_TO_ARCH[args.arch](
        args.pytorch_pretrained) if is_pt_model else args.arch
    num_classes = 1 if args.single_class else ds.num_classes

    if arch == 'linear':
        arch = LinearModel(num_classes, constants.DS_TO_DIM[args.dataset])

    kwargs = {
        'arch': arch,
        'dataset': ds,
        'resume_path': args.model_path,
        'add_custom_forward': is_pt_model or args.arch == 'linear',
        'pytorch_pretrained': args.pytorch_pretrained
    }

    model, _ = model_utils.make_and_restore_model(**kwargs)

    # Wrap the model wtith DataAugmentedModel even if there are not corruptions.
    # For consistenct when loading from checkpoints
    model = boosters.DataAugmentedModel(
        model, ds.ds_name,
        args.augmentations.split(',') if args.augmentations else [])

    # don't pass checkpoint to train_model do avoid resuming for epoch, optimizers etc.
    if args.boosting == 'class_consistent':
        boosting_path = Path(args.out_dir) / BOOSTING_FP
        if boosting_path.exists():
            booster = ch.load(boosting_path)
        else:
            dim = constants.DS_TO_DIM[args.dataset]
            booster = boosters.ClassConsistentBooster(
                ds.num_classes,
                dim,
                constants.PATCH_TRANSFORMS,
                args.patch_size,
                model,
                apply_transforms=args.apply_booster_transforms)

        model = boosters.BoostedModel(model, booster, args.training_mode)
    elif args.boosting == '3d':
        boosting_path = Path(args.out_dir) / BOOSTING_FP
        if boosting_path.exists():
            booster = ch.load(boosting_path)
        else:
            dim = constants.DS_TO_DIM[args.dataset]
            render_options = {
                'min_zoom': args.min_zoom,
                'max_zoom': args.max_zoom,
                'min_light': args.min_light,
                'max_light': args.max_light,
                'samples': args.render_samples
            }
            corruptions = constants.THREE_D_CORRUPTIONS if args.add_corruptions else None
            booster = boosters.ThreeDBooster(
                num_classes=num_classes,
                tex_size=args.patch_size,
                image_size=dim,
                batch_size=args.batch_size,
                render_options=render_options,
                num_texcoords=args.num_texcoord_renderers,
                num_gpus=ch.cuda.device_count(),
                debug=args.debug,
                forward_render=args.forward_render,
                custom_file=args.custom_file,
                corruptions=corruptions)

        model = boosters.BoostedModel(model, booster, args.training_mode)
    elif args.boosting == 'none':
        # assert args.eval_only
        model = boosters.BoostedModel(model, None, args.training_mode)
    else:
        raise ValueError(f'boosting not found: {args.boosting}')

    return model.cuda()
コード例 #21
0
        return torch.clamp(torch.min(torch.max(adv, img - eps), img + eps),
                           0.0, 1.0)

    def forward(self, inp, target, eps, step_size, iterations, **kwargs):
        adv = inp + step_size * torch.rand_like(inp)
        for _ in range(iterations):
            adv = adv.clone().detach().requires_grad_(True)
            loss = self.calc_loss(adv, target)
            loss.backward()
            adv = self.clip(adv + step_size * torch.sign(adv.grad.data), inp,
                            eps)  # gradient ASCENT
        return adv.clone().detach()


ds = CIFAR('/scratch/raunakc/datasets/cifar10')
model, _ = make_and_restore_model(arch='resnet18', dataset=ds)
model.attacker = WhiteboxPGD(model.model, ds)

train_kwargs = {
    'dataset': 'cifar',
    'arch': 'resnet',
    'out_dir': "train_out",
    'adv_train': 1,
    'adv_eval': 1,
    'eps': 8 / 255,
    'attack_lr': 2 / 255,
    'attack_steps': 10,
    'constraint': 'inf'  # not required but arg checker requires it :(
}

args = utils.Parameters(train_kwargs)
コード例 #22
0
	else:
		model = torch.load(MODEL, map_location=device)['net']
	print('model loaded')
else:
	print("=> no checkpoint found at '{}'".format(MODEL))
'''

from CPU_utils import make_and_restore_model_CPU_only

MODEL = "cifar_nat.pt"

ds = CIFAR('data/cifar-10-batches-py')

if use_cuda:
    model, _ = make_and_restore_model(arch='resnet50',
                                      dataset=ds,
                                      resume_path=MODEL,
                                      parallel=False)
else:
    model, _ = make_and_restore_model_CPU_only(arch='resnet50',
                                               dataset=ds,
                                               resume_path=MODEL,
                                               parallel=False)

model = model.model
model.eval()

import torchvision.models as models

# resnet18 = models.resnet18(pretrained=True)
# vgg16 = models.vgg16(pretrained=True)
コード例 #23
0
ファイル: model_utils.py プロジェクト: RoZvEr/adversarial
def convert_to_robustness(model, state_dict):
    dataset = ImageNet('dataset/imagenet-airplanes')
    model, _ = make_and_restore_model(arch=model, dataset=dataset)
    state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
    return model, state_dict
コード例 #24
0
from PIL import Image
import numpy as np

from robustness.datasets import ImageNet
from robustness.model_utils import make_and_restore_model
import torch
import matplotlib.pyplot as plt

ds = ImageNet('/tmp')
model, _ = make_and_restore_model(
    arch='resnet50',
    dataset=ds,
    resume_path='/home/siddhant/Downloads/imagenet_l2_3_0.pt')
model.eval()

img = np.asarray(
    Image.open(
        '/home/siddhant/CMU/robustness_applications/sample_inputs/img_bear.jpg'
    ).resize((224, 224)))
img = img / 254.
img = np.transpose(img, (2, 0, 1))

_IMAGENET_MEAN = [0.485, 0.456, 0.406]
_IMAGENET_STDDEV = [0.229, 0.224, 0.225]

img_var = torch.tensor(img, dtype=torch.float)[None, :]
img = img_var.clone().detach().cpu().numpy()
img = img[0]

img = img.transpose((1, 2, 0))
img *= 255
コード例 #25
0
def load_model(model_type):
    if model_type == "simclr":
        # load checkpoint for simclr
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/resnet50-1x.pth')
        resnet = models.resnet50(pretrained=False)
        resnet.load_state_dict(checkpoint['state_dict'])
        # preprocess images for simclr
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor()
        ])
        return resnet

    if model_type == "simclr_v2_0":
        # load checkpoint for simclr
        checkpoint = torch.load('/content/gdrive/MyDrive/r50_1x_sk0.pth')
        resnet = models.resnet50(pretrained=False)
        resnet.load_state_dict(checkpoint['resnet'])
        # preprocess images for simclr
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor()
        ])
        return resnet
    if model_type == "moco":
        # load checkpoints of moco
        state_dict = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/moco_v1_200ep_pretrain.pth.tar',
            map_location=torch.device('cpu'))['state_dict']
        resnet = models.resnet50(pretrained=False)
        for k in list(state_dict.keys()):
            if k.startswith('module.encoder_q'
                            ) and not k.startswith('module.encoder_q.fc'):
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for moco
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "mocov2":
        # load checkpoints of mocov2
        state_dict = torch.load(
            '/content/gdrive/MyDrive/moco/moco_v2_200ep_pretrain.pth.tar',
            map_location=torch.device('cpu'))['state_dict']
        resnet = models.resnet50(pretrained=False)
        for k in list(state_dict.keys()):
            if k.startswith('module.encoder_q'
                            ) and not k.startswith('module.encoder_q.fc'):
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for mocov2
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "InsDis":
        # load checkpoints for instance recoginition resnet
        resnet = models.resnet50(pretrained=False)
        state_dict = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/lemniscate_resnet50_update.pth',
            map_location=torch.device('cpu'))['state_dict']
        for k in list(state_dict.keys()):
            if k.startswith('module') and not k.startswith('module.fc'):
                state_dict[k[len("module."):]] = state_dict[k]
            del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for instance recoginition resnet
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "place365_rn50":
        # load checkpoints for place365 resnet
        resnet = models.resnet50(pretrained=False)
        state_dict = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/resnet50_places365.pth.tar',
            map_location=torch.device('cuda'))['state_dict']
        #     for k in list(state_dict.keys()):
        #         if k.startswith('module') and not k.startswith('module.fc'):
        #             state_dict[k[len("module."):]] = state_dict[k]
        #         del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        #     assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for place365-resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "resnext101":
        #load ResNeXt 101_32x8 imagenet trained model
        resnet = models.resnext101_32x8d(pretrained=True)
        #preprocess for resnext101
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "wsl_resnext101":
        # load wsl resnext101
        resnet = models.resnext101_32x8d(pretrained=False)
        checkpoint = torch.load(
            "/content/gdrive/MyDrive/model_checkpoints/ig_resnext101_32x8-c38310e5.pth"
        )
        resnet.load_state_dict(checkpoint)
        #preprocess for wsl resnext101
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "st_resnet":
        # load checkpoint for st resnet
        resnet = models.resnet50(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "resnet101":
        # load checkpoint for st resnet
        resnet = models.resnet101(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet
    if model_type == "wide_resnet101":
        # load checkpoint for st resnet
        resnet = models.wide_resnet101_2(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet
    if model_type == "wide_resnet50":
        # load checkpoint for st resnet
        resnet = models.wide_resnet50_2(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_resnet50":
        # load checkpoint for st resnet
        resnet = models.resnet50(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_resnet101":
        # load checkpoint for st resnet
        resnet = models.resnet101(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_wrn50":
        # load checkpoint for st resnet
        resnet = models.wide_resnet50_2(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_wrn101":
        # load checkpoint for st resnet
        resnet = models.wide_resnet101_2(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "st_alexnet":
        # load checkpoint for st alexnet
        alexnet = models.alexnet(pretrained=True)
        #preprocess for alexnet
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return alexnet

    if model_type == "clip":
        import clip
        resnet, preprocess = clip.load("RN50")
        return resnet

    if model_type == 'linf_8':
        #     resnet = torch.load('/content/gdrive/MyDrive/model_checkpoints/imagenet_linf_8_model.pt') # https://drive.google.com/file/d/1DRkIcM_671KQNhz1BIXMK6PQmHmrYy_-/view?usp=sharing
        #     preprocess = transforms.Compose([
        #     transforms.Resize(256),
        #     transforms.CenterCrop(224),
        #     transforms.ToTensor(),
        #     transforms.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225])
        #     ])
        #     return resnet
        resnet = models.resnet50(pretrained=False)
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/imagenet_linf_8.pt',
            map_location=torch.device('cuda'))
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            if k.startswith('module.attacker.model.'):

                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet

    if model_type == 'linf_4':
        #     resnet = torch.load('/content/gdrive/MyDrive/model_checkpoints/robust_resnet.pt')#https://drive.google.com/file/d/1_tOhMBqaBpfOojcueSnYQRw_QgXdPVS6/view?usp=sharing
        #     preprocess = transforms.Compose([
        #     transforms.Resize(256),
        #     transforms.CenterCrop(224),
        #     transforms.ToTensor(),
        #     transforms.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225])
        #     ])
        #     return resnet
        resnet = models.resnet50(pretrained=False)
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/imagenet_linf_4.pt',
            map_location=torch.device('cuda'))
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            #         if k.startswith('module.attacker.model.') and not k.startswith('module.attacker.normalize') :
            if k.startswith('module.attacker.model.'):
                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet

    if model_type == 'l2_3':
        #     resnet = torch.load('/content/gdrive/MyDrive/model_checkpoints/imagenet_l2_3_0_model.pt') # https://drive.google.com/file/d/1SM9wnNr_WnkEIo8se3qd3Di50SUT9apn/view?usp=sharing
        #     preprocess = transforms.Compose([
        #     transforms.Resize(256),
        #     transforms.CenterCrop(224),
        #     transforms.ToTensor(),
        #     transforms.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225])
        #     ])
        #     return resnet
        resnet = models.resnet50(pretrained=False)
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/imagenet_l2_3_0.pt',
            map_location=torch.device('cuda'))
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            if k.startswith('module.attacker.model.'):

                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet

    if model_type == 'resnet50_l2_eps1' or model_type == 'resnet50_l2_eps0.01' or model_type == 'resnet50_l2_eps0.03' or model_type == 'resnet50_l2_eps0.5' or model_type == 'resnet50_l2_eps0.25' or model_type == 'resnet50_l2_eps3' or model_type == 'resnet50_l2_eps5':
        resnet = models.resnet50(pretrained=False)
        ds = ImageNet('/tmp')
        total_resnet, checkpoint = make_and_restore_model(
            arch='resnet50',
            dataset=ds,
            resume_path=
            f'/content/gdrive/MyDrive/model_checkpoints/{model_type}.ckpt')
        # resnet=total_resnet.attacker
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            if k.startswith('module.attacker.model.'):
                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet
コード例 #26
0
    args = parser.parse_args()

    # Personal preference here to default to grad not enabled; 
    # explicitly enable grad when necessary for memory reasons
    ch.manual_seed(0)
    ch.set_grad_enabled(False)

    print("Initializing dataset and loader...")
    ds_imagenet = ImageNet(args.dataset_path)
    train_loader, test_loader = ds_imagenet.make_loaders(args.num_workers, args.batch_size, 
                                                        data_aug=False, shuffle_train=False, shuffle_val=False)

    print("Loading model...")
    model, _ = make_and_restore_model( 
        arch=args.arch, 
        dataset=ds_imagenet,
        resume_path=args.model_root
    )
    model.eval()
    model = ch.nn.DataParallel(model.to(args.device))

    out_dir = args.out_path
    if not os.path.exists(out_dir):
        print(f"Making directory {out_dir}")
        os.makedirs(out_dir)

    for mode,loader in zip(['train', 'test'], [train_loader, test_loader]): 
        print(f"Creating {mode} features in {out_dir}")
        all_latents, all_labels = [], []

        chunk_id, n = 0, 0
コード例 #27
0
def main():
    # Parse image location from command argument
    parser = argparse.ArgumentParser()
    parser.add_argument('--image', type=str, required=True)
    args = parser.parse_args()

    # Check if the given location exists and it is a valid image file
    if os.path.exists(args.image) and args.image.endswith(
        ('png', 'jpg', 'jpeg')):
        # Open the image from the given location
        image = Image.open(args.image)

        # Transform the image to a PyTorch tensor
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        image = transform(image)
        image = image.unsqueeze(0).cuda()

        # Attack parameters
        kwargs = {
            'constraint': 'inf',
            'eps': 16.0 / 255.0,
            'step_size': 1.0 / 255.0,
            'iterations': 500,
            'do_tqdm': True,
        }

        # Set the dataset for the robustness model
        dataset = ImageNet('dataset/')

        # Initialize a pretrained model via the robustness library
        model, _ = make_and_restore_model(arch='resnet50',
                                          dataset=dataset,
                                          pytorch_pretrained=True)
        model = model.cuda()

        # For evaluation, the standard ResNet50 from torchvision is used
        eval_model = torchvision.models.resnet50(pretrained=True).cpu().eval()

        # Get the model prediction for the original image
        label = eval_model(image.cpu())
        label = torch.argmax(label[0])
        label = label.view(1).cuda()

        # Create an adversarial example of the original images
        _, adversarial_example = model(image, label, make_adv=True, **kwargs)

        # Get the prediction of the model for the adversarial image
        adversarial_prediction = eval_model(adversarial_example.cpu())
        adversarial_prediction = torch.argmax(adversarial_prediction[0])

        # Print the original and the adversarial predictions
        print('Original prediction: ' + str(label.item()))
        print('Adversarial prediction: ' + str(adversarial_prediction.item()))

        # Save the adversarial example in the same folder as the original image
        filename_and_extension = args.image.split('.')
        adversarial_location = filename_and_extension[
            0] + '_adversarial.' + filename_and_extension[-1]
        save_image(adversarial_example[0], adversarial_location)

    else:
        print('Incorrect image path!')
コード例 #28
0
def get_boosted_model(args, ds):
    if args.boosting == 'qrcode':
        assert args.arch == 'None', 'With QRCodes, there is no model. Please pass "None" to --arch'
        model = boosters.DataAugmentedModel(boosters.QRCodeModel(ds.num_classes, 
                                        detector=args.qrcode_detector), ds.ds_name, [])
        booster = boosters.QRCodeBooster(ds.num_classes, constants.DS_TO_DIM[args.dataset],
                                        constants.PATCH_TRANSFORMS,args.patch_size,
                                        apply_transforms=args.apply_booster_transforms, 
                                        detector=args.qrcode_detector)
        return boosters.BoostedModel(model, booster, 'qrcode').cuda()

    is_pt_model = args.arch in constants.NAME_TO_ARCH and args.dataset == 'imagenet'
    arch = constants.NAME_TO_ARCH[args.arch](args.pytorch_pretrained) if is_pt_model else args.arch

    if arch == 'linear':
        arch = LinearModel(constants.DS_TO_CLASSES[args.dataset], constants.DS_TO_DIM[args.dataset])


    # Check if the checkpoint is of a BoostedModel or a robustness lib AttackerModel 
    checkpoint = ch.load(args.model_path, pickle_module=dill)
    sd = checkpoint['model']
    is_checkpoint_boosted = 'module.booster.patches' in sd.keys()

    kwargs = {'arch': arch, 'dataset': ds, 
              'resume_path': args.model_path if not is_checkpoint_boosted else None,
              'add_custom_forward': is_pt_model or args.arch=='linear',
              'pytorch_pretrained': args.pytorch_pretrained}

    model, _ = model_utils.make_and_restore_model(**kwargs)

    # Wrap the model wtith DataAugmentedModel even if there are not corruptions. 
    # For consistency when loading from checkpoints
    model = boosters.DataAugmentedModel(model, ds.ds_name, [])

    # don't pass checkpoint to train_model do avoid resuming for epoch, optimizers etc.
    if args.boosting == 'class_consistent':
        boosting_path = Path(args.out_dir) / BOOSTING_FP
        if boosting_path.exists():
            booster = ch.load(boosting_path)
        else:
            dim = constants.DS_TO_DIM[args.dataset]
            booster = boosters.ClassConsistentBooster(ds.num_classes, dim,
                                                      constants.PATCH_TRANSFORMS,
                                                      args.patch_size,
                                                      model, apply_transforms=args.apply_booster_transforms)

        model = boosters.BoostedModel(model, booster, None)
    elif args.boosting == 'best_images':
        dim = constants.DS_TO_DIM[args.dataset]
        booster = boosters.BestImageBooster(ds.num_classes, dim,
                                            constants.PATCH_TRANSFORMS,
                                            args.patch_size,
                                            args.path_best_images,
                                            apply_transforms=args.apply_booster_transforms)

        model = boosters.BoostedModel(model, booster, None)
    elif args.boosting == 'none':
        # assert args.eval_only
        model = boosters.BoostedModel(model, None, None)
    else:
        raise ValueError(f'boosting not found: {args.boosting}')

    if is_checkpoint_boosted:
        sd = {k[len('module.'):]:v for k,v in sd.items()}
        model.load_state_dict(sd)
        print("=> loaded checkpoint of BoostedModel'{}' (epoch {})".format(args.model_path, checkpoint['epoch']))

    return model.cuda()
コード例 #29
0
ファイル: utilities.py プロジェクト: jtx1999/perceptual-advex
def get_dataset_model(
    args=None,
    dataset_path: Optional[str] = None,
    arch: Optional[str] = None,
    checkpoint_fname: Optional[str] = None,
    **kwargs,
) -> Tuple[DataSet, nn.Module]:
    """
    Given an argparse namespace with certain parameters, or those parameters
    as keyword arguments, returns a tuple (dataset, model) with a robustness
    dataset and a FeatureModel.
    """

    if dataset_path is None:
        if args is None:
            dataset_path = '~/datasets'
        else:
            dataset_path = args.dataset_path
    dataset_path = os.path.expandvars(dataset_path)

    dataset_name = kwargs.get('dataset') or args.dataset
    dataset = DATASETS[dataset_name](dataset_path)

    checkpoint_is_feature_model = False

    if checkpoint_fname is None:
        checkpoint_fname = getattr(args, 'checkpoint', None)
    if arch is None:
        arch = args.arch

    if arch.startswith('rob-') or (dataset_name.startswith('cifar')
                                   and 'resnet' in arch):
        if arch.startswith('rob-'):
            arch = arch[4:]
        if checkpoint_fname == 'pretrained':
            pytorch_pretrained = True
            checkpoint_fname = None
        else:
            pytorch_pretrained = False
        try:
            model, _ = make_and_restore_model(
                arch=arch,
                dataset=dataset,
                resume_path=checkpoint_fname,
                pytorch_pretrained=pytorch_pretrained,
                parallel=False,
            )
        except RuntimeError as error:
            if 'state_dict' in str(error):
                model, _ = make_and_restore_model(
                    arch=arch,
                    dataset=dataset,
                    parallel=False,
                )
                try:
                    state = torch.load(checkpoint_fname)
                    model.model.load_state_dict(state['model'])
                except RuntimeError as error:
                    if 'state_dict' in str(error):
                        checkpoint_is_feature_model = True
                    else:
                        raise error
            else:
                raise error  # type: ignore
    elif arch == 'trades-wrn':
        model = TradesWideResNet()
        if checkpoint_fname is not None:
            state = torch.load(checkpoint_fname)
            model.load_state_dict(state)
    elif hasattr(torchvision_models, arch):
        if (arch == 'alexnet' and dataset_name.startswith('cifar')
                and checkpoint_fname != 'pretrained'):
            model = CifarAlexNet(num_classes=dataset.num_classes)
        else:
            if checkpoint_fname == 'pretrained':
                model = getattr(torchvision_models, arch)(pretrained=True)
            else:
                model = getattr(torchvision_models,
                                arch)(num_classes=dataset.num_classes)

        if checkpoint_fname is not None and checkpoint_fname != 'pretrained':
            try:
                state = torch.load(checkpoint_fname)
                model.load_state_dict(state['model'])
            except RuntimeError as error:
                if 'state_dict' in str(error):
                    checkpoint_is_feature_model = True
                else:
                    raise error
    else:
        raise RuntimeError(f'Unsupported architecture {arch}.')

    if 'alexnet' in arch:
        model = AlexNetFeatureModel(model)
    elif 'vgg16' in arch:
        model = VGG16FeatureModel(model)
    elif 'resnet' in arch:
        if not isinstance(model, AttackerModel):
            model = AttackerModel(model, dataset)
        if dataset_name.startswith('cifar'):
            model = CifarResNetFeatureModel(model)
        elif (dataset_name.startswith('imagenet')
              or dataset_name == 'bird_or_bicycle'):
            model = ImageNetResNetFeatureModel(model)
        else:
            raise RuntimeError('Unsupported dataset.')
    elif arch == 'trades-wrn':
        pass  # We can't use this as a FeatureModel yet.
    else:
        raise RuntimeError(f'Unsupported architecture {arch}.')

    if checkpoint_is_feature_model:
        model.load_state_dict(state['model'])

    return dataset, model
コード例 #30
0
##

#ALTERNATIVELY do it Robustness_lib's way
dataset = datasets.RestrictedImageNet('')

model_kwargs = {
    'arch': 'resnet50',
    'dataset': dataset,
    'resume_path':
    '/home/jesse/Documents/robustness_lib/RestrictedImageNet.pt',
    'state_dict_path': 'model',
    'parallel': False
}

# Robust ResNet
model, ckpt = model_utils.make_and_restore_model(**model_kwargs)
robust_resnet = model.model

# Regular ResNet
reg_resnet = copy.deepcopy(robust_resnet)
new_params = reg_resnet.state_dict()
partial_params = model_zoo.load_url(
    'https://download.pytorch.org/models/resnet50-19c8e357.pth')
del partial_params['fc.bias']
del partial_params['fc.weight']
new_params.update(partial_params)
reg_resnet.load_state_dict(new_params)

# VGG
vgg = models.vgg19(pretrained=True).features