Esempio n. 1
0
def taper_edges(image, width):
    """
    taper edges of image (or volume) with cos function

    @param image: input image (or volume)
    @type image: ndarray
    @param width: width of edge
    @type width: int

    @return: image with smoothened edges, taper_mask
    @rtype: array-like

    @author: GvdS
    """

    dims = list(image.shape) + [0]
    val = xp.cos(xp.arange(1, width + 1) * xp.pi / (2. * (width)))
    taperX = xp.ones((dims[0]), dtype=xp.float32)
    taperY = xp.ones((dims[1]))
    taperX[:width] = val[::-1]
    taperX[-width:] = val
    taperY[:width] = val[::-1]
    taperY[-width:] = val
    if dims[2] > 1:
        taperZ = xp.ones((dims[2]))
        taperZ[:width] = val[::-1]
        taperZ[-width:] = val
        Z, X, Y = xp.meshgrid(taperX, taperY, taperZ)
        taper_mask = X * (X < Y) * (X < Z) + Y * (Y <= X) * (Y < Z) + Z * (
            Z <= Y) * (Z <= X)
    else:
        X, Y = xp.meshgrid(taperY, taperX)
        taper_mask = X * (X < Y) + Y * (Y <= X)

    return image * taper_mask, taper_mask
Esempio n. 2
0
def maxIndex(volume, num_threads=1024):
    nblocks = int(xp.ceil(volume.size / num_threads / 2))
    fast_sum = -1000000 * xp.ones((nblocks), dtype=xp.float32)
    max_id = xp.zeros((nblocks), dtype=xp.int32)
    argmax((
        nblocks,
        1,
    ), (num_threads, 1, 1), (volume, fast_sum, max_id, volume.size),
           shared_mem=16 * num_threads)
    mm = min(max_id[fast_sum.argmax()], volume.size - 1)
    indices = xp.unravel_index(mm, volume.shape)
    return indices
Esempio n. 3
0
    def initVolume(self, sizeX, sizeY, sizeZ):
        """
        initVolume:
        @param sizeX:
        @param sizeY:
        @param sizeZ:
        @return:
        @author: Thomas Hrabe
        """
        from pytom.tompy.tools import create_sphere

        if self._radius > 0 or self._smooth > 0:
            self._weight = create_sphere((sizeX, sizeY, sizeZ), self._radius,
                                         self._smooth)
        else:
            self._weight = xp.ones((sizeX, sizeY, sizeZ), dtype=xp.float32)
