def get_pred_labels(self) -> torch.Tensor: r"""Get predicted labels for test inputs. Returns: torch.Tensor: ``torch.BoolTensor`` with shape ``(2 * defense_input_num)``. """ logger = MetricLogger(meter_length=40) str_format = '{global_avg:5.3f} ({min:5.3f}, {max:5.3f})' logger.create_meters(clean_score=str_format, poison_score=str_format) test_set = TensorListDataset(self.test_input, self.test_label) test_loader = self.dataset.get_dataloader(mode='valid', dataset=test_set) for data in logger.log_every(test_loader): _input, _label = self.model.get_data(data) trigger_input = self.attack.add_mark(_input) logger.meters['clean_score'].update_list( self.get_score(_input).tolist()) logger.meters['poison_score'].update_list( self.get_score(trigger_input).tolist()) clean_score = torch.as_tensor(logger.meters['clean_score'].deque) poison_score = torch.as_tensor(logger.meters['poison_score'].deque) clean_score_sorted = clean_score.msort() threshold_low = float(clean_score_sorted[int(self.strip_fpr * len(poison_score))]) entropy = torch.cat((clean_score, poison_score)) print(f'Threshold: {threshold_low:5.3f}') return torch.where(entropy < threshold_low, torch.ones_like(entropy).bool(), torch.zeros_like(entropy).bool())
def get_pred_label(self, img: torch.Tensor, logger: MetricLogger = None) -> bool: r"""Get the prediction label of one certain image (poisoned or not). Args: img (torch.Tensor): Image tensor (on GPU) with shape ``(C, H, W)``. logger (trojanzoo.utils.logger.MetricLogger): output logger. Defaults to ``None``. Returns: bool: Whether the image tensor :attr:`img` is poisoned. """ # get dominant color dom_c = self.get_dominant_color(img).unsqueeze(-1).unsqueeze( -1) # (C, 1, 1) # generate random numbers height, width = img.shape[-2:] pos_height = torch.randint(low=0, high=height - self.mark_size[0], size=[self.neo_sample_num, 1]) pos_width = torch.randint(low=0, high=width - self.mark_size[1], size=[self.neo_sample_num, 1]) pos_list = torch.stack([pos_height, pos_width], dim=1) # (neo_sample_num, 2) # block potential triggers on _input block_input = img.repeat(self.neo_sample_num, 1, 1, 1) # (neo_sample_num, C, H, W) for i in range(self.neo_sample_num): x = pos_list[i][0] y = pos_list[i][1] block_input[i, :, x:x + self.mark_size[0], y:y + self.mark_size[1]] = dom_c # get potential triggers org_class = self.model.get_class(img.unsqueeze(0)).item() # (1) block_class = self.model.get_class( block_input).cpu() # (neo_sample_num) # confirm triggers pos_pairs = pos_list[block_class != org_class] # (*, 2) for pos in pos_pairs: self.attack.mark.mark_height_offset = pos[0] self.attack.mark.mark_width_offset = pos[1] self.attack.mark.mark.fill_(1.0) self.attack.mark.mark[:-1] = img[..., pos[0]:pos[0] + self.mark_size[0], pos[1]:pos[1] + self.mark_size[1]] cls_diff = self.get_cls_diff() if cls_diff > self.neo_asr_threshold: jaccard_idx = mask_jaccard(self.attack.mark.get_mask(), self.real_mask, select_num=self.select_num) logger.update(cls_diff=cls_diff, jaccard_idx=jaccard_idx) return True return False
def compare(module1: nn.Module, module2: nn.Module, loader: torch.utils.data.DataLoader, print_prefix='Validate', indent=0, verbose=True, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = nn.CrossEntropyLoss(), **kwargs) -> float: module1.eval() module2.eval() get_data_fn = get_data_fn if get_data_fn is not None else lambda x: x logger = MetricLogger() logger.create_meters(loss=None) loader_epoch = loader if verbose: header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust( max(len(print_prefix), 30) + get_ansi_len(header)) if env['tqdm']: loader_epoch = tqdm(loader_epoch, leave=False) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) for data in loader_epoch: _input, _label = get_data_fn(data, **kwargs) _output1: torch.Tensor = module1(_input) _output2: torch.Tensor = module2(_input) loss = criterion(_output1, _output2.softmax(1)).item() batch_size = int(_label.size(0)) logger.update(n=batch_size, loss=loss) return logger.meters['loss'].global_avg
def validate(module: nn.Module, num_classes: int, loader: torch.utils.data.DataLoader, print_prefix: str = 'Validate', indent: int = 0, verbose: bool = True, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, forward_fn: Callable[..., torch.Tensor] = None, loss_fn: Callable[..., torch.Tensor] = None, writer=None, main_tag: str = 'valid', tag: str = '', _epoch: int = None, accuracy_fn: Callable[..., list[float]] = None, **kwargs) -> tuple[float, float]: r"""Evaluate the model. Returns: (float, float): Accuracy and loss. """ module.eval() get_data_fn = get_data_fn or (lambda x: x) forward_fn = forward_fn or module.__call__ loss_fn = loss_fn or nn.CrossEntropyLoss() accuracy_fn = accuracy_fn or accuracy logger = MetricLogger() logger.create_meters(loss=None, top1=None, top5=None) loader_epoch = loader if verbose: header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust( max(len(print_prefix), 30) + get_ansi_len(header)) loader_epoch = logger.log_every(loader, header=header, tqdm_header='Batch', indent=indent) for data in loader_epoch: _input, _label = get_data_fn(data, mode='valid', **kwargs) with torch.no_grad(): _output = forward_fn(_input) loss = float(loss_fn(_input, _label, _output=_output, **kwargs)) acc1, acc5 = accuracy_fn(_output, _label, num_classes=num_classes, topk=(1, 5)) batch_size = int(_label.size(0)) logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5) acc, loss = (logger.meters['top1'].global_avg, logger.meters['loss'].global_avg) if writer is not None and _epoch is not None and main_tag: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch) return acc, loss
def validate_mask_generator(self): loader = self.dataset.loader['valid'] dataset = loader.dataset logger = MetricLogger() logger.create_meters(loss=None, div=None, norm=None) idx = torch.randperm(len(dataset)) pos = 0 print_prefix = 'Validate' header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header)) for data in logger.log_every(loader, header=header): _input, _label = self.model.get_data(data) batch_size = len(_input) data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size]) _input2, _label2 = self.model.get_data(data2) pos += batch_size _mask = self.get_mask(_input) _mask2 = self.get_mask(_input2) input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1) mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5 loss_div = input_dist.div(mask_dist).mean() loss_norm = _mask.sub(self.mask_density).relu().mean() loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item())
def get_pred_labels(self) -> torch.Tensor: logger = MetricLogger(meter_length=40) str_format = '{global_avg:5.3f} ({min:5.3f}, {max:5.3f})' logger.create_meters(cls_diff=str_format, jaccard_idx=str_format) test_set = TensorListDataset(self.test_input, self.test_label) test_loader = self.dataset.get_dataloader(mode='valid', dataset=test_set, batch_size=1) clean_list = [] poison_list = [] for data in logger.log_every(test_loader): _input: torch.Tensor = data[0] _input = _input.to(env['device'], non_blocking=True) trigger_input = self.attack.add_mark(_input) clean_list.append(self.get_pred_label(_input[0], logger=logger)) poison_list.append( self.get_pred_label(trigger_input[0], logger=logger)) return torch.as_tensor(clean_list + poison_list, dtype=torch.bool)
def train_mask_generator(self, verbose: bool = True): r"""Train :attr:`self.mask_generator`.""" optimizer = torch.optim.Adam(self.mask_generator.parameters(), lr=1e-2, betas=(0.5, 0.9)) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.train_mask_epochs) loader = self.dataset.loader['train'] dataset = loader.dataset logger = MetricLogger() logger.create_meters(loss=None, div=None, norm=None) print_prefix = 'Mask Epoch' for _epoch in range(self.train_mask_epochs): _epoch += 1 idx = torch.randperm(len(dataset)) pos = 0 logger.reset() header: str = '{blue_light}{0}: {1}{reset}'.format( print_prefix, output_iter(_epoch, self.train_mask_epochs), **ansi) header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header)) self.mask_generator.train() for data in logger.log_every(loader, header=header) if verbose else loader: optimizer.zero_grad() _input, _label = self.model.get_data(data) batch_size = len(_input) data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size]) _input2, _label2 = self.model.get_data(data2) pos += batch_size _mask = self.get_mask(_input) _mask2 = self.get_mask(_input2) input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1) mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5 loss_div = input_dist.div(mask_dist).mean() loss_norm = _mask.sub(self.mask_density).relu().mean() loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div loss.backward() optimizer.step() logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item()) lr_scheduler.step() self.mask_generator.eval() if verbose and (_epoch % (max(self.train_mask_epochs // 5, 1)) == 0 or _epoch == self.train_mask_epochs): self.validate_mask_generator() optimizer.zero_grad()
def compare(module1: nn.Module, module2: nn.Module, loader: torch.utils.data.DataLoader, print_prefix='Validate', indent=0, verbose=True, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> float: logsoftmax = nn.LogSoftmax(dim=1) softmax = nn.Softmax(dim=1) module1.eval() module2.eval() get_data_fn = get_data_fn if get_data_fn is not None else lambda x: x def cross_entropy(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor: result: torch.Tensor = -softmax(p) * logsoftmax(q) return result.sum(1).mean() logger = MetricLogger() logger.meters['loss'] = SmoothedValue() loader_epoch = loader if verbose: header = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust( max(len(print_prefix), 30) + get_ansi_len(header)) if env['tqdm']: header = '{upline}{clear_line}'.format(**ansi) + header loader_epoch = tqdm(loader_epoch) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) with torch.no_grad(): for data in loader_epoch: _input, _label = get_data_fn(data, **kwargs) _output1, _output2 = module1(_input), module2(_input) loss = float(cross_entropy(_output1, _output2)) batch_size = int(_label.size(0)) logger.meters['loss'].update(loss, batch_size) return logger.meters['loss'].global_avg
def _validate(self, full=True, print_prefix='Validate', indent=0, verbose=True, loader: torch.utils.data.DataLoader = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, writer=None, main_tag: str = 'valid', tag: str = '', _epoch: int = None, **kwargs) -> tuple[float, float]: self.eval() if loader is None: loader = self.dataset.loader[ 'valid'] if full else self.dataset.loader['valid2'] get_data_fn = get_data_fn if get_data_fn is not None else self.get_data loss_fn = loss_fn if loss_fn is not None else self.loss logger = MetricLogger() logger.meters['loss'] = SmoothedValue() logger.meters['top1'] = SmoothedValue() logger.meters['top5'] = SmoothedValue() loader_epoch = loader if verbose: header = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust( max(len(print_prefix), 30) + get_ansi_len(header)) if env['tqdm']: header = '{upline}{clear_line}'.format(**ansi) + header loader_epoch = tqdm(loader_epoch) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) for data in loader_epoch: _input, _label = get_data_fn(data, mode='valid', **kwargs) with torch.no_grad(): _output = self(_input) loss = float(loss_fn(_input, _label, _output=_output, **kwargs)) acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) logger.meters['loss'].update(loss, batch_size) logger.meters['top1'].update(acc1, batch_size) logger.meters['top5'].update(acc5, batch_size) loss, acc = logger.meters['loss'].global_avg, logger.meters[ 'top1'].global_avg if writer is not None and _epoch is not None and main_tag: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch) return loss, acc
def _get_asr_result(self, marks: torch.Tensor) -> torch.Tensor: r"""Get attack succ rate result for each mark in :attr:`marks`. Args: marks (torch.Tensor): Marks tensor with shape ``(N, C, H, W)``. Returns: torch.Tensor: Attack succ rate tensor with shape ``(N)``. """ asr_list = [] logger = MetricLogger(meter_length=35, indent=4) logger.create_meters(asr='{median:.3f} ({min:.3f} {max:.3f})') for mark in logger.log_every(marks, header='mark', tqdm_header='mark'): self.mark.mark[:-1] = mark asr, _ = self.model._validate(get_data_fn=self.get_data, keep_org=False, poison_label=True, verbose=False, loader=self.loader_valid) # Original code considers an untargeted-like attack scenario. # org_acc, _ = self.model._validate(get_data_fn=self.get_data, keep_org=False, # poison_label=False, verbose=False) # asr = 100 - org_acc logger.update(asr=asr) asr_list.append(asr) return torch.tensor(asr_list)
def attack(self, epochs: int, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, validate_interval: int = 1, save: bool = False, verbose: bool = True, **kwargs): if verbose: print('train mask generator') self.mark_generator.requires_grad_(False) self.mask_generator.requires_grad_() self.model.requires_grad_(False) self.train_mask_generator(verbose=verbose) if verbose: print() print('train mark generator and model') self.mark_generator.requires_grad_() self.mask_generator.requires_grad_(False) if not self.natural: params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) self.model.activate_params(params) mark_optimizer = torch.optim.Adam(self.mark_generator.parameters(), lr=1e-2, betas=(0.5, 0.9)) mark_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( mark_optimizer, T_max=epochs) loader = self.dataset.loader['train'] dataset = loader.dataset logger = MetricLogger() logger.create_meters(loss=None, div=None, ce=None) if validate_interval != 0: best_validate_result = self.validate_fn(verbose=verbose) best_asr = best_validate_result[0] for _epoch in range(epochs): _epoch += 1 idx = torch.randperm(len(dataset)) pos = 0 logger.reset() if not self.natural: self.model.train() self.mark_generator.train() header: str = '{blue_light}{0}: {1}{reset}'.format( 'Epoch', output_iter(_epoch, epochs), **ansi) header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header)) for data in logger.log_every(loader, header=header) if verbose else loader: if not self.natural: optimizer.zero_grad() mark_optimizer.zero_grad() _input, _label = self.model.get_data(data) batch_size = len(_input) data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size]) _input2, _label2 = self.model.get_data(data2) pos += batch_size final_input, final_label = _input.clone(), _label.clone() # generate trigger input trigger_dec, trigger_int = math.modf(len(_label) * self.poison_percent) trigger_int = int(trigger_int) if random.uniform(0, 1) < trigger_dec: trigger_int += 1 x = _input[:trigger_int] trigger_mark, trigger_mask = self.get_mark(x), self.get_mask(x) trigger_input = x + trigger_mask * (trigger_mark - x) final_input[:trigger_int] = trigger_input final_label[:trigger_int] = self.target_class # generate cross input cross_dec, cross_int = math.modf(len(_label) * self.cross_percent) cross_int = int(cross_int) if random.uniform(0, 1) < cross_dec: cross_int += 1 x = _input[trigger_int:trigger_int + cross_int] x2 = _input2[trigger_int:trigger_int + cross_int] cross_mark, cross_mask = self.get_mark(x2), self.get_mask(x2) cross_input = x + cross_mask * (cross_mark - x) final_input[trigger_int:trigger_int + cross_int] = cross_input # div loss if len(trigger_input) <= len(cross_input): length = len(trigger_input) cross_input = cross_input[:length] cross_mark = cross_mark[:length] cross_mask = cross_mask[:length] else: length = len(cross_input) trigger_input = trigger_input[:length] trigger_mark = trigger_mark[:length] trigger_mask = trigger_mask[:length] input_dist: torch.Tensor = (trigger_input - cross_input).flatten(1).norm(p=2, dim=1) mark_dist: torch.Tensor = (trigger_mark - cross_mark).flatten(1).norm(p=2, dim=1) + 1e-5 loss_ce = self.model.loss(final_input, final_label) loss_div = input_dist.div(mark_dist).mean() loss = loss_ce + self.lambda_div * loss_div loss.backward() if not self.natural: optimizer.step() mark_optimizer.step() logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), ce=loss_ce.item()) if not self.natural and lr_scheduler: lr_scheduler.step() mark_scheduler.step() if not self.natural: self.model.eval() self.mark_generator.eval() if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs): validate_result = self.validate_fn(verbose=verbose) cur_asr = validate_result[0] if cur_asr >= best_asr: best_validate_result = validate_result best_asr = cur_asr if save: self.save() if not self.natural: optimizer.zero_grad() mark_optimizer.zero_grad() self.mark_generator.requires_grad_(False) self.mask_generator.requires_grad_(False) self.model.requires_grad_(False) return best_validate_result
def optimize_mark(self, label: int, loader: Iterable = None, logger_header: str = '', verbose: bool = True, **kwargs) -> tuple[torch.Tensor, float]: r""" Args: label (int): The class label to optimize. loader (collections.abc.Iterable): Data loader to optimize trigger. Defaults to ``self.dataset.loader['train']``. logger_header (str): Header string of logger. Defaults to ``''``. verbose (bool): Whether to use logger for output. Defaults to ``True``. **kwargs: Keyword arguments passed to :meth:`loss()`. Returns: (torch.Tensor, torch.Tensor): Optimized mark tensor with shape ``(C + 1, H, W)`` and loss tensor. """ atanh_mark = torch.randn_like(self.attack.mark.mark, requires_grad=True) optimizer = optim.Adam([atanh_mark], lr=self.defense_remask_lr, betas=(0.5, 0.9)) lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.defense_remask_epoch) optimizer.zero_grad() loader = loader or self.dataset.loader['train'] # best optimization results norm_best: float = float('inf') mark_best: torch.Tensor = None loss_best: float = None logger = MetricLogger(indent=4) logger.create_meters( loss='{last_value:.3f}', acc='{last_value:.3f}', norm='{last_value:.3f}', entropy='{last_value:.3f}', ) batch_logger = MetricLogger() logger.create_meters(loss=None, acc=None, entropy=None) iterator = range(self.defense_remask_epoch) if verbose: iterator = logger.log_every(iterator, header=logger_header) for _ in iterator: batch_logger.reset() for data in loader: self.attack.mark.mark = tanh_func(atanh_mark) # (c+1, h, w) _input, _label = self.model.get_data(data) trigger_input = self.attack.add_mark(_input) trigger_label = label * torch.ones_like(_label) trigger_output = self.model(trigger_input) batch_acc = trigger_label.eq( trigger_output.argmax(1)).float().mean() batch_entropy = self.loss(_input, _label, target=label, trigger_output=trigger_output, **kwargs) batch_norm: torch.Tensor = self.attack.mark.mark[-1].norm(p=1) batch_loss = batch_entropy + self.cost * batch_norm batch_loss.backward() optimizer.step() optimizer.zero_grad() batch_size = _label.size(0) batch_logger.update(n=batch_size, loss=batch_loss.item(), acc=batch_acc.item(), entropy=batch_entropy.item()) lr_scheduler.step() self.attack.mark.mark = tanh_func(atanh_mark) # (c+1, h, w) # check to save best mask or not loss = batch_logger.meters['loss'].global_avg acc = batch_logger.meters['acc'].global_avg norm = float(self.attack.mark.mark[-1].norm(p=1)) entropy = batch_logger.meters['entropy'].global_avg if norm < norm_best: mark_best = self.attack.mark.mark.detach().clone() loss_best = loss logger.update(loss=loss, acc=acc, norm=norm, entropy=entropy) if self.check_early_stop(loss=loss, acc=acc, norm=norm, entropy=entropy): print('early stop') break atanh_mark.requires_grad_(False) self.attack.mark.mark = mark_best return mark_best, loss_best
def train(module: nn.Module, num_classes: int, epochs: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, lr_warmup_epochs: int = 0, model_ema: ExponentialMovingAverage = None, model_ema_steps: int = 32, grad_clip: float = None, pre_conditioner: None | KFAC | EKFAC = None, print_prefix: str = 'Train', start_epoch: int = 0, resume: int = 0, validate_interval: int = 10, save: bool = False, amp: bool = False, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, epoch_fn: Callable[..., None] = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, forward_fn: Callable[..., torch.Tensor] = None, loss_fn: Callable[..., torch.Tensor] = None, after_loss_fn: Callable[..., None] = None, validate_fn: Callable[..., tuple[float, float]] = None, save_fn: Callable[..., None] = None, file_path: str = None, folder_path: str = None, suffix: str = None, writer=None, main_tag: str = 'train', tag: str = '', accuracy_fn: Callable[..., list[float]] = None, verbose: bool = True, output_freq: str = 'iter', indent: int = 0, change_train_eval: bool = True, lr_scheduler_freq: str = 'epoch', backward_and_step: bool = True, **kwargs): r"""Train the model""" if epochs <= 0: return get_data_fn = get_data_fn or (lambda x: x) forward_fn = forward_fn or module.__call__ loss_fn = loss_fn or (lambda _input, _label, _output=None: F.cross_entropy( _output or forward_fn(_input), _label)) validate_fn = validate_fn or validate accuracy_fn = accuracy_fn or accuracy scaler: torch.cuda.amp.GradScaler = None if not env['num_gpus']: amp = False if amp: scaler = torch.cuda.amp.GradScaler() best_validate_result = (0.0, float('inf')) if validate_interval != 0: best_validate_result = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, forward_fn=forward_fn, loss_fn=loss_fn, writer=None, tag=tag, _epoch=start_epoch, verbose=verbose, indent=indent, **kwargs) best_acc = best_validate_result[0] params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) len_loader_train = len(loader_train) total_iter = (epochs - resume) * len_loader_train logger = MetricLogger() logger.create_meters(loss=None, top1=None, top5=None) if resume and lr_scheduler: for _ in range(resume): lr_scheduler.step() iterator = range(resume, epochs) if verbose and output_freq == 'epoch': header: str = '{blue_light}{0}: {reset}'.format(print_prefix, **ansi) header = header.ljust(max(len(header), 30) + get_ansi_len(header)) iterator = logger.log_every(range(resume, epochs), header=print_prefix, tqdm_header='Epoch', indent=indent) for _epoch in iterator: _epoch += 1 logger.reset() if callable(epoch_fn): activate_params(module, []) epoch_fn(optimizer=optimizer, lr_scheduler=lr_scheduler, _epoch=_epoch, epochs=epochs, start_epoch=start_epoch) loader_epoch = loader_train if verbose and output_freq == 'iter': header: str = '{blue_light}{0}: {1}{reset}'.format( 'Epoch', output_iter(_epoch, epochs), **ansi) header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header)) loader_epoch = logger.log_every(loader_train, header=header, tqdm_header='Batch', indent=indent) if change_train_eval: module.train() activate_params(module, params) for i, data in enumerate(loader_epoch): _iter = _epoch * len_loader_train + i # data_time.update(time.perf_counter() - end) _input, _label = get_data_fn(data, mode='train') if pre_conditioner is not None and not amp: pre_conditioner.track.enable() _output = forward_fn(_input, amp=amp, parallel=True) loss = loss_fn(_input, _label, _output=_output, amp=amp) if backward_and_step: optimizer.zero_grad() if amp: scaler.scale(loss).backward() if callable(after_loss_fn) or grad_clip is not None: scaler.unscale_(optimizer) if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) if grad_clip is not None: nn.utils.clip_grad_norm_(params, grad_clip) scaler.step(optimizer) scaler.update() else: loss.backward() if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) # start_epoch=start_epoch, _epoch=_epoch, epochs=epochs) if pre_conditioner is not None: pre_conditioner.track.disable() pre_conditioner.step() if grad_clip is not None: nn.utils.clip_grad_norm_(params, grad_clip) optimizer.step() if model_ema and i % model_ema_steps == 0: model_ema.update_parameters(module) if _epoch <= lr_warmup_epochs: # Reset ema buffer to keep copying weights # during warmup period model_ema.n_averaged.fill_(0) if lr_scheduler and lr_scheduler_freq == 'iter': lr_scheduler.step() acc1, acc5 = accuracy_fn(_output, _label, num_classes=num_classes, topk=(1, 5)) batch_size = int(_label.size(0)) logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5) empty_cache() optimizer.zero_grad() if lr_scheduler and lr_scheduler_freq == 'epoch': lr_scheduler.step() if change_train_eval: module.eval() activate_params(module, []) loss, acc = (logger.meters['loss'].global_avg, logger.meters['top1'].global_avg) if writer is not None: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch + start_epoch) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch + start_epoch) if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs): validate_result = validate_fn(module=module, num_classes=num_classes, loader=loader_valid, get_data_fn=get_data_fn, forward_fn=forward_fn, loss_fn=loss_fn, writer=writer, tag=tag, _epoch=_epoch + start_epoch, verbose=verbose, indent=indent, **kwargs) cur_acc = validate_result[0] if cur_acc >= best_acc: best_validate_result = validate_result if verbose: prints('{purple}best result update!{reset}'.format(**ansi), indent=indent) prints( f'Current Acc: {cur_acc:.3f} ' f'Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose) if verbose: prints('-' * 50, indent=indent) module.zero_grad() return best_validate_result
def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, grad_clip: float = None, print_prefix: str = 'Epoch', start_epoch: int = 0, validate_interval: int = 10, save: bool = False, amp: bool = False, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, epoch_fn: Callable[..., None] = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, after_loss_fn: Callable[..., None] = None, validate_fn: Callable[..., tuple[float, float]] = None, save_fn: Callable[..., None] = None, file_path: str = None, folder_path: str = None, suffix: str = None, writer=None, main_tag: str = 'train', tag: str = '', verbose: bool = True, indent: int = 0, **kwargs): loader_train = loader_train if loader_train is not None else self.dataset.loader[ 'train'] get_data_fn = get_data_fn if callable(get_data_fn) else self.get_data loss_fn = loss_fn if callable(loss_fn) else self.loss validate_fn = validate_fn if callable(validate_fn) else self._validate save_fn = save_fn if callable(save_fn) else self.save # if not callable(iter_fn) and hasattr(self, 'iter_fn'): # iter_fn = getattr(self, 'iter_fn') if not callable(epoch_fn) and hasattr(self, 'epoch_fn'): epoch_fn = getattr(self, 'epoch_fn') if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'): after_loss_fn = getattr(self, 'after_loss_fn') scaler: torch.cuda.amp.GradScaler = None if not env['num_gpus']: amp = False if amp: scaler = torch.cuda.amp.GradScaler() _, best_acc = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, writer=None, tag=tag, _epoch=start_epoch, verbose=verbose, indent=indent, **kwargs) params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) total_iter = epoch * len(loader_train) for _epoch in range(epoch): _epoch += 1 if callable(epoch_fn): self.activate_params([]) epoch_fn(optimizer=optimizer, lr_scheduler=lr_scheduler, _epoch=_epoch, epoch=epoch, start_epoch=start_epoch) self.activate_params(params) logger = MetricLogger() logger.meters['loss'] = SmoothedValue() logger.meters['top1'] = SmoothedValue() logger.meters['top5'] = SmoothedValue() loader_epoch = loader_train if verbose: header = '{blue_light}{0}: {1}{reset}'.format( print_prefix, output_iter(_epoch, epoch), **ansi) header = header.ljust(30 + get_ansi_len(header)) if env['tqdm']: header = '{upline}{clear_line}'.format(**ansi) + header loader_epoch = tqdm(loader_epoch) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) self.train() self.activate_params(params) optimizer.zero_grad() for i, data in enumerate(loader_epoch): _iter = _epoch * len(loader_train) + i # data_time.update(time.perf_counter() - end) _input, _label = get_data_fn(data, mode='train') _output = self(_input, amp=amp) loss = loss_fn(_input, _label, _output=_output, amp=amp) if amp: scaler.scale(loss).backward() if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) scaler.step(optimizer) scaler.update() else: loss.backward() if grad_clip is not None: nn.utils.clip_grad_norm_(params) if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) # start_epoch=start_epoch, _epoch=_epoch, epoch=epoch) optimizer.step() optimizer.zero_grad() acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) logger.meters['loss'].update(float(loss), batch_size) logger.meters['top1'].update(acc1, batch_size) logger.meters['top5'].update(acc5, batch_size) empty_cache( ) # TODO: should it be outside of the dataloader loop? self.eval() self.activate_params([]) loss, acc = logger.meters['loss'].global_avg, logger.meters[ 'top1'].global_avg if writer is not None: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch + start_epoch) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch + start_epoch) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if _epoch % validate_interval == 0 or _epoch == epoch: _, cur_acc = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, writer=writer, tag=tag, _epoch=_epoch + start_epoch, verbose=verbose, indent=indent, **kwargs) if cur_acc >= best_acc: if verbose: prints('{green}best result update!{reset}'.format( **ansi), indent=indent) prints( f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose) if verbose: prints('-' * 50, indent=indent) self.zero_grad()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--voc_root', default='~/voc') parser.add_argument('--tar_path', default='~/reflection.tar') kwargs = parser.parse_args().__dict__ voc_root: str = kwargs['voc_root'] tar_path: str = kwargs['tar_path'] print('get image paths') datasets = [ torchvision.datasets.VOCDetection(voc_root, year=year, image_set=image_set, download=True) for year, image_set in sets ] background_paths = get_img_paths(datasets, positive_class=background_class, negative_class=reflect_class) reflect_paths = get_img_paths(datasets, positive_class=reflect_class, negative_class=background_class) print() print('background: ', len(background_paths)) print('reflect: ', len(reflect_paths)) print() print('load images') reflect_imgs = [read_tensor(fp) for fp in reflect_paths] background_imgs = [ read_tensor(fp) for i, fp in enumerate(background_paths) if i < NUM_ATTACK ] print('writing tar file: ', tar_path) tf = tarfile.open(tar_path, mode='w') trojanzoo.environ.create(color=True, tqdm=True) logger = MetricLogger(meter_length=35) logger.create_meters( reflect_num='{count:3d}', reflect_mean='{global_avg:.3f} ({min:.3f} {max:.3f})', diff_mean='{global_avg:.3f} ({min:.3f} {max:.3f})', blended_max='{global_avg:.3f} ({min:.3f} {max:.3f})', ssim='{global_avg:.3f} ({min:.3f} {max:.3f})') candidates: set[int] = set() for background_img in logger.log_every(background_imgs): for i, reflect_img in enumerate(reflect_imgs): if i in candidates: continue blended, background_layer, reflection_layer = blend_images( background_img, reflect_img, ghost_rate=0.39) reflect_mean: float = reflection_layer.mean().item() diff_mean: float = (blended - reflection_layer).mean().item() blended_max: float = blended.max().item() logger.update(reflect_mean=reflect_mean, diff_mean=diff_mean, blended_max=blended_max) if reflect_mean < 0.8 * diff_mean and blended_max > 0.1: ssim: float = skimage.metrics.structural_similarity( blended.numpy(), background_layer.numpy(), channel_axis=0) logger.update(ssim=ssim) if 0.7 < ssim < 0.85: logger.update(reflect_num=1) candidates.add(i) filename = os.path.basename(reflect_paths[i]) bytes_io = io.BytesIO() format = os.path.splitext(filename)[1][1:].lower().replace( 'jpg', 'jpeg') F.to_pil_image(reflection_layer).save(bytes_io, format=format) bytes_data = bytes_io.getvalue() tarinfo = tarfile.TarInfo(name=filename) tarinfo.size = len(bytes_data) tf.addfile(tarinfo, io.BytesIO(bytes_data)) break tf.close()
def attack(self, epochs: int, save: bool = False, **kwargs): # train generator resnet_model = trojanvision.models.create('resnet18_comp', dataset=self.dataset, pretrained=True) model_extractor = nn.Sequential(*list(resnet_model._model.features.children())[:4]) match self.generator_mode: case 'resnet': resnet_model = torchvision.models.resnet18(pretrained=True).to(device=env['device']) model_extractor = nn.Sequential(*list(resnet_model.children())[:5]) case _: resnet_model = trojanvision.models.create('resnet18_comp', dataset=self.dataset, pretrained=True) model_extractor = nn.Sequential(*list(resnet_model._model.features.children())[:4]) model_extractor.requires_grad_(False) model_extractor.train() self.generator.train() if self.generator_mode == 'default': self.generator.requires_grad_() parameters = self.generator.parameters() else: self.generator.bottleneck.requires_grad_() self.generator.decoder.requires_grad_() parameters = self.generator[1:].parameters() optimizer = torch.optim.Adam(parameters, lr=1e-3) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.train_generator_epochs, eta_min=1e-5) logger = MetricLogger() logger.create_meters(loss=None, acc=None) for _epoch in range(self.train_generator_epochs): _epoch += 1 logger.reset() header: str = '{blue_light}{0}: {1}{reset}'.format( 'Epoch', output_iter(_epoch, self.train_generator_epochs), **ansi) header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header)) loader = logger.log_every(self.dataset.loader['train'], header=header, tqdm_header='Batch') for data in loader: optimizer.zero_grad() _input, _label = self.model.get_data(data) adv_input = (self.generator(_input) + 1) / 2 _feats = model_extractor(_input) adv_feats = model_extractor(adv_input) loss = F.l1_loss(_feats, self.noise_coeff * adv_feats) loss.backward() optimizer.step() batch_size = len(_label) with torch.no_grad(): org_class = self.model.get_class(_input) adv_class = self.model.get_class(adv_input) acc = (org_class == adv_class).float().sum().item() * 100.0 / batch_size logger.update(n=batch_size, loss=loss.item(), acc=acc) lr_scheduler.step() self.save_generator() optimizer.zero_grad() self.generator.eval() self.generator.requires_grad_(False) self.mark.mark[:-1] = (self.generator(self.mark.mark[:-1].unsqueeze(0))[0] + 1) / 2 self.poison_set = self.get_poison_dataset(load_mark=False) return super().attack(epochs, save=save, **kwargs)
def gen_reflect_imgs(tar_path: str, voc_root: str, num_attack: int = 160, reflect_class: set[str] = {'cat'}, background_class: set[str] = {'person'}): r"""Generate a tar file containing reflect images. Args: tar_path (str): Tar file path to save. voc_root (str): VOC dataset root path. num_attack (int): Number of reflect images to generate. reflect_class (set[str]): Set of reflect classes. background_class (set[str]): Set of background classes. """ print('get image paths') if not os.path.isdir(voc_root): os.makedirs(voc_root) datasets = [torchvision.datasets.VOCDetection(voc_root, year=year, image_set=image_set, download=True) for year, image_set in sets] background_paths = _get_img_paths(datasets, positive_class=background_class, negative_class=reflect_class) reflect_paths = _get_img_paths(datasets, positive_class=reflect_class, negative_class=background_class) print() print('background: ', len(background_paths)) print('reflect: ', len(reflect_paths)) print() print('load images') reflect_imgs = [read_tensor(fp) for fp in reflect_paths] print('writing tar file: ', tar_path) tf = tarfile.open(tar_path, mode='w') logger = MetricLogger(meter_length=35) logger.create_meters(reflect_num=f'[ {{count:3d}} / {num_attack:3d} ]', reflect_mean='{global_avg:.3f} ({min:.3f} {max:.3f})', diff_mean='{global_avg:.3f} ({min:.3f} {max:.3f})', blended_max='{global_avg:.3f} ({min:.3f} {max:.3f})', ssim='{global_avg:.3f} ({min:.3f} {max:.3f})') candidates: set[int] = set() for fp in logger.log_every(background_paths): background_img = read_tensor(fp) for i, reflect_img in enumerate(reflect_imgs): if i in candidates: continue blended, background_layer, reflection_layer = blend_images( background_img, reflect_img, ghost_rate=0.39) reflect_mean: float = reflection_layer.mean().item() diff_mean: float = (blended - reflection_layer).mean().item() blended_max: float = blended.max().item() logger.update(reflect_mean=reflect_mean, diff_mean=diff_mean, blended_max=blended_max) if reflect_mean < 0.8 * diff_mean and blended_max > 0.1: ssim: float = skimage.metrics.structural_similarity( blended.numpy(), background_layer.numpy(), channel_axis=0) logger.update(ssim=ssim) if 0.7 < ssim < 0.85: logger.update(reflect_num=1) candidates.add(i) filename = os.path.basename(reflect_paths[i]) bytes_io = io.BytesIO() format = os.path.splitext(filename)[1][1:].lower().replace('jpg', 'jpeg') F.to_pil_image(reflection_layer).save(bytes_io, format=format) bytes_data = bytes_io.getvalue() tarinfo = tarfile.TarInfo(name=filename) tarinfo.size = len(bytes_data) tf.addfile(tarinfo, io.BytesIO(bytes_data)) break if len(candidates) == num_attack: break else: raise RuntimeError('Can not generate enough images') tf.close()