def discrim_train(self, epoch: int, D: nn.Sequential, discrim_loader: torch.utils.data.DataLoader): losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') d_optimizer = optim.Adam(D.parameters(), lr=self.discrim_lr) d_optimizer.zero_grad() for _epoch in range(epoch): losses.reset() top1.reset() self.model.activate_params([D.parameters()]) D.train() for data in discrim_loader: # train D _input, _label = self.model.get_data(data) out_f = self.model.get_final_fm(_input).detach() out_d = D(out_f) loss_d = self.model.criterion(out_d, _label) acc1 = self.model.accuracy(out_d, _label, topk=(1, ))[0] batch_size = int(_label.size(0)) losses.update(loss_d.item(), batch_size) top1.update(acc1, batch_size) loss_d.backward() d_optimizer.step() d_optimizer.zero_grad() print(f'Discriminator - epoch {_epoch:4d} / {epoch:4d} | loss {losses.avg:.4f} | acc {top1.avg:.4f}') self.model.activate_params([]) D.eval()
def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None, validate_interval=10, save=False, verbose=True, indent=0, epoch_fn: Callable = None, **kwargs): loader_train = self.dataset.loader['train'] file_path = self.folder_path + self.get_filename() + '.pth' _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs) losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') params = [param_group['params'] for param_group in optimizer.param_groups] for _epoch in range(epoch): if callable(epoch_fn): self.model.activate_params([]) epoch_fn() self.model.activate_params(params) losses.reset() top1.reset() top5.reset() epoch_start = time.perf_counter() if verbose and env['tqdm']: loader_train = tqdm(loader_train) optimizer.zero_grad() for data in loader_train: _input, _label = self.model.get_data(data) noise = torch.zeros_like(_input) poison_input, poison_label = self.get_poison_data(data) def loss_fn(X: torch.FloatTensor): return -self.model.loss(X, _label) adv_x = _input self.model.train() loss = self.model.loss(adv_x, _label) loss.backward() optimizer.step() optimizer.zero_grad() for m in range(self.pgd.iteration): self.model.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1) optimizer.zero_grad() self.model.train() x = torch.cat((adv_x, poison_input)) y = torch.cat((_label, poison_label)) loss = self.model.loss(x, y) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): _output = self.model.get_logits(_input) acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) losses.update(loss.item(), batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) epoch_time = str(datetime.timedelta(seconds=int( time.perf_counter() - epoch_start))) self.model.eval() self.model.activate_params([]) if verbose: pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30), f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=indent) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1: _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs) if cur_acc < best_acc: prints('best result update!', indent=indent) prints(f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: self.save() if verbose: print('-' * 50) self.model.zero_grad()
def remask(self, label: int): epoch = self.epoch # no bound atanh_mark = torch.randn(self.data_shape, device=env['device']) atanh_mark.requires_grad_() atanh_mask = torch.randn(self.data_shape[1:], device=env['device']) atanh_mask.requires_grad_() mask = tanh_func(atanh_mask) # (h, w) mark = tanh_func(atanh_mark) # (c, h, w) optimizer = optim.Adam([atanh_mark, atanh_mask], lr=0.1, betas=(0.5, 0.9)) optimizer.zero_grad() cost = self.init_cost cost_set_counter = 0 cost_up_counter = 0 cost_down_counter = 0 cost_up_flag = False cost_down_flag = False # best optimization results norm_best = float('inf') mask_best = None mark_best = None entropy_best = None # counter for early stop early_stop_counter = 0 early_stop_norm_best = norm_best losses = AverageMeter('Loss', ':.4e') entropy = AverageMeter('Entropy', ':.4e') norm = AverageMeter('Norm', ':.4e') acc = AverageMeter('Acc', ':6.2f') for _epoch in range(epoch): losses.reset() entropy.reset() norm.reset() acc.reset() epoch_start = time.perf_counter() loader = self.dataset.loader['train'] if env['tqdm']: loader = tqdm(loader) for data in loader: _input, _label = self.model.get_data(data) batch_size = _label.size(0) X = _input + mask * (mark - _input) Y = label * torch.ones_like(_label, dtype=torch.long) _output = self.model(X) batch_acc = Y.eq(_output.argmax(1)).float().mean() batch_entropy = self.loss_fn(_input, _label, Y, mask, mark, label) batch_norm = mask.norm(p=1) batch_loss = batch_entropy + cost * batch_norm acc.update(batch_acc.item(), batch_size) entropy.update(batch_entropy.item(), batch_size) norm.update(batch_norm.item(), batch_size) losses.update(batch_loss.item(), batch_size) batch_loss.backward() optimizer.step() optimizer.zero_grad() mask = tanh_func(atanh_mask) # (h, w) mark = tanh_func(atanh_mark) # (c, h, w) epoch_time = str( datetime.timedelta(seconds=int(time.perf_counter() - epoch_start))) pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Acc: {acc.avg:.2f}, '.ljust(20), f'Norm: {norm.avg:.4f},'.ljust(20), f'Entropy: {entropy.avg:.4f},'.ljust(20), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format( **ansi) if env['tqdm'] else '', indent=4) # check to save best mask or not if acc.avg >= self.attack_succ_threshold and norm.avg < norm_best: mask_best = mask.detach() mark_best = mark.detach() norm_best = norm.avg entropy_best = entropy.avg # check early stop if self.early_stop: # only terminate if a valid attack has been found if norm_best < float('inf'): if norm_best >= self.early_stop_threshold * early_stop_norm_best: early_stop_counter += 1 else: early_stop_counter = 0 early_stop_norm_best = min(norm_best, early_stop_norm_best) if cost_down_flag and cost_up_flag and early_stop_counter >= self.early_stop_patience: print('early stop') break # check cost modification if cost == 0 and acc.avg >= self.attack_succ_threshold: cost_set_counter += 1 if cost_set_counter >= self.patience: cost = self.init_cost cost_up_counter = 0 cost_down_counter = 0 cost_up_flag = False cost_down_flag = False print('initialize cost to %.2f' % cost) else: cost_set_counter = 0 if acc.avg >= self.attack_succ_threshold: cost_up_counter += 1 cost_down_counter = 0 else: cost_up_counter = 0 cost_down_counter += 1 if cost_up_counter >= self.patience: cost_up_counter = 0 prints('up cost from %.4f to %.4f' % (cost, cost * self.cost_multiplier_up), indent=4) cost *= self.cost_multiplier_up cost_up_flag = True elif cost_down_counter >= self.patience: cost_down_counter = 0 prints('down cost from %.4f to %.4f' % (cost, cost / self.cost_multiplier_down), indent=4) cost /= self.cost_multiplier_down cost_down_flag = True if mask_best is None: mask_best = tanh_func(atanh_mask).detach() mark_best = tanh_func(atanh_mark).detach() norm_best = norm.avg entropy_best = entropy.avg atanh_mark.requires_grad = False atanh_mask.requires_grad = False self.attack.mark.mark = mark_best self.attack.mark.alpha_mark = mask_best self.attack.mark.mask = torch.ones_like(mark_best, dtype=torch.bool) self.attack.validate_fn() return mark_best, mask_best, entropy_best
def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, validate_interval: int = 10, save: bool = False, amp: bool = False, verbose: bool = True, indent: int = 0, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, get_data_fn: Callable[..., tuple[InputType, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, validate_func: Callable[..., tuple[float, ...]] = None, epoch_func: Callable[[], None] = None, save_fn: Callable = None, file_path: str = None, folder_path: str = None, suffix: str = None, **kwargs): loader_train = loader_train if loader_train is not None else self.dataset.loader['train'] get_data_fn = get_data_fn if get_data_fn is not None else self.get_data loss_fn = loss_fn if loss_fn is not None else self.loss validate_func = validate_func if validate_func is not None else self._validate save_fn = save_fn if save_fn is not None else self.save scaler: torch.cuda.amp.GradScaler = None if amp and env['num_gpus']: scaler = torch.cuda.amp.GradScaler() _, best_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, verbose=verbose, indent=indent, **kwargs) losses = AverageMeter('Loss') top1 = AverageMeter('Acc@1') top5 = AverageMeter('Acc@5') params: list[list[nn.Parameter]] = [param_group['params'] for param_group in optimizer.param_groups] for _epoch in range(epoch): if epoch_func is not None: self.activate_params([]) epoch_func() self.activate_params(params) losses.reset() top1.reset() top5.reset() epoch_start = time.perf_counter() loader = loader_train if verbose and env['tqdm']: loader = tqdm(loader_train) self.train() self.activate_params(params) optimizer.zero_grad() for data in loader: # data_time.update(time.perf_counter() - end) _input, _label = get_data_fn(data, mode='train') if amp and env['num_gpus']: loss = loss_fn(_input, _label, amp=True) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss = loss_fn(_input, _label) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): _output = self.get_logits(_input) acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) losses.update(loss.item(), batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) empty_cache() epoch_time = str(datetime.timedelta(seconds=int( time.perf_counter() - epoch_start))) self.eval() self.activate_params([]) if verbose: pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Acc: {top1.avg:.3f}, '.ljust(20), f'Top5 Acc: {top5.avg:.3f},'.ljust(20), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=indent) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1: _, cur_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, verbose=verbose, indent=indent, **kwargs) if cur_acc >= best_acc: prints('best result update!', indent=indent) prints(f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose) if verbose: print('-' * 50) self.zero_grad()
def remask(self, label: int) -> tuple[torch.Tensor, torch.Tensor]: generator = Generator(self.noise_dim, self.dataset.num_classes, self.dataset.data_shape) for param in generator.parameters(): param.requires_grad_() optimizer = optim.Adam(generator.parameters(), lr=self.remask_lr) optimizer.zero_grad() # mask = self.attack.mark.mask losses = AverageMeter('Loss', ':.4e') entropy = AverageMeter('Entropy', ':.4e') norm = AverageMeter('Norm', ':.4e') acc = AverageMeter('Acc', ':6.2f') torch.manual_seed(env['seed']) noise = torch.rand(1, self.noise_dim, device=env['device']) mark = torch.zeros(self.dataset.data_shape, device=env['device']) for _epoch in range(self.remask_epoch): losses.reset() entropy.reset() norm.reset() acc.reset() epoch_start = time.perf_counter() loader = self.loader if env['tqdm']: loader = tqdm(loader) for data in loader: _input, _label = self.model.get_data(data) batch_size = _label.size(0) poison_label = label * torch.ones_like(_label) mark = generator( noise, torch.tensor([label], device=poison_label.device, dtype=poison_label.dtype)) poison_input = (_input + mark).clamp(0, 1) _output = self.model(poison_input) batch_acc = poison_label.eq(_output.argmax(1)).float().mean() batch_entropy = self.model.criterion(_output, poison_label) batch_norm = mark.flatten(start_dim=1).norm(p=1, dim=1).mean() batch_loss = batch_entropy + self.gamma_2 * batch_norm acc.update(batch_acc.item(), batch_size) entropy.update(batch_entropy.item(), batch_size) norm.update(batch_norm.item(), batch_size) losses.update(batch_loss.item(), batch_size) batch_loss.backward() optimizer.step() optimizer.zero_grad() epoch_time = str( datetime.timedelta(seconds=int(time.perf_counter() - epoch_start))) pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, self.remask_epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Acc: {acc.avg:.2f}, '.ljust(20), f'Norm: {norm.avg:.4f},'.ljust(20), f'Entropy: {entropy.avg:.4f},'.ljust(20), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format( **ansi) if env['tqdm'] else '', indent=4) def get_data_fn(data, **kwargs): _input, _label = self.model.get_data(data) poison_label = torch.ones_like(_label) * label poison_input = (_input + mark).clamp(0, 1) return poison_input, poison_label self.model._validate(print_prefix='Validate Trigger Tgt', get_data_fn=get_data_fn, indent=4) if not self.attack.mark.random_pos: overlap = jaccard_idx(mark.mean(dim=0), self.real_mask, select_num=self.attack.mark.mark_height * self.attack.mark.mark_width) print(f' Jaccard index: {overlap:.3f}') for param in generator.parameters(): param.requires_grad = False return losses.avg, mark