Esempio n. 1
0
 def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_epsilon: float = 8.0 / 255, pgd_iteration: int = 7, **kwargs):
     super().__init__(**kwargs)
     self.param_list['adv_train'] = ['pgd_alpha', 'pgd_epsilon', 'pgd_iteration']
     self.pgd_alpha = pgd_alpha
     self.pgd_epsilon = pgd_epsilon
     self.pgd_iteration = pgd_iteration
     self.pgd = PGD(alpha=pgd_alpha, epsilon=pgd_epsilon, iteration=pgd_iteration, stop_threshold=None)
Esempio n. 2
0
    def __init__(self,
                 preprocess_layer: str = 'features',
                 pgd_epsilon: int = 16.0 / 255,
                 pgd_iteration: int = 40,
                 pgd_alpha: float = 4.0 / 255,
                 **kwargs):
        super().__init__(**kwargs)

        self.param_list['hidden_trigger'] = [
            'preprocess_layer', 'pgd_alpha', 'pgd_epsilon', 'pgd_iteration'
        ]

        self.preprocess_layer: str = preprocess_layer
        self.pgd_alpha: float = pgd_alpha
        self.pgd_epsilon: float = pgd_epsilon
        self.pgd_iteration: int = pgd_iteration

        self.target_loader = self.dataset.get_dataloader(
            'train',
            full=True,
            classes=self.target_class,
            drop_last=True,
            num_workers=0)
        self.pgd: PGD = PGD(alpha=self.pgd_alpha,
                            epsilon=pgd_epsilon,
                            iteration=pgd_iteration,
                            output=self.output)
Esempio n. 3
0
 def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_eps: float = 8.0 / 255, pgd_iter: int = 7, **kwargs):
     super().__init__(**kwargs)
     self.param_list['adv_train'] = ['pgd_alpha', 'pgd_eps', 'pgd_iter']
     self.pgd_alpha = pgd_alpha
     self.pgd_eps = pgd_eps
     self.pgd_iter = pgd_iter
     self.pgd = PGD(pgd_alpha=pgd_alpha, pgd_eps=pgd_eps, iteration=pgd_iter, stop_threshold=None)
Esempio n. 4
0
    def _train(self,
               epoch: int,
               optimizer: Optimizer,
               lr_scheduler: _LRScheduler = None,
               print_prefix: str = 'Epoch',
               start_epoch: int = 0,
               validate_interval: int = 10,
               save: bool = False,
               amp: bool = False,
               loader_train: torch.utils.data.DataLoader = None,
               loader_valid: torch.utils.data.DataLoader = None,
               epoch_fn: Callable[..., None] = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               after_loss_fn: Callable[..., None] = None,
               validate_fn: Callable[..., tuple[float, float]] = None,
               save_fn: Callable[..., None] = None,
               file_path: str = None,
               folder_path: str = None,
               suffix: str = None,
               writer: SummaryWriter = None,
               main_tag: str = 'train',
               tag: str = '',
               verbose: bool = True,
               indent: int = 0,
               adv_train: bool = False,
               adv_train_alpha: float = 2.0 / 255,
               adv_train_epsilon: float = 8.0 / 255,
               adv_train_iter: int = 7,
               **kwargs):
        if adv_train:
            after_loss_fn_old = after_loss_fn
            if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'):
                after_loss_fn_old = getattr(self, 'after_loss_fn')
            validate_fn_old = validate_fn if callable(
                validate_fn) else self._validate
            loss_fn = loss_fn if callable(loss_fn) else self.loss
            from trojanvision.optim import PGD  # TODO: consider to move import sentences to top of file
            self.pgd = PGD(alpha=adv_train_alpha,
                           epsilon=adv_train_epsilon,
                           iteration=adv_train_iter,
                           stop_threshold=None)

            def after_loss_fn_new(_input: torch.Tensor,
                                  _label: torch.Tensor,
                                  _output: torch.Tensor,
                                  loss: torch.Tensor,
                                  optimizer: Optimizer,
                                  loss_fn: Callable[..., torch.Tensor] = None,
                                  amp: bool = False,
                                  scaler: torch.cuda.amp.GradScaler = None,
                                  **kwargs):
                noise = torch.zeros_like(_input)

                def loss_fn_new(X: torch.FloatTensor) -> torch.Tensor:
                    return -loss_fn(X, _label)

                for m in range(self.pgd.iteration):
                    if amp:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()
                    self.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input,
                                                 noise=noise,
                                                 loss_fn=loss_fn_new,
                                                 iteration=1)
                    self.train()
                    loss = loss_fn(adv_x, _label)
                    if callable(after_loss_fn_old):
                        after_loss_fn_old(_input=_input,
                                          _label=_label,
                                          _output=_output,
                                          loss=loss,
                                          optimizer=optimizer,
                                          loss_fn=loss_fn,
                                          amp=amp,
                                          scaler=scaler,
                                          **kwargs)
                    if amp:
                        scaler.scale(loss).backward()
                    else:
                        loss.backward()

            def validate_fn_new(get_data_fn: Callable[..., tuple[
                torch.Tensor, torch.Tensor]] = None,
                                print_prefix: str = 'Validate',
                                **kwargs) -> tuple[float, float]:
                _, clean_acc = validate_fn_old(print_prefix='Validate Clean',
                                               main_tag='valid clean',
                                               get_data_fn=None,
                                               **kwargs)
                _, adv_acc = validate_fn_old(print_prefix='Validate Adv',
                                             main_tag='valid adv',
                                             get_data_fn=functools.partial(
                                                 get_data_fn, adv=True),
                                             **kwargs)
                return adv_acc, clean_acc

            after_loss_fn = after_loss_fn_new
            validate_fn = validate_fn_new

        super()._train(epoch=epoch,
                       optimizer=optimizer,
                       lr_scheduler=lr_scheduler,
                       print_prefix=print_prefix,
                       start_epoch=start_epoch,
                       validate_interval=validate_interval,
                       save=save,
                       amp=amp,
                       loader_train=loader_train,
                       loader_valid=loader_valid,
                       epoch_fn=epoch_fn,
                       get_data_fn=get_data_fn,
                       loss_fn=loss_fn,
                       after_loss_fn=after_loss_fn,
                       validate_fn=validate_fn,
                       save_fn=save_fn,
                       file_path=file_path,
                       folder_path=folder_path,
                       suffix=suffix,
                       writer=writer,
                       main_tag=main_tag,
                       tag=tag,
                       verbose=verbose,
                       indent=indent,
                       **kwargs)