Exemplo n.º 1
0
    def get_pred_labels(self) -> torch.Tensor:
        r"""Get predicted labels for test inputs.

        Returns:
            torch.Tensor: ``torch.BoolTensor`` with shape ``(2 * defense_input_num)``.
        """
        logger = MetricLogger(meter_length=40)
        str_format = '{global_avg:5.3f} ({min:5.3f}, {max:5.3f})'
        logger.create_meters(clean_score=str_format, poison_score=str_format)
        test_set = TensorListDataset(self.test_input, self.test_label)
        test_loader = self.dataset.get_dataloader(mode='valid',
                                                  dataset=test_set)
        for data in logger.log_every(test_loader):
            _input, _label = self.model.get_data(data)
            trigger_input = self.attack.add_mark(_input)
            logger.meters['clean_score'].update_list(
                self.get_score(_input).tolist())
            logger.meters['poison_score'].update_list(
                self.get_score(trigger_input).tolist())
        clean_score = torch.as_tensor(logger.meters['clean_score'].deque)
        poison_score = torch.as_tensor(logger.meters['poison_score'].deque)
        clean_score_sorted = clean_score.msort()
        threshold_low = float(clean_score_sorted[int(self.strip_fpr *
                                                     len(poison_score))])
        entropy = torch.cat((clean_score, poison_score))
        print(f'Threshold: {threshold_low:5.3f}')
        return torch.where(entropy < threshold_low,
                           torch.ones_like(entropy).bool(),
                           torch.zeros_like(entropy).bool())
Exemplo n.º 2
0
    def get_pred_label(self,
                       img: torch.Tensor,
                       logger: MetricLogger = None) -> bool:
        r"""Get the prediction label of one certain image (poisoned or not).

        Args:
            img (torch.Tensor): Image tensor (on GPU) with shape ``(C, H, W)``.
            logger (trojanzoo.utils.logger.MetricLogger):
                output logger.
                Defaults to ``None``.

        Returns:
            bool: Whether the image tensor :attr:`img` is poisoned.
        """
        # get dominant color
        dom_c = self.get_dominant_color(img).unsqueeze(-1).unsqueeze(
            -1)  # (C, 1, 1)

        # generate random numbers
        height, width = img.shape[-2:]
        pos_height = torch.randint(low=0,
                                   high=height - self.mark_size[0],
                                   size=[self.neo_sample_num, 1])
        pos_width = torch.randint(low=0,
                                  high=width - self.mark_size[1],
                                  size=[self.neo_sample_num, 1])
        pos_list = torch.stack([pos_height, pos_width],
                               dim=1)  # (neo_sample_num, 2)
        # block potential triggers on _input
        block_input = img.repeat(self.neo_sample_num, 1, 1,
                                 1)  # (neo_sample_num, C, H, W)
        for i in range(self.neo_sample_num):
            x = pos_list[i][0]
            y = pos_list[i][1]
            block_input[i, :, x:x + self.mark_size[0],
                        y:y + self.mark_size[1]] = dom_c
        # get potential triggers
        org_class = self.model.get_class(img.unsqueeze(0)).item()  # (1)
        block_class = self.model.get_class(
            block_input).cpu()  # (neo_sample_num)

        # confirm triggers
        pos_pairs = pos_list[block_class != org_class]  # (*, 2)
        for pos in pos_pairs:
            self.attack.mark.mark_height_offset = pos[0]
            self.attack.mark.mark_width_offset = pos[1]
            self.attack.mark.mark.fill_(1.0)
            self.attack.mark.mark[:-1] = img[...,
                                             pos[0]:pos[0] + self.mark_size[0],
                                             pos[1]:pos[1] + self.mark_size[1]]
            cls_diff = self.get_cls_diff()
            if cls_diff > self.neo_asr_threshold:
                jaccard_idx = mask_jaccard(self.attack.mark.get_mask(),
                                           self.real_mask,
                                           select_num=self.select_num)
                logger.update(cls_diff=cls_diff, jaccard_idx=jaccard_idx)
                return True
        return False
Exemplo n.º 3
0
def compare(module1: nn.Module,
            module2: nn.Module,
            loader: torch.utils.data.DataLoader,
            print_prefix='Validate',
            indent=0,
            verbose=True,
            get_data_fn: Callable[..., tuple[torch.Tensor,
                                             torch.Tensor]] = None,
            criterion: Callable[[torch.Tensor, torch.Tensor],
                                torch.Tensor] = nn.CrossEntropyLoss(),
            **kwargs) -> float:
    module1.eval()
    module2.eval()
    get_data_fn = get_data_fn if get_data_fn is not None else lambda x: x

    logger = MetricLogger()
    logger.create_meters(loss=None)
    loader_epoch = loader
    if verbose:
        header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
        header = header.ljust(
            max(len(print_prefix), 30) + get_ansi_len(header))
        if env['tqdm']:
            loader_epoch = tqdm(loader_epoch, leave=False)
        loader_epoch = logger.log_every(loader_epoch,
                                        header=header,
                                        indent=indent)
    for data in loader_epoch:
        _input, _label = get_data_fn(data, **kwargs)
        _output1: torch.Tensor = module1(_input)
        _output2: torch.Tensor = module2(_input)
        loss = criterion(_output1, _output2.softmax(1)).item()
        batch_size = int(_label.size(0))
        logger.update(n=batch_size, loss=loss)
    return logger.meters['loss'].global_avg
