def get_sample(self, idx, query=None, color_augm=None, space_augm=None):
        if query is None:
            query = self.queries
        sample = {}

        if BaseQueries.IMAGE in query or TransQueries.IMAGE in query:
            center, scale = self.pose_dataset.get_center_scale(idx)
            needs_center_scale = True
        else:
            needs_center_scale = False

        if BaseQueries.JOINTVIS in query:
            jointvis = self.pose_dataset.get_jointvis(idx)
            sample[BaseQueries.JOINTVIS] = jointvis

        # Get sides
        if BaseQueries.SIDE in query:
            hand_side = self.pose_dataset.get_sides(idx)
            hand_side, flip = datutils.flip_hand_side(self.sides, hand_side)
            sample[BaseQueries.SIDE] = hand_side
        else:
            flip = False

        # Get original image
        if BaseQueries.IMAGE in query or TransQueries.IMAGE in query:
            img = self.pose_dataset.get_image(idx)
            #img = img.resize((480, 270), Image.BILINEAR)
            if flip:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)

            if BaseQueries.IMAGE in query:
                sample[BaseQueries.IMAGE] = np.array(img)

        # Flip and image 2d if needed
        if flip:
            center[0] = img.size[0] - center[0]
        # Data augmentation
        if space_augm is not None:
            center = space_augm["center"]
            scale = space_augm["scale"]
            rot = space_augm["rot"]
        elif self.train and needs_center_scale:
            # Randomly jitter center
            # Center is located in square of size 2*center_jitter_factor
            # in center of cropped image
            center_jit = Uniform(low=-1, high=1).sample((2, )).numpy()
            center_offsets = self.center_jittering * scale * center_jit
            center = center + center_offsets.astype(int)

            # Scale jittering
            scale_jit = Normal(0, 1).sample().item() + 1
            scale_jittering = self.scale_jittering * scale_jit
            scale_jittering = np.clip(scale_jittering,
                                      1 - self.scale_jittering,
                                      1 + self.scale_jittering)
            scale = scale * scale_jittering

            rot = Uniform(low=-self.max_rot, high=self.max_rot).sample().item()
        else:
            rot = 0
        if self.block_rot:
            rot = 0
        space_augm = {"rot": rot, "scale": scale, "center": center}
        sample["space_augm"] = space_augm
        rot_mat = np.array([[np.cos(rot), -np.sin(rot), 0],
                            [np.sin(rot), np.cos(rot), 0],
                            [0, 0, 1]]).astype(np.float32)

        # Get 2D hand joints
        if (TransQueries.JOINTS2D in query) or (TransQueries.IMAGE in query):
            affinetrans, post_rot_trans = handutils.get_affine_transform(
                center, scale, self.inp_res, rot=rot)
            if TransQueries.AFFINETRANS in query:
                sample[TransQueries.AFFINETRANS] = affinetrans
        if BaseQueries.JOINTS2D in query or TransQueries.JOINTS2D in query:
            joints2d = self.pose_dataset.get_joints2d(idx)
            if flip:
                joints2d = joints2d.copy()
                joints2d[:, 0] = img.size[0] - joints2d[:, 0]
            if BaseQueries.JOINTS2D in query:
                sample[BaseQueries.JOINTS2D] = joints2d.astype(np.float32)
        if TransQueries.JOINTS2D in query:
            rows = handutils.transform_coords(joints2d, affinetrans)
            sample[TransQueries.JOINTS2D] = np.array(rows).astype(np.float32)

        if BaseQueries.CAMINTR in query or TransQueries.CAMINTR in query:
            camintr = self.pose_dataset.get_camintr(idx)
            if BaseQueries.CAMINTR in query:
                sample[BaseQueries.CAMINTR] = camintr.astype(np.float32)
            if TransQueries.CAMINTR in query:
                # Rotation is applied as extr transform
                new_camintr = post_rot_trans.dot(camintr)
                sample[TransQueries.CAMINTR] = new_camintr.astype(np.float32)

        # Get 2D object points
        if BaseQueries.OBJVERTS2D in query or (TransQueries.OBJVERTS2D
                                               in query):
            objverts2d = self.pose_dataset.get_objverts2d(idx)
            if flip:
                objverts2d = objverts2d.copy()
                objverts2d[:, 0] = img.size[0] - objverts2d[:, 0]
            if BaseQueries.OBJVERTS2D in query:
                sample[BaseQueries.OBJVERTS2D] = objverts2d.astype(np.float32)
            if TransQueries.OBJVERTS2D in query:
                transobjverts2d = handutils.transform_coords(
                    objverts2d, affinetrans)
                sample[TransQueries.OBJVERTS2D] = np.array(
                    transobjverts2d).astype(np.float32)
            if BaseQueries.OBJVIS2D in query:
                objvis2d = self.pose_dataset.get_objvis2d(idx)
                sample[BaseQueries.OBJVIS2D] = objvis2d

        # Get 2D object points
        if BaseQueries.OBJCORNERS2D in query or (TransQueries.OBJCORNERS2D
                                                 in query):
            objcorners2d = self.pose_dataset.get_objcorners2d(idx)
            if flip:
                objcorners2d = objcorners2d.copy()
                objcorners2d[:, 0] = img.size[0] - objcorners2d[:, 0]
            if BaseQueries.OBJCORNERS2D in query:
                sample[BaseQueries.OBJCORNERS2D] = np.array(objcorners2d)
            if TransQueries.OBJCORNERS2D in query:
                transobjcorners2d = handutils.transform_coords(
                    objcorners2d, affinetrans)
                sample[TransQueries.OBJCORNERS2D] = np.array(transobjcorners2d)

        # Get 2D hand points
        if BaseQueries.HANDVERTS2D in query or (TransQueries.HANDVERTS2D
                                                in query):
            handverts2d = self.pose_dataset.get_hand_verts2d(idx)
            if flip:
                handverts2d = handverts2d.copy()
                handverts2d[:, 0] = img.size[0] - handverts2d[:, 0]
            if BaseQueries.HANDVERTS2D in query:
                sample[BaseQueries.HANDVERTS2D] = handverts2d
            if TransQueries.HANDVERTS2D in query:
                transhandverts2d = handutils.transform_coords(
                    handverts2d, affinetrans)
                sample[TransQueries.HANDVERTS2D] = np.array(transhandverts2d)
            if BaseQueries.HANDVIS2D in query:
                handvis2d = self.pose_dataset.get_handvis2d(idx)
                sample[BaseQueries.HANDVIS2D] = handvis2d

        # Get 3D hand joints
        if ((BaseQueries.JOINTS3D in query) or (TransQueries.JOINTS3D in query)
                or (TransQueries.HANDVERTS3D in query)
                or (TransQueries.OBJVERTS3D in query)):
            # Center on root joint
            center3d_queries = [
                TransQueries.JOINTS3D, BaseQueries.JOINTS3D,
                TransQueries.HANDVERTS3D
            ]
            if one_query_in([TransQueries.OBJVERTS3D] + center3d_queries,
                            query):
                joints3d = self.pose_dataset.get_joints3d(idx)
                if flip:
                    joints3d[:, 0] = -joints3d[:, 0]

                if BaseQueries.JOINTS3D in query:
                    sample[BaseQueries.JOINTS3D] = joints3d.astype(np.float32)
                if self.train:
                    joints3d = rot_mat.dot(joints3d.transpose(1,
                                                              0)).transpose()
                # Compute 3D center
                if self.center_idx is not None:
                    if self.center_idx == -1:
                        center3d = (joints3d[9] + joints3d[0]) / 2
                    else:
                        center3d = joints3d[self.center_idx]
                if TransQueries.JOINTS3D in query and (self.center_idx
                                                       is not None):
                    joints3d = joints3d - center3d
                if TransQueries.JOINTS3D in query:
                    sample[TransQueries.JOINTS3D] = joints3d.astype(np.float32)

        # Get 3D hand vertices
        if TransQueries.HANDVERTS3D in query or BaseQueries.HANDVERTS3D in query:
            hand_verts3d = self.pose_dataset.get_hand_verts3d(idx)
            if flip:
                hand_verts3d[:, 0] = -hand_verts3d[:, 0]
            if BaseQueries.OBJVERTS3D in query:
                sample[BaseQueries.HANDVERTS3D] = hand_verts3d.astype(
                    np.float32)
            if TransQueries.HANDVERTS3D in query:
                hand_verts3d = rot_mat.dot(hand_verts3d.transpose(
                    1, 0)).transpose()
                if self.center_idx is not None:
                    hand_verts3d = hand_verts3d - center3d
                sample[TransQueries.HANDVERTS3D] = hand_verts3d.astype(
                    np.float32)

        # Get 3D obj vertices
        if TransQueries.OBJVERTS3D in query or BaseQueries.OBJVERTS3D in query:
            obj_verts3d = self.pose_dataset.get_obj_verts_trans(idx)
            if flip:
                obj_verts3d[:, 0] = -obj_verts3d[:, 0]
            if BaseQueries.OBJVERTS3D in query:
                sample[BaseQueries.OBJVERTS3D] = obj_verts3d
            if TransQueries.OBJVERTS3D in query:
                origin_trans_mesh = rot_mat.dot(obj_verts3d.transpose(
                    1, 0)).transpose()
                if self.center_idx is not None:
                    origin_trans_mesh = origin_trans_mesh - center3d
                sample[TransQueries.OBJVERTS3D] = origin_trans_mesh.astype(
                    np.float32)

        # Get 3D obj vertices
        if TransQueries.OBJCANROTVERTS in query or BaseQueries.OBJCANROTVERTS in query:
            obj_canverts3d = self.pose_dataset.get_obj_verts_can_rot(idx)
            if flip:
                obj_canverts3d[:, 0] = -obj_canverts3d[:, 0]
            if BaseQueries.OBJCANROTVERTS in query:
                sample[BaseQueries.OBJCANROTVERTS] = obj_canverts3d
            if TransQueries.OBJCANROTVERTS in query:
                can_rot_mesh = rot_mat.dot(obj_canverts3d.transpose(
                    1, 0)).transpose()
                sample[TransQueries.OBJCANROTVERTS] = can_rot_mesh

        # Get 3D obj vertices
        if TransQueries.OBJCANROTCORNERS in query or BaseQueries.OBJCANROTCORNERS in query:
            obj_cancorners3d = self.pose_dataset.get_obj_corners_can_rot(idx)
            if flip:
                obj_cancorners3d[:, 0] = -obj_cancorners3d[:, 0]
            if BaseQueries.OBJCANROTCORNERS in query:
                sample[BaseQueries.OBJCANROTCORNERS] = obj_cancorners3d
            if TransQueries.OBJCANROTCORNERS in query:
                can_rot_corners = rot_mat.dot(obj_cancorners3d.transpose(
                    1, 0)).transpose()
                sample[TransQueries.OBJCANROTCORNERS] = can_rot_corners

        if BaseQueries.OBJFACES in query:
            obj_faces = self.pose_dataset.get_obj_faces(idx)
            sample[BaseQueries.OBJFACES] = obj_faces
        if BaseQueries.OBJCANVERTS in query:
            obj_canverts, obj_cantrans, obj_canscale = self.pose_dataset.get_obj_verts_can(
                idx)
            if flip:
                obj_canverts[:, 0] = -obj_canverts[:, 0]
            sample[BaseQueries.OBJCANVERTS] = obj_canverts
            sample[BaseQueries.OBJCANSCALE] = obj_canscale
            sample[BaseQueries.OBJCANTRANS] = obj_cantrans

        # Get 3D obj corners
        if BaseQueries.OBJCORNERS3D in query or TransQueries.OBJCORNERS3D in query:
            obj_corners3d = self.pose_dataset.get_obj_corners3d(idx)
            if flip:
                obj_corners3d[:, 0] = -obj_corners3d[:, 0]
            if BaseQueries.OBJCORNERS3D in query:
                sample[BaseQueries.OBJCORNERS3D] = obj_corners3d
            if TransQueries.OBJCORNERS3D in query:
                origin_trans_corners = rot_mat.dot(
                    obj_corners3d.transpose(1, 0)).transpose()
                if self.center_idx is not None:
                    origin_trans_corners = origin_trans_corners - center3d
                sample[TransQueries.OBJCORNERS3D] = origin_trans_corners
        if BaseQueries.OBJCANCORNERS in query:
            if flip:
                obj_canverts[:, 0] = -obj_canverts[:, 0]
            obj_cancorners = self.pose_dataset.get_obj_corners_can(idx)
            sample[BaseQueries.OBJCANCORNERS] = obj_cancorners

        if TransQueries.CENTER3D in query:
            sample[TransQueries.CENTER3D] = center3d

        # Get rgb image
        if TransQueries.IMAGE in query:
            # Data augmentation
            if self.train:
                blur_radius = Uniform(
                    low=0, high=1).sample().item() * self.blur_radius
                img = img.filter(ImageFilter.GaussianBlur(blur_radius))
                if color_augm is None:
                    bright, contrast, sat, hue = colortrans.get_color_params(
                        brightness=self.brightness,
                        saturation=self.saturation,
                        hue=self.hue,
                        contrast=self.contrast,
                    )
                else:
                    sat = color_augm["sat"]
                    contrast = color_augm["contrast"]
                    hue = color_augm["hue"]
                    bright = color_augm["bright"]
                img = colortrans.apply_jitter(img,
                                              brightness=bright,
                                              saturation=sat,
                                              hue=hue,
                                              contrast=contrast)
                sample["color_augm"] = {
                    "sat": sat,
                    "bright": bright,
                    "contrast": contrast,
                    "hue": hue
                }
            else:
                sample["color_augm"] = None
            # Create buffer white image if needed
            if TransQueries.JITTERMASK in query:
                whiteimg = Image.new("RGB", img.size, (255, 255, 255))
            # Transform and crop
            img = handutils.transform_img(img, affinetrans, self.inp_res)
            img = img.crop((0, 0, self.inp_res[0], self.inp_res[1]))
            #img = img.resize((480, 270), Image.BILINEAR)

            # Tensorize and normalize_img
            img = func_transforms.to_tensor(img).float()
            if self.normalize_img:
                img = func_transforms.normalize(img, self.mean, self.std)
            else:
                img = func_transforms.normalize(img, [0.5, 0.5, 0.5],
                                                [1, 1, 1])
            if TransQueries.IMAGE in query:
                sample[TransQueries.IMAGE] = img
            if TransQueries.JITTERMASK in query:
                jittermask = handutils.transform_img(whiteimg, affinetrans,
                                                     self.inp_res)
                jittermask = jittermask.crop(
                    (0, 0, self.inp_res[0], self.inp_res[1]))
                jittermask = func_transforms.to_tensor(jittermask).float()
                sample[TransQueries.JITTERMASK] = jittermask
        if self.pose_dataset.has_dist2strong and self.has_dist2strong:
            dist2strong = self.pose_dataset.get_dist2strong(idx)
            sample["dist2strong"] = dist2strong

        return sample
