def add_load_img(ax, img_path, title=None, transform=None, crop_res=None): img = Image.open(img_path) if transform is not None: img = transform_img(img, transform, crop_res) if title is not None: ax.set_title(title) ax.imshow(img) ax.axis("off") return img
def get_sample(self, idx, query=None, color_augm=None, space_augm=None): if query is None: query = self.queries sample = {} if BaseQueries.IMAGE in query or TransQueries.IMAGE in query: center, scale = self.pose_dataset.get_center_scale(idx) needs_center_scale = True else: needs_center_scale = False if BaseQueries.JOINTVIS in query: jointvis = self.pose_dataset.get_jointvis(idx) sample[BaseQueries.JOINTVIS] = jointvis # Get sides if BaseQueries.SIDE in query: hand_side = self.pose_dataset.get_sides(idx) hand_side, flip = datutils.flip_hand_side(self.sides, hand_side) sample[BaseQueries.SIDE] = hand_side else: flip = False # Get original image if BaseQueries.IMAGE in query or TransQueries.IMAGE in query: img = self.pose_dataset.get_image(idx) #img = img.resize((480, 270), Image.BILINEAR) if flip: img = img.transpose(Image.FLIP_LEFT_RIGHT) if BaseQueries.IMAGE in query: sample[BaseQueries.IMAGE] = np.array(img) # Flip and image 2d if needed if flip: center[0] = img.size[0] - center[0] # Data augmentation if space_augm is not None: center = space_augm["center"] scale = space_augm["scale"] rot = space_augm["rot"] elif self.train and needs_center_scale: # Randomly jitter center # Center is located in square of size 2*center_jitter_factor # in center of cropped image center_jit = Uniform(low=-1, high=1).sample((2, )).numpy() center_offsets = self.center_jittering * scale * center_jit center = center + center_offsets.astype(int) # Scale jittering scale_jit = Normal(0, 1).sample().item() + 1 scale_jittering = self.scale_jittering * scale_jit scale_jittering = np.clip(scale_jittering, 1 - self.scale_jittering, 1 + self.scale_jittering) scale = scale * scale_jittering rot = Uniform(low=-self.max_rot, high=self.max_rot).sample().item() else: rot = 0 if self.block_rot: rot = 0 space_augm = {"rot": rot, "scale": scale, "center": center} sample["space_augm"] = space_augm rot_mat = np.array([[np.cos(rot), -np.sin(rot), 0], [np.sin(rot), np.cos(rot), 0], [0, 0, 1]]).astype(np.float32) # Get 2D hand joints if (TransQueries.JOINTS2D in query) or (TransQueries.IMAGE in query): affinetrans, post_rot_trans = handutils.get_affine_transform( center, scale, self.inp_res, rot=rot) if TransQueries.AFFINETRANS in query: sample[TransQueries.AFFINETRANS] = affinetrans if BaseQueries.JOINTS2D in query or TransQueries.JOINTS2D in query: joints2d = self.pose_dataset.get_joints2d(idx) if flip: joints2d = joints2d.copy() joints2d[:, 0] = img.size[0] - joints2d[:, 0] if BaseQueries.JOINTS2D in query: sample[BaseQueries.JOINTS2D] = joints2d.astype(np.float32) if TransQueries.JOINTS2D in query: rows = handutils.transform_coords(joints2d, affinetrans) sample[TransQueries.JOINTS2D] = np.array(rows).astype(np.float32) if BaseQueries.CAMINTR in query or TransQueries.CAMINTR in query: camintr = self.pose_dataset.get_camintr(idx) if BaseQueries.CAMINTR in query: sample[BaseQueries.CAMINTR] = camintr.astype(np.float32) if TransQueries.CAMINTR in query: # Rotation is applied as extr transform new_camintr = post_rot_trans.dot(camintr) sample[TransQueries.CAMINTR] = new_camintr.astype(np.float32) # Get 2D object points if BaseQueries.OBJVERTS2D in query or (TransQueries.OBJVERTS2D in query): objverts2d = self.pose_dataset.get_objverts2d(idx) if flip: objverts2d = objverts2d.copy() objverts2d[:, 0] = img.size[0] - objverts2d[:, 0] if BaseQueries.OBJVERTS2D in query: sample[BaseQueries.OBJVERTS2D] = objverts2d.astype(np.float32) if TransQueries.OBJVERTS2D in query: transobjverts2d = handutils.transform_coords( objverts2d, affinetrans) sample[TransQueries.OBJVERTS2D] = np.array( transobjverts2d).astype(np.float32) if BaseQueries.OBJVIS2D in query: objvis2d = self.pose_dataset.get_objvis2d(idx) sample[BaseQueries.OBJVIS2D] = objvis2d # Get 2D object points if BaseQueries.OBJCORNERS2D in query or (TransQueries.OBJCORNERS2D in query): objcorners2d = self.pose_dataset.get_objcorners2d(idx) if flip: objcorners2d = objcorners2d.copy() objcorners2d[:, 0] = img.size[0] - objcorners2d[:, 0] if BaseQueries.OBJCORNERS2D in query: sample[BaseQueries.OBJCORNERS2D] = np.array(objcorners2d) if TransQueries.OBJCORNERS2D in query: transobjcorners2d = handutils.transform_coords( objcorners2d, affinetrans) sample[TransQueries.OBJCORNERS2D] = np.array(transobjcorners2d) # Get 2D hand points if BaseQueries.HANDVERTS2D in query or (TransQueries.HANDVERTS2D in query): handverts2d = self.pose_dataset.get_hand_verts2d(idx) if flip: handverts2d = handverts2d.copy() handverts2d[:, 0] = img.size[0] - handverts2d[:, 0] if BaseQueries.HANDVERTS2D in query: sample[BaseQueries.HANDVERTS2D] = handverts2d if TransQueries.HANDVERTS2D in query: transhandverts2d = handutils.transform_coords( handverts2d, affinetrans) sample[TransQueries.HANDVERTS2D] = np.array(transhandverts2d) if BaseQueries.HANDVIS2D in query: handvis2d = self.pose_dataset.get_handvis2d(idx) sample[BaseQueries.HANDVIS2D] = handvis2d # Get 3D hand joints if ((BaseQueries.JOINTS3D in query) or (TransQueries.JOINTS3D in query) or (TransQueries.HANDVERTS3D in query) or (TransQueries.OBJVERTS3D in query)): # Center on root joint center3d_queries = [ TransQueries.JOINTS3D, BaseQueries.JOINTS3D, TransQueries.HANDVERTS3D ] if one_query_in([TransQueries.OBJVERTS3D] + center3d_queries, query): joints3d = self.pose_dataset.get_joints3d(idx) if flip: joints3d[:, 0] = -joints3d[:, 0] if BaseQueries.JOINTS3D in query: sample[BaseQueries.JOINTS3D] = joints3d.astype(np.float32) if self.train: joints3d = rot_mat.dot(joints3d.transpose(1, 0)).transpose() # Compute 3D center if self.center_idx is not None: if self.center_idx == -1: center3d = (joints3d[9] + joints3d[0]) / 2 else: center3d = joints3d[self.center_idx] if TransQueries.JOINTS3D in query and (self.center_idx is not None): joints3d = joints3d - center3d if TransQueries.JOINTS3D in query: sample[TransQueries.JOINTS3D] = joints3d.astype(np.float32) # Get 3D hand vertices if TransQueries.HANDVERTS3D in query or BaseQueries.HANDVERTS3D in query: hand_verts3d = self.pose_dataset.get_hand_verts3d(idx) if flip: hand_verts3d[:, 0] = -hand_verts3d[:, 0] if BaseQueries.OBJVERTS3D in query: sample[BaseQueries.HANDVERTS3D] = hand_verts3d.astype( np.float32) if TransQueries.HANDVERTS3D in query: hand_verts3d = rot_mat.dot(hand_verts3d.transpose( 1, 0)).transpose() if self.center_idx is not None: hand_verts3d = hand_verts3d - center3d sample[TransQueries.HANDVERTS3D] = hand_verts3d.astype( np.float32) # Get 3D obj vertices if TransQueries.OBJVERTS3D in query or BaseQueries.OBJVERTS3D in query: obj_verts3d = self.pose_dataset.get_obj_verts_trans(idx) if flip: obj_verts3d[:, 0] = -obj_verts3d[:, 0] if BaseQueries.OBJVERTS3D in query: sample[BaseQueries.OBJVERTS3D] = obj_verts3d if TransQueries.OBJVERTS3D in query: origin_trans_mesh = rot_mat.dot(obj_verts3d.transpose( 1, 0)).transpose() if self.center_idx is not None: origin_trans_mesh = origin_trans_mesh - center3d sample[TransQueries.OBJVERTS3D] = origin_trans_mesh.astype( np.float32) # Get 3D obj vertices if TransQueries.OBJCANROTVERTS in query or BaseQueries.OBJCANROTVERTS in query: obj_canverts3d = self.pose_dataset.get_obj_verts_can_rot(idx) if flip: obj_canverts3d[:, 0] = -obj_canverts3d[:, 0] if BaseQueries.OBJCANROTVERTS in query: sample[BaseQueries.OBJCANROTVERTS] = obj_canverts3d if TransQueries.OBJCANROTVERTS in query: can_rot_mesh = rot_mat.dot(obj_canverts3d.transpose( 1, 0)).transpose() sample[TransQueries.OBJCANROTVERTS] = can_rot_mesh # Get 3D obj vertices if TransQueries.OBJCANROTCORNERS in query or BaseQueries.OBJCANROTCORNERS in query: obj_cancorners3d = self.pose_dataset.get_obj_corners_can_rot(idx) if flip: obj_cancorners3d[:, 0] = -obj_cancorners3d[:, 0] if BaseQueries.OBJCANROTCORNERS in query: sample[BaseQueries.OBJCANROTCORNERS] = obj_cancorners3d if TransQueries.OBJCANROTCORNERS in query: can_rot_corners = rot_mat.dot(obj_cancorners3d.transpose( 1, 0)).transpose() sample[TransQueries.OBJCANROTCORNERS] = can_rot_corners if BaseQueries.OBJFACES in query: obj_faces = self.pose_dataset.get_obj_faces(idx) sample[BaseQueries.OBJFACES] = obj_faces if BaseQueries.OBJCANVERTS in query: obj_canverts, obj_cantrans, obj_canscale = self.pose_dataset.get_obj_verts_can( idx) if flip: obj_canverts[:, 0] = -obj_canverts[:, 0] sample[BaseQueries.OBJCANVERTS] = obj_canverts sample[BaseQueries.OBJCANSCALE] = obj_canscale sample[BaseQueries.OBJCANTRANS] = obj_cantrans # Get 3D obj corners if BaseQueries.OBJCORNERS3D in query or TransQueries.OBJCORNERS3D in query: obj_corners3d = self.pose_dataset.get_obj_corners3d(idx) if flip: obj_corners3d[:, 0] = -obj_corners3d[:, 0] if BaseQueries.OBJCORNERS3D in query: sample[BaseQueries.OBJCORNERS3D] = obj_corners3d if TransQueries.OBJCORNERS3D in query: origin_trans_corners = rot_mat.dot( obj_corners3d.transpose(1, 0)).transpose() if self.center_idx is not None: origin_trans_corners = origin_trans_corners - center3d sample[TransQueries.OBJCORNERS3D] = origin_trans_corners if BaseQueries.OBJCANCORNERS in query: if flip: obj_canverts[:, 0] = -obj_canverts[:, 0] obj_cancorners = self.pose_dataset.get_obj_corners_can(idx) sample[BaseQueries.OBJCANCORNERS] = obj_cancorners if TransQueries.CENTER3D in query: sample[TransQueries.CENTER3D] = center3d # Get rgb image if TransQueries.IMAGE in query: # Data augmentation if self.train: blur_radius = Uniform( low=0, high=1).sample().item() * self.blur_radius img = img.filter(ImageFilter.GaussianBlur(blur_radius)) if color_augm is None: bright, contrast, sat, hue = colortrans.get_color_params( brightness=self.brightness, saturation=self.saturation, hue=self.hue, contrast=self.contrast, ) else: sat = color_augm["sat"] contrast = color_augm["contrast"] hue = color_augm["hue"] bright = color_augm["bright"] img = colortrans.apply_jitter(img, brightness=bright, saturation=sat, hue=hue, contrast=contrast) sample["color_augm"] = { "sat": sat, "bright": bright, "contrast": contrast, "hue": hue } else: sample["color_augm"] = None # Create buffer white image if needed if TransQueries.JITTERMASK in query: whiteimg = Image.new("RGB", img.size, (255, 255, 255)) # Transform and crop img = handutils.transform_img(img, affinetrans, self.inp_res) img = img.crop((0, 0, self.inp_res[0], self.inp_res[1])) #img = img.resize((480, 270), Image.BILINEAR) # Tensorize and normalize_img img = func_transforms.to_tensor(img).float() if self.normalize_img: img = func_transforms.normalize(img, self.mean, self.std) else: img = func_transforms.normalize(img, [0.5, 0.5, 0.5], [1, 1, 1]) if TransQueries.IMAGE in query: sample[TransQueries.IMAGE] = img if TransQueries.JITTERMASK in query: jittermask = handutils.transform_img(whiteimg, affinetrans, self.inp_res) jittermask = jittermask.crop( (0, 0, self.inp_res[0], self.inp_res[1])) jittermask = func_transforms.to_tensor(jittermask).float() sample[TransQueries.JITTERMASK] = jittermask if self.pose_dataset.has_dist2strong and self.has_dist2strong: dist2strong = self.pose_dataset.get_dist2strong(idx) sample["dist2strong"] = dist2strong return sample