def construct_model(self):
        def setup_keys():
            if self.key_type != 'random':
                if self.arch == 'alexnet':
                    pretrained_model = AlexNetNormal(self.in_channels,
                                                     self.num_classes,
                                                     self.norm_type)
                else:
                    pretrained_model = ResNet18(num_classes=self.num_classes,
                                                norm_type=self.norm_type)
                pretrained_model.load_state_dict(
                    torch.load(self.pretrained_path))
                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)

        passport_kwargs = construct_passport_kwargs(self)
        self.passport_kwargs = passport_kwargs

        if self.arch == 'alexnet':
            model = AlexNetPassportPrivate(self.in_channels, self.num_classes,
                                           passport_kwargs)
        else:
            model = ResNet18Private(num_classes=self.num_classes,
                                    passport_kwargs=passport_kwargs)

        self.model = model.to(self.device)

        setup_keys()
Exemplo n.º 2
0
    def construct_model(self):
        print('Construct Model')

        def setup_keys():
            if self.key_type != 'random':
                pretrained_from_torch = self.pretrained_path is None
                if self.arch == 'alexnet':
                    norm_type = 'none' if pretrained_from_torch else self.norm_type
                    pretrained_model = AlexNetNormal(
                        self.in_channels,
                        self.num_classes,
                        norm_type=norm_type,
                        pretrained=pretrained_from_torch)
                else:
                    norm_type = 'bn' if pretrained_from_torch else self.norm_type
                    pretrained_model = ResNet18(
                        num_classes=self.num_classes,
                        norm_type=norm_type,
                        pretrained=pretrained_from_torch)

                if not pretrained_from_torch:
                    print('Loading pretrained from self-trained model')
                    pretrained_model.load_state_dict(
                        torch.load(self.pretrained_path))
                else:
                    print('Loading pretrained from torch-pretrained model')

                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)

        passport_kwargs = construct_passport_kwargs(self)
        self.passport_kwargs = passport_kwargs

        print('Loading arch: ' + self.arch)
        if self.arch == 'alexnet':
            model = AlexNetPassportPrivate(self.in_channels, self.num_classes,
                                           passport_kwargs)
        else:
            model = ResNet18Private(num_classes=self.num_classes,
                                    passport_kwargs=passport_kwargs)

        self.model = model.to(self.device)

        setup_keys()

        pprint(self.model)