Example #2
0
    def forward(self, sample, no_loss=False, step=0, preparams=None):
        total_loss = torch.Tensor([0]).cuda()
        results = {}
        losses = {}
        # Get input
        image = sample[TransQueries.IMAGE].cuda()
        # Feed input into shared encoder
        encoder_output, encoder_features = self.base_net(image)
        has_mano_super = one_query_in(
            sample.keys(),
            [
                TransQueries.JOINTS3D,
                TransQueries.JOINTS2D,
                TransQueries.HANDVERTS2D,
                TransQueries.HANDVERTS3D,
            ],
        )
        if True or (has_mano_super and self.mano_lambdas):
            if preparams is not None:
                hand_scale = preparams["hand_prescale"]
                hand_pose = preparams["pose"]
                hand_shape = preparams["shape"]
                hand_trans = preparams["hand_pretrans"]
            else:
                hand_scale = None
                hand_pose = None
                hand_shape = None
                hand_trans = None
            # Hand branch
            mano_results, total_loss, mano_losses = self.recover_mano(
                sample,
                encoder_output=encoder_output,
                no_loss=no_loss,
                total_loss=total_loss,
                trans=hand_trans,
                scale=hand_scale,
                pose=hand_pose,
                shape=hand_shape,
            )
            losses.update(mano_losses)
            results.update(mano_results)

        has_obj_super = one_query_in(
            sample.keys(), [TransQueries.OBJVERTS2D, TransQueries.OBJVERTS3D])
        if has_obj_super and self.obj_lambdas:
            if preparams is not None:
                obj_scale = preparams["obj_prescale"]
                obj_rot = preparams["obj_prerot"]
                obj_trans = preparams["obj_pretrans"]
            else:
                obj_scale = None
                obj_rot = None
                obj_trans = None
            # Object branch
            obj_results, total_loss, obj_losses = self.recover_object(
                sample,
                image,
                encoder_output,
                encoder_features,
                no_loss=no_loss,
                total_loss=total_loss,
                scale=obj_scale,
                trans=obj_trans,
                rotaxisang=obj_rot)
            losses.update(obj_losses)
            results.update(obj_results)

        if total_loss is not None:
            losses["total_loss"] = total_loss
        else:
            losses["total_loss"] = None
        return total_loss, results, losses