Exemplo n.º 4
0
def validate(module: nn.Module,
             num_classes: int,
             loader: torch.utils.data.DataLoader,
             print_prefix: str = 'Validate',
             indent: int = 0,
             verbose: bool = True,
             get_data_fn: Callable[..., tuple[torch.Tensor,
                                              torch.Tensor]] = None,
             forward_fn: Callable[..., torch.Tensor] = None,
             loss_fn: Callable[..., torch.Tensor] = None,
             writer=None,
             main_tag: str = 'valid',
             tag: str = '',
             _epoch: int = None,
             accuracy_fn: Callable[..., list[float]] = None,
             **kwargs) -> tuple[float, float]:
    r"""Evaluate the model.

    Returns:
        (float, float): Accuracy and loss.
    """
    module.eval()
    get_data_fn = get_data_fn or (lambda x: x)
    forward_fn = forward_fn or module.__call__
    loss_fn = loss_fn or nn.CrossEntropyLoss()
    accuracy_fn = accuracy_fn or accuracy
    logger = MetricLogger()
    logger.create_meters(loss=None, top1=None, top5=None)
    loader_epoch = loader
    if verbose:
        header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
        header = header.ljust(
            max(len(print_prefix), 30) + get_ansi_len(header))
        loader_epoch = logger.log_every(loader,
                                        header=header,
                                        tqdm_header='Batch',
                                        indent=indent)
    for data in loader_epoch:
        _input, _label = get_data_fn(data, mode='valid', **kwargs)
        with torch.no_grad():
            _output = forward_fn(_input)
            loss = float(loss_fn(_input, _label, _output=_output, **kwargs))
            acc1, acc5 = accuracy_fn(_output,
                                     _label,
                                     num_classes=num_classes,
                                     topk=(1, 5))
            batch_size = int(_label.size(0))
            logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5)
    acc, loss = (logger.meters['top1'].global_avg,
                 logger.meters['loss'].global_avg)
    if writer is not None and _epoch is not None and main_tag:
        from torch.utils.tensorboard import SummaryWriter
        assert isinstance(writer, SummaryWriter)
        writer.add_scalars(main_tag='Acc/' + main_tag,
                           tag_scalar_dict={tag: acc},
                           global_step=_epoch)
        writer.add_scalars(main_tag='Loss/' + main_tag,
                           tag_scalar_dict={tag: loss},
                           global_step=_epoch)
    return acc, loss
Exemplo n.º 5
0
    def validate_mask_generator(self):
        loader = self.dataset.loader['valid']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, norm=None)
        idx = torch.randperm(len(dataset))
        pos = 0

        print_prefix = 'Validate'
        header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
        header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header))
        for data in logger.log_every(loader, header=header):
            _input, _label = self.model.get_data(data)
            batch_size = len(_input)
            data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
            _input2, _label2 = self.model.get_data(data2)
            pos += batch_size

            _mask = self.get_mask(_input)
            _mask2 = self.get_mask(_input2)

            input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1)
            mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5

            loss_div = input_dist.div(mask_dist).mean()
            loss_norm = _mask.sub(self.mask_density).relu().mean()

            loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div
            logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item())
Exemplo n.º 6
0
 def get_pred_labels(self) -> torch.Tensor:
     logger = MetricLogger(meter_length=40)
     str_format = '{global_avg:5.3f} ({min:5.3f}, {max:5.3f})'
     logger.create_meters(cls_diff=str_format, jaccard_idx=str_format)
     test_set = TensorListDataset(self.test_input, self.test_label)
     test_loader = self.dataset.get_dataloader(mode='valid',
                                               dataset=test_set,
                                               batch_size=1)
     clean_list = []
     poison_list = []
     for data in logger.log_every(test_loader):
         _input: torch.Tensor = data[0]
         _input = _input.to(env['device'], non_blocking=True)
         trigger_input = self.attack.add_mark(_input)
         clean_list.append(self.get_pred_label(_input[0], logger=logger))
         poison_list.append(
             self.get_pred_label(trigger_input[0], logger=logger))
     return torch.as_tensor(clean_list + poison_list, dtype=torch.bool)
