Ejemplo n.º 1
0
    def run(self, input, scalefactor, path):
        """
        :param input: input mesh to reconstruct optimally.
        :return: final reconstruction after optimisation
        """
        self.load_template()
        input, translation = my_utils.center(input)
        if not self.HR:
            mesh_ref = self.mesh_ref_LR
        else:
            mesh_ref = self.mesh_ref

        ## Extract points and put them on GPU
        points = input.vertices
        # TODO : remove random here
        random_sample = np.random.choice(np.shape(points)[0], size=10000)

        points = torch.from_numpy(points.astype(
            np.float32)).contiguous().unsqueeze(0)
        points = points.transpose(2, 1).contiguous()
        points = points.cuda()
        #print("Points size is ",points.size())
        # Get a low resolution PC to find the best reconstruction after a rotation on the Y axis
        if self.LR_input:
            print("Using a Low_res input")
            points_LR = torch.from_numpy(input.vertices[random_sample].astype(
                np.float32)).contiguous().unsqueeze(0)
        else:
            print("Using a High_res input")
            points_LR = torch.from_numpy(input.vertices.astype(
                np.float32)).contiguous().unsqueeze(0)

        input_LR_mesh = trimesh.Trimesh(
            vertices=(points_LR.squeeze().data.cpu().numpy() + translation) /
            scalefactor,
            faces=np.array([1, 2, 3]),
            process=False)
        if self.save_path is None:
            input_LR_mesh.export(path[:-4] + "DownsampledInput.ply")
        else:
            input_LR_mesh.export(
                os.path.join(self.save_path,
                             path[-8:-4] + "DownsampledInput.ply"))

        points_LR = points_LR.transpose(2, 1).contiguous()
        points_LR = points_LR.cuda()
        #print("points_LR size is ", points_LR.size())
        theta = 0
        bestLoss = 1000
        pointsReconstructed = self.Tran_points_test(points_LR)
        dist1, dist2 = distChamfer(
            points_LR.transpose(2, 1).contiguous(), pointsReconstructed)
        loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
        print(
            "loss without rotation: ", loss_net.item(), 0
        )  # ---- Search best angle for best reconstruction on the Y axis---

        x = np.linspace(-np.pi / 2, np.pi / 2, self.num_angles)
        y = np.linspace(-np.pi / 4, np.pi / 4, self.num_angles // 4)

        THETA, PHI = np.meshgrid(x, y)
        Z = np.ndarray([THETA.shape[0], THETA.shape[1]])
        for j in range(THETA.shape[1]):
            for i in range(THETA.shape[0]):
                if self.num_angles == 1:
                    theta = 0
                    phi = 0
                theta = THETA[i, j]
                phi = PHI[i, j]

                #  Rotate mesh by theta and renormalise
                rot_matrix = np.array([[np.cos(theta), 0,
                                        np.sin(theta)], [0, 1, 0],
                                       [-np.sin(theta), 0,
                                        np.cos(theta)]])
                rot_matrix = torch.from_numpy(rot_matrix).float().cuda()
                rot_matrix = torch.matmul(
                    torch.from_numpy(
                        np.array([
                            [np.cos(phi), np.sin(phi), 0],
                            [-np.sin(phi), np.cos(phi), 0],
                            [0, 0, 1],
                        ])).float().cuda(), rot_matrix)
                points2 = torch.matmul(rot_matrix, points_LR)
                mesh_tmp = trimesh.Trimesh(process=False,
                                           use_embree=False,
                                           vertices=points2[0].transpose(
                                               1, 0).data.cpu().numpy(),
                                           faces=self.network.mesh.faces)
                # bbox
                bbox = np.array([[
                    np.max(mesh_tmp.vertices[:, 0]),
                    np.max(mesh_tmp.vertices[:, 1]),
                    np.max(mesh_tmp.vertices[:, 2])
                ],
                                 [
                                     np.min(mesh_tmp.vertices[:, 0]),
                                     np.min(mesh_tmp.vertices[:, 1]),
                                     np.min(mesh_tmp.vertices[:, 2])
                                 ]])
                norma = torch.from_numpy(
                    (bbox[0] + bbox[1]) / 2).float().cuda()

                norma2 = norma.unsqueeze(1).expand(
                    3, points2.size(2)).contiguous()
                points2[0] = points2[0] - norma2
                #from IPython import embed
                # reconstruct rotated mesh
                pointsReconstructed = self.Tran_points_test(
                    points2)  #self.network(points2)
                dist1, dist2 = distChamfer(
                    points2.transpose(2, 1).contiguous(), pointsReconstructed)

                loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
                Z[i, j] = loss_net.item()
                if loss_net < bestLoss:
                    print(theta, phi, loss_net)
                    bestLoss = loss_net
                    best_theta = theta
                    best_phi = phi
                    # unrotate the mesh
                    norma3 = norma.unsqueeze(0).expand(
                        pointsReconstructed.size(1), 3).contiguous()
                    pointsReconstructed[0] = pointsReconstructed[0] + norma3
                    rot_matrix = np.array(
                        [[np.cos(-theta), 0, np.sin(-theta)], [0, 1, 0],
                         [-np.sin(-theta), 0,
                          np.cos(-theta)]])
                    rot_matrix = torch.from_numpy(rot_matrix).float().cuda()
                    rot_matrix = torch.matmul(
                        rot_matrix,
                        torch.from_numpy(
                            np.array([
                                [np.cos(-phi), np.sin(-phi), 0],
                                [-np.sin(-phi), np.cos(-phi), 0],
                                [0, 0, 1],
                            ])).float().cuda())
                    pointsReconstructed = torch.matmul(
                        pointsReconstructed, rot_matrix.transpose(1, 0))
                    bestPoints = pointsReconstructed

        try:
            fig = plt.figure()
            ax = plt.axes(projection='3d')
            ax.plot_surface(THETA,
                            PHI,
                            -Z,
                            rstride=1,
                            cstride=1,
                            cmap='magma',
                            edgecolor='none',
                            alpha=0.8)

            ax.set_xlabel('THETA', fontsize=20)
            ax.set_ylabel('PHI', fontsize=20)
            ax.set_zlabel('CHAMFER', fontsize=20)
            ax.scatter(best_theta,
                       best_phi,
                       -bestLoss.item(),
                       marker='*',
                       c="red",
                       s=100,
                       alpha=1)
            ax.scatter(best_theta,
                       best_phi,
                       np.min(-Z),
                       marker='*',
                       c="red",
                       s=100,
                       alpha=1)
            ax.view_init(elev=45., azim=45)
            plt.savefig("3Dcurve.png")
            if self.save_path is not None:
                plt.savefig(
                    os.path.join(self.save_path, path[-8:-4] + "3Dcurve.png"))
            else:
                plt.savefig(path[:-4] + "3Dcurve.png")

        except:
            pass
        # for theta in np.linspace(-np.pi/2, np.pi/2, self.num_angles):
        #     if self.num_angles == 1:
        #         theta = 0
        #     X.append(theta)
        #
        #     #  Rotate mesh by theta and renormalise
        #     rot_matrix = np.array([[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [- np.sin(theta), 0,  np.cos(theta)]])
        #     rot_matrix = torch.from_numpy(rot_matrix).float().cuda()
        #     points2 = torch.matmul(rot_matrix, points_LR).squeeze()
        #     #bbox
        #     bbox = torch.Tensor([[torch.max(points2[0]), torch.max(points2[1]), torch.max(points2[2])], [torch.min(points2[0]), torch.min(points2[1]), torch.min(points2[2])]])
        #     norma = ((bbox[0] + bbox[1]) / 2).float().cuda()
        #     norma = norma.cuda()
        #     points2 = points2.unsqueeze(0)
        #     norma2 = norma.unsqueeze(1).expand(3,points2.size(2)).contiguous()
        #     points2[0] = points2[0] - norma2
        #
        #     # reconstruct rotated mesh
        #     pointsReconstructed = self.network(points2)
        #     dist1, dist2 = distChamfer(points2.transpose(2, 1).contiguous(), pointsReconstructed)
        #
        #
        #     loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
        #     Y.append(loss_net.item())
        #     if loss_net < bestLoss:
        #         bestLoss = loss_net
        #         best_theta = theta
        #         # unrotate the mesh
        #         norma3 = norma.unsqueeze(0).expand(pointsReconstructed.size(1), 3).contiguous()
        #         pointsReconstructed[0] = pointsReconstructed[0] + norma3
        #         rot_matrix = np.array([[np.cos(-theta), 0, np.sin(-theta)], [0, 1, 0], [- np.sin(-theta), 0,  np.cos(-theta)]])
        #         rot_matrix = torch.from_numpy(rot_matrix).float().cuda()
        #         pointsReconstructed = torch.matmul(pointsReconstructed, rot_matrix.transpose(1,0))
        #         bestPoints = pointsReconstructed

        print("best loss and theta and phi : ", bestLoss.item(), best_theta,
              best_phi)

        if self.HR:
            faces_tosave = self.network.mesh_HR.faces
        else:
            faces_tosave = self.network.mesh.faces

        # create initial guess
        mesh = trimesh.Trimesh(
            vertices=(bestPoints[0].data.cpu().numpy() + translation) /
            scalefactor,
            faces=self.network.mesh.faces,
            process=False)
        try:
            #plt.plot(X, Y)
            plt.savefig("curve.png")
        except:
            pass
        # START REGRESSION
        print("start regression...")

        # rotate with optimal angle
        rot_matrix = np.array([[np.cos(best_theta), 0,
                                np.sin(best_theta)], [0, 1, 0],
                               [-np.sin(best_theta), 0,
                                np.cos(best_theta)]])
        rot_matrix = torch.from_numpy(rot_matrix).float().cuda()
        points2 = torch.matmul(rot_matrix, points)
        mesh_tmp = trimesh.Trimesh(vertices=points2[0].transpose(
            1, 0).data.cpu().numpy(),
                                   faces=self.network.mesh.faces,
                                   process=False)
        bbox = np.array([[
            np.max(mesh_tmp.vertices[:, 0]),
            np.max(mesh_tmp.vertices[:, 1]),
            np.max(mesh_tmp.vertices[:, 2])
        ],
                         [
                             np.min(mesh_tmp.vertices[:, 0]),
                             np.min(mesh_tmp.vertices[:, 1]),
                             np.min(mesh_tmp.vertices[:, 2])
                         ]])
        norma = torch.from_numpy((bbox[0] + bbox[1]) / 2).float().cuda()
        norma2 = norma.unsqueeze(1).expand(3, points2.size(2)).contiguous()
        points2[0] = points2[0] - norma2
        pointsReconstructed1 = self.regress(points2)
        # unrotate with optimal angle
        norma3 = norma.unsqueeze(0).expand(pointsReconstructed1.size(1),
                                           3).contiguous()
        rot_matrix = np.array([[np.cos(-best_theta), 0,
                                np.sin(-best_theta)], [0, 1, 0],
                               [-np.sin(-best_theta), 0,
                                np.cos(-best_theta)]])
        rot_matrix = torch.from_numpy(rot_matrix).float().cuda()
        pointsReconstructed1[0] = pointsReconstructed1[0] + norma3
        pointsReconstructed1 = torch.matmul(pointsReconstructed1,
                                            rot_matrix.transpose(1, 0))

        # create optimal reconstruction
        meshReg = trimesh.Trimesh(
            vertices=(pointsReconstructed1[0].data.cpu().numpy() + translation)
            / scalefactor,
            faces=faces_tosave,
            process=False)

        print("... Done!")
        return mesh, meshReg
Ejemplo n.º 2
0
    def run(self, input, scalefactor, path):
        """
        :param input: input mesh to reconstruct optimally.
        :return: final reconstruction after optimisation
        """

        input, translation = my_utils.center(input)
        if not self.HR:
            mesh_ref = self.mesh_ref_LR
        else:
            mesh_ref = self.mesh_ref

        ## Extract points and put them on GPU
        points = input.vertices
        # TODO : remove random here
        random_sample = np.random.choice(np.shape(points)[0], size=10000)

        points = torch.from_numpy(points.astype(
            np.float32)).contiguous().unsqueeze(0)
        points = points.transpose(2, 1).contiguous()
        points = points.cuda()

        # Get a low resolution PC to find the best reconstruction after a rotation on the Y axis
        if self.LR_input:
            print("Using a Low_res input")
            points_LR = torch.from_numpy(input.vertices[random_sample].astype(
                np.float32)).contiguous().unsqueeze(0)
        else:
            print("Using a High_res input")
            points_LR = torch.from_numpy(input.vertices.astype(
                np.float32)).contiguous().unsqueeze(0)

        input_LR_mesh = trimesh.Trimesh(
            vertices=(points_LR.squeeze().data.cpu().numpy() + translation) /
            scalefactor,
            faces=np.array([1, 2, 3]),
            process=False)
        if self.save_path is None:
            input_LR_mesh.export(path[:-4] + "DownsampledInput.ply")
        else:
            input_LR_mesh.export(
                os.path.join(self.save_path,
                             path[-8:-4] + "DownsampledInput.ply"))

        points_LR = points_LR.transpose(2, 1).contiguous()
        points_LR = points_LR.cuda()

        theta = 0
        best_template = 0
        bestLoss = 10
        pointsReconstructed = self.network(points_LR)
        pointsReconstructed = pointsReconstructed.view(
            pointsReconstructed.size(0), -1, 3)
        dist1, dist2 = distChamfer(
            points_LR.transpose(2, 1).contiguous(), pointsReconstructed)
        loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
        print(
            "loss without rotation: ", loss_net.item(), 0
        )  # ---- Search best angle for best reconstruction on the Y axis---

        x = np.linspace(-np.pi / 2, np.pi / 2, self.num_angles)
        y = np.linspace(-np.pi / 4, np.pi / 4, self.num_angles // 4)

        THETA, PHI = np.meshgrid(x, y)
        Z = np.ndarray([THETA.shape[0], THETA.shape[1]])

        rotateCenterPointCloud = pointcloud_processor.RotateCenterPointCloud(
            points_LR)
        for j in range(THETA.shape[1]):
            for i in range(THETA.shape[0]):
                if self.num_angles == 1:
                    theta = 0
                    phi = 0
                theta = THETA[i, j]
                phi = PHI[i, j]
                rotateCenterPointCloud.rotate_center(phi, theta)
                input_network = rotateCenterPointCloud.centered_points
                pointsReconstructed = self.network(input_network)
                pointsReconstructed = pointsReconstructed.view(
                    pointsReconstructed.size(0), -1, 3)
                dist1, dist2 = distChamfer(
                    input_network.transpose(2, 1).contiguous(),
                    pointsReconstructed)
                loss_net = (torch.mean(dist1)) + (torch.mean(dist2))

                Z[i, j] = loss_net.item()

                if loss_net < bestLoss:
                    print(theta, phi, loss_net.item())
                    bestLoss = loss_net
                    best_theta = theta
                    best_phi = phi
                    # unrotate the mesh
                    pointsReconstructed[0] = rotateCenterPointCloud.back(
                        pointsReconstructed[0])
                    bestPoints = pointsReconstructed

        try:
            fig = plt.figure()
            ax = plt.axes(projection='3d')
            ax.plot_surface(THETA,
                            PHI,
                            -Z,
                            rstride=1,
                            cstride=1,
                            cmap='magma',
                            edgecolor='none',
                            alpha=0.8)

            ax.set_xlabel('THETA', fontsize=20)
            ax.set_ylabel('PHI', fontsize=20)
            ax.set_zlabel('CHAMFER', fontsize=20)
            ax.scatter(best_theta,
                       best_phi,
                       -bestLoss.item(),
                       marker='*',
                       c="red",
                       s=100,
                       alpha=1)
            ax.scatter(best_theta,
                       best_phi,
                       np.min(-Z),
                       marker='*',
                       c="red",
                       s=100,
                       alpha=1)
            ax.view_init(elev=45., azim=45)
            plt.savefig("3Dcurve.png")
            if self.save_path is not None:
                plt.savefig(
                    os.path.join(self.save_path, path[-8:-4] + "3Dcurve.png"))
            else:
                plt.savefig(path[:-4] + "3Dcurve.png")

        except:
            pass

        print("best loss and theta and phi : ", bestLoss.item(), best_theta,
              best_phi)

        if self.HR:
            faces_tosave = self.network.template[0].mesh_HR.faces
        else:
            faces_tosave = self.network.template[0].mesh.faces

        # create initial guess
        mesh = trimesh.Trimesh(
            vertices=(bestPoints[0].data.cpu().numpy() + translation) /
            scalefactor,
            faces=self.network.template[0].mesh.faces,
            process=False)
        try:
            plt.plot(X, Y)
            plt.savefig("curve.png")
        except:
            pass

        # START REGRESSION on high rez input
        print("start regression...")

        rotateCenterPointCloud = pointcloud_processor.RotateCenterPointCloud(
            points)
        rotateCenterPointCloud.rotate_center(best_phi, best_theta)
        input_network = rotateCenterPointCloud.centered_points
        pointsReconstructed1 = self.regress(input_network)
        pointsReconstructed1[0] = rotateCenterPointCloud.back(
            pointsReconstructed1[0])
        # create optimal reconstruction
        meshReg = trimesh.Trimesh(
            vertices=(pointsReconstructed1[0].data.cpu().numpy() + translation)
            / scalefactor,
            faces=faces_tosave,
            process=False)

        print("... Done!")
        return mesh, meshReg