예제 #1
0
class TestSmoothing(unittest.TestCase):
    ACCURACY = 5  # The number of decimal places to value accuracy for - needed due to floating point inaccuracies
    DELTA = 0.5

    def setUp(self):
        """Runs before each unit test.
        Sets up the AmpObject object using "stl_file.stl".
        """
        from AmpScan.core import AmpObject
        stl_path = get_path("stl_file_5.stl")
        self.amp = AmpObject(stl_path)

    def test_smoothing_nans(self):
        """Tests that NaNs are properly dealt with by smooth method"""
        # Test that smoothing runs
        self.amp.smoothValues()
        # TODO add test with NaNs

    def test_smoothing_volume(self):
        """Tests that smoothing affects the volume within given acceptable range"""
        # TODO check this is actually working properly
        poly1 = analyse.create_slices(self.amp, [0.001, 0.999],
                                      0.001,
                                      typ='norm_intervals',
                                      axis=2)
        print(analyse.est_volume(poly1))
        self.amp.lp_smooth(1)
        poly2 = analyse.create_slices(self.amp, [0.001, 0.999],
                                      0.001,
                                      typ='norm_intervals',
                                      axis=2)
        print(analyse.est_volume(poly2))
        self.assertAlmostEqual(analyse.est_volume(poly1),
                               analyse.est_volume(poly2),
                               delta=TestSmoothing.DELTA)
예제 #2
0
 def setUp(self):
     """Runs before each unit test.
     Sets up the AmpObject object using "stl_file.stl".
     """
     from AmpScan.core import AmpObject
     stl_path = get_path("stl_file.stl")
     self.amp = AmpObject(stl_path)
예제 #3
0
 def setUp(self):
     """Runs before each unit test.
     Sets up AmpObject object using "stl_file_4.stl" "stl_file_5.stl".
     """
     from AmpScan.core import AmpObject
     # Load 2 spheres with radius 1, and 1.2
     stl_path = get_path("stl_file_5.stl")  # R=1
     self.amp1 = AmpObject(stl_path)
     stl_path = get_path("stl_file_4.stl")  # R=1.2
     self.amp2 = AmpObject(stl_path)
예제 #4
0
 def __init__(self,
              moving,
              static,
              method='linPoint2Plane',
              *args,
              **kwargs):
     mData = dict(
         zip(['vert', 'faces', 'values'],
             [moving.vert, moving.faces, moving.values]))
     alData = copy.deepcopy(mData)
     self.m = AmpObject(alData, stype='reg')
     self.s = static
     self.runICP(method=method, *args, **kwargs)
예제 #5
0
    def test_centre_static(self):

        with self.assertRaises(TypeError):
            self.amp.centreStatic(1)
        with self.assertRaises(TypeError):
            self.amp.centreStatic([])

        # Import second shape
        from AmpScan.core import AmpObject
        stl_path = get_path("stl_file_2.stl")
        amp2 = AmpObject(stl_path)

        self.amp.centreStatic(amp2)

        for i in range(3):
            # This method has a large degree of error so, it's only testing to 2 dp
            self.assertAlmostEqual(
                self.amp.vert.mean(axis=0)[i],
                amp2.vert.mean(axis=0)[i], 2)