Exemplo n.º 3
0
def main(arch='alexnet', dataset='cifar10', scheme=1, loadpath='',
         passport_config='passport_configs/alexnet_passport.json', tagnum=1):
    batch_size = 64
    nclass = {
        'cifar100': 100,
        'imagenet1000': 1000
    }.get(dataset, 10)
    inchan = 3
    device = torch.device('cuda')

    trainloader, valloader = prepare_dataset({'transfer_learning': False,
                                              'dataset': dataset,
                                              'tl_dataset': '',
                                              'batch_size': batch_size})
    passport_kwargs, plkeys = construct_passport_kwargs_from_dict({'passport_config': json.load(open(passport_config)),
                                                                   'norm_type': 'bn',
                                                                   'sl_ratio': 0.1,
                                                                   'key_type': 'shuffle'},
                                                                  True)

    if arch == 'alexnet':
        if scheme == 1:
            model = AlexNetPassport(inchan, nclass, passport_kwargs)
        else:
            model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)
    else:
        if scheme == 1:
            model = ResNet18Passport(num_classes=nclass, passport_kwargs=passport_kwargs)
        else:
            model = ResNet18Private(num_classes=nclass, passport_kwargs=passport_kwargs)

    sd = torch.load(loadpath)
    criterion = nn.CrossEntropyLoss()
    prunedf = []
    for perc in [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
        model.load_state_dict(sd)
        pruning_resnet(model, perc)
        model = model.to(device)

        res = detect_signature(model)

        res['perc'] = perc
        res['tag'] = arch
        res['dataset'] = dataset
        res.update(test(model, criterion, valloader, device))
        prunedf.append(res)

    dirname = f'logs/pruning_attack/{loadpath.split("/")[1]}/{loadpath.split("/")[2]}'
    os.makedirs(dirname, exist_ok=True)

    histdf = pd.DataFrame(prunedf)
    histdf.to_csv(f'{dirname}/{arch}-{scheme}-history-{dataset}-{tagnum}.csv')
def run_maximize(rep=1, flipperc=0, arch='alexnet', dataset='cifar10', scheme=1,
                 loadpath='', passport_config=''):
    epochs = 100
    batch_size = 64
    nclass = 100 if dataset == 'cifar100' else 10
    inchan = 3
    lr = 0.01
    device = torch.device('cuda')

    trainloader, valloader = prepare_dataset({'transfer_learning': False,
                                              'dataset': dataset,
                                              'tl_dataset': '',
                                              'batch_size': batch_size})

    passport_kwargs = construct_passport_kwargs_from_dict({'passport_config': json.load(open(passport_config)),
                                                           'norm_type': 'gn',
                                                           'sl_ratio': 0.1,
                                                           'key_type': 'random'})

    if scheme == 1:
        model = AlexNetPassport(inchan, nclass, passport_kwargs)
    elif scheme == 2:
        model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)
    else:
        model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)

    task_name = loadpath.split('/')[-2]
    loadpath_all = loadpath + '1/models/best.pth'
    sd = torch.load(loadpath_all)
    model.load_state_dict(sd, strict=False)
    logdir = '/data-x/g12/zhangjie/DeepIPR/baseline/passport_attack/' + task_name + '/' + rep
    os.makedirs(logdir, exist_ok=True)
    best_file = os.path.join(logdir, 'best.txt')
    log_file = os.path.join(logdir, 'log.txt')
    lf = open(log_file, 'a')
    shutil.copy('attack_3.py', str(logdir) + "/attack_3.py")

    # #print dict name####
    # for param_tensor in model.state_dict():
    #     print(param_tensor, "\t", model.state_dict()[param_tensor].size())
    #     print(param_tensor, "\t", model.state_dict()[param_tensor].size(),file=lf)

    for param in model.parameters():
        param.requires_grad_(False)

    passblocks = []
    origpassport = []
    fakepassport = []

    for m in model.modules():
        if isinstance(m, PassportBlock) or isinstance(m, PassportPrivateBlock):
            passblocks.append(m)

            if scheme == 1:
                keyname = 'key'
                skeyname = 'skey'
            else:
                keyname = 'key_private'
                skeyname = 'skey_private'

            key, skey = m.__getattr__(keyname).data.clone(), m.__getattr__(skeyname).data.clone()
            origpassport.append(key.to(device))
            origpassport.append(skey.to(device))

            m.__delattr__(keyname) #删除属性
            m.__delattr__(skeyname)

            # #fake like ori
            # m.register_parameter(keyname, nn.Parameter(key.clone() ))
            # m.register_parameter(skeyname, nn.Parameter(skey.clone()))


            # fake slightly modify ori
            m.register_parameter(keyname, nn.Parameter(key.clone() + torch.randn(*key.size()) * 0.001))
            m.register_parameter(skeyname, nn.Parameter(skey.clone() + torch.randn(*skey.size()) * 0.001))

            fakepassport.append(m.__getattr__(keyname))
            fakepassport.append(m.__getattr__(skeyname))

    if flipperc != 0:
        print(f'Reverse {flipperc * 100:.2f}% of binary signature')
        for m in passblocks:
            mflip = flipperc
            if scheme == 1:
                oldb = m.sign_loss.b
            else:
                oldb = m.sign_loss_private.b
            newb = oldb.clone()

            npidx = np.arange(len(oldb))   #bit 长度
            randsize = int(oldb.view(-1).size(0) * mflip)
            randomidx = np.random.choice(npidx, randsize, replace=False) #随机选择

            newb[randomidx] = oldb[randomidx] * -1  # reverse bit  进行翻转


            if scheme == 1:
                m.sign_loss.set_b(newb)
            else:
                m.sign_loss_private.set_b(newb)

    model.to(device)

    optimizer = torch.optim.SGD(fakepassport,
                                lr=lr,
                                momentum=0.9,
                                weight_decay=0.0005)

    scheduler = None
    criterion = nn.CrossEntropyLoss()

    history = []

    def run_cs():  #计算余弦相似性
        cs = []

        for d1, d2 in zip(origpassport, fakepassport):
            d1 = d1.view(d1.size(0), -1)
            d2 = d2.view(d2.size(0), -1)

            cs.append(F.cosine_similarity(d1, d2).item())

        return cs

    print('Before training')
    print('Before training', file = lf)

    res = {}
    valres = test(model, criterion, valloader, device, scheme)
    for key in valres: res[f'valid_{key}'] = valres[key]

    # print(res)
    # sys.exit(0)

    with torch.no_grad():
        cs = run_cs()

        mseloss = 0
        for l, r in zip(origpassport, fakepassport):
            mse = F.mse_loss(l, r)
            mseloss += mse.item()
        mseloss /= len(origpassport)

    print(f'MSE of Real and Maximize passport: {mseloss:.4f}')
    print(f'MSE of Real and Maximize passport: {mseloss:.4f}', file=lf)
    print(f'Cosine Similarity of Real and Maximize passport: {sum(cs) / len(origpassport):.4f}')
    print(f'Cosine Similarity of Real and Maximize passport: {sum(cs) / len(origpassport):.4f}', file=lf)
    print()

    res['epoch'] = 0
    res['cosine_similarity'] = cs
    res['flipperc'] = flipperc
    res['train_mseloss'] = mseloss

    history.append(res)

    torch.save({'origpassport': origpassport,
                'fakepassport': fakepassport,
                'state_dict': model.state_dict()},
                f'{logdir}/{arch}-{scheme}-last-{dataset}-{rep}-{flipperc:.1f}-e0.pth')

    best_acc = 0
    best_ep = 0

    for ep in range(1, epochs + 1):
        if scheduler is not None:
            scheduler.step()

        print(f'Learning rate: {optimizer.param_groups[0]["lr"]}')
        print(f'Epoch {ep:3d}:')
        print(f'Epoch {ep:3d}:',file=lf)
        print('Training')
        trainres = train_maximize(origpassport, fakepassport, model, optimizer, criterion, trainloader, device, scheme)

        print('Testing')
        print('Testing',file=lf)
        valres = test(model, criterion, valloader, device, scheme)

        print(valres,file=lf)
        print('\n',file=lf)

        if best_acc < valres['acc']:
            print(f'Found best at epoch {ep}\n')
            best_acc = valres['acc']
            best_ep = ep

        f = open(best_file,'a')
        f.write(str(best_acc) + '\n')
        f.write("best epoch: %s"%str(best_ep) + '\n')
        f.flush()

        res = {}

        for key in trainres: res[f'train_{key}'] = trainres[key]
        for key in valres: res[f'valid_{key}'] = valres[key]
        res['epoch'] = ep
        res['flipperc'] = flipperc

        with torch.no_grad():
            cs = run_cs()
            res['cosine_similarity'] = cs

        print(f'Cosine Similarity of Real and Maximize passport: '
              f'{sum(cs) / len(origpassport):.4f}')
        print()

        print(f'Cosine Similarity of Real and Maximize passport: '
              f'{sum(cs) / len(origpassport):.4f}'+'\n', file=lf)
        lf.flush()

        history.append(res)

        torch.save({'origpassport': origpassport,
                    'fakepassport': fakepassport,
                    'state_dict': model.state_dict()},
                    f'{logdir}/{arch}-{scheme}-last-{dataset}-{rep}-{flipperc:.1f}-e{ep}.pth')


        histdf = pd.DataFrame(history)
    histdf.to_csv(f'{logdir}/{arch}-{scheme}-history-{dataset}-{rep}-{flipperc:.1f}.csv')
