Exemplo n.º 1
0
    def discrim_train(self, epoch: int, D: nn.Sequential, discrim_loader: torch.utils.data.DataLoader):
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        d_optimizer = optim.Adam(D.parameters(), lr=self.discrim_lr)
        d_optimizer.zero_grad()
        for _epoch in range(epoch):
            losses.reset()
            top1.reset()
            self.model.activate_params([D.parameters()])
            D.train()
            for data in discrim_loader:
                # train D
                _input, _label = self.model.get_data(data)
                out_f = self.model.get_final_fm(_input).detach()
                out_d = D(out_f)
                loss_d = self.model.criterion(out_d, _label)

                acc1 = self.model.accuracy(out_d, _label, topk=(1, ))[0]
                batch_size = int(_label.size(0))
                losses.update(loss_d.item(), batch_size)
                top1.update(acc1, batch_size)

                loss_d.backward()
                d_optimizer.step()
                d_optimizer.zero_grad()
            print(f'Discriminator - epoch {_epoch:4d} / {epoch:4d} | loss {losses.avg:.4f} | acc {top1.avg:.4f}')
            self.model.activate_params([])
            D.eval()
Exemplo n.º 2
0
    def joint_train(self,
                    epoch: int = 0,
                    optimizer: optim.Optimizer = None,
                    lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                    poison_loader=None,
                    discrim_loader=None,
                    save=False,
                    **kwargs):
        in_dim = self.model._model.classifier[0].in_features
        D = nn.Sequential(
            OrderedDict([('fc1', nn.Linear(in_dim,
                                           256)), ('bn1', nn.BatchNorm1d(256)),
                         ('relu1', nn.LeakyReLU()),
                         ('fc2', nn.Linear(256, 128)),
                         ('bn2', nn.BatchNorm1d(128)), ('relu2', nn.ReLU()),
                         ('fc3', nn.Linear(128, 2))]))
        if env['num_gpus']:
            D.cuda()
        optim_params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            optim_params.extend(param_group['params'])
        optimizer.zero_grad()

        best_acc = 0.0
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')

        for _epoch in range(epoch):
            self.discrim_train(epoch=100, D=D, discrim_loader=discrim_loader)

            self.model.train()
            self.model.activate_params(optim_params)
            for data in poison_loader:
                optimizer.zero_grad()
                _input, _label_f, _label_d = self.bypass_get_data(data)
                out_f = self.model(_input)
                loss_f = self.model.criterion(out_f, _label_f)
                out_d = D(self.model.get_final_fm(_input))
                loss_d = self.model.criterion(out_d, _label_d)

                loss = loss_f - self.lambd * loss_d
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            if lr_scheduler:
                lr_scheduler.step()
            self.model.activate_params([])
            self.model.eval()
            _, cur_acc = self.validate_fn(get_data_fn=self.bypass_get_data)
            if cur_acc >= best_acc:
                prints('best result update!', indent=0)
                prints(
                    f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                    indent=0)
                best_acc = cur_acc
                if save:
                    self.save()
            print('-' * 50)
Exemplo n.º 3
0
    def optimize_mark(self, loss_fn=None, **kwargs):
        atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask
        atanh_mark.requires_grad_()
        self.mark.mark = tanh_func(atanh_mark)
        optimizer = optim.Adam([atanh_mark], lr=self.inner_lr)
        optimizer.zero_grad()

        if loss_fn is None:
            loss_fn = self.model.loss

        losses = AverageMeter('Loss', ':.4e')
        for _epoch in range(self.inner_iter):
            for i, data in enumerate(self.dataset.loader['train']):
                if i > 20:
                    break
                _input, _label = self.model.get_data(data)
                poison_x = self.mark.add_mark(_input)
                loss = loss_fn(poison_x,
                               self.target_class * torch.ones_like(_label))
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                self.mark.mark = tanh_func(atanh_mark)
                losses.update(loss.item(), n=len(_label))
        atanh_mark.requires_grad = False
        self.mark.mark.detach_()