예제 #6
0
class TestCore(unittest.TestCase):
    ACCURACY = 5  # The number of decimal places to value accuracy for - needed due to floating point inaccuracies

    def setUp(self):
        """Runs before each unit test.
        Sets up the AmpObject object using "stl_file.stl".
        """
        from AmpScan.core import AmpObject
        stl_path = get_path("stl_file.stl")
        self.amp = AmpObject(stl_path)

    def test_centre(self):
        """Test the centre method of AmpObject"""

        # Translate the mesh
        self.amp.translate([1, 0, 0])
        # Recenter the mesh
        self.amp.centre()
        centre = self.amp.vert.mean(axis=0)

        # Check that the mesh is centred correctly (to at least the number of decimal places of ACCURACY)
        self.assertTrue(
            all(centre[i] < (10**-TestCore.ACCURACY) for i in range(3)))

    def test_centre_static(self):

        with self.assertRaises(TypeError):
            self.amp.centreStatic(1)
        with self.assertRaises(TypeError):
            self.amp.centreStatic([])

        # Import second shape
        from AmpScan.core import AmpObject
        stl_path = get_path("stl_file_2.stl")
        amp2 = AmpObject(stl_path)

        self.amp.centreStatic(amp2)

        for i in range(3):
            # This method has a large degree of error so, it's only testing to 2 dp
            self.assertAlmostEqual(
                self.amp.vert.mean(axis=0)[i],
                amp2.vert.mean(axis=0)[i], 2)

    def test_rotate_ang(self):
        """Tests the rotateAng method of AmpObject"""

        # Test rotation on random node
        n = randrange(len(self.amp.vert))
        rot = [0, 0, np.pi / 3]
        before = self.amp.vert[n].copy()
        self.amp.rotateAng(rot)
        after_vert_pos = self.amp.vert[n].copy()
        # Use 2D rotation matrix formula to test rotate method on z axis
        expected = [
            np.cos(rot[2]) * before[0] - np.sin(rot[2]) * before[1],
            np.sin(rot[2]) * before[0] + np.cos(rot[2]) * before[1], before[2]
        ]
        # Check all coordinate dimensions are correct
        all(
            self.assertAlmostEqual(expected[i], after_vert_pos[i],
                                   TestCore.ACCURACY) for i in range(3))

        # Check single floats cause TypeError
        with self.assertRaises(TypeError):
            self.amp.rotateAng(7)

        # Check dictionaries cause TypeError
        with self.assertRaises(TypeError):
            self.amp.rotateAng(dict())

        # Tests that incorrect number of elements causes ValueError
        with self.assertRaises(ValueError):
            self.amp.rotateAng(rot, "test")
        with self.assertRaises(ValueError):
            self.amp.rotateAng(rot, [])

    def test_rotate(self):
        """Tests the rotate method of AmpObject"""
        # A test rotation and translation using list
        m = [[1, 0, 0], [0, np.sqrt(3) / 2, 1 / 2],
             [0, -1 / 2, np.sqrt(3) / 2]]
        self.amp.rotate(m)

        # Check single floats cause TypeError
        with self.assertRaises(TypeError):
            self.amp.rotate(7)

        # Check dictionaries cause TypeError
        with self.assertRaises(TypeError):
            self.amp.rotate(dict())

        # Check invalid dimensions cause ValueError
        with self.assertRaises(ValueError):
            self.amp.rotate([])
        with self.assertRaises(ValueError):
            self.amp.rotate([[0, 0, 1]])
        with self.assertRaises(ValueError):
            self.amp.rotate([[], [], []])

    def test_translate(self):
        """Test translating method of AmpObject"""

        # Check that everything has been translated correctly to a certain accuracy
        start = self.amp.vert.mean(axis=0).copy()
        self.amp.translate([1, -1, 0])
        end = self.amp.vert.mean(axis=0).copy()
        self.assertAlmostEqual(start[0] + 1, end[0], places=TestCore.ACCURACY)
        self.assertAlmostEqual(start[1] - 1, end[1], places=TestCore.ACCURACY)
        self.assertAlmostEqual(start[2], end[2], places=TestCore.ACCURACY)

        # Check that translating raises TypeError when translating with an invalid type
        with self.assertRaises(TypeError):
            self.amp.translate("")

        # Check that translating raises ValueError when translating with 2 dimensions
        with self.assertRaises(ValueError):
            self.amp.translate([0, 0])

        # Check that translating raises ValueError when translating with 4 dimensions
        with self.assertRaises(ValueError):
            self.amp.translate([0, 0, 0, 0])

    def test_rigid_transform(self):
        """Test the rigid transform method of AmpObject"""

        # Test if no transform is applied, vertices aren't affected
        before_vert = self.amp.vert.copy()
        self.amp.rigidTransform(R=None, T=None)
        all(
            self.assertEqual(self.amp.vert[y][x], before_vert[y][x])
            for y in range(len(self.amp.vert))
            for x in range(len(self.amp.vert[0])))

        # A test rotation and translation
        m = [[1, 0, 0], [0, np.sqrt(3) / 2, 1 / 2],
             [0, -1 / 2, np.sqrt(3) / 2]]
        self.amp.rigidTransform(R=m, T=[1, 0, -1])

        # Check that translating raises TypeError when translating with an invalid type
        with self.assertRaises(TypeError):
            self.amp.rigidTransform(T=dict())

        # Check that rotating raises TypeError when translating with an invalid type
        with self.assertRaises(TypeError):
            self.amp.rigidTransform(R=7)

    def test_rot_matrix(self):
        """Tests the rotMatrix method in AmpObject"""

        # Tests that a transformation by 0 in all axis is 0 matrix
        all(
            self.amp.rotMatrix([0, 0, 0])[y][x] == 0 for x in range(3)
            for y in range(3))

        expected = [[1, 0, 0], [0, np.sqrt(3) / 2, 1 / 2],
                    [0, -1 / 2, np.sqrt(3) / 2]]
        all(
            self.amp.rotMatrix([np.pi / 6, 0, 0])[y][x] == expected[y][x]
            for x in range(3) for y in range(3))

        # Tests that string passed into rot causes TypeError
        with self.assertRaises(TypeError):
            self.amp.rotMatrix(" ")
        with self.assertRaises(TypeError):
            self.amp.rotMatrix(dict())

        # Tests that incorrect number of elements causes ValueError
        with self.assertRaises(ValueError):
            self.amp.rotMatrix([0, 1])
        with self.assertRaises(ValueError):
            self.amp.rotMatrix([0, 1, 3, 0])

    def test_flip(self):
        """Tests the flip method in AmpObject"""
        # Check invalid axis types cause TypeError
        with self.assertRaises(TypeError):
            self.amp.flip(" ")
        with self.assertRaises(TypeError):
            self.amp.flip(dict())

        # Check invalid axis values cause ValueError
        with self.assertRaises(ValueError):
            self.amp.flip(-1)
        with self.assertRaises(ValueError):
            self.amp.flip(3)
