Ejemplo n.º 1
0
    def optimize_mark(self, loss_fn=None, **kwargs):
        atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask
        atanh_mark.requires_grad_()
        self.mark.mark = tanh_func(atanh_mark)
        optimizer = optim.Adam([atanh_mark], lr=self.inner_lr)
        optimizer.zero_grad()

        if loss_fn is None:
            loss_fn = self.model.loss

        losses = AverageMeter('Loss', ':.4e')
        for _epoch in range(self.inner_iter):
            for i, data in enumerate(self.dataset.loader['train']):
                if i > 20:
                    break
                _input, _label = self.model.get_data(data)
                poison_x = self.mark.add_mark(_input)
                loss = loss_fn(poison_x,
                               self.target_class * torch.ones_like(_label))
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                self.mark.mark = tanh_func(atanh_mark)
                losses.update(loss.item(), n=len(_label))
        atanh_mark.requires_grad = False
        self.mark.mark.detach_()
Ejemplo n.º 2
0
    def preprocess_mark(self, data: dict[str, tuple[torch.Tensor,
                                                    torch.Tensor]]):
        other_x, _ = data['other']
        other_set = TensorDataset(other_x)
        other_loader = self.dataset.get_dataloader(mode='train',
                                                   dataset=other_set,
                                                   num_workers=0)

        atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask
        atanh_mark.requires_grad_()
        self.mark.mark = tanh_func(atanh_mark)
        optimizer = optim.Adam([atanh_mark], lr=self.preprocess_lr)
        optimizer.zero_grad()

        losses = AverageMeter('Loss', ':.4e')
        for _epoch in range(self.preprocess_epoch):
            loader = other_loader
            for (batch_x, ) in loader:
                poison_x = self.mark.add_mark(to_tensor(batch_x))
                loss = self.loss_mse(poison_x)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                self.mark.mark = tanh_func(atanh_mark)
                losses.update(loss.item(), n=len(batch_x))
        atanh_mark.requires_grad = False
        self.mark.mark.detach_()
Ejemplo n.º 3
0
 def transform_func(self, x: torch.Tensor) -> torch.Tensor:
     if isinstance(self.input_transform, str):
         if self.input_transform == 'tanh':
             return tanh_func(x)
         elif self.input_transform in ['atan', 'arctan']:
             return atan_func(x)
         elif self.input_transform in ['sigmoid', 'logistic']:
             return torch.sigmoid(x)
         else:
             raise NotImplementedError(self.input_transform)
     # assert callable(self.input_transform)
     return self.input_transform(x)
Ejemplo n.º 4
0
    def preprocess_mark(self, data: dict[str, tuple[torch.Tensor,
                                                    torch.Tensor]]):
        other_x, _ = data['other']
        other_set = TensorDataset(other_x)
        other_loader = self.dataset.get_dataloader(mode='train',
                                                   dataset=other_set,
                                                   num_workers=0)

        atanh_mark = torch.randn_like(self.mark.mark) * self.mark.mask
        atanh_mark.requires_grad_()
        self.mark.mark = tanh_func(atanh_mark)
        optimizer = optim.Adam([atanh_mark], lr=self.preprocess_lr)
        optimizer.zero_grad()

        losses = AverageMeter('Loss', ':.4e')
        for _epoch in range(self.preprocess_epoch):
            # epoch_start = time.perf_counter()
            loader = other_loader
            # if env['tqdm']:
            #     loader = tqdm(loader)
            for (batch_x, ) in loader:
                poison_x = self.mark.add_mark(to_tensor(batch_x))
                loss = self.loss_mse(poison_x)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                self.mark.mark = tanh_func(atanh_mark)
                losses.update(loss.item(), n=len(batch_x))
            # epoch_time = str(datetime.timedelta(seconds=int(
            #     time.perf_counter() - epoch_start)))
            # pre_str = '{blue_light}Epoch: {0}{reset}'.format(
            #     output_iter(_epoch + 1, self.preprocess_epoch), **ansi).ljust(64 if env['color'] else 35)
            # _str = ' '.join([
            #     f'Loss: {losses.avg:.4f},'.ljust(20),
            #     f'Time: {epoch_time},'.ljust(20),
            # ])
            # prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=4)
        atanh_mark.requires_grad = False
        self.mark.mark.detach_()
