Beispiel #1
0
class align(object):
    r"""
    Automated alignment methods between two meshes
    
    Parameters
    ----------
    moving: AmpObject
        The moving AmpObject that is to be aligned to the static object
    static: AmpObject
        The static AmpObject that the moving AmpObject that the moving object 
        will be aligned to
    method: str, default 'linPoint2Plane'
        A string of the method used for alignment
    *args:
    	The arguments used for the registration methods
    **kwargs:
    	The keyword arguments used for the registration methods

    Returns
    -------
    m: AmpObject
        The aligned AmpObject, it same number of vertices and face array as 
        the moving AmpObject
        Access this using align.m

    Examples
    --------
    >>> static = AmpObject(staticfh)
    >>> moving = AmpObject(movingfh)
    >>> al = align(moving, static).m

    """
    def __init__(self,
                 moving,
                 static,
                 method='linPoint2Plane',
                 *args,
                 **kwargs):
        mData = dict(
            zip(['vert', 'faces', 'values'],
                [moving.vert, moving.faces, moving.values]))
        alData = copy.deepcopy(mData)
        self.m = AmpObject(alData, stype='reg')
        self.s = static
        self.runICP(method=method, *args, **kwargs)

    def runICP(self,
               method='linPoint2Plane',
               maxiter=20,
               inlier=1.0,
               initTransform=None,
               *args,
               **kwargs):
        r"""
        The function to run the ICP algorithm, this function calls one of 
        multiple methods to calculate the affine transformation 
        
        Parameters
        ----------
        method: str, default 'linPoint2Plane'
            A string of the method used for alignment
        maxiter: int, default 20
            Maximum number of iterations to run the ICP algorithm
        inlier: float, default 1.0
            The proportion of closest points to use to calculate the 
            transformation, if < 1 then vertices with highest error are 
            discounted
        *args:
        	The arguments used for the registration methods
        **kwargs:
        	The keyword arguments used for the registration methods
        
        """
        # Define the rotation, translation, error and quaterion arrays
        Rs = np.zeros([3, 3, maxiter + 1])
        Ts = np.zeros([3, maxiter + 1])
        #        qs = np.r_[np.ones([1, maxiter+1]),
        #                   np.zeros([6, maxiter+1])]
        #        dq  = np.zeros([7, maxiter+1])
        #        dTheta = np.zeros([maxiter+1])
        err = np.zeros([maxiter + 1])
        if initTransform is None:
            initTransform = np.eye(4)
        Rs[:, :, 0] = initTransform[:3, :3]
        Ts[:, 0] = initTransform[3, :3]
        #        qs[:4, 0] = self.rot2quat(Rs[:, :, 0])
        #        qs[4:, 0] = Ts[:, 0]
        # Define
        fC = self.s.vert[self.s.faces].mean(axis=1)
        kdTree = spatial.cKDTree(fC)
        self.m.rigidTransform(Rs[:, :, 0], Ts[:, 0])
        inlier = math.ceil(self.m.vert.shape[0] * inlier)
        [dist, idx] = kdTree.query(self.m.vert, 1)
        # Sort by distance
        sort = np.argsort(dist)
        # Keep only those within the inlier fraction
        [dist, idx] = [dist[sort], idx[sort]]
        [dist, idx, sort] = dist[:inlier], idx[:inlier], sort[:inlier]
        err[0] = math.sqrt(dist.mean())
        for i in range(maxiter):
            if method == 'linPoint2Point':
                [R, T] = getattr(self, method)(self.m.vert[sort, :],
                                               fC[idx, :], *args, **kwargs)
            elif method == 'linPoint2Plane':
                [R, T] = getattr(self,
                                 method)(self.m.vert[sort, :], fC[idx, :],
                                         self.s.norm[idx, :], *args, **kwargs)
            elif method == 'optPoint2Point':
                [R, T] = getattr(self, method)(self.m.vert[sort, :],
                                               fC[idx, :], *args, **kwargs)
            else:
                KeyError('Not a supported alignment method')
            Rs[:, :, i + 1] = np.dot(R, Rs[:, :, i])
            Ts[:, i + 1] = np.dot(R, Ts[:, i]) + T
            self.m.rigidTransform(R, T)
            [dist, idx] = kdTree.query(self.m.vert, 1)
            sort = np.argsort(dist)
            [dist, idx] = [dist[sort], idx[sort]]
            [dist, idx, sort] = dist[:inlier], idx[:inlier], sort[:inlier]
            err[i + 1] = math.sqrt(dist.mean())
