Пример #1
0
        def setup_keys():
            if self.key_type != 'random':
                pretrained_from_torch = self.pretrained_path is None
                if self.arch == 'alexnet':
                    norm_type = 'none' if pretrained_from_torch else self.norm_type
                    pretrained_model = AlexNetNormal(
                        self.in_channels,
                        self.num_classes,
                        norm_type=norm_type,
                        pretrained=pretrained_from_torch)
                else:
                    norm_type = 'bn' if pretrained_from_torch else self.norm_type
                    pretrained_model = ResNet18(
                        num_classes=self.num_classes,
                        norm_type=norm_type,
                        pretrained=pretrained_from_torch)

                if not pretrained_from_torch:
                    print('Loading pretrained from self-trained model')
                    pretrained_model.load_state_dict(
                        torch.load(self.pretrained_path))
                else:
                    print('Loading pretrained from torch-pretrained model')

                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)
 def setup_keys():
     if self.key_type != 'random':
         if self.arch == 'alexnet':
             pretrained_model = AlexNetNormal(self.in_channels, self.num_classes, self.norm_type)
         else:
             pretrained_model = ResNet18(num_classes=self.num_classes, norm_type=self.norm_type)
         pretrained_model.load_state_dict(torch.load(self.pretrained_path))
         pretrained_model = pretrained_model.to(self.device)
         self.setup_keys(pretrained_model)
Пример #3
0
    def transfer_learning(self):
        if not self.is_tl:
            raise Exception('Please run with --transfer-learning')

        self.num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'caltech-101': 101,
            'caltech-256': 256
        }[self.tl_dataset]

        ##### load clone model #####
        print('Loading clone model')
        if self.arch == 'alexnet':
            clone_model = AlexNetNormal(self.in_channels,
                                        self.num_classes,
                                        self.norm_type)
        else:
            clone_model = ResNet18(num_classes=self.num_classes,
                                   norm_type=self.norm_type)

        ##### load / reset weights of passport layers for clone model #####
        try:
            clone_model.load_state_dict(self.model.state_dict())
        except:
            print('Having problem to direct load state dict, loading it manually')
            if self.arch == 'alexnet':
                for clone_m, self_m in zip(clone_model.features, self.model.features):
                    try:
                        clone_m.load_state_dict(self_m.state_dict())
                    except:
                        print('Having problem to load state dict usually caused by missing keys, load by strict=False')
                        clone_m.load_state_dict(self_m.state_dict(), False)  # load conv weight, bn running mean
                        clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                        clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

            else:
                passport_settings = self.passport_config
                for l_key in passport_settings:  # layer
                    if isinstance(passport_settings[l_key], dict):
                        for i in passport_settings[l_key]:  # sequential
                            for m_key in passport_settings[l_key][i]:  # convblock
                                clone_m = clone_model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: ConvBlock
                                self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: PassportBlock

                                try:
                                    clone_m.load_state_dict(self_m.state_dict())
                                except:
                                    print(f'{l_key}.{i}.{m_key} cannot load state dict directly')
                                    clone_m.load_state_dict(self_m.state_dict(), False)
                                    clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                                    clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

                    else:
                        clone_m = clone_model.__getattr__(l_key)
                        self_m = self.model.__getattr__(l_key)

                        try:
                            clone_m.load_state_dict(self_m.state_dict())
                        except:
                            print(f'{l_key} cannot load state dict directly')
                            clone_m.load_state_dict(self_m.state_dict(), False)
                            clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                            clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

        clone_model.to(self.device)
        print('Loaded clone model')

        ##### dataset is created at constructor #####

        ##### tl scheme setup #####
        if self.tl_scheme == 'rtal':
            # rtal = reset last layer + train all layer
            # ftal = train all layer
            try:
                clone_model.classifier.reset_parameters()
            except:
                clone_model.linear.reset_parameters()

        ##### optimizer setup #####
        optimizer = optim.SGD(clone_model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        self.trainer = Trainer(clone_model,
                               optimizer,
                               scheduler,
                               self.device)
        tester = Tester(self.model,
                        self.device)
        tester_passport = TesterPrivate(self.model,
                                        self.device)

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True
        best_acc = 0

        for ep in range(1, self.epochs + 1):
            train_metrics = self.trainer.train(ep, self.train_data)
            valid_metrics = self.trainer.test(self.valid_data)

            ##### load transfer learning weights from clone model  #####
            try:
                self.model.load_state_dict(clone_model.state_dict())
            except:
                if self.arch == 'alexnet':
                    for clone_m, self_m in zip(clone_model.features, self.model.features):
                        try:
                            self_m.load_state_dict(clone_m.state_dict())
                        except:
                            self_m.load_state_dict(clone_m.state_dict(), False)
                else:
                    passport_settings = self.passport_config
                    for l_key in passport_settings:  # layer
                        if isinstance(passport_settings[l_key], dict):
                            for i in passport_settings[l_key]:  # sequential
                                for m_key in passport_settings[l_key][i]:  # convblock
                                    clone_m = clone_model.__getattr__(l_key)[int(i)].__getattr__(m_key)
                                    self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)

                                    try:
                                        self_m.load_state_dict(clone_m.state_dict())
                                    except:
                                        self_m.load_state_dict(clone_m.state_dict(), False)
                        else:
                            clone_m = clone_model.__getattr__(l_key)
                            self_m = self.model.__getattr__(l_key)

                            try:
                                self_m.load_state_dict(clone_m.state_dict())
                            except:
                                self_m.load_state_dict(clone_m.state_dict(), False)

            clone_model.to(self.device)
            self.model.to(self.device)

            wm_metrics = {}
            if self.train_backdoor:
                wm_metrics = tester.test(self.wm_data, 'WM Result')

            if self.train_passport:
                res = tester_passport.test_signature()
                for key in res: wm_metrics['passport_' + key] = res[key]

            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key]
            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')
                self.save_model(f'tl-epoch-{ep}.pth', clone_model)

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')
                self.save_model('tl-best.pth', clone_model)

            self.save_last_model()