Exemplo n.º 7
0
    def train_mask_generator(self, verbose: bool = True):
        r"""Train :attr:`self.mask_generator`."""
        optimizer = torch.optim.Adam(self.mask_generator.parameters(), lr=1e-2, betas=(0.5, 0.9))
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.train_mask_epochs)
        loader = self.dataset.loader['train']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, norm=None)
        print_prefix = 'Mask Epoch'
        for _epoch in range(self.train_mask_epochs):
            _epoch += 1
            idx = torch.randperm(len(dataset))
            pos = 0
            logger.reset()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                print_prefix, output_iter(_epoch, self.train_mask_epochs), **ansi)
            header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header))
            self.mask_generator.train()
            for data in logger.log_every(loader, header=header) if verbose else loader:
                optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                batch_size = len(_input)
                data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
                _input2, _label2 = self.model.get_data(data2)
                pos += batch_size

                _mask = self.get_mask(_input)
                _mask2 = self.get_mask(_input2)

                input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1)
                mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5

                loss_div = input_dist.div(mask_dist).mean()
                loss_norm = _mask.sub(self.mask_density).relu().mean()

                loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div
                loss.backward()
                optimizer.step()
                logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item())
            lr_scheduler.step()
            self.mask_generator.eval()
            if verbose and (_epoch % (max(self.train_mask_epochs // 5, 1)) == 0 or _epoch == self.train_mask_epochs):
                self.validate_mask_generator()
        optimizer.zero_grad()
Exemplo n.º 8
0
def compare(module1: nn.Module,
            module2: nn.Module,
            loader: torch.utils.data.DataLoader,
            print_prefix='Validate',
            indent=0,
            verbose=True,
            get_data_fn: Callable[..., tuple[torch.Tensor,
                                             torch.Tensor]] = None,
            **kwargs) -> float:
    logsoftmax = nn.LogSoftmax(dim=1)
    softmax = nn.Softmax(dim=1)
    module1.eval()
    module2.eval()
    get_data_fn = get_data_fn if get_data_fn is not None else lambda x: x

    def cross_entropy(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
        result: torch.Tensor = -softmax(p) * logsoftmax(q)
        return result.sum(1).mean()

    logger = MetricLogger()
    logger.meters['loss'] = SmoothedValue()
    loader_epoch = loader
    if verbose:
        header = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
        header = header.ljust(
            max(len(print_prefix), 30) + get_ansi_len(header))
        if env['tqdm']:
            header = '{upline}{clear_line}'.format(**ansi) + header
            loader_epoch = tqdm(loader_epoch)
        loader_epoch = logger.log_every(loader_epoch,
                                        header=header,
                                        indent=indent)
    with torch.no_grad():
        for data in loader_epoch:
            _input, _label = get_data_fn(data, **kwargs)
            _output1, _output2 = module1(_input), module2(_input)
            loss = float(cross_entropy(_output1, _output2))
            batch_size = int(_label.size(0))
            logger.meters['loss'].update(loss, batch_size)
    return logger.meters['loss'].global_avg
Exemplo n.º 9
0
 def _validate(self,
               full=True,
               print_prefix='Validate',
               indent=0,
               verbose=True,
               loader: torch.utils.data.DataLoader = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               writer=None,
               main_tag: str = 'valid',
               tag: str = '',
               _epoch: int = None,
               **kwargs) -> tuple[float, float]:
     self.eval()
     if loader is None:
         loader = self.dataset.loader[
             'valid'] if full else self.dataset.loader['valid2']
     get_data_fn = get_data_fn if get_data_fn is not None else self.get_data
     loss_fn = loss_fn if loss_fn is not None else self.loss
     logger = MetricLogger()
     logger.meters['loss'] = SmoothedValue()
     logger.meters['top1'] = SmoothedValue()
     logger.meters['top5'] = SmoothedValue()
     loader_epoch = loader
     if verbose:
         header = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
         header = header.ljust(
             max(len(print_prefix), 30) + get_ansi_len(header))
         if env['tqdm']:
             header = '{upline}{clear_line}'.format(**ansi) + header
             loader_epoch = tqdm(loader_epoch)
         loader_epoch = logger.log_every(loader_epoch,
                                         header=header,
                                         indent=indent)
     for data in loader_epoch:
         _input, _label = get_data_fn(data, mode='valid', **kwargs)
         with torch.no_grad():
             _output = self(_input)
             loss = float(loss_fn(_input, _label, _output=_output,
                                  **kwargs))
             acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
             batch_size = int(_label.size(0))
             logger.meters['loss'].update(loss, batch_size)
             logger.meters['top1'].update(acc1, batch_size)
             logger.meters['top5'].update(acc5, batch_size)
     loss, acc = logger.meters['loss'].global_avg, logger.meters[
         'top1'].global_avg
     if writer is not None and _epoch is not None and main_tag:
         from torch.utils.tensorboard import SummaryWriter
         assert isinstance(writer, SummaryWriter)
         writer.add_scalars(main_tag='Loss/' + main_tag,
                            tag_scalar_dict={tag: loss},
                            global_step=_epoch)
         writer.add_scalars(main_tag='Acc/' + main_tag,
                            tag_scalar_dict={tag: acc},
                            global_step=_epoch)
     return loss, acc
Exemplo n.º 10
0
    def _get_asr_result(self, marks: torch.Tensor) -> torch.Tensor:
        r"""Get attack succ rate result for each mark in :attr:`marks`.

        Args:
            marks (torch.Tensor): Marks tensor with shape ``(N, C, H, W)``.

        Returns:
            torch.Tensor: Attack succ rate tensor with shape ``(N)``.
        """
        asr_list = []
        logger = MetricLogger(meter_length=35, indent=4)
        logger.create_meters(asr='{median:.3f} ({min:.3f}  {max:.3f})')
        for mark in logger.log_every(marks, header='mark', tqdm_header='mark'):
            self.mark.mark[:-1] = mark
            asr, _ = self.model._validate(get_data_fn=self.get_data, keep_org=False,
                                          poison_label=True, verbose=False,
                                          loader=self.loader_valid)
            # Original code considers an untargeted-like attack scenario.
            # org_acc, _ = self.model._validate(get_data_fn=self.get_data, keep_org=False,
            #                                   poison_label=False, verbose=False)
            # asr = 100 - org_acc
            logger.update(asr=asr)
            asr_list.append(asr)
        return torch.tensor(asr_list)
Exemplo n.º 11
0
    def attack(self, epochs: int, optimizer: torch.optim.Optimizer,
               lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
               validate_interval: int = 1, save: bool = False,
               verbose: bool = True, **kwargs):
        if verbose:
            print('train mask generator')
        self.mark_generator.requires_grad_(False)
        self.mask_generator.requires_grad_()
        self.model.requires_grad_(False)
        self.train_mask_generator(verbose=verbose)
        if verbose:
            print()
            print('train mark generator and model')

        self.mark_generator.requires_grad_()
        self.mask_generator.requires_grad_(False)
        if not self.natural:
            params: list[nn.Parameter] = []
            for param_group in optimizer.param_groups:
                params.extend(param_group['params'])
            self.model.activate_params(params)

        mark_optimizer = torch.optim.Adam(self.mark_generator.parameters(), lr=1e-2, betas=(0.5, 0.9))
        mark_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            mark_optimizer, T_max=epochs)
        loader = self.dataset.loader['train']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, ce=None)

        if validate_interval != 0:
            best_validate_result = self.validate_fn(verbose=verbose)
            best_asr = best_validate_result[0]
        for _epoch in range(epochs):
            _epoch += 1
            idx = torch.randperm(len(dataset))
            pos = 0
            logger.reset()
            if not self.natural:
                self.model.train()
            self.mark_generator.train()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            for data in logger.log_every(loader, header=header) if verbose else loader:
                if not self.natural:
                    optimizer.zero_grad()
                mark_optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                batch_size = len(_input)
                data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
                _input2, _label2 = self.model.get_data(data2)
                pos += batch_size
                final_input, final_label = _input.clone(), _label.clone()

                # generate trigger input
                trigger_dec, trigger_int = math.modf(len(_label) * self.poison_percent)
                trigger_int = int(trigger_int)
                if random.uniform(0, 1) < trigger_dec:
                    trigger_int += 1
                x = _input[:trigger_int]
                trigger_mark, trigger_mask = self.get_mark(x), self.get_mask(x)
                trigger_input = x + trigger_mask * (trigger_mark - x)
                final_input[:trigger_int] = trigger_input
                final_label[:trigger_int] = self.target_class

                # generate cross input
                cross_dec, cross_int = math.modf(len(_label) * self.cross_percent)
                cross_int = int(cross_int)
                if random.uniform(0, 1) < cross_dec:
                    cross_int += 1
                x = _input[trigger_int:trigger_int + cross_int]
                x2 = _input2[trigger_int:trigger_int + cross_int]
                cross_mark, cross_mask = self.get_mark(x2), self.get_mask(x2)
                cross_input = x + cross_mask * (cross_mark - x)
                final_input[trigger_int:trigger_int + cross_int] = cross_input

                # div loss
                if len(trigger_input) <= len(cross_input):
                    length = len(trigger_input)
                    cross_input = cross_input[:length]
                    cross_mark = cross_mark[:length]
                    cross_mask = cross_mask[:length]
                else:
                    length = len(cross_input)
                    trigger_input = trigger_input[:length]
                    trigger_mark = trigger_mark[:length]
                    trigger_mask = trigger_mask[:length]
                input_dist: torch.Tensor = (trigger_input - cross_input).flatten(1).norm(p=2, dim=1)
                mark_dist: torch.Tensor = (trigger_mark - cross_mark).flatten(1).norm(p=2, dim=1) + 1e-5

                loss_ce = self.model.loss(final_input, final_label)
                loss_div = input_dist.div(mark_dist).mean()

                loss = loss_ce + self.lambda_div * loss_div
                loss.backward()
                if not self.natural:
                    optimizer.step()
                mark_optimizer.step()
                logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), ce=loss_ce.item())
            if not self.natural and lr_scheduler:
                lr_scheduler.step()
            mark_scheduler.step()
            if not self.natural:
                self.model.eval()
            self.mark_generator.eval()
            if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs):
                validate_result = self.validate_fn(verbose=verbose)
                cur_asr = validate_result[0]
                if cur_asr >= best_asr:
                    best_validate_result = validate_result
                    best_asr = cur_asr
                    if save:
                        self.save()
        if not self.natural:
            optimizer.zero_grad()
        mark_optimizer.zero_grad()
        self.mark_generator.requires_grad_(False)
        self.mask_generator.requires_grad_(False)
        self.model.requires_grad_(False)
        return best_validate_result
