Exemplo n.º 1
0
 def construct_tiny_imagenet_model(arch, dataset):
     if not arch.startswith("densenet") and not arch.startswith("resnext") and arch in torch_models.__dict__:
         network = torch_models.__dict__[arch](pretrained=False)
     num_classes = CLASS_NUM[dataset]
     if arch.startswith("resnet"):
         num_ftrs = network.fc.in_features
         network.fc = nn.Linear(num_ftrs, num_classes)
     elif arch.startswith("densenet"):
         if arch == "densenet161":
             network = densenet161(pretrained=False)
         elif arch == "densenet121":
             network = densenet121(pretrained=False)
         elif arch == "densenet169":
             network = densenet169(pretrained=False)
         elif arch == "densenet201":
             network = densenet201(pretrained=False)
     elif arch == "resnext32_4":
         network = resnext101_32x4d(pretrained=None)
     elif arch == "resnext64_4":
         network = resnext101_64x4d(pretrained=None)
     elif arch == "ghost_net":
         network = ghost_net(IN_CHANNELS[dataset], CLASS_NUM[dataset])
     elif arch.startswith("inception"):
         network = inception_v3(pretrained=False)
     elif arch == "WRN-28-10-drop":
         network = tiny_imagenet_wrn(in_channels=IN_CHANNELS[dataset],depth=28,num_classes=CLASS_NUM[dataset],widen_factor=10, dropRate=0.3)
     elif arch == "WRN-40-10-drop":
         network = tiny_imagenet_wrn(in_channels=IN_CHANNELS[dataset], depth=40, num_classes=CLASS_NUM[dataset],
                                     widen_factor=10, dropRate=0.3)
     elif arch.startswith("vgg"):
         network.avgpool = Identity()
         network.classifier[0] = nn.Linear(512 * 2 * 2, 4096)  # 64 /2**5 = 2
         network.classifier[-1] = nn.Linear(4096, num_classes)
     return network
def get_tinyimagenet_model(arch, dataset):
    if arch in models.__dict__:
        network = models.__dict__[arch](pretrained=True)
    num_classes = CLASS_NUM[dataset]
    if arch.startswith("resnet"):
        num_ftrs = network.fc.in_features
        network.fc = nn.Linear(num_ftrs, num_classes)
    elif arch.startswith("densenet"):
        if arch == "densenet161":
            network = densenet161(pretrained=True)
        elif arch == "densenet121":
            network = densenet121(pretrained=True)
        elif arch == "densenet169":
            network = densenet169(pretrained=True)
        elif arch == "densenet201":
            network = densenet201(pretrained=True)
    elif arch == "resnext32_4":
        network = resnext101_32x4d(pretrained=None)
    elif arch == "resnext64_4":
        network = resnext101_64x4d(pretrained=None)
    elif arch == "resnext32_4":
        network = resnext101_32x4d(pretrained="imagenet")
    elif arch == "resnext64_4":
        network = resnext101_64x4d(pretrained="imagenet")
    elif arch.startswith("squeezenet"):
        network.classifier[-1] = nn.AdaptiveAvgPool2d(1)
        network.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=1)
    elif arch.startswith("inception"):
        network = inception_v3(pretrained=True)
    elif arch.startswith("vgg"):
        network.avgpool = Identity()
        network.classifier[0] = nn.Linear(512 * 2 * 2, 4096)  # 64 /2**5 = 2
        network.classifier[-1] = nn.Linear(4096, num_classes)
    return network