Exemplo n.º 5
0
def run_attack_1(attack_rep=50, arch='alexnet', dataset='cifar10', scheme=1,
                 loadpath='', passport_config='passport_configs/alexnet_passport.json'):
    batch_size = 64
    nclass = 100 if dataset == 'cifar100' else 10
    inchan = 3
    lr = 0.01
    device = torch.device('cuda')

    baselinepath = f'logs/alexnet_{dataset}/1/models/best.pth'
    passport_kwargs = construct_passport_kwargs_from_dict({'passport_config': json.load(open(passport_config)),
                                                           'norm_type': 'bn',
                                                           'sl_ratio': 0.1,
                                                           'key_type': 'shuffle'})

    if scheme == 1:
        model = AlexNetPassport(inchan, nclass, passport_kwargs)
    elif scheme == 2:
        model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)
    else:
        model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)

    sd = torch.load(loadpath)
    model.load_state_dict(sd, strict=False)

    for fidx in [0, 2]:
        model.features[fidx].bn.weight.data.copy_(sd[f'features.{fidx}.scale'])
        model.features[fidx].bn.bias.data.copy_(sd[f'features.{fidx}.bias'])

    passblocks = []

    for m in model.modules():
        if isinstance(m, PassportBlock) or isinstance(m, PassportPrivateBlock):
            passblocks.append(m)

    trainloader, valloader = prepare_dataset({'transfer_learning': False,
                                              'dataset': dataset,
                                              'tl_dataset': '',
                                              'batch_size': batch_size})
    passport_data = valloader

    pretrained_model = AlexNetNormal(inchan, nclass)
    pretrained_model.load_state_dict(torch.load(baselinepath))
    pretrained_model.to(device)

    def reset_passport():
        print('Reset passport')
        x, y = get_passport(passport_data, device)
        set_intermediate_keys(model, pretrained_model, x, y)

    def run_test():
        res = {}
        valres = test(model, criterion, valloader, device, 1 if scheme != 1 else 0)
        for key in valres: res[f'valid_{key}'] = valres[key]
        res['attack_rep'] = 0
        return res

    criterion = nn.CrossEntropyLoss()

    os.makedirs('logs/passport_attack_1', exist_ok=True)

    history = []

    print('Before training')
    res = run_test()
    history.append(res)

    for r in range(attack_rep):
        print(f'Attack count: {r}')
        reset_passport()
        res = run_test()
        history.append(res)

    histdf = pd.DataFrame(history)
    histdf.to_csv(f'logs/passport_attack_1/{arch}-{scheme}-history-{dataset}-{attack_rep}.csv')