Exemplo n.º 4
0
    def preprocess_mark(self, data: dict[str, tuple[torch.Tensor,
                                                    torch.Tensor]]):
        other_x, _ = data['other']
        other_set = TensorDataset(other_x)
        other_loader = self.dataset.get_dataloader(mode='train',
                                                   dataset=other_set,
                                                   num_workers=0)

        atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask
        atanh_mark.requires_grad_()
        self.mark.mark = tanh_func(atanh_mark)
        optimizer = optim.Adam([atanh_mark], lr=self.preprocess_lr)
        optimizer.zero_grad()

        losses = AverageMeter('Loss', ':.4e')
        for _epoch in range(self.preprocess_epoch):
            loader = other_loader
            for (batch_x, ) in loader:
                poison_x = self.mark.add_mark(to_tensor(batch_x))
                loss = self.loss_mse(poison_x)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                self.mark.mark = tanh_func(atanh_mark)
                losses.update(loss.item(), n=len(batch_x))
        atanh_mark.requires_grad = False
        self.mark.mark.detach_()
Exemplo n.º 5
0
 def confirm_backdoor(self):
     top1 = AverageMeter('Acc@1', ':6.2f')
     for data in self.dataset.loader['valid']:
         _input, _ = self.model.get_data(data, mode='valid')
         poison_input = self.attack.add_mark(_input)
         with torch.no_grad():
             _class = self.model.get_class(_input)
             poison_class = self.model.get_class(poison_input)
         result = ~(_class.eq(poison_class))
         acc1 = result.float().sum() / result.numel() * 100
         top1.update(acc1.item(), len(_input))
     return top1.avg
Exemplo n.º 6
0
 def _validate(self, full=True, print_prefix='Validate', indent=0, verbose=True,
               loader: torch.utils.data.DataLoader = None,
               get_data_fn: Callable = None, loss_fn: Callable[..., float] = None, **kwargs) -> tuple[float, ...]:
     self.eval()
     if loader is None:
         loader = self.dataset.loader['valid'] if full else self.dataset.loader['valid2']
     get_data_fn = get_data_fn if get_data_fn is not None else self.get_data
     loss_fn = loss_fn if loss_fn is not None else self.loss
     losses = AverageMeter('Loss', ':.4e')
     top1 = AverageMeter('Acc@1', ':6.2f')
     top5 = AverageMeter('Acc@5', ':6.2f')
     epoch_start = time.perf_counter()
     if verbose and env['tqdm']:
         loader = tqdm(loader)
     for data in loader:
         _input, _label = get_data_fn(data, mode='valid', **kwargs)
         with torch.no_grad():
             loss = loss_fn(_input, _label)
             _output = self.get_logits(_input)
         # measure accuracy and record loss
         batch_size = int(_label.size(0))
         losses.update(loss.item(), _label.size(0))
         acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
         top1.update(acc1, batch_size)
         top5.update(acc5, batch_size)
     epoch_time = str(datetime.timedelta(seconds=int(
         time.perf_counter() - epoch_start)))
     if verbose:
         pre_str = '{yellow}{0}:{reset}'.format(print_prefix, **ansi).ljust(35)
         _str = ' '.join([
             f'Loss: {losses.avg:.4f},'.ljust(20),
             f'Top1 Acc: {top1.avg:.3f}, '.ljust(20),
             f'Top5 Acc: {top5.avg:.3f},'.ljust(20),
             f'Time: {epoch_time},'.ljust(20),
         ])
         prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=indent)
     return losses.avg, top1.avg, top5.avg
Exemplo n.º 7
0
 def validate_confidence(self) -> float:
     confidence = AverageMeter('Confidence', ':.4e')
     with torch.no_grad():
         for data in self.dataset.loader['valid']:
             _input, _label = self.model.get_data(data)
             idx1 = _label != self.target_class
             _input = _input[idx1]
             _label = _label[idx1]
             if len(_input) == 0:
                 continue
             poison_input = self.add_mark(_input)
             poison_label = self.model.get_class(poison_input)
             idx2 = poison_label == self.target_class
             poison_input = poison_input[idx2]
             if len(poison_input) == 0:
                 continue
             batch_conf = self.model.get_prob(
                 poison_input)[:, self.target_class].mean()
             confidence.update(batch_conf, len(poison_input))
     return float(confidence.avg)