Exemplo n.º 12
0
    def optimize_mark(self,
                      label: int,
                      loader: Iterable = None,
                      logger_header: str = '',
                      verbose: bool = True,
                      **kwargs) -> tuple[torch.Tensor, float]:
        r"""
        Args:
            label (int): The class label to optimize.
            loader (collections.abc.Iterable):
                Data loader to optimize trigger.
                Defaults to ``self.dataset.loader['train']``.
            logger_header (str): Header string of logger.
                Defaults to ``''``.
            verbose (bool): Whether to use logger for output.
                Defaults to ``True``.
            **kwargs: Keyword arguments passed to :meth:`loss()`.

        Returns:
            (torch.Tensor, torch.Tensor):
                Optimized mark tensor with shape ``(C + 1, H, W)``
                and loss tensor.
        """
        atanh_mark = torch.randn_like(self.attack.mark.mark,
                                      requires_grad=True)
        optimizer = optim.Adam([atanh_mark],
                               lr=self.defense_remask_lr,
                               betas=(0.5, 0.9))
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.defense_remask_epoch)
        optimizer.zero_grad()
        loader = loader or self.dataset.loader['train']

        # best optimization results
        norm_best: float = float('inf')
        mark_best: torch.Tensor = None
        loss_best: float = None

        logger = MetricLogger(indent=4)
        logger.create_meters(
            loss='{last_value:.3f}',
            acc='{last_value:.3f}',
            norm='{last_value:.3f}',
            entropy='{last_value:.3f}',
        )
        batch_logger = MetricLogger()
        logger.create_meters(loss=None, acc=None, entropy=None)

        iterator = range(self.defense_remask_epoch)
        if verbose:
            iterator = logger.log_every(iterator, header=logger_header)
        for _ in iterator:
            batch_logger.reset()
            for data in loader:
                self.attack.mark.mark = tanh_func(atanh_mark)  # (c+1, h, w)
                _input, _label = self.model.get_data(data)
                trigger_input = self.attack.add_mark(_input)
                trigger_label = label * torch.ones_like(_label)
                trigger_output = self.model(trigger_input)

                batch_acc = trigger_label.eq(
                    trigger_output.argmax(1)).float().mean()
                batch_entropy = self.loss(_input,
                                          _label,
                                          target=label,
                                          trigger_output=trigger_output,
                                          **kwargs)
                batch_norm: torch.Tensor = self.attack.mark.mark[-1].norm(p=1)
                batch_loss = batch_entropy + self.cost * batch_norm

                batch_loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                batch_size = _label.size(0)
                batch_logger.update(n=batch_size,
                                    loss=batch_loss.item(),
                                    acc=batch_acc.item(),
                                    entropy=batch_entropy.item())
            lr_scheduler.step()
            self.attack.mark.mark = tanh_func(atanh_mark)  # (c+1, h, w)

            # check to save best mask or not
            loss = batch_logger.meters['loss'].global_avg
            acc = batch_logger.meters['acc'].global_avg
            norm = float(self.attack.mark.mark[-1].norm(p=1))
            entropy = batch_logger.meters['entropy'].global_avg
            if norm < norm_best:
                mark_best = self.attack.mark.mark.detach().clone()
                loss_best = loss
                logger.update(loss=loss, acc=acc, norm=norm, entropy=entropy)

            if self.check_early_stop(loss=loss,
                                     acc=acc,
                                     norm=norm,
                                     entropy=entropy):
                print('early stop')
                break
        atanh_mark.requires_grad_(False)
        self.attack.mark.mark = mark_best
        return mark_best, loss_best