Exemplo n.º 6
0
def run_maximize(rep=1,
                 flipperc=0,
                 arch='alexnet',
                 dataset='cifar10',
                 scheme=1,
                 loadpath='',
                 passport_config='',
                 tagnum=1):
    epochs = {'imagenet1000': 30}.get(dataset, 100)
    batch_size = 64
    nclass = {'cifar100': 100, 'imagenet1000': 1000}.get(dataset, 10)
    inchan = 3
    lr = 0.01
    device = torch.device('cuda')

    trainloader, valloader = prepare_dataset({
        'transfer_learning': False,
        'dataset': dataset,
        'tl_dataset': '',
        'batch_size': batch_size
    })

    passport_kwargs = construct_passport_kwargs_from_dict({
        'passport_config':
        json.load(open(passport_config)),
        'norm_type':
        'bn',
        'sl_ratio':
        0.1,
        'key_type':
        'shuffle'
    })

    if arch == 'alexnet':
        if scheme == 1:
            model = AlexNetPassport(inchan, nclass, passport_kwargs)
        else:
            model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)
    else:
        if scheme == 1:
            model = ResNet18Passport(num_classes=nclass,
                                     passport_kwargs=passport_kwargs)
        else:
            model = ResNet18Private(num_classes=nclass,
                                    passport_kwargs=passport_kwargs)

    sd = torch.load(loadpath)
    model.load_state_dict(sd)

    for param in model.parameters():
        param.requires_grad_(False)

    passblocks = []
    origpassport = []
    fakepassport = []

    for m in model.modules():
        if isinstance(m, PassportBlock) or isinstance(m, PassportPrivateBlock):
            passblocks.append(m)

            if scheme == 1:
                keyname = 'key'
                skeyname = 'skey'
            else:
                keyname = 'key_private'
                skeyname = 'skey_private'

            key, skey = m.__getattr__(keyname).data.clone(), m.__getattr__(
                skeyname).data.clone()
            origpassport.append(key.to(device))
            origpassport.append(skey.to(device))

            m.__delattr__(keyname)
            m.__delattr__(skeyname)

            # re-initialize the key and skey, but by adding noise on it
            m.register_parameter(
                keyname,
                nn.Parameter(key.clone() + torch.randn(*key.size()) * 0.001))
            m.register_parameter(
                skeyname,
                nn.Parameter(skey.clone() + torch.randn(*skey.size()) * 0.001))
            fakepassport.append(m.__getattr__(keyname))
            fakepassport.append(m.__getattr__(skeyname))

    if flipperc != 0:
        print(f'Reverse {flipperc * 100:.2f}% of binary signature')
        for m in passblocks:
            mflip = flipperc
            if scheme == 1:
                oldb = m.sign_loss.b
            else:
                oldb = m.sign_loss_private.b
            newb = oldb.clone()

            npidx = np.arange(len(oldb))
            randsize = int(oldb.view(-1).size(0) * mflip)
            randomidx = np.random.choice(npidx, randsize, replace=False)

            newb[randomidx] = oldb[randomidx] * -1  # reverse bit
            if scheme == 1:
                m.sign_loss.set_b(newb)
            else:
                m.sign_loss_private.set_b(newb)

    model.to(device)

    optimizer = torch.optim.SGD(fakepassport,
                                lr=lr,
                                momentum=0.9,
                                weight_decay=0.0005)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
    #                                                  [int(epochs * 0.5), int(epochs * 0.75)],
    #                                                  0.1)
    scheduler = None
    criterion = nn.CrossEntropyLoss()

    history = []

    dirname = f'logs/passport_attack_3/{loadpath.split("/")[1]}/{loadpath.split("/")[2]}'
    os.makedirs(dirname, exist_ok=True)

    def run_cs():
        cs = []

        for d1, d2 in zip(origpassport, fakepassport):
            d1 = d1.view(d1.size(0), -1)
            d2 = d2.view(d2.size(0), -1)

            cs.append(F.cosine_similarity(d1, d2).item())

        return cs

    print('Before training')
    res = {}
    valres = test(model, criterion, valloader, device, scheme)
    for key in valres:
        res[f'valid_{key}'] = valres[key]
    with torch.no_grad():
        cs = run_cs()

        mseloss = 0
        for l, r in zip(origpassport, fakepassport):
            mse = F.mse_loss(l, r)
            mseloss += mse.item()
        mseloss /= len(origpassport)

    print(f'MSE of Real and Maximize passport: {mseloss:.4f}')
    print(
        f'Cosine Similarity of Real and Maximize passport: {sum(cs) / len(origpassport):.4f}'
    )
    print()

    res['epoch'] = 0
    res['cosine_similarity'] = cs
    res['flipperc'] = flipperc
    res['train_mseloss'] = mseloss

    history.append(res)

    torch.save(
        {
            'origpassport': origpassport,
            'fakepassport': fakepassport,
            'state_dict': model.state_dict()
        },
        f'{dirname}/{arch}-{scheme}-last-{dataset}-{rep}-{tagnum}-{flipperc:.1f}-e0.pth'
    )

    for ep in range(1, epochs + 1):
        if scheduler is not None:
            scheduler.step()

        print(f'Learning rate: {optimizer.param_groups[0]["lr"]}')
        print(f'Epoch {ep:3d}:')
        print('Training')
        trainres = train_maximize(origpassport, fakepassport, model, optimizer,
                                  criterion, trainloader, device, scheme)

        print('Testing')
        valres = test(model, criterion, valloader, device, scheme)

        res = {}

        for key in trainres:
            res[f'train_{key}'] = trainres[key]
        for key in valres:
            res[f'valid_{key}'] = valres[key]
        res['epoch'] = ep
        res['flipperc'] = flipperc

        with torch.no_grad():
            cs = run_cs()
            res['cosine_similarity'] = cs

        print(f'Cosine Similarity of Real and Maximize passport: '
              f'{sum(cs) / len(origpassport):.4f}')
        print()

        history.append(res)

        torch.save(
            {
                'origpassport': origpassport,
                'fakepassport': fakepassport,
                'state_dict': model.state_dict()
            },
            f'{dirname}/{arch}-{scheme}-{dataset}-{rep}-{tagnum}-{flipperc:.1f}-last.pth'
        )

        histdf = pd.DataFrame(history)
        histdf.to_csv(
            f'{dirname}/{arch}-{scheme}-history-{dataset}-{rep}-{tagnum}-{flipperc:.1f}.csv'
        )
