def preprocess_mark(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]): other_x, _ = data['other'] other_set = TensorDataset(other_x) other_loader = self.dataset.get_dataloader(mode='train', dataset=other_set, num_workers=0) atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask atanh_mark.requires_grad_() self.mark.mark = tanh_func(atanh_mark) optimizer = optim.Adam([atanh_mark], lr=self.preprocess_lr) optimizer.zero_grad() losses = AverageMeter('Loss', ':.4e') for _epoch in range(self.preprocess_epoch): loader = other_loader for (batch_x, ) in loader: poison_x = self.mark.add_mark(to_tensor(batch_x)) loss = self.loss_mse(poison_x) loss.backward() optimizer.step() optimizer.zero_grad() self.mark.mark = tanh_func(atanh_mark) losses.update(loss.item(), n=len(batch_x)) atanh_mark.requires_grad = False self.mark.mark.detach_()
def optimize_mark(self, loss_fn=None, **kwargs): atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask atanh_mark.requires_grad_() self.mark.mark = tanh_func(atanh_mark) optimizer = optim.Adam([atanh_mark], lr=self.inner_lr) optimizer.zero_grad() if loss_fn is None: loss_fn = self.model.loss losses = AverageMeter('Loss', ':.4e') for _epoch in range(self.inner_iter): for i, data in enumerate(self.dataset.loader['train']): if i > 20: break _input, _label = self.model.get_data(data) poison_x = self.mark.add_mark(_input) loss = loss_fn(poison_x, self.target_class * torch.ones_like(_label)) loss.backward() optimizer.step() optimizer.zero_grad() self.mark.mark = tanh_func(atanh_mark) losses.update(loss.item(), n=len(_label)) atanh_mark.requires_grad = False self.mark.mark.detach_()
def discrim_train(self, epoch: int, D: nn.Sequential, discrim_loader: torch.utils.data.DataLoader): losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') d_optimizer = optim.Adam(D.parameters(), lr=self.discrim_lr) d_optimizer.zero_grad() for _epoch in range(epoch): losses.reset() top1.reset() self.model.activate_params([D.parameters()]) D.train() for data in discrim_loader: # train D _input, _label = self.model.get_data(data) out_f = self.model.get_final_fm(_input).detach() out_d = D(out_f) loss_d = self.model.criterion(out_d, _label) acc1 = self.model.accuracy(out_d, _label, topk=(1, ))[0] batch_size = int(_label.size(0)) losses.update(loss_d.item(), batch_size) top1.update(acc1, batch_size) loss_d.backward() d_optimizer.step() d_optimizer.zero_grad() print(f'Discriminator - epoch {_epoch:4d} / {epoch:4d} | loss {losses.avg:.4f} | acc {top1.avg:.4f}') self.model.activate_params([]) D.eval()
def confirm_backdoor(self): top1 = AverageMeter('Acc@1', ':6.2f') for data in self.dataset.loader['valid']: _input, _ = self.model.get_data(data, mode='valid') poison_input = self.attack.add_mark(_input) with torch.no_grad(): _class = self.model.get_class(_input) poison_class = self.model.get_class(poison_input) result = ~(_class.eq(poison_class)) acc1 = result.float().sum() / result.numel() * 100 top1.update(acc1.item(), len(_input)) return top1.avg
def validate_confidence(self) -> float: confidence = AverageMeter('Confidence', ':.4e') with torch.no_grad(): for data in self.dataset.loader['valid']: _input, _label = self.model.get_data(data) idx1 = _label != self.target_class _input = _input[idx1] _label = _label[idx1] if len(_input) == 0: continue poison_input = self.add_mark(_input) poison_label = self.model.get_class(poison_input) idx2 = poison_label == self.target_class poison_input = poison_input[idx2] if len(poison_input) == 0: continue batch_conf = self.model.get_prob( poison_input)[:, self.target_class].mean() confidence.update(batch_conf, len(poison_input)) return float(confidence.avg)
def preprocess_mark(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]): other_x, _ = data['other'] other_set = TensorDataset(other_x) other_loader = self.dataset.get_dataloader(mode='train', dataset=other_set, num_workers=0) atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask atanh_mark.requires_grad_() self.mark.mark = tanh_func(atanh_mark) optimizer = optim.Adam([atanh_mark], lr=self.preprocess_lr) optimizer.zero_grad() losses = AverageMeter('Loss', ':.4e') for _epoch in range(self.preprocess_epoch): # epoch_start = time.perf_counter() loader = other_loader # if env['tqdm']: # loader = tqdm(loader) for (batch_x, ) in loader: poison_x = self.mark.add_mark(to_tensor(batch_x)) loss = self.loss_mse(poison_x) loss.backward() optimizer.step() optimizer.zero_grad() self.mark.mark = tanh_func(atanh_mark) losses.update(loss.item(), n=len(batch_x)) # epoch_time = str(datetime.timedelta(seconds=int( # time.perf_counter() - epoch_start))) # pre_str = '{blue_light}Epoch: {0}{reset}'.format( # output_iter(_epoch + 1, self.preprocess_epoch), **ansi).ljust(64 if env['color'] else 35) # _str = ' '.join([ # f'Loss: {losses.avg:.4f},'.ljust(20), # f'Time: {epoch_time},'.ljust(20), # ]) # prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=4) atanh_mark.requires_grad = False self.mark.mark.detach_()
def _validate(self, full=True, print_prefix='Validate', indent=0, verbose=True, loader: torch.utils.data.DataLoader = None, get_data_fn: Callable = None, loss_fn: Callable[..., float] = None, **kwargs) -> tuple[float, ...]: self.eval() if loader is None: loader = self.dataset.loader['valid'] if full else self.dataset.loader['valid2'] get_data_fn = get_data_fn if get_data_fn is not None else self.get_data loss_fn = loss_fn if loss_fn is not None else self.loss losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') epoch_start = time.perf_counter() if verbose and env['tqdm']: loader = tqdm(loader) for data in loader: _input, _label = get_data_fn(data, mode='valid', **kwargs) with torch.no_grad(): loss = loss_fn(_input, _label) _output = self.get_logits(_input) # measure accuracy and record loss batch_size = int(_label.size(0)) losses.update(loss.item(), _label.size(0)) acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) top1.update(acc1, batch_size) top5.update(acc5, batch_size) epoch_time = str(datetime.timedelta(seconds=int( time.perf_counter() - epoch_start))) if verbose: pre_str = '{yellow}{0}:{reset}'.format(print_prefix, **ansi).ljust(35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Acc: {top1.avg:.3f}, '.ljust(20), f'Top5 Acc: {top5.avg:.3f},'.ljust(20), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=indent) return losses.avg, top1.avg, top5.avg
def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None, validate_interval=10, save=False, verbose=True, indent=0, epoch_fn: Callable = None, **kwargs): loader_train = self.dataset.loader['train'] file_path = self.folder_path + self.get_filename() + '.pth' _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs) losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') params = [param_group['params'] for param_group in optimizer.param_groups] for _epoch in range(epoch): if callable(epoch_fn): self.model.activate_params([]) epoch_fn() self.model.activate_params(params) losses.reset() top1.reset() top5.reset() epoch_start = time.perf_counter() if verbose and env['tqdm']: loader_train = tqdm(loader_train) optimizer.zero_grad() for data in loader_train: _input, _label = self.model.get_data(data) noise = torch.zeros_like(_input) poison_input, poison_label = self.get_poison_data(data) def loss_fn(X: torch.FloatTensor): return -self.model.loss(X, _label) adv_x = _input self.model.train() loss = self.model.loss(adv_x, _label) loss.backward() optimizer.step() optimizer.zero_grad() for m in range(self.pgd.iteration): self.model.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1) optimizer.zero_grad() self.model.train() x = torch.cat((adv_x, poison_input)) y = torch.cat((_label, poison_label)) loss = self.model.loss(x, y) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): _output = self.model.get_logits(_input) acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) losses.update(loss.item(), batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) epoch_time = str(datetime.timedelta(seconds=int( time.perf_counter() - epoch_start))) self.model.eval() self.model.activate_params([]) if verbose: pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30), f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=indent) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1: _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs) if cur_acc < best_acc: prints('best result update!', indent=indent) prints(f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: self.save() if verbose: print('-' * 50) self.model.zero_grad()
def remask(self, label: int): epoch = self.epoch # no bound atanh_mark = torch.randn(self.data_shape, device=env['device']) atanh_mark.requires_grad_() atanh_mask = torch.randn(self.data_shape[1:], device=env['device']) atanh_mask.requires_grad_() mask = tanh_func(atanh_mask) # (h, w) mark = tanh_func(atanh_mark) # (c, h, w) optimizer = optim.Adam([atanh_mark, atanh_mask], lr=0.1, betas=(0.5, 0.9)) optimizer.zero_grad() cost = self.init_cost cost_set_counter = 0 cost_up_counter = 0 cost_down_counter = 0 cost_up_flag = False cost_down_flag = False # best optimization results norm_best = float('inf') mask_best = None mark_best = None entropy_best = None # counter for early stop early_stop_counter = 0 early_stop_norm_best = norm_best losses = AverageMeter('Loss', ':.4e') entropy = AverageMeter('Entropy', ':.4e') norm = AverageMeter('Norm', ':.4e') acc = AverageMeter('Acc', ':6.2f') for _epoch in range(epoch): losses.reset() entropy.reset() norm.reset() acc.reset() epoch_start = time.perf_counter() loader = self.dataset.loader['train'] if env['tqdm']: loader = tqdm(loader) for data in loader: _input, _label = self.model.get_data(data) batch_size = _label.size(0) X = _input + mask * (mark - _input) Y = label * torch.ones_like(_label, dtype=torch.long) _output = self.model(X) batch_acc = Y.eq(_output.argmax(1)).float().mean() batch_entropy = self.loss_fn(_input, _label, Y, mask, mark, label) batch_norm = mask.norm(p=1) batch_loss = batch_entropy + cost * batch_norm acc.update(batch_acc.item(), batch_size) entropy.update(batch_entropy.item(), batch_size) norm.update(batch_norm.item(), batch_size) losses.update(batch_loss.item(), batch_size) batch_loss.backward() optimizer.step() optimizer.zero_grad() mask = tanh_func(atanh_mask) # (h, w) mark = tanh_func(atanh_mark) # (c, h, w) epoch_time = str( datetime.timedelta(seconds=int(time.perf_counter() - epoch_start))) pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Acc: {acc.avg:.2f}, '.ljust(20), f'Norm: {norm.avg:.4f},'.ljust(20), f'Entropy: {entropy.avg:.4f},'.ljust(20), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format( **ansi) if env['tqdm'] else '', indent=4) # check to save best mask or not if acc.avg >= self.attack_succ_threshold and norm.avg < norm_best: mask_best = mask.detach() mark_best = mark.detach() norm_best = norm.avg entropy_best = entropy.avg # check early stop if self.early_stop: # only terminate if a valid attack has been found if norm_best < float('inf'): if norm_best >= self.early_stop_threshold * early_stop_norm_best: early_stop_counter += 1 else: early_stop_counter = 0 early_stop_norm_best = min(norm_best, early_stop_norm_best) if cost_down_flag and cost_up_flag and early_stop_counter >= self.early_stop_patience: print('early stop') break # check cost modification if cost == 0 and acc.avg >= self.attack_succ_threshold: cost_set_counter += 1 if cost_set_counter >= self.patience: cost = self.init_cost cost_up_counter = 0 cost_down_counter = 0 cost_up_flag = False cost_down_flag = False print('initialize cost to %.2f' % cost) else: cost_set_counter = 0 if acc.avg >= self.attack_succ_threshold: cost_up_counter += 1 cost_down_counter = 0 else: cost_up_counter = 0 cost_down_counter += 1 if cost_up_counter >= self.patience: cost_up_counter = 0 prints('up cost from %.4f to %.4f' % (cost, cost * self.cost_multiplier_up), indent=4) cost *= self.cost_multiplier_up cost_up_flag = True elif cost_down_counter >= self.patience: cost_down_counter = 0 prints('down cost from %.4f to %.4f' % (cost, cost / self.cost_multiplier_down), indent=4) cost /= self.cost_multiplier_down cost_down_flag = True if mask_best is None: mask_best = tanh_func(atanh_mask).detach() mark_best = tanh_func(atanh_mark).detach() norm_best = norm.avg entropy_best = entropy.avg atanh_mark.requires_grad = False atanh_mask.requires_grad = False self.attack.mark.mark = mark_best self.attack.mark.alpha_mark = mask_best self.attack.mark.mask = torch.ones_like(mark_best, dtype=torch.bool) self.attack.validate_fn() return mark_best, mask_best, entropy_best
def get_potential_triggers(self, neuron_dict: dict[int, list[dict]], _input: torch.Tensor, _label: torch.Tensor, use_mask=True) -> dict[int, list[dict]]: losses = AverageMeter('Loss', ':.4e') norms = AverageMeter('Norm', ':6.2f') jaccard = AverageMeter('Jaccard Idx', ':6.2f') score_list = [0.0] * len(list(neuron_dict.keys())) result_dict = {} for label, label_list in neuron_dict.items(): print('label: ', label) best_score = 100.0 for _dict in reversed(label_list): layer = _dict['layer'] neuron = _dict['neuron'] value = _dict['value'] # color = ('{red}' if label == self.attack.target_class else '{green}').format(**ansi) # _str = f'layer: {layer:<20} neuron: {neuron:<5d} label: {label:<5d}' # prints('{color}{_str}{reset}'.format(color=color, _str=_str, **ansi), indent=4) mark, mask, loss = self.remask(_input, layer=layer, neuron=neuron, label=label, use_mask=use_mask) self.attack.mark.mark = mark self.attack.mark.alpha_mask = mask self.attack.mark.mask = torch.ones_like(mark, dtype=torch.bool) self.attack.target_class = label attack_loss, attack_acc = self.model._validate( verbose=False, get_data_fn=self.attack.get_data, keep_org=False) _dict['loss'] = loss _dict['attack_acc'] = attack_acc _dict['attack_loss'] = attack_loss _dict['mask'] = to_numpy(mask) _dict['mark'] = to_numpy(mark) _dict['norm'] = float(mask.norm(p=1)) score = attack_loss + 7e-2 * float(mask.norm(p=1)) if score < best_score: best_score = score result_dict[label] = _dict if attack_acc > 90: losses.update(loss) norms.update(mask.norm(p=1)) _str = f' layer: {layer:20s} neuron: {neuron:5d} value: {value:.3f}' _str += f' loss: {loss:10.3f}' f' ATK Acc: {attack_acc:.3f}' f' ATK Loss: {attack_loss:10.3f}' f' Norm: {mask.norm(p=1):.3f}' f' Score: {score:.3f}' if not self.attack.mark.random_pos: overlap = jaccard_idx(mask, self.real_mask) _dict['jaccard'] = overlap _str += f' Jaccard: {overlap:.3f}' if attack_acc > 90: jaccard.update(overlap) else: _dict['jaccard'] = 0.0 print(_str) if not os.path.exists(self.folder_path): os.makedirs(self.folder_path) np.save( self.folder_path + self.get_filename(target_class=self.target_class) + '.npy', neuron_dict) np.save( self.folder_path + self.get_filename(target_class=self.target_class) + '_best.npy', result_dict) print( f'Label: {label:3d} loss: {result_dict[label]["loss"]:10.3f} ATK loss: {result_dict[label]["attack_loss"]:10.3f} Norm: {result_dict[label]["norm"]:10.3f} Jaccard: {result_dict[label]["jaccard"]:10.3f} Score: {best_score:.3f}' ) score_list[label] = best_score print('Score: ', score_list) print('Score MAD: ', normalize_mad(score_list)) return neuron_dict
def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, validate_interval: int = 10, save: bool = False, amp: bool = False, verbose: bool = True, indent: int = 0, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, get_data_fn: Callable[..., tuple[InputType, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, validate_func: Callable[..., tuple[float, ...]] = None, epoch_func: Callable[[], None] = None, save_fn: Callable = None, file_path: str = None, folder_path: str = None, suffix: str = None, **kwargs): loader_train = loader_train if loader_train is not None else self.dataset.loader['train'] get_data_fn = get_data_fn if get_data_fn is not None else self.get_data loss_fn = loss_fn if loss_fn is not None else self.loss validate_func = validate_func if validate_func is not None else self._validate save_fn = save_fn if save_fn is not None else self.save scaler: torch.cuda.amp.GradScaler = None if amp and env['num_gpus']: scaler = torch.cuda.amp.GradScaler() _, best_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, verbose=verbose, indent=indent, **kwargs) losses = AverageMeter('Loss') top1 = AverageMeter('Acc@1') top5 = AverageMeter('Acc@5') params: list[list[nn.Parameter]] = [param_group['params'] for param_group in optimizer.param_groups] for _epoch in range(epoch): if epoch_func is not None: self.activate_params([]) epoch_func() self.activate_params(params) losses.reset() top1.reset() top5.reset() epoch_start = time.perf_counter() loader = loader_train if verbose and env['tqdm']: loader = tqdm(loader_train) self.train() self.activate_params(params) optimizer.zero_grad() for data in loader: # data_time.update(time.perf_counter() - end) _input, _label = get_data_fn(data, mode='train') if amp and env['num_gpus']: loss = loss_fn(_input, _label, amp=True) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss = loss_fn(_input, _label) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): _output = self.get_logits(_input) acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) losses.update(loss.item(), batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) empty_cache() epoch_time = str(datetime.timedelta(seconds=int( time.perf_counter() - epoch_start))) self.eval() self.activate_params([]) if verbose: pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Acc: {top1.avg:.3f}, '.ljust(20), f'Top5 Acc: {top5.avg:.3f},'.ljust(20), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=indent) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1: _, cur_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, verbose=verbose, indent=indent, **kwargs) if cur_acc >= best_acc: prints('best result update!', indent=indent) prints(f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose) if verbose: print('-' * 50) self.zero_grad()
def remask(self, label: int) -> tuple[torch.Tensor, torch.Tensor]: generator = Generator(self.noise_dim, self.dataset.num_classes, self.dataset.data_shape) for param in generator.parameters(): param.requires_grad_() optimizer = optim.Adam(generator.parameters(), lr=self.remask_lr) optimizer.zero_grad() # mask = self.attack.mark.mask losses = AverageMeter('Loss', ':.4e') entropy = AverageMeter('Entropy', ':.4e') norm = AverageMeter('Norm', ':.4e') acc = AverageMeter('Acc', ':6.2f') torch.manual_seed(env['seed']) noise = torch.rand(1, self.noise_dim, device=env['device']) mark = torch.zeros(self.dataset.data_shape, device=env['device']) for _epoch in range(self.remask_epoch): losses.reset() entropy.reset() norm.reset() acc.reset() epoch_start = time.perf_counter() loader = self.loader if env['tqdm']: loader = tqdm(loader) for data in loader: _input, _label = self.model.get_data(data) batch_size = _label.size(0) poison_label = label * torch.ones_like(_label) mark = generator( noise, torch.tensor([label], device=poison_label.device, dtype=poison_label.dtype)) poison_input = (_input + mark).clamp(0, 1) _output = self.model(poison_input) batch_acc = poison_label.eq(_output.argmax(1)).float().mean() batch_entropy = self.model.criterion(_output, poison_label) batch_norm = mark.flatten(start_dim=1).norm(p=1, dim=1).mean() batch_loss = batch_entropy + self.gamma_2 * batch_norm acc.update(batch_acc.item(), batch_size) entropy.update(batch_entropy.item(), batch_size) norm.update(batch_norm.item(), batch_size) losses.update(batch_loss.item(), batch_size) batch_loss.backward() optimizer.step() optimizer.zero_grad() epoch_time = str( datetime.timedelta(seconds=int(time.perf_counter() - epoch_start))) pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, self.remask_epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Acc: {acc.avg:.2f}, '.ljust(20), f'Norm: {norm.avg:.4f},'.ljust(20), f'Entropy: {entropy.avg:.4f},'.ljust(20), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format( **ansi) if env['tqdm'] else '', indent=4) def get_data_fn(data, **kwargs): _input, _label = self.model.get_data(data) poison_label = torch.ones_like(_label) * label poison_input = (_input + mark).clamp(0, 1) return poison_input, poison_label self.model._validate(print_prefix='Validate Trigger Tgt', get_data_fn=get_data_fn, indent=4) if not self.attack.mark.random_pos: overlap = jaccard_idx(mark.mean(dim=0), self.real_mask, select_num=self.attack.mark.mark_height * self.attack.mark.mark_width) print(f' Jaccard index: {overlap:.3f}') for param in generator.parameters(): param.requires_grad = False return losses.avg, mark