コード例 #1
0
    def get_explation_feature(self) -> list[float]:
        dataset = self.dataset.get_dataset(mode='train')
        subset, _ = self.dataset.split_dataset(dataset,
                                               percent=self.sample_ratio)
        clean_loader = self.dataset.get_dataloader(mode='train',
                                                   dataset=subset)

        _input, _label = zip(*subset)
        _input = torch.stack(_input)
        _label = torch.tensor(_label)
        poison_input = self.attack.add_mark(_input)
        newset = TensorDataset(poison_input, _label)
        backdoor_loader = self.dataset.get_dataloader(mode='train',
                                                      dataset=newset)

        exp_features = []
        for label in range(self.model.num_classes):
            print('Class: ', output_iter(label, self.model.num_classes))
            backdoor_saliency_maps = self.saliency_map(
                label, backdoor_loader)  # (N, H, W)
            benign_saliency_maps = self.saliency_map(label,
                                                     clean_loader)  # (N, H, W)
            exp_features.append(
                self.cal_explanation_feature(backdoor_saliency_maps,
                                             benign_saliency_maps))
        return exp_features
コード例 #2
0
 def get_potential_triggers(
         self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     mark_list, mask_list, loss_list = [], [], []
     # todo: parallel to avoid for loop
     file_path = os.path.normpath(
         os.path.join(
             self.folder_path,
             self.get_filename(target_class=self.target_class) + '.npz'))
     for label in range(self.model.num_classes):
         print('Class: ', output_iter(label, self.model.num_classes))
         mark, mask, loss = self.remask(label)
         mark_list.append(mark)
         mask_list.append(mask)
         loss_list.append(loss)
         if not self.random_pos:
             overlap = jaccard_idx(mask,
                                   self.real_mask,
                                   select_num=self.attack.mark.mark_height *
                                   self.attack.mark.mark_width)
             print(f'Jaccard index: {overlap:.3f}')
         np.savez(file_path,
                  mark_list=mark_list,
                  mask_list=mask_list,
                  loss_list=loss_list)
         print('Defense results saved at: ' + file_path)
     mark_list = torch.stack(mark_list)
     mask_list = torch.stack(mask_list)
     loss_list = torch.as_tensor(loss_list)
     return mark_list, mask_list, loss_list
コード例 #3
0
 def get_potential_triggers(self) -> tuple[torch.Tensor, torch.Tensor]:
     mark_list, loss_list = [], []
     # todo: parallel to avoid for loop
     for label in range(self.model.num_classes):
         print('Class: ', output_iter(label, self.model.num_classes))
         loss, mark = self.remask(label)
         loss_list.append(loss)
         mark_list.append(mark)
     loss_list = torch.as_tensor(loss_list)
     return loss_list, mark_list
コード例 #4
0
    def get_mark_loss_list(
            self,
            verbose: bool = True,
            **kwargs) -> tuple[torch.Tensor, list[float], list[float]]:
        r"""Get list of mark, loss, asr of recovered trigger for each class.

        Args:
            verbose (bool): Whether to output jaccard index for each trigger.
                It's also passed to :meth:`optimize_mark()`.
            **kwargs: Keyword arguments passed to :meth:`optimize_mark()`.

        Returns:
            (torch.Tensor, list[float], list[float]):
                list of mark, loss, asr with length ``num_classes``.
        """
        mark_list: list[torch.Tensor] = []
        loss_list: list[float] = []
        asr_list: list[float] = []
        # todo: parallel to avoid for loop
        file_path = os.path.normpath(
            os.path.join(self.folder_path,
                         self.get_filename() + '.npz'))
        for label in range(self.model.num_classes):
            print('Class: ', output_iter(label, self.model.num_classes))
            mark, loss = self.optimize_mark(label, verbose=verbose, **kwargs)
            if verbose:
                asr, _ = self.attack.validate_fn(indent=4)
                if not self.mark_random_pos:
                    select_num = self.attack.mark.mark_height * self.attack.mark.mark_width
                    overlap = mask_jaccard(self.attack.mark.get_mask(),
                                           self.real_mask,
                                           select_num=select_num)
                    prints(f'Jaccard index: {overlap:.3f}', indent=4)
            else:
                asr, _ = self.model._validate(get_data_fn=self.attack.get_data,
                                              keep_org=False,
                                              poison_label=True,
                                              verbose=False)
            mark_list.append(mark)
            loss_list.append(loss)
            asr_list.append(asr)
            np.savez(file_path,
                     mark_list=np.stack(
                         [mark.detach().cpu().numpy() for mark in mark_list]),
                     loss_list=np.array(loss_list))
        print()
        print('Defense results saved at: ' + file_path)
        mark_list_tensor = torch.stack(mark_list)
        return mark_list_tensor, loss_list, asr_list
コード例 #5
0
    def prune(self, **kwargs):
        length = int(self.conv_module.out_channels)
        mask = self.conv_module.weight_mask
        self.prune_step(mask, prune_num=max(self.prune_num - 10, 0))
        self.attack.validate_fn()

        for i in range(min(10, length)):
            print('Iter: ', output_iter(i + 1, 10))
            self.prune_step(mask, prune_num=1)
            clean_acc, _ = self.attack.validate_fn()
            if self.attack.clean_acc - clean_acc > 20:
                break
        file_path = os.path.join(self.folder_path, self.get_filename() + '.pth')
        self.model._train(validate_fn=self.attack.validate_fn, file_path=file_path, **kwargs)
        self.attack.validate_fn()
コード例 #6
0
    def train_mask_generator(self, verbose: bool = True):
        r"""Train :attr:`self.mask_generator`."""
        optimizer = torch.optim.Adam(self.mask_generator.parameters(), lr=1e-2, betas=(0.5, 0.9))
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.train_mask_epochs)
        loader = self.dataset.loader['train']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, norm=None)
        print_prefix = 'Mask Epoch'
        for _epoch in range(self.train_mask_epochs):
            _epoch += 1
            idx = torch.randperm(len(dataset))
            pos = 0
            logger.reset()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                print_prefix, output_iter(_epoch, self.train_mask_epochs), **ansi)
            header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header))
            self.mask_generator.train()
            for data in logger.log_every(loader, header=header) if verbose else loader:
                optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                batch_size = len(_input)
                data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
                _input2, _label2 = self.model.get_data(data2)
                pos += batch_size

                _mask = self.get_mask(_input)
                _mask2 = self.get_mask(_input2)

                input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1)
                mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5

                loss_div = input_dist.div(mask_dist).mean()
                loss_norm = _mask.sub(self.mask_density).relu().mean()

                loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div
                loss.backward()
                optimizer.step()
                logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item())
            lr_scheduler.step()
            self.mask_generator.eval()
            if verbose and (_epoch % (max(self.train_mask_epochs // 5, 1)) == 0 or _epoch == self.train_mask_epochs):
                self.validate_mask_generator()
        optimizer.zero_grad()
コード例 #7
0
    def get_mark_loss_list(
            self) -> tuple[torch.Tensor, list[float], list[float]]:
        print('sample neurons')
        all_ps = self.sample_neuron(self.seed_data['input'])
        print('find min max')
        self.neuron_dict = self.find_min_max(all_ps, self.seed_data['label'])

        format_str = self.serialize_format(layer='20s',
                                           neuron='5d',
                                           value='10.3f')
        # Output neuron dict information
        for label in range(self.model.num_classes):
            print('Class: ', output_iter(label, self.model.num_classes))
            for _dict in reversed(self.neuron_dict[label]):
                prints(format_str.format(**_dict), indent=4)
        print()
        print('optimize marks')
        return super().get_mark_loss_list(verbose=False)
コード例 #8
0
    def get_potential_triggers(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mark_list, mask_list, loss_list = [], [], []
        # todo: parallel to avoid for loop
        for label in range(self.model.num_classes):
            print('Class: ', output_iter(label, self.model.num_classes))
            mark, mask, loss = self.remask(
                label)
            mark_list.append(mark)
            mask_list.append(mask)
            loss_list.append(loss)

            if not self.random_pos:
                overlap = jaccard_idx(mask, self.real_mask,
                                      select_num=self.attack.mark.mark_height * self.attack.mark.mark_width)
                print(f'Jaccard index: {overlap:.3f}')
        mark_list = torch.stack(mark_list)
        mask_list = torch.stack(mask_list)
        loss_list = torch.as_tensor(loss_list)
        return mark_list, mask_list, loss_list
コード例 #9
0
ファイル: refool.py プロジェクト: ain-soph/trojanzoo
    def attack(self, epochs: int, optimizer: torch.optim.Optimizer, **kwargs):
        model_dict = copy.deepcopy(self.model.state_dict())
        W = torch.ones(len(self.reflect_imgs))
        refool_optimizer = torch.optim.SGD(optimizer.param_groups[0]['params'],
                                           lr=self.refool_lr, momentum=0.9,
                                           weight_decay=5e-4)
        # logger = MetricLogger(meter_length=35)
        # logger.create_meters(asr='{median:.3f} ({min:.3f}  {max:.3f})')
        # iterator = logger.log_every(range(self.rank_iter))
        for _iter in range(self.rank_iter):
            print('Select iteration: ', output_iter(_iter + 1, self.rank_iter))
            # prepare data
            idx = random.choices(range(len(W)), weights=W.tolist(), k=self.refool_sample_num)
            mark = torch.ones_like(self.mark.mark).expand(self.refool_sample_num, -1, -1, -1).clone()
            mark[:, :-1] = self.reflect_imgs[idx]
            clean_input, _ = sample_batch(self.target_set, self.refool_sample_num)
            trigger_input = self.add_mark(clean_input, mark=mark)
            dataset = TensorListDataset(trigger_input, [self.target_class] * len(trigger_input))
            loader = self.dataset.get_dataloader(mode='train', dataset=dataset)
            # train
            self.model._train(self.refool_epochs, optimizer=refool_optimizer,
                              loader_train=loader, validate_interval=0,
                              output_freq='epoch', indent=4)
            self.model._validate(indent=4)
            # test
            select_idx = list(set(idx))
            marks = self.reflect_imgs[select_idx]
            asr_result = self._get_asr_result(marks)
            # update W
            W[select_idx] = asr_result
            other_idx = list(set(range(len(W))) - set(idx))
            W[other_idx] = asr_result.median()

            # logger.reset().update_list(asr=asr_result)
            self.model.load_state_dict(model_dict)
        self.mark.mark[:-1] = self.reflect_imgs[W.argmax().item()]
        self.poison_set = self.get_poison_dataset(load_mark=False)
        return super().attack(epochs=epochs, optimizer=optimizer, **kwargs)
コード例 #10
0
ファイル: fine_pruning.py プロジェクト: ain-soph/trojanzoo
    def prune(self, **kwargs):
        for name, module in reversed(list(self.model.named_modules())):
            if isinstance(module, nn.Conv2d):
                self.last_conv: nn.Conv2d = prune.identity(module, 'weight')
                break
        length = self.last_conv.out_channels

        mask: torch.Tensor = self.last_conv.weight_mask
        self.prune_step(mask, prune_num=max(self.prune_num - 10, 0))
        self.attack.validate_fn()

        for i in range(min(10, length)):
            print('Iter: ', output_iter(i + 1, 10))
            self.prune_step(mask, prune_num=1)
            _, clean_acc = self.attack.validate_fn()
            if self.attack.clean_acc - clean_acc > 20:
                break
        file_path = os.path.join(self.folder_path,
                                 self.get_filename() + '.pth')
        self.model._train(validate_fn=self.attack.validate_fn,
                          file_path=file_path,
                          **kwargs)
        self.attack.validate_fn()
コード例 #11
0
ファイル: abs.py プロジェクト: hkunzhe/trojanzoo
    def remask(self,
               _input: torch.Tensor,
               layer: str,
               neuron: int,
               label: int,
               use_mask: bool = True,
               validate_interval: int = 100,
               verbose=False) -> tuple[torch.Tensor, torch.Tensor, float]:
        atanh_mark = torch.randn(self.data_shape, device=env['device'])
        atanh_mark.requires_grad_()
        parameters: list[torch.Tensor] = [atanh_mark]
        mask = torch.ones(self.data_shape[1:], device=env['device'])
        atanh_mask = torch.ones(self.data_shape[1:], device=env['device'])
        if use_mask:
            atanh_mask.requires_grad_()
            parameters.append(atanh_mask)
            mask = tanh_func(atanh_mask)  # (h, w)
        mark = tanh_func(atanh_mark)  # (c, h, w)

        optimizer = optim.Adam(parameters,
                               lr=self.remask_lr if use_mask else 0.01 *
                               self.remask_lr)
        optimizer.zero_grad()

        # best optimization results
        mark_best = None
        loss_best = float('inf')
        mask_best = None

        for _epoch in range(self.remask_epoch):
            epoch_start = time.perf_counter()

            loss = self.abs_loss(_input,
                                 mask,
                                 mark,
                                 layer=layer,
                                 neuron=neuron,
                                 use_mask=use_mask)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            mark = tanh_func(atanh_mark)  # (c, h, w)
            if use_mask:
                mask = tanh_func(atanh_mask)  # (h, w)

            with torch.no_grad():
                X = _input + mask * (mark - _input)
                _output = self.model(X)
            acc = float(_output.argmax(dim=1).eq(label).float().mean()) * 100
            loss = float(loss)

            if verbose:
                norm = mask.norm(p=1)
                epoch_time = str(
                    datetime.timedelta(seconds=int(time.perf_counter() -
                                                   epoch_start)))
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, self.remask_epoch),
                    **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {loss:10.3f},'.ljust(20),
                    f'Acc: {acc:.3f}, '.ljust(20),
                    f'Norm: {norm:.3f},'.ljust(20),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, indent=8)
            if loss < loss_best:
                loss_best = loss
                mark_best = mark
                if use_mask:
                    mask_best = mask
            if validate_interval != 0 and verbose:
                if (
                        _epoch + 1
                ) % validate_interval == 0 or _epoch == self.remask_epoch - 1:
                    self.attack.mark.mark = mark
                    self.attack.mark.alpha_mask = mask
                    self.attack.mark.mask = torch.ones_like(mark,
                                                            dtype=torch.bool)
                    self.attack.target_class = label
                    self.model._validate(print_prefix='Validate Trigger Tgt',
                                         get_data_fn=self.attack.get_data,
                                         keep_org=False,
                                         indent=8)
                    print()
        atanh_mark.requires_grad = False
        if use_mask:
            atanh_mask.requires_grad = False
        return mark_best, mask_best, loss_best
コード例 #12
0
    def adv_train(self,
                  epochs: int,
                  optimizer: optim.Optimizer,
                  lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                  validate_interval=10,
                  save=False,
                  verbose=True,
                  indent=0,
                  **kwargs):
        loader_train = self.dataset.loader['train']
        file_path = os.path.join(self.folder_path,
                                 self.get_filename() + '.pth')

        best_acc, _ = self.validate_fn(verbose=verbose,
                                       indent=indent,
                                       **kwargs)

        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            params.extend(param_group['params'])
        for _epoch in range(epochs):
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            if verbose and env['tqdm']:
                loader_train = tqdm(loader_train)
            self.model.activate_params(params)
            optimizer.zero_grad()
            for data in loader_train:
                _input, _label = self.model.get_data(data)
                noise = torch.zeros_like(_input)
                adv_x = _input
                self.model.train()
                loss = self.model.loss(adv_x, _label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                for m in range(self.pgd.iteration):
                    self.model.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input,
                                                 noise=noise,
                                                 target=_label,
                                                 iteration=1)
                    optimizer.zero_grad()
                    self.model.train()
                    loss = self.model.loss(adv_x, _label)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.model(_input)
                acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
            epoch_time = str(
                datetime.timedelta(seconds=int(time.perf_counter() -
                                               epoch_start)))
            self.model.eval()
            self.model.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epochs),
                    **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30),
                    f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str,
                       _str,
                       prefix='{upline}{clear_line}'.format(
                           **ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch +
                        1) % validate_interval == 0 or _epoch == epochs - 1:
                    adv_acc, _ = self.validate_fn(verbose=verbose,
                                                  indent=indent,
                                                  **kwargs)
                    if adv_acc < best_acc:
                        prints('{purple}best result update!{reset}'.format(
                            **ansi),
                               indent=indent)
                        prints(
                            f'Current Acc: {adv_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                            indent=indent)
                        best_acc = adv_acc
                    if save:
                        self.model.save(file_path=file_path, verbose=verbose)
                    if verbose:
                        print('-' * 50)
        self.model.zero_grad()
コード例 #13
0
    def sample(self,
               child_name: str = None,
               class_dict: dict = None,
               sample_num: int = None,
               verbose=True):
        if sample_num is None:
            assert class_dict
            sample_num = len(class_dict)
        if child_name is None:
            child_name = self.name + '_sample%d' % sample_num
        src_path = self.folder_path
        mode_list = [
            _dir for _dir in os.listdir(src_path)
            if os.path.isdir(src_path + _dir) and _dir[0] != '.'
        ]
        dst_path = os.path.normpath(
            os.path.join(os.path.dirname(self.folder_path), child_name))
        if verbose:
            print('{yellow}src path{reset}: '.format(**ansi), src_path)
            print('{yellow}dst path{reset}: '.format(**ansi), dst_path)
        if class_dict is None:
            assert sample_num
            idx_list = np.arange(self.num_classes)
            np.random.seed(env['seed'])
            np.random.shuffle(idx_list)
            idx_list = idx_list[:sample_num]
            class_list = np.array(os.listdir(src_path +
                                             mode_list[0]))[idx_list]
            class_dict = {}
            for class_name in class_list:
                class_dict[class_name] = [class_name]
        if verbose:
            print(class_dict)

        len_i = len(class_dict.keys())
        for src_mode in mode_list:
            if verbose:
                print('{purple}{0}{reset}'.format(src_mode, **ansi))
            assert src_mode in ['train', 'valid', 'test', 'val']
            dst_mode = 'valid' if src_mode == 'val' else src_mode
            for i, dst_class in enumerate(class_dict.keys()):
                if not os.path.exists(
                        os.path.join(dst_path, dst_mode, dst_class)):
                    os.makedirs(os.path.join(dst_path, dst_mode, dst_class))
                prints('{blue_light}{0}{reset}'.format(dst_class, **ansi),
                       indent=10)
                class_list = class_dict[dst_class]
                len_j = len(class_list)
                for j, src_class in enumerate(class_list):
                    _list = os.listdir(
                        os.path.join(src_path, src_mode, src_class))
                    prints(
                        output_iter(i + 1, len_i) + output_iter(j + 1, len_j) +
                        f'dst: {dst_class:15s}    src: {src_class:15s}    image_num: {len(_list):>8d}',
                        indent=10)
                    if env['tqdm']:
                        _list = tqdm(_list)
                    for _file in _list:
                        shutil.copyfile(
                            os.path.join(src_path, src_mode, src_class, _file),
                            os.path.join(dst_path, dst_mode, dst_class, _file))
                    if env['tqdm']:
                        print('{upline}{clear_line}'.format(**ansi))
コード例 #14
0
    def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None,
               validate_interval: int = 10, save: bool = False, amp: bool = False, verbose: bool = True, indent: int = 0,
               loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None,
               get_data_fn: Callable[..., tuple[InputType, torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               validate_func: Callable[..., tuple[float, ...]] = None, epoch_func: Callable[[], None] = None,
               save_fn: Callable = None, file_path: str = None, folder_path: str = None, suffix: str = None, **kwargs):
        loader_train = loader_train if loader_train is not None else self.dataset.loader['train']
        get_data_fn = get_data_fn if get_data_fn is not None else self.get_data
        loss_fn = loss_fn if loss_fn is not None else self.loss
        validate_func = validate_func if validate_func is not None else self._validate
        save_fn = save_fn if save_fn is not None else self.save

        scaler: torch.cuda.amp.GradScaler = None
        if amp and env['num_gpus']:
            scaler = torch.cuda.amp.GradScaler()
        _, best_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn,
                                       verbose=verbose, indent=indent, **kwargs)
        losses = AverageMeter('Loss')
        top1 = AverageMeter('Acc@1')
        top5 = AverageMeter('Acc@5')
        params: list[list[nn.Parameter]] = [param_group['params'] for param_group in optimizer.param_groups]
        for _epoch in range(epoch):
            if epoch_func is not None:
                self.activate_params([])
                epoch_func()
                self.activate_params(params)
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            loader = loader_train
            if verbose and env['tqdm']:
                loader = tqdm(loader_train)
            self.train()
            self.activate_params(params)
            optimizer.zero_grad()
            for data in loader:
                # data_time.update(time.perf_counter() - end)
                _input, _label = get_data_fn(data, mode='train')
                if amp and env['num_gpus']:
                    loss = loss_fn(_input, _label, amp=True)
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss = loss_fn(_input, _label)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.get_logits(_input)
                acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
                empty_cache()
            epoch_time = str(datetime.timedelta(seconds=int(
                time.perf_counter() - epoch_start)))
            self.eval()
            self.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Acc: {top1.avg:.3f}, '.ljust(20),
                    f'Top5 Acc: {top5.avg:.3f},'.ljust(20),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn,
                                                  verbose=verbose, indent=indent, **kwargs)
                    if cur_acc >= best_acc:
                        prints('best result update!', indent=indent)
                        prints(f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}', indent=indent)
                        best_acc = cur_acc
                        if save:
                            save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose)
                    if verbose:
                        print('-' * 50)
        self.zero_grad()