Exemplo n.º 13
0
def train(module: nn.Module,
          num_classes: int,
          epochs: int,
          optimizer: Optimizer,
          lr_scheduler: _LRScheduler = None,
          lr_warmup_epochs: int = 0,
          model_ema: ExponentialMovingAverage = None,
          model_ema_steps: int = 32,
          grad_clip: float = None,
          pre_conditioner: None | KFAC | EKFAC = None,
          print_prefix: str = 'Train',
          start_epoch: int = 0,
          resume: int = 0,
          validate_interval: int = 10,
          save: bool = False,
          amp: bool = False,
          loader_train: torch.utils.data.DataLoader = None,
          loader_valid: torch.utils.data.DataLoader = None,
          epoch_fn: Callable[..., None] = None,
          get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None,
          forward_fn: Callable[..., torch.Tensor] = None,
          loss_fn: Callable[..., torch.Tensor] = None,
          after_loss_fn: Callable[..., None] = None,
          validate_fn: Callable[..., tuple[float, float]] = None,
          save_fn: Callable[..., None] = None,
          file_path: str = None,
          folder_path: str = None,
          suffix: str = None,
          writer=None,
          main_tag: str = 'train',
          tag: str = '',
          accuracy_fn: Callable[..., list[float]] = None,
          verbose: bool = True,
          output_freq: str = 'iter',
          indent: int = 0,
          change_train_eval: bool = True,
          lr_scheduler_freq: str = 'epoch',
          backward_and_step: bool = True,
          **kwargs):
    r"""Train the model"""
    if epochs <= 0:
        return
    get_data_fn = get_data_fn or (lambda x: x)
    forward_fn = forward_fn or module.__call__
    loss_fn = loss_fn or (lambda _input, _label, _output=None: F.cross_entropy(
        _output or forward_fn(_input), _label))
    validate_fn = validate_fn or validate
    accuracy_fn = accuracy_fn or accuracy

    scaler: torch.cuda.amp.GradScaler = None
    if not env['num_gpus']:
        amp = False
    if amp:
        scaler = torch.cuda.amp.GradScaler()
    best_validate_result = (0.0, float('inf'))
    if validate_interval != 0:
        best_validate_result = validate_fn(loader=loader_valid,
                                           get_data_fn=get_data_fn,
                                           forward_fn=forward_fn,
                                           loss_fn=loss_fn,
                                           writer=None,
                                           tag=tag,
                                           _epoch=start_epoch,
                                           verbose=verbose,
                                           indent=indent,
                                           **kwargs)
        best_acc = best_validate_result[0]

    params: list[nn.Parameter] = []
    for param_group in optimizer.param_groups:
        params.extend(param_group['params'])
    len_loader_train = len(loader_train)
    total_iter = (epochs - resume) * len_loader_train

    logger = MetricLogger()
    logger.create_meters(loss=None, top1=None, top5=None)

    if resume and lr_scheduler:
        for _ in range(resume):
            lr_scheduler.step()
    iterator = range(resume, epochs)
    if verbose and output_freq == 'epoch':
        header: str = '{blue_light}{0}: {reset}'.format(print_prefix, **ansi)
        header = header.ljust(max(len(header), 30) + get_ansi_len(header))
        iterator = logger.log_every(range(resume, epochs),
                                    header=print_prefix,
                                    tqdm_header='Epoch',
                                    indent=indent)
    for _epoch in iterator:
        _epoch += 1
        logger.reset()
        if callable(epoch_fn):
            activate_params(module, [])
            epoch_fn(optimizer=optimizer,
                     lr_scheduler=lr_scheduler,
                     _epoch=_epoch,
                     epochs=epochs,
                     start_epoch=start_epoch)
        loader_epoch = loader_train
        if verbose and output_freq == 'iter':
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            loader_epoch = logger.log_every(loader_train,
                                            header=header,
                                            tqdm_header='Batch',
                                            indent=indent)
        if change_train_eval:
            module.train()
        activate_params(module, params)
        for i, data in enumerate(loader_epoch):
            _iter = _epoch * len_loader_train + i
            # data_time.update(time.perf_counter() - end)
            _input, _label = get_data_fn(data, mode='train')
            if pre_conditioner is not None and not amp:
                pre_conditioner.track.enable()
            _output = forward_fn(_input, amp=amp, parallel=True)
            loss = loss_fn(_input, _label, _output=_output, amp=amp)
            if backward_and_step:
                optimizer.zero_grad()
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn) or grad_clip is not None:
                        scaler.unscale_(optimizer)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params, grad_clip)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epochs=epochs)
                    if pre_conditioner is not None:
                        pre_conditioner.track.disable()
                        pre_conditioner.step()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params, grad_clip)
                    optimizer.step()

            if model_ema and i % model_ema_steps == 0:
                model_ema.update_parameters(module)
                if _epoch <= lr_warmup_epochs:
                    # Reset ema buffer to keep copying weights
                    # during warmup period
                    model_ema.n_averaged.fill_(0)

            if lr_scheduler and lr_scheduler_freq == 'iter':
                lr_scheduler.step()
            acc1, acc5 = accuracy_fn(_output,
                                     _label,
                                     num_classes=num_classes,
                                     topk=(1, 5))
            batch_size = int(_label.size(0))
            logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5)
            empty_cache()
        optimizer.zero_grad()
        if lr_scheduler and lr_scheduler_freq == 'epoch':
            lr_scheduler.step()
        if change_train_eval:
            module.eval()
        activate_params(module, [])
        loss, acc = (logger.meters['loss'].global_avg,
                     logger.meters['top1'].global_avg)
        if writer is not None:
            from torch.utils.tensorboard import SummaryWriter
            assert isinstance(writer, SummaryWriter)
            writer.add_scalars(main_tag='Loss/' + main_tag,
                               tag_scalar_dict={tag: loss},
                               global_step=_epoch + start_epoch)
            writer.add_scalars(main_tag='Acc/' + main_tag,
                               tag_scalar_dict={tag: acc},
                               global_step=_epoch + start_epoch)
        if validate_interval != 0 and (_epoch % validate_interval == 0
                                       or _epoch == epochs):
            validate_result = validate_fn(module=module,
                                          num_classes=num_classes,
                                          loader=loader_valid,
                                          get_data_fn=get_data_fn,
                                          forward_fn=forward_fn,
                                          loss_fn=loss_fn,
                                          writer=writer,
                                          tag=tag,
                                          _epoch=_epoch + start_epoch,
                                          verbose=verbose,
                                          indent=indent,
                                          **kwargs)
            cur_acc = validate_result[0]
            if cur_acc >= best_acc:
                best_validate_result = validate_result
                if verbose:
                    prints('{purple}best result update!{reset}'.format(**ansi),
                           indent=indent)
                    prints(
                        f'Current Acc: {cur_acc:.3f}    '
                        f'Previous Best Acc: {best_acc:.3f}',
                        indent=indent)
                best_acc = cur_acc
                if save:
                    save_fn(file_path=file_path,
                            folder_path=folder_path,
                            suffix=suffix,
                            verbose=verbose)
            if verbose:
                prints('-' * 50, indent=indent)
    module.zero_grad()
    return best_validate_result
