def get_datasets( self) -> tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]: r"""Get clean and poison datasets. Returns: (torch.utils.data.Dataset, torch.utils.data.Dataset): Clean training dataset and poison training dataset. """ if self.attack.poison_set is None: self.attack.poison_set = self.attack.get_poison_dataset( poison_num=len(self.dataset.loader['train'].dataset)) if not self.defense_input_num: return self.dataset.loader['train'].dataset, self.attack.poison_set if self.attack.train_mode != 'dataset': poison_num = int(self.defense_input_num * self.attack.poison_percent) clean_num = self.defense_input_num - poison_num clean_input, clean_label = sample_batch( self.dataset.loader['train'].dataset, batch_size=clean_num) trigger_input, trigger_label = sample_batch(self.attack.poison_set, batch_size=poison_num) clean_set = TensorListDataset(clean_input, clean_label.tolist()) poison_set = TensorListDataset(trigger_input, trigger_label.tolist()) return clean_set, poison_set
def attack(self, epochs: int, **kwargs): trigger_x, trigger_y = self.syn_trigger_candidates() train_noise_x, train_noise_y = self.syn_random_noises(length=self.train_noise_num) valid_noise_x, valid_noise_y = self.syn_random_noises(length=self.valid_noise_num) train_set = TensorListDataset(torch.cat((trigger_x, train_noise_x)), trigger_y + train_noise_y) valid_set = TensorListDataset(torch.cat((trigger_x, valid_noise_x)), trigger_y + valid_noise_y) loader_train = self.dataset.get_dataloader('train', dataset=train_set) loader_valid = self.dataset.get_dataloader('valid', dataset=valid_set) trigger = trigger_x[self.target_class].view_as(self.mark.mark[0]).unsqueeze(0) mark = torch.cat((trigger, torch.ones_like(trigger))) self.mark.load_mark(mark, already_processed=True) self.mlp_model = ImageModel(name='mlpnet', model=_MLPNet, input_dim=self.all_point, output_dim=self.combination_number + 1, dataset=self.dataset, data_shape=[self.all_point], loss_weights=None) self.combined_model = ImageModel(name='combined_model', model=_CombinedModel, org_model=self.model._model, mlp_model=self.mlp_model._model, mark=self.mark, dataset=self.dataset, alpha=self.mlp_alpha, temperature=self.comb_temperature, amplify_rate=self.amplify_rate) optimizer = torch.optim.Adam(params=self.mlp_model.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) self.mlp_model._train(epochs, optimizer=optimizer, lr_scheduler=lr_scheduler, loader_train=loader_train, loader_valid=loader_valid, save_fn=self.save) return self.validate_fn()
def __init__(self, mix_image_num: int = 100, clean_image_ratio: float = 0.95, retrain_epoch: int = 10, nb_clusters: int = 2, clustering_method: str = "KMeans", nb_dims: int = 10, reduce_method: str = "FastICA", cluster_analysis: str = "exclusionary-reclassification", **kwargs): super().__init__(**kwargs) self.mix_image_num = mix_image_num self.clean_image_ratio = clean_image_ratio self.clean_image_num = int(mix_image_num * clean_image_ratio) self.poison_image_num = self.mix_image_num - self.clean_image_num self.nb_clusters = nb_clusters self.clustering_method = clustering_method self.nb_dims = nb_dims self.reduce_method = reduce_method self.cluster_analysis = cluster_analysis self.retrain_epoch = retrain_epoch self.clean_dataset, _ = self.dataset.split_set( self.dataset.get_full_dataset(mode='train'), self.clean_image_num) # clean_dataset, _ = self.dataset.split_set(self.dataset.get_full_dataset(mode='train'), self.clean_image_num) # clean_dataloader = self.dataset.get_dataloader(mode='train', dataset=clean_dataset, batch_size=self.clean_image_num, num_workers=0) # clean_imgs, _ = self.model.get_data(next(iter(clean_dataloader))) # self.clean_dataset = TensorDataset(clean_imgs, _) poison_dataset, _ = self.dataset.split_set(self.dataset.get_full_dataset(mode='train'), self.poison_image_num) poison_dataloader = self.dataset.get_dataloader( mode='train', dataset=poison_dataset, batch_size=self.poison_image_num, num_workers=0) poison_imgs, _ = self.model.get_data(next(iter(poison_dataloader))) poison_imgs = self.attack.add_mark(poison_imgs).cpu() poison_label = [self.attack.target_class] * self.poison_image_num self.poison_dataset = TensorListDataset(poison_imgs, poison_label) self.mix_dataset = torch.utils.data.ConcatDataset([self.clean_dataset, self.poison_dataset]) self.mix_dataloader = self.dataset.get_dataloader( mode='train', dataset=self.mix_dataset, batch_size=self.dataset.batch_size, num_workers=0)
def get_poison_dataset(self, poison_num: int = None, load_mark: bool = True, seed: int = None) -> torch.utils.data.Dataset: r"""Get poison dataset from target class (no clean data). Args: poison_num (int): Number of poison data. Defaults to ``self.poison_num`` load_mark (bool): Whether to load previously saved watermark. This should be ``False`` during attack. Defaults to ``True``. seed (int): Random seed to sample poison input indices. Defaults to ``env['data_seed']``. Returns: torch.utils.data.Dataset: Poison dataset from target class (no clean data). """ file_path = os.path.join(self.folder_path, self.get_filename() + '.npy') if load_mark: if os.path.isfile(file_path): self.load_mark = False self.mark.load_mark(file_path, already_processed=True) else: raise FileNotFoundError(file_path) if seed is None: seed = env['data_seed'] torch.random.manual_seed(seed) poison_num = min(poison_num or self.poison_num, len(self.target_set)) _input, _label = sample_batch(self.target_set, batch_size=poison_num) _label = _label.tolist() trigger_input = self.add_mark(_input) return TensorListDataset(trigger_input, _label)
def get_poison_dataset(self, poison_label: bool = True, poison_num: int = None, seed: int = None ) -> torch.utils.data.Dataset: r"""Get poison dataset (no clean data). Args: poison_label (bool): Whether to use target poison label for poison data. Defaults to ``True``. poison_num (int): Number of poison data. Defaults to ``round(self.poison_ratio * len(train_set))`` seed (int): Random seed to sample poison input indices. Defaults to ``env['data_seed']``. Returns: torch.utils.data.Dataset: Poison dataset (no clean data). """ if seed is None: seed = env['data_seed'] torch.random.manual_seed(seed) train_set = self.dataset.loader['train'].dataset poison_num = poison_num or round(self.poison_ratio * len(train_set)) _input, _label = sample_batch(train_set, batch_size=poison_num) _label = _label.tolist() if poison_label: _label = [self.target_class] * len(_label) trigger_input = self.add_mark(_input) return TensorListDataset(trigger_input, _label)
def get_pred_labels(self) -> torch.Tensor: r"""Get predicted labels for test inputs. Returns: torch.Tensor: ``torch.BoolTensor`` with shape ``(2 * defense_input_num)``. """ logger = MetricLogger(meter_length=40) str_format = '{global_avg:5.3f} ({min:5.3f}, {max:5.3f})' logger.create_meters(clean_score=str_format, poison_score=str_format) test_set = TensorListDataset(self.test_input, self.test_label) test_loader = self.dataset.get_dataloader(mode='valid', dataset=test_set) for data in logger.log_every(test_loader): _input, _label = self.model.get_data(data) trigger_input = self.attack.add_mark(_input) logger.meters['clean_score'].update_list( self.get_score(_input).tolist()) logger.meters['poison_score'].update_list( self.get_score(trigger_input).tolist()) clean_score = torch.as_tensor(logger.meters['clean_score'].deque) poison_score = torch.as_tensor(logger.meters['poison_score'].deque) clean_score_sorted = clean_score.msort() threshold_low = float(clean_score_sorted[int(self.strip_fpr * len(poison_score))]) entropy = torch.cat((clean_score, poison_score)) print(f'Threshold: {threshold_low:5.3f}') return torch.where(entropy < threshold_low, torch.ones_like(entropy).bool(), torch.zeros_like(entropy).bool())
def mix_dataset(self, poison_label: bool = True) -> torch.utils.data.Dataset: clean_dataset = self.dataset.loader['train'].dataset _input, _label = dataset_to_list(clean_dataset) _input = torch.stack(_input[self.poison_percent * len(clean_dataset)]) if poison_label: _label = [self.target_class] * len(_label) poison_input = self.add_mark(_input) poison_dataset = TensorListDataset(poison_input, _label) return torch.utils.data.ConcatDataset([clean_dataset, poison_dataset])
def mix_dataset(self) -> torch.utils.data.Dataset: clean_set = self.dataset.loader['train'].dataset subset, other_set = ImageSet.split_dataset(clean_set, percent=self.poison_percent) if not len(subset): return clean_set _input, _label = dataset_to_tensor(subset) _label += torch.randint_like(_label, low=1, high=self.model.num_classes) _label %= self.model.num_classes poison_set = TensorListDataset(_input, _label.tolist()) return torch.utils.data.ConcatDataset([poison_set, other_set])
def get_pred_labels(self) -> torch.Tensor: logger = MetricLogger(meter_length=40) str_format = '{global_avg:5.3f} ({min:5.3f}, {max:5.3f})' logger.create_meters(cls_diff=str_format, jaccard_idx=str_format) test_set = TensorListDataset(self.test_input, self.test_label) test_loader = self.dataset.get_dataloader(mode='valid', dataset=test_set, batch_size=1) clean_list = [] poison_list = [] for data in logger.log_every(test_loader): _input: torch.Tensor = data[0] _input = _input.to(env['device'], non_blocking=True) trigger_input = self.attack.add_mark(_input) clean_list.append(self.get_pred_label(_input[0], logger=logger)) poison_list.append( self.get_pred_label(trigger_input[0], logger=logger)) return torch.as_tensor(clean_list + poison_list, dtype=torch.bool)
def attack(self, epochs: int, optimizer: torch.optim.Optimizer, **kwargs): model_dict = copy.deepcopy(self.model.state_dict()) W = torch.ones(len(self.reflect_imgs)) refool_optimizer = torch.optim.SGD(optimizer.param_groups[0]['params'], lr=self.refool_lr, momentum=0.9, weight_decay=5e-4) # logger = MetricLogger(meter_length=35) # logger.create_meters(asr='{median:.3f} ({min:.3f} {max:.3f})') # iterator = logger.log_every(range(self.rank_iter)) for _iter in range(self.rank_iter): print('Select iteration: ', output_iter(_iter + 1, self.rank_iter)) # prepare data idx = random.choices(range(len(W)), weights=W.tolist(), k=self.refool_sample_num) mark = torch.ones_like(self.mark.mark).expand(self.refool_sample_num, -1, -1, -1).clone() mark[:, :-1] = self.reflect_imgs[idx] clean_input, _ = sample_batch(self.target_set, self.refool_sample_num) trigger_input = self.add_mark(clean_input, mark=mark) dataset = TensorListDataset(trigger_input, [self.target_class] * len(trigger_input)) loader = self.dataset.get_dataloader(mode='train', dataset=dataset) # train self.model._train(self.refool_epochs, optimizer=refool_optimizer, loader_train=loader, validate_interval=0, output_freq='epoch', indent=4) self.model._validate(indent=4) # test select_idx = list(set(idx)) marks = self.reflect_imgs[select_idx] asr_result = self._get_asr_result(marks) # update W W[select_idx] = asr_result other_idx = list(set(range(len(W))) - set(idx)) W[other_idx] = asr_result.median() # logger.reset().update_list(asr=asr_result) self.model.load_state_dict(model_dict) self.mark.mark[:-1] = self.reflect_imgs[W.argmax().item()] self.poison_set = self.get_poison_dataset(load_mark=False) return super().attack(epochs=epochs, optimizer=optimizer, **kwargs)
def attack(self, epoch: int, save=False, **kwargs): if self.train_mode == 'batch': self.model._train(epoch, save=save, validate_func=self.validate_func, get_data_fn=self.get_data, save_fn=self.save, **kwargs) elif self.train_mode == 'dataset': clean_dataset = self.dataset.loader['train'].dataset _input, _label = next( iter( self.dataset.get_dataloader( 'train', batch_size=int(self.poison_percent * len(clean_dataset))))) _label = torch.ones_like(_label) * self.target_class _label = _label.tolist() poison_input = self.add_mark(_input) poison_dataset = TensorListDataset(poison_input, _label) dataset = torch.utils.data.ConcatDataset( [clean_dataset, poison_dataset]) loader = self.dataset.get_dataloader('train', dataset=dataset) self.model._train(epoch, save=save, validate_func=self.validate_func, loader_train=loader, save_fn=self.save, **kwargs) elif self.train_mode == 'loss': self.model._train(epoch, save=save, validate_func=self.validate_func, loss_fn=self.loss_fn, save_fn=self.save, **kwargs)
def attack(self, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs): target_class_dataset = self.dataset.get_dataset( 'train', full=True, classes=[self.target_class]) sample_target_class_dataset, target_original_dataset = self.dataset.split_set( target_class_dataset, self.poison_num) sample_target_dataloader = self.dataset.get_dataloader( mode='train', dataset=sample_target_class_dataset, batch_size=self.poison_num, num_workers=0) target_imgs, _ = self.model.get_data( next(iter(sample_target_dataloader))) full_set = self.dataset.get_dataset('train', full=False) poison_set: TensorListDataset = None # TODO if self.poison_generation_method == 'pgd': poison_label = self.target_class * torch.ones( len(target_imgs), dtype=torch.long, device=target_imgs.device) poison_imgs, _ = self.model.remove_misclassify(data=(target_imgs, poison_label)) poison_imgs, _ = self.pgd.craft_example(_input=poison_imgs) poison_imgs = self.add_mark(poison_imgs).cpu() poison_label = [self.target_class] * len(target_imgs) poison_set = TensorListDataset(poison_imgs, poison_label) # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset]) elif self.poison_generation_method == 'gan': other_classes = list(range(self.dataset.num_classes)) other_classes.pop(self.target_class) x_list = [] y_list = [] for source_class in other_classes: print('Process data of Source Class: ', source_class) source_class_dataset = self.dataset.get_dataset( mode='train', full=True, classes=[source_class]) sample_source_class_dataset, _ = self.dataset.split_set( source_class_dataset, self.poison_num) sample_source_class_dataloader = self.dataset.get_dataloader( mode='train', dataset=sample_source_class_dataset, batch_size=self.poison_num, num_workers=0) source_imgs, _ = self.model.get_data( next(iter(sample_source_class_dataloader))) g_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_g.pth' d_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_d.pth' if os.path.exists(g_path) and os.path.exists( d_path) and not self.train_gan: self.wgan.G.load_state_dict( torch.load(g_path, map_location=env['device'])) self.wgan.D.load_state_dict( torch.load(d_path, map_location=env['device'])) print( f' load model from: \n {g_path}\n {d_path}', ) else: self.train_gan = True self.wgan.reset_parameters() gan_dataset = torch.utils.data.ConcatDataset( [source_class_dataset, target_class_dataset]) gan_dataloader = self.dataset.get_dataloader( mode='train', dataset=gan_dataset, batch_size=self.dataset.batch_size, num_workers=0) self.wgan.train(gan_dataloader) torch.save(self.wgan.G.state_dict(), g_path) torch.save(self.wgan.D.state_dict(), d_path) print(f'GAN Model Saved at : \n{g_path}\n{d_path}') continue source_encode = self.wgan.get_encode_value( source_imgs).detach() target_encode = self.wgan.get_encode_value( target_imgs).detach() # noise = torch.randn_like(source_encode) # from trojanzoo.utils.tensor import save_tensor_as_img # source_img = self.wgan.G(source_encode) # target_img = self.wgan.G(target_encode) # for i in range(len(source_img)): # save_tensor_as_img(f'./imgs/source_{i}.png', source_img[i]) # for i in range(len(target_img)): # save_tensor_as_img(f'./imgs/target_{i}.png', target_img[i]) # exit() interpolation_encode = source_encode * self.tau + target_encode * ( 1 - self.tau) poison_imgs = self.wgan.G(interpolation_encode).detach() poison_imgs = self.add_mark(poison_imgs) poison_label = [self.target_class] * len(poison_imgs) poison_imgs = poison_imgs.cpu() x_list.append(poison_imgs) y_list.extend(poison_label) assert not self.train_gan x_list = torch.cat(x_list) poison_set = TensorListDataset(x_list, y_list) # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset]) final_set = torch.utils.data.ConcatDataset([poison_set, full_set]) # final_set = poison_set final_loader = self.dataset.get_dataloader(mode='train', dataset=final_set, num_workers=0) self.model._train(optimizer=optimizer, lr_scheduler=lr_scheduler, save_fn=self.save, loader_train=final_loader, validate_fn=self.validate_fn, **kwargs)
def attack(self, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs): target_class_set = self.dataset.get_dataset('train', class_list=[self.target_class]) target_imgs, _ = sample_batch(target_class_set, batch_size=self.poison_num) target_imgs = target_imgs.to(env['device']) full_set = self.dataset.get_dataset('train') poison_set: TensorListDataset = None # TODO match self.poison_generation_method: case 'pgd': trigger_label = self.target_class * torch.ones( len(target_imgs), dtype=torch.long, device=target_imgs.device) result = [] for data in zip(target_imgs.chunk(self.dataset.batch_size), trigger_label.chunk(self.dataset.batch_size)): poison_img, _ = self.model.remove_misclassify(data) poison_img, _ = self.pgd.optimize(poison_img) poison_img = self.add_mark(poison_img).cpu() result.append(poison_img) poison_imgs = torch.cat(result) poison_set = TensorListDataset(poison_imgs, [self.target_class] * len(poison_imgs)) # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset]) case 'gan': other_classes = list(range(self.dataset.num_classes)) other_classes.pop(self.target_class) x_list = [] y_list = [] for source_class in other_classes: print('Process data of Source Class: ', source_class) source_class_dataset = self.dataset.get_dataset(mode='train', class_list=[source_class]) sample_source_class_dataset, _ = self.dataset.split_dataset( source_class_dataset, self.poison_num) source_imgs = dataset_to_tensor(sample_source_class_dataset)[0].to(device=env['device']) g_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_g.pth' d_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_d.pth' if os.path.exists(g_path) and os.path.exists(d_path) and not self.train_gan: self.wgan.G.load_state_dict(torch.load(g_path, map_location=env['device'])) self.wgan.D.load_state_dict(torch.load(d_path, map_location=env['device'])) print(f' load model from: \n {g_path}\n {d_path}', ) else: self.train_gan = True self.wgan.reset_parameters() gan_dataset = torch.utils.data.ConcatDataset([source_class_dataset, target_class_set]) gan_dataloader = self.dataset.get_dataloader( mode='train', dataset=gan_dataset, batch_size=self.dataset.batch_size, num_workers=0) self.wgan.train(gan_dataloader) torch.save(self.wgan.G.state_dict(), g_path) torch.save(self.wgan.D.state_dict(), d_path) print(f'GAN Model Saved at : \n{g_path}\n{d_path}') continue for source_chunk, target_chunk in zip(source_imgs.chunk(self.dataset.batch_size), target_imgs.chunk(self.dataset.batch_size)): source_encode = self.wgan.get_encode_value(source_chunk).detach() target_encode = self.wgan.get_encode_value(target_chunk).detach() # noise = torch.randn_like(source_encode) # source_img = self.wgan.G(source_encode) # target_img = self.wgan.G(target_encode) # if not os.path.exists('./imgs'): # os.makedirs('./imgs') # for i in range(len(source_img)): # F.to_pil_image(source_img[i]).save(f'./imgs/source_{i}.png') # for i in range(len(target_img)): # F.to_pil_image(target_img[i]).save(f'./imgs/target_{i}.png') # exit() interpolation_encode = source_encode * self.tau + target_encode * (1 - self.tau) poison_imgs = self.wgan.G(interpolation_encode).detach() poison_imgs = self.add_mark(poison_imgs) poison_imgs = poison_imgs.cpu() x_list.append(poison_imgs) y_list.extend([self.target_class] * len(source_imgs)) assert not self.train_gan x_list = torch.cat(x_list) poison_set = TensorListDataset(x_list, y_list) # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset]) final_set = torch.utils.data.ConcatDataset([poison_set, full_set]) # final_set = poison_set final_loader = self.dataset.get_dataloader(mode='train', dataset=final_set, num_workers=0) self.model._train(optimizer=optimizer, lr_scheduler=lr_scheduler, save_fn=self.save, loader_train=final_loader, validate_fn=self.validate_fn, **kwargs)