コード例 #15
0
ファイル: imagefolder.py プロジェクト: ain-soph/trojanzoo
    def sample(self, child_name: str = None,
               class_dict: dict[str, list[str]] = None,
               sample_num: int = None,
               method='zip'):
        r"""Sample a subset image folder dataset.

        Args:
            child_name (str): Name of child subset.
                Defaults to ``'{self.name}_sample{sample_num}'``
            class_dict (dict[str, list[str]] | None):
                Map from new class name to list of old class names.
                If ``None``, use :attr:`sample_num` to
                random sample a subset (1 to 1).
                Defaults to ``None``.
            sample_num (int | None):
                The number of subset classes to sample
                if :attr:`class_dict` is ``None``.
                Defaults to ``None``.
            method (str): :attr:`data_format` of new subset to save.
                Defaults to ``'.zip'``.
        """
        if sample_num is None:
            assert class_dict
            sample_num = len(class_dict)
        if child_name is None and sample_num is not None:
            child_name = f'{self.name}_sample{sample_num:d}'
        src_path = self.folder_path
        dst_path = os.path.normpath(os.path.join(
            os.path.dirname(self.folder_path), child_name))
        if not os.path.exists(dst_path):
            os.makedirs(dst_path)
        print('{yellow}src path{reset}: '.format(**ansi), src_path)
        print('{yellow}dst path{reset}: '.format(**ansi), dst_path)

        mode_list = [mode for mode in ['train', 'valid', 'test']
                     if os.path.isdir(os.path.join(src_path, mode))]
        if method == 'zip':
            zip_path_list: list[str] = glob.glob(
                os.path.join(src_path, '*_store.zip'))
            mode_list = [os.path.basename(zip_path).removeprefix(
                self.name).removesuffix('_store.zip')
                for zip_path in zip_path_list]

        src2dst_dict: dict[str, str] = {}
        if class_dict is None:
            assert sample_num
            idx_list: np.ndarray = np.arange(self.num_classes)
            np.random.seed(env['data_seed'])
            np.random.shuffle(idx_list)
            idx_list = idx_list[:sample_num]
            mode = mode_list[0]
            class_list: list[str] = []
            match method:
                case 'zip':
                    zip_path = os.path.join(src_path,
                                            f'{self.name}_{mode}_store.zip')
                    with zipfile.ZipFile(zip_path, 'r',
                                         compression=zipfile.ZIP_STORED
                                         ) as src_zip:
                        name_list = src_zip.namelist()
                    for name in name_list:
                        name_dir, name_base = os.path.split(os.path.dirname(name))
                        if name_dir == mode:
                            class_list.append(name_base)
                case 'folder':
                    folder_path = os.path.join(src_path, f'{mode}')
                    class_array: np.ndarray = np.array(
                        os.listdir(folder_path))[idx_list]
                    class_list = class_array.tolist()
                    class_list = [_dir for _dir in class_list
                                  if os.path.isdir(os.path.join(
                                      folder_path, _dir))]
            class_list.sort()
            class_array = np.array(class_list)[idx_list]
            class_list = class_array.tolist()
            for class_name in class_list:
                src2dst_dict[class_name] = class_name
        else:
            src2dst_dict = {src_class: dst_class
                            for src_class, dst_list in class_dict.items()
                            for dst_class in dst_list}
        src_class_list = src2dst_dict.keys()
        print(src2dst_dict)
        if method == 'zip':
            for mode in mode_list:
                print('{purple}mode: {0}{reset}'.format(mode, **ansi))
                assert mode in ['train', 'valid', 'test']
                dst_zip_path = os.path.join(dst_path,
                                            f'{child_name}_{mode}_store.zip')
                dst_zip = zipfile.ZipFile(dst_zip_path, 'w',
                                          compression=zipfile.ZIP_STORED)
                src_zip_path = os.path.join(src_path,
                                            f'{self.name}_{mode}_store.zip')
                src_zip = zipfile.ZipFile(src_zip_path, 'r',
                                          compression=zipfile.ZIP_STORED)
                _list = src_zip.namelist()
                if env['tqdm']:
                    _list = tqdm(_list, leave=False)
                for filename in _list:
                    if filename[-1] == '/':
                        continue
                    dirname, basename = os.path.split(filename)
                    mode_check, src_class = os.path.split(dirname)
                    if mode_check == mode and src_class in src_class_list:
                        print(filename)
                        dst_class = src2dst_dict[src_class]
                        dst_zip.writestr(f'{mode}/{dst_class}/{basename}',
                                         src_zip.read(filename))
                src_zip.close()
                dst_zip.close()
        elif method == 'folder':
            len_i = len(class_dict.keys())
            for mode in mode_list:
                print('{purple}{0}{reset}'.format(mode, **ansi))
                assert mode in ['train', 'valid', 'test']
                for i, dst_class in enumerate(class_dict.keys()):
                    if not os.path.exists(_path := os.path.join(dst_path,
                                                                mode,
                                                                dst_class)):
                        os.makedirs(_path)
                    prints('{blue_light}{0}{reset}'.format(dst_class, **ansi),
                           indent=10)
                    class_list = class_dict[dst_class]
                    len_j = len(class_list)
                    for j, src_class in enumerate(class_list):
                        _list = os.listdir(os.path.join(src_path,
                                                        mode,
                                                        src_class))
                        prints(output_iter(i + 1, len_i),
                               output_iter(j + 1, len_j),
                               f'dst: {dst_class:15s}    '
                               f'src: {src_class:15s}    '
                               f'image_num: {len(_list):>8d}',
                               indent=10)
                        if env['tqdm']:
                            _list = tqdm(_list, leave=False)
                        for _file in _list:
                            src_file_path = os.path.join(src_path, mode,
                                                         src_class, _file)
                            dst_file_path = os.path.join(dst_path, mode,
                                                         dst_class, _file)
                            shutil.copyfile(src_file_path, dst_file_path)
                        if env['tqdm']:
                            print('{upline}{clear_line}'.format(**ansi))