Exemplo n.º 7
0
def run_attack_1(attack_rep=50,
                 arch='alexnet',
                 dataset='cifar10',
                 scheme=1,
                 loadpath='',
                 passport_config='passport_configs/alexnet_passport.json',
                 tagnum=1):
    batch_size = 64
    nclass = {'cifar100': 100, 'imagenet1000': 1000}.get(dataset, 10)
    inchan = 3
    lr = 0.01
    device = torch.device('cuda')

    # baselinepath = f'logs/alexnet_{dataset}/1/models/best.pth'
    passport_kwargs, plkeys = construct_passport_kwargs_from_dict(
        {
            'passport_config': json.load(open(passport_config)),
            'norm_type': 'bn',
            'sl_ratio': 0.1,
            'key_type': 'shuffle'
        }, True)

    if arch == 'alexnet':
        if scheme == 1:
            model = AlexNetPassport(inchan, nclass, passport_kwargs)
        else:
            model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)
    else:
        if scheme == 1:
            model = ResNet18Passport(num_classes=nclass,
                                     passport_kwargs=passport_kwargs)
        else:
            model = ResNet18Private(num_classes=nclass,
                                    passport_kwargs=passport_kwargs)

    sd = torch.load(loadpath)
    model.load_state_dict(sd)
    model = model.to(device)

    passblocks = []

    for m in model.modules():
        if isinstance(m, PassportBlock) or isinstance(m, PassportPrivateBlock):
            passblocks.append(m)

    trainloader, valloader = prepare_dataset({
        'transfer_learning': False,
        'dataset': dataset,
        'tl_dataset': '',
        'batch_size': batch_size,
        'shuffle_val': True
    })
    passport_data = valloader

    pretrained_model = load_pretrained(arch, nclass).to(device)

    def reset_passport():
        print('Reset passport')
        x, y = get_passport(passport_data, device)
        passport_generator.set_key(pretrained_model, model, x, y)

    def run_test():
        res = {}
        valres = test(model, criterion, valloader, device,
                      1 if scheme != 1 else 0)
        for key in valres:
            res[f'valid_{key}'] = valres[key]
        res['attack_rep'] = 0
        return res

    criterion = nn.CrossEntropyLoss()

    dirname = f'logs/passport_attack_1/{loadpath.split("/")[1]}/{loadpath.split("/")[2]}'
    os.makedirs(dirname, exist_ok=True)

    history = []

    print('Before training')
    res = run_test()
    history.append(res)

    for r in range(attack_rep):
        print(f'Attack count: {r}')
        reset_passport()
        res = run_test()
        res['attack_rep'] = r
        history.append(res)

    histdf = pd.DataFrame(history)
    histdf.to_csv(
        f'{dirname}/{arch}-{scheme}-history-{dataset}-{attack_rep}-{tagnum}.csv'
    )