Exemplo n.º 8
0
    def preprocess_mark(self, data: dict[str, tuple[torch.Tensor,
                                                    torch.Tensor]]):
        other_x, _ = data['other']
        other_set = TensorDataset(other_x)
        other_loader = self.dataset.get_dataloader(mode='train',
                                                   dataset=other_set,
                                                   num_workers=0)

        atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask
        atanh_mark.requires_grad_()
        self.mark.mark = tanh_func(atanh_mark)
        optimizer = optim.Adam([atanh_mark], lr=self.preprocess_lr)
        optimizer.zero_grad()

        losses = AverageMeter('Loss', ':.4e')
        for _epoch in range(self.preprocess_epoch):
            # epoch_start = time.perf_counter()
            loader = other_loader
            # if env['tqdm']:
            #     loader = tqdm(loader)
            for (batch_x, ) in loader:
                poison_x = self.mark.add_mark(to_tensor(batch_x))
                loss = self.loss_mse(poison_x)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                self.mark.mark = tanh_func(atanh_mark)
                losses.update(loss.item(), n=len(batch_x))
            # epoch_time = str(datetime.timedelta(seconds=int(
            #     time.perf_counter() - epoch_start)))
            # pre_str = '{blue_light}Epoch: {0}{reset}'.format(
            #     output_iter(_epoch + 1, self.preprocess_epoch), **ansi).ljust(64 if env['color'] else 35)
            # _str = ' '.join([
            #     f'Loss: {losses.avg:.4f},'.ljust(20),
            #     f'Time: {epoch_time},'.ljust(20),
            # ])
            # prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=4)
        atanh_mark.requires_grad = False
        self.mark.mark.detach_()
Exemplo n.º 9
0
    def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                  validate_interval=10, save=False, verbose=True, indent=0, epoch_fn: Callable = None,
                  **kwargs):
        loader_train = self.dataset.loader['train']
        file_path = self.folder_path + self.get_filename() + '.pth'

        _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)

        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        params = [param_group['params'] for param_group in optimizer.param_groups]
        for _epoch in range(epoch):
            if callable(epoch_fn):
                self.model.activate_params([])
                epoch_fn()
                self.model.activate_params(params)
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            if verbose and env['tqdm']:
                loader_train = tqdm(loader_train)
            optimizer.zero_grad()
            for data in loader_train:
                _input, _label = self.model.get_data(data)
                noise = torch.zeros_like(_input)

                poison_input, poison_label = self.get_poison_data(data)

                def loss_fn(X: torch.FloatTensor):
                    return -self.model.loss(X, _label)
                adv_x = _input
                self.model.train()
                loss = self.model.loss(adv_x, _label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                for m in range(self.pgd.iteration):
                    self.model.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1)

                    optimizer.zero_grad()
                    self.model.train()

                    x = torch.cat((adv_x, poison_input))
                    y = torch.cat((_label, poison_label))
                    loss = self.model.loss(x, y)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.model.get_logits(_input)
                acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
            epoch_time = str(datetime.timedelta(seconds=int(
                time.perf_counter() - epoch_start)))
            self.model.eval()
            self.model.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30),
                    f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)
                    if cur_acc < best_acc:
                        prints('best result update!', indent=indent)
                        prints(f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}', indent=indent)
                        best_acc = cur_acc
                    if save:
                        self.save()
                    if verbose:
                        print('-' * 50)
        self.model.zero_grad()