Exemplo n.º 14
0
    def _train(self,
               epoch: int,
               optimizer: Optimizer,
               lr_scheduler: _LRScheduler = None,
               grad_clip: float = None,
               print_prefix: str = 'Epoch',
               start_epoch: int = 0,
               validate_interval: int = 10,
               save: bool = False,
               amp: bool = False,
               loader_train: torch.utils.data.DataLoader = None,
               loader_valid: torch.utils.data.DataLoader = None,
               epoch_fn: Callable[..., None] = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               after_loss_fn: Callable[..., None] = None,
               validate_fn: Callable[..., tuple[float, float]] = None,
               save_fn: Callable[..., None] = None,
               file_path: str = None,
               folder_path: str = None,
               suffix: str = None,
               writer=None,
               main_tag: str = 'train',
               tag: str = '',
               verbose: bool = True,
               indent: int = 0,
               **kwargs):
        loader_train = loader_train if loader_train is not None else self.dataset.loader[
            'train']
        get_data_fn = get_data_fn if callable(get_data_fn) else self.get_data
        loss_fn = loss_fn if callable(loss_fn) else self.loss
        validate_fn = validate_fn if callable(validate_fn) else self._validate
        save_fn = save_fn if callable(save_fn) else self.save
        # if not callable(iter_fn) and hasattr(self, 'iter_fn'):
        #     iter_fn = getattr(self, 'iter_fn')
        if not callable(epoch_fn) and hasattr(self, 'epoch_fn'):
            epoch_fn = getattr(self, 'epoch_fn')
        if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'):
            after_loss_fn = getattr(self, 'after_loss_fn')

        scaler: torch.cuda.amp.GradScaler = None
        if not env['num_gpus']:
            amp = False
        if amp:
            scaler = torch.cuda.amp.GradScaler()
        _, best_acc = validate_fn(loader=loader_valid,
                                  get_data_fn=get_data_fn,
                                  loss_fn=loss_fn,
                                  writer=None,
                                  tag=tag,
                                  _epoch=start_epoch,
                                  verbose=verbose,
                                  indent=indent,
                                  **kwargs)

        params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            params.extend(param_group['params'])
        total_iter = epoch * len(loader_train)
        for _epoch in range(epoch):
            _epoch += 1
            if callable(epoch_fn):
                self.activate_params([])
                epoch_fn(optimizer=optimizer,
                         lr_scheduler=lr_scheduler,
                         _epoch=_epoch,
                         epoch=epoch,
                         start_epoch=start_epoch)
                self.activate_params(params)
            logger = MetricLogger()
            logger.meters['loss'] = SmoothedValue()
            logger.meters['top1'] = SmoothedValue()
            logger.meters['top5'] = SmoothedValue()
            loader_epoch = loader_train
            if verbose:
                header = '{blue_light}{0}: {1}{reset}'.format(
                    print_prefix, output_iter(_epoch, epoch), **ansi)
                header = header.ljust(30 + get_ansi_len(header))
                if env['tqdm']:
                    header = '{upline}{clear_line}'.format(**ansi) + header
                    loader_epoch = tqdm(loader_epoch)
                loader_epoch = logger.log_every(loader_epoch,
                                                header=header,
                                                indent=indent)
            self.train()
            self.activate_params(params)
            optimizer.zero_grad()
            for i, data in enumerate(loader_epoch):
                _iter = _epoch * len(loader_train) + i
                # data_time.update(time.perf_counter() - end)
                _input, _label = get_data_fn(data, mode='train')
                _output = self(_input, amp=amp)
                loss = loss_fn(_input, _label, _output=_output, amp=amp)
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epoch=epoch)
                    optimizer.step()
                optimizer.zero_grad()
                acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                logger.meters['loss'].update(float(loss), batch_size)
                logger.meters['top1'].update(acc1, batch_size)
                logger.meters['top5'].update(acc5, batch_size)
                empty_cache(
                )  # TODO: should it be outside of the dataloader loop?
            self.eval()
            self.activate_params([])
            loss, acc = logger.meters['loss'].global_avg, logger.meters[
                'top1'].global_avg
            if writer is not None:
                from torch.utils.tensorboard import SummaryWriter
                assert isinstance(writer, SummaryWriter)
                writer.add_scalars(main_tag='Loss/' + main_tag,
                                   tag_scalar_dict={tag: loss},
                                   global_step=_epoch + start_epoch)
                writer.add_scalars(main_tag='Acc/' + main_tag,
                                   tag_scalar_dict={tag: acc},
                                   global_step=_epoch + start_epoch)
            if lr_scheduler:
                lr_scheduler.step()
            if validate_interval != 0:
                if _epoch % validate_interval == 0 or _epoch == epoch:
                    _, cur_acc = validate_fn(loader=loader_valid,
                                             get_data_fn=get_data_fn,
                                             loss_fn=loss_fn,
                                             writer=writer,
                                             tag=tag,
                                             _epoch=_epoch + start_epoch,
                                             verbose=verbose,
                                             indent=indent,
                                             **kwargs)
                    if cur_acc >= best_acc:
                        if verbose:
                            prints('{green}best result update!{reset}'.format(
                                **ansi),
                                   indent=indent)
                            prints(
                                f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                                indent=indent)
                        best_acc = cur_acc
                        if save:
                            save_fn(file_path=file_path,
                                    folder_path=folder_path,
                                    suffix=suffix,
                                    verbose=verbose)
                    if verbose:
                        prints('-' * 50, indent=indent)
        self.zero_grad()
