예제 #1
0
def load_tomo(fname, mmap=False):
    """
    Loads a tomogram in MRC, EM or VTI format and converts it into a numpy
    format.

    Args:
        fname (str): full path to the tomogram, has to end with '.mrc', '.em' or
            '.vti'
        mmap (boolean, optional): if True (default False) a numpy.memmap object
            is loaded instead of numpy.ndarray, which means that data are not
            loaded completely to memory, this is useful only for very large
            tomograms. Only valid with formats MRC and EM. VERY IMPORTANT: This
            subclass of ndarray has some unpleasant interaction with some
            operations, because it does not quite fit properly as a subclass of
            numpy.ndarray

    Returns:
        numpy.ndarray or numpy.memmap object
    """
    # Input parsing
    stem, ext = os.path.splitext(fname)
    if mmap and (not (ext == '.mrc' or (ext == '.em'))):
        error_msg = ('mmap option is only valid for MRC or EM formats, current '
                     + ext)
        raise pexceptions.PySegInputError(expr='load_tomo', msg=error_msg)

    # if ext == '.fits':
    #     im_data = pyfits.getdata(fname).transpose()
    elif ext == '.mrc':
        image = ImageIO()
        if mmap:
            image.readMRC(fname, memmap=mmap)
        else:
            image.readMRC(fname)
        im_data = image.data
    elif ext == '.em':
        image = ImageIO()
        if mmap:
            image.readEM(fname, memmap=mmap)
        else:
            image.readEM(fname)
        im_data = image.data
    elif ext == '.vti':
        reader = vtk.vtkXMLImageDataReader()
        reader.SetFileName(fname)
        reader.Update()
        im_data = vti_to_numpy(reader.GetOutput())
    else:
        error_msg = '%s is non valid format.' % ext
        raise pexceptions.PySegInputError(expr='load_tomo', msg=error_msg)

    # For avoiding 2D arrays
    if len(im_data.shape) == 2:
        im_data = np.reshape(im_data, (im_data.shape[0], im_data.shape[1], 1))

    return im_data
