Exemplo n.º 1
0
def back_projection(image, angle, interpolation=None):
    '''
    back_projection: Returns 3D volume with values projected under an angle
    @param image: 2d projection image
    @type image: 2d numpy/cupy array of floats
    @param angle: Tilt angle of projection
    @type angle: float32
    @param interpolation: interpolation type using in rotation. filtered bspline is recommended.
    @return:
    @author Gijs van der Schot
    '''

    dimx, dimy = image.shape
    dims = [dimx, dimy, dimx]

    ndim = (dimx * 3) // 2

    out = xp.dstack([image] * (ndim))

    bp = xp.zeros_like(out)
    transform(out,
              output=bp,
              rotation=(0, angle, 0),
              rotation_order='sxyz',
              interpolation=interpolation,
              center=numpy.array([dims[0] // 2, dims[1] // 2, ndim // 2]),
              device=device)

    return bp[:, :, ndim // 2 - dimx // 2:ndim // 2 + dimx // 2].copy()
Exemplo n.º 2
0
def subPixelShifts(volume, scale=[1, 1, 1], cutout=20, shifts=True):
    from pytom.voltools import transform
    import numpy as np

    volC = xp.array(
        [volume.shape[0] // 2, volume.shape[1] // 2, volume.shape[2] // 2])
    cx, cy, cz = center = find_coords_max_ccmap(volume)

    x, y, z = volume.shape
    if cx - cutout // 2 < 0 or cy - cutout // 2 < 0 or cz - cutout // 2 < 0 or cx + cutout // 2 >= x or cy + cutout // 2 >= y or cz + cutout // 2 >= z:
        return [
            volume.max(),
            xp.array([cx - x // 2, cy - y // 2, cz - z // 2])
        ]

    cutvol = volume[cx - cutout // 2:cx + cutout // 2,
                    cy - cutout // 2:cy + cutout // 2,
                    cz - cutout // 2:cz + cutout // 2]
    nx, ny, nz = [int(np.ceil(cutout / s)) for s in scale]
    nC = xp.array([nx // 2, ny // 2, nz // 2])
    centered = xp.zeros((nx, ny, nz), dtype=xp.float32)
    centered[nx // 2 - cutout // 2:nx // 2 + cutout // 2,
             ny // 2 - cutout // 2:ny // 2 + cutout // 2,
             nz // 2 - cutout // 2:nz // 2 + cutout // 2] = cutvol
    interpolated_peak = xp.zeros_like(centered)
    transform(centered,
              scale=scale,
              device=device,
              interpolation='filt_bspline',
              output=interpolated_peak)
    interpol_center = find_coords_max_ccmap(interpolated_peak)
    shifts = xp.array(center) - volC + (xp.array(interpol_center) -
                                        xp.array(nC)) * xp.array(scale)

    return [interpolated_peak.max(), shifts]
Exemplo n.º 3
0
def subPixelPeak(inp, k=1, ignore_border=1):
    from pytom.voltools import transform
    is2d = len(inp.shape.squeeze()) == 2
    if is2d:
        inp2 = inp.squeeze()[ignore_border:-ignore_border,
                             ignore_border:-ignore_border]
        scale = (k, k, 1)
        out = xp.array(inp.squeeze(), dtype=xp.float32)
        x, y = xp.array(xp.unravel_index(inp2.argmax(),
                                         inp2.shape)) + ignore_border
        out = xp.expand_dims(out, 2)
        translation = [inp.shape[0] // 2 - x, inp.shape[1] // 2 - y, 0]
    else:
        inp2 = inp[ignore_border:-ignore_border, ignore_border:-ignore_border,
                   ignore_border:-ignore_border]
        scale = (k, k, k)
        out = xp.array(inp, dtype=xp.float32)
        x, y, z = xp.array(xp.unravel_index(inp2.argmax(),
                                            inp2.shape)) + ignore_border
        translation = [
            inp.shape[0] // 2 - x, inp.shape[1] // 2 - y, inp.shape[2] // 2 - z
        ]

    zoomed = xp.zeros_like(out, dtype=xp.float32)
    transform(out,
              output=zoomed,
              scale=scale,
              translation=translation,
              device='gpu:0',
              interpolation='filt_bspline')

    shifts = [x2, y2, z2] = xp.unravel_index(zoomed.argmax(), zoomed.shape)

    return [zoomed[x2, y2, z2], shifts[3 - is2d]]
Exemplo n.º 4
0
def rotateWeighting(weighting, rotation, mask=None, binarize=False):
    """
    rotateWeighting: Rotates a frequency weighting volume around the center. If the volume provided is reduced complex, it will be rescaled to full size, ftshifted, rotated, iftshifted and scaled back to reduced size.
    @param weighting: A weighting volume in reduced complex convention
    @type weighting: cupy or numpy array
    @param rotation: rotation angles in zxz order
    @type rotation: list
    @param mask:=None is there a rotation mask? A mask with all = 1 will be generated otherwise. Such mask should be \
        provided anyway.
    @type mask: cupy or numpy ndarray
    @return: weight as reduced complex volume
    @rtype: L{pytom_volume.vol_comp}
    """
    from pytom_volume import vol, limit, vol_comp
    from pytom_volume import rotate
    from pytom.voltools import transform
    assert type(weighting) == vol or type(
        weighting
    ) == vol_comp, "rotateWeighting: input neither vol nor vol_comp"
    from pytom.tompy.transform import fourier_reduced2full, fourier_full2reduced

    weighting = fourier_reduced2full(weighting,
                                     isodd=weighting.shape[0] % 2 == 1)
    weighting = xp.fft.fftshift(weighting)

    weightingRotated = xp.zeros_like(weighting)

    transform(weighting,
              output=weightingRotated,
              rotation=rotation,
              rotation_order='rzxz',
              device=device,
              interpolation='filt_bspline')

    if not mask is None:
        weightingRotated *= mask

    weightingRotated = xp.fft.fftshift(weightingRotated)
    returnVolume = fourier_full2reduced(weightingRotated)

    if binarize:
        returnVolume[returnVolume < 0.5] = 0
        returnVolume[returnVolume >= 0.5] = 1

    return returnVolume
Exemplo n.º 5
0
def cut_patch(projection,
              ang,
              pick_position,
              vol_size=200,
              binning=8,
              dimz=None,
              offset=[0, 0, 0],
              projection_name=None):
    from pytom.tompy.tools import taper_edges
    #from pytom.gpu.initialize import xp
    from pytom.voltools import transform
    from pytom.tompy.transform import rotate3d, rotate_axis
    from pytom.tompy.transform import cut_from_projection
    from pytom.tompy.io import read

    # Get the size of the original projection
    dim_x = projection.shape[0]
    dim_z = dim_x if dimz is None else dimz

    x, y, z = pick_position
    x = (x + offset[0]) * binning
    y = (y + offset[1]) * binning
    z = (z + offset[2]) * binning

    # Get coordinates of the paricle adjusted for the tilt angle
    yy = y  # assume the rotation axis is around y
    xx = (xp.cos(ang * xp.pi / 180) *
          (x - dim_x / 2) - xp.sin(ang * xp.pi / 180) *
          (z - dim_z / 2)) + dim_x / 2

    # Cut the small patch out

    patch = projection[max(0,
                           int(xp.floor(xx)) -
                           vol_size // 2):int(xp.floor(xx)) + vol_size // 2,
                       int(yy) - vol_size // 2:int(yy) + vol_size // 2, :]
    shiftedPatch = xp.zeros_like(patch)
    transform(patch,
              output=shiftedPatch,
              translation=[xx % 1, yy % 1, 0],
              device=device,
              interpolation='filt_bspline')
    return shiftedPatch.squeeze()
Exemplo n.º 6
0
def recenterVolume(volume, densityNegative=False):
    from scipy.ndimage import center_of_mass
    from pytom.tompy.io import read, write
    from pytom.tompy.tools import paste_in_center
    from pytom.gpu.initialize import xp
    from pytom_numpy import vol2npy
    import os

    try:
        a = vol2npy(volume).copy()
        vol = True
    except:
        a = volume
        vol = False

    if densityNegative:
        a *= -1

    x, y, z = list(map(int, center_of_mass(a)))
    cx, cy, cz = a.shape[0] // 2, a.shape[1] // 2, a.shape[2] // 2

    sx = min(x, a.shape[0] - x)
    sy = min(y, a.shape[0] - y)
    sz = min(z, a.shape[0] - z)

    ac = a[x - sx:x + sx, y - sy:y + sy, z - sz:z + sz]
    b = xp.zeros_like(a)

    b = paste_in_center(ac, b)

    if densityNegative: b *= -1

    if vol:
        write('recenteredDBV21.em', b)
        from pytom.basic.files import read
        vol = read('recenteredDBV21.em')
        os.system('rm recenteredDBV21.em')
        return vol
    else:
        return b
Exemplo n.º 7
0
def profile2FourierVol(profile, dim=None, reduced=False):
    """
    create Volume from 1d radial profile, e.g., to modulate signal with \
    specific function such as CTF or FSC. Simple linear interpolation is used\
    for sampling.

    @param profile: profile
    @type profile: 1-d L{pytom_volume.vol} or 1-d python array
    @param dim: dimension of (cubic) output
    @type dim: L{int}
    @param reduced: If true reduced Fourier representation (N/2+1, N, N) is generated.
    @type reduced: L{bool}

    @return: 3-dim complex volume with spherically symmetrical profile
    @rtype: L{pytom_volume.vol}
    @author: FF
    """

    if dim is None:
        try:
            dim = [
                2 * profile.shape[0],
            ] * 3
        except:
            dim = [
                2 * len(profile),
            ] * 3

    is3D = (len(dim) == 3)

    nx, ny = dim[:2]
    if reduced:
        if is3D:
            nz = int(dim[2] // 2) + 1
        else:
            ny = int(ny // 2) + 1
    else:
        if is3D:
            nz = dim[2]

    try:
        r_max = profile.shape[0] - 1
    except:
        r_max = len(profile) - 1

    if len(dim) == 3:
        if reduced:
            X, Y, Z = xp.meshgrid(xp.arange(-nx // 2, nx // 2 + nx % 2),
                                  xp.arange(-ny // 2, ny // 2 + ny % 2),
                                  xp.arange(0, nz))
        else:
            X, Y, Z = xp.meshgrid(xp.arange(-nx // 2, nx // 2 + nx % 2),
                                  xp.arange(-ny // 2, ny // 2 + ny % 2),
                                  xp.arange(-nz // 2, nz // 2 + nz % 2))
        R = xp.sqrt(X**2 + Y**2 + Z**2)

    else:
        if reduced:
            X, Y = xp.meshgrid(xp.arange(-nx // 2, ny // 2 + ny % 2),
                               xp.arange(0, ny))
        else:
            X, Y = xp.meshgrid(xp.arange(-nx // 2, nx // 2 + nx % 2),
                               xp.arange(-ny // 2, ny // 2 + ny % 2))
        R = xp.sqrt(X**2 + Y**2)

    IR = xp.floor(R).astype(xp.int64)
    valIR_l1 = IR.copy()
    valIR_l2 = valIR_l1 + 1
    val_l1, val_l2 = xp.zeros_like(X, dtype=xp.float64), xp.zeros_like(
        X, dtype=xp.float64)

    l1 = R - IR.astype(xp.float32)
    l2 = 1 - l1

    try:
        profile = xp.array(profile)
    except:
        import numpy
        profile = xp.array(numpy.array(profile))

    for n in xp.arange(r_max):
        val_l1[valIR_l1 == n] = profile[n]
        val_l2[valIR_l2 == n + 1] = profile[n + 1]

    val_l1[IR == r_max] = profile[n + 1]
    val_l2[IR == r_max] = profile[n + 1]

    val_l1[R > r_max] = 0
    val_l2[R > r_max] = 0

    fkernel = l2 * val_l1 + l1 * val_l2

    if reduced:
        fkernel = xp.fft.fftshift(fkernel, axes=(0, 1))
    else:
        fkernel = xp.fft.fftshift(fkernel)

    return fkernel
Exemplo n.º 8
0
def weightedXCF(volume, reference, numberOfBands, wedgeAngle=-1, gpu=False):
    """
    weightedXCF: Determines the weighted correlation function for volume and reference
    @param volume: A volume 
    @param reference: A reference 
    @param numberOfBands:Number of bands
    @param wedgeAngle: A optional wedge angle
    @return: The weighted correlation function 
    @rtype: L{pytom_volume.vol} 
    @author: Thomas Hrabe 
    @todo: does not work yet -> test is disabled
    """

    if gpu:
        import cupy as xp
    else:
        import numpy as xp

    from pytom.tompy.correlation import bandCF
    from pytom.tompy.transforms import fourier_reduced2full
    from math import sqrt
    import pytom_freqweight

    result = xp.zeros_like(volume)

    q = 0

    if wedgeAngle >= 0:
        wedgeFilter = pytom_freqweight.weight(wedgeAngle, 0, volume.sizeX(),
                                              volume.sizeY(), volume.sizeZ())
        wedgeVolume = wedgeFilter.getWeightVolume(True)
    else:
        wedgeVolume = xp.ones_like(volume)

    w = xp.sqrt(1 / float(volume.size))

    numberVoxels = 0

    for i in range(numberOfBands):
        """
        notation according Steward/Grigorieff paper
        """
        band = [0, 0]
        band[0] = i * volume.shape[0] / numberOfBands
        band[1] = (i + 1) * volume.shape[0] / numberOfBands

        r = bandCF(volume, reference, band, gpu=gpu)

        cc = r[0]

        filter = r[1]
        #get bandVolume
        bandVolume = filter.getWeightVolume(True)

        filterVolumeReduced = bandVolume * wedgeVolume
        filterVolume = fourier_reduced2full(filterVolumeReduced)
        #determine number of voxels != 0
        N = filterVolume[abs(filterVolume) < 1].sum()

        #add to number of total voxels
        numberVoxels = numberVoxels + N

        cc2 = r[0].copy()

        cc2 = cc2**2

        cc = shiftscale(cc, w, 1)
        ccdiv = cc2 / (cc)

        ccdiv = ccdiv**3

        #abs(ccdiv); as suggested by grigorief
        ccdiv = shiftscale(ccdiv, 0, N)

        result = result + ccdiv

    result = shiftscale(result, 0, 1 / float(numberVoxels))

    return result
Exemplo n.º 9
0
def averageGPU(particleList,
               averageName,
               showProgressBar=False,
               verbose=False,
               createInfoVolumes=False,
               weighting=False,
               norm=False,
               gpuId=None,
               profile=True):
    """
    average : Creates new average from a particleList
    @param particleList: The particles
    @param averageName: Filename of new average
    @param verbose: Prints particle information. Disabled by default.
    @param createInfoVolumes: Create info data (wedge sum, inverted density) too? False by default.
    @param weighting: apply weighting to each average according to its correlation score
    @param norm: apply normalization for each particle
    @return: A new Reference object
    @rtype: L{pytom.basic.structures.Reference}
    @author: Thomas Hrabe
    @change: limit for wedgeSum set to 1% or particles to avoid division by small numbers - FF
    """
    import time
    from pytom.tompy.io import read, write, read_size
    from pytom.tompy.filter import bandpass as lowpassFilter, rotateWeighting, applyFourierFilter, applyFourierFilterFull, create_wedge
    from pytom.voltools import transform, StaticVolume
    from pytom.basic.structures import Reference
    from pytom.tompy.normalise import mean0std1
    from pytom.tompy.tools import volumesSameSize, invert_WedgeSum, create_sphere
    from pytom.tompy.transform import fourier_full2reduced, fourier_reduced2full
    from cupyx.scipy.fftpack.fft import fftn as fftnP
    from cupyx.scipy.fftpack.fft import ifftn as ifftnP
    from cupyx.scipy.fftpack.fft import get_fft_plan
    from pytom.tools.ProgressBar import FixedProgBar
    from multiprocessing import RawArray
    import numpy as np
    import cupy as xp

    if not gpuId is None:
        device = f'gpu:{gpuId}'
        xp.cuda.Device(gpuId).use()
    else:
        print(gpuId)
        raise Exception('Running gpu code on non-gpu device')
    print(device)
    cstream = xp.cuda.Stream()
    if profile:
        stream = xp.cuda.Stream.null
        t_start = stream.record()

    # from pytom.tools.ProgressBar import FixedProgBar
    from math import exp
    import os

    if len(particleList) == 0:
        raise RuntimeError('The particle list is empty. Aborting!')

    if showProgressBar:
        progressBar = FixedProgBar(0, len(particleList), 'Particles averaged ')
        progressBar.update(0)
        numberAlignedParticles = 0

    # pre-check that scores != 0
    if weighting:
        wsum = 0.
        for particleObject in particleList:
            wsum += particleObject.getScore().getValue()
        if wsum < 0.00001:
            weighting = False
            print("Warning: all scores have been zero - weighting not applied")
    import time
    sx, sy, sz = read_size(particleList[0].getFilename())
    wedgeInfo = particleList[0].getWedge().convert2numpy()
    print('angle: ', wedgeInfo.getWedgeAngle())
    wedgeZero = xp.fft.fftshift(
        xp.array(wedgeInfo.returnWedgeVolume(sx, sy, sz, True).get(),
                 dtype=xp.float32))
    # wedgeZeroReduced = fourier_full2reduced(wedgeZero)
    wedge = xp.zeros_like(wedgeZero, dtype=xp.float32)
    wedgeSum = xp.zeros_like(wedge, dtype=xp.float32)
    print('init texture')
    wedgeText = StaticVolume(xp.fft.fftshift(wedgeZero),
                             device=device,
                             interpolation='filt_bspline')

    newParticle = xp.zeros((sx, sy, sz), dtype=xp.float32)

    centerX = sx // 2
    centerY = sy // 2
    centerZ = sz // 2

    result = xp.zeros((sx, sy, sz), dtype=xp.float32)

    fftplan = get_fft_plan(wedge.astype(xp.complex64))

    n = 0

    total = len(particleList)
    # total = int(np.floor((11*1024**3 - mempool.total_bytes())/(sx*sy*sz*4)))
    # total = 128
    #
    #
    # particlesNP = np.zeros((total, sx, sy, sz),dtype=np.float32)
    # particles = []
    # mask = create_sphere([sx,sy,sz], sx//2-6, 2)
    # raw = RawArray('f', int(particlesNP.size))
    # shared_array = np.ctypeslib.as_array(raw)
    # shared_array[:] = particlesNP.flatten()
    # procs = allocateProcess(particleList, shared_array, n, total, wedgeZero.size)
    # del particlesNP

    if profile:
        t_end = stream.record()
        t_end.synchronize()

        time_took = xp.cuda.get_elapsed_time(t_start, t_end)
        print(f'startup time {n:5d}: \t{time_took:.3f}ms')
        t_start = stream.record()

    for particleObject in particleList:

        rotation = particleObject.getRotation()
        rotinvert = rotation.invert()
        shiftV = particleObject.getShift()

        # if n % total == 0:
        #     while len(procs):
        #         procs =[proc for proc in procs if proc.is_alive()]
        #         time.sleep(0.1)
        #         print(0.1)
        #     # del particles
        #     # xp._default_memory_pool.free_all_blocks()
        #     # pinned_mempool.free_all_blocks()
        #     particles = xp.array(shared_array.reshape(total, sx, sy, sz), dtype=xp.float32)
        #     procs = allocateProcess(particleList, shared_array, n, total, size=wedgeZero.size)
        #     #pinned_mempool.free_all_blocks()
        #     #print(mempool.total_bytes()/1024**3)

        particle = read(particleObject.getFilename(), deviceID=device)

        #particle = particles[n%total]

        if norm:  # normalize the particle
            mean0std1(particle)  # happen inplace

        # apply its wedge to
        #particle = applyFourierFilter(particle, wedgeZeroReduced)
        #particle = (xp.fft.ifftn( xp.fft.fftn(particle) * wedgeZero)).real
        particle = (ifftnP(fftnP(particle, plan=fftplan) * wedgeZero,
                           plan=fftplan)).real

        ### create spectral wedge weighting

        wedge *= 0

        wedgeText.transform(
            rotation=[rotinvert[0], rotinvert[2], rotinvert[1]],
            rotation_order='rzxz',
            output=wedge)
        #wedge = xp.fft.fftshift(fourier_reduced2full(create_wedge(30, 30, 21, 42, 42, 42, rotation=[rotinvert[0],rotinvert[2], rotinvert[1]])))
        # if analytWedge:
        #     # > analytical buggy version
        # wedge = wedgeInfo.returnWedgeVolume(sx, sy, sz, True, rotinvert)
        # else:
        #     # > FF: interpol bugfix

        # wedge = rotateWeighting(weighting=wedgeInfo.returnWedgeVolume(sx, sy, sz, True), rotation=[rotinvert[0], rotinvert[2], rotinvert[1]])
        #     # < FF
        #     # > TH bugfix
        #     # wedgeVolume = wedgeInfo.returnWedgeVolume(wedgeSizeX=sizeX, wedgeSizeY=sizeY, wedgeSizeZ=sizeZ,
        #     #                                    humanUnderstandable=True, rotation=rotinvert)
        #     # wedge = rotate(volume=wedgeVolume, rotation=rotinvert, imethod='linear')
        #     # < TH

        ### shift and rotate particle

        newParticle *= 0
        transform(particle,
                  output=newParticle,
                  rotation=[-rotation[1], -rotation[2], -rotation[0]],
                  center=[centerX, centerY, centerZ],
                  translation=[-shiftV[0], -shiftV[1], -shiftV[2]],
                  device=device,
                  interpolation='filt_bspline',
                  rotation_order='rzxz')

        #write(f'trash/GPU_{n}.em', newParticle)
        # print(rotation.toVector())
        # break
        result += newParticle
        wedgeSum += xp.fft.fftshift(wedge)
        # if showProgressBar:
        #     numberAlignedParticles = numberAlignedParticles + 1
        #     progressBar.update(numberAlignedParticles)

        if n % total == 0:
            if profile:
                t_end = stream.record()
                t_end.synchronize()

                time_took = xp.cuda.get_elapsed_time(t_start, t_end)
                print(f'total time {n:5d}: \t{time_took:.3f}ms')
                t_start = stream.record()
        cstream.synchronize()
        n += 1

    print('averaged particles')
    ###apply spectral weighting to sum

    result = lowpassFilter(result, high=sx / 2 - 1, sigma=0)
    # if createInfoVolumes:
    write(averageName[:len(averageName) - 3] + '-PreWedge.em', result)
    write(averageName[:len(averageName) - 3] + '-WedgeSumUnscaled.em',
          fourier_full2reduced(wedgeSum))

    wedgeSumINV = invert_WedgeSum(wedgeSum,
                                  r_max=sx // 2 - 2.,
                                  lowlimit=.05 * len(particleList),
                                  lowval=.05 * len(particleList))
    wedgeSumINV = wedgeSumINV

    #print(wedgeSum.mean(), wedgeSum.std())
    if createInfoVolumes:
        write(averageName[:len(averageName) - 3] + '-WedgeSumInverted.em',
              xp.fft.fftshift(wedgeSumINV))

    result = applyFourierFilterFull(result, xp.fft.fftshift(wedgeSumINV))

    # do a low pass filter
    result = lowpassFilter(result, sx / 2 - 2, (sx / 2 - 1) / 10.)[0]
    write(averageName, result)

    if createInfoVolumes:
        resultINV = result * -1
        # write sign inverted result to disk (good for chimera viewing ... )
        write(averageName[:len(averageName) - 3] + '-INV.em', resultINV)

    newReference = Reference(averageName, particleList)

    return newReference
Exemplo n.º 10
0
def alignImagesUsingAlignmentResultFile(alignmentResultsFile,
                                        weighting=None,
                                        lowpassFilter=0.9,
                                        binning=1,
                                        circleFilter=False):
    import os
    from pytom.basic.files import read as readCVol
    from pytom_numpy import vol2npy, npy2vol
    from pytom.gui.guiFunctions import fmtAR, headerAlignmentResults, datatype, datatypeAR, loadstar
    from pytom.reconstruction.reconstructionStructures import Projection, ProjectionList
    from pytom.tompy.io import read, write, read_size
    from pytom.tompy.tools import taper_edges, create_circle
    from pytom.tompy.filter import circle_filter, ramp_filter, exact_filter
    import pytom.voltools as vt
    from pytom.gpu.initialize import xp, device
    print("Create aligned images from alignResults.txt")

    alignmentResults = loadstar(alignmentResultsFile, dtype=datatypeAR)
    imageList = alignmentResults['FileName']
    tilt_angles = alignmentResults['TiltAngle']

    imdim = read_size(imageList[0], 'x')

    if binning > 1:
        imdim = int(float(imdim) / float(binning) + .5)
    else:
        imdim = imdim

    sliceWidth = imdim

    if (weighting != None) and (float(weighting) < -0.001):
        weightSlice = xp.fft.fftshift(ramp_filter(imdim, imdim))

    if circleFilter:
        circleFilterRadius = imdim // 2
        circleSlice = xp.fft.fftshift(
            circle_filter(imdim, imdim, circleFilterRadius))
    else:
        circleSlice = xp.ones((imdim, imdim))

    # design lowpass filter
    if lowpassFilter:
        if lowpassFilter > 1.:
            lowpassFilter = 1.
            print("Warning: lowpassFilter > 1 - set to 1 (=Nyquist)")

        # weighting filter: arguments: (()dimx, dimy), cutoff radius, sigma
        lpf = xp.fft.fftshift(
            create_circle((imdim, imdim),
                          lowpassFilter * (imdim // 2),
                          sigma=0.4 * lowpassFilter * (imdim // 2)))

    projectionList = ProjectionList()
    for n, image in enumerate(imageList):
        atx = alignmentResults['AlignmentTransX'][n] / binning
        aty = alignmentResults['AlignmentTransY'][n] / binning
        rot = alignmentResults['InPlaneRotation'][n]
        mag = 1 / alignmentResults['Magnification'][n]
        projection = Projection(imageList[n],
                                tiltAngle=tilt_angles[n],
                                alignmentTransX=atx,
                                alignmentTransY=aty,
                                alignmentRotation=rot,
                                alignmentMagnification=mag)
        projectionList.append(projection)

    stack = xp.zeros((imdim, imdim, len(imageList)), dtype=xp.float32)
    phiStack = xp.zeros((1, 1, len(imageList)), dtype=xp.float32)
    thetaStack = xp.zeros((1, 1, len(imageList)), dtype=xp.float32)
    offsetStack = xp.zeros((1, 2, len(imageList)), dtype=xp.float32)

    for (ii, projection) in enumerate(projectionList):
        print(f'Align {projection._filename}')
        image = read(str(projection._filename))[::binning, ::binning].squeeze()

        if lowpassFilter:
            image = xp.abs((xp.fft.ifftn(xp.fft.fftn(image) * lpf)))

        tiltAngle = projection._tiltAngle

        # normalize to contrast - subtract mean and norm to mean
        immean = image.mean()
        image = (image - immean) / immean

        # smoothen borders to prevent high contrast oscillations
        image = taper_edges(image, imdim // 30)[0]

        # transform projection according to tilt alignment
        transX = projection._alignmentTransX / binning
        transY = projection._alignmentTransY / binning
        rot = float(projection._alignmentRotation)
        mag = float(projection._alignmentMagnification)

        inputImage = xp.expand_dims(image, 2).copy()
        outputImage = xp.zeros_like(inputImage, dtype=xp.float32)

        vt.transform(inputImage.astype(xp.float32),
                     rotation=[0, 0, rot],
                     rotation_order='rxyz',
                     output=outputImage,
                     device=device,
                     translation=[transX, transY, 0],
                     scale=[mag, mag, 1],
                     interpolation='filt_bspline')

        image = outputImage.squeeze()

        # smoothen once more to avoid edges
        image = taper_edges(image, imdim // 30)[0]

        # analytical weighting
        if (weighting != None) and (weighting < 0):
            # image = (ifft(complexRealMult(fft(image), w_func)) / (image.sizeX() * image.sizeY() * image.sizeZ()))
            image = xp.fft.ifftn(
                xp.fft.fftn(image) * weightSlice * circleSlice)

        elif (weighting != None) and (weighting > 0):
            weightSlice = xp.fft.fftshift(
                exact_filter(tilt_angles, tiltAngle, imdim, imdim, sliceWidth))
            image = xp.fft.ifftn(
                xp.fft.fftn(image) * weightSlice * circleSlice)

        thetaStack[0, 0, ii] = int(round(projection.getTiltAngle()))
        offsetStack[0, :, ii] = xp.array([
            int(round(projection.getOffsetX())),
            int(round(projection.getOffsetY()))
        ])

        stack[:, :, ii] = image

    arrays = []

    for fname, arr in (('stack.mrc', stack), ('offsetStack.mrc', offsetStack),
                       ('thetaStack.mrc', thetaStack), ('phiStack.mrc',
                                                        phiStack)):
        if 'gpu' in device:
            arr = arr.get()
        import numpy as np
        res = npy2vol(np.array(arr, dtype='float32', order='F'), 3)
        arrays.append(res)

    #
    #     write('stack.mrc', stack)
    #     stack = readCVol('stack.mrc')
    # write('offsetstack.mrc', offsetStack)
    # offsetStack = readCVol('offsetstack.mrc')
    # write('thetastack.mrc', thetaStack)
    # thetaStack = readCVol('thetastack.mrc')
    # write('phistack.mrc', phiStack)
    # phiStack = readCVol('phistack.mrc')
    #
    # os.remove('stack.mrc')
    # os.remove('offsetstack.mrc')
    # os.remove('thetastack.mrc')
    # os.remove('psistack.mrc')

    return arrays
Exemplo n.º 11
0
def alignImageUsingAlignmentResultFile(alignmentResultsFile,
                                       indexImage,
                                       weighting=None,
                                       lowpassFilter=0.9,
                                       binning=1,
                                       circleFilter=False):
    import pytom_freqweight
    from pytom_numpy import vol2npy
    from pytom.gui.guiFunctions import fmtAR, headerAlignmentResults, datatype, datatypeAR, loadstar
    from pytom.reconstruction.reconstructionStructures import Projection, ProjectionList
    from pytom.tompy.io import read, write, read_size
    from pytom.tompy.tools import taper_edges, create_circle
    from pytom.tompy.filter import circle_filter, ramp_filter, exact_filter, ellipse_filter
    import pytom.voltools as vt
    from pytom.gpu.initialize import xp, device

    # print("Create aligned images from alignResults.txt")

    alignmentResults = loadstar(alignmentResultsFile, dtype=datatypeAR)
    imageList = alignmentResults['FileName']
    tilt_angles = alignmentResults['TiltAngle']

    imdimX = read_size(imageList[0], 'x')
    imdimY = read_size(imageList[0], 'y')

    if binning > 1:
        imdimX = int(float(imdimX) / float(binning) + .5)
        imdimY = int(float(imdimY) / float(binning) + .5)

    sliceWidth = imdimX

    if (weighting != None) and (float(weighting) < -0.001):
        weightSlice = xp.fft.fftshift(ramp_filter(imdimY, imdimX))

    if circleFilter:
        circleFilterRadiusX = imdimX // 2
        circleFilterRadiusY = imdimY // 2

        circleSlice = xp.fft.fftshift(
            ellipse_filter(imdimX, imdimY, circleFilterRadiusX,
                           circleFilterRadiusY))
    else:
        circleSlice = xp.ones((imdimX, imdimY))

    # design lowpass filter
    if lowpassFilter:
        if lowpassFilter > 1.:
            lowpassFilter = 1.
            print("Warning: lowpassFilter > 1 - set to 1 (=Nyquist)")

        # weighting filter: arguments: (()dimx, dimy), cutoff radius, sigma
        # lpf = xp.fft.fftshift(create_circle((imdimX,imdimY),lowpassFilter*(imdim//2), sigma=0.4*lowpassFilter*(imdim//2)))

    projectionList = ProjectionList()
    for n, image in enumerate(imageList):
        atx = alignmentResults['AlignmentTransX'][n]
        aty = alignmentResults['AlignmentTransY'][n]
        rot = alignmentResults['InPlaneRotation'][n]
        mag = 1 / (alignmentResults['Magnification'][n])
        projection = Projection(imageList[n],
                                tiltAngle=tilt_angles[n],
                                alignmentTransX=atx,
                                alignmentTransY=aty,
                                alignmentRotation=rot,
                                alignmentMagnification=mag)
        projectionList.append(projection)

    imdim = min(imdimY, imdimX)

    for (ii, projection) in enumerate(projectionList):
        if not ii == indexImage:
            continue
        from pytom.tompy.transform import resize

        # print(f'read {projection._filename}')
        image = read(str(projection._filename)).squeeze()

        if binning > 1:
            image = resize(image, 1 / binning)

        #write(f'test/image_{ii}.mrc', image, tilt_angle=tilt_angles[ii])

        tiltAngle = projection._tiltAngle

        # 1 -- normalize to contrast - subtract mean and norm to mean
        immean = image.mean()
        image = (image - immean) / immean

        # 2 -- smoothen borders to prevent high contrast oscillations
        image = taper_edges(image, imdim // 30)[0]

        # 3 -- square if needed
        if 0 and imdimY != imdimX:
            newImage = xp.zeros((imdim, imdim, 1), dtype=xp.float32)
            pasteCenter(image, newImage)
            image = newImage

        # 4 -- transform projection according to tilt alignment
        transX = projection._alignmentTransX / binning
        transY = projection._alignmentTransY / binning
        rot = float(projection._alignmentRotation)
        mag = float(projection._alignmentMagnification)

        inputImage = xp.expand_dims(image, 2).copy()
        outputImage = xp.zeros_like(inputImage, dtype=xp.float32)

        vt.transform(
            inputImage.astype(xp.float32),
            rotation=[0, 0, rot],
            rotation_order='rxyz',
            output=outputImage,
            center=[inputImage.shape[0] // 2, inputImage.shape[1] // 2, 0],
            device=device,
            translation=[transX, transY, 0],
            scale=[mag, mag, 1],
            interpolation='filt_bspline')

        del image
        image = outputImage.squeeze()

        # 5 -- Optional Low Pass Filter
        if lowpassFilter:
            from pytom.tompy.filter import bandpass_circle

            image = bandpass_circle(
                image,
                high=lowpassFilter * (min(imdimX, imdimY) // 2),
                sigma=0.4 * lowpassFilter * (min(imdimX, imdimY) // 2))
            # image = xp.abs((xp.fft.ifftn(xp.fft.fftn(image) * lpf)))

        # 6 -- smoothen once more to avoid edges
        image = taper_edges(image, imdim // 30)[0]

        # 7 -- analytical weighting
        if (weighting != None) and (weighting < 0):

            # image = (ifft(complexRealMult(fft(image), w_func)) / (image.sizeX() * image.sizeY() * image.sizeZ()))
            image = xp.fft.ifftn(
                xp.fft.fftn(image) * weightSlice.T * circleSlice).real

        elif (weighting != None) and (weighting > 0):
            weightSlice = xp.fft.fftshift(
                exact_filter(tilt_angles, tiltAngle, imdim, imdim, sliceWidth))
            image = xp.fft.ifftn(
                xp.fft.fftn(image) * weightSlice * circleSlice).real

        del inputImage, outputImage, circleSlice

        write(f'inputImage_{ii}.mrc', image)

        return image.astype(xp.float32)
Exemplo n.º 12
0
def scale(volume, factor, interpolation='Spline'):
    """
    scale: Scale a volume by a factor in REAL SPACE - see also resize function for more accurate operation in Fourier \
    space
    @param volume: input volume
    @type volume: L{pytom_volume.vol}
    @param factor: a factor > 0. Factors < 1 will de-magnify the volume, factors > 1 will magnify.
    @type factor: L{float}
    @param interpolation: Can be Spline (default), Cubic or Linear
    @type interpolation: L{str}

    @return: The scaled volume
    @author: Thomas Hrabe
    """

    from pytom.voltools import transform
    from pytom.tompy.tools import paste_in_center
    if factor <= 0:
        raise RuntimeError('Scaling factor must be > 0!')

    interpolation_dict = {
        'Spline': 'filt_bspline',
        'Linear': 'linear',
        'Cubic': 'filt_bspline',
        'filt_bspline': 'filt_bspline',
        'linear': 'linear'
    }
    interpolation = interpolation_dict[interpolation]

    volume = volume.squeeze()

    sizeX = volume.shape[0]
    sizeY = volume.shape[1]
    sizeZ = 1
    newSizeX = int(xp.floor(sizeX * float(factor) + 0.5))
    newSizeY = int(xp.floor(sizeY * float(factor) + 0.5))
    newSizeZ = 1

    if len(volume.shape) == 3:
        sizeZ = volume.shape[2]
        newSizeZ = int(xp.floor(sizeZ * factor + 0.5))
        scaleF = [1 / factor, 1 / factor, 1 / factor]
    else:
        scaleF = [1 / factor, 1 / factor, 1]
        volume = xp.expand_dims(volume, 2)

    if factor == 1:
        rescaledVolume = volume
    elif factor > 1:
        newVolume = xp.zeros((newSizeX, newSizeY, newSizeZ),
                             dtype=volume.dtype)
        newVolume = paste_in_center(volume, newVolume)
        rescaledVolume = xp.zeros_like(newVolume)
        transform(newVolume,
                  scale=scaleF,
                  output=rescaledVolume,
                  device=device,
                  interpolation=interpolation)
    else:
        rescaledVolumeFull = xp.zeros_like(volume)
        transform(volume,
                  scale=scaleF,
                  output=rescaledVolumeFull,
                  device=device,
                  interpolation=interpolation)
        rescaledVolume = xp.zeros((newSizeX, newSizeY, newSizeZ),
                                  dtype=volume.dtype)
        rescaledVolume = paste_in_center(rescaledVolumeFull, rescaledVolume)

    return rescaledVolume