Пример #4
0
def run_attack_2(rep=1,
                 arch='alexnet',
                 dataset='cifar10',
                 scheme=1,
                 loadpath=''):
    epochs = 100
    batch_size = 64
    nclass = 100 if dataset == 'cifar100' else 10
    inchan = 3
    lr = 0.01
    device = torch.device('cuda')

    trainloader, valloader = prepare_dataset({
        'transfer_learning': False,
        'dataset': dataset,
        'tl_dataset': '',
        'batch_size': batch_size
    })

    model = AlexNetNormal(inchan, nclass, 'bn' if scheme == 1 else 'gn')
    model.to(device)

    sd = torch.load(loadpath)
    model.load_state_dict(sd, strict=False)

    for param in model.parameters():
        param.requires_grad_(False)

    for fidx in [0, 2]:
        model.features[fidx].bn.weight.data.copy_(sd[f'features.{fidx}.scale'])
        model.features[fidx].bn.bias.data.copy_(sd[f'features.{fidx}.bias'])

    for fidx in [4, 5, 6]:
        model.features[fidx].bn.weight.data.normal_().sign_().mul_(0.5)
        model.features[fidx].bn.bias.data.zero_()

        model.features[fidx].bn.weight.requires_grad_(True)
        model.features[fidx].bn.bias.requires_grad_(True)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=0.9,
                                weight_decay=0.0005)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
    #                                                  [int(epochs * 0.5), int(epochs * 0.75)],
    #                                                  0.1)
    scheduler = None
    criterion = nn.CrossEntropyLoss()

    history = []

    print('Before training')
    valres = test(model, criterion, valloader, device)
    res = {}
    for key in valres:
        res[f'valid_{key}'] = valres[key]
    res['epoch'] = 0
    history.append(res)
    print()

    os.makedirs('logs/passport_attack_2', exist_ok=True)

    for ep in range(1, epochs + 1):
        if scheduler is not None:
            scheduler.step()

        print(f'Learning rate: {optimizer.param_groups[0]["lr"]}')
        print(f'Epoch {ep:3d}:')
        print('Training')
        trainres = train(model, optimizer, criterion, trainloader, device)

        print('Testing')
        valres = test(model, criterion, valloader, device)

        print()

        res = {}

        for key in trainres:
            res[f'train_{key}'] = trainres[key]
        for key in valres:
            res[f'valid_{key}'] = valres[key]
        res['epoch'] = ep

        history.append(res)

        torch.save(
            model.state_dict(),
            f'logs/passport_attack_2/{arch}-{scheme}-last-{dataset}-{rep}.pth')

    histdf = pd.DataFrame(history)
    histdf.to_csv(
        f'logs/passport_attack_2/{arch}-{scheme}-history-{dataset}-{rep}.csv')
    def transfer_learning(self):
        if not self.is_tl:
            raise Exception('Please run with --transfer-learning')

        if self.tl_dataset == 'caltech-101':
            self.num_classes = 101
        elif self.tl_dataset == 'cifar100':
            self.num_classes = 100
        elif self.tl_dataset == 'caltech-256':
            self.num_classes = 257
        else:  # cifar10
            self.num_classes = 10

        # load clone model
        print('Loading clone model')
        if self.arch == 'alexnet':
            tl_model = AlexNetNormal(self.in_channels,
                                     self.num_classes,
                                     self.norm_type)
        else:
            tl_model = ResNet18(num_classes=self.num_classes,
                                norm_type=self.norm_type)

        # # 自己的更改,fine-tune alex 一路
        # if self.arch == 'alexnet':
        #     tl_model = AlexNetPassportPrivate(self.in_channels, self.num_classes, passport_kwargs)
        # else:
        #     tl_model = ResNet18Private(num_classes=self.num_classes, passport_kwargs=passport_kwargs)
        #

        ##### load / reset weights of passport layers for clone model #####

        try:
            tl_model.load_state_dict(self.model.state_dict())
            # tl_model.load_state_dict(self.copy_model.state_dict())
        except:
            print('Having problem to direct load state dict, loading it manually')
            if self.arch == 'alexnet':
                for tl_m, self_m in zip(tl_model.features, self.model.features):

                    try:
                        tl_m.load_state_dict(self_m.state_dict())
                    except:
                        print(
                            'Having problem to load state dict usually caused by missing keys, load by strict=False')
                        tl_m.load_state_dict(self_m.state_dict(), False)  # load conv weight, bn running mean
                        # print(self_m)
                        # print(tl_m)
                        # 原来的参数载入
                        # tl_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                        # tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

                        #更改,注意bn的值
                        scale1,scale2 = self_m.get_scale()
                        tl_m.bn.weight.data.copy_(scale1.detach().view(-1))
                        tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))



            else:

                passport_settings = self.passport_config
                for l_key in passport_settings:  # layer
                    if isinstance(passport_settings[l_key], dict):
                        for i in passport_settings[l_key]:  # sequential
                            for m_key in passport_settings[l_key][i]:  # convblock

                                tl_m = tl_model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: ConvBlock
                                self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)

                                try:
                                    tl_m.load_state_dict(self_m.state_dict())
                                except:
                                    print(f'{l_key}.{i}.{m_key} cannot load state dict directly')
                                    # print(self_m)
                                    # print(tl_m)
                                    tl_m.load_state_dict(self_m.state_dict(), False)

                                    scale1, scale2 = self_m.get_scale()
                                    tl_m.bn.weight.data.copy_(scale1.detach().view(-1))
                                    tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

                    else:
                        print("FFFFFFFFFFFFFFFFFFFFFFF")
                        tl_m = tl_model.__getattr__(l_key)
                        self_m = self.model.__getattr__(l_key)

                        try:
                            tl_m.load_state_dict(self_m.state_dict())
                        except:
                            print(f'{l_key} cannot load state dict directly')
                            tl_m.load_state_dict(self_m.state_dict(), False)
                            # tl_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))

                            scale1, scale2 = self_m.get_scale()
                            tl_m.bn.weight.data.copy_(scale1.detach().view(-1))
                            tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

        tl_model.to(self.device)
        print('Loaded clone model')

        # tl scheme setup
        if self.tl_scheme == 'rtal':
            # rtal = reset last layer + train all layer
            # ftal = train all layer
            try:
                tl_model.classifier.reset_parameters()
            except:
                tl_model.linear.reset_parameters()

        # for name, m in self.model.named_modules():
        #     print('name',name)
        #     if name
        #
        # for i in self.model.fc.parameters():
        #     i.requires_grad = False
        #
        # for i in self.model.bn1.parameters():
        #     i.requires_grad = False


        optimizer = optim.SGD(tl_model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        tl_trainer = Trainer(tl_model,
                             optimizer,
                             scheduler,
                             self.device)

        tester = TesterPrivate(self.model,
                               self.device)

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True
        best_acc = 0
        best_file = os.path.join(self.logdir, 'best.txt')
        best_ep = 1

        for ep in range(1, self.epochs + 1):
            train_metrics = tl_trainer.train(ep, self.train_data)
            valid_metrics = tl_trainer.test(self.valid_data)

            ##### load transfer learning weights from clone model  #####
            try:
                self.model.load_state_dict(tl_model.state_dict())
            except:
                if self.arch == 'alexnet':
                    for tl_m, self_m in zip(tl_model.features, self.model.features):
                        try:
                            self_m.load_state_dict(tl_m.state_dict())
                        except:
                            self_m.load_state_dict(tl_m.state_dict(), False)
                else:
                    passport_settings = self.passport_config
                    for l_key in passport_settings:  # layer
                        if isinstance(passport_settings[l_key], dict):
                            for i in passport_settings[l_key]:  # sequential
                                for m_key in passport_settings[l_key][i]:  # convblock
                                    tl_m = tl_model.__getattr__(l_key)[int(i)].__getattr__(m_key)
                                    self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)

                                    try:
                                        self_m.load_state_dict(tl_m.state_dict())
                                    except:
                                        self_m.load_state_dict(tl_m.state_dict(), False)
                        else:
                            tl_m = tl_model.__getattr__(l_key)
                            self_m = self.model.__getattr__(l_key)

                            try:
                                self_m.load_state_dict(tl_m.state_dict())
                            except:
                                self_m.load_state_dict(tl_m.state_dict(), False)

            wm_metrics = tester.test_signature()
            L = len(wm_metrics)
            S = sum(wm_metrics.values())
            pri_sign = S/L

            if self.train_backdoor:
                backdoor_metrics = tester.test(self.wm_data, 'Old WM Accuracy')

            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key]
            if self.train_backdoor:
                for key in backdoor_metrics: metrics[f'backdoor_{key}'] = backdoor_metrics[key]
            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')
                self.save_model(f'tl-epoch-{ep}.pth', tl_model)

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')
                self.save_model('tl-best.pth', tl_model)
                best_ep = ep

            self.save_last_model()
            f = open(best_file,'a')
            print(str(wm_metrics) + '\n', file=f)
            print(str(metrics) + '\n', file=f)
            f.write('Bset ACC %s'%str(best_acc) + "\n")
            print('Private Sign Detction:',str(pri_sign) + '\n', file=f)
            f.write( "\n")
            f.write("best epoch: %s"%str(best_ep) + '\n')
            f.flush()
