def get_datasets( self) -> tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]: r"""Get clean and poison datasets. Returns: (torch.utils.data.Dataset, torch.utils.data.Dataset): Clean training dataset and poison training dataset. """ if self.attack.poison_set is None: self.attack.poison_set = self.attack.get_poison_dataset( poison_num=len(self.dataset.loader['train'].dataset)) if not self.defense_input_num: return self.dataset.loader['train'].dataset, self.attack.poison_set if self.attack.train_mode != 'dataset': poison_num = int(self.defense_input_num * self.attack.poison_percent) clean_num = self.defense_input_num - poison_num clean_input, clean_label = sample_batch( self.dataset.loader['train'].dataset, batch_size=clean_num) trigger_input, trigger_label = sample_batch(self.attack.poison_set, batch_size=poison_num) clean_set = TensorListDataset(clean_input, clean_label.tolist()) poison_set = TensorListDataset(trigger_input, trigger_label.tolist()) return clean_set, poison_set
def sample_data(self) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: r"""Sample data from each class. The returned data dict is: * ``'other'``: ``(input, label)`` from source classes with batch size ``self.class_sample_num * len(source_class)``. * ``'target'``: ``(input, label)`` from target class with batch size ``self.class_sample_num``. Returns: dict[str, tuple[torch.Tensor, torch.Tensor]]: Data dict. """ source_class = self.source_class or list( range(self.dataset.num_classes)) source_class = source_class.copy() if self.target_class in source_class: source_class.remove(self.target_class) other_x, other_y = [], [] dataset = self.dataset.get_dataset('train') for _class in source_class: class_set = self.dataset.get_class_subset(dataset, class_list=[_class]) _input, _label = sample_batch(class_set, batch_size=self.class_sample_num) other_x.append(_input) other_y.append(_label) other_x = torch.cat(other_x) other_y = torch.cat(other_y) target_set = self.dataset.get_class_subset( dataset, class_list=[self.target_class]) target_x, target_y = sample_batch(target_set, batch_size=self.class_sample_num) data = {'other': (other_x, other_y), 'target': (target_x, target_y)} return data
def validate_mask_generator(self): loader = self.dataset.loader['valid'] dataset = loader.dataset logger = MetricLogger() logger.create_meters(loss=None, div=None, norm=None) idx = torch.randperm(len(dataset)) pos = 0 print_prefix = 'Validate' header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header)) for data in logger.log_every(loader, header=header): _input, _label = self.model.get_data(data) batch_size = len(_input) data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size]) _input2, _label2 = self.model.get_data(data2) pos += batch_size _mask = self.get_mask(_input) _mask2 = self.get_mask(_input2) input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1) mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5 loss_div = input_dist.div(mask_dist).mean() loss_norm = _mask.sub(self.mask_density).relu().mean() loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item())
def gen_seed_data(self) -> dict[str, np.ndarray]: r"""Generate seed data. Returns: dict[str, numpy.ndarray]: Seed data dict with keys ``'input'`` and ``'label'``. """ torch.manual_seed(env['seed']) if self.seed_data_num % self.model.num_classes: raise ValueError( f'seed_data_num({self.seed_data_num:d}) % num_classes({self.model.num_classes:d}) should be 0.' ) seed_class_num: int = self.seed_data_num // self.model.num_classes x, y = [], [] for _class in range(self.model.num_classes): class_set = self.dataset.get_dataset(mode='train', class_list=[_class]) _input, _label = sample_batch(class_set, batch_size=seed_class_num) x.append(_input) y.append(_label) x = torch.cat(x).numpy() y = torch.cat(y).numpy() seed_data = {'input': x, 'label': y} seed_path = os.path.join(self.folder_path, f'seed_{self.seed_data_num}.npz') np.savez(seed_path, **seed_data) print('seed data saved at: ', seed_path) return seed_data
def get_poison_dataset(self, poison_num: int = None, load_mark: bool = True, seed: int = None) -> torch.utils.data.Dataset: r"""Get poison dataset from target class (no clean data). Args: poison_num (int): Number of poison data. Defaults to ``self.poison_num`` load_mark (bool): Whether to load previously saved watermark. This should be ``False`` during attack. Defaults to ``True``. seed (int): Random seed to sample poison input indices. Defaults to ``env['data_seed']``. Returns: torch.utils.data.Dataset: Poison dataset from target class (no clean data). """ file_path = os.path.join(self.folder_path, self.get_filename() + '.npy') if load_mark: if os.path.isfile(file_path): self.load_mark = False self.mark.load_mark(file_path, already_processed=True) else: raise FileNotFoundError(file_path) if seed is None: seed = env['data_seed'] torch.random.manual_seed(seed) poison_num = min(poison_num or self.poison_num, len(self.target_set)) _input, _label = sample_batch(self.target_set, batch_size=poison_num) _label = _label.tolist() trigger_input = self.add_mark(_input) return TensorListDataset(trigger_input, _label)
def get_poison_dataset(self, poison_label: bool = True, poison_num: int = None, seed: int = None ) -> torch.utils.data.Dataset: r"""Get poison dataset (no clean data). Args: poison_label (bool): Whether to use target poison label for poison data. Defaults to ``True``. poison_num (int): Number of poison data. Defaults to ``round(self.poison_ratio * len(train_set))`` seed (int): Random seed to sample poison input indices. Defaults to ``env['data_seed']``. Returns: torch.utils.data.Dataset: Poison dataset (no clean data). """ if seed is None: seed = env['data_seed'] torch.random.manual_seed(seed) train_set = self.dataset.loader['train'].dataset poison_num = poison_num or round(self.poison_ratio * len(train_set)) _input, _label = sample_batch(train_set, batch_size=poison_num) _label = _label.tolist() if poison_label: _label = [self.target_class] * len(_label) trigger_input = self.add_mark(_input) return TensorListDataset(trigger_input, _label)
def _get_cross_data(self, data: tuple[torch.Tensor, torch.Tensor], **kwargs) -> tuple[torch.Tensor, torch.Tensor]: r"""Get cross-trigger mode data. Sample another batch from train set and apply their marks and masks to current batch. """ _input, _label = self.model.get_data(data) batch_size = len(_input) data2 = sample_batch(self.train_set, idx=self.idx[self.pos:self.pos + batch_size]) _input2, _label2 = self.model.get_data(data2) self.pos += batch_size if self.pos >= len(self.idx): self.pos = 0 self.idx = torch.randperm(len(self.idx)) mark, mask = self.get_mark(_input2), self.get_mask(_input2) _input = _input + mask * (mark - _input) return _input, _label
def train_mask_generator(self, verbose: bool = True): r"""Train :attr:`self.mask_generator`.""" optimizer = torch.optim.Adam(self.mask_generator.parameters(), lr=1e-2, betas=(0.5, 0.9)) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.train_mask_epochs) loader = self.dataset.loader['train'] dataset = loader.dataset logger = MetricLogger() logger.create_meters(loss=None, div=None, norm=None) print_prefix = 'Mask Epoch' for _epoch in range(self.train_mask_epochs): _epoch += 1 idx = torch.randperm(len(dataset)) pos = 0 logger.reset() header: str = '{blue_light}{0}: {1}{reset}'.format( print_prefix, output_iter(_epoch, self.train_mask_epochs), **ansi) header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header)) self.mask_generator.train() for data in logger.log_every(loader, header=header) if verbose else loader: optimizer.zero_grad() _input, _label = self.model.get_data(data) batch_size = len(_input) data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size]) _input2, _label2 = self.model.get_data(data2) pos += batch_size _mask = self.get_mask(_input) _mask2 = self.get_mask(_input2) input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1) mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5 loss_div = input_dist.div(mask_dist).mean() loss_norm = _mask.sub(self.mask_density).relu().mean() loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div loss.backward() optimizer.step() logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item()) lr_scheduler.step() self.mask_generator.eval() if verbose and (_epoch % (max(self.train_mask_epochs // 5, 1)) == 0 or _epoch == self.train_mask_epochs): self.validate_mask_generator() optimizer.zero_grad()
def attack(self, epochs: int, optimizer: torch.optim.Optimizer, **kwargs): model_dict = copy.deepcopy(self.model.state_dict()) W = torch.ones(len(self.reflect_imgs)) refool_optimizer = torch.optim.SGD(optimizer.param_groups[0]['params'], lr=self.refool_lr, momentum=0.9, weight_decay=5e-4) # logger = MetricLogger(meter_length=35) # logger.create_meters(asr='{median:.3f} ({min:.3f} {max:.3f})') # iterator = logger.log_every(range(self.rank_iter)) for _iter in range(self.rank_iter): print('Select iteration: ', output_iter(_iter + 1, self.rank_iter)) # prepare data idx = random.choices(range(len(W)), weights=W.tolist(), k=self.refool_sample_num) mark = torch.ones_like(self.mark.mark).expand(self.refool_sample_num, -1, -1, -1).clone() mark[:, :-1] = self.reflect_imgs[idx] clean_input, _ = sample_batch(self.target_set, self.refool_sample_num) trigger_input = self.add_mark(clean_input, mark=mark) dataset = TensorListDataset(trigger_input, [self.target_class] * len(trigger_input)) loader = self.dataset.get_dataloader(mode='train', dataset=dataset) # train self.model._train(self.refool_epochs, optimizer=refool_optimizer, loader_train=loader, validate_interval=0, output_freq='epoch', indent=4) self.model._validate(indent=4) # test select_idx = list(set(idx)) marks = self.reflect_imgs[select_idx] asr_result = self._get_asr_result(marks) # update W W[select_idx] = asr_result other_idx = list(set(range(len(W))) - set(idx)) W[other_idx] = asr_result.median() # logger.reset().update_list(asr=asr_result) self.model.load_state_dict(model_dict) self.mark.mark[:-1] = self.reflect_imgs[W.argmax().item()] self.poison_set = self.get_poison_dataset(load_mark=False) return super().attack(epochs=epochs, optimizer=optimizer, **kwargs)
def attack(self, epochs: int, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, validate_interval: int = 1, save: bool = False, verbose: bool = True, **kwargs): if verbose: print('train mask generator') self.mark_generator.requires_grad_(False) self.mask_generator.requires_grad_() self.model.requires_grad_(False) self.train_mask_generator(verbose=verbose) if verbose: print() print('train mark generator and model') self.mark_generator.requires_grad_() self.mask_generator.requires_grad_(False) if not self.natural: params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) self.model.activate_params(params) mark_optimizer = torch.optim.Adam(self.mark_generator.parameters(), lr=1e-2, betas=(0.5, 0.9)) mark_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( mark_optimizer, T_max=epochs) loader = self.dataset.loader['train'] dataset = loader.dataset logger = MetricLogger() logger.create_meters(loss=None, div=None, ce=None) if validate_interval != 0: best_validate_result = self.validate_fn(verbose=verbose) best_asr = best_validate_result[0] for _epoch in range(epochs): _epoch += 1 idx = torch.randperm(len(dataset)) pos = 0 logger.reset() if not self.natural: self.model.train() self.mark_generator.train() header: str = '{blue_light}{0}: {1}{reset}'.format( 'Epoch', output_iter(_epoch, epochs), **ansi) header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header)) for data in logger.log_every(loader, header=header) if verbose else loader: if not self.natural: optimizer.zero_grad() mark_optimizer.zero_grad() _input, _label = self.model.get_data(data) batch_size = len(_input) data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size]) _input2, _label2 = self.model.get_data(data2) pos += batch_size final_input, final_label = _input.clone(), _label.clone() # generate trigger input trigger_dec, trigger_int = math.modf(len(_label) * self.poison_percent) trigger_int = int(trigger_int) if random.uniform(0, 1) < trigger_dec: trigger_int += 1 x = _input[:trigger_int] trigger_mark, trigger_mask = self.get_mark(x), self.get_mask(x) trigger_input = x + trigger_mask * (trigger_mark - x) final_input[:trigger_int] = trigger_input final_label[:trigger_int] = self.target_class # generate cross input cross_dec, cross_int = math.modf(len(_label) * self.cross_percent) cross_int = int(cross_int) if random.uniform(0, 1) < cross_dec: cross_int += 1 x = _input[trigger_int:trigger_int + cross_int] x2 = _input2[trigger_int:trigger_int + cross_int] cross_mark, cross_mask = self.get_mark(x2), self.get_mask(x2) cross_input = x + cross_mask * (cross_mark - x) final_input[trigger_int:trigger_int + cross_int] = cross_input # div loss if len(trigger_input) <= len(cross_input): length = len(trigger_input) cross_input = cross_input[:length] cross_mark = cross_mark[:length] cross_mask = cross_mask[:length] else: length = len(cross_input) trigger_input = trigger_input[:length] trigger_mark = trigger_mark[:length] trigger_mask = trigger_mask[:length] input_dist: torch.Tensor = (trigger_input - cross_input).flatten(1).norm(p=2, dim=1) mark_dist: torch.Tensor = (trigger_mark - cross_mark).flatten(1).norm(p=2, dim=1) + 1e-5 loss_ce = self.model.loss(final_input, final_label) loss_div = input_dist.div(mark_dist).mean() loss = loss_ce + self.lambda_div * loss_div loss.backward() if not self.natural: optimizer.step() mark_optimizer.step() logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), ce=loss_ce.item()) if not self.natural and lr_scheduler: lr_scheduler.step() mark_scheduler.step() if not self.natural: self.model.eval() self.mark_generator.eval() if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs): validate_result = self.validate_fn(verbose=verbose) cur_asr = validate_result[0] if cur_asr >= best_asr: best_validate_result = validate_result best_asr = cur_asr if save: self.save() if not self.natural: optimizer.zero_grad() mark_optimizer.zero_grad() self.mark_generator.requires_grad_(False) self.mask_generator.requires_grad_(False) self.model.requires_grad_(False) return best_validate_result
def attack(self, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs): target_class_set = self.dataset.get_dataset('train', class_list=[self.target_class]) target_imgs, _ = sample_batch(target_class_set, batch_size=self.poison_num) target_imgs = target_imgs.to(env['device']) full_set = self.dataset.get_dataset('train') poison_set: TensorListDataset = None # TODO match self.poison_generation_method: case 'pgd': trigger_label = self.target_class * torch.ones( len(target_imgs), dtype=torch.long, device=target_imgs.device) result = [] for data in zip(target_imgs.chunk(self.dataset.batch_size), trigger_label.chunk(self.dataset.batch_size)): poison_img, _ = self.model.remove_misclassify(data) poison_img, _ = self.pgd.optimize(poison_img) poison_img = self.add_mark(poison_img).cpu() result.append(poison_img) poison_imgs = torch.cat(result) poison_set = TensorListDataset(poison_imgs, [self.target_class] * len(poison_imgs)) # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset]) case 'gan': other_classes = list(range(self.dataset.num_classes)) other_classes.pop(self.target_class) x_list = [] y_list = [] for source_class in other_classes: print('Process data of Source Class: ', source_class) source_class_dataset = self.dataset.get_dataset(mode='train', class_list=[source_class]) sample_source_class_dataset, _ = self.dataset.split_dataset( source_class_dataset, self.poison_num) source_imgs = dataset_to_tensor(sample_source_class_dataset)[0].to(device=env['device']) g_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_g.pth' d_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_d.pth' if os.path.exists(g_path) and os.path.exists(d_path) and not self.train_gan: self.wgan.G.load_state_dict(torch.load(g_path, map_location=env['device'])) self.wgan.D.load_state_dict(torch.load(d_path, map_location=env['device'])) print(f' load model from: \n {g_path}\n {d_path}', ) else: self.train_gan = True self.wgan.reset_parameters() gan_dataset = torch.utils.data.ConcatDataset([source_class_dataset, target_class_set]) gan_dataloader = self.dataset.get_dataloader( mode='train', dataset=gan_dataset, batch_size=self.dataset.batch_size, num_workers=0) self.wgan.train(gan_dataloader) torch.save(self.wgan.G.state_dict(), g_path) torch.save(self.wgan.D.state_dict(), d_path) print(f'GAN Model Saved at : \n{g_path}\n{d_path}') continue for source_chunk, target_chunk in zip(source_imgs.chunk(self.dataset.batch_size), target_imgs.chunk(self.dataset.batch_size)): source_encode = self.wgan.get_encode_value(source_chunk).detach() target_encode = self.wgan.get_encode_value(target_chunk).detach() # noise = torch.randn_like(source_encode) # source_img = self.wgan.G(source_encode) # target_img = self.wgan.G(target_encode) # if not os.path.exists('./imgs'): # os.makedirs('./imgs') # for i in range(len(source_img)): # F.to_pil_image(source_img[i]).save(f'./imgs/source_{i}.png') # for i in range(len(target_img)): # F.to_pil_image(target_img[i]).save(f'./imgs/target_{i}.png') # exit() interpolation_encode = source_encode * self.tau + target_encode * (1 - self.tau) poison_imgs = self.wgan.G(interpolation_encode).detach() poison_imgs = self.add_mark(poison_imgs) poison_imgs = poison_imgs.cpu() x_list.append(poison_imgs) y_list.extend([self.target_class] * len(source_imgs)) assert not self.train_gan x_list = torch.cat(x_list) poison_set = TensorListDataset(x_list, y_list) # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset]) final_set = torch.utils.data.ConcatDataset([poison_set, full_set]) # final_set = poison_set final_loader = self.dataset.get_dataloader(mode='train', dataset=final_set, num_workers=0) self.model._train(optimizer=optimizer, lr_scheduler=lr_scheduler, save_fn=self.save, loader_train=final_loader, validate_fn=self.validate_fn, **kwargs)