Exemple #1
0
    def get_datasets(
            self) -> tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
        r"""Get clean and poison datasets.

        Returns:
            (torch.utils.data.Dataset, torch.utils.data.Dataset):
                Clean training dataset and poison training dataset.
        """
        if self.attack.poison_set is None:
            self.attack.poison_set = self.attack.get_poison_dataset(
                poison_num=len(self.dataset.loader['train'].dataset))
        if not self.defense_input_num:
            return self.dataset.loader['train'].dataset, self.attack.poison_set
        if self.attack.train_mode != 'dataset':
            poison_num = int(self.defense_input_num *
                             self.attack.poison_percent)
            clean_num = self.defense_input_num - poison_num
            clean_input, clean_label = sample_batch(
                self.dataset.loader['train'].dataset, batch_size=clean_num)
            trigger_input, trigger_label = sample_batch(self.attack.poison_set,
                                                        batch_size=poison_num)
            clean_set = TensorListDataset(clean_input, clean_label.tolist())
            poison_set = TensorListDataset(trigger_input,
                                           trigger_label.tolist())
        return clean_set, poison_set
    def sample_data(self) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
        r"""Sample data from each class. The returned data dict is:

        * ``'other'``: ``(input, label)`` from source classes with batch size
          ``self.class_sample_num * len(source_class)``.
        * ``'target'``: ``(input, label)`` from target class with batch size
          ``self.class_sample_num``.

        Returns:
            dict[str, tuple[torch.Tensor, torch.Tensor]]: Data dict.
        """
        source_class = self.source_class or list(
            range(self.dataset.num_classes))
        source_class = source_class.copy()
        if self.target_class in source_class:
            source_class.remove(self.target_class)
        other_x, other_y = [], []
        dataset = self.dataset.get_dataset('train')
        for _class in source_class:
            class_set = self.dataset.get_class_subset(dataset,
                                                      class_list=[_class])
            _input, _label = sample_batch(class_set,
                                          batch_size=self.class_sample_num)
            other_x.append(_input)
            other_y.append(_label)
        other_x = torch.cat(other_x)
        other_y = torch.cat(other_y)
        target_set = self.dataset.get_class_subset(
            dataset, class_list=[self.target_class])
        target_x, target_y = sample_batch(target_set,
                                          batch_size=self.class_sample_num)
        data = {'other': (other_x, other_y), 'target': (target_x, target_y)}
        return data
Exemple #3
0
    def validate_mask_generator(self):
        loader = self.dataset.loader['valid']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, norm=None)
        idx = torch.randperm(len(dataset))
        pos = 0

        print_prefix = 'Validate'
        header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
        header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header))
        for data in logger.log_every(loader, header=header):
            _input, _label = self.model.get_data(data)
            batch_size = len(_input)
            data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
            _input2, _label2 = self.model.get_data(data2)
            pos += batch_size

            _mask = self.get_mask(_input)
            _mask2 = self.get_mask(_input2)

            input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1)
            mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5

            loss_div = input_dist.div(mask_dist).mean()
            loss_norm = _mask.sub(self.mask_density).relu().mean()

            loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div
            logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item())
Exemple #4
0
    def gen_seed_data(self) -> dict[str, np.ndarray]:
        r"""Generate seed data.

        Returns:
            dict[str, numpy.ndarray]:
                Seed data dict with keys ``'input'`` and ``'label'``.
        """
        torch.manual_seed(env['seed'])
        if self.seed_data_num % self.model.num_classes:
            raise ValueError(
                f'seed_data_num({self.seed_data_num:d}) % num_classes({self.model.num_classes:d}) should be 0.'
            )
        seed_class_num: int = self.seed_data_num // self.model.num_classes
        x, y = [], []
        for _class in range(self.model.num_classes):
            class_set = self.dataset.get_dataset(mode='train',
                                                 class_list=[_class])
            _input, _label = sample_batch(class_set, batch_size=seed_class_num)
            x.append(_input)
            y.append(_label)
        x = torch.cat(x).numpy()
        y = torch.cat(y).numpy()
        seed_data = {'input': x, 'label': y}
        seed_path = os.path.join(self.folder_path,
                                 f'seed_{self.seed_data_num}.npz')
        np.savez(seed_path, **seed_data)
        print('seed data saved at: ', seed_path)
        return seed_data
