Esempio n. 1
0
 def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_eps: float = 8.0 / 255, pgd_iter: int = 7, **kwargs):
     super().__init__(**kwargs)
     self.param_list['adv_train'] = ['pgd_alpha', 'pgd_eps', 'pgd_iter']
     self.pgd_alpha = pgd_alpha
     self.pgd_eps = pgd_eps
     self.pgd_iter = pgd_iter
     self.pgd = PGD(pgd_alpha=pgd_alpha, pgd_eps=pgd_eps, iteration=pgd_iter, stop_threshold=None)
Esempio n. 2
0
 def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_epsilon: float = 8.0 / 255, pgd_iteration: int = 7, **kwargs):
     super().__init__(**kwargs)
     self.param_list['adv_train'] = ['pgd_alpha', 'pgd_epsilon', 'pgd_iteration']
     self.pgd_alpha = pgd_alpha
     self.pgd_epsilon = pgd_epsilon
     self.pgd_iteration = pgd_iteration
     self.pgd = PGD(alpha=pgd_alpha, epsilon=pgd_epsilon, iteration=pgd_iteration, stop_threshold=None)
Esempio n. 3
0
    def __init__(self,
                 preprocess_layer: str = 'features',
                 pgd_epsilon: int = 16.0 / 255,
                 pgd_iteration: int = 40,
                 pgd_alpha: float = 4.0 / 255,
                 **kwargs):
        super().__init__(**kwargs)

        self.param_list['hidden_trigger'] = [
            'preprocess_layer', 'pgd_alpha', 'pgd_epsilon', 'pgd_iteration'
        ]

        self.preprocess_layer: str = preprocess_layer
        self.pgd_alpha: float = pgd_alpha
        self.pgd_epsilon: float = pgd_epsilon
        self.pgd_iteration: int = pgd_iteration

        self.target_loader = self.dataset.get_dataloader(
            'train',
            full=True,
            classes=self.target_class,
            drop_last=True,
            num_workers=0)
        self.pgd: PGD = PGD(alpha=self.pgd_alpha,
                            epsilon=pgd_epsilon,
                            iteration=pgd_iteration,
                            output=self.output)
Esempio n. 4
0
class Grad_Train(Defense):

    name: str = 'grad_train'

    def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_epsilon: float = 8.0 / 255, pgd_iteration: int = 7,
                 grad_lambda: float = 10, **kwargs):
        super().__init__(**kwargs)
        self.param_list['grad_train'] = ['grad_lambda']
        self.grad_lambda = grad_lambda

        self.param_list['adv_train'] = ['pgd_alpha', 'pgd_epsilon', 'pgd_iteration']
        self.pgd_alpha = pgd_alpha
        self.pgd_epsilon = pgd_epsilon
        self.pgd_iteration = pgd_iteration
        self.pgd = PGD(alpha=pgd_alpha, epsilon=pgd_epsilon, iteration=pgd_iteration, stop_threshold=None)

    def detect(self, **kwargs):
        self.model._train(loss_fn=self.loss_fn, validate_func=self.validate_func, verbose=True, **kwargs)

    def loss_fn(self, _input, _label, **kwargs):
        new_input = _input.repeat(4, 1, 1, 1)
        new_label = _label.repeat(4)
        noise = torch.randn_like(new_input)
        noise = noise / noise.norm(p=float('inf')) * self.pgd_epsilon
        new_input = new_input + noise
        new_input = new_input.clamp(0, 1).detach()
        new_input.requires_grad_()
        loss = self.model.loss(new_input, new_label)
        grad = torch.autograd.grad(loss, new_input, create_graph=True)[0]
        new_loss = loss + self.grad_lambda * grad.flatten(start_dim=1).norm(p=1, dim=1).mean()
        return new_loss

    def validate_func(self, get_data_fn=None, loss_fn=None, **kwargs) -> tuple[float, float, float]:
        clean_loss, clean_acc, _ = self.model._validate(print_prefix='Validate Clean',
                                                        get_data_fn=None, **kwargs)
        adv_loss, adv_acc, _ = self.model._validate(print_prefix='Validate Adv',
                                                    get_data_fn=self.get_data, **kwargs)
        # todo: Return value
        if self.clean_acc - clean_acc > 20 and self.clean_acc > 40:
            adv_acc = 0.0
        return clean_loss + adv_loss, adv_acc, clean_acc

    def get_data(self, data: tuple[torch.Tensor, torch.Tensor], **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        _input, _label = self.model.get_data(data, **kwargs)

        def loss_fn(X: torch.FloatTensor):
            return -self.model.loss(X, _label)
        adv_x, _ = self.pgd.optimize(_input=_input, loss_fn=loss_fn)
        return adv_x, _label

    def save(self, **kwargs):
        self.model.save(folder_path=self.folder_path, suffix='_grad_train', verbose=True, **kwargs)
Esempio n. 5
0
 def __init__(self, pgd_alpha: float = 1.0, pgd_epsilon: float = 8.0, pgd_iteration: int = 8,
              stop_conf: float = 0.9,
              magnet: bool = False, randomized_smooth: bool = False, curvature: bool = False, **kwargs):
     super().__init__(**kwargs)
     self.param_list['pgd'] = ['pgd_alpha', 'pgd_epsilon', 'pgd_iteration']
     self.pgd_alpha: float = pgd_alpha
     self.pgd_epsilon: float = pgd_epsilon
     self.pgd_iteration: int = pgd_iteration
     self.pgd = PGD_Optimizer(alpha=self.pgd_alpha / 255, epsilon=self.pgd_epsilon / 255,
                              iteration=self.pgd_iteration)
     self.stop_conf: float = stop_conf
     if magnet:
         self.magnet: MagNet = MagNet(dataset=self.dataset, pretrain=True)
     self.randomized_smooth: bool = randomized_smooth
     if curvature:
         self.curvature: Curvature = Curvature(model=self.model)
Esempio n. 6
0
class IMC_AdvTrain(IMC):

    r"""
    Input Model Co-optimization (IMC) Backdoor Attack is described in detail in the paper `A Tale of Evil Twins`_ by Ren Pang.

    Based on :class:`trojanzoo.attacks.backdoor.BadNet`,
    IMC optimizes the watermark pixel values using PGD attack to enhance the performance.

    Args:
        target_value (float): The proportion of malicious images in the training set (Max 0.5). Default: 10.

    .. _A Tale of Evil Twins:
        https://arxiv.org/abs/1911.01559

    """

    name: str = 'imc_advtrain'

    def __init__(self, pgd_alpha: float = 2.0 / 255, pgd_epsilon: float = 8.0 / 255, pgd_iteration: int = 7, **kwargs):
        super().__init__(**kwargs)
        self.param_list['adv_train'] = ['pgd_alpha', 'pgd_epsilon', 'pgd_iteration']
        self.pgd_alpha = pgd_alpha
        self.pgd_epsilon = pgd_epsilon
        self.pgd_iteration = pgd_iteration
        self.pgd = PGD(alpha=pgd_alpha, epsilon=pgd_epsilon, iteration=pgd_iteration, stop_threshold=None)

    def get_poison_data(self, data: tuple[torch.Tensor, torch.Tensor], **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        _input, _label = self.model.get_data(data)
        decimal, integer = math.modf(self.poison_num)
        integer = int(integer)
        if random.uniform(0, 1) < decimal:
            integer += 1
        if integer:
            org_input, org_label = _input, _label
            _input = self.add_mark(org_input[:integer])
            _label = self.target_class * torch.ones_like(org_label[:integer])
        return _input, _label

    def attack(self, epoch: int, save=False, **kwargs):
        self.adv_train(epoch, save=save,
                       validate_fn=self.validate_fn, get_data_fn=self.get_data,
                       epoch_fn=self.epoch_fn, save_fn=self.save, **kwargs)

    def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                  validate_interval=10, save=False, verbose=True, indent=0, epoch_fn: Callable = None,
                  **kwargs):
        loader_train = self.dataset.loader['train']
        file_path = self.folder_path + self.get_filename() + '.pth'

        _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)

        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        params = [param_group['params'] for param_group in optimizer.param_groups]
        for _epoch in range(epoch):
            if callable(epoch_fn):
                self.model.activate_params([])
                epoch_fn()
                self.model.activate_params(params)
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            if verbose and env['tqdm']:
                loader_train = tqdm(loader_train)
            optimizer.zero_grad()
            for data in loader_train:
                _input, _label = self.model.get_data(data)
                noise = torch.zeros_like(_input)

                poison_input, poison_label = self.get_poison_data(data)

                def loss_fn(X: torch.FloatTensor):
                    return -self.model.loss(X, _label)
                adv_x = _input
                self.model.train()
                loss = self.model.loss(adv_x, _label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                for m in range(self.pgd.iteration):
                    self.model.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1)

                    optimizer.zero_grad()
                    self.model.train()

                    x = torch.cat((adv_x, poison_input))
                    y = torch.cat((_label, poison_label))
                    loss = self.model.loss(x, y)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.model.get_logits(_input)
                acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
            epoch_time = str(datetime.timedelta(seconds=int(
                time.perf_counter() - epoch_start)))
            self.model.eval()
            self.model.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30),
                    f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)
                    if cur_acc < best_acc:
                        prints('best result update!', indent=indent)
                        prints(f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}', indent=indent)
                        best_acc = cur_acc
                    if save:
                        self.save()
                    if verbose:
                        print('-' * 50)
        self.model.zero_grad()