Ejemplo n.º 5
0
    def remask(self, label: int):
        epoch = self.epoch
        # no bound
        atanh_mark = torch.randn(self.data_shape, device=env['device'])
        atanh_mark.requires_grad_()
        atanh_mask = torch.randn(self.data_shape[1:], device=env['device'])
        atanh_mask.requires_grad_()
        mask = tanh_func(atanh_mask)  # (h, w)
        mark = tanh_func(atanh_mark)  # (c, h, w)

        optimizer = optim.Adam([atanh_mark, atanh_mask],
                               lr=0.1,
                               betas=(0.5, 0.9))
        optimizer.zero_grad()

        cost = self.init_cost
        cost_set_counter = 0
        cost_up_counter = 0
        cost_down_counter = 0
        cost_up_flag = False
        cost_down_flag = False

        # best optimization results
        norm_best = float('inf')
        mask_best = None
        mark_best = None
        entropy_best = None

        # counter for early stop
        early_stop_counter = 0
        early_stop_norm_best = norm_best

        losses = AverageMeter('Loss', ':.4e')
        entropy = AverageMeter('Entropy', ':.4e')
        norm = AverageMeter('Norm', ':.4e')
        acc = AverageMeter('Acc', ':6.2f')

        for _epoch in range(epoch):
            losses.reset()
            entropy.reset()
            norm.reset()
            acc.reset()
            epoch_start = time.perf_counter()
            loader = self.dataset.loader['train']
            if env['tqdm']:
                loader = tqdm(loader)
            for data in loader:
                _input, _label = self.model.get_data(data)
                batch_size = _label.size(0)
                X = _input + mask * (mark - _input)
                Y = label * torch.ones_like(_label, dtype=torch.long)
                _output = self.model(X)

                batch_acc = Y.eq(_output.argmax(1)).float().mean()
                batch_entropy = self.loss_fn(_input, _label, Y, mask, mark,
                                             label)
                batch_norm = mask.norm(p=1)
                batch_loss = batch_entropy + cost * batch_norm

                acc.update(batch_acc.item(), batch_size)
                entropy.update(batch_entropy.item(), batch_size)
                norm.update(batch_norm.item(), batch_size)
                losses.update(batch_loss.item(), batch_size)

                batch_loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                mask = tanh_func(atanh_mask)  # (h, w)
                mark = tanh_func(atanh_mark)  # (c, h, w)
            epoch_time = str(
                datetime.timedelta(seconds=int(time.perf_counter() -
                                               epoch_start)))
            pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                output_iter(_epoch + 1, epoch),
                **ansi).ljust(64 if env['color'] else 35)
            _str = ' '.join([
                f'Loss: {losses.avg:.4f},'.ljust(20),
                f'Acc: {acc.avg:.2f}, '.ljust(20),
                f'Norm: {norm.avg:.4f},'.ljust(20),
                f'Entropy: {entropy.avg:.4f},'.ljust(20),
                f'Time: {epoch_time},'.ljust(20),
            ])
            prints(pre_str,
                   _str,
                   prefix='{upline}{clear_line}'.format(
                       **ansi) if env['tqdm'] else '',
                   indent=4)

            # check to save best mask or not
            if acc.avg >= self.attack_succ_threshold and norm.avg < norm_best:
                mask_best = mask.detach()
                mark_best = mark.detach()
                norm_best = norm.avg
                entropy_best = entropy.avg

            # check early stop
            if self.early_stop:
                # only terminate if a valid attack has been found
                if norm_best < float('inf'):
                    if norm_best >= self.early_stop_threshold * early_stop_norm_best:
                        early_stop_counter += 1
                    else:
                        early_stop_counter = 0
                early_stop_norm_best = min(norm_best, early_stop_norm_best)

                if cost_down_flag and cost_up_flag and early_stop_counter >= self.early_stop_patience:
                    print('early stop')
                    break

            # check cost modification
            if cost == 0 and acc.avg >= self.attack_succ_threshold:
                cost_set_counter += 1
                if cost_set_counter >= self.patience:
                    cost = self.init_cost
                    cost_up_counter = 0
                    cost_down_counter = 0
                    cost_up_flag = False
                    cost_down_flag = False
                    print('initialize cost to %.2f' % cost)
            else:
                cost_set_counter = 0

            if acc.avg >= self.attack_succ_threshold:
                cost_up_counter += 1
                cost_down_counter = 0
            else:
                cost_up_counter = 0
                cost_down_counter += 1

            if cost_up_counter >= self.patience:
                cost_up_counter = 0
                prints('up cost from %.4f to %.4f' %
                       (cost, cost * self.cost_multiplier_up),
                       indent=4)
                cost *= self.cost_multiplier_up
                cost_up_flag = True
            elif cost_down_counter >= self.patience:
                cost_down_counter = 0
                prints('down cost from %.4f to %.4f' %
                       (cost, cost / self.cost_multiplier_down),
                       indent=4)
                cost /= self.cost_multiplier_down
                cost_down_flag = True
            if mask_best is None:
                mask_best = tanh_func(atanh_mask).detach()
                mark_best = tanh_func(atanh_mark).detach()
                norm_best = norm.avg
                entropy_best = entropy.avg
        atanh_mark.requires_grad = False
        atanh_mask.requires_grad = False

        self.attack.mark.mark = mark_best
        self.attack.mark.alpha_mark = mask_best
        self.attack.mark.mask = torch.ones_like(mark_best, dtype=torch.bool)
        self.attack.validate_fn()
        return mark_best, mask_best, entropy_best
