def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_epsilon: float = 8.0 / 255, pgd_iteration: int = 7, **kwargs): super().__init__(**kwargs) self.param_list['adv_train'] = ['pgd_alpha', 'pgd_epsilon', 'pgd_iteration'] self.pgd_alpha = pgd_alpha self.pgd_epsilon = pgd_epsilon self.pgd_iteration = pgd_iteration self.pgd = PGD(alpha=pgd_alpha, epsilon=pgd_epsilon, iteration=pgd_iteration, stop_threshold=None)
def __init__(self, preprocess_layer: str = 'features', pgd_epsilon: int = 16.0 / 255, pgd_iteration: int = 40, pgd_alpha: float = 4.0 / 255, **kwargs): super().__init__(**kwargs) self.param_list['hidden_trigger'] = [ 'preprocess_layer', 'pgd_alpha', 'pgd_epsilon', 'pgd_iteration' ] self.preprocess_layer: str = preprocess_layer self.pgd_alpha: float = pgd_alpha self.pgd_epsilon: float = pgd_epsilon self.pgd_iteration: int = pgd_iteration self.target_loader = self.dataset.get_dataloader( 'train', full=True, classes=self.target_class, drop_last=True, num_workers=0) self.pgd: PGD = PGD(alpha=self.pgd_alpha, epsilon=pgd_epsilon, iteration=pgd_iteration, output=self.output)
def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_eps: float = 8.0 / 255, pgd_iter: int = 7, **kwargs): super().__init__(**kwargs) self.param_list['adv_train'] = ['pgd_alpha', 'pgd_eps', 'pgd_iter'] self.pgd_alpha = pgd_alpha self.pgd_eps = pgd_eps self.pgd_iter = pgd_iter self.pgd = PGD(pgd_alpha=pgd_alpha, pgd_eps=pgd_eps, iteration=pgd_iter, stop_threshold=None)
def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, print_prefix: str = 'Epoch', start_epoch: int = 0, validate_interval: int = 10, save: bool = False, amp: bool = False, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, epoch_fn: Callable[..., None] = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, after_loss_fn: Callable[..., None] = None, validate_fn: Callable[..., tuple[float, float]] = None, save_fn: Callable[..., None] = None, file_path: str = None, folder_path: str = None, suffix: str = None, writer: SummaryWriter = None, main_tag: str = 'train', tag: str = '', verbose: bool = True, indent: int = 0, adv_train: bool = False, adv_train_alpha: float = 2.0 / 255, adv_train_epsilon: float = 8.0 / 255, adv_train_iter: int = 7, **kwargs): if adv_train: after_loss_fn_old = after_loss_fn if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'): after_loss_fn_old = getattr(self, 'after_loss_fn') validate_fn_old = validate_fn if callable( validate_fn) else self._validate loss_fn = loss_fn if callable(loss_fn) else self.loss from trojanvision.optim import PGD # TODO: consider to move import sentences to top of file self.pgd = PGD(alpha=adv_train_alpha, epsilon=adv_train_epsilon, iteration=adv_train_iter, stop_threshold=None) def after_loss_fn_new(_input: torch.Tensor, _label: torch.Tensor, _output: torch.Tensor, loss: torch.Tensor, optimizer: Optimizer, loss_fn: Callable[..., torch.Tensor] = None, amp: bool = False, scaler: torch.cuda.amp.GradScaler = None, **kwargs): noise = torch.zeros_like(_input) def loss_fn_new(X: torch.FloatTensor) -> torch.Tensor: return -loss_fn(X, _label) for m in range(self.pgd.iteration): if amp: scaler.step(optimizer) scaler.update() else: optimizer.step() self.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn_new, iteration=1) self.train() loss = loss_fn(adv_x, _label) if callable(after_loss_fn_old): after_loss_fn_old(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, **kwargs) if amp: scaler.scale(loss).backward() else: loss.backward() def validate_fn_new(get_data_fn: Callable[..., tuple[ torch.Tensor, torch.Tensor]] = None, print_prefix: str = 'Validate', **kwargs) -> tuple[float, float]: _, clean_acc = validate_fn_old(print_prefix='Validate Clean', main_tag='valid clean', get_data_fn=None, **kwargs) _, adv_acc = validate_fn_old(print_prefix='Validate Adv', main_tag='valid adv', get_data_fn=functools.partial( get_data_fn, adv=True), **kwargs) return adv_acc, clean_acc after_loss_fn = after_loss_fn_new validate_fn = validate_fn_new super()._train(epoch=epoch, optimizer=optimizer, lr_scheduler=lr_scheduler, print_prefix=print_prefix, start_epoch=start_epoch, validate_interval=validate_interval, save=save, amp=amp, loader_train=loader_train, loader_valid=loader_valid, epoch_fn=epoch_fn, get_data_fn=get_data_fn, loss_fn=loss_fn, after_loss_fn=after_loss_fn, validate_fn=validate_fn, save_fn=save_fn, file_path=file_path, folder_path=folder_path, suffix=suffix, writer=writer, main_tag=main_tag, tag=tag, verbose=verbose, indent=indent, **kwargs)