예제 #2
0
def gen_surface(tomo, lbl=1, mask=True, purge_ratio=1, field=False,
                mode_2d=False, verbose=False):
    """
    Generates a VTK PolyData surface from a segmented tomogram.

    Args:
        tomo (numpy.ndarray or str): the input segmentation as numpy ndarray or
            the file name in MRC, EM or VTI format
        lbl (int, optional): label for the foreground, default 1
        mask (boolean, optional): if True (default), the input segmentation is
            used as mask for the surface
        purge_ratio (int, optional): if greater than 1 (default), then 1 every
            purge_ratio points of the segmentation are randomly deleted
        field (boolean, optional): if True (default False), additionally returns
            the polarity distance scalar field
        mode_2d (boolean, optional): needed for polarity distance calculation
            (if field is True), if True (default False), ...
        verbose (boolean, optional): if True (default False), prints out
            messages for checking the progress

    Returns:
        - output surface (vtk.vtkPolyData)
        - polarity distance scalar field (np.ndarray), if field is True
    """
    # Check input format
    if isinstance(tomo, str):
        fname, fext = os.path.splitext(tomo)
        # if fext == '.fits':
        #     tomo = pyfits.getdata(tomo)
        if fext == '.mrc':
            hold = ImageIO()
            hold.readMRC(file=tomo)
            tomo = hold.data
        elif fext == '.em':
            hold = ImageIO()
            hold.readEM(file=tomo)
            tomo = hold.data
        elif fext == '.vti':
            reader = vtk.vtkXMLImageDataReader()
            reader.SetFileName(tomo)
            reader.Update()
            tomo = vti_to_numpy(reader.GetOutput())
        else:
            error_msg = 'Format %s not readable.' % fext
            raise pexceptions.PySegInputError(expr='gen_surface', msg=error_msg)
    elif not isinstance(tomo, np.ndarray):
        error_msg = 'Input must be either a file name or a ndarray.'
        raise pexceptions.PySegInputError(expr='gen_surface', msg=error_msg)

    # Load file with the cloud of points
    nx, ny, nz = tomo.shape
    cloud = vtk.vtkPolyData()
    points = vtk.vtkPoints()
    cloud.SetPoints(points)

    if purge_ratio <= 1:
        for x in range(nx):
            for y in range(ny):
                for z in range(nz):
                    if tomo[x, y, z] == lbl:
                        points.InsertNextPoint(x, y, z)
    else:
        count = 0
        mx_value = purge_ratio - 1
        purge = np.random.randint(0, purge_ratio+1, nx*ny*nz)
        for x in range(nx):
            for y in range(ny):
                for z in range(nz):
                    if purge[count] == mx_value:
                        if tomo[x, y, z] == lbl:
                            points.InsertNextPoint(x, y, z)
                    count += 1

    if verbose:
        print 'Cloud of points loaded...'

    # Creating the isosurface
    surf = vtk.vtkSurfaceReconstructionFilter()
    # surf.SetSampleSpacing(2)
    surf.SetSampleSpacing(purge_ratio)
    # surf.SetNeighborhoodSize(10)
    surf.SetInputData(cloud)
    contf = vtk.vtkContourFilter()
    contf.SetInputConnection(surf.GetOutputPort())
    # if thick is None:
    contf.SetValue(0, 0)
    # else:
        # contf.SetValue(0, thick)

    # Sometimes the contouring algorithm can create a volume whose gradient
    # vector and ordering of polygon (using the right hand rule) are
    # inconsistent. vtkReverseSense cures    this problem.
    reverse = vtk.vtkReverseSense()
    reverse.SetInputConnection(contf.GetOutputPort())
    reverse.ReverseCellsOn()
    reverse.ReverseNormalsOn()
    reverse.Update()
    rsurf = reverse.GetOutput()

    if verbose:
        print 'Isosurfaces generated...'

    # Translate and scale to the proper positions
    cloud.ComputeBounds()
    rsurf.ComputeBounds()
    xmin, xmax, ymin, ymax, zmin, zmax = cloud.GetBounds()
    rxmin, rxmax, rymin, rymax, rzmin, rzmax = rsurf.GetBounds()
    scale_x = (xmax-xmin) / (rxmax-rxmin)
    scale_y = (ymax-ymin) / (rymax-rymin)
    denom = rzmax - rzmin
    num = zmax - xmin
    if (denom == 0) or (num == 0):
        scale_z = 1
    else:
        scale_z = (zmax-zmin) / (rzmax-rzmin)
    transp = vtk.vtkTransform()
    transp.Translate(xmin, ymin, zmin)
    transp.Scale(scale_x, scale_y, scale_z)
    transp.Translate(-rxmin, -rymin, -rzmin)
    tpd = vtk.vtkTransformPolyDataFilter()
    tpd.SetInputData(rsurf)
    tpd.SetTransform(transp)
    tpd.Update()
    tsurf = tpd.GetOutput()

    if verbose:
        print 'Rescaled and translated...'

    # Masking according to distance to the original segmentation
    if mask:
        tomod = scipy.ndimage.morphology.distance_transform_edt(
            np.invert(tomo == lbl))
        for i in range(tsurf.GetNumberOfCells()):

            # Check if all points which made up the polygon are in the mask
            points_cell = tsurf.GetCell(i).GetPoints()
            count = 0
            for j in range(0, points_cell.GetNumberOfPoints()):
                x, y, z = points_cell.GetPoint(j)
                if (tomod[int(round(x)), int(round(y)), int(round(z))]
                        > MAX_DIST_SURF):
                    count += 1

            if count > 0:
                tsurf.DeleteCell(i)

        # Release free memory
        tsurf.RemoveDeletedCells()

        if verbose:
            print 'Mask applied...'

    # Field distance
    if field:

        # Get normal attributes
        norm_flt = vtk.vtkPolyDataNormals()
        norm_flt.SetInputData(tsurf)
        norm_flt.ComputeCellNormalsOn()
        norm_flt.AutoOrientNormalsOn()
        norm_flt.ConsistencyOn()
        norm_flt.Update()
        tsurf = norm_flt.GetOutput()
        # for i in range(tsurf.GetPointData().GetNumberOfArrays()):
        #    array = tsurf.GetPointData().GetArray(i)
        #    if array.GetNumberOfComponents() == 3:
        #        break
        array = tsurf.GetCellData().GetNormals()

        # Build membrane mask
        tomoh = np.ones(shape=tomo.shape, dtype=np.bool)
        tomon = np.ones(shape=(tomo.shape[0], tomo.shape[1], tomo.shape[2], 3),
                        dtype=TypesConverter().vtk_to_numpy(array))
        # for i in range(tsurf.GetNumberOfCells()):
        #     points_cell = tsurf.GetCell(i).GetPoints()
        #     for j in range(0, points_cell.GetNumberOfPoints()):
        #         x, y, z = points_cell.GetPoint(j)
        #         # print x, y, z, array.GetTuple(j)
        #         x, y, z = int(round(x)), int(round(y)), int(round(z))
        #         tomo[x, y, z] = False
        #         tomon[x, y, z, :] = array.GetTuple(j)
        for i in range(tsurf.GetNumberOfCells()):
            points_cell = tsurf.GetCell(i).GetPoints()
            for j in range(0, points_cell.GetNumberOfPoints()):
                x, y, z = points_cell.GetPoint(j)
                # print x, y, z, array.GetTuple(j)
                x, y, z = int(round(x)), int(round(y)), int(round(z))
                if tomo[x, y, z] == lbl:
                    tomoh[x, y, z] = False
                    tomon[x, y, z, :] = array.GetTuple(i)

        # Distance transform
        tomod, ids = scipy.ndimage.morphology.distance_transform_edt(
            tomoh, return_indices=True)

        # Compute polarity
        if mode_2d:
            for x in range(nx):
                for y in range(ny):
                    for z in range(nz):
                        i_x, i_y, i_z = (ids[0, x, y, z], ids[1, x, y, z],
                                         ids[2, x, y, z])
                        norm = tomon[i_x, i_y, i_z]
                        norm[2] = 0
                        pnorm = (i_x, i_y, 0)
                        p = (x, y, 0)
                        dprod = dot_norm(np.asarray(p, dtype=np.float),
                                         np.asarray(pnorm, dtype=np.float),
                                         np.asarray(norm, dtype=np.float))
                        tomod[x, y, z] = tomod[x, y, z] * np.sign(dprod)
        else:
            for x in range(nx):
                for y in range(ny):
                    for z in range(nz):
                        i_x, i_y, i_z = (ids[0, x, y, z], ids[1, x, y, z],
                                         ids[2, x, y, z])
                        hold_norm = tomon[i_x, i_y, i_z]
                        norm = hold_norm
                        # norm[0] = (-1) * hold_norm[1]
                        # norm[1] = hold_norm[0]
                        # norm[2] = hold_norm[2]
                        pnorm = (i_x, i_y, i_z)
                        p = (x, y, z)
                        dprod = dot_norm(np.asarray(pnorm, dtype=np.float),
                                         np.asarray(p, dtype=np.float),
                                         np.asarray(norm, dtype=np.float))
                        tomod[x, y, z] = tomod[x, y, z] * np.sign(dprod)

        if verbose:
            print 'Distance field generated...'

        return tsurf, tomod

    if verbose:
        print 'Finished!'

    return tsurf