Пример #6
0
def run_attack_1(attack_rep=50, arch='alexnet', dataset='cifar10', scheme=1,
                 loadpath='', passport_config='passport_configs/alexnet_passport.json'):
    batch_size = 64
    nclass = 100 if dataset == 'cifar100' else 10
    inchan = 3
    lr = 0.01
    device = torch.device('cuda')

    baselinepath = f'logs/alexnet_{dataset}/1/models/best.pth'
    passport_kwargs = construct_passport_kwargs_from_dict({'passport_config': json.load(open(passport_config)),
                                                           'norm_type': 'bn',
                                                           'sl_ratio': 0.1,
                                                           'key_type': 'shuffle'})

    if scheme == 1:
        model = AlexNetPassport(inchan, nclass, passport_kwargs)
    elif scheme == 2:
        model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)
    else:
        model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)

    sd = torch.load(loadpath)
    model.load_state_dict(sd, strict=False)

    for fidx in [0, 2]:
        model.features[fidx].bn.weight.data.copy_(sd[f'features.{fidx}.scale'])
        model.features[fidx].bn.bias.data.copy_(sd[f'features.{fidx}.bias'])

    passblocks = []

    for m in model.modules():
        if isinstance(m, PassportBlock) or isinstance(m, PassportPrivateBlock):
            passblocks.append(m)

    trainloader, valloader = prepare_dataset({'transfer_learning': False,
                                              'dataset': dataset,
                                              'tl_dataset': '',
                                              'batch_size': batch_size})
    passport_data = valloader

    pretrained_model = AlexNetNormal(inchan, nclass)
    pretrained_model.load_state_dict(torch.load(baselinepath))
    pretrained_model.to(device)

    def reset_passport():
        print('Reset passport')
        x, y = get_passport(passport_data, device)
        set_intermediate_keys(model, pretrained_model, x, y)

    def run_test():
        res = {}
        valres = test(model, criterion, valloader, device, 1 if scheme != 1 else 0)
        for key in valres: res[f'valid_{key}'] = valres[key]
        res['attack_rep'] = 0
        return res

    criterion = nn.CrossEntropyLoss()

    os.makedirs('logs/passport_attack_1', exist_ok=True)

    history = []

    print('Before training')
    res = run_test()
    history.append(res)

    for r in range(attack_rep):
        print(f'Attack count: {r}')
        reset_passport()
        res = run_test()
        history.append(res)

    histdf = pd.DataFrame(history)
    histdf.to_csv(f'logs/passport_attack_1/{arch}-{scheme}-history-{dataset}-{attack_rep}.csv')