Exemple #5
0
    def get_poison_dataset(self, poison_num: int = None, load_mark: bool = True,
                           seed: int = None) -> torch.utils.data.Dataset:
        r"""Get poison dataset from target class (no clean data).

        Args:
            poison_num (int): Number of poison data.
                Defaults to ``self.poison_num``
            load_mark (bool): Whether to load previously saved watermark.
                This should be ``False`` during attack.
                Defaults to ``True``.
            seed (int): Random seed to sample poison input indices.
                Defaults to ``env['data_seed']``.

        Returns:
            torch.utils.data.Dataset:
                Poison dataset from target class (no clean data).
        """
        file_path = os.path.join(self.folder_path, self.get_filename() + '.npy')
        if load_mark:
            if os.path.isfile(file_path):
                self.load_mark = False
                self.mark.load_mark(file_path, already_processed=True)
            else:
                raise FileNotFoundError(file_path)
        if seed is None:
            seed = env['data_seed']
        torch.random.manual_seed(seed)
        poison_num = min(poison_num or self.poison_num, len(self.target_set))
        _input, _label = sample_batch(self.target_set, batch_size=poison_num)
        _label = _label.tolist()
        trigger_input = self.add_mark(_input)
        return TensorListDataset(trigger_input, _label)
Exemple #6
0
    def get_poison_dataset(self, poison_label: bool = True,
                           poison_num: int = None,
                           seed: int = None
                           ) -> torch.utils.data.Dataset:
        r"""Get poison dataset (no clean data).

        Args:
            poison_label (bool):
                Whether to use target poison label for poison data.
                Defaults to ``True``.
            poison_num (int): Number of poison data.
                Defaults to ``round(self.poison_ratio * len(train_set))``
            seed (int): Random seed to sample poison input indices.
                Defaults to ``env['data_seed']``.

        Returns:
            torch.utils.data.Dataset:
                Poison dataset (no clean data).
        """
        if seed is None:
            seed = env['data_seed']
        torch.random.manual_seed(seed)
        train_set = self.dataset.loader['train'].dataset
        poison_num = poison_num or round(self.poison_ratio * len(train_set))
        _input, _label = sample_batch(train_set, batch_size=poison_num)
        _label = _label.tolist()

        if poison_label:
            _label = [self.target_class] * len(_label)
        trigger_input = self.add_mark(_input)
        return TensorListDataset(trigger_input, _label)