コード例 #16
0
    def attack(self, epochs: int, optimizer: torch.optim.Optimizer,
               lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
               validate_interval: int = 1, save: bool = False,
               verbose: bool = True, **kwargs):
        if verbose:
            print('train mask generator')
        self.mark_generator.requires_grad_(False)
        self.mask_generator.requires_grad_()
        self.model.requires_grad_(False)
        self.train_mask_generator(verbose=verbose)
        if verbose:
            print()
            print('train mark generator and model')

        self.mark_generator.requires_grad_()
        self.mask_generator.requires_grad_(False)
        if not self.natural:
            params: list[nn.Parameter] = []
            for param_group in optimizer.param_groups:
                params.extend(param_group['params'])
            self.model.activate_params(params)

        mark_optimizer = torch.optim.Adam(self.mark_generator.parameters(), lr=1e-2, betas=(0.5, 0.9))
        mark_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            mark_optimizer, T_max=epochs)
        loader = self.dataset.loader['train']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, ce=None)

        if validate_interval != 0:
            best_validate_result = self.validate_fn(verbose=verbose)
            best_asr = best_validate_result[0]
        for _epoch in range(epochs):
            _epoch += 1
            idx = torch.randperm(len(dataset))
            pos = 0
            logger.reset()
            if not self.natural:
                self.model.train()
            self.mark_generator.train()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            for data in logger.log_every(loader, header=header) if verbose else loader:
                if not self.natural:
                    optimizer.zero_grad()
                mark_optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                batch_size = len(_input)
                data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
                _input2, _label2 = self.model.get_data(data2)
                pos += batch_size
                final_input, final_label = _input.clone(), _label.clone()

                # generate trigger input
                trigger_dec, trigger_int = math.modf(len(_label) * self.poison_percent)
                trigger_int = int(trigger_int)
                if random.uniform(0, 1) < trigger_dec:
                    trigger_int += 1
                x = _input[:trigger_int]
                trigger_mark, trigger_mask = self.get_mark(x), self.get_mask(x)
                trigger_input = x + trigger_mask * (trigger_mark - x)
                final_input[:trigger_int] = trigger_input
                final_label[:trigger_int] = self.target_class

                # generate cross input
                cross_dec, cross_int = math.modf(len(_label) * self.cross_percent)
                cross_int = int(cross_int)
                if random.uniform(0, 1) < cross_dec:
                    cross_int += 1
                x = _input[trigger_int:trigger_int + cross_int]
                x2 = _input2[trigger_int:trigger_int + cross_int]
                cross_mark, cross_mask = self.get_mark(x2), self.get_mask(x2)
                cross_input = x + cross_mask * (cross_mark - x)
                final_input[trigger_int:trigger_int + cross_int] = cross_input

                # div loss
                if len(trigger_input) <= len(cross_input):
                    length = len(trigger_input)
                    cross_input = cross_input[:length]
                    cross_mark = cross_mark[:length]
                    cross_mask = cross_mask[:length]
                else:
                    length = len(cross_input)
                    trigger_input = trigger_input[:length]
                    trigger_mark = trigger_mark[:length]
                    trigger_mask = trigger_mask[:length]
                input_dist: torch.Tensor = (trigger_input - cross_input).flatten(1).norm(p=2, dim=1)
                mark_dist: torch.Tensor = (trigger_mark - cross_mark).flatten(1).norm(p=2, dim=1) + 1e-5

                loss_ce = self.model.loss(final_input, final_label)
                loss_div = input_dist.div(mark_dist).mean()

                loss = loss_ce + self.lambda_div * loss_div
                loss.backward()
                if not self.natural:
                    optimizer.step()
                mark_optimizer.step()
                logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), ce=loss_ce.item())
            if not self.natural and lr_scheduler:
                lr_scheduler.step()
            mark_scheduler.step()
            if not self.natural:
                self.model.eval()
            self.mark_generator.eval()
            if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs):
                validate_result = self.validate_fn(verbose=verbose)
                cur_asr = validate_result[0]
                if cur_asr >= best_asr:
                    best_validate_result = validate_result
                    best_asr = cur_asr
                    if save:
                        self.save()
        if not self.natural:
            optimizer.zero_grad()
        mark_optimizer.zero_grad()
        self.mark_generator.requires_grad_(False)
        self.mask_generator.requires_grad_(False)
        self.model.requires_grad_(False)
        return best_validate_result