예제 #7
0
    def point2plane(self,
                    steps=1,
                    neigh=10,
                    inside=True,
                    subset=None,
                    scale=None,
                    smooth=1,
                    fixBrim=False,
                    error='norm'):
        r"""
        Point to Plane method for registration between the two meshes 
        
        Parameters
        ----------
        steps: int, default 1
            Number of iterations
        int, default 10
            Number of nearest neighbours to interrogate for each baseline point
        inside: bool, default True
            If True, a barycentric centre check is made to ensure the registered 
            point lines within the target triangle
        subset: array_like, default None
            Indicies of the baseline nodes to include in the registration, default is none so 
            all are used
        scale: float, default None
            If not None scale the baseline mesh to match the target mesh in the z-direction, 
            the value of scale will be used as a plane from which the nodes are scaled.
            Nodes with a higher z value will not be scaled. 
        smooth: int, default 1
            Indicate number of laplacian smooth steps in between the steps 
        fixBrim: bool, default False
            If True, the nodes on the brim line will not be included in the smooth
        error: bool, default False
            If True, the polarity will be included when calculating the distance 
            between the target and baseline mesh
		
        """
        # Calc FaceCentroids
        fC = self.t.vert[self.t.faces].mean(axis=1)
        # Construct knn tree
        tTree = spatial.cKDTree(fC)
        bData = dict(
            zip(['vert', 'faces', 'values'],
                [self.b.vert, self.b.faces, self.b.values]))
        regData = copy.deepcopy(bData)
        self.reg = AmpObject(regData, stype='reg')
        self.disp = AmpObject({
            'vert': np.zeros(self.reg.vert.shape),
            'faces': self.reg.faces,
            'values': self.reg.values
        })
        if scale is not None:
            tmin = self.t.vert.min(axis=0)[2]
            rmin = self.reg.vert.min(axis=0)[2]
            SF = ((tmin - scale) / (rmin - scale)) - 1
            logic = self.reg.vert[:, 2] < scale
            d = (self.reg.vert[logic, 2] - scale) * SF
            self.disp.vert[logic, 2] += d
            self.reg.vert = self.b.vert + self.disp.vert
        normals = np.cross(
            self.t.vert[self.t.faces[:, 1]] - self.t.vert[self.t.faces[:, 0]],
            self.t.vert[self.t.faces[:, 2]] - self.t.vert[self.t.faces[:, 0]])
        mag = (normals**2).sum(axis=1)
        for step in np.arange(steps, 0, -1, dtype=float):
            # Index of 10 centroids nearest to each baseline vertex
            ind = tTree.query(self.reg.vert, neigh)[1]
            # Define normals for faces of nearest faces
            norms = normals[ind]
            # Get a point on each face
            fPoints = self.t.vert[self.t.faces[ind, 0]]
            # Calculate dot product between point on face and normals
            d = np.einsum('ijk, ijk->ij', norms, fPoints)
            t = (d - np.einsum('ijk, ik->ij', norms, self.reg.vert)) / mag[ind]
            # Calculate the vector from old point to new point
            G = self.reg.vert[:, None, :] + np.einsum('ijk, ij->ijk', norms, t)
            # Ensure new points lie inside points otherwise set to 99999
            # Find smallest distance from old to new point
            if inside is False:
                G = G - self.reg.vert[:, None, :]
                GMag = np.sqrt(np.einsum('ijk, ijk->ij', G, G))
                GInd = GMag.argmin(axis=1)
            else:
                G, GInd = self.__calcBarycentric(self.reg.vert, G, ind)
            # Define vector from baseline point to intersect point
            D = G[np.arange(len(G)), GInd, :]
            #            rVert += D/step
            self.disp.vert += D / step
            if smooth > 0 and step > 1:
                self.disp.lp_smooth(smooth, brim=fixBrim)
                self.reg.vert = self.b.vert + self.disp.vert
            else:
                self.reg.vert = self.b.vert + self.disp.vert
                self.reg.calcNorm()
        self.reg.calcStruct()
        self.reg.values[:] = self.calcError(error)
