예제 #1
0
    def get_datasets(
            self) -> tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
        r"""Get clean and poison datasets.

        Returns:
            (torch.utils.data.Dataset, torch.utils.data.Dataset):
                Clean training dataset and poison training dataset.
        """
        if self.attack.poison_set is None:
            self.attack.poison_set = self.attack.get_poison_dataset(
                poison_num=len(self.dataset.loader['train'].dataset))
        if not self.defense_input_num:
            return self.dataset.loader['train'].dataset, self.attack.poison_set
        if self.attack.train_mode != 'dataset':
            poison_num = int(self.defense_input_num *
                             self.attack.poison_percent)
            clean_num = self.defense_input_num - poison_num
            clean_input, clean_label = sample_batch(
                self.dataset.loader['train'].dataset, batch_size=clean_num)
            trigger_input, trigger_label = sample_batch(self.attack.poison_set,
                                                        batch_size=poison_num)
            clean_set = TensorListDataset(clean_input, clean_label.tolist())
            poison_set = TensorListDataset(trigger_input,
                                           trigger_label.tolist())
        return clean_set, poison_set
예제 #2
0
    def attack(self, epochs: int, **kwargs):
        trigger_x, trigger_y = self.syn_trigger_candidates()
        train_noise_x, train_noise_y = self.syn_random_noises(length=self.train_noise_num)
        valid_noise_x, valid_noise_y = self.syn_random_noises(length=self.valid_noise_num)
        train_set = TensorListDataset(torch.cat((trigger_x, train_noise_x)),
                                      trigger_y + train_noise_y)
        valid_set = TensorListDataset(torch.cat((trigger_x, valid_noise_x)),
                                      trigger_y + valid_noise_y)
        loader_train = self.dataset.get_dataloader('train', dataset=train_set)
        loader_valid = self.dataset.get_dataloader('valid', dataset=valid_set)

        trigger = trigger_x[self.target_class].view_as(self.mark.mark[0]).unsqueeze(0)
        mark = torch.cat((trigger, torch.ones_like(trigger)))
        self.mark.load_mark(mark, already_processed=True)
        self.mlp_model = ImageModel(name='mlpnet', model=_MLPNet,
                                    input_dim=self.all_point,
                                    output_dim=self.combination_number + 1,
                                    dataset=self.dataset,
                                    data_shape=[self.all_point],
                                    loss_weights=None)
        self.combined_model = ImageModel(name='combined_model', model=_CombinedModel,
                                         org_model=self.model._model,
                                         mlp_model=self.mlp_model._model,
                                         mark=self.mark, dataset=self.dataset,
                                         alpha=self.mlp_alpha,
                                         temperature=self.comb_temperature,
                                         amplify_rate=self.amplify_rate)
        optimizer = torch.optim.Adam(params=self.mlp_model.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        self.mlp_model._train(epochs, optimizer=optimizer,
                              lr_scheduler=lr_scheduler,
                              loader_train=loader_train,
                              loader_valid=loader_valid,
                              save_fn=self.save)
        return self.validate_fn()
예제 #3
0
    def __init__(self, mix_image_num: int = 100, clean_image_ratio: float = 0.95, retrain_epoch: int = 10, nb_clusters: int = 2, clustering_method: str = "KMeans", nb_dims: int = 10, reduce_method: str = "FastICA", cluster_analysis: str = "exclusionary-reclassification", **kwargs):
        super().__init__(**kwargs)

        self.mix_image_num = mix_image_num
        self.clean_image_ratio = clean_image_ratio
        self.clean_image_num = int(mix_image_num * clean_image_ratio)
        self.poison_image_num = self.mix_image_num - self.clean_image_num

        self.nb_clusters = nb_clusters
        self.clustering_method = clustering_method
        self.nb_dims = nb_dims
        self.reduce_method = reduce_method
        self.cluster_analysis = cluster_analysis

        self.retrain_epoch = retrain_epoch

        self.clean_dataset, _ = self.dataset.split_set(
            self.dataset.get_full_dataset(mode='train'), self.clean_image_num)
        # clean_dataset, _ = self.dataset.split_set(self.dataset.get_full_dataset(mode='train'), self.clean_image_num)
        # clean_dataloader = self.dataset.get_dataloader(mode='train', dataset=clean_dataset, batch_size=self.clean_image_num, num_workers=0)
        # clean_imgs, _ = self.model.get_data(next(iter(clean_dataloader)))
        # self.clean_dataset = TensorDataset(clean_imgs, _)

        poison_dataset, _ = self.dataset.split_set(self.dataset.get_full_dataset(mode='train'), self.poison_image_num)
        poison_dataloader = self.dataset.get_dataloader(
            mode='train', dataset=poison_dataset, batch_size=self.poison_image_num, num_workers=0)
        poison_imgs, _ = self.model.get_data(next(iter(poison_dataloader)))
        poison_imgs = self.attack.add_mark(poison_imgs).cpu()
        poison_label = [self.attack.target_class] * self.poison_image_num
        self.poison_dataset = TensorListDataset(poison_imgs, poison_label)
        self.mix_dataset = torch.utils.data.ConcatDataset([self.clean_dataset, self.poison_dataset])
        self.mix_dataloader = self.dataset.get_dataloader(
            mode='train', dataset=self.mix_dataset, batch_size=self.dataset.batch_size, num_workers=0)
예제 #4
0
    def get_poison_dataset(self, poison_num: int = None, load_mark: bool = True,
                           seed: int = None) -> torch.utils.data.Dataset:
        r"""Get poison dataset from target class (no clean data).

        Args:
            poison_num (int): Number of poison data.
                Defaults to ``self.poison_num``
            load_mark (bool): Whether to load previously saved watermark.
                This should be ``False`` during attack.
                Defaults to ``True``.
            seed (int): Random seed to sample poison input indices.
                Defaults to ``env['data_seed']``.

        Returns:
            torch.utils.data.Dataset:
                Poison dataset from target class (no clean data).
        """
        file_path = os.path.join(self.folder_path, self.get_filename() + '.npy')
        if load_mark:
            if os.path.isfile(file_path):
                self.load_mark = False
                self.mark.load_mark(file_path, already_processed=True)
            else:
                raise FileNotFoundError(file_path)
        if seed is None:
            seed = env['data_seed']
        torch.random.manual_seed(seed)
        poison_num = min(poison_num or self.poison_num, len(self.target_set))
        _input, _label = sample_batch(self.target_set, batch_size=poison_num)
        _label = _label.tolist()
        trigger_input = self.add_mark(_input)
        return TensorListDataset(trigger_input, _label)
예제 #5
0
    def get_poison_dataset(self, poison_label: bool = True,
                           poison_num: int = None,
                           seed: int = None
                           ) -> torch.utils.data.Dataset:
        r"""Get poison dataset (no clean data).

        Args:
            poison_label (bool):
                Whether to use target poison label for poison data.
                Defaults to ``True``.
            poison_num (int): Number of poison data.
                Defaults to ``round(self.poison_ratio * len(train_set))``
            seed (int): Random seed to sample poison input indices.
                Defaults to ``env['data_seed']``.

        Returns:
            torch.utils.data.Dataset:
                Poison dataset (no clean data).
        """
        if seed is None:
            seed = env['data_seed']
        torch.random.manual_seed(seed)
        train_set = self.dataset.loader['train'].dataset
        poison_num = poison_num or round(self.poison_ratio * len(train_set))
        _input, _label = sample_batch(train_set, batch_size=poison_num)
        _label = _label.tolist()

        if poison_label:
            _label = [self.target_class] * len(_label)
        trigger_input = self.add_mark(_input)
        return TensorListDataset(trigger_input, _label)
예제 #6
0
파일: strip.py 프로젝트: ain-soph/trojanzoo
    def get_pred_labels(self) -> torch.Tensor:
        r"""Get predicted labels for test inputs.

        Returns:
            torch.Tensor: ``torch.BoolTensor`` with shape ``(2 * defense_input_num)``.
        """
        logger = MetricLogger(meter_length=40)
        str_format = '{global_avg:5.3f} ({min:5.3f}, {max:5.3f})'
        logger.create_meters(clean_score=str_format, poison_score=str_format)
        test_set = TensorListDataset(self.test_input, self.test_label)
        test_loader = self.dataset.get_dataloader(mode='valid',
                                                  dataset=test_set)
        for data in logger.log_every(test_loader):
            _input, _label = self.model.get_data(data)
            trigger_input = self.attack.add_mark(_input)
            logger.meters['clean_score'].update_list(
                self.get_score(_input).tolist())
            logger.meters['poison_score'].update_list(
                self.get_score(trigger_input).tolist())
        clean_score = torch.as_tensor(logger.meters['clean_score'].deque)
        poison_score = torch.as_tensor(logger.meters['poison_score'].deque)
        clean_score_sorted = clean_score.msort()
        threshold_low = float(clean_score_sorted[int(self.strip_fpr *
                                                     len(poison_score))])
        entropy = torch.cat((clean_score, poison_score))
        print(f'Threshold: {threshold_low:5.3f}')
        return torch.where(entropy < threshold_low,
                           torch.ones_like(entropy).bool(),
                           torch.zeros_like(entropy).bool())
예제 #7
0
 def mix_dataset(self, poison_label: bool = True) -> torch.utils.data.Dataset:
     clean_dataset = self.dataset.loader['train'].dataset
     _input, _label = dataset_to_list(clean_dataset)
     _input = torch.stack(_input[self.poison_percent * len(clean_dataset)])
     if poison_label:
         _label = [self.target_class] * len(_label)
     poison_input = self.add_mark(_input)
     poison_dataset = TensorListDataset(poison_input, _label)
     return torch.utils.data.ConcatDataset([clean_dataset, poison_dataset])
예제 #8
0
    def mix_dataset(self) -> torch.utils.data.Dataset:
        clean_set = self.dataset.loader['train'].dataset
        subset, other_set = ImageSet.split_dataset(clean_set,
                                                   percent=self.poison_percent)
        if not len(subset):
            return clean_set
        _input, _label = dataset_to_tensor(subset)

        _label += torch.randint_like(_label,
                                     low=1,
                                     high=self.model.num_classes)
        _label %= self.model.num_classes
        poison_set = TensorListDataset(_input, _label.tolist())
        return torch.utils.data.ConcatDataset([poison_set, other_set])
예제 #9
0
 def get_pred_labels(self) -> torch.Tensor:
     logger = MetricLogger(meter_length=40)
     str_format = '{global_avg:5.3f} ({min:5.3f}, {max:5.3f})'
     logger.create_meters(cls_diff=str_format, jaccard_idx=str_format)
     test_set = TensorListDataset(self.test_input, self.test_label)
     test_loader = self.dataset.get_dataloader(mode='valid',
                                               dataset=test_set,
                                               batch_size=1)
     clean_list = []
     poison_list = []
     for data in logger.log_every(test_loader):
         _input: torch.Tensor = data[0]
         _input = _input.to(env['device'], non_blocking=True)
         trigger_input = self.attack.add_mark(_input)
         clean_list.append(self.get_pred_label(_input[0], logger=logger))
         poison_list.append(
             self.get_pred_label(trigger_input[0], logger=logger))
     return torch.as_tensor(clean_list + poison_list, dtype=torch.bool)
예제 #10
0
    def attack(self, epochs: int, optimizer: torch.optim.Optimizer, **kwargs):
        model_dict = copy.deepcopy(self.model.state_dict())
        W = torch.ones(len(self.reflect_imgs))
        refool_optimizer = torch.optim.SGD(optimizer.param_groups[0]['params'],
                                           lr=self.refool_lr, momentum=0.9,
                                           weight_decay=5e-4)
        # logger = MetricLogger(meter_length=35)
        # logger.create_meters(asr='{median:.3f} ({min:.3f}  {max:.3f})')
        # iterator = logger.log_every(range(self.rank_iter))
        for _iter in range(self.rank_iter):
            print('Select iteration: ', output_iter(_iter + 1, self.rank_iter))
            # prepare data
            idx = random.choices(range(len(W)), weights=W.tolist(), k=self.refool_sample_num)
            mark = torch.ones_like(self.mark.mark).expand(self.refool_sample_num, -1, -1, -1).clone()
            mark[:, :-1] = self.reflect_imgs[idx]
            clean_input, _ = sample_batch(self.target_set, self.refool_sample_num)
            trigger_input = self.add_mark(clean_input, mark=mark)
            dataset = TensorListDataset(trigger_input, [self.target_class] * len(trigger_input))
            loader = self.dataset.get_dataloader(mode='train', dataset=dataset)
            # train
            self.model._train(self.refool_epochs, optimizer=refool_optimizer,
                              loader_train=loader, validate_interval=0,
                              output_freq='epoch', indent=4)
            self.model._validate(indent=4)
            # test
            select_idx = list(set(idx))
            marks = self.reflect_imgs[select_idx]
            asr_result = self._get_asr_result(marks)
            # update W
            W[select_idx] = asr_result
            other_idx = list(set(range(len(W))) - set(idx))
            W[other_idx] = asr_result.median()

            # logger.reset().update_list(asr=asr_result)
            self.model.load_state_dict(model_dict)
        self.mark.mark[:-1] = self.reflect_imgs[W.argmax().item()]
        self.poison_set = self.get_poison_dataset(load_mark=False)
        return super().attack(epochs=epochs, optimizer=optimizer, **kwargs)
예제 #11
0
 def attack(self, epoch: int, save=False, **kwargs):
     if self.train_mode == 'batch':
         self.model._train(epoch,
                           save=save,
                           validate_func=self.validate_func,
                           get_data_fn=self.get_data,
                           save_fn=self.save,
                           **kwargs)
     elif self.train_mode == 'dataset':
         clean_dataset = self.dataset.loader['train'].dataset
         _input, _label = next(
             iter(
                 self.dataset.get_dataloader(
                     'train',
                     batch_size=int(self.poison_percent *
                                    len(clean_dataset)))))
         _label = torch.ones_like(_label) * self.target_class
         _label = _label.tolist()
         poison_input = self.add_mark(_input)
         poison_dataset = TensorListDataset(poison_input, _label)
         dataset = torch.utils.data.ConcatDataset(
             [clean_dataset, poison_dataset])
         loader = self.dataset.get_dataloader('train', dataset=dataset)
         self.model._train(epoch,
                           save=save,
                           validate_func=self.validate_func,
                           loader_train=loader,
                           save_fn=self.save,
                           **kwargs)
     elif self.train_mode == 'loss':
         self.model._train(epoch,
                           save=save,
                           validate_func=self.validate_func,
                           loss_fn=self.loss_fn,
                           save_fn=self.save,
                           **kwargs)
예제 #12
0
    def attack(self, optimizer: torch.optim.Optimizer,
               lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs):

        target_class_dataset = self.dataset.get_dataset(
            'train', full=True, classes=[self.target_class])

        sample_target_class_dataset, target_original_dataset = self.dataset.split_set(
            target_class_dataset, self.poison_num)
        sample_target_dataloader = self.dataset.get_dataloader(
            mode='train',
            dataset=sample_target_class_dataset,
            batch_size=self.poison_num,
            num_workers=0)
        target_imgs, _ = self.model.get_data(
            next(iter(sample_target_dataloader)))

        full_set = self.dataset.get_dataset('train', full=False)
        poison_set: TensorListDataset = None  # TODO
        if self.poison_generation_method == 'pgd':
            poison_label = self.target_class * torch.ones(
                len(target_imgs), dtype=torch.long, device=target_imgs.device)

            poison_imgs, _ = self.model.remove_misclassify(data=(target_imgs,
                                                                 poison_label))
            poison_imgs, _ = self.pgd.craft_example(_input=poison_imgs)
            poison_imgs = self.add_mark(poison_imgs).cpu()

            poison_label = [self.target_class] * len(target_imgs)
            poison_set = TensorListDataset(poison_imgs, poison_label)
            # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset])

        elif self.poison_generation_method == 'gan':
            other_classes = list(range(self.dataset.num_classes))
            other_classes.pop(self.target_class)
            x_list = []
            y_list = []
            for source_class in other_classes:
                print('Process data of Source Class: ', source_class)
                source_class_dataset = self.dataset.get_dataset(
                    mode='train', full=True, classes=[source_class])
                sample_source_class_dataset, _ = self.dataset.split_set(
                    source_class_dataset, self.poison_num)
                sample_source_class_dataloader = self.dataset.get_dataloader(
                    mode='train',
                    dataset=sample_source_class_dataset,
                    batch_size=self.poison_num,
                    num_workers=0)
                source_imgs, _ = self.model.get_data(
                    next(iter(sample_source_class_dataloader)))

                g_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_g.pth'
                d_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_d.pth'
                if os.path.exists(g_path) and os.path.exists(
                        d_path) and not self.train_gan:
                    self.wgan.G.load_state_dict(
                        torch.load(g_path, map_location=env['device']))
                    self.wgan.D.load_state_dict(
                        torch.load(d_path, map_location=env['device']))
                    print(
                        f'    load model from: \n        {g_path}\n        {d_path}',
                    )
                else:
                    self.train_gan = True
                    self.wgan.reset_parameters()
                    gan_dataset = torch.utils.data.ConcatDataset(
                        [source_class_dataset, target_class_dataset])
                    gan_dataloader = self.dataset.get_dataloader(
                        mode='train',
                        dataset=gan_dataset,
                        batch_size=self.dataset.batch_size,
                        num_workers=0)
                    self.wgan.train(gan_dataloader)
                    torch.save(self.wgan.G.state_dict(), g_path)
                    torch.save(self.wgan.D.state_dict(), d_path)
                    print(f'GAN Model Saved at : \n{g_path}\n{d_path}')
                    continue
                source_encode = self.wgan.get_encode_value(
                    source_imgs).detach()
                target_encode = self.wgan.get_encode_value(
                    target_imgs).detach()
                # noise = torch.randn_like(source_encode)
                # from trojanzoo.utils.tensor import save_tensor_as_img
                # source_img = self.wgan.G(source_encode)
                # target_img = self.wgan.G(target_encode)
                # for i in range(len(source_img)):
                #     save_tensor_as_img(f'./imgs/source_{i}.png', source_img[i])
                # for i in range(len(target_img)):
                #     save_tensor_as_img(f'./imgs/target_{i}.png', target_img[i])
                # exit()
                interpolation_encode = source_encode * self.tau + target_encode * (
                    1 - self.tau)
                poison_imgs = self.wgan.G(interpolation_encode).detach()
                poison_imgs = self.add_mark(poison_imgs)

                poison_label = [self.target_class] * len(poison_imgs)
                poison_imgs = poison_imgs.cpu()
                x_list.append(poison_imgs)
                y_list.extend(poison_label)
            assert not self.train_gan
            x_list = torch.cat(x_list)
            poison_set = TensorListDataset(x_list, y_list)
            # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset])
        final_set = torch.utils.data.ConcatDataset([poison_set, full_set])
        # final_set = poison_set
        final_loader = self.dataset.get_dataloader(mode='train',
                                                   dataset=final_set,
                                                   num_workers=0)
        self.model._train(optimizer=optimizer,
                          lr_scheduler=lr_scheduler,
                          save_fn=self.save,
                          loader_train=final_loader,
                          validate_fn=self.validate_fn,
                          **kwargs)
예제 #13
0
    def attack(self, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs):

        target_class_set = self.dataset.get_dataset('train', class_list=[self.target_class])
        target_imgs, _ = sample_batch(target_class_set, batch_size=self.poison_num)
        target_imgs = target_imgs.to(env['device'])

        full_set = self.dataset.get_dataset('train')
        poison_set: TensorListDataset = None    # TODO
        match self.poison_generation_method:
            case 'pgd':
                trigger_label = self.target_class * torch.ones(
                    len(target_imgs), dtype=torch.long, device=target_imgs.device)
                result = []
                for data in zip(target_imgs.chunk(self.dataset.batch_size),
                                trigger_label.chunk(self.dataset.batch_size)):
                    poison_img, _ = self.model.remove_misclassify(data)
                    poison_img, _ = self.pgd.optimize(poison_img)
                    poison_img = self.add_mark(poison_img).cpu()
                    result.append(poison_img)
                poison_imgs = torch.cat(result)

                poison_set = TensorListDataset(poison_imgs, [self.target_class] * len(poison_imgs))
                # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset])

            case 'gan':
                other_classes = list(range(self.dataset.num_classes))
                other_classes.pop(self.target_class)
                x_list = []
                y_list = []
                for source_class in other_classes:
                    print('Process data of Source Class: ', source_class)
                    source_class_dataset = self.dataset.get_dataset(mode='train', class_list=[source_class])
                    sample_source_class_dataset, _ = self.dataset.split_dataset(
                        source_class_dataset, self.poison_num)
                    source_imgs = dataset_to_tensor(sample_source_class_dataset)[0].to(device=env['device'])

                    g_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_g.pth'
                    d_path = f'{self.folder_path}gan_dim{self.noise_dim}_class{source_class}_d.pth'
                    if os.path.exists(g_path) and os.path.exists(d_path) and not self.train_gan:
                        self.wgan.G.load_state_dict(torch.load(g_path, map_location=env['device']))
                        self.wgan.D.load_state_dict(torch.load(d_path, map_location=env['device']))
                        print(f'    load model from: \n        {g_path}\n        {d_path}', )
                    else:
                        self.train_gan = True
                        self.wgan.reset_parameters()
                        gan_dataset = torch.utils.data.ConcatDataset([source_class_dataset, target_class_set])
                        gan_dataloader = self.dataset.get_dataloader(
                            mode='train', dataset=gan_dataset, batch_size=self.dataset.batch_size, num_workers=0)
                        self.wgan.train(gan_dataloader)
                        torch.save(self.wgan.G.state_dict(), g_path)
                        torch.save(self.wgan.D.state_dict(), d_path)
                        print(f'GAN Model Saved at : \n{g_path}\n{d_path}')
                        continue

                    for source_chunk, target_chunk in zip(source_imgs.chunk(self.dataset.batch_size),
                                                          target_imgs.chunk(self.dataset.batch_size)):
                        source_encode = self.wgan.get_encode_value(source_chunk).detach()
                        target_encode = self.wgan.get_encode_value(target_chunk).detach()
                        # noise = torch.randn_like(source_encode)
                        # source_img = self.wgan.G(source_encode)
                        # target_img = self.wgan.G(target_encode)
                        # if not os.path.exists('./imgs'):
                        #     os.makedirs('./imgs')
                        # for i in range(len(source_img)):
                        #     F.to_pil_image(source_img[i]).save(f'./imgs/source_{i}.png')
                        # for i in range(len(target_img)):
                        #     F.to_pil_image(target_img[i]).save(f'./imgs/target_{i}.png')
                        # exit()
                        interpolation_encode = source_encode * self.tau + target_encode * (1 - self.tau)
                        poison_imgs = self.wgan.G(interpolation_encode).detach()
                        poison_imgs = self.add_mark(poison_imgs)

                        poison_imgs = poison_imgs.cpu()
                        x_list.append(poison_imgs)
                    y_list.extend([self.target_class] * len(source_imgs))
                assert not self.train_gan
                x_list = torch.cat(x_list)
                poison_set = TensorListDataset(x_list, y_list)
                # poison_set = torch.utils.data.ConcatDataset([poison_set, target_original_dataset])
        final_set = torch.utils.data.ConcatDataset([poison_set, full_set])
        # final_set = poison_set
        final_loader = self.dataset.get_dataloader(mode='train', dataset=final_set, num_workers=0)
        self.model._train(optimizer=optimizer, lr_scheduler=lr_scheduler, save_fn=self.save,
                          loader_train=final_loader, validate_fn=self.validate_fn, **kwargs)