def transfer_learning(self): if not self.is_tl: raise Exception('Please run with --transfer-learning') self.num_classes = { 'cifar10': 10, 'cifar100': 100, 'caltech-101': 101, 'caltech-256': 256 }[self.tl_dataset] ##### load clone model ##### print('Loading clone model') if self.arch == 'alexnet': clone_model = AlexNetNormal(self.in_channels, self.num_classes, self.norm_type) else: clone_model = ResNet18(num_classes=self.num_classes, norm_type=self.norm_type) ##### load / reset weights of passport layers for clone model ##### try: clone_model.load_state_dict(self.model.state_dict()) except: print('Having problem to direct load state dict, loading it manually') if self.arch == 'alexnet': for clone_m, self_m in zip(clone_model.features, self.model.features): try: clone_m.load_state_dict(self_m.state_dict()) except: print('Having problem to load state dict usually caused by missing keys, load by strict=False') clone_m.load_state_dict(self_m.state_dict(), False) # load conv weight, bn running mean clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1)) clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1)) else: passport_settings = self.passport_config for l_key in passport_settings: # layer if isinstance(passport_settings[l_key], dict): for i in passport_settings[l_key]: # sequential for m_key in passport_settings[l_key][i]: # convblock clone_m = clone_model.__getattr__(l_key)[int(i)].__getattr__(m_key) # type: ConvBlock self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key) # type: PassportBlock try: clone_m.load_state_dict(self_m.state_dict()) except: print(f'{l_key}.{i}.{m_key} cannot load state dict directly') clone_m.load_state_dict(self_m.state_dict(), False) clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1)) clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1)) else: clone_m = clone_model.__getattr__(l_key) self_m = self.model.__getattr__(l_key) try: clone_m.load_state_dict(self_m.state_dict()) except: print(f'{l_key} cannot load state dict directly') clone_m.load_state_dict(self_m.state_dict(), False) clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1)) clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1)) clone_model.to(self.device) print('Loaded clone model') ##### dataset is created at constructor ##### ##### tl scheme setup ##### if self.tl_scheme == 'rtal': # rtal = reset last layer + train all layer # ftal = train all layer try: clone_model.classifier.reset_parameters() except: clone_model.linear.reset_parameters() ##### optimizer setup ##### optimizer = optim.SGD(clone_model.parameters(), lr=self.lr, momentum=0.9, weight_decay=0.0005) if len(self.lr_config[self.lr_config['type']]) != 0: # if no specify steps, then scheduler = None scheduler = optim.lr_scheduler.MultiStepLR(optimizer, self.lr_config[self.lr_config['type']], self.lr_config['gamma']) else: scheduler = None self.trainer = Trainer(clone_model, optimizer, scheduler, self.device) tester = Tester(self.model, self.device) tester_passport = TesterPrivate(self.model, self.device) history_file = os.path.join(self.logdir, 'history.csv') first = True best_acc = 0 for ep in range(1, self.epochs + 1): train_metrics = self.trainer.train(ep, self.train_data) valid_metrics = self.trainer.test(self.valid_data) ##### load transfer learning weights from clone model ##### try: self.model.load_state_dict(clone_model.state_dict()) except: if self.arch == 'alexnet': for clone_m, self_m in zip(clone_model.features, self.model.features): try: self_m.load_state_dict(clone_m.state_dict()) except: self_m.load_state_dict(clone_m.state_dict(), False) else: passport_settings = self.passport_config for l_key in passport_settings: # layer if isinstance(passport_settings[l_key], dict): for i in passport_settings[l_key]: # sequential for m_key in passport_settings[l_key][i]: # convblock clone_m = clone_model.__getattr__(l_key)[int(i)].__getattr__(m_key) self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key) try: self_m.load_state_dict(clone_m.state_dict()) except: self_m.load_state_dict(clone_m.state_dict(), False) else: clone_m = clone_model.__getattr__(l_key) self_m = self.model.__getattr__(l_key) try: self_m.load_state_dict(clone_m.state_dict()) except: self_m.load_state_dict(clone_m.state_dict(), False) clone_model.to(self.device) self.model.to(self.device) wm_metrics = {} if self.train_backdoor: wm_metrics = tester.test(self.wm_data, 'WM Result') if self.train_passport: res = tester_passport.test_signature() for key in res: wm_metrics['passport_' + key] = res[key] metrics = {} for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key] for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key] for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key] self.append_history(history_file, metrics, first) first = False if self.save_interval and ep % self.save_interval == 0: self.save_model(f'epoch-{ep}.pth') self.save_model(f'tl-epoch-{ep}.pth', clone_model) if best_acc < metrics['valid_acc']: print(f'Found best at epoch {ep}\n') best_acc = metrics['valid_acc'] self.save_model('best.pth') self.save_model('tl-best.pth', clone_model) self.save_last_model()
def transfer_learning(self): if not self.is_tl: raise Exception('Please run with --transfer-learning') if self.tl_dataset == 'caltech-101': self.num_classes = 101 elif self.tl_dataset == 'cifar100': self.num_classes = 100 elif self.tl_dataset == 'caltech-256': self.num_classes = 257 else: # cifar10 self.num_classes = 10 # load clone model print('Loading clone model') if self.arch == 'alexnet': tl_model = AlexNetNormal(self.in_channels, self.num_classes, self.norm_type) else: tl_model = ResNet18(num_classes=self.num_classes, norm_type=self.norm_type) # # 自己的更改,fine-tune alex 一路 # if self.arch == 'alexnet': # tl_model = AlexNetPassportPrivate(self.in_channels, self.num_classes, passport_kwargs) # else: # tl_model = ResNet18Private(num_classes=self.num_classes, passport_kwargs=passport_kwargs) # ##### load / reset weights of passport layers for clone model ##### try: tl_model.load_state_dict(self.model.state_dict()) # tl_model.load_state_dict(self.copy_model.state_dict()) except: print('Having problem to direct load state dict, loading it manually') if self.arch == 'alexnet': for tl_m, self_m in zip(tl_model.features, self.model.features): try: tl_m.load_state_dict(self_m.state_dict()) except: print( 'Having problem to load state dict usually caused by missing keys, load by strict=False') tl_m.load_state_dict(self_m.state_dict(), False) # load conv weight, bn running mean # print(self_m) # print(tl_m) # 原来的参数载入 # tl_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1)) # tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1)) #更改,注意bn的值 scale1,scale2 = self_m.get_scale() tl_m.bn.weight.data.copy_(scale1.detach().view(-1)) tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1)) else: passport_settings = self.passport_config for l_key in passport_settings: # layer if isinstance(passport_settings[l_key], dict): for i in passport_settings[l_key]: # sequential for m_key in passport_settings[l_key][i]: # convblock tl_m = tl_model.__getattr__(l_key)[int(i)].__getattr__(m_key) # type: ConvBlock self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key) try: tl_m.load_state_dict(self_m.state_dict()) except: print(f'{l_key}.{i}.{m_key} cannot load state dict directly') # print(self_m) # print(tl_m) tl_m.load_state_dict(self_m.state_dict(), False) scale1, scale2 = self_m.get_scale() tl_m.bn.weight.data.copy_(scale1.detach().view(-1)) tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1)) else: print("FFFFFFFFFFFFFFFFFFFFFFF") tl_m = tl_model.__getattr__(l_key) self_m = self.model.__getattr__(l_key) try: tl_m.load_state_dict(self_m.state_dict()) except: print(f'{l_key} cannot load state dict directly') tl_m.load_state_dict(self_m.state_dict(), False) # tl_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1)) scale1, scale2 = self_m.get_scale() tl_m.bn.weight.data.copy_(scale1.detach().view(-1)) tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1)) tl_model.to(self.device) print('Loaded clone model') # tl scheme setup if self.tl_scheme == 'rtal': # rtal = reset last layer + train all layer # ftal = train all layer try: tl_model.classifier.reset_parameters() except: tl_model.linear.reset_parameters() # for name, m in self.model.named_modules(): # print('name',name) # if name # # for i in self.model.fc.parameters(): # i.requires_grad = False # # for i in self.model.bn1.parameters(): # i.requires_grad = False optimizer = optim.SGD(tl_model.parameters(), lr=self.lr, momentum=0.9, weight_decay=0.0005) if len(self.lr_config[self.lr_config['type']]) != 0: # if no specify steps, then scheduler = None scheduler = optim.lr_scheduler.MultiStepLR(optimizer, self.lr_config[self.lr_config['type']], self.lr_config['gamma']) else: scheduler = None tl_trainer = Trainer(tl_model, optimizer, scheduler, self.device) tester = TesterPrivate(self.model, self.device) history_file = os.path.join(self.logdir, 'history.csv') first = True best_acc = 0 best_file = os.path.join(self.logdir, 'best.txt') best_ep = 1 for ep in range(1, self.epochs + 1): train_metrics = tl_trainer.train(ep, self.train_data) valid_metrics = tl_trainer.test(self.valid_data) ##### load transfer learning weights from clone model ##### try: self.model.load_state_dict(tl_model.state_dict()) except: if self.arch == 'alexnet': for tl_m, self_m in zip(tl_model.features, self.model.features): try: self_m.load_state_dict(tl_m.state_dict()) except: self_m.load_state_dict(tl_m.state_dict(), False) else: passport_settings = self.passport_config for l_key in passport_settings: # layer if isinstance(passport_settings[l_key], dict): for i in passport_settings[l_key]: # sequential for m_key in passport_settings[l_key][i]: # convblock tl_m = tl_model.__getattr__(l_key)[int(i)].__getattr__(m_key) self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key) try: self_m.load_state_dict(tl_m.state_dict()) except: self_m.load_state_dict(tl_m.state_dict(), False) else: tl_m = tl_model.__getattr__(l_key) self_m = self.model.__getattr__(l_key) try: self_m.load_state_dict(tl_m.state_dict()) except: self_m.load_state_dict(tl_m.state_dict(), False) wm_metrics = tester.test_signature() L = len(wm_metrics) S = sum(wm_metrics.values()) pri_sign = S/L if self.train_backdoor: backdoor_metrics = tester.test(self.wm_data, 'Old WM Accuracy') metrics = {} for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key] for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key] for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key] if self.train_backdoor: for key in backdoor_metrics: metrics[f'backdoor_{key}'] = backdoor_metrics[key] self.append_history(history_file, metrics, first) first = False if self.save_interval and ep % self.save_interval == 0: self.save_model(f'epoch-{ep}.pth') self.save_model(f'tl-epoch-{ep}.pth', tl_model) if best_acc < metrics['valid_acc']: print(f'Found best at epoch {ep}\n') best_acc = metrics['valid_acc'] self.save_model('best.pth') self.save_model('tl-best.pth', tl_model) best_ep = ep self.save_last_model() f = open(best_file,'a') print(str(wm_metrics) + '\n', file=f) print(str(metrics) + '\n', file=f) f.write('Bset ACC %s'%str(best_acc) + "\n") print('Private Sign Detction:',str(pri_sign) + '\n', file=f) f.write( "\n") f.write("best epoch: %s"%str(best_ep) + '\n') f.flush()