Exemplo n.º 15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--voc_root', default='~/voc')
    parser.add_argument('--tar_path', default='~/reflection.tar')
    kwargs = parser.parse_args().__dict__
    voc_root: str = kwargs['voc_root']
    tar_path: str = kwargs['tar_path']

    print('get image paths')
    datasets = [
        torchvision.datasets.VOCDetection(voc_root,
                                          year=year,
                                          image_set=image_set,
                                          download=True)
        for year, image_set in sets
    ]
    background_paths = get_img_paths(datasets,
                                     positive_class=background_class,
                                     negative_class=reflect_class)
    reflect_paths = get_img_paths(datasets,
                                  positive_class=reflect_class,
                                  negative_class=background_class)
    print()
    print('background: ', len(background_paths))
    print('reflect: ', len(reflect_paths))
    print()
    print('load images')
    reflect_imgs = [read_tensor(fp) for fp in reflect_paths]
    background_imgs = [
        read_tensor(fp) for i, fp in enumerate(background_paths)
        if i < NUM_ATTACK
    ]

    print('writing tar file: ', tar_path)
    tf = tarfile.open(tar_path, mode='w')
    trojanzoo.environ.create(color=True, tqdm=True)
    logger = MetricLogger(meter_length=35)
    logger.create_meters(
        reflect_num='{count:3d}',
        reflect_mean='{global_avg:.3f} ({min:.3f}  {max:.3f})',
        diff_mean='{global_avg:.3f} ({min:.3f}  {max:.3f})',
        blended_max='{global_avg:.3f} ({min:.3f}  {max:.3f})',
        ssim='{global_avg:.3f} ({min:.3f}  {max:.3f})')
    candidates: set[int] = set()
    for background_img in logger.log_every(background_imgs):
        for i, reflect_img in enumerate(reflect_imgs):
            if i in candidates:
                continue
            blended, background_layer, reflection_layer = blend_images(
                background_img, reflect_img, ghost_rate=0.39)
            reflect_mean: float = reflection_layer.mean().item()
            diff_mean: float = (blended - reflection_layer).mean().item()
            blended_max: float = blended.max().item()
            logger.update(reflect_mean=reflect_mean,
                          diff_mean=diff_mean,
                          blended_max=blended_max)
            if reflect_mean < 0.8 * diff_mean and blended_max > 0.1:
                ssim: float = skimage.metrics.structural_similarity(
                    blended.numpy(), background_layer.numpy(), channel_axis=0)
                logger.update(ssim=ssim)
                if 0.7 < ssim < 0.85:
                    logger.update(reflect_num=1)
                    candidates.add(i)
                    filename = os.path.basename(reflect_paths[i])
                    bytes_io = io.BytesIO()
                    format = os.path.splitext(filename)[1][1:].lower().replace(
                        'jpg', 'jpeg')
                    F.to_pil_image(reflection_layer).save(bytes_io,
                                                          format=format)
                    bytes_data = bytes_io.getvalue()
                    tarinfo = tarfile.TarInfo(name=filename)
                    tarinfo.size = len(bytes_data)
                    tf.addfile(tarinfo, io.BytesIO(bytes_data))
                    break
    tf.close()