コード例 #17
0
    def remask(self, label: int) -> tuple[torch.Tensor, torch.Tensor]:
        generator = Generator(self.noise_dim, self.dataset.num_classes,
                              self.dataset.data_shape)
        for param in generator.parameters():
            param.requires_grad_()
        optimizer = optim.Adam(generator.parameters(), lr=self.remask_lr)
        optimizer.zero_grad()
        # mask = self.attack.mark.mask

        losses = AverageMeter('Loss', ':.4e')
        entropy = AverageMeter('Entropy', ':.4e')
        norm = AverageMeter('Norm', ':.4e')
        acc = AverageMeter('Acc', ':6.2f')
        torch.manual_seed(env['seed'])
        noise = torch.rand(1, self.noise_dim, device=env['device'])
        mark = torch.zeros(self.dataset.data_shape, device=env['device'])
        for _epoch in range(self.remask_epoch):
            losses.reset()
            entropy.reset()
            norm.reset()
            acc.reset()
            epoch_start = time.perf_counter()
            loader = self.loader
            if env['tqdm']:
                loader = tqdm(loader)
            for data in loader:
                _input, _label = self.model.get_data(data)
                batch_size = _label.size(0)
                poison_label = label * torch.ones_like(_label)
                mark = generator(
                    noise,
                    torch.tensor([label],
                                 device=poison_label.device,
                                 dtype=poison_label.dtype))
                poison_input = (_input + mark).clamp(0, 1)
                _output = self.model(poison_input)

                batch_acc = poison_label.eq(_output.argmax(1)).float().mean()
                batch_entropy = self.model.criterion(_output, poison_label)
                batch_norm = mark.flatten(start_dim=1).norm(p=1, dim=1).mean()
                batch_loss = batch_entropy + self.gamma_2 * batch_norm

                acc.update(batch_acc.item(), batch_size)
                entropy.update(batch_entropy.item(), batch_size)
                norm.update(batch_norm.item(), batch_size)
                losses.update(batch_loss.item(), batch_size)

                batch_loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            epoch_time = str(
                datetime.timedelta(seconds=int(time.perf_counter() -
                                               epoch_start)))
            pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                output_iter(_epoch + 1, self.remask_epoch),
                **ansi).ljust(64 if env['color'] else 35)
            _str = ' '.join([
                f'Loss: {losses.avg:.4f},'.ljust(20),
                f'Acc: {acc.avg:.2f}, '.ljust(20),
                f'Norm: {norm.avg:.4f},'.ljust(20),
                f'Entropy: {entropy.avg:.4f},'.ljust(20),
                f'Time: {epoch_time},'.ljust(20),
            ])
            prints(pre_str,
                   _str,
                   prefix='{upline}{clear_line}'.format(
                       **ansi) if env['tqdm'] else '',
                   indent=4)

        def get_data_fn(data, **kwargs):
            _input, _label = self.model.get_data(data)
            poison_label = torch.ones_like(_label) * label
            poison_input = (_input + mark).clamp(0, 1)
            return poison_input, poison_label

        self.model._validate(print_prefix='Validate Trigger Tgt',
                             get_data_fn=get_data_fn,
                             indent=4)

        if not self.attack.mark.random_pos:
            overlap = jaccard_idx(mark.mean(dim=0),
                                  self.real_mask,
                                  select_num=self.attack.mark.mark_height *
                                  self.attack.mark.mark_width)
            print(f'    Jaccard index: {overlap:.3f}')

        for param in generator.parameters():
            param.requires_grad = False
        return losses.avg, mark