예제 #8
0
class registration(object):
    r"""
    Registration methods between two AmpObject meshes. This function morphs the baseline 
    vertices onto the surface of the target and returns a new AmpObject
    
    Parameters
    ----------
    baseline: AmpObject
    	The baseline AmpObject, the vertices from this will be morphed onto the target
    target: AmpObject
    	The target AmpObject, the shape that the baseline attempts to morph onto
    method: str: default 'point2plane'
    	A string of the method used for registration
    *args:
    	The arguments used for the registration methods
    **kwargs:
    	The keyword arguments used for the registration methods
        
    Returns
    -------
    reg: AmpObject
        The registered AmpObject, the vertices of this are on the surface of the target 
        and it has the same number of vertices and face array as the baseline AmpObject
        Access this accessing the registration.reg 
    
    Examples
    --------
    >>> from AmpScan.core import AmpObject
    >>> baseline = AmpObject(basefh)
    >>> target = AmpObject(targfh)
    >>> reg = registration(baseline, target, steps=10, neigh=10, smooth=1).reg
		
    """
    def __init__(self,
                 baseline,
                 target,
                 method='point2plane',
                 *args,
                 **kwargs):
        self.b = baseline
        self.t = target
        if method is not None:
            getattr(self, method)(*args, **kwargs)

    def point2plane(self,
                    steps=1,
                    neigh=10,
                    inside=True,
                    subset=None,
                    scale=None,
                    smooth=1,
                    fixBrim=False,
                    error='norm'):
        r"""
        Point to Plane method for registration between the two meshes 
        
        Parameters
        ----------
        steps: int, default 1
            Number of iterations
        int, default 10
            Number of nearest neighbours to interrogate for each baseline point
        inside: bool, default True
            If True, a barycentric centre check is made to ensure the registered 
            point lines within the target triangle
        subset: array_like, default None
            Indicies of the baseline nodes to include in the registration, default is none so 
            all are used
        scale: float, default None
            If not None scale the baseline mesh to match the target mesh in the z-direction, 
            the value of scale will be used as a plane from which the nodes are scaled.
            Nodes with a higher z value will not be scaled. 
        smooth: int, default 1
            Indicate number of laplacian smooth steps in between the steps 
        fixBrim: bool, default False
            If True, the nodes on the brim line will not be included in the smooth
        error: bool, default False
            If True, the polarity will be included when calculating the distance 
            between the target and baseline mesh
		
        """
        # Calc FaceCentroids
        fC = self.t.vert[self.t.faces].mean(axis=1)
        # Construct knn tree
        tTree = spatial.cKDTree(fC)
        bData = dict(
            zip(['vert', 'faces', 'values'],
                [self.b.vert, self.b.faces, self.b.values]))
        regData = copy.deepcopy(bData)
        self.reg = AmpObject(regData, stype='reg')
        self.disp = AmpObject({
            'vert': np.zeros(self.reg.vert.shape),
            'faces': self.reg.faces,
            'values': self.reg.values
        })
        if scale is not None:
            tmin = self.t.vert.min(axis=0)[2]
            rmin = self.reg.vert.min(axis=0)[2]
            SF = ((tmin - scale) / (rmin - scale)) - 1
            logic = self.reg.vert[:, 2] < scale
            d = (self.reg.vert[logic, 2] - scale) * SF
            self.disp.vert[logic, 2] += d
            self.reg.vert = self.b.vert + self.disp.vert
        normals = np.cross(
            self.t.vert[self.t.faces[:, 1]] - self.t.vert[self.t.faces[:, 0]],
            self.t.vert[self.t.faces[:, 2]] - self.t.vert[self.t.faces[:, 0]])
        mag = (normals**2).sum(axis=1)
        for step in np.arange(steps, 0, -1, dtype=float):
            # Index of 10 centroids nearest to each baseline vertex
            ind = tTree.query(self.reg.vert, neigh)[1]
            # Define normals for faces of nearest faces
            norms = normals[ind]
            # Get a point on each face
            fPoints = self.t.vert[self.t.faces[ind, 0]]
            # Calculate dot product between point on face and normals
            d = np.einsum('ijk, ijk->ij', norms, fPoints)
            t = (d - np.einsum('ijk, ik->ij', norms, self.reg.vert)) / mag[ind]
            # Calculate the vector from old point to new point
            G = self.reg.vert[:, None, :] + np.einsum('ijk, ij->ijk', norms, t)
            # Ensure new points lie inside points otherwise set to 99999
            # Find smallest distance from old to new point
            if inside is False:
                G = G - self.reg.vert[:, None, :]
                GMag = np.sqrt(np.einsum('ijk, ijk->ij', G, G))
                GInd = GMag.argmin(axis=1)
            else:
                G, GInd = self.__calcBarycentric(self.reg.vert, G, ind)
            # Define vector from baseline point to intersect point
            D = G[np.arange(len(G)), GInd, :]
            #            rVert += D/step
            self.disp.vert += D / step
            if smooth > 0 and step > 1:
                self.disp.lp_smooth(smooth, brim=fixBrim)
                self.reg.vert = self.b.vert + self.disp.vert
            else:
                self.reg.vert = self.b.vert + self.disp.vert
                self.reg.calcNorm()
        self.reg.calcStruct()
        self.reg.values[:] = self.calcError(error)

    def calcError(self, method='norm'):
        r"""
        Calculate the magnitude of distances between the baseline and registered array
		
        Parameters
        ----------
        method: str, default 'norm'
            The method used to calculate the distances. 'abs' returns the absolute
            distance. 'cent'calculates polarity based upon distance from centroid.
            'norm' calculates dot product between baseline vertex normal and distance 
            normal

        Returns
        -------
        values: array_like
            Magnitude of distances

        """
        method = '_registration__' + method + 'Dist'
        try:
            values = getattr(self, method)()
            return values
        except:
            ValueError('"%s" is not a method, try "abs", "cent" or "prod"' %
                       method)

    def __absDist(self):
        r"""
        Return the error based upon the absolute distance
        
        Returns
        -------
        values: array_like
            Magnitude of distances

        """
        return np.linalg.norm(self.reg.vert - self.b.vert, axis=1)

    def __centDist(self):
        r"""
        Return the error based upon distance from centroid 
        
        Returns
        -------
        values: array_like
            Magnitude of distances

        """
        values = np.linalg.norm(self.reg.vert - self.b.vert, axis=1)
        cent = self.b.vert.mean(axis=0)
        r = np.linalg.norm(self.reg.vert - cent, axis=1)
        b = np.linalg.norm(self.b.vert - cent, axis=1)
        polarity = np.ones([self.reg.vert.shape[0]])
        polarity[r < b] = -1
        return values * polarity

    def __normDist(self):
        r"""
        Returns error based upon scalar product of normal 
        
        Returns
        -------
        values: array_like
            Magnitude of distances

        """
        self.b.calcVNorm()
        D = self.reg.vert - self.b.vert
        n = self.b.vNorm
        values = np.linalg.norm(D, axis=1)
        polarity = np.sum(n * D, axis=1) < 0
        values[polarity] *= -1.0
        return values

    def __calcBarycentric(self, vert, G, ind):
        r"""
        Calculate the barycentric co-ordinates of each target face and the registered vertex, 
        this ensures that the registered vertex is within the bounds of the target face. If not 
        the registered vertex is moved to the nearest vertex on the target face 

        Parameters
        ----------
        vert: array_like
            The array of baseline vertices
        G: array_like
            The array of candidates for registered vertices. If neigh>1 then axis 2 will correspond 
            to the number of nearest neighbours selected
        ind: array_like
            The index of the nearest faces to the baseline vertices
        
        Returns
        -------
        G: array_like 
            The new array of candidates for registered vertices, from here, the one with 
            smallest magnitude is selected. All these points will lie within the target face
        GInd: array_like
            The index of the shortest distance between each baseline vertex and the registered vertex
            
        """
        P0 = self.t.vert[self.t.faces[ind, 0]]
        P1 = self.t.vert[self.t.faces[ind, 1]]
        P2 = self.t.vert[self.t.faces[ind, 2]]

        v0 = P2 - P0
        v1 = P1 - P0
        v2 = G - P0

        d00 = np.einsum('ijk, ijk->ij', v0, v0)
        d01 = np.einsum('ijk, ijk->ij', v0, v1)
        d02 = np.einsum('ijk, ijk->ij', v0, v2)
        d11 = np.einsum('ijk, ijk->ij', v1, v1)
        d12 = np.einsum('ijk, ijk->ij', v1, v2)

        denom = d00 * d11 - d01 * d01
        u = (d11 * d02 - d01 * d12) / denom
        v = (d00 * d12 - d01 * d02) / denom
        # Test if inside
        logic = (u >= 0) * (v >= 0) * (u + v < 1)

        P = np.stack([P0, P1, P2], axis=3)
        pg = G[:, :, :, None] - P
        pd = np.linalg.norm(pg, axis=2)
        pdx = pd.argmin(axis=2)
        i, j = np.meshgrid(np.arange(P.shape[0]), np.arange(P.shape[1]))
        nearP = P[i.T, j.T, :, pdx]
        G[~logic, :] = nearP[~logic, :]
        G = G - vert[:, None, :]
        GMag = np.sqrt(np.einsum('ijk, ijk->ij', G, G))
        GInd = GMag.argmin(axis=1)
        return G, GInd

    def plotResults(self, name=None, xrange=None, color=None, alpha=None):
        r"""
        Function to generate a mpl figure. Includes a rendering of the 
        AmpObject, a histogram of the registration values 
        
        Returns
        -------
        fig: mplfigure
            A matplot figure of the standard analysis
        
        """
        fig, ax = plt.subplots(1)
        n, bins, _ = ax.hist(self.reg.values,
                             50,
                             density=True,
                             range=xrange,
                             color=color,
                             alpha=alpha)
        mean = self.reg.values.mean()
        stdev = self.reg.values.std()
        ax.set_title(r'Distribution of shape variance, '
                     '$\mu=%.2f$, $\sigma=%.2f$' % (mean, stdev))
        ax.set_xlim(None)
        if name is not None:
            plt.savefig(name, dpi=300)
        return ax, n, bins