Exemplo n.º 10
0
    def remask(self, label: int):
        epoch = self.epoch
        # no bound
        atanh_mark = torch.randn(self.data_shape, device=env['device'])
        atanh_mark.requires_grad_()
        atanh_mask = torch.randn(self.data_shape[1:], device=env['device'])
        atanh_mask.requires_grad_()
        mask = tanh_func(atanh_mask)  # (h, w)
        mark = tanh_func(atanh_mark)  # (c, h, w)

        optimizer = optim.Adam([atanh_mark, atanh_mask],
                               lr=0.1,
                               betas=(0.5, 0.9))
        optimizer.zero_grad()

        cost = self.init_cost
        cost_set_counter = 0
        cost_up_counter = 0
        cost_down_counter = 0
        cost_up_flag = False
        cost_down_flag = False

        # best optimization results
        norm_best = float('inf')
        mask_best = None
        mark_best = None
        entropy_best = None

        # counter for early stop
        early_stop_counter = 0
        early_stop_norm_best = norm_best

        losses = AverageMeter('Loss', ':.4e')
        entropy = AverageMeter('Entropy', ':.4e')
        norm = AverageMeter('Norm', ':.4e')
        acc = AverageMeter('Acc', ':6.2f')

        for _epoch in range(epoch):
            losses.reset()
            entropy.reset()
            norm.reset()
            acc.reset()
            epoch_start = time.perf_counter()
            loader = self.dataset.loader['train']
            if env['tqdm']:
                loader = tqdm(loader)
            for data in loader:
                _input, _label = self.model.get_data(data)
                batch_size = _label.size(0)
                X = _input + mask * (mark - _input)
                Y = label * torch.ones_like(_label, dtype=torch.long)
                _output = self.model(X)

                batch_acc = Y.eq(_output.argmax(1)).float().mean()
                batch_entropy = self.loss_fn(_input, _label, Y, mask, mark,
                                             label)
                batch_norm = mask.norm(p=1)
                batch_loss = batch_entropy + cost * batch_norm

                acc.update(batch_acc.item(), batch_size)
                entropy.update(batch_entropy.item(), batch_size)
                norm.update(batch_norm.item(), batch_size)
                losses.update(batch_loss.item(), batch_size)

                batch_loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                mask = tanh_func(atanh_mask)  # (h, w)
                mark = tanh_func(atanh_mark)  # (c, h, w)
            epoch_time = str(
                datetime.timedelta(seconds=int(time.perf_counter() -
                                               epoch_start)))
            pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                output_iter(_epoch + 1, epoch),
                **ansi).ljust(64 if env['color'] else 35)
            _str = ' '.join([
                f'Loss: {losses.avg:.4f},'.ljust(20),
                f'Acc: {acc.avg:.2f}, '.ljust(20),
                f'Norm: {norm.avg:.4f},'.ljust(20),
                f'Entropy: {entropy.avg:.4f},'.ljust(20),
                f'Time: {epoch_time},'.ljust(20),
            ])
            prints(pre_str,
                   _str,
                   prefix='{upline}{clear_line}'.format(
                       **ansi) if env['tqdm'] else '',
                   indent=4)

            # check to save best mask or not
            if acc.avg >= self.attack_succ_threshold and norm.avg < norm_best:
                mask_best = mask.detach()
                mark_best = mark.detach()
                norm_best = norm.avg
                entropy_best = entropy.avg

            # check early stop
            if self.early_stop:
                # only terminate if a valid attack has been found
                if norm_best < float('inf'):
                    if norm_best >= self.early_stop_threshold * early_stop_norm_best:
                        early_stop_counter += 1
                    else:
                        early_stop_counter = 0
                early_stop_norm_best = min(norm_best, early_stop_norm_best)

                if cost_down_flag and cost_up_flag and early_stop_counter >= self.early_stop_patience:
                    print('early stop')
                    break

            # check cost modification
            if cost == 0 and acc.avg >= self.attack_succ_threshold:
                cost_set_counter += 1
                if cost_set_counter >= self.patience:
                    cost = self.init_cost
                    cost_up_counter = 0
                    cost_down_counter = 0
                    cost_up_flag = False
                    cost_down_flag = False
                    print('initialize cost to %.2f' % cost)
            else:
                cost_set_counter = 0

            if acc.avg >= self.attack_succ_threshold:
                cost_up_counter += 1
                cost_down_counter = 0
            else:
                cost_up_counter = 0
                cost_down_counter += 1

            if cost_up_counter >= self.patience:
                cost_up_counter = 0
                prints('up cost from %.4f to %.4f' %
                       (cost, cost * self.cost_multiplier_up),
                       indent=4)
                cost *= self.cost_multiplier_up
                cost_up_flag = True
            elif cost_down_counter >= self.patience:
                cost_down_counter = 0
                prints('down cost from %.4f to %.4f' %
                       (cost, cost / self.cost_multiplier_down),
                       indent=4)
                cost /= self.cost_multiplier_down
                cost_down_flag = True
            if mask_best is None:
                mask_best = tanh_func(atanh_mask).detach()
                mark_best = tanh_func(atanh_mark).detach()
                norm_best = norm.avg
                entropy_best = entropy.avg
        atanh_mark.requires_grad = False
        atanh_mask.requires_grad = False

        self.attack.mark.mark = mark_best
        self.attack.mark.alpha_mark = mask_best
        self.attack.mark.mask = torch.ones_like(mark_best, dtype=torch.bool)
        self.attack.validate_fn()
        return mark_best, mask_best, entropy_best
