def get_sample(self, idx, query=None, color_augm=None, space_augm=None): if query is None: query = self.queries sample = {} if BaseQueries.IMAGE in query or TransQueries.IMAGE in query: center, scale = self.pose_dataset.get_center_scale(idx) needs_center_scale = True else: needs_center_scale = False if BaseQueries.JOINTVIS in query: jointvis = self.pose_dataset.get_jointvis(idx) sample[BaseQueries.JOINTVIS] = jointvis # Get sides if BaseQueries.SIDE in query: hand_side = self.pose_dataset.get_sides(idx) hand_side, flip = datutils.flip_hand_side(self.sides, hand_side) sample[BaseQueries.SIDE] = hand_side else: flip = False # Get original image if BaseQueries.IMAGE in query or TransQueries.IMAGE in query: img = self.pose_dataset.get_image(idx) #img = img.resize((480, 270), Image.BILINEAR) if flip: img = img.transpose(Image.FLIP_LEFT_RIGHT) if BaseQueries.IMAGE in query: sample[BaseQueries.IMAGE] = np.array(img) # Flip and image 2d if needed if flip: center[0] = img.size[0] - center[0] # Data augmentation if space_augm is not None: center = space_augm["center"] scale = space_augm["scale"] rot = space_augm["rot"] elif self.train and needs_center_scale: # Randomly jitter center # Center is located in square of size 2*center_jitter_factor # in center of cropped image center_jit = Uniform(low=-1, high=1).sample((2, )).numpy() center_offsets = self.center_jittering * scale * center_jit center = center + center_offsets.astype(int) # Scale jittering scale_jit = Normal(0, 1).sample().item() + 1 scale_jittering = self.scale_jittering * scale_jit scale_jittering = np.clip(scale_jittering, 1 - self.scale_jittering, 1 + self.scale_jittering) scale = scale * scale_jittering rot = Uniform(low=-self.max_rot, high=self.max_rot).sample().item() else: rot = 0 if self.block_rot: rot = 0 space_augm = {"rot": rot, "scale": scale, "center": center} sample["space_augm"] = space_augm rot_mat = np.array([[np.cos(rot), -np.sin(rot), 0], [np.sin(rot), np.cos(rot), 0], [0, 0, 1]]).astype(np.float32) # Get 2D hand joints if (TransQueries.JOINTS2D in query) or (TransQueries.IMAGE in query): affinetrans, post_rot_trans = handutils.get_affine_transform( center, scale, self.inp_res, rot=rot) if TransQueries.AFFINETRANS in query: sample[TransQueries.AFFINETRANS] = affinetrans if BaseQueries.JOINTS2D in query or TransQueries.JOINTS2D in query: joints2d = self.pose_dataset.get_joints2d(idx) if flip: joints2d = joints2d.copy() joints2d[:, 0] = img.size[0] - joints2d[:, 0] if BaseQueries.JOINTS2D in query: sample[BaseQueries.JOINTS2D] = joints2d.astype(np.float32) if TransQueries.JOINTS2D in query: rows = handutils.transform_coords(joints2d, affinetrans) sample[TransQueries.JOINTS2D] = np.array(rows).astype(np.float32) if BaseQueries.CAMINTR in query or TransQueries.CAMINTR in query: camintr = self.pose_dataset.get_camintr(idx) if BaseQueries.CAMINTR in query: sample[BaseQueries.CAMINTR] = camintr.astype(np.float32) if TransQueries.CAMINTR in query: # Rotation is applied as extr transform new_camintr = post_rot_trans.dot(camintr) sample[TransQueries.CAMINTR] = new_camintr.astype(np.float32) # Get 2D object points if BaseQueries.OBJVERTS2D in query or (TransQueries.OBJVERTS2D in query): objverts2d = self.pose_dataset.get_objverts2d(idx) if flip: objverts2d = objverts2d.copy() objverts2d[:, 0] = img.size[0] - objverts2d[:, 0] if BaseQueries.OBJVERTS2D in query: sample[BaseQueries.OBJVERTS2D] = objverts2d.astype(np.float32) if TransQueries.OBJVERTS2D in query: transobjverts2d = handutils.transform_coords( objverts2d, affinetrans) sample[TransQueries.OBJVERTS2D] = np.array( transobjverts2d).astype(np.float32) if BaseQueries.OBJVIS2D in query: objvis2d = self.pose_dataset.get_objvis2d(idx) sample[BaseQueries.OBJVIS2D] = objvis2d # Get 2D object points if BaseQueries.OBJCORNERS2D in query or (TransQueries.OBJCORNERS2D in query): objcorners2d = self.pose_dataset.get_objcorners2d(idx) if flip: objcorners2d = objcorners2d.copy() objcorners2d[:, 0] = img.size[0] - objcorners2d[:, 0] if BaseQueries.OBJCORNERS2D in query: sample[BaseQueries.OBJCORNERS2D] = np.array(objcorners2d) if TransQueries.OBJCORNERS2D in query: transobjcorners2d = handutils.transform_coords( objcorners2d, affinetrans) sample[TransQueries.OBJCORNERS2D] = np.array(transobjcorners2d) # Get 2D hand points if BaseQueries.HANDVERTS2D in query or (TransQueries.HANDVERTS2D in query): handverts2d = self.pose_dataset.get_hand_verts2d(idx) if flip: handverts2d = handverts2d.copy() handverts2d[:, 0] = img.size[0] - handverts2d[:, 0] if BaseQueries.HANDVERTS2D in query: sample[BaseQueries.HANDVERTS2D] = handverts2d if TransQueries.HANDVERTS2D in query: transhandverts2d = handutils.transform_coords( handverts2d, affinetrans) sample[TransQueries.HANDVERTS2D] = np.array(transhandverts2d) if BaseQueries.HANDVIS2D in query: handvis2d = self.pose_dataset.get_handvis2d(idx) sample[BaseQueries.HANDVIS2D] = handvis2d # Get 3D hand joints if ((BaseQueries.JOINTS3D in query) or (TransQueries.JOINTS3D in query) or (TransQueries.HANDVERTS3D in query) or (TransQueries.OBJVERTS3D in query)): # Center on root joint center3d_queries = [ TransQueries.JOINTS3D, BaseQueries.JOINTS3D, TransQueries.HANDVERTS3D ] if one_query_in([TransQueries.OBJVERTS3D] + center3d_queries, query): joints3d = self.pose_dataset.get_joints3d(idx) if flip: joints3d[:, 0] = -joints3d[:, 0] if BaseQueries.JOINTS3D in query: sample[BaseQueries.JOINTS3D] = joints3d.astype(np.float32) if self.train: joints3d = rot_mat.dot(joints3d.transpose(1, 0)).transpose() # Compute 3D center if self.center_idx is not None: if self.center_idx == -1: center3d = (joints3d[9] + joints3d[0]) / 2 else: center3d = joints3d[self.center_idx] if TransQueries.JOINTS3D in query and (self.center_idx is not None): joints3d = joints3d - center3d if TransQueries.JOINTS3D in query: sample[TransQueries.JOINTS3D] = joints3d.astype(np.float32) # Get 3D hand vertices if TransQueries.HANDVERTS3D in query or BaseQueries.HANDVERTS3D in query: hand_verts3d = self.pose_dataset.get_hand_verts3d(idx) if flip: hand_verts3d[:, 0] = -hand_verts3d[:, 0] if BaseQueries.OBJVERTS3D in query: sample[BaseQueries.HANDVERTS3D] = hand_verts3d.astype( np.float32) if TransQueries.HANDVERTS3D in query: hand_verts3d = rot_mat.dot(hand_verts3d.transpose( 1, 0)).transpose() if self.center_idx is not None: hand_verts3d = hand_verts3d - center3d sample[TransQueries.HANDVERTS3D] = hand_verts3d.astype( np.float32) # Get 3D obj vertices if TransQueries.OBJVERTS3D in query or BaseQueries.OBJVERTS3D in query: obj_verts3d = self.pose_dataset.get_obj_verts_trans(idx) if flip: obj_verts3d[:, 0] = -obj_verts3d[:, 0] if BaseQueries.OBJVERTS3D in query: sample[BaseQueries.OBJVERTS3D] = obj_verts3d if TransQueries.OBJVERTS3D in query: origin_trans_mesh = rot_mat.dot(obj_verts3d.transpose( 1, 0)).transpose() if self.center_idx is not None: origin_trans_mesh = origin_trans_mesh - center3d sample[TransQueries.OBJVERTS3D] = origin_trans_mesh.astype( np.float32) # Get 3D obj vertices if TransQueries.OBJCANROTVERTS in query or BaseQueries.OBJCANROTVERTS in query: obj_canverts3d = self.pose_dataset.get_obj_verts_can_rot(idx) if flip: obj_canverts3d[:, 0] = -obj_canverts3d[:, 0] if BaseQueries.OBJCANROTVERTS in query: sample[BaseQueries.OBJCANROTVERTS] = obj_canverts3d if TransQueries.OBJCANROTVERTS in query: can_rot_mesh = rot_mat.dot(obj_canverts3d.transpose( 1, 0)).transpose() sample[TransQueries.OBJCANROTVERTS] = can_rot_mesh # Get 3D obj vertices if TransQueries.OBJCANROTCORNERS in query or BaseQueries.OBJCANROTCORNERS in query: obj_cancorners3d = self.pose_dataset.get_obj_corners_can_rot(idx) if flip: obj_cancorners3d[:, 0] = -obj_cancorners3d[:, 0] if BaseQueries.OBJCANROTCORNERS in query: sample[BaseQueries.OBJCANROTCORNERS] = obj_cancorners3d if TransQueries.OBJCANROTCORNERS in query: can_rot_corners = rot_mat.dot(obj_cancorners3d.transpose( 1, 0)).transpose() sample[TransQueries.OBJCANROTCORNERS] = can_rot_corners if BaseQueries.OBJFACES in query: obj_faces = self.pose_dataset.get_obj_faces(idx) sample[BaseQueries.OBJFACES] = obj_faces if BaseQueries.OBJCANVERTS in query: obj_canverts, obj_cantrans, obj_canscale = self.pose_dataset.get_obj_verts_can( idx) if flip: obj_canverts[:, 0] = -obj_canverts[:, 0] sample[BaseQueries.OBJCANVERTS] = obj_canverts sample[BaseQueries.OBJCANSCALE] = obj_canscale sample[BaseQueries.OBJCANTRANS] = obj_cantrans # Get 3D obj corners if BaseQueries.OBJCORNERS3D in query or TransQueries.OBJCORNERS3D in query: obj_corners3d = self.pose_dataset.get_obj_corners3d(idx) if flip: obj_corners3d[:, 0] = -obj_corners3d[:, 0] if BaseQueries.OBJCORNERS3D in query: sample[BaseQueries.OBJCORNERS3D] = obj_corners3d if TransQueries.OBJCORNERS3D in query: origin_trans_corners = rot_mat.dot( obj_corners3d.transpose(1, 0)).transpose() if self.center_idx is not None: origin_trans_corners = origin_trans_corners - center3d sample[TransQueries.OBJCORNERS3D] = origin_trans_corners if BaseQueries.OBJCANCORNERS in query: if flip: obj_canverts[:, 0] = -obj_canverts[:, 0] obj_cancorners = self.pose_dataset.get_obj_corners_can(idx) sample[BaseQueries.OBJCANCORNERS] = obj_cancorners if TransQueries.CENTER3D in query: sample[TransQueries.CENTER3D] = center3d # Get rgb image if TransQueries.IMAGE in query: # Data augmentation if self.train: blur_radius = Uniform( low=0, high=1).sample().item() * self.blur_radius img = img.filter(ImageFilter.GaussianBlur(blur_radius)) if color_augm is None: bright, contrast, sat, hue = colortrans.get_color_params( brightness=self.brightness, saturation=self.saturation, hue=self.hue, contrast=self.contrast, ) else: sat = color_augm["sat"] contrast = color_augm["contrast"] hue = color_augm["hue"] bright = color_augm["bright"] img = colortrans.apply_jitter(img, brightness=bright, saturation=sat, hue=hue, contrast=contrast) sample["color_augm"] = { "sat": sat, "bright": bright, "contrast": contrast, "hue": hue } else: sample["color_augm"] = None # Create buffer white image if needed if TransQueries.JITTERMASK in query: whiteimg = Image.new("RGB", img.size, (255, 255, 255)) # Transform and crop img = handutils.transform_img(img, affinetrans, self.inp_res) img = img.crop((0, 0, self.inp_res[0], self.inp_res[1])) #img = img.resize((480, 270), Image.BILINEAR) # Tensorize and normalize_img img = func_transforms.to_tensor(img).float() if self.normalize_img: img = func_transforms.normalize(img, self.mean, self.std) else: img = func_transforms.normalize(img, [0.5, 0.5, 0.5], [1, 1, 1]) if TransQueries.IMAGE in query: sample[TransQueries.IMAGE] = img if TransQueries.JITTERMASK in query: jittermask = handutils.transform_img(whiteimg, affinetrans, self.inp_res) jittermask = jittermask.crop( (0, 0, self.inp_res[0], self.inp_res[1])) jittermask = func_transforms.to_tensor(jittermask).float() sample[TransQueries.JITTERMASK] = jittermask if self.pose_dataset.has_dist2strong and self.has_dist2strong: dist2strong = self.pose_dataset.get_dist2strong(idx) sample["dist2strong"] = dist2strong return sample
def forward(self, sample, no_loss=False, step=0, preparams=None): total_loss = torch.Tensor([0]).cuda() results = {} losses = {} # Get input image = sample[TransQueries.IMAGE].cuda() # Feed input into shared encoder encoder_output, encoder_features = self.base_net(image) has_mano_super = one_query_in( sample.keys(), [ TransQueries.JOINTS3D, TransQueries.JOINTS2D, TransQueries.HANDVERTS2D, TransQueries.HANDVERTS3D, ], ) if True or (has_mano_super and self.mano_lambdas): if preparams is not None: hand_scale = preparams["hand_prescale"] hand_pose = preparams["pose"] hand_shape = preparams["shape"] hand_trans = preparams["hand_pretrans"] else: hand_scale = None hand_pose = None hand_shape = None hand_trans = None # Hand branch mano_results, total_loss, mano_losses = self.recover_mano( sample, encoder_output=encoder_output, no_loss=no_loss, total_loss=total_loss, trans=hand_trans, scale=hand_scale, pose=hand_pose, shape=hand_shape, ) losses.update(mano_losses) results.update(mano_results) has_obj_super = one_query_in( sample.keys(), [TransQueries.OBJVERTS2D, TransQueries.OBJVERTS3D]) if has_obj_super and self.obj_lambdas: if preparams is not None: obj_scale = preparams["obj_prescale"] obj_rot = preparams["obj_prerot"] obj_trans = preparams["obj_pretrans"] else: obj_scale = None obj_rot = None obj_trans = None # Object branch obj_results, total_loss, obj_losses = self.recover_object( sample, image, encoder_output, encoder_features, no_loss=no_loss, total_loss=total_loss, scale=obj_scale, trans=obj_trans, rotaxisang=obj_rot) losses.update(obj_losses) results.update(obj_results) if total_loss is not None: losses["total_loss"] = total_loss else: losses["total_loss"] = None return total_loss, results, losses