def validate_confidence(self, mode: str = 'valid', success_only: bool = True) -> float: r"""Get :attr:`self.target_class` confidence on dataset of :attr:`mode`. Args: mode (str): Dataset mode. Defaults to ``'valid'``. success_only (bool): Whether to only measure confidence on attack-successful inputs. Defaults to ``True``. Returns: float: Average confidence of :attr:`self.target_class`. """ source_class = self.source_class or list(range(self.dataset.num_classes)) source_class = source_class.copy() if self.target_class in source_class: source_class.remove(self.target_class) loader = self.dataset.get_dataloader(mode=mode, class_list=source_class) confidence = SmoothedValue() for data in loader: _input, _label = self.model.get_data(data) trigger_input = self.add_mark(_input) trigger_label = self.model.get_class(trigger_input) if success_only: trigger_input = trigger_input[trigger_label == self.target_class] if len(trigger_input) == 0: continue batch_conf = self.model.get_prob(trigger_input)[:, self.target_class].mean() confidence.update(batch_conf, len(trigger_input)) return confidence.global_avg
def _validate(self, full=True, print_prefix='Validate', indent=0, verbose=True, loader: torch.utils.data.DataLoader = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, writer=None, main_tag: str = 'valid', tag: str = '', _epoch: int = None, **kwargs) -> tuple[float, float]: self.eval() if loader is None: loader = self.dataset.loader[ 'valid'] if full else self.dataset.loader['valid2'] get_data_fn = get_data_fn if get_data_fn is not None else self.get_data loss_fn = loss_fn if loss_fn is not None else self.loss logger = MetricLogger() logger.meters['loss'] = SmoothedValue() logger.meters['top1'] = SmoothedValue() logger.meters['top5'] = SmoothedValue() loader_epoch = loader if verbose: header = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust( max(len(print_prefix), 30) + get_ansi_len(header)) if env['tqdm']: header = '{upline}{clear_line}'.format(**ansi) + header loader_epoch = tqdm(loader_epoch) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) for data in loader_epoch: _input, _label = get_data_fn(data, mode='valid', **kwargs) with torch.no_grad(): _output = self(_input) loss = float(loss_fn(_input, _label, _output=_output, **kwargs)) acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) logger.meters['loss'].update(loss, batch_size) logger.meters['top1'].update(acc1, batch_size) logger.meters['top5'].update(acc5, batch_size) loss, acc = logger.meters['loss'].global_avg, logger.meters[ 'top1'].global_avg if writer is not None and _epoch is not None and main_tag: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch) return loss, acc
def get_cls_diff(self): r"""Get classification difference between original inputs and trigger inputs. Returns: float: Classification difference percentage. """ diff = SmoothedValue() for data in self.dataset.loader['valid']: _input, _ = self.model.get_data(data) _class = self.model.get_class(_input) trigger_input = self.attack.add_mark(_input) trigger_class = self.model.get_class(trigger_input) result = _class.not_equal(trigger_class) diff.update(result.float().mean().item(), len(_input)) return diff.global_avg
def validate_confidence(self) -> float: confidence = SmoothedValue() with torch.no_grad(): for data in self.dataset.loader['valid']: _input, _label = self.model.get_data(data) idx1 = _label != self.target_class _input = _input[idx1] _label = _label[idx1] if len(_input) == 0: continue poison_input = self.add_mark(_input) poison_label = self.model.get_class(poison_input) idx2 = poison_label == self.target_class poison_input = poison_input[idx2] if len(poison_input) == 0: continue batch_conf = self.model.get_prob(poison_input)[:, self.target_class].mean() confidence.update(batch_conf, len(poison_input)) return confidence.global_avg
def compare(module1: nn.Module, module2: nn.Module, loader: torch.utils.data.DataLoader, print_prefix='Validate', indent=0, verbose=True, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> float: logsoftmax = nn.LogSoftmax(dim=1) softmax = nn.Softmax(dim=1) module1.eval() module2.eval() get_data_fn = get_data_fn if get_data_fn is not None else lambda x: x def cross_entropy(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor: result: torch.Tensor = -softmax(p) * logsoftmax(q) return result.sum(1).mean() logger = MetricLogger() logger.meters['loss'] = SmoothedValue() loader_epoch = loader if verbose: header = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust( max(len(print_prefix), 30) + get_ansi_len(header)) if env['tqdm']: header = '{upline}{clear_line}'.format(**ansi) + header loader_epoch = tqdm(loader_epoch) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) with torch.no_grad(): for data in loader_epoch: _input, _label = get_data_fn(data, **kwargs) _output1, _output2 = module1(_input), module2(_input) loss = float(cross_entropy(_output1, _output2)) batch_size = int(_label.size(0)) logger.meters['loss'].update(loss, batch_size) return logger.meters['loss'].global_avg
def attack(self, verbose: int = 1, **kwargs) -> tuple[float, float]: validset = self.dataset.get_dataset('valid') testset, _ = self.dataset.split_dataset(validset, percent=0.3) loader = self.dataset.get_dataloader(mode='valid', dataset=testset, shuffle=True) fmt_str = '{global_avg:7.3f} ({min:7.3f} {max:7.3f})' total_adv_target_conf = SmoothedValue(fmt=fmt_str) total_org_target_conf = SmoothedValue(fmt=fmt_str) succ_adv_target_conf = SmoothedValue(fmt=fmt_str) total_adv_org_conf = SmoothedValue(fmt=fmt_str) total_org_org_conf = SmoothedValue(fmt=fmt_str) succ_adv_org_conf = SmoothedValue(fmt=fmt_str) total_iter_list = SmoothedValue(fmt=fmt_str) succ_iter_list = SmoothedValue(fmt=fmt_str) succ_idx_list: list[int] = [] for data in loader: rest_length = self.test_num - total_adv_target_conf.count if rest_length <= 0: break _input, _label = self.model.remove_misclassify(data) if len(_label) == 0: continue if len(_label) > rest_length: _input = _input[:rest_length] _label = _label[:rest_length] target = self.generate_target(_input, idx=self.target_idx) if self.target_class is None \ else self.target_class * torch.ones_like(_label) adv_input = _input.clone().detach() iter_list = -torch.ones(len(_label), dtype=torch.long) current_idx = torch.arange(len(iter_list)) for _ in range(max(self.num_restart, 1)): temp_adv_input, temp_iter_list = self.optimize(_input[current_idx], target=target[current_idx], **kwargs) adv_input[current_idx] = temp_adv_input iter_list[current_idx] = temp_iter_list fail_idx = iter_list == -1 if (~fail_idx).all(): break current_idx = current_idx[fail_idx] for i, _iter in enumerate(iter_list): if _iter != -1: succ_idx_list.append(total_iter_list.count + i) adv_target_conf = self.model.get_target_prob(adv_input, target) adv_org_conf = self.model.get_target_prob(adv_input, _label) org_target_conf = self.model.get_target_prob(_input, target) org_org_conf = self.model.get_target_prob(_input, _label) total_adv_target_conf.update_list(adv_target_conf.detach().cpu().tolist()) total_adv_org_conf.update_list(adv_org_conf.detach().cpu().tolist()) succ_adv_target_conf.update_list(adv_target_conf[iter_list != -1].detach().cpu().tolist()) succ_adv_org_conf.update_list(adv_org_conf[iter_list != -1].detach().cpu().tolist()) total_org_target_conf.update_list(org_target_conf.detach().cpu().tolist()) total_org_org_conf.update_list(org_org_conf.detach().cpu().tolist()) total_iter_list.update_list(torch.where(iter_list != -1, iter_list, 2 * self.iteration * torch.ones_like(iter_list)).tolist()) succ_iter_list.update_list(iter_list[iter_list != -1].tolist()) if verbose >= 3: prints(f'{ansi["green"]}{succ_iter_list.count} / {total_iter_list.count}{ansi["reset"]}') if verbose >= 4: prints(f'{total_iter_list=:}', indent=4) prints(f'{succ_iter_list=:}', indent=4) prints() prints('-------------------------------------------------', indent=4) prints(f'{ansi["yellow"]}Target Class:{ansi["reset"]}', indent=4) prints(f'{total_adv_target_conf=:}', indent=8) prints(f'{total_org_target_conf=:}', indent=8) prints(f'{succ_adv_target_conf=:}', indent=8) prints() prints('-------------------------------------------------', indent=4) prints(f'{ansi["yellow"]}Original Class:{ansi["reset"]}', indent=4) prints(f'{total_adv_org_conf=:}', indent=8) prints(f'{total_org_org_conf=:}', indent=8) prints(f'{succ_adv_org_conf=:}', indent=8) if verbose: prints(f'{ansi["green"]}{succ_iter_list.count} / {total_iter_list.count}{ansi["reset"]}') prints(succ_idx_list) if verbose >= 2: prints(f'{total_iter_list=:}', indent=4) prints(f'{succ_iter_list=:}', indent=4) prints() prints('-------------------------------------------------', indent=4) prints(f'{ansi["yellow"]}Target Class:{ansi["reset"]}', indent=4) prints(f'{total_adv_target_conf=:}', indent=8) prints(f'{total_org_target_conf=:}', indent=8) prints(f'{succ_adv_target_conf=:}', indent=8) prints() prints('-------------------------------------------------', indent=4) prints(f'{ansi["yellow"]}Original Class:{ansi["reset"]}', indent=4) prints(f'{total_adv_org_conf=:}', indent=8) prints(f'{total_org_org_conf=:}', indent=8) prints(f'{succ_adv_org_conf=:}', indent=8) return float(succ_iter_list.count) / total_iter_list.count, total_iter_list.global_avg
def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, grad_clip: float = None, print_prefix: str = 'Epoch', start_epoch: int = 0, validate_interval: int = 10, save: bool = False, amp: bool = False, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, epoch_fn: Callable[..., None] = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, after_loss_fn: Callable[..., None] = None, validate_fn: Callable[..., tuple[float, float]] = None, save_fn: Callable[..., None] = None, file_path: str = None, folder_path: str = None, suffix: str = None, writer=None, main_tag: str = 'train', tag: str = '', verbose: bool = True, indent: int = 0, **kwargs): loader_train = loader_train if loader_train is not None else self.dataset.loader[ 'train'] get_data_fn = get_data_fn if callable(get_data_fn) else self.get_data loss_fn = loss_fn if callable(loss_fn) else self.loss validate_fn = validate_fn if callable(validate_fn) else self._validate save_fn = save_fn if callable(save_fn) else self.save # if not callable(iter_fn) and hasattr(self, 'iter_fn'): # iter_fn = getattr(self, 'iter_fn') if not callable(epoch_fn) and hasattr(self, 'epoch_fn'): epoch_fn = getattr(self, 'epoch_fn') if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'): after_loss_fn = getattr(self, 'after_loss_fn') scaler: torch.cuda.amp.GradScaler = None if not env['num_gpus']: amp = False if amp: scaler = torch.cuda.amp.GradScaler() _, best_acc = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, writer=None, tag=tag, _epoch=start_epoch, verbose=verbose, indent=indent, **kwargs) params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) total_iter = epoch * len(loader_train) for _epoch in range(epoch): _epoch += 1 if callable(epoch_fn): self.activate_params([]) epoch_fn(optimizer=optimizer, lr_scheduler=lr_scheduler, _epoch=_epoch, epoch=epoch, start_epoch=start_epoch) self.activate_params(params) logger = MetricLogger() logger.meters['loss'] = SmoothedValue() logger.meters['top1'] = SmoothedValue() logger.meters['top5'] = SmoothedValue() loader_epoch = loader_train if verbose: header = '{blue_light}{0}: {1}{reset}'.format( print_prefix, output_iter(_epoch, epoch), **ansi) header = header.ljust(30 + get_ansi_len(header)) if env['tqdm']: header = '{upline}{clear_line}'.format(**ansi) + header loader_epoch = tqdm(loader_epoch) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) self.train() self.activate_params(params) optimizer.zero_grad() for i, data in enumerate(loader_epoch): _iter = _epoch * len(loader_train) + i # data_time.update(time.perf_counter() - end) _input, _label = get_data_fn(data, mode='train') _output = self(_input, amp=amp) loss = loss_fn(_input, _label, _output=_output, amp=amp) if amp: scaler.scale(loss).backward() if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) scaler.step(optimizer) scaler.update() else: loss.backward() if grad_clip is not None: nn.utils.clip_grad_norm_(params) if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) # start_epoch=start_epoch, _epoch=_epoch, epoch=epoch) optimizer.step() optimizer.zero_grad() acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) logger.meters['loss'].update(float(loss), batch_size) logger.meters['top1'].update(acc1, batch_size) logger.meters['top5'].update(acc5, batch_size) empty_cache( ) # TODO: should it be outside of the dataloader loop? self.eval() self.activate_params([]) loss, acc = logger.meters['loss'].global_avg, logger.meters[ 'top1'].global_avg if writer is not None: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch + start_epoch) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch + start_epoch) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if _epoch % validate_interval == 0 or _epoch == epoch: _, cur_acc = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, writer=writer, tag=tag, _epoch=_epoch + start_epoch, verbose=verbose, indent=indent, **kwargs) if cur_acc >= best_acc: if verbose: prints('{green}best result update!{reset}'.format( **ansi), indent=indent) prints( f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose) if verbose: prints('-' * 50, indent=indent) self.zero_grad()