Exemplo n.º 11
0
 def get_potential_triggers(self,
                            neuron_dict: dict[int, list[dict]],
                            _input: torch.Tensor,
                            _label: torch.Tensor,
                            use_mask=True) -> dict[int, list[dict]]:
     losses = AverageMeter('Loss', ':.4e')
     norms = AverageMeter('Norm', ':6.2f')
     jaccard = AverageMeter('Jaccard Idx', ':6.2f')
     score_list = [0.0] * len(list(neuron_dict.keys()))
     result_dict = {}
     for label, label_list in neuron_dict.items():
         print('label: ', label)
         best_score = 100.0
         for _dict in reversed(label_list):
             layer = _dict['layer']
             neuron = _dict['neuron']
             value = _dict['value']
             # color = ('{red}' if label == self.attack.target_class else '{green}').format(**ansi)
             # _str = f'layer: {layer:<20} neuron: {neuron:<5d} label: {label:<5d}'
             # prints('{color}{_str}{reset}'.format(color=color, _str=_str, **ansi), indent=4)
             mark, mask, loss = self.remask(_input,
                                            layer=layer,
                                            neuron=neuron,
                                            label=label,
                                            use_mask=use_mask)
             self.attack.mark.mark = mark
             self.attack.mark.alpha_mask = mask
             self.attack.mark.mask = torch.ones_like(mark, dtype=torch.bool)
             self.attack.target_class = label
             attack_loss, attack_acc = self.model._validate(
                 verbose=False,
                 get_data_fn=self.attack.get_data,
                 keep_org=False)
             _dict['loss'] = loss
             _dict['attack_acc'] = attack_acc
             _dict['attack_loss'] = attack_loss
             _dict['mask'] = to_numpy(mask)
             _dict['mark'] = to_numpy(mark)
             _dict['norm'] = float(mask.norm(p=1))
             score = attack_loss + 7e-2 * float(mask.norm(p=1))
             if score < best_score:
                 best_score = score
                 result_dict[label] = _dict
             if attack_acc > 90:
                 losses.update(loss)
                 norms.update(mask.norm(p=1))
             _str = f'    layer: {layer:20s}    neuron: {neuron:5d}    value: {value:.3f}'
             _str += f'    loss: {loss:10.3f}'
             f'    ATK Acc: {attack_acc:.3f}'
             f'    ATK Loss: {attack_loss:10.3f}'
             f'    Norm: {mask.norm(p=1):.3f}'
             f'    Score: {score:.3f}'
             if not self.attack.mark.random_pos:
                 overlap = jaccard_idx(mask, self.real_mask)
                 _dict['jaccard'] = overlap
                 _str += f'    Jaccard: {overlap:.3f}'
                 if attack_acc > 90:
                     jaccard.update(overlap)
             else:
                 _dict['jaccard'] = 0.0
             print(_str)
             if not os.path.exists(self.folder_path):
                 os.makedirs(self.folder_path)
             np.save(
                 self.folder_path +
                 self.get_filename(target_class=self.target_class) + '.npy',
                 neuron_dict)
             np.save(
                 self.folder_path +
                 self.get_filename(target_class=self.target_class) +
                 '_best.npy', result_dict)
         print(
             f'Label: {label:3d}  loss: {result_dict[label]["loss"]:10.3f}  ATK loss: {result_dict[label]["attack_loss"]:10.3f}  Norm: {result_dict[label]["norm"]:10.3f}  Jaccard: {result_dict[label]["jaccard"]:10.3f}  Score: {best_score:.3f}'
         )
         score_list[label] = best_score
     print('Score: ', score_list)
     print('Score MAD: ', normalize_mad(score_list))
     return neuron_dict
Exemplo n.º 12
0
    def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None,
               validate_interval: int = 10, save: bool = False, amp: bool = False, verbose: bool = True, indent: int = 0,
               loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None,
               get_data_fn: Callable[..., tuple[InputType, torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               validate_func: Callable[..., tuple[float, ...]] = None, epoch_func: Callable[[], None] = None,
               save_fn: Callable = None, file_path: str = None, folder_path: str = None, suffix: str = None, **kwargs):
        loader_train = loader_train if loader_train is not None else self.dataset.loader['train']
        get_data_fn = get_data_fn if get_data_fn is not None else self.get_data
        loss_fn = loss_fn if loss_fn is not None else self.loss
        validate_func = validate_func if validate_func is not None else self._validate
        save_fn = save_fn if save_fn is not None else self.save

        scaler: torch.cuda.amp.GradScaler = None
        if amp and env['num_gpus']:
            scaler = torch.cuda.amp.GradScaler()
        _, best_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn,
                                       verbose=verbose, indent=indent, **kwargs)
        losses = AverageMeter('Loss')
        top1 = AverageMeter('Acc@1')
        top5 = AverageMeter('Acc@5')
        params: list[list[nn.Parameter]] = [param_group['params'] for param_group in optimizer.param_groups]
        for _epoch in range(epoch):
            if epoch_func is not None:
                self.activate_params([])
                epoch_func()
                self.activate_params(params)
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            loader = loader_train
            if verbose and env['tqdm']:
                loader = tqdm(loader_train)
            self.train()
            self.activate_params(params)
            optimizer.zero_grad()
            for data in loader:
                # data_time.update(time.perf_counter() - end)
                _input, _label = get_data_fn(data, mode='train')
                if amp and env['num_gpus']:
                    loss = loss_fn(_input, _label, amp=True)
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss = loss_fn(_input, _label)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.get_logits(_input)
                acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
                empty_cache()
            epoch_time = str(datetime.timedelta(seconds=int(
                time.perf_counter() - epoch_start)))
            self.eval()
            self.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Acc: {top1.avg:.3f}, '.ljust(20),
                    f'Top5 Acc: {top5.avg:.3f},'.ljust(20),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn,
                                                  verbose=verbose, indent=indent, **kwargs)
                    if cur_acc >= best_acc:
                        prints('best result update!', indent=indent)
                        prints(f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}', indent=indent)
                        best_acc = cur_acc
                        if save:
                            save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose)
                    if verbose:
                        print('-' * 50)
        self.zero_grad()