コード例 #18
0
    def _train(self,
               epoch: int,
               optimizer: Optimizer,
               lr_scheduler: _LRScheduler = None,
               grad_clip: float = None,
               print_prefix: str = 'Epoch',
               start_epoch: int = 0,
               validate_interval: int = 10,
               save: bool = False,
               amp: bool = False,
               loader_train: torch.utils.data.DataLoader = None,
               loader_valid: torch.utils.data.DataLoader = None,
               epoch_fn: Callable[..., None] = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               after_loss_fn: Callable[..., None] = None,
               validate_fn: Callable[..., tuple[float, float]] = None,
               save_fn: Callable[..., None] = None,
               file_path: str = None,
               folder_path: str = None,
               suffix: str = None,
               writer=None,
               main_tag: str = 'train',
               tag: str = '',
               verbose: bool = True,
               indent: int = 0,
               **kwargs):
        loader_train = loader_train if loader_train is not None else self.dataset.loader[
            'train']
        get_data_fn = get_data_fn if callable(get_data_fn) else self.get_data
        loss_fn = loss_fn if callable(loss_fn) else self.loss
        validate_fn = validate_fn if callable(validate_fn) else self._validate
        save_fn = save_fn if callable(save_fn) else self.save
        # if not callable(iter_fn) and hasattr(self, 'iter_fn'):
        #     iter_fn = getattr(self, 'iter_fn')
        if not callable(epoch_fn) and hasattr(self, 'epoch_fn'):
            epoch_fn = getattr(self, 'epoch_fn')
        if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'):
            after_loss_fn = getattr(self, 'after_loss_fn')

        scaler: torch.cuda.amp.GradScaler = None
        if not env['num_gpus']:
            amp = False
        if amp:
            scaler = torch.cuda.amp.GradScaler()
        _, best_acc = validate_fn(loader=loader_valid,
                                  get_data_fn=get_data_fn,
                                  loss_fn=loss_fn,
                                  writer=None,
                                  tag=tag,
                                  _epoch=start_epoch,
                                  verbose=verbose,
                                  indent=indent,
                                  **kwargs)

        params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            params.extend(param_group['params'])
        total_iter = epoch * len(loader_train)
        for _epoch in range(epoch):
            _epoch += 1
            if callable(epoch_fn):
                self.activate_params([])
                epoch_fn(optimizer=optimizer,
                         lr_scheduler=lr_scheduler,
                         _epoch=_epoch,
                         epoch=epoch,
                         start_epoch=start_epoch)
                self.activate_params(params)
            logger = MetricLogger()
            logger.meters['loss'] = SmoothedValue()
            logger.meters['top1'] = SmoothedValue()
            logger.meters['top5'] = SmoothedValue()
            loader_epoch = loader_train
            if verbose:
                header = '{blue_light}{0}: {1}{reset}'.format(
                    print_prefix, output_iter(_epoch, epoch), **ansi)
                header = header.ljust(30 + get_ansi_len(header))
                if env['tqdm']:
                    header = '{upline}{clear_line}'.format(**ansi) + header
                    loader_epoch = tqdm(loader_epoch)
                loader_epoch = logger.log_every(loader_epoch,
                                                header=header,
                                                indent=indent)
            self.train()
            self.activate_params(params)
            optimizer.zero_grad()
            for i, data in enumerate(loader_epoch):
                _iter = _epoch * len(loader_train) + i
                # data_time.update(time.perf_counter() - end)
                _input, _label = get_data_fn(data, mode='train')
                _output = self(_input, amp=amp)
                loss = loss_fn(_input, _label, _output=_output, amp=amp)
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epoch=epoch)
                    optimizer.step()
                optimizer.zero_grad()
                acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                logger.meters['loss'].update(float(loss), batch_size)
                logger.meters['top1'].update(acc1, batch_size)
                logger.meters['top5'].update(acc5, batch_size)
                empty_cache(
                )  # TODO: should it be outside of the dataloader loop?
            self.eval()
            self.activate_params([])
            loss, acc = logger.meters['loss'].global_avg, logger.meters[
                'top1'].global_avg
            if writer is not None:
                from torch.utils.tensorboard import SummaryWriter
                assert isinstance(writer, SummaryWriter)
                writer.add_scalars(main_tag='Loss/' + main_tag,
                                   tag_scalar_dict={tag: loss},
                                   global_step=_epoch + start_epoch)
                writer.add_scalars(main_tag='Acc/' + main_tag,
                                   tag_scalar_dict={tag: acc},
                                   global_step=_epoch + start_epoch)
            if lr_scheduler:
                lr_scheduler.step()
            if validate_interval != 0:
                if _epoch % validate_interval == 0 or _epoch == epoch:
                    _, cur_acc = validate_fn(loader=loader_valid,
                                             get_data_fn=get_data_fn,
                                             loss_fn=loss_fn,
                                             writer=writer,
                                             tag=tag,
                                             _epoch=_epoch + start_epoch,
                                             verbose=verbose,
                                             indent=indent,
                                             **kwargs)
                    if cur_acc >= best_acc:
                        if verbose:
                            prints('{green}best result update!{reset}'.format(
                                **ansi),
                                   indent=indent)
                            prints(
                                f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                                indent=indent)
                        best_acc = cur_acc
                        if save:
                            save_fn(file_path=file_path,
                                    folder_path=folder_path,
                                    suffix=suffix,
                                    verbose=verbose)
                    if verbose:
                        prints('-' * 50, indent=indent)
        self.zero_grad()
コード例 #19
0
    def optimize_mark(self, label: int) -> tuple[torch.Tensor, float]:
        r"""
        Args:
            label (int): The class label to optimize.

        Returns:
            (torch.Tensor, torch.Tensor):
                Optimized mark tensor with shape ``(C + 1, H, W)``
                and loss tensor.
        """
        epochs = self.defense_remask_epoch
        generator = Generator(self.noise_dim, self.dataset.num_classes,
                              self.dataset.data_shape)
        generator.requires_grad_()
        optimizer = optim.Adam(generator.parameters(),
                               lr=self.defense_remask_lr)
        optimizer.zero_grad()

        losses = AverageMeter('Loss', ':.4e')
        entropy = AverageMeter('Entropy', ':.4e')
        norm = AverageMeter('Norm', ':.4e')
        acc = AverageMeter('Acc', ':6.2f')

        noise = torch.rand(1, self.noise_dim, device=env['device'])
        mark = torch.zeros(self.dataset.data_shape, device=env['device'])
        for _epoch in range(epochs):
            losses.reset()
            entropy.reset()
            norm.reset()
            acc.reset()
            epoch_start = time.perf_counter()
            loader = self.loader
            if env['tqdm']:
                loader = tqdm(loader, leave=False)
            for data in loader:
                _input, _label = self.model.get_data(data)
                mark: torch.Tensor = generator(
                    noise,
                    torch.tensor([label],
                                 device=_label.device,
                                 dtype=_label.dtype))
                self.attack.mark.mark = torch.ones_like(self.attack.mark.mark)
                self.attack.mark.mark[:-1] = mark.squeeze()
                # Or directly add and clamp according to their paper?
                trigger_input = self.attack.add_mark(_input)
                trigger_label = label * torch.ones_like(_label)
                trigger_output = self.model(trigger_input)

                batch_acc = trigger_label.eq(
                    trigger_output.argmax(1)).float().mean()
                batch_entropy = self.loss(_input,
                                          _label,
                                          target=label,
                                          trigger_output=trigger_output)
                batch_norm = torch.mean(self.attack.mark.mark[:-1].norm(p=1))
                batch_loss = batch_entropy + self.gamma_2 * batch_norm

                batch_size = _label.size(0)
                acc.update(batch_acc.item(), batch_size)
                entropy.update(batch_entropy.item(), batch_size)
                norm.update(batch_norm.item(), batch_size)
                losses.update(batch_loss.item(), batch_size)

                batch_loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            epoch_time = str(
                datetime.timedelta(seconds=int(time.perf_counter() -
                                               epoch_start)))
            pre_str: str = '{blue_light}Epoch: {0}{reset}'.format(
                output_iter(_epoch + 1, epochs), **ansi)
            pre_str = pre_str.ljust(64 if env['color'] else 35)
            _str = ' '.join([
                f'Loss: {losses.avg:.4f},'.ljust(20),
                f'Acc: {acc.avg:.2f}, '.ljust(20),
                f'Norm: {norm.avg:.4f},'.ljust(20),
                f'Entropy: {entropy.avg:.4f},'.ljust(20),
                f'Time: {epoch_time},'.ljust(20),
            ])
            prints(pre_str, _str, indent=4)
        generator.requires_grad_(False)
        return self.attack.mark.mark, losses.avg