Exemplo n.º 8
0
def run_attack_2(rep=1, arch='alexnet', dataset='cifar10', scheme=1, loadpath='',
                 passport_config='passport_configs/alexnet_passport.json', tagnum=1):
    epochs = {
        'imagenet1000': 30
    }.get(dataset, 100)
    batch_size = 64
    nclass = {
        'cifar100': 100,
        'imagenet1000': 1000
    }.get(dataset, 10)
    inchan = 3
    lr = 0.01
    device = torch.device('cuda')

    trainloader, valloader = prepare_dataset({'transfer_learning': False,
                                              'dataset': dataset,
                                              'tl_dataset': '',
                                              'batch_size': batch_size})
    passport_kwargs, plkeys = construct_passport_kwargs_from_dict({'passport_config': json.load(open(passport_config)),
                                                                   'norm_type': 'bn',
                                                                   'sl_ratio': 0.1,
                                                                   'key_type': 'shuffle'},
                                                                  True)

    if arch == 'alexnet':
        model = AlexNetNormal(inchan, nclass, 'bn' if scheme == 1 else 'gn')
    else:
        ResNetClass = ResNet18 if arch == 'resnet18' else ResNet9
        model = ResNetClass(num_classes=nclass,
                            norm_type='bn' if scheme == 1 else 'gn')

    if arch == 'alexnet':
        if scheme == 1:
            passport_model = AlexNetPassport(inchan, nclass, passport_kwargs)
        else:
            passport_model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)
    else:
        if scheme == 1:
            ResNetClass = ResNet18Passport if arch == 'resnet18' else ResNet9Passport
            passport_model = ResNetClass(num_classes=nclass, passport_kwargs=passport_kwargs)
        else:
            if arch == 'resnet9':
                raise NotImplementedError
            passport_model = ResNet18Private(num_classes=nclass, passport_kwargs=passport_kwargs)

    sd = torch.load(loadpath)
    passport_model.load_state_dict(sd)
    passport_model = passport_model.to(device)

    sd = torch.load(loadpath)
    model.load_state_dict(sd, strict=False)  # need to load with strict because passport model no scale and bias
    model = model.to(device)

    for param in model.parameters():
        param.requires_grad_(False)

    # for fidx in [0, 2]:
    #     model.features[fidx].bn.weight.data.copy_(sd[f'features.{fidx}.scale'])
    #     model.features[fidx].bn.bias.data.copy_(sd[f'features.{fidx}.bias'])

    if arch == 'alexnet':
        for fidx in plkeys:
            fidx = int(fidx)
            model.features[fidx].bn.weight.data.copy_(passport_model.features[fidx].get_scale().view(-1))
            model.features[fidx].bn.bias.data.copy_(passport_model.features[fidx].get_bias().view(-1))

            model.features[fidx].bn.weight.requires_grad_(True)
            model.features[fidx].bn.bias.requires_grad_(True)
    else:
        for fidx in plkeys:
            layer_key, i, module_key = fidx.split('.')

            def get_layer(m):
                return m.__getattr__(layer_key)[int(i)].__getattr__(module_key)

            convblock = get_layer(model)
            passblock = get_layer(passport_model)
            convblock.bn.weight.data.copy_(passblock.get_scale().view(-1))
            convblock.bn.bias.data.copy_(passblock.get_bias().view(-1))

            convblock.bn.weight.requires_grad_(True)
            convblock.bn.bias.requires_grad_(True)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=0.9,
                                weight_decay=0.0005)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
    #                                                  [int(epochs * 0.5), int(epochs * 0.75)],
    #                                                  0.1)
    scheduler = None
    criterion = nn.CrossEntropyLoss()

    history = []

    def evaluate():
        print('Before training')
        valres = test(model, criterion, valloader, device)
        res = {}
        for key in valres: res[f'valid_{key}'] = valres[key]
        res['epoch'] = 0
        history.append(res)
        print()

    # evaluate()

    conv_weights_to_reset = []
    total_weight_size = 0

    if arch == 'alexnet':
        sim = 0
        for fidx in plkeys:
            fidx = int(fidx)

            w = model.features[fidx].bn.weight
            size = w.size(0)
            conv_weights_to_reset.append(w)
            total_weight_size += size

            model.features[fidx].bn.bias.data.zero_()

            model.features[fidx].bn.weight.requires_grad_(True)
            model.features[fidx].bn.bias.requires_grad_(True)
    else:
        for fidx in plkeys:
            layer_key, i, module_key = fidx.split('.')

            def get_layer(m):
                return m.__getattr__(layer_key)[int(i)].__getattr__(module_key)

            convblock = get_layer(model)
            passblock = get_layer(passport_model)

            w = convblock.bn.weight
            size = w.size(0)
            conv_weights_to_reset.append(w)
            total_weight_size += size

            convblock.bn.bias.data.zero_()
            convblock.bn.weight.requires_grad_(True)
            convblock.bn.bias.requires_grad_(True)

    randidxs = torch.randperm(total_weight_size)
    idxs = randidxs[:int(total_weight_size * args.flipperc)]
    print(total_weight_size, len(idxs))
    sim = 0

    for w in conv_weights_to_reset:
        size = w.size(0)
        # wsize of first layer = 64, e.g. 0~63 - 64 = -64~-1, this is the indices within the first layer
        print(len(idxs), size)
        widxs = idxs[(idxs - size) < 0]

        # reset the weights but remains signature sign bit
        origsign = w.data.sign()
        newsign = origsign.clone()

        # reverse the sign on target bit
        newsign[widxs] *= -1

        # assign new signature
        w.data.copy_(newsign)

        sim += ((w.data.sign() == origsign).float().mean())

        # remove all indices from first layer
        idxs = idxs[(idxs - size) >= 0] - size

    print('signature similarity', sim / len(conv_weights_to_reset))

    evaluate()

    dirname = f'logs/passport_attack_2/{loadpath.split("/")[1]}/{loadpath.split("/")[2]}'
    os.makedirs(dirname, exist_ok=True)

    json.dump(vars(args), open(f'{dirname}/{arch}-{scheme}-last-{dataset}-{rep}-{tagnum}.json', 'w+'))

    for ep in range(1, epochs + 1):
        if scheduler is not None:
            scheduler.step()

        print(f'Learning rate: {optimizer.param_groups[0]["lr"]}')
        print(f'Epoch {ep:3d}:')
        print('Training')
        trainres = train(model, optimizer, criterion, trainloader, device)

        print('Testing')
        valres = test(model, criterion, valloader, device)

        print()

        res = {}

        for key in trainres: res[f'train_{key}'] = trainres[key]
        for key in valres: res[f'valid_{key}'] = valres[key]
        res['epoch'] = ep

        history.append(res)

        torch.save(model.state_dict(),
                   f'{dirname}/{arch}-{scheme}-last-{dataset}-{rep}-{tagnum}.pth')

        histdf = pd.DataFrame(history)
        histdf.to_csv(f'{dirname}/{arch}-{scheme}-history-{dataset}-{tagnum}.csv')