def __init__(self, data_shape: list[int], edge_color: Union[str, torch.Tensor] = 'auto', mark_path: str = 'trojanzoo/data/mark/square_white.png', mark_alpha: float = 0.0, mark_height: int = None, mark_width: int = None, height_offset: int = 0, width_offset: int = 0, random_pos=False, random_init=False, mark_distributed=False, add_mark_fn=None, **kwargs): self.param_list: dict[str, list[str]] = {} self.param_list['mark'] = ['mark_path', 'data_shape', 'edge_color', 'mark_alpha', 'mark_height', 'mark_width', 'random_pos', 'random_init'] assert mark_height > 0 and mark_width > 0 # --------------------------------------------------- # # WaterMark Image Parameters self.mark_alpha: float = mark_alpha self.data_shape: list[int] = data_shape self.mark_path: str = mark_path self.mark_height: int = mark_height self.mark_width: int = mark_width self.random_pos = random_pos self.random_init = random_init self.mark_distributed = mark_distributed self.add_mark_fn: Callable = add_mark_fn # --------------------------------------------------- # if self.mark_distributed: self.mark = torch.rand(data_shape, dtype=torch.float, device=env['device']) mask = torch.zeros(data_shape[-2:], dtype=torch.bool, device=env['device']).flatten() idx = np.random.choice(len(mask), self.mark_height * self.mark_width, replace=False).tolist() mask[idx] = 1.0 mask = mask.view(data_shape[-2:]) self.mask = mask self.alpha_mask = self.mask * (1 - mark_alpha) self.edge_color = None else: org_mark_img: Image.Image = self.load_img(img_path=mark_path, height=mark_height, width=mark_width, channel=data_shape[0]) self.org_mark: torch.Tensor = byte2float(org_mark_img) self.edge_color: torch.Tensor = self.get_edge_color( self.org_mark, data_shape, edge_color) self.org_mask, self.org_alpha_mask = self.org_mask_mark(self.org_mark, self.edge_color, self.mark_alpha) if random_init: self.org_mark = self.random_init_mark(self.org_mark, self.org_mask) if not random_pos: self.param_list['mark'].extend(['height_offset', 'width_offset']) self.height_offset: int = height_offset self.width_offset: int = width_offset self.mark, self.mask, self.alpha_mask = self.mask_mark()
def get_mark(self, conv_ref_img: torch.Tensor): ''' input is a convolved reflection images, already in same shape of any input images, this function will legally reshape this ref_img and give to self.mark.mark. ''' org_mark_img: Image.Image = to_pil_image(conv_ref_img) org_mark_img = org_mark_img.resize( (self.mark.mark_width, self.mark.mark_height), Image.ANTIALIAS) self.mark.org_mark = byte2float(org_mark_img) self.mark.org_mask, self.mark.org_alpha_mask = self.mark.org_mask_mark( self.mark.org_mark, self.mark.edge_color, self.mark.mark_alpha) self.mark.mark, self.mark.mask, self.mark.alpha_mask = self.mark.mask_mark( height_offset=self.mark.height_offset, width_offset=self.mark.width_offset)