Esempio n. 4
0
def exact_filter(tilt_angles, tiltAngle, sX, sY, sliceWidth=1, arr=[]):
    """
    exactFilter: Generates the exact weighting function required for weighted backprojection - y-axis is tilt axis
    Reference : Optik, Exact filters for general geometry three dimensional reconstuction, vol.73,146,1986.
    @param tilt_angles: list of all the tilt angles in one tilt series
    @param titlAngle: tilt angle for which the exact weighting function is calculated
    @param sizeX: size of weighted image in X
    @param sizeY: size of weighted image in Y

    @return: filter volume

    """
    import numpy as xp

    # Calculate the relative angles in radians.
    diffAngles = (xp.array(tilt_angles) - tiltAngle) * xp.pi / 180.

    # Closest angle to tiltAngle (but not tiltAngle) sets the maximal frequency of overlap (Crowther's frequency).
    # Weights only need to be calculated up to this frequency.
    sampling = xp.min(xp.abs(diffAngles)[xp.abs(diffAngles) > 0.001])

    crowtherFreq = min(sX // 2,
                       xp.int32(xp.ceil(sliceWidth / xp.sin(sampling))))
    arrCrowther = xp.matrix(
        xp.abs(xp.arange(-crowtherFreq, min(sX // 2, crowtherFreq + 1))))

    # Calculate weights
    wfuncCrowther = 1. / (xp.clip(
        1 - xp.array(xp.matrix(xp.abs(xp.sin(diffAngles))).T * arrCrowther)**2,
        0, 2)).sum(axis=0)

    # Create full with weightFunc
    wfunc = xp.ones((sX, sY), dtype=xp.float32)

    # row_stack is not implemented in cupy
    weightingFunc = xp.column_stack(([
        (wfuncCrowther),
    ] * (sY))).T

    wfunc[:, sX // 2 - crowtherFreq:sX // 2 +
          min(sX // 2, crowtherFreq + 1)] = weightingFunc

    return wfunc
Esempio n. 5
0
def alignImagesUsingAlignmentResultFile(alignmentResultsFile,
                                        weighting=None,
                                        lowpassFilter=0.9,
                                        binning=1,
                                        circleFilter=False):
    import os
    from pytom.basic.files import read as readCVol
    from pytom_numpy import vol2npy, npy2vol
    from pytom.gui.guiFunctions import fmtAR, headerAlignmentResults, datatype, datatypeAR, loadstar
    from pytom.reconstruction.reconstructionStructures import Projection, ProjectionList
    from pytom.tompy.io import read, write, read_size
    from pytom.tompy.tools import taper_edges, create_circle
    from pytom.tompy.filter import circle_filter, ramp_filter, exact_filter
    import pytom.voltools as vt
    from pytom.gpu.initialize import xp, device
    print("Create aligned images from alignResults.txt")

    alignmentResults = loadstar(alignmentResultsFile, dtype=datatypeAR)
    imageList = alignmentResults['FileName']
    tilt_angles = alignmentResults['TiltAngle']

    imdim = read_size(imageList[0], 'x')

    if binning > 1:
        imdim = int(float(imdim) / float(binning) + .5)
    else:
        imdim = imdim

    sliceWidth = imdim

    if (weighting != None) and (float(weighting) < -0.001):
        weightSlice = xp.fft.fftshift(ramp_filter(imdim, imdim))

    if circleFilter:
        circleFilterRadius = imdim // 2
        circleSlice = xp.fft.fftshift(
            circle_filter(imdim, imdim, circleFilterRadius))
    else:
        circleSlice = xp.ones((imdim, imdim))

    # design lowpass filter
    if lowpassFilter:
        if lowpassFilter > 1.:
            lowpassFilter = 1.
            print("Warning: lowpassFilter > 1 - set to 1 (=Nyquist)")

        # weighting filter: arguments: (()dimx, dimy), cutoff radius, sigma
        lpf = xp.fft.fftshift(
            create_circle((imdim, imdim),
                          lowpassFilter * (imdim // 2),
                          sigma=0.4 * lowpassFilter * (imdim // 2)))

    projectionList = ProjectionList()
    for n, image in enumerate(imageList):
        atx = alignmentResults['AlignmentTransX'][n] / binning
        aty = alignmentResults['AlignmentTransY'][n] / binning
        rot = alignmentResults['InPlaneRotation'][n]
        mag = 1 / alignmentResults['Magnification'][n]
        projection = Projection(imageList[n],
                                tiltAngle=tilt_angles[n],
                                alignmentTransX=atx,
                                alignmentTransY=aty,
                                alignmentRotation=rot,
                                alignmentMagnification=mag)
        projectionList.append(projection)

    stack = xp.zeros((imdim, imdim, len(imageList)), dtype=xp.float32)
    phiStack = xp.zeros((1, 1, len(imageList)), dtype=xp.float32)
    thetaStack = xp.zeros((1, 1, len(imageList)), dtype=xp.float32)
    offsetStack = xp.zeros((1, 2, len(imageList)), dtype=xp.float32)

    for (ii, projection) in enumerate(projectionList):
        print(f'Align {projection._filename}')
        image = read(str(projection._filename))[::binning, ::binning].squeeze()

        if lowpassFilter:
            image = xp.abs((xp.fft.ifftn(xp.fft.fftn(image) * lpf)))

        tiltAngle = projection._tiltAngle

        # normalize to contrast - subtract mean and norm to mean
        immean = image.mean()
        image = (image - immean) / immean

        # smoothen borders to prevent high contrast oscillations
        image = taper_edges(image, imdim // 30)[0]

        # transform projection according to tilt alignment
        transX = projection._alignmentTransX / binning
        transY = projection._alignmentTransY / binning
        rot = float(projection._alignmentRotation)
        mag = float(projection._alignmentMagnification)

        inputImage = xp.expand_dims(image, 2).copy()
        outputImage = xp.zeros_like(inputImage, dtype=xp.float32)

        vt.transform(inputImage.astype(xp.float32),
                     rotation=[0, 0, rot],
                     rotation_order='rxyz',
                     output=outputImage,
                     device=device,
                     translation=[transX, transY, 0],
                     scale=[mag, mag, 1],
                     interpolation='filt_bspline')

        image = outputImage.squeeze()

        # smoothen once more to avoid edges
        image = taper_edges(image, imdim // 30)[0]

        # analytical weighting
        if (weighting != None) and (weighting < 0):
            # image = (ifft(complexRealMult(fft(image), w_func)) / (image.sizeX() * image.sizeY() * image.sizeZ()))
            image = xp.fft.ifftn(
                xp.fft.fftn(image) * weightSlice * circleSlice)

        elif (weighting != None) and (weighting > 0):
            weightSlice = xp.fft.fftshift(
                exact_filter(tilt_angles, tiltAngle, imdim, imdim, sliceWidth))
            image = xp.fft.ifftn(
                xp.fft.fftn(image) * weightSlice * circleSlice)

        thetaStack[0, 0, ii] = int(round(projection.getTiltAngle()))
        offsetStack[0, :, ii] = xp.array([
            int(round(projection.getOffsetX())),
            int(round(projection.getOffsetY()))
        ])

        stack[:, :, ii] = image

    arrays = []

    for fname, arr in (('stack.mrc', stack), ('offsetStack.mrc', offsetStack),
                       ('thetaStack.mrc', thetaStack), ('phiStack.mrc',
                                                        phiStack)):
        if 'gpu' in device:
            arr = arr.get()
        import numpy as np
        res = npy2vol(np.array(arr, dtype='float32', order='F'), 3)
        arrays.append(res)

    #
    #     write('stack.mrc', stack)
    #     stack = readCVol('stack.mrc')
    # write('offsetstack.mrc', offsetStack)
    # offsetStack = readCVol('offsetstack.mrc')
    # write('thetastack.mrc', thetaStack)
    # thetaStack = readCVol('thetastack.mrc')
    # write('phistack.mrc', phiStack)
    # phiStack = readCVol('phistack.mrc')
    #
    # os.remove('stack.mrc')
    # os.remove('offsetstack.mrc')
    # os.remove('thetastack.mrc')
    # os.remove('psistack.mrc')

    return arrays
Esempio n. 6
0
def alignImageUsingAlignmentResultFile(alignmentResultsFile,
                                       indexImage,
                                       weighting=None,
                                       lowpassFilter=0.9,
                                       binning=1,
                                       circleFilter=False):
    import pytom_freqweight
    from pytom_numpy import vol2npy
    from pytom.gui.guiFunctions import fmtAR, headerAlignmentResults, datatype, datatypeAR, loadstar
    from pytom.reconstruction.reconstructionStructures import Projection, ProjectionList
    from pytom.tompy.io import read, write, read_size
    from pytom.tompy.tools import taper_edges, create_circle
    from pytom.tompy.filter import circle_filter, ramp_filter, exact_filter, ellipse_filter
    import pytom.voltools as vt
    from pytom.gpu.initialize import xp, device

    # print("Create aligned images from alignResults.txt")

    alignmentResults = loadstar(alignmentResultsFile, dtype=datatypeAR)
    imageList = alignmentResults['FileName']
    tilt_angles = alignmentResults['TiltAngle']

    imdimX = read_size(imageList[0], 'x')
    imdimY = read_size(imageList[0], 'y')

    if binning > 1:
        imdimX = int(float(imdimX) / float(binning) + .5)
        imdimY = int(float(imdimY) / float(binning) + .5)

    sliceWidth = imdimX

    if (weighting != None) and (float(weighting) < -0.001):
        weightSlice = xp.fft.fftshift(ramp_filter(imdimY, imdimX))

    if circleFilter:
        circleFilterRadiusX = imdimX // 2
        circleFilterRadiusY = imdimY // 2

        circleSlice = xp.fft.fftshift(
            ellipse_filter(imdimX, imdimY, circleFilterRadiusX,
                           circleFilterRadiusY))
    else:
        circleSlice = xp.ones((imdimX, imdimY))

    # design lowpass filter
    if lowpassFilter:
        if lowpassFilter > 1.:
            lowpassFilter = 1.
            print("Warning: lowpassFilter > 1 - set to 1 (=Nyquist)")

        # weighting filter: arguments: (()dimx, dimy), cutoff radius, sigma
        # lpf = xp.fft.fftshift(create_circle((imdimX,imdimY),lowpassFilter*(imdim//2), sigma=0.4*lowpassFilter*(imdim//2)))

    projectionList = ProjectionList()
    for n, image in enumerate(imageList):
        atx = alignmentResults['AlignmentTransX'][n]
        aty = alignmentResults['AlignmentTransY'][n]
        rot = alignmentResults['InPlaneRotation'][n]
        mag = 1 / (alignmentResults['Magnification'][n])
        projection = Projection(imageList[n],
                                tiltAngle=tilt_angles[n],
                                alignmentTransX=atx,
                                alignmentTransY=aty,
                                alignmentRotation=rot,
                                alignmentMagnification=mag)
        projectionList.append(projection)

    imdim = min(imdimY, imdimX)

    for (ii, projection) in enumerate(projectionList):
        if not ii == indexImage:
            continue
        from pytom.tompy.transform import resize

        # print(f'read {projection._filename}')
        image = read(str(projection._filename)).squeeze()

        if binning > 1:
            image = resize(image, 1 / binning)

        #write(f'test/image_{ii}.mrc', image, tilt_angle=tilt_angles[ii])

        tiltAngle = projection._tiltAngle

        # 1 -- normalize to contrast - subtract mean and norm to mean
        immean = image.mean()
        image = (image - immean) / immean

        # 2 -- smoothen borders to prevent high contrast oscillations
        image = taper_edges(image, imdim // 30)[0]

        # 3 -- square if needed
        if 0 and imdimY != imdimX:
            newImage = xp.zeros((imdim, imdim, 1), dtype=xp.float32)
            pasteCenter(image, newImage)
            image = newImage

        # 4 -- transform projection according to tilt alignment
        transX = projection._alignmentTransX / binning
        transY = projection._alignmentTransY / binning
        rot = float(projection._alignmentRotation)
        mag = float(projection._alignmentMagnification)

        inputImage = xp.expand_dims(image, 2).copy()
        outputImage = xp.zeros_like(inputImage, dtype=xp.float32)

        vt.transform(
            inputImage.astype(xp.float32),
            rotation=[0, 0, rot],
            rotation_order='rxyz',
            output=outputImage,
            center=[inputImage.shape[0] // 2, inputImage.shape[1] // 2, 0],
            device=device,
            translation=[transX, transY, 0],
            scale=[mag, mag, 1],
            interpolation='filt_bspline')

        del image
        image = outputImage.squeeze()

        # 5 -- Optional Low Pass Filter
        if lowpassFilter:
            from pytom.tompy.filter import bandpass_circle

            image = bandpass_circle(
                image,
                high=lowpassFilter * (min(imdimX, imdimY) // 2),
                sigma=0.4 * lowpassFilter * (min(imdimX, imdimY) // 2))
            # image = xp.abs((xp.fft.ifftn(xp.fft.fftn(image) * lpf)))

        # 6 -- smoothen once more to avoid edges
        image = taper_edges(image, imdim // 30)[0]

        # 7 -- analytical weighting
        if (weighting != None) and (weighting < 0):

            # image = (ifft(complexRealMult(fft(image), w_func)) / (image.sizeX() * image.sizeY() * image.sizeZ()))
            image = xp.fft.ifftn(
                xp.fft.fftn(image) * weightSlice.T * circleSlice).real

        elif (weighting != None) and (weighting > 0):
            weightSlice = xp.fft.fftshift(
                exact_filter(tilt_angles, tiltAngle, imdim, imdim, sliceWidth))
            image = xp.fft.ifftn(
                xp.fft.fftn(image) * weightSlice * circleSlice).real

        del inputImage, outputImage, circleSlice

        write(f'inputImage_{ii}.mrc', image)

        return image.astype(xp.float32)