#            qs[:, i+1] = np.r_[self.rot2quat(R), T]
        R = Rs[:, :, -1]
        #Simpl
        [U, s, V] = np.linalg.svd(R)
        R = np.dot(U, V)
        self.tForm = np.r_[np.c_[R, np.zeros(3)],
                           np.append(Ts[:, -1], 1)[:, None].T]
        self.R = R
        self.T = Ts[:, -1]
        self.rmse = err[-1]

    @staticmethod
    def linPoint2Plane(mv, sv, sn):
        r"""
        Iterative Closest Point algorithm which relies on using least squares
        method from converting the minimisation problem into a set of linear 
        equations. This uses a 
        
        Parameters
        ----------
        mv: ndarray
            The array of vertices to be moved 
        sv: ndarray
            The array of static vertices, these are the face centroids of the 
            static mesh
        sn: ndarray
            The normals of the point in teh static array, these are derived 
            from the normals of the faces for each centroid
        
        Returns
        -------
        R: ndarray
            The optimal rotation array 
        T: ndarray
            The optimal translation array
        
        References
        ----------
        .. [1] Besl, Paul J.; N.D. McKay (1992). "A Method for Registration of 3-D
           Shapes". IEEE Trans. on Pattern Analysis and Machine Intelligence (Los
           Alamitos, CA, USA: IEEE Computer Society) 14 (2): 239-256.
        
        .. [2] Chen, Yang; Gerard Medioni (1991). "Object modelling by registration of
           multiple range images". Image Vision Comput. (Newton, MA, USA:
           Butterworth-Heinemann): 145-155

        Examples
        --------
        >>> static = AmpObject(staticfh)
        >>> moving = AmpObject(movingfh)
        >>> al = align(moving, static, method='linPoint2Plane').m
        
        """
        cn = np.c_[np.cross(mv, sn), sn]
        C = np.dot(cn.T, cn)
        v = sv - mv
        b = np.zeros([6])
        for i, col in enumerate(cn.T):
            b[i] = (v * np.repeat(col[:, None], 3, axis=1) * sn).sum()
        X = np.linalg.lstsq(C, b, rcond=None)[0]
        [cx, cy, cz] = np.cos(X[:3])
        [sx, sy, sz] = np.sin(X[:3])
        R = np.array(
            [[cy * cz, sx * sy * cz - cx * sz, cx * sy * cz + sx * sz],
             [cy * sz, cx * cz + sx * sy * sz, cx * sy * sz - sx * cz],
             [-sy, sx * cy, cx * cy]])
        T = X[3:]
        return (R, T)

    @staticmethod
    def linPoint2Point(mv, sv):
        r"""
        Point-to-Point Iterative Closest Point algorithm which 
        relies on using singular value decomposition on the centered arrays.  
        
        Parameters
        ----------
        mv: ndarray
            The array of vertices to be moved 
        sv: ndarray
            The array of static vertices, these are the face centroids of the 
            static mesh
        
        Returns
        -------
        R: ndarray
            The optimal rotation array 
        T: ndarray
            The optimal translation array
        
        References
        ----------
        .. [1] Besl, Paul J.; N.D. McKay (1992). "A Method for Registration of 3-D
           Shapes". IEEE Trans. on Pattern Analysis and Machine Intelligence (Los
           Alamitos, CA, USA: IEEE Computer Society) 14 (2): 239-256.
        
        .. [2] Chen, Yang; Gerard Medioni (1991). "Object modelling by registration of
           multiple range images". Image Vision Comput. (Newton, MA, USA:
           Butterworth-Heinemann): 145-155

        Examples
        --------
        >>> static = AmpObject(staticfh)
        >>> moving = AmpObject(movingfh)
        >>> al = align(moving, static, method='linPoint2Point').m

        """
        mCent = mv - mv.mean(axis=0)
        sCent = sv - sv.mean(axis=0)
        C = np.dot(mCent.T, sCent)
        [U, _, V] = np.linalg.svd(C)
        det = np.linalg.det(np.dot(U, V))
        sign = np.eye(3)
        sign[2, 2] = np.sign(det)
        R = np.dot(V.T, sign)
        R = np.dot(R, U.T)
        T = sv.mean(axis=0) - np.dot(R, mv.mean(axis=0))
        return (R, T)

    @staticmethod
    def optPoint2Point(mv, sv, opt='L-BFGS-B'):
        r"""
        Direct minimisation of the rmse between the points of the two meshes. This 
        method enables access to all of Scipy's minimisation algorithms 
        
        Parameters
        ----------
        mv: ndarray
            The array of vertices to be moved 
        sv: ndarray
            The array of static vertices, these are the face centroids of the 
            static mesh
        opt: str, default 'L_BFGS-B'
            The string of the scipy optimiser to use 
        
        Returns
        -------
        R: ndarray
            The optimal rotation array 
        T: ndarray
            The optimal translation array
            
        Examples
        --------
        >>> static = AmpObject(staticfh)
        >>> moving = AmpObject(movingfh)
        >>> al = align(moving, static, method='optPoint2Point', opt='SLSQP').m
            
        """
        X = np.zeros(6)
        lim = [-np.pi / 4, np.pi / 4] * 3 + [-5, 5] * 3
        lim = np.reshape(lim, [6, 2])
        try:
            X = minimize(align.optDistError,
                         X,
                         args=(mv, sv),
                         bounds=lim,
                         method=opt)
        except:
            X = minimize(align.optDistError, X, args=(mv, sv), method=opt)
        [angx, angy, angz] = X.x[:3]
        Rx = np.array([[1, 0, 0], [0, np.cos(angx), -np.sin(angx)],
                       [0, np.sin(angx), np.cos(angx)]])
        Ry = np.array([[np.cos(angy), 0, np.sin(angy)], [0, 1, 0],
                       [-np.sin(angy), 0, np.cos(angy)]])
        Rz = np.array([[np.cos(angz), -np.sin(angz), 0],
                       [np.sin(angz), np.cos(angz), 0], [0, 0, 1]])
        R = np.dot(np.dot(Rz, Ry), Rx)
        T = X.x[3:]
        return (R, T)

    @staticmethod
    def optDistError(X, mv, sv):
        r"""
        The function to minimise. It performs the affine transformation then returns 
        the rmse between the two vertex sets
        
        Parameters
        ----------
        X:  ndarray
            The affine transformation corresponding to [Rx, Ry, Rz, Tx, Ty, Tz]
        mv: ndarray
            The array of vertices to be moved 
        sv: ndarray
            The array of static vertices, these are the face centroids of the 
            static mesh

        Returns
        -------
        err: float
            The RMSE between the two meshes
        
        """
        [angx, angy, angz] = X[:3]
        Rx = np.array([[1, 0, 0], [0, np.cos(angx), -np.sin(angx)],
                       [0, np.sin(angx), np.cos(angx)]])
        Ry = np.array([[np.cos(angy), 0, np.sin(angy)], [0, 1, 0],
                       [-np.sin(angy), 0, np.cos(angy)]])
        Rz = np.array([[np.cos(angz), -np.sin(angz), 0],
                       [np.sin(angz), np.cos(angz), 0], [0, 0, 1]])
        R = np.dot(np.dot(Rz, Ry), Rx)
        moved = np.dot(mv, R.T)
        moved += X[3:]
        dist = (moved - sv)**2
        dist = dist.sum(axis=1)
        err = np.sqrt(dist.mean())
        return err

    @staticmethod
    def rot2quat(R):
        """
        Convert a rotation matrix to a quaternionic matrix
        
        Parameters
        ----------
        R: array_like
            The 3x3 rotation array to be converted to a quaternionic matrix
        
        Returns
        -------
        Q: ndarray
            The quaternionic matrix

        """
        [[Qxx, Qxy, Qxz], [Qyx, Qyy, Qyz], [Qzx, Qzy, Qzz]] = R
        t = Qxx + Qyy + Qzz
        if t >= 0:
            r = math.sqrt(1 + t)
            s = 0.5 / r
            w = 0.5 * r
            x = (Qzy - Qyz) * s
            y = (Qxz - Qzx) * s
            z = (Qyx - Qxy) * s
        else:
            maxv = max([Qxx, Qyy, Qzz])
            if maxv == Qxx:
                r = math.sqrt(1 + Qxx - Qyy - Qzz)
                s = 0.5 / r
                w = (Qzy - Qyz) * s
                x = 0.5 * r
                y = (Qyx + Qxy) * s
                z = (Qxz + Qzx) * s
            elif maxv == Qyy:
                r = math.sqrt(1 + Qyy - Qxx - Qzz)
                s = 0.5 / r
                w = (Qxz - Qzx) * s
                x = (Qyx + Qxy) * s
                y = 0.5 * r
                z = (Qzy + Qyz) * s
            else:
                r = math.sqrt(1 + Qzz - Qxx - Qyy)
                s = 0.5 / r
                w = (Qyx - Qxy) * s
                x = (Qxz + Qzx) * s
                y = (Qzy + Qyz) * s
                z = 0.5 * r
        return np.array([w, x, y, z])

    def display(self):
        r"""
        Display the static mesh and the aligned within an interactive VTK 
        window 
        
        """
        if not hasattr(self.s, 'actor'):
            self.s.addActor()
        if not hasattr(self.m, 'actor'):
            self.m.addActor()
        # Generate a renderer window
        win = vtkRenWin()
        # Set the number of viewports
        win.setnumViewports(1)
        # Set the background colour
        win.setBackground([1, 1, 1])
        # Set camera projection
        renderWindowInteractor = vtk.vtkRenderWindowInteractor()
        renderWindowInteractor.SetRenderWindow(win)
        renderWindowInteractor.SetInteractorStyle(
            vtk.vtkInteractorStyleTrackballCamera())
        # Set camera projection
        win.setView()
        self.s.actor.setColor([1.0, 0.0, 0.0])
        self.s.actor.setOpacity(0.5)
        self.m.actor.setColor([0.0, 0.0, 1.0])
        self.m.actor.setOpacity(0.5)
        win.renderActors([self.s.actor, self.m.actor])
        win.Render()
        win.rens[0].GetActiveCamera().Azimuth(180)
        win.rens[0].GetActiveCamera().SetParallelProjection(True)
        win.Render()
        return win

    def genIm(self, crop=False):
        r"""
        Display the static mesh and the aligned within an interactive VTK 
        window 
        
        """
        if not hasattr(self.s, 'actor'):
            self.s.addActor()
        if not hasattr(self.m, 'actor'):
            self.m.addActor()
        # Generate a renderer window
        win = vtkRenWin()
        # Set the number of viewports
        win.setnumViewports(1)
        # Set the background colour
        win.setBackground([1, 1, 1])
        # Set camera projection
        # Set camera projection
        win.setView([0, -1, 0], 0)
        win.SetSize(512, 512)
        win.Modified()
        win.OffScreenRenderingOn()
        self.s.actor.setColor([1.0, 0.0, 0.0])
        self.s.actor.setOpacity(0.5)
        self.m.actor.setColor([0.0, 0.0, 1.0])
        self.m.actor.setOpacity(0.5)
        win.renderActors([self.s.actor, self.m.actor])
        win.Render()
        win.rens[0].GetActiveCamera().Azimuth(0)
        win.rens[0].GetActiveCamera().SetParallelProjection(True)
        win.Render()
        im = win.getImage()
        if crop is True:
            mask = np.all(im == 1, axis=2)
            mask = ~np.all(mask, axis=1)
            im = im[mask, :, :]
            mask = np.all(im == 1, axis=2)
            mask = ~np.all(mask, axis=0)
            im = im[:, mask, :]
        return im, win