Exemplo n.º 16
0
    def attack(self, epochs: int, save: bool = False, **kwargs):
        # train generator
        resnet_model = trojanvision.models.create('resnet18_comp',
                                                  dataset=self.dataset,
                                                  pretrained=True)
        model_extractor = nn.Sequential(*list(resnet_model._model.features.children())[:4])
        match self.generator_mode:
            case 'resnet':
                resnet_model = torchvision.models.resnet18(pretrained=True).to(device=env['device'])
                model_extractor = nn.Sequential(*list(resnet_model.children())[:5])
            case _:
                resnet_model = trojanvision.models.create('resnet18_comp',
                                                          dataset=self.dataset,
                                                          pretrained=True)
                model_extractor = nn.Sequential(*list(resnet_model._model.features.children())[:4])
        model_extractor.requires_grad_(False)
        model_extractor.train()

        self.generator.train()
        if self.generator_mode == 'default':
            self.generator.requires_grad_()
            parameters = self.generator.parameters()
        else:
            self.generator.bottleneck.requires_grad_()
            self.generator.decoder.requires_grad_()
            parameters = self.generator[1:].parameters()
        optimizer = torch.optim.Adam(parameters, lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.train_generator_epochs,
            eta_min=1e-5)

        logger = MetricLogger()
        logger.create_meters(loss=None, acc=None)
        for _epoch in range(self.train_generator_epochs):
            _epoch += 1
            logger.reset()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, self.train_generator_epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            loader = logger.log_every(self.dataset.loader['train'], header=header,
                                      tqdm_header='Batch')
            for data in loader:
                optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                adv_input = (self.generator(_input) + 1) / 2

                _feats = model_extractor(_input)
                adv_feats = model_extractor(adv_input)

                loss = F.l1_loss(_feats, self.noise_coeff * adv_feats)
                loss.backward()
                optimizer.step()
                batch_size = len(_label)
                with torch.no_grad():
                    org_class = self.model.get_class(_input)
                    adv_class = self.model.get_class(adv_input)
                    acc = (org_class == adv_class).float().sum().item() * 100.0 / batch_size
                logger.update(n=batch_size, loss=loss.item(), acc=acc)
            lr_scheduler.step()
        self.save_generator()
        optimizer.zero_grad()
        self.generator.eval()
        self.generator.requires_grad_(False)

        self.mark.mark[:-1] = (self.generator(self.mark.mark[:-1].unsqueeze(0))[0] + 1) / 2
        self.poison_set = self.get_poison_dataset(load_mark=False)
        return super().attack(epochs, save=save, **kwargs)
Exemplo n.º 17
0
def gen_reflect_imgs(tar_path: str, voc_root: str, num_attack: int = 160,
                     reflect_class: set[str] = {'cat'},
                     background_class: set[str] = {'person'}):
    r"""Generate a tar file containing reflect images.

    Args:
        tar_path (str): Tar file path to save.
        voc_root (str): VOC dataset root path.
        num_attack (int): Number of reflect images to generate.
        reflect_class (set[str]): Set of reflect classes.
        background_class (set[str]): Set of background classes.
    """
    print('get image paths')
    if not os.path.isdir(voc_root):
        os.makedirs(voc_root)
    datasets = [torchvision.datasets.VOCDetection(voc_root, year=year, image_set=image_set,
                                                  download=True) for year, image_set in sets]
    background_paths = _get_img_paths(datasets, positive_class=background_class, negative_class=reflect_class)
    reflect_paths = _get_img_paths(datasets, positive_class=reflect_class, negative_class=background_class)
    print()
    print('background: ', len(background_paths))
    print('reflect: ', len(reflect_paths))
    print()
    print('load images')
    reflect_imgs = [read_tensor(fp) for fp in reflect_paths]

    print('writing tar file: ', tar_path)
    tf = tarfile.open(tar_path, mode='w')
    logger = MetricLogger(meter_length=35)
    logger.create_meters(reflect_num=f'[ {{count:3d}} / {num_attack:3d} ]',
                         reflect_mean='{global_avg:.3f} ({min:.3f}  {max:.3f})',
                         diff_mean='{global_avg:.3f} ({min:.3f}  {max:.3f})',
                         blended_max='{global_avg:.3f} ({min:.3f}  {max:.3f})',
                         ssim='{global_avg:.3f} ({min:.3f}  {max:.3f})')
    candidates: set[int] = set()
    for fp in logger.log_every(background_paths):
        background_img = read_tensor(fp)
        for i, reflect_img in enumerate(reflect_imgs):
            if i in candidates:
                continue
            blended, background_layer, reflection_layer = blend_images(
                background_img, reflect_img, ghost_rate=0.39)
            reflect_mean: float = reflection_layer.mean().item()
            diff_mean: float = (blended - reflection_layer).mean().item()
            blended_max: float = blended.max().item()
            logger.update(reflect_mean=reflect_mean, diff_mean=diff_mean, blended_max=blended_max)
            if reflect_mean < 0.8 * diff_mean and blended_max > 0.1:
                ssim: float = skimage.metrics.structural_similarity(
                    blended.numpy(), background_layer.numpy(), channel_axis=0)
                logger.update(ssim=ssim)
                if 0.7 < ssim < 0.85:
                    logger.update(reflect_num=1)
                    candidates.add(i)
                    filename = os.path.basename(reflect_paths[i])
                    bytes_io = io.BytesIO()
                    format = os.path.splitext(filename)[1][1:].lower().replace('jpg', 'jpeg')
                    F.to_pil_image(reflection_layer).save(bytes_io, format=format)
                    bytes_data = bytes_io.getvalue()
                    tarinfo = tarfile.TarInfo(name=filename)
                    tarinfo.size = len(bytes_data)
                    tf.addfile(tarinfo, io.BytesIO(bytes_data))
                    break
        if len(candidates) == num_attack:
            break
    else:
        raise RuntimeError('Can not generate enough images')
    tf.close()