Exemplo n.º 13
0
    def remask(self, label: int) -> tuple[torch.Tensor, torch.Tensor]:
        generator = Generator(self.noise_dim, self.dataset.num_classes,
                              self.dataset.data_shape)
        for param in generator.parameters():
            param.requires_grad_()
        optimizer = optim.Adam(generator.parameters(), lr=self.remask_lr)
        optimizer.zero_grad()
        # mask = self.attack.mark.mask

        losses = AverageMeter('Loss', ':.4e')
        entropy = AverageMeter('Entropy', ':.4e')
        norm = AverageMeter('Norm', ':.4e')
        acc = AverageMeter('Acc', ':6.2f')
        torch.manual_seed(env['seed'])
        noise = torch.rand(1, self.noise_dim, device=env['device'])
        mark = torch.zeros(self.dataset.data_shape, device=env['device'])
        for _epoch in range(self.remask_epoch):
            losses.reset()
            entropy.reset()
            norm.reset()
            acc.reset()
            epoch_start = time.perf_counter()
            loader = self.loader
            if env['tqdm']:
                loader = tqdm(loader)
            for data in loader:
                _input, _label = self.model.get_data(data)
                batch_size = _label.size(0)
                poison_label = label * torch.ones_like(_label)
                mark = generator(
                    noise,
                    torch.tensor([label],
                                 device=poison_label.device,
                                 dtype=poison_label.dtype))
                poison_input = (_input + mark).clamp(0, 1)
                _output = self.model(poison_input)

                batch_acc = poison_label.eq(_output.argmax(1)).float().mean()
                batch_entropy = self.model.criterion(_output, poison_label)
                batch_norm = mark.flatten(start_dim=1).norm(p=1, dim=1).mean()
                batch_loss = batch_entropy + self.gamma_2 * batch_norm

                acc.update(batch_acc.item(), batch_size)
                entropy.update(batch_entropy.item(), batch_size)
                norm.update(batch_norm.item(), batch_size)
                losses.update(batch_loss.item(), batch_size)

                batch_loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            epoch_time = str(
                datetime.timedelta(seconds=int(time.perf_counter() -
                                               epoch_start)))
            pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                output_iter(_epoch + 1, self.remask_epoch),
                **ansi).ljust(64 if env['color'] else 35)
            _str = ' '.join([
                f'Loss: {losses.avg:.4f},'.ljust(20),
                f'Acc: {acc.avg:.2f}, '.ljust(20),
                f'Norm: {norm.avg:.4f},'.ljust(20),
                f'Entropy: {entropy.avg:.4f},'.ljust(20),
                f'Time: {epoch_time},'.ljust(20),
            ])
            prints(pre_str,
                   _str,
                   prefix='{upline}{clear_line}'.format(
                       **ansi) if env['tqdm'] else '',
                   indent=4)

        def get_data_fn(data, **kwargs):
            _input, _label = self.model.get_data(data)
            poison_label = torch.ones_like(_label) * label
            poison_input = (_input + mark).clamp(0, 1)
            return poison_input, poison_label

        self.model._validate(print_prefix='Validate Trigger Tgt',
                             get_data_fn=get_data_fn,
                             indent=4)

        if not self.attack.mark.random_pos:
            overlap = jaccard_idx(mark.mean(dim=0),
                                  self.real_mask,
                                  select_num=self.attack.mark.mark_height *
                                  self.attack.mark.mark_width)
            print(f'    Jaccard index: {overlap:.3f}')

        for param in generator.parameters():
            param.requires_grad = False
        return losses.avg, mark