def construct_model(arch, dataset):
    if dataset != "TinyImageNet":
        if arch == "conv3":
            network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                            CLASS_NUM[dataset])
        elif arch == "densenet121":
            network = DenseNet121(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "densenet169":
            network = DenseNet169(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "densenet201":
            network = DenseNet201(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "googlenet":
            network = GoogLeNet(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "mobilenet":
            network = MobileNet(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "mobilenet_v2":
            network = MobileNetV2(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "resnet18":
            network = ResNet18(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "resnet34":
            network = ResNet34(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "resnet50":
            network = ResNet50(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "resnet101":
            network = ResNet101(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "resnet152":
            network = ResNet152(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "pnasnetA":
            network = PNASNetA(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "pnasnetB":
            network = PNASNetB(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "efficientnet":
            network = EfficientNetB0(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "dpn26":
            network = DPN26(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "dpn92":
            network = DPN92(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "resnext29_2":
            network = ResNeXt29_2x64d(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "resnext29_4":
            network = ResNeXt29_4x64d(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "resnext29_8":
            network = ResNeXt29_8x64d(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "resnext29_32":
            network = ResNeXt29_32x4d(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "senet18":
            network = SENet18(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "shufflenet_G2":
            network = ShuffleNetG2(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "shufflenet_G3":
            network = ShuffleNetG3(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "vgg11":
            network = vgg11(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "vgg13":
            network = vgg13(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "vgg16":
            network = vgg16(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "vgg19":
            network = vgg19(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "preactresnet18":
            network = PreActResNet18(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "preactresnet34":
            network = PreActResNet34(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "preactresnet50":
            network = PreActResNet50(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "preactresnet101":
            network = PreActResNet101(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "preactresnet152":
            network = PreActResNet152(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "wideresnet28":
            network = wideresnet28(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "wideresnet34":
            network = wideresnet34(IN_CHANNELS[dataset], CLASS_NUM[dataset])
        elif arch == "wideresnet40":
            network = wideresnet40(IN_CHANNELS[dataset], CLASS_NUM[dataset])
    else:
        if arch in torch_models.__dict__:
            network = torch_models.__dict__[arch](pretrained=True)
        num_classes = CLASS_NUM[args.dataset]
        if arch.startswith("resnet"):
            num_ftrs = network.fc.in_features
            network.fc = nn.Linear(num_ftrs, num_classes)
        elif arch.startswith("densenet"):
            if arch == "densenet161":
                network = densenet161(pretrained=True)
            elif arch == "densenet121":
                network = densenet121(pretrained=True)
            elif arch == "densenet169":
                network = densenet169(pretrained=True)
            elif arch == "densenet201":
                network = densenet201(pretrained=True)
        elif arch == "resnext32_4":
            network = resnext101_32x4d(pretrained="imagenet")
        elif arch == "resnext64_4":
            network = resnext101_64x4d(pretrained="imagenet")
        elif arch.startswith("vgg"):
            network.avgpool = Identity()
            network.classifier[0] = nn.Linear(512 * 2 * 2,
                                              4096)  # 64 /2**5 = 2
            network.classifier[-1] = nn.Linear(4096, num_classes)
    return network
Exemplo n.º 4
0
    def make_model(self, dataset, arch, **kwargs):
        """
        Make model, and load pre-trained weights.
        :param dataset: cifar10 or imagenet
        :param arch: arch name, e.g., alexnet_bn
        :return: model (in cpu and training mode)
        """
        if dataset in ['CIFAR-10',"MNIST","FashionMNIST"]:
            if arch == 'gdas':
                assert kwargs['train_data'] == 'full'
                model = target_models.cifar.gdas('{}/subspace_attack/data/cifar10-models/gdas/seed-6293/checkpoint-cifar10-model.pth'.format(PY_ROOT))
                model.mean = [125.3 / 255, 123.0 / 255, 113.9 / 255]
                model.std = [63.0 / 255, 62.1 / 255, 66.7 / 255]
                model.input_space = 'RGB'
                model.input_range = [0, 1]
                model.input_size = [IN_CHANNELS[dataset], IMAGE_SIZE[dataset][0], IMAGE_SIZE[dataset][1]]
            elif arch == 'pyramidnet272':
                assert kwargs['train_data'] == 'full'
                model = target_models.cifar.pyramidnet272(num_classes=10)
                self.load_weight_from_pth_checkpoint(model, '{}/subspace_attack/data/cifar10-models/pyramidnet272/checkpoint.pth'.format(PY_ROOT))
                model.mean = [0.49139968, 0.48215841, 0.44653091]
                model.std = [0.24703223, 0.24348513, 0.26158784]
                model.input_space = 'RGB'
                model.input_range = [0, 1]
                model.input_size = [IN_CHANNELS[dataset], IMAGE_SIZE[dataset][0], IMAGE_SIZE[dataset][1]]
            elif arch == "carlinet":
                assert kwargs['train_data'] == 'full'
                model = carlinet()
                self.load_weight_from_h5_checkpoint(model, '{}/subspace_attack/data/cifar10-models/carlinet'.format(PY_ROOT))
                model.mean = [0.5, 0.5, 0.5]
                model.std = [1, 1, 1]
                model.input_space = 'RGB'
                model.input_range = [0, 1]
                model.input_size =  [IN_CHANNELS[dataset], IMAGE_SIZE[dataset][0], IMAGE_SIZE[dataset][1]]

            else:
                # decide weight filename prefix, suffix
                if kwargs['train_data'] in ['cifar10.1']:
                    # use cifar10.1 (2,000 images) to train_simulate_grad_mode target_models
                    if kwargs['train_data'] == 'cifar10.1':
                        prefix = '{}/subspace_attack/data/cifar10.1-models'.format(PY_ROOT)
                    else:
                        raise NotImplementedError('Unknown train data {}'.format(kwargs['train_data']))
                    if kwargs['epoch'] == 'final':
                        suffix = 'final.pth'
                    elif kwargs['epoch'] == 'best':
                        suffix = 'model_best.pth'
                    else:
                        raise NotImplementedError('Unknown epoch {} for train data {}'.format(
                            kwargs['epoch'], kwargs['train_data']))
                elif kwargs['train_data'] == 'full':
                    # use full training set to train_simulate_grad_mode target_models
                    prefix = '{}/subspace_attack/data/cifar10-models'.format(PY_ROOT)
                    if kwargs['epoch'] == 'final':
                        suffix = 'checkpoint.pth.tar'
                    elif kwargs['epoch'] == 'best':
                        suffix = 'model_best.pth.tar'
                    else:
                        raise NotImplementedError('Unknown epoch {} for train data {}'.format(
                            kwargs['epoch'], kwargs['train_data']))
                else:
                    raise NotImplementedError('Unknown train_simulate_grad_mode data {}'.format(kwargs['train_data']))

                if arch == 'alexnet_bn':
                    model = target_models.cifar.alexnet_bn(num_classes=10)
                elif arch == 'vgg11_bn':
                    model = target_models.cifar.vgg11_bn(num_classes=10)
                elif arch == 'vgg13_bn':
                    model = target_models.cifar.vgg13_bn(num_classes=10)
                elif arch == 'vgg16_bn':
                    model = target_models.cifar.vgg16_bn(num_classes=10)
                elif arch == 'vgg19_bn':
                    model = target_models.cifar.vgg19_bn(num_classes=10)
                elif arch == 'wrn-28-10-drop':
                    model = target_models.cifar.wrn(depth=28, widen_factor=10, dropRate=0.3, num_classes=10)
                else:
                    raise NotImplementedError('Unknown arch {}'.format(arch))

                # load weight
                self.load_weight_from_pth_checkpoint(model, osp.join(prefix, arch, suffix))
                print("load model checkpoint from {}".format(osp.join(prefix, arch, suffix)))

                # assign meta info
                model.mean = [0.4914, 0.4822, 0.4465]
                model.std = [0.2023, 0.1994, 0.2010]
                model.input_space = 'RGB'
                model.input_range = [0, 1]
                model.input_size = [3, 32, 32]
        elif dataset == "TinyImageNet":
            class Identity(nn.Module):
                def __init__(self):
                    super(Identity, self).__init__()

                def forward(self, x):
                    return x
            if arch in target_models.__dict__:
                model = target_models.__dict__[arch](pretrained=True)
            num_classes = CLASS_NUM[dataset]
            if arch.startswith("resnet"):
                num_ftrs = model.fc.in_features
                model.fc = nn.Linear(num_ftrs, num_classes)
            elif arch.startswith("densenet"):
                if arch == "densenet161":
                    model = densenet161(pretrained=True)
                elif arch == "densenet121":
                    model = densenet121(pretrained=True)
                elif arch == "densenet169":
                    model = densenet169(pretrained=True)
                elif arch == "densenet201":
                    model = densenet201(pretrained=True)
            elif arch == "resnext32_4":
                model = resnext101_32x4d(pretrained="imagenet")
            elif arch == "resnext64_4":
                model = resnext101_64x4d(pretrained="imagenet")
            elif arch.startswith("inception"):
                model = inception_v3(pretrained=True)
            elif arch.startswith("vgg"):
                model.avgpool = Identity()
                model.classifier[0] = nn.Linear(512 * 2 * 2, 4096)  # 64 /2**5 = 2
                model.classifier[-1] = nn.Linear(4096, num_classes)
            model_path = "{}/train_pytorch_model/real_image_model/{}@{}*.pth.tar".format(PY_ROOT, dataset, arch)
            model_path = glob.glob(model_path)
            model_path = model_path[0]
            model.load_state_dict(
                torch.load(model_path, map_location=lambda storage, location: storage)["state_dict"])
        elif dataset == 'ImageNet':

            model = eval('target_models.imagenet.{}(num_classes=1000, pretrained=\'imagenet\')'.format(arch))

            if kwargs['train_data'] == 'full':
                # torchvision has load correct checkpoint automatically
                pass
            elif kwargs['train_data'] == 'imagenetv2-val':
                prefix = '{}/subspace_attack/data/imagenetv2-v1val45000-target_models'.format(PY_ROOT)
                if kwargs['epoch'] == 'final':
                    suffix = 'checkpoint.pth.tar'
                elif kwargs['epoch'] == 'best':
                    suffix = 'model_best.pth.tar'
                else:
                    raise NotImplementedError('Unknown epoch {} for train_simulate_grad_mode data {}'.format(
                        kwargs['epoch'], kwargs['train_data']))
                # load weight
                self.load_weight_from_pth_checkpoint(model, osp.join(prefix, arch, suffix))
            else:
                raise NotImplementedError('Unknown train_simulate_grad_mode data {}'.format(kwargs['train_data']))
        else:
            raise NotImplementedError('Unknown dataset {}'.format(dataset))

        return model
def main_train_worker(args):

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    print("=> creating model '{}'".format(args.arch))
    if args.arch in models.__dict__:
        network = models.__dict__[args.arch](pretrained=True)
    num_classes = CLASS_NUM[args.dataset]
    if args.arch.startswith("resnet"):
        num_ftrs = network.fc.in_features
        network.fc = nn.Linear(num_ftrs, num_classes)
    elif args.arch.startswith("densenet"):
        if args.arch == "densenet161":
            network = densenet161(pretrained=True)
        elif args.arch == "densenet121":
            network = densenet121(pretrained=True)
        elif args.arch == "densenet169":
            network = densenet169(pretrained=True)
        elif args.arch == "densenet201":
            network = densenet201(pretrained=True)
    elif args.arch == "resnext32_4":
        network = resnext101_32x4d(pretrained=None)
    elif args.arch == "resnext64_4":
        network = resnext101_64x4d(pretrained=None)
    elif args.arch == "resnext32_4":
        network = resnext101_32x4d(pretrained="imagenet")
    elif args.arch == "resnext64_4":
        network = resnext101_64x4d(pretrained="imagenet")
    elif args.arch.startswith("squeezenet"):
        network.classifier[-1] = nn.AdaptiveAvgPool2d(1)
        network.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=1)
    elif args.arch.startswith("inception"):
        network = inception_v3(pretrained=True)
    elif args.arch.startswith("vgg"):
        network.avgpool = Identity()
        network.classifier[0] = nn.Linear(512 * 2 * 2, 4096)  # 64 /2**5 = 2
        network.classifier[-1] = nn.Linear(4096, num_classes)

# densenet和inception必须自己改一份新代码,因为forward用了F.avg_pool2d
    model_path = '{}/train_pytorch_model/real_image_model/{}@{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format(
        PY_ROOT, args.dataset, args.arch, args.epochs, args.lr,
        args.batch_size)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    print("after train_simulate_grad_mode, model will be saved to {}".format(
        model_path))
    preprocessor = get_preprocessor(IMAGE_SIZE[args.dataset], use_flip=True)
    network.cuda()
    image_classifier_loss = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(network.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    cudnn.benchmark = True
    train_dataset = TinyImageNet(IMAGE_DATA_ROOT[args.dataset],
                                 preprocessor,
                                 train=True)
    test_dataset = TinyImageNet(IMAGE_DATA_ROOT[args.dataset],
                                preprocessor,
                                train=False)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    for epoch in range(0, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)
        # train_simulate_grad_mode for one epoch
        train(train_loader, network, image_classifier_loss, optimizer, epoch,
              args)
        # evaluate_accuracy on validation set
        val_acc = validate(val_loader, network, image_classifier_loss, args)
        # remember best acc@1 and save checkpoint
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                "val_acc": val_acc,
                'state_dict': network.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            filename=model_path)