예제 #1
0
    def validate_confidence(self, mode: str = 'valid', success_only: bool = True) -> float:
        r"""Get :attr:`self.target_class` confidence on dataset of :attr:`mode`.

        Args:
            mode (str): Dataset mode. Defaults to ``'valid'``.
            success_only (bool): Whether to only measure confidence
                on attack-successful inputs.
                Defaults to ``True``.

        Returns:
            float: Average confidence of :attr:`self.target_class`.
        """
        source_class = self.source_class or list(range(self.dataset.num_classes))
        source_class = source_class.copy()
        if self.target_class in source_class:
            source_class.remove(self.target_class)
        loader = self.dataset.get_dataloader(mode=mode, class_list=source_class)

        confidence = SmoothedValue()
        for data in loader:
            _input, _label = self.model.get_data(data)
            trigger_input = self.add_mark(_input)
            trigger_label = self.model.get_class(trigger_input)
            if success_only:
                trigger_input = trigger_input[trigger_label == self.target_class]
                if len(trigger_input) == 0:
                    continue
            batch_conf = self.model.get_prob(trigger_input)[:, self.target_class].mean()
            confidence.update(batch_conf, len(trigger_input))
        return confidence.global_avg
예제 #2
0
 def _validate(self,
               full=True,
               print_prefix='Validate',
               indent=0,
               verbose=True,
               loader: torch.utils.data.DataLoader = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               writer=None,
               main_tag: str = 'valid',
               tag: str = '',
               _epoch: int = None,
               **kwargs) -> tuple[float, float]:
     self.eval()
     if loader is None:
         loader = self.dataset.loader[
             'valid'] if full else self.dataset.loader['valid2']
     get_data_fn = get_data_fn if get_data_fn is not None else self.get_data
     loss_fn = loss_fn if loss_fn is not None else self.loss
     logger = MetricLogger()
     logger.meters['loss'] = SmoothedValue()
     logger.meters['top1'] = SmoothedValue()
     logger.meters['top5'] = SmoothedValue()
     loader_epoch = loader
     if verbose:
         header = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
         header = header.ljust(
             max(len(print_prefix), 30) + get_ansi_len(header))
         if env['tqdm']:
             header = '{upline}{clear_line}'.format(**ansi) + header
             loader_epoch = tqdm(loader_epoch)
         loader_epoch = logger.log_every(loader_epoch,
                                         header=header,
                                         indent=indent)
     for data in loader_epoch:
         _input, _label = get_data_fn(data, mode='valid', **kwargs)
         with torch.no_grad():
             _output = self(_input)
             loss = float(loss_fn(_input, _label, _output=_output,
                                  **kwargs))
             acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
             batch_size = int(_label.size(0))
             logger.meters['loss'].update(loss, batch_size)
             logger.meters['top1'].update(acc1, batch_size)
             logger.meters['top5'].update(acc5, batch_size)
     loss, acc = logger.meters['loss'].global_avg, logger.meters[
         'top1'].global_avg
     if writer is not None and _epoch is not None and main_tag:
         from torch.utils.tensorboard import SummaryWriter
         assert isinstance(writer, SummaryWriter)
         writer.add_scalars(main_tag='Loss/' + main_tag,
                            tag_scalar_dict={tag: loss},
                            global_step=_epoch)
         writer.add_scalars(main_tag='Acc/' + main_tag,
                            tag_scalar_dict={tag: acc},
                            global_step=_epoch)
     return loss, acc
예제 #3
0
    def get_cls_diff(self):
        r"""Get classification difference between
        original inputs and trigger inputs.

        Returns:
            float: Classification difference percentage.
        """
        diff = SmoothedValue()
        for data in self.dataset.loader['valid']:
            _input, _ = self.model.get_data(data)
            _class = self.model.get_class(_input)
            trigger_input = self.attack.add_mark(_input)
            trigger_class = self.model.get_class(trigger_input)
            result = _class.not_equal(trigger_class)
            diff.update(result.float().mean().item(), len(_input))
        return diff.global_avg
