Ejemplo n.º 1
0
def get_model(name, input_size=None, output=None):
    name = name.lower()
    if name == 'lenet-300-100':
        model = LeNet_300_100(input_size, output)
    elif name == 'lenet-5':
        model = LeNet(input_size, output)
    elif 'vgg' in name:
        # if 'bn' in name:
        if name == 'vgg11':
            model = vgg11(pretrained=False, num_classes=output)
        elif name == 'vgg16':
            model = vgg16(pretrained=False, num_classes=output)
        else:
            assert False

        for n, m in model.named_modules():
            if hasattr(m, 'bias') and not isinstance(m, _BatchNorm):
                if m.bias is not None:
                    if m.bias.sum() == 0:
                        m.bias = None

    elif 'alexnet' in name:
        model = AlexNet(num_classes=output)

        for n, m in model.named_modules():
            if hasattr(m, 'bias') and not isinstance(m, _BatchNorm):
                if m.bias is not None:
                    if m.bias.sum() == 0:
                        m.bias = None
    elif 'resnet' in name:
        if name == 'resnet20':
            model = resnet20(num_classes=output)
        elif name == 'resnet32':
            model = resnet32(num_classes=output)
        else:
            assert False

        for n, m in model.named_modules():
            if hasattr(m, 'bias') and not isinstance(m, _BatchNorm):
                if m.bias is not None:
                    if m.bias.sum() == 0:
                        m.bias = None

    else:
        assert False

    return model
Ejemplo n.º 2
0
noise_multiplier0 = noise_multiplier1 = sigma
print('noise scale for gradient embedding: ', noise_multiplier0, 'noise scale for residual gradient: ', noise_multiplier1, '\n rgp enabled: ', args.rgp, 'privacy guarantee: ', eps)

print('\n==> Creating GEP class instance')
gep = GEP(args.num_bases, args.batchsize, args.clip0, args.clip1, args.power_iter).cuda()
## attach auxiliary data to GEP instance
gep.public_inputs = public_inputs
gep.public_targets = public_targets

print('\n==> Creating ResNet20 model instance')
if(args.resume):
    try:
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint_file = './checkpoint/' + args.sess  + '.ckpt'
        checkpoint = torch.load(checkpoint_file)
        net = resnet20()
        net.cuda()
        restore_param(net.state_dict(), checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1
        torch.set_rng_state(checkpoint['rng_state'])
        approx_error = checkpoint['approx_error']
    except:
        print('resume from checkpoint failed')
else:
    net = resnet20() 
    net.cuda()

net = extend(net)

net.gep = gep
Ejemplo n.º 3
0
        net = resnet_imagenet.resnet101()
    elif args.arch == 'resnet50':
        net = resnet_imagenet.resnet50()
    elif args.arch == 'resnet34':
        net = resnet_imagenet.resnet34()
    elif args.arch == 'resnet18':
        net = resnet_imagenet.resnet18()
else:
    if args.arch == 'resnet110':
        net = models.resnet110(num_classes=10)
    elif args.arch == 'resnet56':
        net = models.resnet56(num_classes=10)
    elif args.arch == 'resnet32':
        net = models.resnet32(num_classes=10)
    elif args.arch == 'resnet20':
        net = models.resnet20(num_classes=10)

if args.dataset == 'imagenet':

    if args.arch == 'resnet101':
        state_dict = torch.load(
            os.path.join(args.pretrain_path, 'resnet101-5d3b4d8f.pth'))
    elif args.arch == 'resnet50':
        state_dict = torch.load(
            os.path.join(args.pretrain_path, 'resnet50-19c8e357.pth'))
    elif args.arch == 'resnet34':
        state_dict = torch.load(
            os.path.join(args.pretrain_path, 'resnet34-333f7ec4.pth'))
    elif args.arch == 'resnet18':
        state_dict = torch.load(
            os.path.join(args.pretrain_path, 'resnet18-5c106cde.pth'))
Ejemplo n.º 4
0
#     else:
#         assert False
#
#     return new_model

if __name__ == '__main__':
    import torch
    from models import resnet20

    # import torchvision.models as models
    from collections import defaultdict
    from utils import calculate_trainable_parameters
    from methods.supermask.models_utils import remove_wrappers_from_model, get_masks_from_gradients, \
    add_wrappers_to_model, extract_inner_model

    resnet18 = resnet20(num_classes=10)

    add_wrappers_to_model(resnet18,
                          ensemble=2,
                          masks_params={
                              'name': 'weights',
                              'initialization': {
                                  'name': 'constant',
                                  'c': 1
                              }
                          },
                          batch_ensemble=True)

    x = torch.rand((12, 3, 32, 32))
    y = torch.randint(9, size=(12, ))
Ejemplo n.º 5
0
def get_models(model_name):
    if model_name == 'resnet20':
        return models.resnet20()
Ejemplo n.º 6
0
def get_models(model_name):
    if model_name == 'resnet20':
        return models.resnet20()
    elif model_name == 'modelA':
        return models.ModelA()
Ejemplo n.º 7
0
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_test)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=128,
                                         shuffle=False)

classes = ("airplane", "automobile", "bird", "cat", "deer", "dog", "frog",
           "horse", "ship", "truck")

# Model
print('==> Building model..')
model = resnet20()
# If you want to restore training (instead of training from beginning),
# you can continue training based on previously-saved models
# by uncommenting the following two lines.
# Do not forget to modify start_epoch and end_epoch.
#restore_model_path = 'pretrained/ckpt_4_acc_63.320000.pth'
#model.load_state_dict(torch.load(restore_model_path)['net'])

# A better method to calculate loss
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=5e-4)


def train(epoch):
    model.train()
    train_loss = 0