Пример #7
0
    def transfer_learning(self):
        if not self.is_tl:
            raise Exception('Please run with --transfer-learning')

        is_imagenet = self.num_classes == 1000

        self.num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'caltech-101': 101,
            'caltech-256': 256,
            'imagenet1000': 1000
        }[self.tl_dataset]

        ##### load clone model #####
        print('Loading clone model')
        if self.arch == 'alexnet':
            tl_model = AlexNetNormal(self.in_channels,
                                     self.num_classes,
                                     self.norm_type,
                                     imagenet=is_imagenet)
        else:
            tl_model = ResNet18(num_classes=self.num_classes,
                                norm_type=self.norm_type,
                                imagenet=is_imagenet)

        ##### load / reset weights of passport layers for clone model #####
        tl_model.to(self.device)
        if self.is_baseline:  # baseline
            load_normal_model_to_normal_model(self.arch, tl_model, self.model)
        else:
            load_passport_model_to_normal_model(self.arch, self.plkeys, self.model, tl_model)

        print(tl_model)
        print('Loaded clone model')

        ##### dataset is created at constructor #####

        ##### tl scheme setup #####
        if self.tl_scheme == 'rtal':
            # rtal = reset last layer + train all layer
            # ftal = train all layer
            try:
                if isinstance(tl_model.classifier, nn.Sequential):
                    tl_model.classifier[-1].reset_parameters()
                else:
                    tl_model.classifier.reset_parameters()
            except:
                tl_model.linear.reset_parameters()

        ##### optimizer setup #####
        optimizer = optim.SGD(tl_model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        ##### training is on finetune model
        self.trainer = Trainer(tl_model,
                               optimizer,
                               scheduler,
                               self.device)

        ##### tester is on original model
        tester = Tester(self.model,
                        self.device)
        tester_passport = TesterPrivate(self.model,
                                        self.device)

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True
        best_acc = 0

        for ep in range(1, self.epochs + 1):
            ##### transfer learning on new tasks #####
            train_metrics = self.trainer.train(ep, self.train_data)
            valid_metrics = self.trainer.test(self.valid_data)

            ##### load transfer learning weights from clone model  #####
            if self.is_baseline:
                load_normal_model_to_normal_model(self.arch, self.model, tl_model)
            else:
                load_normal_model_to_passport_model(self.arch, self.plkeys, self.model, tl_model)

            tl_model.to(self.device)
            self.model.to(self.device)

            ##### check if using weight of finetuned model is still able to detect trigger set watermark #####
            wm_metrics = {}
            if self.train_backdoor:
                wm_metrics = tester.test(self.wm_data, 'WM Result')

            ##### check if using weight of finetuend model is still able to extract signature correctly #####
            if not self.is_baseline and self.train_passport:
                res = tester_passport.test_signature()
                for key in res: wm_metrics['passport_' + key] = res[key]

            ##### store results #####
            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key]
            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')
                self.save_model(f'tl-epoch-{ep}.pth', tl_model)

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')
                self.save_model('tl-best.pth', tl_model)

            self.save_last_model()