Exemple #7
0
 def _get_cross_data(self, data: tuple[torch.Tensor, torch.Tensor],
                     **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
     r"""Get cross-trigger mode data.
     Sample another batch from train set
     and apply their marks and masks to current batch.
     """
     _input, _label = self.model.get_data(data)
     batch_size = len(_input)
     data2 = sample_batch(self.train_set, idx=self.idx[self.pos:self.pos + batch_size])
     _input2, _label2 = self.model.get_data(data2)
     self.pos += batch_size
     if self.pos >= len(self.idx):
         self.pos = 0
         self.idx = torch.randperm(len(self.idx))
     mark, mask = self.get_mark(_input2), self.get_mask(_input2)
     _input = _input + mask * (mark - _input)
     return _input, _label
Exemple #8
0
    def train_mask_generator(self, verbose: bool = True):
        r"""Train :attr:`self.mask_generator`."""
        optimizer = torch.optim.Adam(self.mask_generator.parameters(), lr=1e-2, betas=(0.5, 0.9))
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.train_mask_epochs)
        loader = self.dataset.loader['train']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, norm=None)
        print_prefix = 'Mask Epoch'
        for _epoch in range(self.train_mask_epochs):
            _epoch += 1
            idx = torch.randperm(len(dataset))
            pos = 0
            logger.reset()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                print_prefix, output_iter(_epoch, self.train_mask_epochs), **ansi)
            header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header))
            self.mask_generator.train()
            for data in logger.log_every(loader, header=header) if verbose else loader:
                optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                batch_size = len(_input)
                data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
                _input2, _label2 = self.model.get_data(data2)
                pos += batch_size

                _mask = self.get_mask(_input)
                _mask2 = self.get_mask(_input2)

                input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1)
                mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5

                loss_div = input_dist.div(mask_dist).mean()
                loss_norm = _mask.sub(self.mask_density).relu().mean()

                loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div
                loss.backward()
                optimizer.step()
                logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item())
            lr_scheduler.step()
            self.mask_generator.eval()
            if verbose and (_epoch % (max(self.train_mask_epochs // 5, 1)) == 0 or _epoch == self.train_mask_epochs):
                self.validate_mask_generator()
        optimizer.zero_grad()
Exemple #9
0
    def attack(self, epochs: int, optimizer: torch.optim.Optimizer, **kwargs):
        model_dict = copy.deepcopy(self.model.state_dict())
        W = torch.ones(len(self.reflect_imgs))
        refool_optimizer = torch.optim.SGD(optimizer.param_groups[0]['params'],
                                           lr=self.refool_lr, momentum=0.9,
                                           weight_decay=5e-4)
        # logger = MetricLogger(meter_length=35)
        # logger.create_meters(asr='{median:.3f} ({min:.3f}  {max:.3f})')
        # iterator = logger.log_every(range(self.rank_iter))
        for _iter in range(self.rank_iter):
            print('Select iteration: ', output_iter(_iter + 1, self.rank_iter))
            # prepare data
            idx = random.choices(range(len(W)), weights=W.tolist(), k=self.refool_sample_num)
            mark = torch.ones_like(self.mark.mark).expand(self.refool_sample_num, -1, -1, -1).clone()
            mark[:, :-1] = self.reflect_imgs[idx]
            clean_input, _ = sample_batch(self.target_set, self.refool_sample_num)
            trigger_input = self.add_mark(clean_input, mark=mark)
            dataset = TensorListDataset(trigger_input, [self.target_class] * len(trigger_input))
            loader = self.dataset.get_dataloader(mode='train', dataset=dataset)
            # train
            self.model._train(self.refool_epochs, optimizer=refool_optimizer,
                              loader_train=loader, validate_interval=0,
                              output_freq='epoch', indent=4)
            self.model._validate(indent=4)
            # test
            select_idx = list(set(idx))
            marks = self.reflect_imgs[select_idx]
            asr_result = self._get_asr_result(marks)
            # update W
            W[select_idx] = asr_result
            other_idx = list(set(range(len(W))) - set(idx))
            W[other_idx] = asr_result.median()

            # logger.reset().update_list(asr=asr_result)
            self.model.load_state_dict(model_dict)
        self.mark.mark[:-1] = self.reflect_imgs[W.argmax().item()]
        self.poison_set = self.get_poison_dataset(load_mark=False)
        return super().attack(epochs=epochs, optimizer=optimizer, **kwargs)
Exemple #10
0
    def attack(self, epochs: int, optimizer: torch.optim.Optimizer,
               lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
               validate_interval: int = 1, save: bool = False,
               verbose: bool = True, **kwargs):
        if verbose:
            print('train mask generator')
        self.mark_generator.requires_grad_(False)
        self.mask_generator.requires_grad_()
        self.model.requires_grad_(False)
        self.train_mask_generator(verbose=verbose)
        if verbose:
            print()
            print('train mark generator and model')

        self.mark_generator.requires_grad_()
        self.mask_generator.requires_grad_(False)
        if not self.natural:
            params: list[nn.Parameter] = []
            for param_group in optimizer.param_groups:
                params.extend(param_group['params'])
            self.model.activate_params(params)

        mark_optimizer = torch.optim.Adam(self.mark_generator.parameters(), lr=1e-2, betas=(0.5, 0.9))
        mark_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            mark_optimizer, T_max=epochs)
        loader = self.dataset.loader['train']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, ce=None)

        if validate_interval != 0:
            best_validate_result = self.validate_fn(verbose=verbose)
            best_asr = best_validate_result[0]
        for _epoch in range(epochs):
            _epoch += 1
            idx = torch.randperm(len(dataset))
            pos = 0
            logger.reset()
            if not self.natural:
                self.model.train()
            self.mark_generator.train()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            for data in logger.log_every(loader, header=header) if verbose else loader:
                if not self.natural:
                    optimizer.zero_grad()
                mark_optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                batch_size = len(_input)
                data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
                _input2, _label2 = self.model.get_data(data2)
                pos += batch_size
                final_input, final_label = _input.clone(), _label.clone()

                # generate trigger input
                trigger_dec, trigger_int = math.modf(len(_label) * self.poison_percent)
                trigger_int = int(trigger_int)
                if random.uniform(0, 1) < trigger_dec:
                    trigger_int += 1
                x = _input[:trigger_int]
                trigger_mark, trigger_mask = self.get_mark(x), self.get_mask(x)
                trigger_input = x + trigger_mask * (trigger_mark - x)
                final_input[:trigger_int] = trigger_input
                final_label[:trigger_int] = self.target_class

                # generate cross input
                cross_dec, cross_int = math.modf(len(_label) * self.cross_percent)
                cross_int = int(cross_int)
                if random.uniform(0, 1) < cross_dec:
                    cross_int += 1
                x = _input[trigger_int:trigger_int + cross_int]
                x2 = _input2[trigger_int:trigger_int + cross_int]
                cross_mark, cross_mask = self.get_mark(x2), self.get_mask(x2)
                cross_input = x + cross_mask * (cross_mark - x)
                final_input[trigger_int:trigger_int + cross_int] = cross_input

                # div loss
                if len(trigger_input) <= len(cross_input):
                    length = len(trigger_input)
                    cross_input = cross_input[:length]
                    cross_mark = cross_mark[:length]
                    cross_mask = cross_mask[:length]
                else:
                    length = len(cross_input)
                    trigger_input = trigger_input[:length]
                    trigger_mark = trigger_mark[:length]
                    trigger_mask = trigger_mask[:length]
                input_dist: torch.Tensor = (trigger_input - cross_input).flatten(1).norm(p=2, dim=1)
                mark_dist: torch.Tensor = (trigger_mark - cross_mark).flatten(1).norm(p=2, dim=1) + 1e-5

                loss_ce = self.model.loss(final_input, final_label)
                loss_div = input_dist.div(mark_dist).mean()

                loss = loss_ce + self.lambda_div * loss_div
                loss.backward()
                if not self.natural:
                    optimizer.step()
                mark_optimizer.step()
                logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), ce=loss_ce.item())
            if not self.natural and lr_scheduler:
                lr_scheduler.step()
            mark_scheduler.step()
            if not self.natural:
                self.model.eval()
            self.mark_generator.eval()
            if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs):
                validate_result = self.validate_fn(verbose=verbose)
                cur_asr = validate_result[0]
                if cur_asr >= best_asr:
                    best_validate_result = validate_result
                    best_asr = cur_asr
                    if save:
                        self.save()
        if not self.natural:
            optimizer.zero_grad()
        mark_optimizer.zero_grad()
        self.mark_generator.requires_grad_(False)
        self.mask_generator.requires_grad_(False)
        self.model.requires_grad_(False)
        return best_validate_result