コード例 #20
0
    def sample(self,
               child_name: str = None,
               class_dict: dict[str, list[str]] = None,
               sample_num: int = None,
               method='zip'):
        if sample_num is None:
            assert class_dict
            sample_num = len(class_dict)
        if child_name is None:
            child_name = self.name + '_sample%d' % sample_num
        src_path = self.folder_path
        dst_path = os.path.normpath(
            os.path.join(os.path.dirname(self.folder_path), child_name))
        if not os.path.exists(dst_path):
            os.makedirs(dst_path)
        print('{yellow}src path{reset}: '.format(**ansi), src_path)
        print('{yellow}dst path{reset}: '.format(**ansi), dst_path)

        mode_list = [
            mode for mode in ['train', 'valid', 'test']
            if os.path.isdir(os.path.join(src_path, mode))
        ]
        if method == 'zip':
            zip_path_list: list[str] = glob.glob(
                os.path.join(src_path, '*_store.zip'))
            mode_list = [
                os.path.basename(zip_path).removeprefix(
                    self.name).removesuffix('_store.zip')
                for zip_path in zip_path_list
            ]

        src2dst_dict: dict[str, str] = {}
        if class_dict is None:
            assert sample_num
            idx_list = np.arange(self.num_classes)
            np.random.seed(env['seed'])
            np.random.shuffle(idx_list)
            idx_list = idx_list[:sample_num]
            mode = mode_list[0]
            class_list: list[str] = []
            if method == 'zip':
                zip_path = os.path.join(src_path,
                                        f'{self.name}_{mode}_store.zip')
                with zipfile.ZipFile(
                        zip_path, 'r',
                        compression=zipfile.ZIP_STORED) as src_zip:
                    name_list = src_zip.namelist()
                for name in name_list:
                    name_dir, name_base = os.path.split(os.path.dirname(name))
                    if name_dir == mode:
                        class_list.append(name_base)
            elif method == 'folder':
                folder_path = os.path.join(src_path, f'{mode}')
                class_list = np.array(
                    os.listdir(folder_path))[idx_list].tolist()
                class_list = [
                    _dir for _dir in class_list
                    if os.path.isdir(os.path.join(folder_path, _dir))
                ]
            class_list.sort()
            class_list = np.array(class_list)[idx_list].tolist()
            for class_name in class_list:
                src2dst_dict[class_name] = class_name
        else:
            src2dst_dict = {
                src_class: dst_class
                for src_class, dst_list in class_dict.items()
                for dst_class in dst_list
            }
        src_class_list = src2dst_dict.keys()
        print(src2dst_dict)
        if method == 'zip':
            for mode in mode_list:
                print('{purple}mode: {0}{reset}'.format(mode, **ansi))
                assert mode in ['train', 'valid', 'test']
                dst_zip = zipfile.ZipFile(os.path.join(
                    dst_path, f'{child_name}_{mode}_store.zip'),
                                          'w',
                                          compression=zipfile.ZIP_STORED)
                src_zip = zipfile.ZipFile(os.path.join(
                    src_path, f'{self.name}_{mode}_store.zip'),
                                          'r',
                                          compression=zipfile.ZIP_STORED)
                _list = src_zip.namelist()
                if env['tqdm']:
                    _list = tqdm(_list)
                for filename in _list:
                    if filename[-1] == '/':
                        continue
                    dirname, basename = os.path.split(filename)
                    mode_check, src_class = os.path.split(dirname)
                    if mode_check == mode and src_class in src_class_list:
                        print(filename)
                        dst_class = src2dst_dict[src_class]
                        dst_zip.writestr(f'{mode}/{dst_class}/{basename}',
                                         src_zip.read(filename))
                src_zip.close()
                dst_zip.close()
        elif method == 'folder':
            len_i = len(class_dict.keys())
            for mode in mode_list:
                print('{purple}{0}{reset}'.format(mode, **ansi))
                assert mode in ['train', 'valid', 'test']
                for i, dst_class in enumerate(class_dict.keys()):
                    if not os.path.exists(
                            _path := os.path.join(dst_path, mode, dst_class)):
                        os.makedirs(_path)
                    prints('{blue_light}{0}{reset}'.format(dst_class, **ansi),
                           indent=10)
                    class_list = class_dict[dst_class]
                    len_j = len(class_list)
                    for j, src_class in enumerate(class_list):
                        _list = os.listdir(
                            os.path.join(src_path, mode, src_class))
                        prints(
                            output_iter(i + 1, len_i) +
                            output_iter(j + 1, len_j) +
                            f'dst: {dst_class:15s}    src: {src_class:15s}    image_num: {len(_list):>8d}',
                            indent=10)
                        if env['tqdm']:
                            _list = tqdm(_list)
                        for _file in _list:
                            shutil.copyfile(
                                os.path.join(src_path, mode, src_class, _file),
                                os.path.join(dst_path, mode, dst_class, _file))
                        if env['tqdm']:
                            print('{upline}{clear_line}'.format(**ansi))
