示例#1
0
    def transfer_learning(self):
        if not self.is_tl:
            raise Exception('Please run with --transfer-learning')

        self.num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'caltech-101': 101,
            'caltech-256': 256
        }[self.tl_dataset]

        ##### load clone model #####
        print('Loading clone model')
        if self.arch == 'alexnet':
            clone_model = AlexNetNormal(self.in_channels,
                                        self.num_classes,
                                        self.norm_type)
        else:
            clone_model = ResNet18(num_classes=self.num_classes,
                                   norm_type=self.norm_type)

        ##### load / reset weights of passport layers for clone model #####
        try:
            clone_model.load_state_dict(self.model.state_dict())
        except:
            print('Having problem to direct load state dict, loading it manually')
            if self.arch == 'alexnet':
                for clone_m, self_m in zip(clone_model.features, self.model.features):
                    try:
                        clone_m.load_state_dict(self_m.state_dict())
                    except:
                        print('Having problem to load state dict usually caused by missing keys, load by strict=False')
                        clone_m.load_state_dict(self_m.state_dict(), False)  # load conv weight, bn running mean
                        clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                        clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

            else:
                passport_settings = self.passport_config
                for l_key in passport_settings:  # layer
                    if isinstance(passport_settings[l_key], dict):
                        for i in passport_settings[l_key]:  # sequential
                            for m_key in passport_settings[l_key][i]:  # convblock
                                clone_m = clone_model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: ConvBlock
                                self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: PassportBlock

                                try:
                                    clone_m.load_state_dict(self_m.state_dict())
                                except:
                                    print(f'{l_key}.{i}.{m_key} cannot load state dict directly')
                                    clone_m.load_state_dict(self_m.state_dict(), False)
                                    clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                                    clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

                    else:
                        clone_m = clone_model.__getattr__(l_key)
                        self_m = self.model.__getattr__(l_key)

                        try:
                            clone_m.load_state_dict(self_m.state_dict())
                        except:
                            print(f'{l_key} cannot load state dict directly')
                            clone_m.load_state_dict(self_m.state_dict(), False)
                            clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                            clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

        clone_model.to(self.device)
        print('Loaded clone model')

        ##### dataset is created at constructor #####

        ##### tl scheme setup #####
        if self.tl_scheme == 'rtal':
            # rtal = reset last layer + train all layer
            # ftal = train all layer
            try:
                clone_model.classifier.reset_parameters()
            except:
                clone_model.linear.reset_parameters()

        ##### optimizer setup #####
        optimizer = optim.SGD(clone_model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        self.trainer = Trainer(clone_model,
                               optimizer,
                               scheduler,
                               self.device)
        tester = Tester(self.model,
                        self.device)
        tester_passport = TesterPrivate(self.model,
                                        self.device)

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True
        best_acc = 0

        for ep in range(1, self.epochs + 1):
            train_metrics = self.trainer.train(ep, self.train_data)
            valid_metrics = self.trainer.test(self.valid_data)

            ##### load transfer learning weights from clone model  #####
            try:
                self.model.load_state_dict(clone_model.state_dict())
            except:
                if self.arch == 'alexnet':
                    for clone_m, self_m in zip(clone_model.features, self.model.features):
                        try:
                            self_m.load_state_dict(clone_m.state_dict())
                        except:
                            self_m.load_state_dict(clone_m.state_dict(), False)
                else:
                    passport_settings = self.passport_config
                    for l_key in passport_settings:  # layer
                        if isinstance(passport_settings[l_key], dict):
                            for i in passport_settings[l_key]:  # sequential
                                for m_key in passport_settings[l_key][i]:  # convblock
                                    clone_m = clone_model.__getattr__(l_key)[int(i)].__getattr__(m_key)
                                    self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)

                                    try:
                                        self_m.load_state_dict(clone_m.state_dict())
                                    except:
                                        self_m.load_state_dict(clone_m.state_dict(), False)
                        else:
                            clone_m = clone_model.__getattr__(l_key)
                            self_m = self.model.__getattr__(l_key)

                            try:
                                self_m.load_state_dict(clone_m.state_dict())
                            except:
                                self_m.load_state_dict(clone_m.state_dict(), False)

            clone_model.to(self.device)
            self.model.to(self.device)

            wm_metrics = {}
            if self.train_backdoor:
                wm_metrics = tester.test(self.wm_data, 'WM Result')

            if self.train_passport:
                res = tester_passport.test_signature()
                for key in res: wm_metrics['passport_' + key] = res[key]

            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key]
            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')
                self.save_model(f'tl-epoch-{ep}.pth', clone_model)

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')
                self.save_model('tl-best.pth', clone_model)

            self.save_last_model()
    def transfer_learning(self):
        if not self.is_tl:
            raise Exception('Please run with --transfer-learning')

        if self.tl_dataset == 'caltech-101':
            self.num_classes = 101
        elif self.tl_dataset == 'cifar100':
            self.num_classes = 100
        elif self.tl_dataset == 'caltech-256':
            self.num_classes = 257
        else:  # cifar10
            self.num_classes = 10

        # load clone model
        print('Loading clone model')
        if self.arch == 'alexnet':
            tl_model = AlexNetNormal(self.in_channels,
                                     self.num_classes,
                                     self.norm_type)
        else:
            tl_model = ResNet18(num_classes=self.num_classes,
                                norm_type=self.norm_type)

        # # 自己的更改,fine-tune alex 一路
        # if self.arch == 'alexnet':
        #     tl_model = AlexNetPassportPrivate(self.in_channels, self.num_classes, passport_kwargs)
        # else:
        #     tl_model = ResNet18Private(num_classes=self.num_classes, passport_kwargs=passport_kwargs)
        #

        ##### load / reset weights of passport layers for clone model #####

        try:
            tl_model.load_state_dict(self.model.state_dict())
            # tl_model.load_state_dict(self.copy_model.state_dict())
        except:
            print('Having problem to direct load state dict, loading it manually')
            if self.arch == 'alexnet':
                for tl_m, self_m in zip(tl_model.features, self.model.features):

                    try:
                        tl_m.load_state_dict(self_m.state_dict())
                    except:
                        print(
                            'Having problem to load state dict usually caused by missing keys, load by strict=False')
                        tl_m.load_state_dict(self_m.state_dict(), False)  # load conv weight, bn running mean
                        # print(self_m)
                        # print(tl_m)
                        # 原来的参数载入
                        # tl_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                        # tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

                        #更改,注意bn的值
                        scale1,scale2 = self_m.get_scale()
                        tl_m.bn.weight.data.copy_(scale1.detach().view(-1))
                        tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))



            else:

                passport_settings = self.passport_config
                for l_key in passport_settings:  # layer
                    if isinstance(passport_settings[l_key], dict):
                        for i in passport_settings[l_key]:  # sequential
                            for m_key in passport_settings[l_key][i]:  # convblock

                                tl_m = tl_model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: ConvBlock
                                self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)

                                try:
                                    tl_m.load_state_dict(self_m.state_dict())
                                except:
                                    print(f'{l_key}.{i}.{m_key} cannot load state dict directly')
                                    # print(self_m)
                                    # print(tl_m)
                                    tl_m.load_state_dict(self_m.state_dict(), False)

                                    scale1, scale2 = self_m.get_scale()
                                    tl_m.bn.weight.data.copy_(scale1.detach().view(-1))
                                    tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

                    else:
                        print("FFFFFFFFFFFFFFFFFFFFFFF")
                        tl_m = tl_model.__getattr__(l_key)
                        self_m = self.model.__getattr__(l_key)

                        try:
                            tl_m.load_state_dict(self_m.state_dict())
                        except:
                            print(f'{l_key} cannot load state dict directly')
                            tl_m.load_state_dict(self_m.state_dict(), False)
                            # tl_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))

                            scale1, scale2 = self_m.get_scale()
                            tl_m.bn.weight.data.copy_(scale1.detach().view(-1))
                            tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

        tl_model.to(self.device)
        print('Loaded clone model')

        # tl scheme setup
        if self.tl_scheme == 'rtal':
            # rtal = reset last layer + train all layer
            # ftal = train all layer
            try:
                tl_model.classifier.reset_parameters()
            except:
                tl_model.linear.reset_parameters()

        # for name, m in self.model.named_modules():
        #     print('name',name)
        #     if name
        #
        # for i in self.model.fc.parameters():
        #     i.requires_grad = False
        #
        # for i in self.model.bn1.parameters():
        #     i.requires_grad = False


        optimizer = optim.SGD(tl_model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        tl_trainer = Trainer(tl_model,
                             optimizer,
                             scheduler,
                             self.device)

        tester = TesterPrivate(self.model,
                               self.device)

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True
        best_acc = 0
        best_file = os.path.join(self.logdir, 'best.txt')
        best_ep = 1

        for ep in range(1, self.epochs + 1):
            train_metrics = tl_trainer.train(ep, self.train_data)
            valid_metrics = tl_trainer.test(self.valid_data)

            ##### load transfer learning weights from clone model  #####
            try:
                self.model.load_state_dict(tl_model.state_dict())
            except:
                if self.arch == 'alexnet':
                    for tl_m, self_m in zip(tl_model.features, self.model.features):
                        try:
                            self_m.load_state_dict(tl_m.state_dict())
                        except:
                            self_m.load_state_dict(tl_m.state_dict(), False)
                else:
                    passport_settings = self.passport_config
                    for l_key in passport_settings:  # layer
                        if isinstance(passport_settings[l_key], dict):
                            for i in passport_settings[l_key]:  # sequential
                                for m_key in passport_settings[l_key][i]:  # convblock
                                    tl_m = tl_model.__getattr__(l_key)[int(i)].__getattr__(m_key)
                                    self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)

                                    try:
                                        self_m.load_state_dict(tl_m.state_dict())
                                    except:
                                        self_m.load_state_dict(tl_m.state_dict(), False)
                        else:
                            tl_m = tl_model.__getattr__(l_key)
                            self_m = self.model.__getattr__(l_key)

                            try:
                                self_m.load_state_dict(tl_m.state_dict())
                            except:
                                self_m.load_state_dict(tl_m.state_dict(), False)

            wm_metrics = tester.test_signature()
            L = len(wm_metrics)
            S = sum(wm_metrics.values())
            pri_sign = S/L

            if self.train_backdoor:
                backdoor_metrics = tester.test(self.wm_data, 'Old WM Accuracy')

            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key]
            if self.train_backdoor:
                for key in backdoor_metrics: metrics[f'backdoor_{key}'] = backdoor_metrics[key]
            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')
                self.save_model(f'tl-epoch-{ep}.pth', tl_model)

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')
                self.save_model('tl-best.pth', tl_model)
                best_ep = ep

            self.save_last_model()
            f = open(best_file,'a')
            print(str(wm_metrics) + '\n', file=f)
            print(str(metrics) + '\n', file=f)
            f.write('Bset ACC %s'%str(best_acc) + "\n")
            print('Private Sign Detction:',str(pri_sign) + '\n', file=f)
            f.write( "\n")
            f.write("best epoch: %s"%str(best_ep) + '\n')
            f.flush()