class Grad_Train(Defense): name: str = 'grad_train' def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_epsilon: float = 8.0 / 255, pgd_iteration: int = 7, grad_lambda: float = 10, **kwargs): super().__init__(**kwargs) self.param_list['grad_train'] = ['grad_lambda'] self.grad_lambda = grad_lambda self.param_list['adv_train'] = ['pgd_alpha', 'pgd_epsilon', 'pgd_iteration'] self.pgd_alpha = pgd_alpha self.pgd_epsilon = pgd_epsilon self.pgd_iteration = pgd_iteration self.pgd = PGD(alpha=pgd_alpha, epsilon=pgd_epsilon, iteration=pgd_iteration, stop_threshold=None) def detect(self, **kwargs): self.model._train(loss_fn=self.loss_fn, validate_func=self.validate_func, verbose=True, **kwargs) def loss_fn(self, _input, _label, **kwargs): new_input = _input.repeat(4, 1, 1, 1) new_label = _label.repeat(4) noise = torch.randn_like(new_input) noise = noise / noise.norm(p=float('inf')) * self.pgd_epsilon new_input = new_input + noise new_input = new_input.clamp(0, 1).detach() new_input.requires_grad_() loss = self.model.loss(new_input, new_label) grad = torch.autograd.grad(loss, new_input, create_graph=True)[0] new_loss = loss + self.grad_lambda * grad.flatten(start_dim=1).norm(p=1, dim=1).mean() return new_loss def validate_func(self, get_data_fn=None, loss_fn=None, **kwargs) -> tuple[float, float, float]: clean_loss, clean_acc, _ = self.model._validate(print_prefix='Validate Clean', get_data_fn=None, **kwargs) adv_loss, adv_acc, _ = self.model._validate(print_prefix='Validate Adv', get_data_fn=self.get_data, **kwargs) # todo: Return value if self.clean_acc - clean_acc > 20 and self.clean_acc > 40: adv_acc = 0.0 return clean_loss + adv_loss, adv_acc, clean_acc def get_data(self, data: tuple[torch.Tensor, torch.Tensor], **kwargs) -> tuple[torch.Tensor, torch.Tensor]: _input, _label = self.model.get_data(data, **kwargs) def loss_fn(X: torch.FloatTensor): return -self.model.loss(X, _label) adv_x, _ = self.pgd.optimize(_input=_input, loss_fn=loss_fn) return adv_x, _label def save(self, **kwargs): self.model.save(folder_path=self.folder_path, suffix='_grad_train', verbose=True, **kwargs)
class IMC_AdvTrain(IMC): r""" Input Model Co-optimization (IMC) Backdoor Attack is described in detail in the paper `A Tale of Evil Twins`_ by Ren Pang. Based on :class:`trojanzoo.attacks.backdoor.BadNet`, IMC optimizes the watermark pixel values using PGD attack to enhance the performance. Args: target_value (float): The proportion of malicious images in the training set (Max 0.5). Default: 10. .. _A Tale of Evil Twins: https://arxiv.org/abs/1911.01559 """ name: str = 'imc_advtrain' def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_epsilon: float = 8.0 / 255, pgd_iteration: int = 7, **kwargs): super().__init__(**kwargs) self.param_list['adv_train'] = ['pgd_alpha', 'pgd_epsilon', 'pgd_iteration'] self.pgd_alpha = pgd_alpha self.pgd_epsilon = pgd_epsilon self.pgd_iteration = pgd_iteration self.pgd = PGD(alpha=pgd_alpha, epsilon=pgd_epsilon, iteration=pgd_iteration, stop_threshold=None) def get_poison_data(self, data: tuple[torch.Tensor, torch.Tensor], **kwargs) -> tuple[torch.Tensor, torch.Tensor]: _input, _label = self.model.get_data(data) decimal, integer = math.modf(self.poison_num) integer = int(integer) if random.uniform(0, 1) < decimal: integer += 1 if integer: org_input, org_label = _input, _label _input = self.add_mark(org_input[:integer]) _label = self.target_class * torch.ones_like(org_label[:integer]) return _input, _label def attack(self, epoch: int, save=False, **kwargs): self.adv_train(epoch, save=save, validate_fn=self.validate_fn, get_data_fn=self.get_data, epoch_fn=self.epoch_fn, save_fn=self.save, **kwargs) def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None, validate_interval=10, save=False, verbose=True, indent=0, epoch_fn: Callable = None, **kwargs): loader_train = self.dataset.loader['train'] file_path = self.folder_path + self.get_filename() + '.pth' _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs) losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') params = [param_group['params'] for param_group in optimizer.param_groups] for _epoch in range(epoch): if callable(epoch_fn): self.model.activate_params([]) epoch_fn() self.model.activate_params(params) losses.reset() top1.reset() top5.reset() epoch_start = time.perf_counter() if verbose and env['tqdm']: loader_train = tqdm(loader_train) optimizer.zero_grad() for data in loader_train: _input, _label = self.model.get_data(data) noise = torch.zeros_like(_input) poison_input, poison_label = self.get_poison_data(data) def loss_fn(X: torch.FloatTensor): return -self.model.loss(X, _label) adv_x = _input self.model.train() loss = self.model.loss(adv_x, _label) loss.backward() optimizer.step() optimizer.zero_grad() for m in range(self.pgd.iteration): self.model.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1) optimizer.zero_grad() self.model.train() x = torch.cat((adv_x, poison_input)) y = torch.cat((_label, poison_label)) loss = self.model.loss(x, y) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): _output = self.model.get_logits(_input) acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) losses.update(loss.item(), batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) epoch_time = str(datetime.timedelta(seconds=int( time.perf_counter() - epoch_start))) self.model.eval() self.model.activate_params([]) if verbose: pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30), f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=indent) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1: _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs) if cur_acc < best_acc: prints('best result update!', indent=indent) prints(f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: self.save() if verbose: print('-' * 50) self.model.zero_grad()
class IMC_Poison(PoisonBasic): name: str = 'imc_poison' # TODO: change PGD to Uname.optimizer @classmethod def add_argument(cls, group: argparse._ArgumentGroup): super().add_argument(group) group.add_argument('--pgd_alpha', dest='pgd_alpha', type=float) group.add_argument('--pgd_eps', dest='pgd_eps', type=float) group.add_argument('--pgd_iter', dest='pgd_iter', type=int) group.add_argument('--stop_conf', dest='stop_conf', type=float) group.add_argument('--magnet', dest='magnet', action='store_true') group.add_argument('--randomized_smooth', dest='randomized_smooth', action='store_true') group.add_argument('--curvature', dest='curvature', action='store_true') def __init__(self, pgd_alpha: float = 1.0, pgd_eps: float = 8.0, pgd_iter: int = 8, stop_conf: float = 0.9, magnet: bool = False, randomized_smooth: bool = False, curvature: bool = False, **kwargs): super().__init__(**kwargs) self.param_list['pgd'] = ['pgd_alpha', 'pgd_eps', 'pgd_iter'] self.pgd_alpha: float = pgd_alpha self.pgd_eps: float = pgd_eps self.pgd_iter: int = pgd_iter self.pgd = PGD_Optimizer(pgd_alpha=self.pgd_alpha / 255, pgd_eps=self.pgd_eps / 255, iteration=self.pgd_iter) self.stop_conf: float = stop_conf if magnet: self.magnet: MagNet = MagNet(dataset=self.dataset, pretrain=True) self.randomized_smooth: bool = randomized_smooth if curvature: self.curvature: Curvature = Curvature(model=self.model) def attack(self, epoch: int, **kwargs): # model._validate() total = 0 target_conf_list = [] target_acc_list = [] clean_acc_list = [] pgd_norm_list = [] pgd_alpha = 1.0 / 255 pgd_eps = 8.0 / 255 if self.dataset.name in ['cifar10', 'gtsrb', 'isic2018']: pgd_alpha = 1.0 / 255 pgd_eps = 8.0 / 255 if self.dataset.name in ['sample_imagenet', 'sample_vggface2']: pgd_alpha = 0.25 / 255 pgd_eps = 2.0 / 255 pgd_checker = PGD(pgd_alpha=pgd_alpha, pgd_eps=pgd_eps, iteration=8, dataset=self.dataset, model=self.model, target_idx=self.target_idx, stop_threshold=0.95) easy = 0 difficult = 0 normal = 0 loader = self.dataset.get_dataloader( mode='valid', batch_size=self.dataset.test_batch_size) if 'curvature' in self.__dict__.keys(): benign_curvature = self.curvature.benign_measure() tgt_curvature_list = [] org_curvature_list = [] if self.randomized_smooth: org_conf_list = [] tgt_conf_list = [] if 'magnet' in self.__dict__.keys(): org_magnet_list = [] tgt_magnet_list = [] for data in loader: print(easy, normal, difficult) if normal >= 100: break self.model.load() _input, _label = self.model.remove_misclassify(data) if len(_label) == 0: continue target_label = self.model.generate_target(_input, idx=self.target_idx) self.temp_input = _input self.temp_label = target_label _, _iter = pgd_checker.craft_example(_input) if _iter is None: difficult += 1 continue if _iter < 4: easy += 1 continue normal += 1 target_conf, target_acc, clean_acc = self.validate_fn() noise = torch.zeros_like(_input) poison_input = self.craft_example(_input=_input, _label=target_label, epoch=epoch, noise=noise, **kwargs) pgd_norm = float(noise.norm(p=float('inf'))) target_conf, target_acc, clean_acc = self.validate_fn() target_conf_list.append(target_conf) target_acc_list.append(target_acc) clean_acc_list.append(max(self.clean_acc - clean_acc, 0.0)) pgd_norm_list.append(pgd_norm) print( f'[{total+1} / 100]\n' f'target confidence: {np.mean(target_conf_list)}({np.std(target_conf_list)})\n' f'target accuracy: {np.mean(target_acc_list)}({np.std(target_acc_list)})\n' f'clean accuracy Drop: {np.mean(clean_acc_list)}({np.std(clean_acc_list)})\n' f'PGD Norm: {np.mean(pgd_norm_list)}({np.std(pgd_norm_list)})\n\n\n' ) org_conf = self.model.get_target_prob(_input=poison_input, target=_label) tgt_conf = self.model.get_target_prob(_input=poison_input, target=target_label) if 'curvature' in self.__dict__.keys(): org_curvature_list.extend( to_list(self.curvature.measure(poison_input, _label))) # type: ignore tgt_curvature_list.extend( to_list(self.curvature.measure( poison_input, target_label))) # type: ignore print('Curvature:') print( f' org_curvature: {ks_2samp(org_curvature_list, benign_curvature)}' ) # type: ignore print( f' tgt_curvature: {ks_2samp(tgt_curvature_list, benign_curvature)}' ) # type: ignore print() if self.randomized_smooth: org_new = self.model.get_target_prob(_input=poison_input, target=_label, randomized_smooth=True) tgt_new = self.model.get_target_prob(_input=poison_input, target=target_label, randomized_smooth=True) org_increase = (org_new - org_conf).clamp(min=0.0) tgt_decrease = (tgt_new - tgt_conf).clamp(min=0.0) org_conf_list.extend(to_list(org_increase)) # type: ignore tgt_conf_list.extend(to_list(tgt_decrease)) # type: ignore print('Randomized Smooth:') print(f' org_confidence: {np.mean(org_conf_list)}' ) # type: ignore print(f' tgt_confidence: {np.mean(tgt_conf_list)}' ) # type: ignore print() if 'magnet' in self.__dict__.keys(): poison_input = self.magnet(poison_input) org_new = self.model.get_target_prob(_input=poison_input, target=_label) tgt_new = self.model.get_target_prob(_input=poison_input, target=target_label) org_increase = (org_new - org_conf).clamp(min=0.0) tgt_decrease = (tgt_conf - tgt_new).clamp(min=0.0) org_magnet_list.extend(to_list(org_increase)) # type: ignore tgt_magnet_list.extend(to_list(tgt_decrease)) # type: ignore print('MagNet:') print(f' org_confidence: {np.mean(org_magnet_list)}' ) # type: ignore print(f' tgt_confidence: {np.mean(tgt_magnet_list)}' ) # type: ignore print() total += 1 def craft_example(self, _input: torch.Tensor, _label: torch.Tensor, noise: torch.Tensor = None, save=False, **kwargs): if noise is None: noise = torch.zeros_like(_input) poison_input = None for _iter in range(self.pgd_iter): target_conf, target_acc = self.validate_target(indent=4, verbose=False) if target_conf > self.stop_conf: break poison_input, _ = self.pgd.optimize(_input, noise=noise, loss_fn=self.loss_pgd, iteration=1) self.temp_input = poison_input target_conf, target_acc = self.validate_target(indent=4, verbose=False) if target_conf > self.stop_conf: break self._train(_input=poison_input, _label=_label, **kwargs) target_conf, target_acc = self.validate_target(indent=4, verbose=False) return poison_input def save(self, **kwargs): filename = self.get_filename(**kwargs) file_path = os.path.join(self.folder_path, filename) self.model.save(file_path + '.pth') print('attack results saved at: ', file_path) def get_filename(self, **kwargs): return self.model.name def loss_pgd(self, x: torch.Tensor) -> torch.Tensor: return self.model.loss(x, self.temp_label)
class AdvTrain(BackdoorDefense): name: str = 'adv_train' @classmethod def add_argument(cls, group: argparse._ArgumentGroup): super().add_argument(group) group.add_argument('--pgd_alpha', dest='pgd_alpha', type=float) group.add_argument('--pgd_epsilon', dest='pgd_epsilon', type=float) group.add_argument('--pgd_iteration', dest='pgd_iteration', type=int) def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_epsilon: float = 8.0 / 255, pgd_iteration: int = 7, **kwargs): super().__init__(**kwargs) self.param_list['adv_train'] = [ 'pgd_alpha', 'pgd_epsilon', 'pgd_iteration' ] self.pgd_alpha = pgd_alpha self.pgd_epsilon = pgd_epsilon self.pgd_iteration = pgd_iteration self.pgd = PGD(alpha=pgd_alpha, epsilon=pgd_epsilon, iteration=pgd_iteration, stop_threshold=None) def detect(self, **kwargs): super().detect(**kwargs) print() self.adv_train(verbose=True, **kwargs) self.attack.validate_func() def validate_func(self, get_data_fn=None, **kwargs) -> tuple[float, float, float]: clean_loss, clean_acc = self.model._validate( print_prefix='Validate Clean', get_data_fn=None, **kwargs) adv_loss, adv_acc = self.model._validate(print_prefix='Validate Adv', get_data_fn=self.get_data, **kwargs) # todo: Return value if self.clean_acc - clean_acc > 20 and self.clean_acc > 40: adv_acc = 0.0 return clean_loss + adv_loss, adv_acc, clean_acc def get_data(self, data: tuple[torch.Tensor, torch.Tensor], **kwargs) -> tuple[torch.Tensor, torch.Tensor]: _input, _label = self.model.get_data(data, **kwargs) def loss_fn(X: torch.FloatTensor): return -self.model.loss(X, _label) adv_x, _ = self.pgd.optimize(_input=_input, loss_fn=loss_fn) return adv_x, _label def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None, validate_interval=10, save=False, verbose=True, indent=0, **kwargs): loader_train = self.dataset.loader['train'] file_path = self.folder_path + self.get_filename() + '.pth' _, best_acc = self.validate_func(verbose=verbose, indent=indent, **kwargs) losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') params = [ param_group['params'] for param_group in optimizer.param_groups ] for _epoch in range(epoch): losses.reset() top1.reset() top5.reset() epoch_start = time.perf_counter() if verbose and env['tqdm']: loader_train = tqdm(loader_train) self.model.activate_params(params) optimizer.zero_grad() for data in loader_train: _input, _label = self.model.get_data(data) noise = torch.zeros_like(_input) def loss_fn(X: torch.FloatTensor): return -self.model.loss(X, _label) adv_x = _input self.model.train() loss = self.model.loss(adv_x, _label) loss.backward() optimizer.step() optimizer.zero_grad() for m in range(self.pgd.iteration): self.model.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1) optimizer.zero_grad() self.model.train() loss = self.model.loss(adv_x, _label) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): _output = self.model.get_logits(_input) acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) losses.update(loss.item(), batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) epoch_time = str( datetime.timedelta(seconds=int(time.perf_counter() - epoch_start))) self.model.eval() self.model.activate_params([]) if verbose: pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30), f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format( **ansi) if env['tqdm'] else '', indent=indent) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1: _, cur_acc = self.validate_func(verbose=verbose, indent=indent, **kwargs) if cur_acc < best_acc: prints('best result update!', indent=indent) prints( f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: self.model.save(file_path=file_path, verbose=verbose) if verbose: print('-' * 50) self.model.zero_grad()
class ImageModel(Model): @classmethod def add_argument(cls, group: argparse._ArgumentGroup): super().add_argument(group) group.add_argument('--layer', dest='layer', type=int, help='layer (optional, maybe embedded in --model)') group.add_argument( '--width_factor', dest='width_factor', type=int, help= 'width factor for wide-ResNet (optional, maybe embedded in --model)' ) group.add_argument( '--sgm', dest='sgm', action='store_true', help='whether to use sgm gradient, defaults to False') group.add_argument('--sgm_gamma', dest='sgm_gamma', type=float, help='sgm gamma, defaults to 1.0') return group def __init__(self, name: str = 'imagemodel', layer: int = None, width_factor: int = None, model_class: type[_ImageModel] = _ImageModel, dataset: ImageSet = None, sgm: bool = False, sgm_gamma: float = 1.0, **kwargs): name, layer, width_factor = self.split_model_name( name, layer=layer, width_factor=width_factor) self.layer = layer self.width_factor = width_factor if 'norm_par' not in kwargs.keys() and isinstance(dataset, ImageSet): kwargs['norm_par'] = dataset.norm_par if 'num_classes' not in kwargs.keys() and dataset is None: kwargs['num_classes'] = 1000 super().__init__(name=name, model_class=model_class, layer=layer, width_factor=width_factor, dataset=dataset, **kwargs) self.sgm: bool = sgm self.sgm_gamma: float = sgm_gamma self.param_list['imagemodel'] = ['layer', 'width_factor', 'sgm'] if sgm: self.param_list['imagemodel'].extend(['sgm_gamma']) self._model: _ImageModel self.dataset: ImageSet self.pgd = None # TODO: python 3.10 type annotation def get_data(self, data: tuple[torch.Tensor, torch.Tensor], adv: bool = False, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: if adv and self.pgd is not None: _input, _label = super().get_data(data, **kwargs) def loss_fn_new( X: torch.FloatTensor ) -> torch.Tensor: # TODO: use functools.partial return -self.loss(X, _label) adv_x, _ = self.pgd.optimize(_input=_input, loss_fn=loss_fn_new) return adv_x, _label return super().get_data(data, **kwargs) def get_layer(self, x: torch.Tensor, layer_output: str = 'logits', layer_input: str = 'input') -> torch.Tensor: return self._model.get_layer(x, layer_output=layer_output, layer_input=layer_input) def get_layer_name(self) -> list[str]: return self._model.get_layer_name() def get_all_layer(self, x: torch.Tensor, layer_input: str = 'input') -> dict[str, torch.Tensor]: return self._model.get_all_layer(x, layer_input=layer_input) # TODO: requires _input shape (N, C, H, W) # Reference: https://keras.io/examples/vision/grad_cam/ def get_heatmap(self, _input: torch.Tensor, _label: torch.Tensor, method: str = 'grad_cam', cmap: Colormap = jet) -> torch.Tensor: squeeze_flag = False if _input.dim() == 3: _input = _input.unsqueeze(0) # (N, C, H, W) squeeze_flag = True if isinstance(_label, int): _label = [_label] * len(_input) _label = torch.as_tensor(_label, device=_input.device) heatmap = _input # linting purpose if method == 'grad_cam': feats = self._model.get_fm(_input).detach() # (N, C', H', W') feats.requires_grad_() _output: torch.Tensor = self._model.pool(feats) # (N, C', 1, 1) _output = self._model.flatten(_output) # (N, C') _output = self._model.classifier(_output) # (N, num_classes) _output = _output.gather(dim=1, index=_label.unsqueeze(1)).sum() grad = torch.autograd.grad(_output, feats)[0] # (N, C',H', W') feats.requires_grad_(False) weights = grad.mean(dim=-2, keepdim=True).mean(dim=-1, keepdim=True) # (N, C',1,1) heatmap = (feats * weights).sum(dim=1, keepdim=True).clamp( 0) # (N, 1, H', W') # heatmap.sub_(heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0]) heatmap.div_( heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1, keepdim=True)[0]) heatmap: torch.Tensor = F.upsample(heatmap, _input.shape[-2:], mode='bilinear')[:, 0] # (N, H, W) # Note that we violate the image order convension (W, H, C) elif method == 'saliency_map': _input.requires_grad_() _output = self(_input).gather(dim=1, index=_label.unsqueeze(1)).sum() grad = torch.autograd.grad(_output, _input)[0] # (N,C,H,W) _input.requires_grad_(False) heatmap = grad.abs().max(dim=1)[0] # (N,H,W) heatmap.sub_( heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0]) heatmap.div_( heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1, keepdim=True)[0]) heatmap = apply_cmap(heatmap.detach().cpu(), cmap) return heatmap[0] if squeeze_flag else heatmap @staticmethod def split_model_name(name: str, layer: int = None, width_factor: int = None) -> tuple[str, int, int]: re_list = re.findall(r'[0-9]+|[a-z]+|_', name) if len(re_list) > 1: name = re_list[0] layer = int(re_list[1]) if len(re_list) > 2 and re_list[-2] == 'x': width_factor = int(re_list[-1]) if layer is not None: name += str(layer) if width_factor is not None: name += f'x{width_factor:d}' return name, layer, width_factor def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, print_prefix: str = 'Epoch', start_epoch: int = 0, validate_interval: int = 10, save: bool = False, amp: bool = False, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, epoch_fn: Callable[..., None] = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, after_loss_fn: Callable[..., None] = None, validate_fn: Callable[..., tuple[float, float]] = None, save_fn: Callable[..., None] = None, file_path: str = None, folder_path: str = None, suffix: str = None, writer: SummaryWriter = None, main_tag: str = 'train', tag: str = '', verbose: bool = True, indent: int = 0, adv_train: bool = False, adv_train_alpha: float = 2.0 / 255, adv_train_epsilon: float = 8.0 / 255, adv_train_iter: int = 7, **kwargs): if adv_train: after_loss_fn_old = after_loss_fn if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'): after_loss_fn_old = getattr(self, 'after_loss_fn') validate_fn_old = validate_fn if callable( validate_fn) else self._validate loss_fn = loss_fn if callable(loss_fn) else self.loss from trojanvision.optim import PGD # TODO: consider to move import sentences to top of file self.pgd = PGD(alpha=adv_train_alpha, epsilon=adv_train_epsilon, iteration=adv_train_iter, stop_threshold=None) def after_loss_fn_new(_input: torch.Tensor, _label: torch.Tensor, _output: torch.Tensor, loss: torch.Tensor, optimizer: Optimizer, loss_fn: Callable[..., torch.Tensor] = None, amp: bool = False, scaler: torch.cuda.amp.GradScaler = None, **kwargs): noise = torch.zeros_like(_input) def loss_fn_new(X: torch.FloatTensor) -> torch.Tensor: return -loss_fn(X, _label) for m in range(self.pgd.iteration): if amp: scaler.step(optimizer) scaler.update() else: optimizer.step() self.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn_new, iteration=1) self.train() loss = loss_fn(adv_x, _label) if callable(after_loss_fn_old): after_loss_fn_old(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, **kwargs) if amp: scaler.scale(loss).backward() else: loss.backward() def validate_fn_new(get_data_fn: Callable[..., tuple[ torch.Tensor, torch.Tensor]] = None, print_prefix: str = 'Validate', **kwargs) -> tuple[float, float]: _, clean_acc = validate_fn_old(print_prefix='Validate Clean', main_tag='valid clean', get_data_fn=None, **kwargs) _, adv_acc = validate_fn_old(print_prefix='Validate Adv', main_tag='valid adv', get_data_fn=functools.partial( get_data_fn, adv=True), **kwargs) return adv_acc, clean_acc after_loss_fn = after_loss_fn_new validate_fn = validate_fn_new super()._train(epoch=epoch, optimizer=optimizer, lr_scheduler=lr_scheduler, print_prefix=print_prefix, start_epoch=start_epoch, validate_interval=validate_interval, save=save, amp=amp, loader_train=loader_train, loader_valid=loader_valid, epoch_fn=epoch_fn, get_data_fn=get_data_fn, loss_fn=loss_fn, after_loss_fn=after_loss_fn, validate_fn=validate_fn, save_fn=save_fn, file_path=file_path, folder_path=folder_path, suffix=suffix, writer=writer, main_tag=main_tag, tag=tag, verbose=verbose, indent=indent, **kwargs)
class ImageModel(Model): @classmethod def add_argument(cls, group: argparse._ArgumentGroup): super().add_argument(group) group.add_argument('--adv_train', action='store_true', help='enable adversarial training.') group.add_argument( '--adv_train_iter', type=int, help='adversarial training PGD iteration, defaults to 7.') group.add_argument( '--adv_train_alpha', type=float, help='adversarial training PGD alpha, defaults to 2/255.') group.add_argument( '--adv_train_eps', type=float, help='adversarial training PGD eps, defaults to 8/255.') group.add_argument( '--adv_train_valid_eps', type=float, help='adversarial training PGD eps, defaults to 8/255.') group.add_argument( '--sgm', action='store_true', help='whether to use sgm gradient, defaults to False') group.add_argument('--sgm_gamma', type=float, help='sgm gamma, defaults to 1.0') return group def __init__(self, name: str = 'imagemodel', layer: int = None, model: Union[type[_ImageModel], _ImageModel] = _ImageModel, dataset: ImageSet = None, adv_train: bool = False, adv_train_iter: int = 7, adv_train_alpha: float = 2 / 255, adv_train_eps: float = 8 / 255, adv_train_valid_eps: float = 8 / 255, sgm: bool = False, sgm_gamma: float = 1.0, norm_par: dict[str, list[float]] = None, **kwargs): name = self.get_name(name, layer=layer) norm_par = dataset.norm_par if norm_par is None else norm_par if 'num_classes' not in kwargs.keys() and dataset is None: kwargs['num_classes'] = 1000 super().__init__(name=name, model=model, dataset=dataset, norm_par=norm_par, **kwargs) self.sgm: bool = sgm self.sgm_gamma: float = sgm_gamma self.adv_train = adv_train self.adv_train_iter = adv_train_iter self.adv_train_alpha = adv_train_alpha self.adv_train_eps = adv_train_eps self.adv_train_valid_eps = adv_train_valid_eps self.param_list['imagemodel'] = [] if sgm: self.param_list['imagemodel'].append('sgm_gamma') if adv_train: self.param_list['adv_train'] = [ 'adv_train_iter', 'adv_train_alpha', 'adv_train_eps', 'adv_train_valid_eps' ] self.suffix += '_adv_train' if 'suffix' not in self.param_list['model']: self.param_list['model'].append('suffix') self._model: _ImageModel self.dataset: ImageSet self.pgd = None # TODO: python 3.10 type annotation self._ce_loss_fn = nn.CrossEntropyLoss(weight=self.loss_weights) @classmethod def get_name(cls, name: str, layer: int = None) -> str: full_list = name.split('_') partial_name = full_list[0] re_list = re.findall(r'\d+|\D+', partial_name) if len(re_list) > 1: layer = int(re_list[1]) elif layer is not None: partial_name += str(layer) full_list[0] = partial_name return '_'.join(full_list) def adv_loss(self, _input: torch.Tensor, _label: torch.Tensor) -> torch.Tensor: _output = self(_input) return -self._ce_loss_fn(_output, _label) def get_data(self, data: tuple[torch.Tensor, torch.Tensor], adv: bool = False, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: if adv and self.pgd is not None: _input, _label = super().get_data(data, **kwargs) adv_loss_fn = functools.partial(self.adv_loss, _label=_label) adv_x, _ = self.pgd.optimize(_input=_input, loss_fn=adv_loss_fn) return adv_x, _label return super().get_data(data, **kwargs) # TODO: requires _input shape (N, C, H, W) # Reference: https://keras.io/examples/vision/grad_cam/ def get_heatmap(self, _input: torch.Tensor, _label: torch.Tensor, method: str = 'grad_cam', cmap: Colormap = jet) -> torch.Tensor: squeeze_flag = False if _input.dim() == 3: _input = _input.unsqueeze(0) # (N, C, H, W) squeeze_flag = True if isinstance(_label, int): _label = [_label] * len(_input) _label = torch.as_tensor(_label, device=_input.device) heatmap = _input # linting purpose if method == 'grad_cam': feats = self._model.get_fm(_input).detach() # (N, C', H', W') feats.requires_grad_() _output: torch.Tensor = self._model.pool(feats) # (N, C', 1, 1) _output = self._model.flatten(_output) # (N, C') _output = self._model.classifier(_output) # (N, num_classes) _output = _output.gather(dim=1, index=_label.unsqueeze(1)).sum() grad = torch.autograd.grad(_output, feats)[0] # (N, C',H', W') feats.requires_grad_(False) weights = grad.mean(dim=-2, keepdim=True).mean(dim=-1, keepdim=True) # (N, C',1,1) heatmap = (feats * weights).sum(dim=1, keepdim=True).clamp( 0) # (N, 1, H', W') # heatmap.sub_(heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0]) heatmap.div_( heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1, keepdim=True)[0]) heatmap: torch.Tensor = F.upsample(heatmap, _input.shape[-2:], mode='bilinear')[:, 0] # (N, H, W) # Note that we violate the image order convension (W, H, C) elif method == 'saliency_map': _input.requires_grad_() _output = self(_input).gather(dim=1, index=_label.unsqueeze(1)).sum() grad = torch.autograd.grad(_output, _input)[0] # (N,C,H,W) _input.requires_grad_(False) heatmap = grad.abs().max(dim=1)[0] # (N,H,W) heatmap.sub_( heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0]) heatmap.div_( heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1, keepdim=True)[0]) heatmap = apply_cmap(heatmap.detach().cpu(), cmap) return heatmap[0] if squeeze_flag else heatmap def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, print_prefix: str = 'Epoch', start_epoch: int = 0, resume: int = 0, validate_interval: int = 10, save: bool = False, amp: bool = False, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, epoch_fn: Callable[..., None] = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, after_loss_fn: Callable[..., None] = None, validate_fn: Callable[..., tuple[float, float]] = None, save_fn: Callable[..., None] = None, file_path: str = None, folder_path: str = None, suffix: str = None, writer=None, main_tag: str = 'train', tag: str = '', verbose: bool = True, indent: int = 0, **kwargs): if self.adv_train: after_loss_fn_old = after_loss_fn if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'): after_loss_fn_old = getattr(self, 'after_loss_fn') validate_fn_old = validate_fn if callable( validate_fn) else self._validate loss_fn = loss_fn if callable(loss_fn) else self.loss from trojanvision.optim import PGD # TODO: consider to move import sentences to top of file self.pgd = PGD(pgd_alpha=self.adv_train_alpha, pgd_eps=self.adv_train_valid_eps, iteration=self.adv_train_iter, stop_threshold=None) def after_loss_fn_new(_input: torch.Tensor, _label: torch.Tensor, _output: torch.Tensor, loss: torch.Tensor, optimizer: Optimizer, loss_fn: Callable[..., torch.Tensor] = None, amp: bool = False, scaler: torch.cuda.amp.GradScaler = None, **kwargs): noise = torch.zeros_like(_input) adv_loss_fn = functools.partial(self.adv_loss, _label=_label) for m in range(self.pgd.iteration): if amp: scaler.step(optimizer) scaler.update() else: optimizer.step() self.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=adv_loss_fn, iteration=1, pgd_eps=self.adv_train_eps) self.train() loss = loss_fn(adv_x, _label) if callable(after_loss_fn_old): after_loss_fn_old(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, **kwargs) if amp: scaler.scale(loss).backward() else: loss.backward() def validate_fn_new(get_data_fn: Callable[..., tuple[ torch.Tensor, torch.Tensor]] = None, print_prefix: str = 'Validate', **kwargs) -> tuple[float, float]: _, clean_acc = validate_fn_old(print_prefix='Validate Clean', main_tag='valid clean', get_data_fn=None, **kwargs) _, adv_acc = validate_fn_old(print_prefix='Validate Adv', main_tag='valid adv', get_data_fn=functools.partial( get_data_fn, adv=True), **kwargs) return adv_acc, clean_acc + adv_acc after_loss_fn = after_loss_fn_new validate_fn = validate_fn_new super()._train(epoch=epoch, optimizer=optimizer, lr_scheduler=lr_scheduler, print_prefix=print_prefix, start_epoch=start_epoch, resume=resume, validate_interval=validate_interval, save=save, amp=amp, loader_train=loader_train, loader_valid=loader_valid, epoch_fn=epoch_fn, get_data_fn=get_data_fn, loss_fn=loss_fn, after_loss_fn=after_loss_fn, validate_fn=validate_fn, save_fn=save_fn, file_path=file_path, folder_path=folder_path, suffix=suffix, writer=writer, main_tag=main_tag, tag=tag, verbose=verbose, indent=indent, **kwargs)