예제 #4
0
 def validate_confidence(self) -> float:
     confidence = SmoothedValue()
     with torch.no_grad():
         for data in self.dataset.loader['valid']:
             _input, _label = self.model.get_data(data)
             idx1 = _label != self.target_class
             _input = _input[idx1]
             _label = _label[idx1]
             if len(_input) == 0:
                 continue
             poison_input = self.add_mark(_input)
             poison_label = self.model.get_class(poison_input)
             idx2 = poison_label == self.target_class
             poison_input = poison_input[idx2]
             if len(poison_input) == 0:
                 continue
             batch_conf = self.model.get_prob(poison_input)[:, self.target_class].mean()
             confidence.update(batch_conf, len(poison_input))
     return confidence.global_avg
예제 #5
0
def compare(module1: nn.Module,
            module2: nn.Module,
            loader: torch.utils.data.DataLoader,
            print_prefix='Validate',
            indent=0,
            verbose=True,
            get_data_fn: Callable[..., tuple[torch.Tensor,
                                             torch.Tensor]] = None,
            **kwargs) -> float:
    logsoftmax = nn.LogSoftmax(dim=1)
    softmax = nn.Softmax(dim=1)
    module1.eval()
    module2.eval()
    get_data_fn = get_data_fn if get_data_fn is not None else lambda x: x

    def cross_entropy(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
        result: torch.Tensor = -softmax(p) * logsoftmax(q)
        return result.sum(1).mean()

    logger = MetricLogger()
    logger.meters['loss'] = SmoothedValue()
    loader_epoch = loader
    if verbose:
        header = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
        header = header.ljust(
            max(len(print_prefix), 30) + get_ansi_len(header))
        if env['tqdm']:
            header = '{upline}{clear_line}'.format(**ansi) + header
            loader_epoch = tqdm(loader_epoch)
        loader_epoch = logger.log_every(loader_epoch,
                                        header=header,
                                        indent=indent)
    with torch.no_grad():
        for data in loader_epoch:
            _input, _label = get_data_fn(data, **kwargs)
            _output1, _output2 = module1(_input), module2(_input)
            loss = float(cross_entropy(_output1, _output2))
            batch_size = int(_label.size(0))
            logger.meters['loss'].update(loss, batch_size)
    return logger.meters['loss'].global_avg
예제 #6
0
파일: pgd.py 프로젝트: ain-soph/trojanzoo
    def attack(self, verbose: int = 1, **kwargs) -> tuple[float, float]:
        validset = self.dataset.get_dataset('valid')
        testset, _ = self.dataset.split_dataset(validset, percent=0.3)
        loader = self.dataset.get_dataloader(mode='valid', dataset=testset,
                                             shuffle=True)
        fmt_str = '{global_avg:7.3f} ({min:7.3f}  {max:7.3f})'
        total_adv_target_conf = SmoothedValue(fmt=fmt_str)
        total_org_target_conf = SmoothedValue(fmt=fmt_str)
        succ_adv_target_conf = SmoothedValue(fmt=fmt_str)

        total_adv_org_conf = SmoothedValue(fmt=fmt_str)
        total_org_org_conf = SmoothedValue(fmt=fmt_str)
        succ_adv_org_conf = SmoothedValue(fmt=fmt_str)

        total_iter_list = SmoothedValue(fmt=fmt_str)
        succ_iter_list = SmoothedValue(fmt=fmt_str)

        succ_idx_list: list[int] = []
        for data in loader:
            rest_length = self.test_num - total_adv_target_conf.count
            if rest_length <= 0:
                break
            _input, _label = self.model.remove_misclassify(data)
            if len(_label) == 0:
                continue

            if len(_label) > rest_length:
                _input = _input[:rest_length]
                _label = _label[:rest_length]
            target = self.generate_target(_input, idx=self.target_idx) if self.target_class is None \
                else self.target_class * torch.ones_like(_label)
            adv_input = _input.clone().detach()
            iter_list = -torch.ones(len(_label), dtype=torch.long)
            current_idx = torch.arange(len(iter_list))
            for _ in range(max(self.num_restart, 1)):
                temp_adv_input, temp_iter_list = self.optimize(_input[current_idx],
                                                               target=target[current_idx], **kwargs)
                adv_input[current_idx] = temp_adv_input
                iter_list[current_idx] = temp_iter_list
                fail_idx = iter_list == -1
                if (~fail_idx).all():
                    break
                current_idx = current_idx[fail_idx]
            for i, _iter in enumerate(iter_list):
                if _iter != -1:
                    succ_idx_list.append(total_iter_list.count + i)
            adv_target_conf = self.model.get_target_prob(adv_input, target)
            adv_org_conf = self.model.get_target_prob(adv_input, _label)
            org_target_conf = self.model.get_target_prob(_input, target)
            org_org_conf = self.model.get_target_prob(_input, _label)

            total_adv_target_conf.update_list(adv_target_conf.detach().cpu().tolist())
            total_adv_org_conf.update_list(adv_org_conf.detach().cpu().tolist())
            succ_adv_target_conf.update_list(adv_target_conf[iter_list != -1].detach().cpu().tolist())
            succ_adv_org_conf.update_list(adv_org_conf[iter_list != -1].detach().cpu().tolist())
            total_org_target_conf.update_list(org_target_conf.detach().cpu().tolist())
            total_org_org_conf.update_list(org_org_conf.detach().cpu().tolist())

            total_iter_list.update_list(torch.where(iter_list != -1, iter_list, 2 *
                                        self.iteration * torch.ones_like(iter_list)).tolist())
            succ_iter_list.update_list(iter_list[iter_list != -1].tolist())
            if verbose >= 3:
                prints(f'{ansi["green"]}{succ_iter_list.count} / {total_iter_list.count}{ansi["reset"]}')
            if verbose >= 4:
                prints(f'{total_iter_list=:}', indent=4)
                prints(f'{succ_iter_list=:}', indent=4)
                prints()
                prints('-------------------------------------------------', indent=4)
                prints(f'{ansi["yellow"]}Target Class:{ansi["reset"]}', indent=4)
                prints(f'{total_adv_target_conf=:}', indent=8)
                prints(f'{total_org_target_conf=:}', indent=8)
                prints(f'{succ_adv_target_conf=:}', indent=8)
                prints()
                prints('-------------------------------------------------', indent=4)
                prints(f'{ansi["yellow"]}Original Class:{ansi["reset"]}', indent=4)
                prints(f'{total_adv_org_conf=:}', indent=8)
                prints(f'{total_org_org_conf=:}', indent=8)
                prints(f'{succ_adv_org_conf=:}', indent=8)
        if verbose:
            prints(f'{ansi["green"]}{succ_iter_list.count} / {total_iter_list.count}{ansi["reset"]}')
            prints(succ_idx_list)
        if verbose >= 2:
            prints(f'{total_iter_list=:}', indent=4)
            prints(f'{succ_iter_list=:}', indent=4)
            prints()
            prints('-------------------------------------------------', indent=4)
            prints(f'{ansi["yellow"]}Target Class:{ansi["reset"]}', indent=4)
            prints(f'{total_adv_target_conf=:}', indent=8)
            prints(f'{total_org_target_conf=:}', indent=8)
            prints(f'{succ_adv_target_conf=:}', indent=8)
            prints()
            prints('-------------------------------------------------', indent=4)
            prints(f'{ansi["yellow"]}Original Class:{ansi["reset"]}', indent=4)
            prints(f'{total_adv_org_conf=:}', indent=8)
            prints(f'{total_org_org_conf=:}', indent=8)
            prints(f'{succ_adv_org_conf=:}', indent=8)
        return float(succ_iter_list.count) / total_iter_list.count, total_iter_list.global_avg
예제 #7
0
    def _train(self,
               epoch: int,
               optimizer: Optimizer,
               lr_scheduler: _LRScheduler = None,
               grad_clip: float = None,
               print_prefix: str = 'Epoch',
               start_epoch: int = 0,
               validate_interval: int = 10,
               save: bool = False,
               amp: bool = False,
               loader_train: torch.utils.data.DataLoader = None,
               loader_valid: torch.utils.data.DataLoader = None,
               epoch_fn: Callable[..., None] = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               after_loss_fn: Callable[..., None] = None,
               validate_fn: Callable[..., tuple[float, float]] = None,
               save_fn: Callable[..., None] = None,
               file_path: str = None,
               folder_path: str = None,
               suffix: str = None,
               writer=None,
               main_tag: str = 'train',
               tag: str = '',
               verbose: bool = True,
               indent: int = 0,
               **kwargs):
        loader_train = loader_train if loader_train is not None else self.dataset.loader[
            'train']
        get_data_fn = get_data_fn if callable(get_data_fn) else self.get_data
        loss_fn = loss_fn if callable(loss_fn) else self.loss
        validate_fn = validate_fn if callable(validate_fn) else self._validate
        save_fn = save_fn if callable(save_fn) else self.save
        # if not callable(iter_fn) and hasattr(self, 'iter_fn'):
        #     iter_fn = getattr(self, 'iter_fn')
        if not callable(epoch_fn) and hasattr(self, 'epoch_fn'):
            epoch_fn = getattr(self, 'epoch_fn')
        if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'):
            after_loss_fn = getattr(self, 'after_loss_fn')

        scaler: torch.cuda.amp.GradScaler = None
        if not env['num_gpus']:
            amp = False
        if amp:
            scaler = torch.cuda.amp.GradScaler()
        _, best_acc = validate_fn(loader=loader_valid,
                                  get_data_fn=get_data_fn,
                                  loss_fn=loss_fn,
                                  writer=None,
                                  tag=tag,
                                  _epoch=start_epoch,
                                  verbose=verbose,
                                  indent=indent,
                                  **kwargs)

        params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            params.extend(param_group['params'])
        total_iter = epoch * len(loader_train)
        for _epoch in range(epoch):
            _epoch += 1
            if callable(epoch_fn):
                self.activate_params([])
                epoch_fn(optimizer=optimizer,
                         lr_scheduler=lr_scheduler,
                         _epoch=_epoch,
                         epoch=epoch,
                         start_epoch=start_epoch)
                self.activate_params(params)
            logger = MetricLogger()
            logger.meters['loss'] = SmoothedValue()
            logger.meters['top1'] = SmoothedValue()
            logger.meters['top5'] = SmoothedValue()
            loader_epoch = loader_train
            if verbose:
                header = '{blue_light}{0}: {1}{reset}'.format(
                    print_prefix, output_iter(_epoch, epoch), **ansi)
                header = header.ljust(30 + get_ansi_len(header))
                if env['tqdm']:
                    header = '{upline}{clear_line}'.format(**ansi) + header
                    loader_epoch = tqdm(loader_epoch)
                loader_epoch = logger.log_every(loader_epoch,
                                                header=header,
                                                indent=indent)
            self.train()
            self.activate_params(params)
            optimizer.zero_grad()
            for i, data in enumerate(loader_epoch):
                _iter = _epoch * len(loader_train) + i
                # data_time.update(time.perf_counter() - end)
                _input, _label = get_data_fn(data, mode='train')
                _output = self(_input, amp=amp)
                loss = loss_fn(_input, _label, _output=_output, amp=amp)
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epoch=epoch)
                    optimizer.step()
                optimizer.zero_grad()
                acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                logger.meters['loss'].update(float(loss), batch_size)
                logger.meters['top1'].update(acc1, batch_size)
                logger.meters['top5'].update(acc5, batch_size)
                empty_cache(
                )  # TODO: should it be outside of the dataloader loop?
            self.eval()
            self.activate_params([])
            loss, acc = logger.meters['loss'].global_avg, logger.meters[
                'top1'].global_avg
            if writer is not None:
                from torch.utils.tensorboard import SummaryWriter
                assert isinstance(writer, SummaryWriter)
                writer.add_scalars(main_tag='Loss/' + main_tag,
                                   tag_scalar_dict={tag: loss},
                                   global_step=_epoch + start_epoch)
                writer.add_scalars(main_tag='Acc/' + main_tag,
                                   tag_scalar_dict={tag: acc},
                                   global_step=_epoch + start_epoch)
            if lr_scheduler:
                lr_scheduler.step()
            if validate_interval != 0:
                if _epoch % validate_interval == 0 or _epoch == epoch:
                    _, cur_acc = validate_fn(loader=loader_valid,
                                             get_data_fn=get_data_fn,
                                             loss_fn=loss_fn,
                                             writer=writer,
                                             tag=tag,
                                             _epoch=_epoch + start_epoch,
                                             verbose=verbose,
                                             indent=indent,
                                             **kwargs)
                    if cur_acc >= best_acc:
                        if verbose:
                            prints('{green}best result update!{reset}'.format(
                                **ansi),
                                   indent=indent)
                            prints(
                                f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                                indent=indent)
                        best_acc = cur_acc
                        if save:
                            save_fn(file_path=file_path,
                                    folder_path=folder_path,
                                    suffix=suffix,
                                    verbose=verbose)
                    if verbose:
                        prints('-' * 50, indent=indent)
        self.zero_grad()