Пример #8
0
def run_attack_2(rep=1, arch='alexnet', dataset='cifar10', scheme=1, loadpath='',
                 passport_config='passport_configs/alexnet_passport.json', tagnum=1):
    epochs = {
        'imagenet1000': 30
    }.get(dataset, 100)
    batch_size = 64
    nclass = {
        'cifar100': 100,
        'imagenet1000': 1000
    }.get(dataset, 10)
    inchan = 3
    lr = 0.01
    device = torch.device('cuda')

    trainloader, valloader = prepare_dataset({'transfer_learning': False,
                                              'dataset': dataset,
                                              'tl_dataset': '',
                                              'batch_size': batch_size})
    passport_kwargs, plkeys = construct_passport_kwargs_from_dict({'passport_config': json.load(open(passport_config)),
                                                                   'norm_type': 'bn',
                                                                   'sl_ratio': 0.1,
                                                                   'key_type': 'shuffle'},
                                                                  True)

    if arch == 'alexnet':
        model = AlexNetNormal(inchan, nclass, 'bn' if scheme == 1 else 'gn')
    else:
        ResNetClass = ResNet18 if arch == 'resnet18' else ResNet9
        model = ResNetClass(num_classes=nclass,
                            norm_type='bn' if scheme == 1 else 'gn')

    if arch == 'alexnet':
        if scheme == 1:
            passport_model = AlexNetPassport(inchan, nclass, passport_kwargs)
        else:
            passport_model = AlexNetPassportPrivate(inchan, nclass, passport_kwargs)
    else:
        if scheme == 1:
            ResNetClass = ResNet18Passport if arch == 'resnet18' else ResNet9Passport
            passport_model = ResNetClass(num_classes=nclass, passport_kwargs=passport_kwargs)
        else:
            if arch == 'resnet9':
                raise NotImplementedError
            passport_model = ResNet18Private(num_classes=nclass, passport_kwargs=passport_kwargs)

    sd = torch.load(loadpath)
    passport_model.load_state_dict(sd)
    passport_model = passport_model.to(device)

    sd = torch.load(loadpath)
    model.load_state_dict(sd, strict=False)  # need to load with strict because passport model no scale and bias
    model = model.to(device)

    for param in model.parameters():
        param.requires_grad_(False)

    # for fidx in [0, 2]:
    #     model.features[fidx].bn.weight.data.copy_(sd[f'features.{fidx}.scale'])
    #     model.features[fidx].bn.bias.data.copy_(sd[f'features.{fidx}.bias'])

    if arch == 'alexnet':
        for fidx in plkeys:
            fidx = int(fidx)
            model.features[fidx].bn.weight.data.copy_(passport_model.features[fidx].get_scale().view(-1))
            model.features[fidx].bn.bias.data.copy_(passport_model.features[fidx].get_bias().view(-1))

            model.features[fidx].bn.weight.requires_grad_(True)
            model.features[fidx].bn.bias.requires_grad_(True)
    else:
        for fidx in plkeys:
            layer_key, i, module_key = fidx.split('.')

            def get_layer(m):
                return m.__getattr__(layer_key)[int(i)].__getattr__(module_key)

            convblock = get_layer(model)
            passblock = get_layer(passport_model)
            convblock.bn.weight.data.copy_(passblock.get_scale().view(-1))
            convblock.bn.bias.data.copy_(passblock.get_bias().view(-1))

            convblock.bn.weight.requires_grad_(True)
            convblock.bn.bias.requires_grad_(True)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=0.9,
                                weight_decay=0.0005)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
    #                                                  [int(epochs * 0.5), int(epochs * 0.75)],
    #                                                  0.1)
    scheduler = None
    criterion = nn.CrossEntropyLoss()

    history = []

    def evaluate():
        print('Before training')
        valres = test(model, criterion, valloader, device)
        res = {}
        for key in valres: res[f'valid_{key}'] = valres[key]
        res['epoch'] = 0
        history.append(res)
        print()

    # evaluate()

    conv_weights_to_reset = []
    total_weight_size = 0

    if arch == 'alexnet':
        sim = 0
        for fidx in plkeys:
            fidx = int(fidx)

            w = model.features[fidx].bn.weight
            size = w.size(0)
            conv_weights_to_reset.append(w)
            total_weight_size += size

            model.features[fidx].bn.bias.data.zero_()

            model.features[fidx].bn.weight.requires_grad_(True)
            model.features[fidx].bn.bias.requires_grad_(True)
    else:
        for fidx in plkeys:
            layer_key, i, module_key = fidx.split('.')

            def get_layer(m):
                return m.__getattr__(layer_key)[int(i)].__getattr__(module_key)

            convblock = get_layer(model)
            passblock = get_layer(passport_model)

            w = convblock.bn.weight
            size = w.size(0)
            conv_weights_to_reset.append(w)
            total_weight_size += size

            convblock.bn.bias.data.zero_()
            convblock.bn.weight.requires_grad_(True)
            convblock.bn.bias.requires_grad_(True)

    randidxs = torch.randperm(total_weight_size)
    idxs = randidxs[:int(total_weight_size * args.flipperc)]
    print(total_weight_size, len(idxs))
    sim = 0

    for w in conv_weights_to_reset:
        size = w.size(0)
        # wsize of first layer = 64, e.g. 0~63 - 64 = -64~-1, this is the indices within the first layer
        print(len(idxs), size)
        widxs = idxs[(idxs - size) < 0]

        # reset the weights but remains signature sign bit
        origsign = w.data.sign()
        newsign = origsign.clone()

        # reverse the sign on target bit
        newsign[widxs] *= -1

        # assign new signature
        w.data.copy_(newsign)

        sim += ((w.data.sign() == origsign).float().mean())

        # remove all indices from first layer
        idxs = idxs[(idxs - size) >= 0] - size

    print('signature similarity', sim / len(conv_weights_to_reset))

    evaluate()

    dirname = f'logs/passport_attack_2/{loadpath.split("/")[1]}/{loadpath.split("/")[2]}'
    os.makedirs(dirname, exist_ok=True)

    json.dump(vars(args), open(f'{dirname}/{arch}-{scheme}-last-{dataset}-{rep}-{tagnum}.json', 'w+'))

    for ep in range(1, epochs + 1):
        if scheduler is not None:
            scheduler.step()

        print(f'Learning rate: {optimizer.param_groups[0]["lr"]}')
        print(f'Epoch {ep:3d}:')
        print('Training')
        trainres = train(model, optimizer, criterion, trainloader, device)

        print('Testing')
        valres = test(model, criterion, valloader, device)

        print()

        res = {}

        for key in trainres: res[f'train_{key}'] = trainres[key]
        for key in valres: res[f'valid_{key}'] = valres[key]
        res['epoch'] = ep

        history.append(res)

        torch.save(model.state_dict(),
                   f'{dirname}/{arch}-{scheme}-last-{dataset}-{rep}-{tagnum}.pth')

        histdf = pd.DataFrame(history)
        histdf.to_csv(f'{dirname}/{arch}-{scheme}-history-{dataset}-{tagnum}.csv')