Exemplo n.º 1
0
    def _get_projected_positions_of_sphere_points(self, sphere_points, rotation, translation):
        """
        For the given points on unit sphere calculates the 3D coordinates on the mesh template
        and projects them back to image plane

        :param sphere_points: A (B X 3 X H X W) tensor containing the predicted points on the sphere
        :param rotation: A (B X CP X 3 X 3) camera rotation tensor
        :param translation: A (B X CP X 3) camera translation tensor
        :return: A tuple(xy, z, uv, uv_3d)
            - xy - (B X CP X 2 X H X W) x,y values of the 3D points after projecting onto image plane
            - z - (B X CP X 1 X H X W) z value of the projection
            - uv - (B X 2 X H X W) UV values of the sphere coordinates
            - uv_3d - (B X H X W X 3) tensor with the 3D coordinates on the mesh
                template for the given sphere coordinates
        """

        uv = convert_3d_to_uv_coordinates(sphere_points.permute(0, 2, 3, 1))
        batch_size = uv.size(0)
        height = uv.size(1)
        width = uv.size(2)
        num_poses = rotation.size(1)

        uv_flatten = uv.view(-1, 2)
        uv_3d = self.uv_to_3d(uv_flatten).view(batch_size, 1, -1, 3)
        uv_3d = uv_3d.repeat(1, num_poses, 1, 1).view(batch_size*num_poses, -1, 3)

        cameras = OpenGLOrthographicCameras(device=sphere_points.device, R=rotation.view(-1, 3, 3), T=translation.view(-1, 3))
        xyz_cam = cameras.get_world_to_view_transform().transform_points(uv_3d)
        z = xyz_cam[:, :, 2:].view(batch_size, num_poses, height, width, 1)
        xy = cameras.transform_points(uv_3d)[:, :, :2].view(batch_size, num_poses, height, width, 2)

        xy = xy.permute(0, 1, 4, 2, 3).flip(2)
        z = z.permute(0, 1, 4, 2, 3)
        uv = uv.permute(0, 3, 1, 2)
        uv_3d = uv_3d.view(batch_size, num_poses, height, width, 3)[:, 0, :, :, :].squeeze()

        return xy, z, uv, uv_3d
Exemplo n.º 2
0
            scale, trans, quat, True)
        camera = OpenGLOrthographicCameras(device=device,
                                           R=rotation,
                                           T=translation)

        kps = (((data['kp'].to(device, dtype=torch.float) + 1) / 2) * 255).to(
            torch.int32)
        kp = draw_key_points(img, kps, key_point_colors)
        fig = plt.figure()
        plt.subplot(2, 2, 1)
        plt.imshow(kp[0].permute(1, 2, 0).cpu())

        # Plot the key points directly using the 3D keypoints and projecting them onto image plane
        kp_3d = torch.from_numpy(dataset.kp_3d).to(
            device, dtype=torch.float32).unsqueeze(0)
        xyz = camera.transform_points(kp_3d)
        xy = (((xyz[:, :, :2] + 1) / 2) * 255).to(torch.int32)
        kp_xy = torch.cat((xy, kps[:, :, 2:]), dim=2)
        kp_pred_direct = draw_key_points(img, kp_xy, key_point_colors)
        plt.subplot(2, 2, 2)
        plt.imshow(kp_pred_direct[0].permute(1, 2, 0).cpu())

        # Draw key points by converting 3D kps to uv values and then back to 3D
        kp_uv = convert_3d_to_uv_coordinates(kp_3d)
        uv_flatten = kp_uv.view(-1, 2)
        uv_3d = uv_to_3d(uv_flatten).view(1, -1, 3)
        xyz = camera.transform_points(uv_3d)
        xy = (((xyz[:, :, :2] + 1) / 2) * 255).to(torch.int32)
        kp_xy = torch.cat((xy, kps[:, :, 2:]), dim=2)
        kp_pred = draw_key_points(img, kp_xy, key_point_colors)
        plt.subplot(2, 2, 3)