コード例 #21
0
    def remask(self, label: int):
        epoch = self.epoch
        # no bound
        atanh_mark = torch.randn(self.data_shape, device=env['device'])
        atanh_mark.requires_grad_()
        atanh_mask = torch.randn(self.data_shape[1:], device=env['device'])
        atanh_mask.requires_grad_()
        mask = tanh_func(atanh_mask)  # (h, w)
        mark = tanh_func(atanh_mark)  # (c, h, w)

        optimizer = optim.Adam([atanh_mark, atanh_mask],
                               lr=0.1,
                               betas=(0.5, 0.9))
        optimizer.zero_grad()

        cost = self.init_cost
        cost_set_counter = 0
        cost_up_counter = 0
        cost_down_counter = 0
        cost_up_flag = False
        cost_down_flag = False

        # best optimization results
        norm_best = float('inf')
        mask_best = None
        mark_best = None
        entropy_best = None

        # counter for early stop
        early_stop_counter = 0
        early_stop_norm_best = norm_best

        losses = AverageMeter('Loss', ':.4e')
        entropy = AverageMeter('Entropy', ':.4e')
        norm = AverageMeter('Norm', ':.4e')
        acc = AverageMeter('Acc', ':6.2f')

        for _epoch in range(epoch):
            losses.reset()
            entropy.reset()
            norm.reset()
            acc.reset()
            epoch_start = time.perf_counter()
            loader = self.dataset.loader['train']
            if env['tqdm']:
                loader = tqdm(loader)
            for data in loader:
                _input, _label = self.model.get_data(data)
                batch_size = _label.size(0)
                X = _input + mask * (mark - _input)
                Y = label * torch.ones_like(_label, dtype=torch.long)
                _output = self.model(X)

                batch_acc = Y.eq(_output.argmax(1)).float().mean()
                batch_entropy = self.loss_fn(_input, _label, Y, mask, mark,
                                             label)
                batch_norm = mask.norm(p=1)
                batch_loss = batch_entropy + cost * batch_norm

                acc.update(batch_acc.item(), batch_size)
                entropy.update(batch_entropy.item(), batch_size)
                norm.update(batch_norm.item(), batch_size)
                losses.update(batch_loss.item(), batch_size)

                batch_loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                mask = tanh_func(atanh_mask)  # (h, w)
                mark = tanh_func(atanh_mark)  # (c, h, w)
            epoch_time = str(
                datetime.timedelta(seconds=int(time.perf_counter() -
                                               epoch_start)))
            pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                output_iter(_epoch + 1, epoch),
                **ansi).ljust(64 if env['color'] else 35)
            _str = ' '.join([
                f'Loss: {losses.avg:.4f},'.ljust(20),
                f'Acc: {acc.avg:.2f}, '.ljust(20),
                f'Norm: {norm.avg:.4f},'.ljust(20),
                f'Entropy: {entropy.avg:.4f},'.ljust(20),
                f'Time: {epoch_time},'.ljust(20),
            ])
            prints(pre_str,
                   _str,
                   prefix='{upline}{clear_line}'.format(
                       **ansi) if env['tqdm'] else '',
                   indent=4)

            # check to save best mask or not
            if acc.avg >= self.attack_succ_threshold and norm.avg < norm_best:
                mask_best = mask.detach()
                mark_best = mark.detach()
                norm_best = norm.avg
                entropy_best = entropy.avg

            # check early stop
            if self.early_stop:
                # only terminate if a valid attack has been found
                if norm_best < float('inf'):
                    if norm_best >= self.early_stop_threshold * early_stop_norm_best:
                        early_stop_counter += 1
                    else:
                        early_stop_counter = 0
                early_stop_norm_best = min(norm_best, early_stop_norm_best)

                if cost_down_flag and cost_up_flag and early_stop_counter >= self.early_stop_patience:
                    print('early stop')
                    break

            # check cost modification
            if cost == 0 and acc.avg >= self.attack_succ_threshold:
                cost_set_counter += 1
                if cost_set_counter >= self.patience:
                    cost = self.init_cost
                    cost_up_counter = 0
                    cost_down_counter = 0
                    cost_up_flag = False
                    cost_down_flag = False
                    print('initialize cost to %.2f' % cost)
            else:
                cost_set_counter = 0

            if acc.avg >= self.attack_succ_threshold:
                cost_up_counter += 1
                cost_down_counter = 0
            else:
                cost_up_counter = 0
                cost_down_counter += 1

            if cost_up_counter >= self.patience:
                cost_up_counter = 0
                prints('up cost from %.4f to %.4f' %
                       (cost, cost * self.cost_multiplier_up),
                       indent=4)
                cost *= self.cost_multiplier_up
                cost_up_flag = True
            elif cost_down_counter >= self.patience:
                cost_down_counter = 0
                prints('down cost from %.4f to %.4f' %
                       (cost, cost / self.cost_multiplier_down),
                       indent=4)
                cost /= self.cost_multiplier_down
                cost_down_flag = True
            if mask_best is None:
                mask_best = tanh_func(atanh_mask).detach()
                mark_best = tanh_func(atanh_mark).detach()
                norm_best = norm.avg
                entropy_best = entropy.avg
        atanh_mark.requires_grad = False
        atanh_mask.requires_grad = False

        self.attack.mark.mark = mark_best
        self.attack.mark.alpha_mark = mask_best
        self.attack.mark.mask = torch.ones_like(mark_best, dtype=torch.bool)
        self.attack.validate_fn()
        return mark_best, mask_best, entropy_best
コード例 #22
0
    def attack(self, epochs: int, save: bool = False, **kwargs):
        # train generator
        resnet_model = trojanvision.models.create('resnet18_comp',
                                                  dataset=self.dataset,
                                                  pretrained=True)
        model_extractor = nn.Sequential(*list(resnet_model._model.features.children())[:4])
        match self.generator_mode:
            case 'resnet':
                resnet_model = torchvision.models.resnet18(pretrained=True).to(device=env['device'])
                model_extractor = nn.Sequential(*list(resnet_model.children())[:5])
            case _:
                resnet_model = trojanvision.models.create('resnet18_comp',
                                                          dataset=self.dataset,
                                                          pretrained=True)
                model_extractor = nn.Sequential(*list(resnet_model._model.features.children())[:4])
        model_extractor.requires_grad_(False)
        model_extractor.train()

        self.generator.train()
        if self.generator_mode == 'default':
            self.generator.requires_grad_()
            parameters = self.generator.parameters()
        else:
            self.generator.bottleneck.requires_grad_()
            self.generator.decoder.requires_grad_()
            parameters = self.generator[1:].parameters()
        optimizer = torch.optim.Adam(parameters, lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.train_generator_epochs,
            eta_min=1e-5)

        logger = MetricLogger()
        logger.create_meters(loss=None, acc=None)
        for _epoch in range(self.train_generator_epochs):
            _epoch += 1
            logger.reset()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, self.train_generator_epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            loader = logger.log_every(self.dataset.loader['train'], header=header,
                                      tqdm_header='Batch')
            for data in loader:
                optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                adv_input = (self.generator(_input) + 1) / 2

                _feats = model_extractor(_input)
                adv_feats = model_extractor(adv_input)

                loss = F.l1_loss(_feats, self.noise_coeff * adv_feats)
                loss.backward()
                optimizer.step()
                batch_size = len(_label)
                with torch.no_grad():
                    org_class = self.model.get_class(_input)
                    adv_class = self.model.get_class(adv_input)
                    acc = (org_class == adv_class).float().sum().item() * 100.0 / batch_size
                logger.update(n=batch_size, loss=loss.item(), acc=acc)
            lr_scheduler.step()
        self.save_generator()
        optimizer.zero_grad()
        self.generator.eval()
        self.generator.requires_grad_(False)

        self.mark.mark[:-1] = (self.generator(self.mark.mark[:-1].unsqueeze(0))[0] + 1) / 2
        self.poison_set = self.get_poison_dataset(load_mark=False)
        return super().attack(epochs, save=save, **kwargs)
コード例 #23
0
    def sample(self,
               child_name: str = None,
               class_dict: dict = None,
               sample_num: int = None,
               verbose=True):
        if sample_num is None:
            assert class_dict
            sample_num = len(class_dict)
        if child_name is None:
            child_name = self.name + '_sample%d' % sample_num
        src_path = self.folder_path + self.name + '/'
        mode_list = [
            _dir for _dir in os.listdir(src_path)
            if os.path.isdir(src_path + _dir) and _dir[0] != '.'
        ]
        dst_path = env['data_dir'] + self.data_type + \
            '/{0}/data/{0}/'.format(child_name)
        if verbose:
            print('src path: ', src_path)
            print('dst path: ', dst_path)
        if class_dict is None:
            assert sample_num
            idx_list = np.arange(self.num_classes)
            np.random.seed(env['seed'])
            np.random.shuffle(idx_list)
            idx_list = idx_list[:sample_num]
            class_list = np.array(os.listdir(src_path +
                                             mode_list[0]))[idx_list]
            class_dict = {}
            for class_name in class_list:
                class_dict[class_name] = [class_name]
        if verbose:
            print(class_dict)

        len_i = len(class_dict.keys())
        for src_mode in mode_list:
            if verbose:
                print(src_mode)
            assert src_mode in ['train', 'valid', 'test', 'val']
            dst_mode = 'valid' if src_mode == 'val' else src_mode
            for i, dst_class in enumerate(class_dict.keys()):
                if not os.path.exists(dst_path + dst_mode + '/' + dst_class):
                    os.makedirs(dst_path + dst_mode + '/' + dst_class)
                prints(dst_class, indent=10)
                class_list = class_dict[dst_class]
                len_j = len(class_list)
                for j, src_class in enumerate(class_list):
                    _list = os.listdir(src_path + src_mode + '/' + src_class)
                    prints(
                        output_iter(i + 1, len_i) + output_iter(j + 1, len_j) +
                        f'dst: {dst_class:15s}    src: {src_class:15s}    image_num: {len(_list):>8d}',
                        indent=10)
                    if env['tqdm']:
                        _list = tqdm(_list)
                    for _file in _list:
                        shutil.copyfile(
                            src_path + src_mode + '/' + src_class + '/' +
                            _file, dst_path + dst_mode + '/' + dst_class +
                            '/' + _file)
                    if env['tqdm']:
                        print('{upline}{clear_line}'.format(**ansi), end='')
