def optimize_mark(self, loss_fn=None, **kwargs): atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask atanh_mark.requires_grad_() self.mark.mark = tanh_func(atanh_mark) optimizer = optim.Adam([atanh_mark], lr=self.inner_lr) optimizer.zero_grad() if loss_fn is None: loss_fn = self.model.loss losses = AverageMeter('Loss', ':.4e') for _epoch in range(self.inner_iter): for i, data in enumerate(self.dataset.loader['train']): if i > 20: break _input, _label = self.model.get_data(data) poison_x = self.mark.add_mark(_input) loss = loss_fn(poison_x, self.target_class * torch.ones_like(_label)) loss.backward() optimizer.step() optimizer.zero_grad() self.mark.mark = tanh_func(atanh_mark) losses.update(loss.item(), n=len(_label)) atanh_mark.requires_grad = False self.mark.mark.detach_()
def preprocess_mark(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]): other_x, _ = data['other'] other_set = TensorDataset(other_x) other_loader = self.dataset.get_dataloader(mode='train', dataset=other_set, num_workers=0) atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask atanh_mark.requires_grad_() self.mark.mark = tanh_func(atanh_mark) optimizer = optim.Adam([atanh_mark], lr=self.preprocess_lr) optimizer.zero_grad() losses = AverageMeter('Loss', ':.4e') for _epoch in range(self.preprocess_epoch): loader = other_loader for (batch_x, ) in loader: poison_x = self.mark.add_mark(to_tensor(batch_x)) loss = self.loss_mse(poison_x) loss.backward() optimizer.step() optimizer.zero_grad() self.mark.mark = tanh_func(atanh_mark) losses.update(loss.item(), n=len(batch_x)) atanh_mark.requires_grad = False self.mark.mark.detach_()
def transform_func(self, x: torch.Tensor) -> torch.Tensor: if isinstance(self.input_transform, str): if self.input_transform == 'tanh': return tanh_func(x) elif self.input_transform in ['atan', 'arctan']: return atan_func(x) elif self.input_transform in ['sigmoid', 'logistic']: return torch.sigmoid(x) else: raise NotImplementedError(self.input_transform) # assert callable(self.input_transform) return self.input_transform(x)
def preprocess_mark(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]): other_x, _ = data['other'] other_set = TensorDataset(other_x) other_loader = self.dataset.get_dataloader(mode='train', dataset=other_set, num_workers=0) atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask atanh_mark.requires_grad_() self.mark.mark = tanh_func(atanh_mark) optimizer = optim.Adam([atanh_mark], lr=self.preprocess_lr) optimizer.zero_grad() losses = AverageMeter('Loss', ':.4e') for _epoch in range(self.preprocess_epoch): # epoch_start = time.perf_counter() loader = other_loader # if env['tqdm']: # loader = tqdm(loader) for (batch_x, ) in loader: poison_x = self.mark.add_mark(to_tensor(batch_x)) loss = self.loss_mse(poison_x) loss.backward() optimizer.step() optimizer.zero_grad() self.mark.mark = tanh_func(atanh_mark) losses.update(loss.item(), n=len(batch_x)) # epoch_time = str(datetime.timedelta(seconds=int( # time.perf_counter() - epoch_start))) # pre_str = '{blue_light}Epoch: {0}{reset}'.format( # output_iter(_epoch + 1, self.preprocess_epoch), **ansi).ljust(64 if env['color'] else 35) # _str = ' '.join([ # f'Loss: {losses.avg:.4f},'.ljust(20), # f'Time: {epoch_time},'.ljust(20), # ]) # prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=4) atanh_mark.requires_grad = False self.mark.mark.detach_()
def remask(self, label: int): epoch = self.epoch # no bound atanh_mark = torch.randn(self.data_shape, device=env['device']) atanh_mark.requires_grad_() atanh_mask = torch.randn(self.data_shape[1:], device=env['device']) atanh_mask.requires_grad_() mask = tanh_func(atanh_mask) # (h, w) mark = tanh_func(atanh_mark) # (c, h, w) optimizer = optim.Adam([atanh_mark, atanh_mask], lr=0.1, betas=(0.5, 0.9)) optimizer.zero_grad() cost = self.init_cost cost_set_counter = 0 cost_up_counter = 0 cost_down_counter = 0 cost_up_flag = False cost_down_flag = False # best optimization results norm_best = float('inf') mask_best = None mark_best = None entropy_best = None # counter for early stop early_stop_counter = 0 early_stop_norm_best = norm_best losses = AverageMeter('Loss', ':.4e') entropy = AverageMeter('Entropy', ':.4e') norm = AverageMeter('Norm', ':.4e') acc = AverageMeter('Acc', ':6.2f') for _epoch in range(epoch): losses.reset() entropy.reset() norm.reset() acc.reset() epoch_start = time.perf_counter() loader = self.dataset.loader['train'] if env['tqdm']: loader = tqdm(loader) for data in loader: _input, _label = self.model.get_data(data) batch_size = _label.size(0) X = _input + mask * (mark - _input) Y = label * torch.ones_like(_label, dtype=torch.long) _output = self.model(X) batch_acc = Y.eq(_output.argmax(1)).float().mean() batch_entropy = self.loss_fn(_input, _label, Y, mask, mark, label) batch_norm = mask.norm(p=1) batch_loss = batch_entropy + cost * batch_norm acc.update(batch_acc.item(), batch_size) entropy.update(batch_entropy.item(), batch_size) norm.update(batch_norm.item(), batch_size) losses.update(batch_loss.item(), batch_size) batch_loss.backward() optimizer.step() optimizer.zero_grad() mask = tanh_func(atanh_mask) # (h, w) mark = tanh_func(atanh_mark) # (c, h, w) epoch_time = str( datetime.timedelta(seconds=int(time.perf_counter() - epoch_start))) pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Acc: {acc.avg:.2f}, '.ljust(20), f'Norm: {norm.avg:.4f},'.ljust(20), f'Entropy: {entropy.avg:.4f},'.ljust(20), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format( **ansi) if env['tqdm'] else '', indent=4) # check to save best mask or not if acc.avg >= self.attack_succ_threshold and norm.avg < norm_best: mask_best = mask.detach() mark_best = mark.detach() norm_best = norm.avg entropy_best = entropy.avg # check early stop if self.early_stop: # only terminate if a valid attack has been found if norm_best < float('inf'): if norm_best >= self.early_stop_threshold * early_stop_norm_best: early_stop_counter += 1 else: early_stop_counter = 0 early_stop_norm_best = min(norm_best, early_stop_norm_best) if cost_down_flag and cost_up_flag and early_stop_counter >= self.early_stop_patience: print('early stop') break # check cost modification if cost == 0 and acc.avg >= self.attack_succ_threshold: cost_set_counter += 1 if cost_set_counter >= self.patience: cost = self.init_cost cost_up_counter = 0 cost_down_counter = 0 cost_up_flag = False cost_down_flag = False print('initialize cost to %.2f' % cost) else: cost_set_counter = 0 if acc.avg >= self.attack_succ_threshold: cost_up_counter += 1 cost_down_counter = 0 else: cost_up_counter = 0 cost_down_counter += 1 if cost_up_counter >= self.patience: cost_up_counter = 0 prints('up cost from %.4f to %.4f' % (cost, cost * self.cost_multiplier_up), indent=4) cost *= self.cost_multiplier_up cost_up_flag = True elif cost_down_counter >= self.patience: cost_down_counter = 0 prints('down cost from %.4f to %.4f' % (cost, cost / self.cost_multiplier_down), indent=4) cost /= self.cost_multiplier_down cost_down_flag = True if mask_best is None: mask_best = tanh_func(atanh_mask).detach() mark_best = tanh_func(atanh_mark).detach() norm_best = norm.avg entropy_best = entropy.avg atanh_mark.requires_grad = False atanh_mask.requires_grad = False self.attack.mark.mark = mark_best self.attack.mark.alpha_mark = mask_best self.attack.mark.mask = torch.ones_like(mark_best, dtype=torch.bool) self.attack.validate_fn() return mark_best, mask_best, entropy_best
def remask(self, _input: torch.Tensor, layer: str, neuron: int, label: int, use_mask: bool = True, validate_interval: int = 100, verbose=False) -> tuple[torch.Tensor, torch.Tensor, float]: atanh_mark = torch.randn(self.data_shape, device=env['device']) atanh_mark.requires_grad_() parameters: list[torch.Tensor] = [atanh_mark] mask = torch.ones(self.data_shape[1:], device=env['device']) atanh_mask = torch.ones(self.data_shape[1:], device=env['device']) if use_mask: atanh_mask.requires_grad_() parameters.append(atanh_mask) mask = tanh_func(atanh_mask) # (h, w) mark = tanh_func(atanh_mark) # (c, h, w) optimizer = optim.Adam(parameters, lr=self.remask_lr if use_mask else 0.01 * self.remask_lr) optimizer.zero_grad() # best optimization results mark_best = None loss_best = float('inf') mask_best = None for _epoch in range(self.remask_epoch): epoch_start = time.perf_counter() loss = self.abs_loss(_input, mask, mark, layer=layer, neuron=neuron, use_mask=use_mask) loss.backward() optimizer.step() optimizer.zero_grad() mark = tanh_func(atanh_mark) # (c, h, w) if use_mask: mask = tanh_func(atanh_mask) # (h, w) with torch.no_grad(): X = _input + mask * (mark - _input) _output = self.model(X) acc = float(_output.argmax(dim=1).eq(label).float().mean()) * 100 loss = float(loss) if verbose: norm = mask.norm(p=1) epoch_time = str( datetime.timedelta(seconds=int(time.perf_counter() - epoch_start))) pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, self.remask_epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {loss:10.3f},'.ljust(20), f'Acc: {acc:.3f}, '.ljust(20), f'Norm: {norm:.3f},'.ljust(20), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, indent=8) if loss < loss_best: loss_best = loss mark_best = mark if use_mask: mask_best = mask if validate_interval != 0 and verbose: if ( _epoch + 1 ) % validate_interval == 0 or _epoch == self.remask_epoch - 1: self.attack.mark.mark = mark self.attack.mark.alpha_mask = mask self.attack.mark.mask = torch.ones_like(mark, dtype=torch.bool) self.attack.target_class = label self.model._validate(print_prefix='Validate Trigger Tgt', get_data_fn=self.attack.get_data, keep_org=False, indent=8) print() atanh_mark.requires_grad = False if use_mask: atanh_mask.requires_grad = False return mark_best, mask_best, loss_best