Beispiel #2
0
class TestCore(unittest.TestCase):
    ACCURACY = 5  # The number of decimal places to value accuracy for - needed due to floating point inaccuracies

    def setUp(self):
        """Runs before each unit test.
        Sets up the AmpObject object using "stl_file.stl".
        """
        from ampscan.core import AmpObject
        stl_path = get_path("stl_file.stl")
        self.amp = AmpObject(stl_path)

    def test_centre(self):
        """Test the centre method of AmpObject"""

        # Translate the mesh
        self.amp.translate([1, 0, 0])
        # Recenter the mesh
        self.amp.centre()
        centre = self.amp.vert.mean(axis=0)

        # Check that the mesh is centred correctly (to at least the number of decimal places of ACCURACY)
        self.assertTrue(
            all(centre[i] < (10**-TestCore.ACCURACY) for i in range(3)))

    def test_centre_static(self):

        with self.assertRaises(TypeError):
            self.amp.centreStatic(1)
        with self.assertRaises(TypeError):
            self.amp.centreStatic([])

        # Import second shape
        from ampscan.core import AmpObject
        stl_path = get_path("stl_file_2.stl")
        amp2 = AmpObject(stl_path)

        self.amp.centreStatic(amp2)

        for i in range(3):
            # This method has a large degree of error so, it's only testing to 2 dp
            self.assertAlmostEqual(
                self.amp.vert.mean(axis=0)[i],
                amp2.vert.mean(axis=0)[i], 2)

    def test_rotate_ang(self):
        """Tests the rotateAng method of AmpObject"""

        # Test rotation on random node
        n = randrange(len(self.amp.vert))
        rot = [0, 0, np.pi / 3]
        before = self.amp.vert[n].copy()
        self.amp.rotateAng(rot)
        after_vert_pos = self.amp.vert[n].copy()
        # Use 2D rotation matrix formula to test rotate method on z axis
        expected = [
            np.cos(rot[2]) * before[0] - np.sin(rot[2]) * before[1],
            np.sin(rot[2]) * before[0] + np.cos(rot[2]) * before[1], before[2]
        ]
        # Check all coordinate dimensions are correct
        all(
            self.assertAlmostEqual(expected[i], after_vert_pos[i],
                                   TestCore.ACCURACY) for i in range(3))

        # Check single floats cause TypeError
        with self.assertRaises(TypeError):
            self.amp.rotateAng(7)

        # Check dictionaries cause TypeError
        with self.assertRaises(TypeError):
            self.amp.rotateAng(dict())

        # Tests that incorrect number of elements causes ValueError
        with self.assertRaises(ValueError):
            self.amp.rotateAng(rot, "test")
        with self.assertRaises(ValueError):
            self.amp.rotateAng(rot, [])

    def test_rotate(self):
        """Tests the rotate method of AmpObject"""
        # A test rotation and translation using list
        m = [[1, 0, 0], [0, np.sqrt(3) / 2, 1 / 2],
             [0, -1 / 2, np.sqrt(3) / 2]]
        self.amp.rotate(m)

        # Check single floats cause TypeError
        with self.assertRaises(TypeError):
            self.amp.rotate(7)

        # Check dictionaries cause TypeError
        with self.assertRaises(TypeError):
            self.amp.rotate(dict())

        # Check invalid dimensions cause ValueError
        with self.assertRaises(ValueError):
            self.amp.rotate([])
        with self.assertRaises(ValueError):
            self.amp.rotate([[0, 0, 1]])
        with self.assertRaises(ValueError):
            self.amp.rotate([[], [], []])

    def test_translate(self):
        """Test translating method of AmpObject"""

        # Check that everything has been translated correctly to a certain accuracy
        start = self.amp.vert.mean(axis=0).copy()
        self.amp.translate([1, -1, 0])
        end = self.amp.vert.mean(axis=0).copy()
        self.assertAlmostEqual(start[0] + 1, end[0], places=TestCore.ACCURACY)
        self.assertAlmostEqual(start[1] - 1, end[1], places=TestCore.ACCURACY)
        self.assertAlmostEqual(start[2], end[2], places=TestCore.ACCURACY)

        # Check that translating raises TypeError when translating with an invalid type
        with self.assertRaises(TypeError):
            self.amp.translate("")

        # Check that translating raises ValueError when translating with 2 dimensions
        with self.assertRaises(ValueError):
            self.amp.translate([0, 0])

        # Check that translating raises ValueError when translating with 4 dimensions
        with self.assertRaises(ValueError):
            self.amp.translate([0, 0, 0, 0])

    def test_rigid_transform(self):
        """Test the rigid transform method of AmpObject"""

        # Test if no transform is applied, vertices aren't affected
        before_vert = self.amp.vert.copy()
        self.amp.rigidTransform(R=None, T=None)
        all(
            self.assertEqual(self.amp.vert[y][x], before_vert[y][x])
            for y in range(len(self.amp.vert))
            for x in range(len(self.amp.vert[0])))

        # A test rotation and translation
        m = [[1, 0, 0], [0, np.sqrt(3) / 2, 1 / 2],
             [0, -1 / 2, np.sqrt(3) / 2]]
        self.amp.rigidTransform(R=m, T=[1, 0, -1])

        # Check that translating raises TypeError when translating with an invalid type
        with self.assertRaises(TypeError):
            self.amp.rigidTransform(T=dict())

        # Check that rotating raises TypeError when translating with an invalid type
        with self.assertRaises(TypeError):
            self.amp.rigidTransform(R=7)

    def test_rot_matrix(self):
        """Tests the rotMatrix method in AmpObject"""

        # Tests that a transformation by 0 in all axis is 0 matrix
        all(
            self.amp.rotMatrix([0, 0, 0])[y][x] == 0 for x in range(3)
            for y in range(3))

        expected = [[1, 0, 0], [0, np.sqrt(3) / 2, 1 / 2],
                    [0, -1 / 2, np.sqrt(3) / 2]]
        all(
            self.amp.rotMatrix([np.pi / 6, 0, 0])[y][x] == expected[y][x]
            for x in range(3) for y in range(3))

        # Tests that string passed into rot causes TypeError
        with self.assertRaises(TypeError):
            self.amp.rotMatrix(" ")
        with self.assertRaises(TypeError):
            self.amp.rotMatrix(dict())

        # Tests that incorrect number of elements causes ValueError
        with self.assertRaises(ValueError):
            self.amp.rotMatrix([0, 1])
        with self.assertRaises(ValueError):
            self.amp.rotMatrix([0, 1, 3, 0])

    def test_flip(self):
        """Tests the flip method in AmpObject"""
        # Check invalid axis types cause TypeError
        with self.assertRaises(TypeError):
            self.amp.flip(" ")
        with self.assertRaises(TypeError):
            self.amp.flip(dict())

        # Check invalid axis values cause ValueError
        with self.assertRaises(ValueError):
            self.amp.flip(-1)
        with self.assertRaises(ValueError):
            self.amp.flip(3)