Exemple #11
0
    def attack(self, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs):

        target_class_set = self.dataset.get_dataset('train', class_list=[self.target_class])
        target_imgs, _ = sample_batch(target_class_set, batch_size=self.poison_num)
        target_imgs = target_imgs.to(env['device'])

        full_set = self.dataset.get_dataset('train')
        poison_set: TensorListDataset = None    # TODO
        match self.poison_generation_method:
            case 'pgd':
                trigger_label = self.target_class * torch.ones(
                    len(target_imgs), dtype=torch.long, device=target_imgs.device)
                result = []
                for data in zip(target_imgs.chunk(self.dataset.batch_size),
                                trigger_label.chunk(self.dataset.batch_size)):
                    poison_img, _ = self.model.remove_misclassify(data)
                    poison_img, _ = self.pgd.optimize(poison_img)
                    poison_img = self.add_mark(poison_img).cpu()
                    result.append(poison_img)
                poison_imgs = torch.cat(result)

                poison_set = TensorListDataset(poison_imgs, [self.target_class] * len(poison_imgs))
                # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset])

            case 'gan':
                other_classes = list(range(self.dataset.num_classes))
                other_classes.pop(self.target_class)
                x_list = []
                y_list = []
                for source_class in other_classes:
                    print('Process data of Source Class: ', source_class)
                    source_class_dataset = self.dataset.get_dataset(mode='train', class_list=[source_class])
                    sample_source_class_dataset, _ = self.dataset.split_dataset(
                        source_class_dataset, self.poison_num)
                    source_imgs = dataset_to_tensor(sample_source_class_dataset)[0].to(device=env['device'])

                    g_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_g.pth'
                    d_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_d.pth'
                    if os.path.exists(g_path) and os.path.exists(d_path) and not self.train_gan:
                        self.wgan.G.load_state_dict(torch.load(g_path, map_location=env['device']))
                        self.wgan.D.load_state_dict(torch.load(d_path, map_location=env['device']))
                        print(f'    load model from: \n        {g_path}\n        {d_path}', )
                    else:
                        self.train_gan = True
                        self.wgan.reset_parameters()
                        gan_dataset = torch.utils.data.ConcatDataset([source_class_dataset, target_class_set])
                        gan_dataloader = self.dataset.get_dataloader(
                            mode='train', dataset=gan_dataset, batch_size=self.dataset.batch_size, num_workers=0)
                        self.wgan.train(gan_dataloader)
                        torch.save(self.wgan.G.state_dict(), g_path)
                        torch.save(self.wgan.D.state_dict(), d_path)
                        print(f'GAN Model Saved at : \n{g_path}\n{d_path}')
                        continue

                    for source_chunk, target_chunk in zip(source_imgs.chunk(self.dataset.batch_size),
                                                          target_imgs.chunk(self.dataset.batch_size)):
                        source_encode = self.wgan.get_encode_value(source_chunk).detach()
                        target_encode = self.wgan.get_encode_value(target_chunk).detach()
                        # noise = torch.randn_like(source_encode)
                        # source_img = self.wgan.G(source_encode)
                        # target_img = self.wgan.G(target_encode)
                        # if not os.path.exists('./imgs'):
                        #     os.makedirs('./imgs')
                        # for i in range(len(source_img)):
                        #     F.to_pil_image(source_img[i]).save(f'./imgs/source_{i}.png')
                        # for i in range(len(target_img)):
                        #     F.to_pil_image(target_img[i]).save(f'./imgs/target_{i}.png')
                        # exit()
                        interpolation_encode = source_encode * self.tau + target_encode * (1 - self.tau)
                        poison_imgs = self.wgan.G(interpolation_encode).detach()
                        poison_imgs = self.add_mark(poison_imgs)

                        poison_imgs = poison_imgs.cpu()
                        x_list.append(poison_imgs)
                    y_list.extend([self.target_class] * len(source_imgs))
                assert not self.train_gan
                x_list = torch.cat(x_list)
                poison_set = TensorListDataset(x_list, y_list)
                # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset])
        final_set = torch.utils.data.ConcatDataset([poison_set, full_set])
        # final_set = poison_set
        final_loader = self.dataset.get_dataloader(mode='train', dataset=final_set, num_workers=0)
        self.model._train(optimizer=optimizer, lr_scheduler=lr_scheduler, save_fn=self.save,
                          loader_train=final_loader, validate_fn=self.validate_fn, **kwargs)