예제 #1
0
    def __init__(self, data_shape: list[int], edge_color: Union[str, torch.Tensor] = 'auto',
                 mark_path: str = 'trojanzoo/data/mark/square_white.png', mark_alpha: float = 0.0,
                 mark_height: int = None, mark_width: int = None,
                 height_offset: int = 0, width_offset: int = 0,
                 random_pos=False, random_init=False, mark_distributed=False,
                 add_mark_fn=None, **kwargs):
        self.param_list: dict[str, list[str]] = {}
        self.param_list['mark'] = ['mark_path', 'data_shape', 'edge_color',
                                   'mark_alpha', 'mark_height', 'mark_width',
                                   'random_pos', 'random_init']
        assert mark_height > 0 and mark_width > 0
        # --------------------------------------------------- #

        # WaterMark Image Parameters
        self.mark_alpha: float = mark_alpha
        self.data_shape: list[int] = data_shape
        self.mark_path: str = mark_path
        self.mark_height: int = mark_height
        self.mark_width: int = mark_width
        self.random_pos = random_pos
        self.random_init = random_init
        self.mark_distributed = mark_distributed
        self.add_mark_fn: Callable = add_mark_fn
        # --------------------------------------------------- #

        if self.mark_distributed:
            self.mark = torch.rand(data_shape, dtype=torch.float, device=env['device'])
            mask = torch.zeros(data_shape[-2:], dtype=torch.bool, device=env['device']).flatten()
            idx = np.random.choice(len(mask), self.mark_height * self.mark_width, replace=False).tolist()
            mask[idx] = 1.0
            mask = mask.view(data_shape[-2:])
            self.mask = mask
            self.alpha_mask = self.mask * (1 - mark_alpha)
            self.edge_color = None
        else:
            org_mark_img: Image.Image = self.load_img(img_path=mark_path,
                                                      height=mark_height, width=mark_width, channel=data_shape[0])
            self.org_mark: torch.Tensor = byte2float(org_mark_img)
            self.edge_color: torch.Tensor = self.get_edge_color(
                self.org_mark, data_shape, edge_color)
            self.org_mask, self.org_alpha_mask = self.org_mask_mark(self.org_mark, self.edge_color, self.mark_alpha)
            if random_init:
                self.org_mark = self.random_init_mark(self.org_mark, self.org_mask)
            if not random_pos:
                self.param_list['mark'].extend(['height_offset', 'width_offset'])
                self.height_offset: int = height_offset
                self.width_offset: int = width_offset
                self.mark, self.mask, self.alpha_mask = self.mask_mark()
예제 #2
0
    def get_mark(self, conv_ref_img: torch.Tensor):
        '''
        input is a convolved reflection images, already in same
        shape of any input images, this function will legally reshape
        this ref_img and give to self.mark.mark.
        '''
        org_mark_img: Image.Image = to_pil_image(conv_ref_img)
        org_mark_img = org_mark_img.resize(
            (self.mark.mark_width, self.mark.mark_height), Image.ANTIALIAS)
        self.mark.org_mark = byte2float(org_mark_img)

        self.mark.org_mask, self.mark.org_alpha_mask = self.mark.org_mask_mark(
            self.mark.org_mark, self.mark.edge_color, self.mark.mark_alpha)
        self.mark.mark, self.mark.mask, self.mark.alpha_mask = self.mark.mask_mark(
            height_offset=self.mark.height_offset,
            width_offset=self.mark.width_offset)