예제 #9
0
class TestTrim(unittest.TestCase):
    def setUp(self):
        """Runs before each unit test
        Sets up the AmpObject object using "stl_file.stl"
        """
        from AmpScan.core import AmpObject
        stl_path = get_path("stl_file.stl")
        self.amp = AmpObject(stl_path)

    def test_trim(self):
        """Tests the trim method of AmpObject for TypeErrors"""

        # Testing that the method runs
        self.amp.planarTrim(0.6, plane=2)

        # Testing invalid data types raise TypeErrors
        with self.assertRaises(TypeError):
            self.amp.planarTrim(0.6, plane=[])
        with self.assertRaises(TypeError):
            self.amp.planarTrim(0.6, plane=0.9)
        with self.assertRaises(TypeError):
            self.amp.planarTrim([], plane=[])

    def test_trim_2(self):
        """Tests the trim method of AmpObject by checking no vertices are above trim line"""
        # Test no points are above 10
        h = 10
        self.amp.planarTrim(h, plane=2)
        self.assertLessEqual(self.amp.vert[:, 2].max(), h)

        # Test no points are above 0
        h = 0
        self.amp.planarTrim(h, plane=2)
        self.assertLessEqual(self.amp.vert[:, 2].max(), h)
예제 #10
0
class align(object):
    r"""
    Automated alignment methods between two meshes
    
    Parameters
    ----------
    moving: AmpObject
        The moving AmpObject that is to be aligned to the static object
    static: AmpObject
        The static AmpObject that the moving AmpObject that the moving object 
        will be aligned to
    method: str, default 'linPoint2Plane'
        A string of the method used for alignment
    *args:
    	The arguments used for the registration methods
    **kwargs:
    	The keyword arguments used for the registration methods

    Returns
    -------
    m: AmpObject
        The aligned AmpObject, it same number of vertices and face array as 
        the moving AmpObject
        Access this using align.m

    Examples
    --------
    >>> static = AmpObject(staticfh)
    >>> moving = AmpObject(movingfh)
    >>> al = align(moving, static).m

    """
    def __init__(self,
                 moving,
                 static,
                 method='linPoint2Plane',
                 *args,
                 **kwargs):
        mData = dict(
            zip(['vert', 'faces', 'values'],
                [moving.vert, moving.faces, moving.values]))
        alData = copy.deepcopy(mData)
        self.m = AmpObject(alData, stype='reg')
        self.s = static
        self.runICP(method=method, *args, **kwargs)

    def runICP(self,
               method='linPoint2Plane',
               maxiter=20,
               inlier=1.0,
               initTransform=None,
               *args,
               **kwargs):
        r"""
        The function to run the ICP algorithm, this function calls one of 
        multiple methods to calculate the affine transformation 
        
        Parameters
        ----------
        method: str, default 'linPoint2Plane'
            A string of the method used for alignment
        maxiter: int, default 20
            Maximum number of iterations to run the ICP algorithm
        inlier: float, default 1.0
            The proportion of closest points to use to calculate the 
            transformation, if < 1 then vertices with highest error are 
            discounted
        *args:
        	The arguments used for the registration methods
        **kwargs:
        	The keyword arguments used for the registration methods
        
        """
        # Define the rotation, translation, error and quaterion arrays
        Rs = np.zeros([3, 3, maxiter + 1])
        Ts = np.zeros([3, maxiter + 1])
        #        qs = np.r_[np.ones([1, maxiter+1]),
        #                   np.zeros([6, maxiter+1])]
        #        dq  = np.zeros([7, maxiter+1])
        #        dTheta = np.zeros([maxiter+1])
        err = np.zeros([maxiter + 1])
        if initTransform is None:
            initTransform = np.eye(4)
        Rs[:, :, 0] = initTransform[:3, :3]
        Ts[:, 0] = initTransform[3, :3]
        #        qs[:4, 0] = self.rot2quat(Rs[:, :, 0])
        #        qs[4:, 0] = Ts[:, 0]
        # Define
        fC = self.s.vert[self.s.faces].mean(axis=1)
        kdTree = spatial.cKDTree(fC)
        self.m.rigidTransform(Rs[:, :, 0], Ts[:, 0])
        inlier = math.ceil(self.m.vert.shape[0] * inlier)
        [dist, idx] = kdTree.query(self.m.vert, 1)
        # Sort by distance
        sort = np.argsort(dist)
        # Keep only those within the inlier fraction
        [dist, idx] = [dist[sort], idx[sort]]
        [dist, idx, sort] = dist[:inlier], idx[:inlier], sort[:inlier]
        err[0] = math.sqrt(dist.mean())
        for i in range(maxiter):
            if method == 'linPoint2Point':
                [R, T] = getattr(self, method)(self.m.vert[sort, :],
                                               fC[idx, :], *args, **kwargs)
            elif method == 'linPoint2Plane':
                [R, T] = getattr(self,
                                 method)(self.m.vert[sort, :], fC[idx, :],
                                         self.s.norm[idx, :], *args, **kwargs)
            elif method == 'optPoint2Point':
                [R, T] = getattr(self, method)(self.m.vert[sort, :],
                                               fC[idx, :], *args, **kwargs)
            else:
                KeyError('Not a supported alignment method')
            Rs[:, :, i + 1] = np.dot(R, Rs[:, :, i])
            Ts[:, i + 1] = np.dot(R, Ts[:, i]) + T
            self.m.rigidTransform(R, T)
            [dist, idx] = kdTree.query(self.m.vert, 1)
            sort = np.argsort(dist)
            [dist, idx] = [dist[sort], idx[sort]]
            [dist, idx, sort] = dist[:inlier], idx[:inlier], sort[:inlier]
            err[i + 1] = math.sqrt(dist.mean())