Esempio n. 7
0
class IMC_Poison(PoisonBasic):

    name: str = 'imc_poison'

    # TODO: change PGD to Uname.optimizer
    @classmethod
    def add_argument(cls, group: argparse._ArgumentGroup):
        super().add_argument(group)
        group.add_argument('--pgd_alpha', dest='pgd_alpha', type=float)
        group.add_argument('--pgd_eps', dest='pgd_eps', type=float)
        group.add_argument('--pgd_iter', dest='pgd_iter', type=int)
        group.add_argument('--stop_conf', dest='stop_conf', type=float)

        group.add_argument('--magnet', dest='magnet', action='store_true')
        group.add_argument('--randomized_smooth',
                           dest='randomized_smooth',
                           action='store_true')
        group.add_argument('--curvature',
                           dest='curvature',
                           action='store_true')

    def __init__(self,
                 pgd_alpha: float = 1.0,
                 pgd_eps: float = 8.0,
                 pgd_iter: int = 8,
                 stop_conf: float = 0.9,
                 magnet: bool = False,
                 randomized_smooth: bool = False,
                 curvature: bool = False,
                 **kwargs):
        super().__init__(**kwargs)
        self.param_list['pgd'] = ['pgd_alpha', 'pgd_eps', 'pgd_iter']
        self.pgd_alpha: float = pgd_alpha
        self.pgd_eps: float = pgd_eps
        self.pgd_iter: int = pgd_iter
        self.pgd = PGD_Optimizer(pgd_alpha=self.pgd_alpha / 255,
                                 pgd_eps=self.pgd_eps / 255,
                                 iteration=self.pgd_iter)
        self.stop_conf: float = stop_conf
        if magnet:
            self.magnet: MagNet = MagNet(dataset=self.dataset, pretrain=True)
        self.randomized_smooth: bool = randomized_smooth
        if curvature:
            self.curvature: Curvature = Curvature(model=self.model)

    def attack(self, epoch: int, **kwargs):
        # model._validate()
        total = 0
        target_conf_list = []
        target_acc_list = []
        clean_acc_list = []
        pgd_norm_list = []
        pgd_alpha = 1.0 / 255
        pgd_eps = 8.0 / 255
        if self.dataset.name in ['cifar10', 'gtsrb', 'isic2018']:
            pgd_alpha = 1.0 / 255
            pgd_eps = 8.0 / 255
        if self.dataset.name in ['sample_imagenet', 'sample_vggface2']:
            pgd_alpha = 0.25 / 255
            pgd_eps = 2.0 / 255
        pgd_checker = PGD(pgd_alpha=pgd_alpha,
                          pgd_eps=pgd_eps,
                          iteration=8,
                          dataset=self.dataset,
                          model=self.model,
                          target_idx=self.target_idx,
                          stop_threshold=0.95)
        easy = 0
        difficult = 0
        normal = 0
        loader = self.dataset.get_dataloader(
            mode='valid', batch_size=self.dataset.test_batch_size)
        if 'curvature' in self.__dict__.keys():
            benign_curvature = self.curvature.benign_measure()
            tgt_curvature_list = []
            org_curvature_list = []
        if self.randomized_smooth:
            org_conf_list = []
            tgt_conf_list = []
        if 'magnet' in self.__dict__.keys():
            org_magnet_list = []
            tgt_magnet_list = []
        for data in loader:
            print(easy, normal, difficult)
            if normal >= 100:
                break
            self.model.load()
            _input, _label = self.model.remove_misclassify(data)
            if len(_label) == 0:
                continue
            target_label = self.model.generate_target(_input,
                                                      idx=self.target_idx)
            self.temp_input = _input
            self.temp_label = target_label
            _, _iter = pgd_checker.craft_example(_input)
            if _iter is None:
                difficult += 1
                continue
            if _iter < 4:
                easy += 1
                continue
            normal += 1
            target_conf, target_acc, clean_acc = self.validate_fn()
            noise = torch.zeros_like(_input)
            poison_input = self.craft_example(_input=_input,
                                              _label=target_label,
                                              epoch=epoch,
                                              noise=noise,
                                              **kwargs)
            pgd_norm = float(noise.norm(p=float('inf')))
            target_conf, target_acc, clean_acc = self.validate_fn()
            target_conf_list.append(target_conf)
            target_acc_list.append(target_acc)
            clean_acc_list.append(max(self.clean_acc - clean_acc, 0.0))
            pgd_norm_list.append(pgd_norm)
            print(
                f'[{total+1} / 100]\n'
                f'target confidence: {np.mean(target_conf_list)}({np.std(target_conf_list)})\n'
                f'target accuracy: {np.mean(target_acc_list)}({np.std(target_acc_list)})\n'
                f'clean accuracy Drop: {np.mean(clean_acc_list)}({np.std(clean_acc_list)})\n'
                f'PGD Norm: {np.mean(pgd_norm_list)}({np.std(pgd_norm_list)})\n\n\n'
            )
            org_conf = self.model.get_target_prob(_input=poison_input,
                                                  target=_label)
            tgt_conf = self.model.get_target_prob(_input=poison_input,
                                                  target=target_label)
            if 'curvature' in self.__dict__.keys():
                org_curvature_list.extend(
                    to_list(self.curvature.measure(poison_input,
                                                   _label)))  # type: ignore
                tgt_curvature_list.extend(
                    to_list(self.curvature.measure(
                        poison_input, target_label)))  # type: ignore
                print('Curvature:')
                print(
                    f'    org_curvature: {ks_2samp(org_curvature_list, benign_curvature)}'
                )  # type: ignore
                print(
                    f'    tgt_curvature: {ks_2samp(tgt_curvature_list, benign_curvature)}'
                )  # type: ignore
                print()
            if self.randomized_smooth:
                org_new = self.model.get_target_prob(_input=poison_input,
                                                     target=_label,
                                                     randomized_smooth=True)
                tgt_new = self.model.get_target_prob(_input=poison_input,
                                                     target=target_label,
                                                     randomized_smooth=True)
                org_increase = (org_new - org_conf).clamp(min=0.0)
                tgt_decrease = (tgt_new - tgt_conf).clamp(min=0.0)
                org_conf_list.extend(to_list(org_increase))  # type: ignore
                tgt_conf_list.extend(to_list(tgt_decrease))  # type: ignore
                print('Randomized Smooth:')
                print(f'    org_confidence: {np.mean(org_conf_list)}'
                      )  # type: ignore
                print(f'    tgt_confidence: {np.mean(tgt_conf_list)}'
                      )  # type: ignore
                print()
            if 'magnet' in self.__dict__.keys():
                poison_input = self.magnet(poison_input)
                org_new = self.model.get_target_prob(_input=poison_input,
                                                     target=_label)
                tgt_new = self.model.get_target_prob(_input=poison_input,
                                                     target=target_label)
                org_increase = (org_new - org_conf).clamp(min=0.0)
                tgt_decrease = (tgt_conf - tgt_new).clamp(min=0.0)
                org_magnet_list.extend(to_list(org_increase))  # type: ignore
                tgt_magnet_list.extend(to_list(tgt_decrease))  # type: ignore
                print('MagNet:')
                print(f'    org_confidence: {np.mean(org_magnet_list)}'
                      )  # type: ignore
                print(f'    tgt_confidence: {np.mean(tgt_magnet_list)}'
                      )  # type: ignore
                print()
            total += 1

    def craft_example(self,
                      _input: torch.Tensor,
                      _label: torch.Tensor,
                      noise: torch.Tensor = None,
                      save=False,
                      **kwargs):
        if noise is None:
            noise = torch.zeros_like(_input)
        poison_input = None
        for _iter in range(self.pgd_iter):
            target_conf, target_acc = self.validate_target(indent=4,
                                                           verbose=False)
            if target_conf > self.stop_conf:
                break
            poison_input, _ = self.pgd.optimize(_input,
                                                noise=noise,
                                                loss_fn=self.loss_pgd,
                                                iteration=1)
            self.temp_input = poison_input
            target_conf, target_acc = self.validate_target(indent=4,
                                                           verbose=False)
            if target_conf > self.stop_conf:
                break
            self._train(_input=poison_input, _label=_label, **kwargs)
        target_conf, target_acc = self.validate_target(indent=4, verbose=False)
        return poison_input

    def save(self, **kwargs):
        filename = self.get_filename(**kwargs)
        file_path = os.path.join(self.folder_path, filename)
        self.model.save(file_path + '.pth')
        print('attack results saved at: ', file_path)

    def get_filename(self, **kwargs):
        return self.model.name

    def loss_pgd(self, x: torch.Tensor) -> torch.Tensor:
        return self.model.loss(x, self.temp_label)
