def get_data(self, data: tuple[torch.Tensor, torch.Tensor], org: bool = False, keep_org: bool = True, poison_label=True, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: if org: _input, _label = self.model.get_data(data) else: _input, _label = self.attack.get_data(data=data, keep_org=keep_org, poison_label=poison_label, **kwargs) h, w = _input.shape[-2], _input.shape[-1] _input_list = [] for single_input in _input: image = to_pil_image(single_input) image = F.resize(image, (int(h * self.resize_ratio), int(w * self.resize_ratio)), Image.ANTIALIAS) image = F.resize(image, (h, w)) _input_list.append(to_tensor(image)) return torch.stack(_input_list), _label
def get_mark(self, conv_ref_img: torch.Tensor): ''' input is a convolved reflection images, already in same shape of any input images, this function will legally reshape this ref_img and give to self.mark.mark. ''' org_mark_img: Image.Image = to_pil_image(conv_ref_img) org_mark_img = org_mark_img.resize( (self.mark.mark_width, self.mark.mark_height), Image.ANTIALIAS) self.mark.org_mark = byte2float(org_mark_img) self.mark.org_mask, self.mark.org_alpha_mask = self.mark.org_mask_mark( self.mark.org_mark, self.mark.edge_color, self.mark.mark_alpha) self.mark.mark, self.mark.mask, self.mark.alpha_mask = self.mark.mask_mark( height_offset=self.mark.height_offset, width_offset=self.mark.width_offset)