#            qs[:, i+1] = np.r_[self.rot2quat(R), T]
        R = Rs[:, :, -1]
        #Simpl
        [U, s, V] = np.linalg.svd(R)
        R = np.dot(U, V)
        self.tForm = np.r_[np.c_[R, np.zeros(3)],
                           np.append(Ts[:, -1], 1)[:, None].T]
        self.R = R
        self.T = Ts[:, -1]
        self.rmse = err[-1]

    @staticmethod
    def linPoint2Plane(mv, sv, sn):
        r"""
        Iterative Closest Point algorithm which relies on using least squares
        method from converting the minimisation problem into a set of linear 
        equations. This uses a 
        
        Parameters
        ----------
        mv: ndarray
            The array of vertices to be moved 
        sv: ndarray
            The array of static vertices, these are the face centroids of the 
            static mesh
        sn: ndarray
            The normals of the point in teh static array, these are derived 
            from the normals of the faces for each centroid
        
        Returns
        -------
        R: ndarray
            The optimal rotation array 
        T: ndarray
            The optimal translation array
        
        References
        ----------
        .. [1] Besl, Paul J.; N.D. McKay (1992). "A Method for Registration of 3-D
           Shapes". IEEE Trans. on Pattern Analysis and Machine Intelligence (Los
           Alamitos, CA, USA: IEEE Computer Society) 14 (2): 239-256.
        
        .. [2] Chen, Yang; Gerard Medioni (1991). "Object modelling by registration of
           multiple range images". Image Vision Comput. (Newton, MA, USA:
           Butterworth-Heinemann): 145-155

        Examples
        --------
        >>> static = AmpObject(staticfh)
        >>> moving = AmpObject(movingfh)
        >>> al = align(moving, static, method='linPoint2Plane').m
        
        """
        cn = np.c_[np.cross(mv, sn), sn]
        C = np.dot(cn.T, cn)
        v = sv - mv
        b = np.zeros([6])
        for i, col in enumerate(cn.T):
            b[i] = (v * np.repeat(col[:, None], 3, axis=1) * sn).sum()
        X = np.linalg.lstsq(C, b, rcond=None)[0]
        [cx, cy, cz] = np.cos(X[:3])
        [sx, sy, sz] = np.sin(X[:3])
        R = np.array(
            [[cy * cz, sx * sy * cz - cx * sz, cx * sy * cz + sx * sz],
             [cy * sz, cx * cz + sx * sy * sz, cx * sy * sz - sx * cz],
             [-sy, sx * cy, cx * cy]])
        T = X[3:]
        return (R, T)

    @staticmethod
    def linPoint2Point(mv, sv):
        r"""
        Point-to-Point Iterative Closest Point algorithm which 
        relies on using singular value decomposition on the centered arrays.  
        
        Parameters
        ----------
        mv: ndarray
            The array of vertices to be moved 
        sv: ndarray
            The array of static vertices, these are the face centroids of the 
            static mesh
        
        Returns
        -------
        R: ndarray
            The optimal rotation array 
        T: ndarray
            The optimal translation array
        
        References
        ----------
        .. [1] Besl, Paul J.; N.D. McKay (1992). "A Method for Registration of 3-D
           Shapes". IEEE Trans. on Pattern Analysis and Machine Intelligence (Los
           Alamitos, CA, USA: IEEE Computer Society) 14 (2): 239-256.
        
        .. [2] Chen, Yang; Gerard Medioni (1991). "Object modelling by registration of
           multiple range images". Image Vision Comput. (Newton, MA, USA:
           Butterworth-Heinemann): 145-155

        Examples
        --------
        >>> static = AmpObject(staticfh)
        >>> moving = AmpObject(movingfh)
        >>> al = align(moving, static, method='linPoint2Point').m

        """
        mCent = mv - mv.mean(axis=0)
        sCent = sv - sv.mean(axis=0)
        C = np.dot(mCent.T, sCent)
        [U, _, V] = np.linalg.svd(C)
        det = np.linalg.det(np.dot(U, V))
        sign = np.eye(3)
        sign[2, 2] = np.sign(det)
        R = np.dot(V.T, sign)
        R = np.dot(R, U.T)
        T = sv.mean(axis=0) - np.dot(R, mv.mean(axis=0))
        return (R, T)

    @staticmethod
    def optPoint2Point(mv, sv, opt='L-BFGS-B'):
        r"""
        Direct minimisation of the rmse between the points of the two meshes. This 
        method enables access to all of Scipy's minimisation algorithms 
        
        Parameters
        ----------
        mv: ndarray
            The array of vertices to be moved 
        sv: ndarray
            The array of static vertices, these are the face centroids of the 
            static mesh
        opt: str, default 'L_BFGS-B'
            The string of the scipy optimiser to use 
        
        Returns
        -------
        R: ndarray
            The optimal rotation array 
        T: ndarray
            The optimal translation array
            
        Examples
        --------
        >>> static = AmpObject(staticfh)
        >>> moving = AmpObject(movingfh)
        >>> al = align(moving, static, method='optPoint2Point', opt='SLSQP').m
            
        """
        X = np.zeros(6)
        lim = [-np.pi / 4, np.pi / 4] * 3 + [-5, 5] * 3
        lim = np.reshape(lim, [6, 2])
        try:
            X = minimize(align.optDistError,
                         X,
                         args=(mv, sv),
                         bounds=lim,
                         method=opt)
        except:
            X = minimize(align.optDistError, X, args=(mv, sv), method=opt)
        [angx, angy, angz] = X.x[:3]
        Rx = np.array([[1, 0, 0], [0, np.cos(angx), -np.sin(angx)],
                       [0, np.sin(angx), np.cos(angx)]])
        Ry = np.array([[np.cos(angy), 0, np.sin(angy)], [0, 1, 0],
                       [-np.sin(angy), 0, np.cos(angy)]])
        Rz = np.array([[np.cos(angz), -np.sin(angz), 0],
                       [np.sin(angz), np.cos(angz), 0], [0, 0, 1]])
        R = np.dot(np.dot(Rz, Ry), Rx)
        T = X.x[3:]
        return (R, T)

    @staticmethod
    def optDistError(X, mv, sv):
        r"""
        The function to minimise. It performs the affine transformation then returns 
        the rmse between the two vertex sets
        
        Parameters
        ----------
        X:  ndarray
            The affine transformation corresponding to [Rx, Ry, Rz, Tx, Ty, Tz]
        mv: ndarray
            The array of vertices to be moved 
        sv: ndarray
            The array of static vertices, these are the face centroids of the 
            static mesh

        Returns
        -------
        err: float
            The RMSE between the two meshes
        
        """
        [angx, angy, angz] = X[:3]
        Rx = np.array([[1, 0, 0], [0, np.cos(angx), -np.sin(angx)],
                       [0, np.sin(angx), np.cos(angx)]])
        Ry = np.array([[np.cos(angy), 0, np.sin(angy)], [0, 1, 0],
                       [-np.sin(angy), 0, np.cos(angy)]])
        Rz = np.array([[np.cos(angz), -np.sin(angz), 0],
                       [np.sin(angz), np.cos(angz), 0], [0, 0, 1]])
        R = np.dot(np.dot(Rz, Ry), Rx)
        moved = np.dot(mv, R.T)
        moved += X[3:]
        dist = (moved - sv)**2
        dist = dist.sum(axis=1)
        err = np.sqrt(dist.mean())
        return err

    @staticmethod
    def rot2quat(R):
        """
        Convert a rotation matrix to a quaternionic matrix
        
        Parameters
        ----------
        R: array_like
            The 3x3 rotation array to be converted to a quaternionic matrix
        
        Returns
        -------
        Q: ndarray
            The quaternionic matrix

        """
        [[Qxx, Qxy, Qxz], [Qyx, Qyy, Qyz], [Qzx, Qzy, Qzz]] = R
        t = Qxx + Qyy + Qzz
        if t >= 0:
            r = math.sqrt(1 + t)
            s = 0.5 / r
            w = 0.5 * r
            x = (Qzy - Qyz) * s
            y = (Qxz - Qzx) * s
            z = (Qyx - Qxy) * s
        else:
            maxv = max([Qxx, Qyy, Qzz])
            if maxv == Qxx:
                r = math.sqrt(1 + Qxx - Qyy - Qzz)
                s = 0.5 / r
                w = (Qzy - Qyz) * s
                x = 0.5 * r
                y = (Qyx + Qxy) * s
                z = (Qxz + Qzx) * s
            elif maxv == Qyy:
                r = math.sqrt(1 + Qyy - Qxx - Qzz)
                s = 0.5 / r
                w = (Qxz - Qzx) * s
                x = (Qyx + Qxy) * s
                y = 0.5 * r
                z = (Qzy + Qyz) * s
            else:
                r = math.sqrt(1 + Qzz - Qxx - Qyy)
                s = 0.5 / r
                w = (Qyx - Qxy) * s
                x = (Qxz + Qzx) * s
                y = (Qzy + Qyz) * s
                z = 0.5 * r
        return np.array([w, x, y, z])

    def display(self):
        r"""
        Display the static mesh and the aligned within an interactive VTK 
        window 
        
        """
        if not hasattr(self.s, 'actor'):
            self.s.addActor()
        if not hasattr(self.m, 'actor'):
            self.m.addActor()
        # Generate a renderer window
        win = vtkRenWin()
        # Set the number of viewports
        win.setnumViewports(1)
        # Set the background colour
        win.setBackground([1, 1, 1])
        # Set camera projection
        renderWindowInteractor = vtk.vtkRenderWindowInteractor()
        renderWindowInteractor.SetRenderWindow(win)
        renderWindowInteractor.SetInteractorStyle(
            vtk.vtkInteractorStyleTrackballCamera())
        # Set camera projection
        win.setView()
        self.s.actor.setColor([1.0, 0.0, 0.0])
        self.s.actor.setOpacity(0.5)
        self.m.actor.setColor([0.0, 0.0, 1.0])
        self.m.actor.setOpacity(0.5)
        win.renderActors([self.s.actor, self.m.actor])
        win.Render()
        win.rens[0].GetActiveCamera().Azimuth(180)
        win.rens[0].GetActiveCamera().SetParallelProjection(True)
        win.Render()
        return win