コード例 #24
0
    def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                  validate_interval=10, save=False, verbose=True, indent=0, epoch_fn: Callable = None,
                  **kwargs):
        loader_train = self.dataset.loader['train']
        file_path = self.folder_path + self.get_filename() + '.pth'

        _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)

        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        params = [param_group['params'] for param_group in optimizer.param_groups]
        for _epoch in range(epoch):
            if callable(epoch_fn):
                self.model.activate_params([])
                epoch_fn()
                self.model.activate_params(params)
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            if verbose and env['tqdm']:
                loader_train = tqdm(loader_train)
            optimizer.zero_grad()
            for data in loader_train:
                _input, _label = self.model.get_data(data)
                noise = torch.zeros_like(_input)

                poison_input, poison_label = self.get_poison_data(data)

                def loss_fn(X: torch.FloatTensor):
                    return -self.model.loss(X, _label)
                adv_x = _input
                self.model.train()
                loss = self.model.loss(adv_x, _label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                for m in range(self.pgd.iteration):
                    self.model.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1)

                    optimizer.zero_grad()
                    self.model.train()

                    x = torch.cat((adv_x, poison_input))
                    y = torch.cat((_label, poison_label))
                    loss = self.model.loss(x, y)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.model.get_logits(_input)
                acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
            epoch_time = str(datetime.timedelta(seconds=int(
                time.perf_counter() - epoch_start)))
            self.model.eval()
            self.model.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30),
                    f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)
                    if cur_acc < best_acc:
                        prints('best result update!', indent=indent)
                        prints(f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}', indent=indent)
                        best_acc = cur_acc
                    if save:
                        self.save()
                    if verbose:
                        print('-' * 50)
        self.model.zero_grad()
コード例 #25
0
ファイル: train.py プロジェクト: ain-soph/trojanzoo
def train(module: nn.Module,
          num_classes: int,
          epochs: int,
          optimizer: Optimizer,
          lr_scheduler: _LRScheduler = None,
          lr_warmup_epochs: int = 0,
          model_ema: ExponentialMovingAverage = None,
          model_ema_steps: int = 32,
          grad_clip: float = None,
          pre_conditioner: None | KFAC | EKFAC = None,
          print_prefix: str = 'Train',
          start_epoch: int = 0,
          resume: int = 0,
          validate_interval: int = 10,
          save: bool = False,
          amp: bool = False,
          loader_train: torch.utils.data.DataLoader = None,
          loader_valid: torch.utils.data.DataLoader = None,
          epoch_fn: Callable[..., None] = None,
          get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None,
          forward_fn: Callable[..., torch.Tensor] = None,
          loss_fn: Callable[..., torch.Tensor] = None,
          after_loss_fn: Callable[..., None] = None,
          validate_fn: Callable[..., tuple[float, float]] = None,
          save_fn: Callable[..., None] = None,
          file_path: str = None,
          folder_path: str = None,
          suffix: str = None,
          writer=None,
          main_tag: str = 'train',
          tag: str = '',
          accuracy_fn: Callable[..., list[float]] = None,
          verbose: bool = True,
          output_freq: str = 'iter',
          indent: int = 0,
          change_train_eval: bool = True,
          lr_scheduler_freq: str = 'epoch',
          backward_and_step: bool = True,
          **kwargs):
    r"""Train the model"""
    if epochs <= 0:
        return
    get_data_fn = get_data_fn or (lambda x: x)
    forward_fn = forward_fn or module.__call__
    loss_fn = loss_fn or (lambda _input, _label, _output=None: F.cross_entropy(
        _output or forward_fn(_input), _label))
    validate_fn = validate_fn or validate
    accuracy_fn = accuracy_fn or accuracy

    scaler: torch.cuda.amp.GradScaler = None
    if not env['num_gpus']:
        amp = False
    if amp:
        scaler = torch.cuda.amp.GradScaler()
    best_validate_result = (0.0, float('inf'))
    if validate_interval != 0:
        best_validate_result = validate_fn(loader=loader_valid,
                                           get_data_fn=get_data_fn,
                                           forward_fn=forward_fn,
                                           loss_fn=loss_fn,
                                           writer=None,
                                           tag=tag,
                                           _epoch=start_epoch,
                                           verbose=verbose,
                                           indent=indent,
                                           **kwargs)
        best_acc = best_validate_result[0]

    params: list[nn.Parameter] = []
    for param_group in optimizer.param_groups:
        params.extend(param_group['params'])
    len_loader_train = len(loader_train)
    total_iter = (epochs - resume) * len_loader_train

    logger = MetricLogger()
    logger.create_meters(loss=None, top1=None, top5=None)

    if resume and lr_scheduler:
        for _ in range(resume):
            lr_scheduler.step()
    iterator = range(resume, epochs)
    if verbose and output_freq == 'epoch':
        header: str = '{blue_light}{0}: {reset}'.format(print_prefix, **ansi)
        header = header.ljust(max(len(header), 30) + get_ansi_len(header))
        iterator = logger.log_every(range(resume, epochs),
                                    header=print_prefix,
                                    tqdm_header='Epoch',
                                    indent=indent)
    for _epoch in iterator:
        _epoch += 1
        logger.reset()
        if callable(epoch_fn):
            activate_params(module, [])
            epoch_fn(optimizer=optimizer,
                     lr_scheduler=lr_scheduler,
                     _epoch=_epoch,
                     epochs=epochs,
                     start_epoch=start_epoch)
        loader_epoch = loader_train
        if verbose and output_freq == 'iter':
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            loader_epoch = logger.log_every(loader_train,
                                            header=header,
                                            tqdm_header='Batch',
                                            indent=indent)
        if change_train_eval:
            module.train()
        activate_params(module, params)
        for i, data in enumerate(loader_epoch):
            _iter = _epoch * len_loader_train + i
            # data_time.update(time.perf_counter() - end)
            _input, _label = get_data_fn(data, mode='train')
            if pre_conditioner is not None and not amp:
                pre_conditioner.track.enable()
            _output = forward_fn(_input, amp=amp, parallel=True)
            loss = loss_fn(_input, _label, _output=_output, amp=amp)
            if backward_and_step:
                optimizer.zero_grad()
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn) or grad_clip is not None:
                        scaler.unscale_(optimizer)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params, grad_clip)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epochs=epochs)
                    if pre_conditioner is not None:
                        pre_conditioner.track.disable()
                        pre_conditioner.step()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params, grad_clip)
                    optimizer.step()

            if model_ema and i % model_ema_steps == 0:
                model_ema.update_parameters(module)
                if _epoch <= lr_warmup_epochs:
                    # Reset ema buffer to keep copying weights
                    # during warmup period
                    model_ema.n_averaged.fill_(0)

            if lr_scheduler and lr_scheduler_freq == 'iter':
                lr_scheduler.step()
            acc1, acc5 = accuracy_fn(_output,
                                     _label,
                                     num_classes=num_classes,
                                     topk=(1, 5))
            batch_size = int(_label.size(0))
            logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5)
            empty_cache()
        optimizer.zero_grad()
        if lr_scheduler and lr_scheduler_freq == 'epoch':
            lr_scheduler.step()
        if change_train_eval:
            module.eval()
        activate_params(module, [])
        loss, acc = (logger.meters['loss'].global_avg,
                     logger.meters['top1'].global_avg)
        if writer is not None:
            from torch.utils.tensorboard import SummaryWriter
            assert isinstance(writer, SummaryWriter)
            writer.add_scalars(main_tag='Loss/' + main_tag,
                               tag_scalar_dict={tag: loss},
                               global_step=_epoch + start_epoch)
            writer.add_scalars(main_tag='Acc/' + main_tag,
                               tag_scalar_dict={tag: acc},
                               global_step=_epoch + start_epoch)
        if validate_interval != 0 and (_epoch % validate_interval == 0
                                       or _epoch == epochs):
            validate_result = validate_fn(module=module,
                                          num_classes=num_classes,
                                          loader=loader_valid,
                                          get_data_fn=get_data_fn,
                                          forward_fn=forward_fn,
                                          loss_fn=loss_fn,
                                          writer=writer,
                                          tag=tag,
                                          _epoch=_epoch + start_epoch,
                                          verbose=verbose,
                                          indent=indent,
                                          **kwargs)
            cur_acc = validate_result[0]
            if cur_acc >= best_acc:
                best_validate_result = validate_result
                if verbose:
                    prints('{purple}best result update!{reset}'.format(**ansi),
                           indent=indent)
                    prints(
                        f'Current Acc: {cur_acc:.3f}    '
                        f'Previous Best Acc: {best_acc:.3f}',
                        indent=indent)
                best_acc = cur_acc
                if save:
                    save_fn(file_path=file_path,
                            folder_path=folder_path,
                            suffix=suffix,
                            verbose=verbose)
            if verbose:
                prints('-' * 50, indent=indent)
    module.zero_grad()
    return best_validate_result