Esempio n. 8
0
class AdvTrain(BackdoorDefense):

    name: str = 'adv_train'

    @classmethod
    def add_argument(cls, group: argparse._ArgumentGroup):
        super().add_argument(group)
        group.add_argument('--pgd_alpha', dest='pgd_alpha', type=float)
        group.add_argument('--pgd_epsilon', dest='pgd_epsilon', type=float)
        group.add_argument('--pgd_iteration', dest='pgd_iteration', type=int)

    def __init__(self,
                 pgd_alpha: float = 2.0 / 255,
                 pgd_epsilon: float = 8.0 / 255,
                 pgd_iteration: int = 7,
                 **kwargs):
        super().__init__(**kwargs)
        self.param_list['adv_train'] = [
            'pgd_alpha', 'pgd_epsilon', 'pgd_iteration'
        ]
        self.pgd_alpha = pgd_alpha
        self.pgd_epsilon = pgd_epsilon
        self.pgd_iteration = pgd_iteration
        self.pgd = PGD(alpha=pgd_alpha,
                       epsilon=pgd_epsilon,
                       iteration=pgd_iteration,
                       stop_threshold=None)

    def detect(self, **kwargs):
        super().detect(**kwargs)
        print()
        self.adv_train(verbose=True, **kwargs)
        self.attack.validate_func()

    def validate_func(self,
                      get_data_fn=None,
                      **kwargs) -> tuple[float, float, float]:
        clean_loss, clean_acc = self.model._validate(
            print_prefix='Validate Clean', get_data_fn=None, **kwargs)
        adv_loss, adv_acc = self.model._validate(print_prefix='Validate Adv',
                                                 get_data_fn=self.get_data,
                                                 **kwargs)
        # todo: Return value
        if self.clean_acc - clean_acc > 20 and self.clean_acc > 40:
            adv_acc = 0.0
        return clean_loss + adv_loss, adv_acc, clean_acc

    def get_data(self, data: tuple[torch.Tensor, torch.Tensor],
                 **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        _input, _label = self.model.get_data(data, **kwargs)

        def loss_fn(X: torch.FloatTensor):
            return -self.model.loss(X, _label)

        adv_x, _ = self.pgd.optimize(_input=_input, loss_fn=loss_fn)
        return adv_x, _label

    def adv_train(self,
                  epoch: int,
                  optimizer: optim.Optimizer,
                  lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                  validate_interval=10,
                  save=False,
                  verbose=True,
                  indent=0,
                  **kwargs):
        loader_train = self.dataset.loader['train']
        file_path = self.folder_path + self.get_filename() + '.pth'

        _, best_acc = self.validate_func(verbose=verbose,
                                         indent=indent,
                                         **kwargs)

        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        params = [
            param_group['params'] for param_group in optimizer.param_groups
        ]
        for _epoch in range(epoch):
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            if verbose and env['tqdm']:
                loader_train = tqdm(loader_train)
            self.model.activate_params(params)
            optimizer.zero_grad()
            for data in loader_train:
                _input, _label = self.model.get_data(data)
                noise = torch.zeros_like(_input)

                def loss_fn(X: torch.FloatTensor):
                    return -self.model.loss(X, _label)

                adv_x = _input
                self.model.train()
                loss = self.model.loss(adv_x, _label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                for m in range(self.pgd.iteration):
                    self.model.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input,
                                                 noise=noise,
                                                 loss_fn=loss_fn,
                                                 iteration=1)
                    optimizer.zero_grad()
                    self.model.train()
                    loss = self.model.loss(adv_x, _label)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.model.get_logits(_input)
                acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
            epoch_time = str(
                datetime.timedelta(seconds=int(time.perf_counter() -
                                               epoch_start)))
            self.model.eval()
            self.model.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch),
                    **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30),
                    f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str,
                       _str,
                       prefix='{upline}{clear_line}'.format(
                           **ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch +
                        1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc = self.validate_func(verbose=verbose,
                                                    indent=indent,
                                                    **kwargs)
                    if cur_acc < best_acc:
                        prints('best result update!', indent=indent)
                        prints(
                            f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                            indent=indent)
                        best_acc = cur_acc
                    if save:
                        self.model.save(file_path=file_path, verbose=verbose)
                    if verbose:
                        print('-' * 50)
        self.model.zero_grad()
Esempio n. 9
0
    def _train(self,
               epoch: int,
               optimizer: Optimizer,
               lr_scheduler: _LRScheduler = None,
               print_prefix: str = 'Epoch',
               start_epoch: int = 0,
               validate_interval: int = 10,
               save: bool = False,
               amp: bool = False,
               loader_train: torch.utils.data.DataLoader = None,
               loader_valid: torch.utils.data.DataLoader = None,
               epoch_fn: Callable[..., None] = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               after_loss_fn: Callable[..., None] = None,
               validate_fn: Callable[..., tuple[float, float]] = None,
               save_fn: Callable[..., None] = None,
               file_path: str = None,
               folder_path: str = None,
               suffix: str = None,
               writer: SummaryWriter = None,
               main_tag: str = 'train',
               tag: str = '',
               verbose: bool = True,
               indent: int = 0,
               adv_train: bool = False,
               adv_train_alpha: float = 2.0 / 255,
               adv_train_epsilon: float = 8.0 / 255,
               adv_train_iter: int = 7,
               **kwargs):
        if adv_train:
            after_loss_fn_old = after_loss_fn
            if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'):
                after_loss_fn_old = getattr(self, 'after_loss_fn')
            validate_fn_old = validate_fn if callable(
                validate_fn) else self._validate
            loss_fn = loss_fn if callable(loss_fn) else self.loss
            from trojanvision.optim import PGD  # TODO: consider to move import sentences to top of file
            self.pgd = PGD(alpha=adv_train_alpha,
                           epsilon=adv_train_epsilon,
                           iteration=adv_train_iter,
                           stop_threshold=None)

            def after_loss_fn_new(_input: torch.Tensor,
                                  _label: torch.Tensor,
                                  _output: torch.Tensor,
                                  loss: torch.Tensor,
                                  optimizer: Optimizer,
                                  loss_fn: Callable[..., torch.Tensor] = None,
                                  amp: bool = False,
                                  scaler: torch.cuda.amp.GradScaler = None,
                                  **kwargs):
                noise = torch.zeros_like(_input)

                def loss_fn_new(X: torch.FloatTensor) -> torch.Tensor:
                    return -loss_fn(X, _label)

                for m in range(self.pgd.iteration):
                    if amp:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()
                    self.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input,
                                                 noise=noise,
                                                 loss_fn=loss_fn_new,
                                                 iteration=1)
                    self.train()
                    loss = loss_fn(adv_x, _label)
                    if callable(after_loss_fn_old):
                        after_loss_fn_old(_input=_input,
                                          _label=_label,
                                          _output=_output,
                                          loss=loss,
                                          optimizer=optimizer,
                                          loss_fn=loss_fn,
                                          amp=amp,
                                          scaler=scaler,
                                          **kwargs)
                    if amp:
                        scaler.scale(loss).backward()
                    else:
                        loss.backward()

            def validate_fn_new(get_data_fn: Callable[..., tuple[
                torch.Tensor, torch.Tensor]] = None,
                                print_prefix: str = 'Validate',
                                **kwargs) -> tuple[float, float]:
                _, clean_acc = validate_fn_old(print_prefix='Validate Clean',
                                               main_tag='valid clean',
                                               get_data_fn=None,
                                               **kwargs)
                _, adv_acc = validate_fn_old(print_prefix='Validate Adv',
                                             main_tag='valid adv',
                                             get_data_fn=functools.partial(
                                                 get_data_fn, adv=True),
                                             **kwargs)
                return adv_acc, clean_acc

            after_loss_fn = after_loss_fn_new
            validate_fn = validate_fn_new

        super()._train(epoch=epoch,
                       optimizer=optimizer,
                       lr_scheduler=lr_scheduler,
                       print_prefix=print_prefix,
                       start_epoch=start_epoch,
                       validate_interval=validate_interval,
                       save=save,
                       amp=amp,
                       loader_train=loader_train,
                       loader_valid=loader_valid,
                       epoch_fn=epoch_fn,
                       get_data_fn=get_data_fn,
                       loss_fn=loss_fn,
                       after_loss_fn=after_loss_fn,
                       validate_fn=validate_fn,
                       save_fn=save_fn,
                       file_path=file_path,
                       folder_path=folder_path,
                       suffix=suffix,
                       writer=writer,
                       main_tag=main_tag,
                       tag=tag,
                       verbose=verbose,
                       indent=indent,
                       **kwargs)