Ejemplo n.º 6
0
    def remask(self,
               _input: torch.Tensor,
               layer: str,
               neuron: int,
               label: int,
               use_mask: bool = True,
               validate_interval: int = 100,
               verbose=False) -> tuple[torch.Tensor, torch.Tensor, float]:
        atanh_mark = torch.randn(self.data_shape, device=env['device'])
        atanh_mark.requires_grad_()
        parameters: list[torch.Tensor] = [atanh_mark]
        mask = torch.ones(self.data_shape[1:], device=env['device'])
        atanh_mask = torch.ones(self.data_shape[1:], device=env['device'])
        if use_mask:
            atanh_mask.requires_grad_()
            parameters.append(atanh_mask)
            mask = tanh_func(atanh_mask)  # (h, w)
        mark = tanh_func(atanh_mark)  # (c, h, w)

        optimizer = optim.Adam(parameters,
                               lr=self.remask_lr if use_mask else 0.01 *
                               self.remask_lr)
        optimizer.zero_grad()

        # best optimization results
        mark_best = None
        loss_best = float('inf')
        mask_best = None

        for _epoch in range(self.remask_epoch):
            epoch_start = time.perf_counter()

            loss = self.abs_loss(_input,
                                 mask,
                                 mark,
                                 layer=layer,
                                 neuron=neuron,
                                 use_mask=use_mask)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            mark = tanh_func(atanh_mark)  # (c, h, w)
            if use_mask:
                mask = tanh_func(atanh_mask)  # (h, w)

            with torch.no_grad():
                X = _input + mask * (mark - _input)
                _output = self.model(X)
            acc = float(_output.argmax(dim=1).eq(label).float().mean()) * 100
            loss = float(loss)

            if verbose:
                norm = mask.norm(p=1)
                epoch_time = str(
                    datetime.timedelta(seconds=int(time.perf_counter() -
                                                   epoch_start)))
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, self.remask_epoch),
                    **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {loss:10.3f},'.ljust(20),
                    f'Acc: {acc:.3f}, '.ljust(20),
                    f'Norm: {norm:.3f},'.ljust(20),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, indent=8)
            if loss < loss_best:
                loss_best = loss
                mark_best = mark
                if use_mask:
                    mask_best = mask
            if validate_interval != 0 and verbose:
                if (
                        _epoch + 1
                ) % validate_interval == 0 or _epoch == self.remask_epoch - 1:
                    self.attack.mark.mark = mark
                    self.attack.mark.alpha_mask = mask
                    self.attack.mark.mask = torch.ones_like(mark,
                                                            dtype=torch.bool)
                    self.attack.target_class = label
                    self.model._validate(print_prefix='Validate Trigger Tgt',
                                         get_data_fn=self.attack.get_data,
                                         keep_org=False,
                                         indent=8)
                    print()
        atanh_mark.requires_grad = False
        if use_mask:
            atanh_mask.requires_grad = False
        return mark_best, mask_best, loss_best