Esempio n. 10
0
class ImageModel(Model):
    @classmethod
    def add_argument(cls, group: argparse._ArgumentGroup):
        super().add_argument(group)
        group.add_argument('--layer',
                           dest='layer',
                           type=int,
                           help='layer (optional, maybe embedded in --model)')
        group.add_argument(
            '--width_factor',
            dest='width_factor',
            type=int,
            help=
            'width factor for wide-ResNet (optional, maybe embedded in --model)'
        )
        group.add_argument(
            '--sgm',
            dest='sgm',
            action='store_true',
            help='whether to use sgm gradient, defaults to False')
        group.add_argument('--sgm_gamma',
                           dest='sgm_gamma',
                           type=float,
                           help='sgm gamma, defaults to 1.0')
        return group

    def __init__(self,
                 name: str = 'imagemodel',
                 layer: int = None,
                 width_factor: int = None,
                 model_class: type[_ImageModel] = _ImageModel,
                 dataset: ImageSet = None,
                 sgm: bool = False,
                 sgm_gamma: float = 1.0,
                 **kwargs):
        name, layer, width_factor = self.split_model_name(
            name, layer=layer, width_factor=width_factor)
        self.layer = layer
        self.width_factor = width_factor
        if 'norm_par' not in kwargs.keys() and isinstance(dataset, ImageSet):
            kwargs['norm_par'] = dataset.norm_par
        if 'num_classes' not in kwargs.keys() and dataset is None:
            kwargs['num_classes'] = 1000
        super().__init__(name=name,
                         model_class=model_class,
                         layer=layer,
                         width_factor=width_factor,
                         dataset=dataset,
                         **kwargs)
        self.sgm: bool = sgm
        self.sgm_gamma: float = sgm_gamma
        self.param_list['imagemodel'] = ['layer', 'width_factor', 'sgm']
        if sgm:
            self.param_list['imagemodel'].extend(['sgm_gamma'])
        self._model: _ImageModel
        self.dataset: ImageSet
        self.pgd = None  # TODO: python 3.10 type annotation

    def get_data(self,
                 data: tuple[torch.Tensor, torch.Tensor],
                 adv: bool = False,
                 **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        if adv and self.pgd is not None:
            _input, _label = super().get_data(data, **kwargs)

            def loss_fn_new(
                X: torch.FloatTensor
            ) -> torch.Tensor:  # TODO: use functools.partial
                return -self.loss(X, _label)

            adv_x, _ = self.pgd.optimize(_input=_input, loss_fn=loss_fn_new)
            return adv_x, _label
        return super().get_data(data, **kwargs)

    def get_layer(self,
                  x: torch.Tensor,
                  layer_output: str = 'logits',
                  layer_input: str = 'input') -> torch.Tensor:
        return self._model.get_layer(x,
                                     layer_output=layer_output,
                                     layer_input=layer_input)

    def get_layer_name(self) -> list[str]:
        return self._model.get_layer_name()

    def get_all_layer(self,
                      x: torch.Tensor,
                      layer_input: str = 'input') -> dict[str, torch.Tensor]:
        return self._model.get_all_layer(x, layer_input=layer_input)

    # TODO: requires _input shape (N, C, H, W)
    # Reference: https://keras.io/examples/vision/grad_cam/
    def get_heatmap(self,
                    _input: torch.Tensor,
                    _label: torch.Tensor,
                    method: str = 'grad_cam',
                    cmap: Colormap = jet) -> torch.Tensor:
        squeeze_flag = False
        if _input.dim() == 3:
            _input = _input.unsqueeze(0)  # (N, C, H, W)
            squeeze_flag = True
        if isinstance(_label, int):
            _label = [_label] * len(_input)
        _label = torch.as_tensor(_label, device=_input.device)
        heatmap = _input  # linting purpose
        if method == 'grad_cam':
            feats = self._model.get_fm(_input).detach()  # (N, C', H', W')
            feats.requires_grad_()
            _output: torch.Tensor = self._model.pool(feats)  # (N, C', 1, 1)
            _output = self._model.flatten(_output)  # (N, C')
            _output = self._model.classifier(_output)  # (N, num_classes)
            _output = _output.gather(dim=1, index=_label.unsqueeze(1)).sum()
            grad = torch.autograd.grad(_output, feats)[0]  # (N, C',H', W')
            feats.requires_grad_(False)
            weights = grad.mean(dim=-2,
                                keepdim=True).mean(dim=-1,
                                                   keepdim=True)  # (N, C',1,1)
            heatmap = (feats * weights).sum(dim=1, keepdim=True).clamp(
                0)  # (N, 1, H', W')
            # heatmap.sub_(heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0])
            heatmap.div_(
                heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1,
                                                         keepdim=True)[0])
            heatmap: torch.Tensor = F.upsample(heatmap,
                                               _input.shape[-2:],
                                               mode='bilinear')[:,
                                                                0]  # (N, H, W)
            # Note that we violate the image order convension (W, H, C)
        elif method == 'saliency_map':
            _input.requires_grad_()
            _output = self(_input).gather(dim=1,
                                          index=_label.unsqueeze(1)).sum()
            grad = torch.autograd.grad(_output, _input)[0]  # (N,C,H,W)
            _input.requires_grad_(False)

            heatmap = grad.abs().max(dim=1)[0]  # (N,H,W)
            heatmap.sub_(
                heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1,
                                                         keepdim=True)[0])
            heatmap.div_(
                heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1,
                                                         keepdim=True)[0])
        heatmap = apply_cmap(heatmap.detach().cpu(), cmap)
        return heatmap[0] if squeeze_flag else heatmap

    @staticmethod
    def split_model_name(name: str,
                         layer: int = None,
                         width_factor: int = None) -> tuple[str, int, int]:
        re_list = re.findall(r'[0-9]+|[a-z]+|_', name)
        if len(re_list) > 1:
            name = re_list[0]
            layer = int(re_list[1])
        if len(re_list) > 2 and re_list[-2] == 'x':
            width_factor = int(re_list[-1])
        if layer is not None:
            name += str(layer)
        if width_factor is not None:
            name += f'x{width_factor:d}'
        return name, layer, width_factor

    def _train(self,
               epoch: int,
               optimizer: Optimizer,
               lr_scheduler: _LRScheduler = None,
               print_prefix: str = 'Epoch',
               start_epoch: int = 0,
               validate_interval: int = 10,
               save: bool = False,
               amp: bool = False,
               loader_train: torch.utils.data.DataLoader = None,
               loader_valid: torch.utils.data.DataLoader = None,
               epoch_fn: Callable[..., None] = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               after_loss_fn: Callable[..., None] = None,
               validate_fn: Callable[..., tuple[float, float]] = None,
               save_fn: Callable[..., None] = None,
               file_path: str = None,
               folder_path: str = None,
               suffix: str = None,
               writer: SummaryWriter = None,
               main_tag: str = 'train',
               tag: str = '',
               verbose: bool = True,
               indent: int = 0,
               adv_train: bool = False,
               adv_train_alpha: float = 2.0 / 255,
               adv_train_epsilon: float = 8.0 / 255,
               adv_train_iter: int = 7,
               **kwargs):
        if adv_train:
            after_loss_fn_old = after_loss_fn
            if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'):
                after_loss_fn_old = getattr(self, 'after_loss_fn')
            validate_fn_old = validate_fn if callable(
                validate_fn) else self._validate
            loss_fn = loss_fn if callable(loss_fn) else self.loss
            from trojanvision.optim import PGD  # TODO: consider to move import sentences to top of file
            self.pgd = PGD(alpha=adv_train_alpha,
                           epsilon=adv_train_epsilon,
                           iteration=adv_train_iter,
                           stop_threshold=None)

            def after_loss_fn_new(_input: torch.Tensor,
                                  _label: torch.Tensor,
                                  _output: torch.Tensor,
                                  loss: torch.Tensor,
                                  optimizer: Optimizer,
                                  loss_fn: Callable[..., torch.Tensor] = None,
                                  amp: bool = False,
                                  scaler: torch.cuda.amp.GradScaler = None,
                                  **kwargs):
                noise = torch.zeros_like(_input)

                def loss_fn_new(X: torch.FloatTensor) -> torch.Tensor:
                    return -loss_fn(X, _label)

                for m in range(self.pgd.iteration):
                    if amp:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()
                    self.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input,
                                                 noise=noise,
                                                 loss_fn=loss_fn_new,
                                                 iteration=1)
                    self.train()
                    loss = loss_fn(adv_x, _label)
                    if callable(after_loss_fn_old):
                        after_loss_fn_old(_input=_input,
                                          _label=_label,
                                          _output=_output,
                                          loss=loss,
                                          optimizer=optimizer,
                                          loss_fn=loss_fn,
                                          amp=amp,
                                          scaler=scaler,
                                          **kwargs)
                    if amp:
                        scaler.scale(loss).backward()
                    else:
                        loss.backward()

            def validate_fn_new(get_data_fn: Callable[..., tuple[
                torch.Tensor, torch.Tensor]] = None,
                                print_prefix: str = 'Validate',
                                **kwargs) -> tuple[float, float]:
                _, clean_acc = validate_fn_old(print_prefix='Validate Clean',
                                               main_tag='valid clean',
                                               get_data_fn=None,
                                               **kwargs)
                _, adv_acc = validate_fn_old(print_prefix='Validate Adv',
                                             main_tag='valid adv',
                                             get_data_fn=functools.partial(
                                                 get_data_fn, adv=True),
                                             **kwargs)
                return adv_acc, clean_acc

            after_loss_fn = after_loss_fn_new
            validate_fn = validate_fn_new

        super()._train(epoch=epoch,
                       optimizer=optimizer,
                       lr_scheduler=lr_scheduler,
                       print_prefix=print_prefix,
                       start_epoch=start_epoch,
                       validate_interval=validate_interval,
                       save=save,
                       amp=amp,
                       loader_train=loader_train,
                       loader_valid=loader_valid,
                       epoch_fn=epoch_fn,
                       get_data_fn=get_data_fn,
                       loss_fn=loss_fn,
                       after_loss_fn=after_loss_fn,
                       validate_fn=validate_fn,
                       save_fn=save_fn,
                       file_path=file_path,
                       folder_path=folder_path,
                       suffix=suffix,
                       writer=writer,
                       main_tag=main_tag,
                       tag=tag,
                       verbose=verbose,
                       indent=indent,
                       **kwargs)
Esempio n. 11
0
class ImageModel(Model):
    @classmethod
    def add_argument(cls, group: argparse._ArgumentGroup):
        super().add_argument(group)
        group.add_argument('--adv_train',
                           action='store_true',
                           help='enable adversarial training.')
        group.add_argument(
            '--adv_train_iter',
            type=int,
            help='adversarial training PGD iteration, defaults to 7.')
        group.add_argument(
            '--adv_train_alpha',
            type=float,
            help='adversarial training PGD alpha, defaults to 2/255.')
        group.add_argument(
            '--adv_train_eps',
            type=float,
            help='adversarial training PGD eps, defaults to 8/255.')
        group.add_argument(
            '--adv_train_valid_eps',
            type=float,
            help='adversarial training PGD eps, defaults to 8/255.')

        group.add_argument(
            '--sgm',
            action='store_true',
            help='whether to use sgm gradient, defaults to False')
        group.add_argument('--sgm_gamma',
                           type=float,
                           help='sgm gamma, defaults to 1.0')
        return group

    def __init__(self,
                 name: str = 'imagemodel',
                 layer: int = None,
                 model: Union[type[_ImageModel], _ImageModel] = _ImageModel,
                 dataset: ImageSet = None,
                 adv_train: bool = False,
                 adv_train_iter: int = 7,
                 adv_train_alpha: float = 2 / 255,
                 adv_train_eps: float = 8 / 255,
                 adv_train_valid_eps: float = 8 / 255,
                 sgm: bool = False,
                 sgm_gamma: float = 1.0,
                 norm_par: dict[str, list[float]] = None,
                 **kwargs):
        name = self.get_name(name, layer=layer)
        norm_par = dataset.norm_par if norm_par is None else norm_par
        if 'num_classes' not in kwargs.keys() and dataset is None:
            kwargs['num_classes'] = 1000
        super().__init__(name=name,
                         model=model,
                         dataset=dataset,
                         norm_par=norm_par,
                         **kwargs)
        self.sgm: bool = sgm
        self.sgm_gamma: float = sgm_gamma
        self.adv_train = adv_train
        self.adv_train_iter = adv_train_iter
        self.adv_train_alpha = adv_train_alpha
        self.adv_train_eps = adv_train_eps
        self.adv_train_valid_eps = adv_train_valid_eps
        self.param_list['imagemodel'] = []
        if sgm:
            self.param_list['imagemodel'].append('sgm_gamma')
        if adv_train:
            self.param_list['adv_train'] = [
                'adv_train_iter', 'adv_train_alpha', 'adv_train_eps',
                'adv_train_valid_eps'
            ]
            self.suffix += '_adv_train'
            if 'suffix' not in self.param_list['model']:
                self.param_list['model'].append('suffix')
        self._model: _ImageModel
        self.dataset: ImageSet
        self.pgd = None  # TODO: python 3.10 type annotation
        self._ce_loss_fn = nn.CrossEntropyLoss(weight=self.loss_weights)

    @classmethod
    def get_name(cls, name: str, layer: int = None) -> str:
        full_list = name.split('_')
        partial_name = full_list[0]
        re_list = re.findall(r'\d+|\D+', partial_name)
        if len(re_list) > 1:
            layer = int(re_list[1])
        elif layer is not None:
            partial_name += str(layer)
        full_list[0] = partial_name
        return '_'.join(full_list)

    def adv_loss(self, _input: torch.Tensor,
                 _label: torch.Tensor) -> torch.Tensor:
        _output = self(_input)
        return -self._ce_loss_fn(_output, _label)

    def get_data(self,
                 data: tuple[torch.Tensor, torch.Tensor],
                 adv: bool = False,
                 **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        if adv and self.pgd is not None:
            _input, _label = super().get_data(data, **kwargs)
            adv_loss_fn = functools.partial(self.adv_loss, _label=_label)
            adv_x, _ = self.pgd.optimize(_input=_input, loss_fn=adv_loss_fn)
            return adv_x, _label
        return super().get_data(data, **kwargs)

    # TODO: requires _input shape (N, C, H, W)
    # Reference: https://keras.io/examples/vision/grad_cam/
    def get_heatmap(self,
                    _input: torch.Tensor,
                    _label: torch.Tensor,
                    method: str = 'grad_cam',
                    cmap: Colormap = jet) -> torch.Tensor:
        squeeze_flag = False
        if _input.dim() == 3:
            _input = _input.unsqueeze(0)  # (N, C, H, W)
            squeeze_flag = True
        if isinstance(_label, int):
            _label = [_label] * len(_input)
        _label = torch.as_tensor(_label, device=_input.device)
        heatmap = _input  # linting purpose
        if method == 'grad_cam':
            feats = self._model.get_fm(_input).detach()  # (N, C', H', W')
            feats.requires_grad_()
            _output: torch.Tensor = self._model.pool(feats)  # (N, C', 1, 1)
            _output = self._model.flatten(_output)  # (N, C')
            _output = self._model.classifier(_output)  # (N, num_classes)
            _output = _output.gather(dim=1, index=_label.unsqueeze(1)).sum()
            grad = torch.autograd.grad(_output, feats)[0]  # (N, C',H', W')
            feats.requires_grad_(False)
            weights = grad.mean(dim=-2,
                                keepdim=True).mean(dim=-1,
                                                   keepdim=True)  # (N, C',1,1)
            heatmap = (feats * weights).sum(dim=1, keepdim=True).clamp(
                0)  # (N, 1, H', W')
            # heatmap.sub_(heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0])
            heatmap.div_(
                heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1,
                                                         keepdim=True)[0])
            heatmap: torch.Tensor = F.upsample(heatmap,
                                               _input.shape[-2:],
                                               mode='bilinear')[:,
                                                                0]  # (N, H, W)
            # Note that we violate the image order convension (W, H, C)
        elif method == 'saliency_map':
            _input.requires_grad_()
            _output = self(_input).gather(dim=1,
                                          index=_label.unsqueeze(1)).sum()
            grad = torch.autograd.grad(_output, _input)[0]  # (N,C,H,W)
            _input.requires_grad_(False)

            heatmap = grad.abs().max(dim=1)[0]  # (N,H,W)
            heatmap.sub_(
                heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1,
                                                         keepdim=True)[0])
            heatmap.div_(
                heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1,
                                                         keepdim=True)[0])
        heatmap = apply_cmap(heatmap.detach().cpu(), cmap)
        return heatmap[0] if squeeze_flag else heatmap

    def _train(self,
               epoch: int,
               optimizer: Optimizer,
               lr_scheduler: _LRScheduler = None,
               print_prefix: str = 'Epoch',
               start_epoch: int = 0,
               resume: int = 0,
               validate_interval: int = 10,
               save: bool = False,
               amp: bool = False,
               loader_train: torch.utils.data.DataLoader = None,
               loader_valid: torch.utils.data.DataLoader = None,
               epoch_fn: Callable[..., None] = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               after_loss_fn: Callable[..., None] = None,
               validate_fn: Callable[..., tuple[float, float]] = None,
               save_fn: Callable[..., None] = None,
               file_path: str = None,
               folder_path: str = None,
               suffix: str = None,
               writer=None,
               main_tag: str = 'train',
               tag: str = '',
               verbose: bool = True,
               indent: int = 0,
               **kwargs):
        if self.adv_train:
            after_loss_fn_old = after_loss_fn
            if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'):
                after_loss_fn_old = getattr(self, 'after_loss_fn')
            validate_fn_old = validate_fn if callable(
                validate_fn) else self._validate
            loss_fn = loss_fn if callable(loss_fn) else self.loss
            from trojanvision.optim import PGD  # TODO: consider to move import sentences to top of file
            self.pgd = PGD(pgd_alpha=self.adv_train_alpha,
                           pgd_eps=self.adv_train_valid_eps,
                           iteration=self.adv_train_iter,
                           stop_threshold=None)

            def after_loss_fn_new(_input: torch.Tensor,
                                  _label: torch.Tensor,
                                  _output: torch.Tensor,
                                  loss: torch.Tensor,
                                  optimizer: Optimizer,
                                  loss_fn: Callable[..., torch.Tensor] = None,
                                  amp: bool = False,
                                  scaler: torch.cuda.amp.GradScaler = None,
                                  **kwargs):
                noise = torch.zeros_like(_input)
                adv_loss_fn = functools.partial(self.adv_loss, _label=_label)

                for m in range(self.pgd.iteration):
                    if amp:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()
                    self.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input,
                                                 noise=noise,
                                                 loss_fn=adv_loss_fn,
                                                 iteration=1,
                                                 pgd_eps=self.adv_train_eps)
                    self.train()
                    loss = loss_fn(adv_x, _label)
                    if callable(after_loss_fn_old):
                        after_loss_fn_old(_input=_input,
                                          _label=_label,
                                          _output=_output,
                                          loss=loss,
                                          optimizer=optimizer,
                                          loss_fn=loss_fn,
                                          amp=amp,
                                          scaler=scaler,
                                          **kwargs)
                    if amp:
                        scaler.scale(loss).backward()
                    else:
                        loss.backward()

            def validate_fn_new(get_data_fn: Callable[..., tuple[
                torch.Tensor, torch.Tensor]] = None,
                                print_prefix: str = 'Validate',
                                **kwargs) -> tuple[float, float]:
                _, clean_acc = validate_fn_old(print_prefix='Validate Clean',
                                               main_tag='valid clean',
                                               get_data_fn=None,
                                               **kwargs)
                _, adv_acc = validate_fn_old(print_prefix='Validate Adv',
                                             main_tag='valid adv',
                                             get_data_fn=functools.partial(
                                                 get_data_fn, adv=True),
                                             **kwargs)
                return adv_acc, clean_acc + adv_acc

            after_loss_fn = after_loss_fn_new
            validate_fn = validate_fn_new

        super()._train(epoch=epoch,
                       optimizer=optimizer,
                       lr_scheduler=lr_scheduler,
                       print_prefix=print_prefix,
                       start_epoch=start_epoch,
                       resume=resume,
                       validate_interval=validate_interval,
                       save=save,
                       amp=amp,
                       loader_train=loader_train,
                       loader_valid=loader_valid,
                       epoch_fn=epoch_fn,
                       get_data_fn=get_data_fn,
                       loss_fn=loss_fn,
                       after_loss_fn=after_loss_fn,
                       validate_fn=validate_fn,
                       save_fn=save_fn,
                       file_path=file_path,
                       folder_path=folder_path,
                       suffix=suffix,
                       writer=writer,
                       main_tag=main_tag,
                       tag=tag,
                       verbose=verbose,
                       indent=indent,
                       **kwargs)