Ejemplo n.º 1
0
def bandCF(volume, reference, band=[0, 100]):
    """
    bandCF:
    @param volume: The volume
    @param reference: The reference
    @param band: [a,b] - specify the lower and upper end of band. [0,1] if not set.
    @return: First parameter - The correlation of the two volumes in the specified ring. 
             Second parameter - The bandpass filter used.
    @rtype: List - [L{pytom_volume.vol},L{pytom_freqweight.weight}]
    @author: Thomas Hrabe   
    @todo: does not work yet -> test is disabled
    """

    if gpu:
        import cupy as xp
    else:
        import numpy as xp

    import pytom_volume
    from math import sqrt
    from pytom.basic import fourier
    from pytom.basic.filter import bandpassFilter
    from pytom.basic.correlation import nXcf

    vf = bandpassFilter(volume, band[0], band[1], fourierOnly=True)
    rf = bandpassFilter(reference, band[0], band[1], vf[1], fourierOnly=True)

    v = pytom_volume.reducedToFull(vf[0])
    r = pytom_volume.reducedToFull(rf[0])

    absV = pytom_volume.abs(v)
    absR = pytom_volume.abs(r)

    pytom_volume.power(absV, 2)
    pytom_volume.power(absR, 2)

    sumV = abs(pytom_volume.sum(absV))
    sumR = abs(pytom_volume.sum(absR))

    if sumV == 0:
        sumV = 1

    if sumR == 0:
        sumR = 1

    pytom_volume.conjugate(rf[0])

    fresult = vf[0] * rf[0]

    #transform back to real space
    result = fourier.ifft(fresult)

    fourier.iftshift(result)

    result.shiftscale(0, 1 / float(sqrt(sumV * sumR)))

    return [result, vf[1]]
Ejemplo n.º 2
0
def powerspectrum(volume):
    """
    compute power spectrum of a volume
    
    @param volume: input volume
    @type volume: L{pytom_volume.vol}
    @return: power spectrum of vol
    @rtype: L{pytom_volume.vol}

    @author: FF
    """
    from pytom.basic.fourier import fft, ftshift
    from pytom_volume import vol
    from pytom_volume import reducedToFull

    fvol = ftshift(reducedToFull(fft(volume)),inplace=False)
    nx=fvol.sizeX()
    ny=fvol.sizeY()
    nz=fvol.sizeZ()
    ps = vol(nx,ny,nz)
    sf = 1./(nx*ny*nz)

    for ix in range(0,nx):
        for iy in range(0,ny):
            for iz in range(0,nz):
                temp = fvol.getV(ix,iy,iz)
                temp = temp*temp.conjugate()*sf
                ps.setV(float(temp.real),ix,iy,iz)
    return ps
Ejemplo n.º 3
0
def rotateWeighting(weighting, z1, z2, x, mask=None, isReducedComplex=None, returnReducedComplex=False, binarize=False):
    """
    rotateWeighting: Rotates a frequency weighting volume around the center. If the volume provided is reduced complex, it will be rescaled to full size, ftshifted, rotated, iftshifted and scaled back to reduced size.
    @param weighting: A weighting volume
    @type weighting: L{pytom_volume.vol}
    @param z1: Z1 rotation angle
    @type z1: float
    @param z2: Z2 rotation angle
    @type z2: float
    @param x: X rotation angle
    @type x: float
    @param mask:=None is there a rotation mask? A mask with all = 1 will be generated otherwise. Such mask should be \
        provided anyway.
    @type mask: L{pytom_volume.vol}
    @param isReducedComplex: Either set to True or False. Will be determined otherwise
    @type isReducedComplex: bool
    @param returnReducedComplex: Return as reduced complex? (Default is False)
    @type returnReducedComplex: bool
    @param binarize: binarize weighting
    @type binarize: bool
    @return: weight as reduced complex volume
    @rtype: L{pytom_volume.vol_comp}
    """
    from pytom_volume import vol, limit, vol_comp
    from pytom_volume import rotate
    assert type(weighting) == vol or  type(weighting) == vol_comp, "rotateWeighting: input neither vol nor vol_comp"
    
    isReducedComplex = isReducedComplex or int(weighting.sizeX()/2)+1 == weighting.sizeZ();

    if isReducedComplex:
        #scale weighting to full size
        from pytom_fftplan import fftShift
        from pytom_volume import reducedToFull
        weighting = reducedToFull(weighting)
        fftShift(weighting, True)

    if not mask:
        mask = vol(weighting.sizeX(),weighting.sizeY(),weighting.sizeZ())
        mask.setAll(1)

    weightingRotated = vol(weighting.sizeX(),weighting.sizeY(),weighting.sizeZ())

    rotate(weighting,weightingRotated,z1,z2,x)
    weightingRotated = weightingRotated * mask
    
    if returnReducedComplex:
        from pytom_fftplan import fftShift
        from pytom_volume import fullToReduced
        fftShift(weightingRotated,True)
        returnVolume = fullToReduced(weightingRotated)
    else:
        returnVolume = weightingRotated

    if binarize:
        limit(returnVolume,0.5,0,0.5,1,True,True)
    
    return returnVolume
Ejemplo n.º 4
0
def shift(volume,shiftX,shiftY,shiftZ,imethod='fourier',twice=False):
    """
    shift: Performs a shift on a volume
    @param volume: the volume
    @param shiftX: shift in x direction
    @param shiftY: shift in y direction
    @param shiftZ: shift in z direction
    @param imethod: Select interpolation method. Real space : linear, cubic, spline . Fourier space: fourier
    @param twice: Zero pad volume into a twice sized volume and perform calculation there.
    @return: The shifted volume.   
    @author: Yuxiang Chen and Thomas Hrabe 
    """
    if imethod == 'fourier':
        from pytom_volume import vol_comp,reducedToFull,fullToReduced,shiftFourier
        from pytom.basic.fourier import fft,ifft,ftshift, iftshift
        
        fvolume = fft(volume)
        fullFVolume = reducedToFull(fvolume)

        destFourier = vol_comp(fullFVolume.sizeX(),fullFVolume.sizeY(),fullFVolume.sizeZ())
        
        shiftFourier(fullFVolume,destFourier,shiftX,shiftY,shiftZ)
        
        resFourier = fullToReduced(destFourier)
        
        return ifft(resFourier)/volume.numelem()
        
    else:
        from pytom_volume import vol
        if imethod == 'linear':
            from pytom_volume import transform
        elif imethod == 'cubic':
            from pytom_volume import transformCubic as transform
        elif imethod == 'spline':
            from pytom_volume import transformSpline as transform
        # now results should be consistent with python2
        centerX = int(volume.sizeX()/2)
        centerY = int(volume.sizeY()/2)
        centerZ = int(volume.sizeZ()/2)
        
        res = vol(volume.sizeX(),volume.sizeY(),volume.sizeZ())
        transform(volume,res,0,0,0,centerX,centerY,centerZ,shiftX,shiftY,shiftZ,0,0,0)
        
        return res
Ejemplo n.º 5
0
def calculate_averages(pl, binning, mask, outdir='./'):
    """
    calcuate averages for particle lists
    @param pl: particle list
    @type pl: L{pytom.basic.structures.ParticleList}
    @param binning: binning factor
    @type binning: C{int}

    last change: Jan 18 2020: error message for too few processes, FF
    """
    import os
    from pytom_volume import complexDiv, vol, pasteCenter
    from pytom.basic.fourier import fft, ifft
    from pytom.basic.correlation import FSC, determineResolution
    from pytom_fftplan import fftShift
    from pytom_volume import reducedToFull

    pls = pl.copy().splitByClass()
    res = {}
    freqs = {}
    wedgeSum = {}

    for pp in pls:
        # ignore the -1 class, which is used for storing the trash class
        class_label = pp[0].getClass()
        if class_label != '-1':
            assert len(pp) > 3
            if len(pp) >= 4 * mpi.size:
                spp = mpi._split_seq(pp, mpi.size)
            else:  # not enough particle to do averaging on one node
                spp = [None] * 2
                spp[0] = pp[:len(pp) // 2]
                spp[1] = pp[len(pp) // 2:]

            args = list(
                zip(spp, [True] * len(spp), [binning] * len(spp),
                    [False] * len(spp), [outdir] * len(spp)))
            avgs = mpi.parfor(paverage, args)

            even_a, even_w, odd_a, odd_w = None, None, None, None
            even_avgs = avgs[1::2]
            odd_avgs = avgs[::2]

            for a, w in even_avgs:
                if even_a is None:
                    even_a = a.getVolume()
                    even_w = w.getVolume()
                else:
                    even_a += a.getVolume()
                    even_w += w.getVolume()
                os.remove(a.getFilename())
                os.remove(w.getFilename())

            for a, w in odd_avgs:
                if odd_a is None:
                    odd_a = a.getVolume()
                    odd_w = w.getVolume()
                else:
                    odd_a += a.getVolume()
                    odd_w += w.getVolume()
                os.remove(a.getFilename())
                os.remove(w.getFilename())

            # determine the resolution
            # raise error message in case even_a == None - only one processor used
            if even_a == None:
                from pytom.basic.exceptions import ParameterError
                raise ParameterError(
                    'cannot split odd / even. Likely you used only one processor - use: mpirun -np 2 (or higher!)?!'
                )

            if mask and mask.__class__ == str:
                from pytom_volume import read, pasteCenter, vol

                maskBin = read(mask, 0, 0, 0, 0, 0, 0, 0, 0, 0, binning,
                               binning, binning)
                if even_a.sizeX() != maskBin.sizeX() or even_a.sizeY(
                ) != maskBin.sizeY() or even_a.sizeZ() != maskBin.sizeZ():
                    mask = vol(even_a.sizeX(), even_a.sizeY(), even_a.sizeZ())
                    mask.setAll(0)
                    pasteCenter(maskBin, mask)
                else:
                    mask = maskBin

            fsc = FSC(even_a, odd_a, int(even_a.sizeX() // 2), mask)
            band = determineResolution(fsc, 0.5)[1]

            aa = even_a + odd_a
            ww = even_w + odd_w
            fa = fft(aa)
            r = complexDiv(fa, ww)
            rr = ifft(r)
            rr.shiftscale(0.0, 1. / (rr.sizeX() * rr.sizeY() * rr.sizeZ()))

            res[class_label] = rr
            freqs[class_label] = band

            ww2 = reducedToFull(ww)
            fftShift(ww2, True)
            wedgeSum[class_label] = ww2
    print('done')
    return res, freqs, wedgeSum
Ejemplo n.º 6
0
    def start(self, job, verbose=False):
        if self.mpi_id == 0:
            from pytom.basic.structures import ParticleList, Reference
            from pytom.basic.resolution import bandToAngstrom
            from pytom.basic.filter import lowpassFilter
            from math import ceil

            # randomly split the particle list into 2 half sets
            if len(job.particleList.splitByClass()) != 2:
                import numpy as np
                n = len(job.particleList)
                labels = np.random.randint(2, size=(n, ))
                print(self.node_name + ': Number of 1st half set:',
                      n - np.sum(labels), 'Number of 2nd half set:',
                      np.sum(labels))
                for i in range(n):
                    p = job.particleList[i]
                    p.setClass(labels[i])

            self.destination = job.destination
            new_reference = job.reference
            old_freq = job.freq
            new_freq = job.freq
            # main node
            for i in range(job.max_iter):
                if verbose:
                    print(self.node_name + ': starting iteration %d ...' % i)

                # construct a new job by updating the reference and the frequency
                new_job = FRMJob(job.particleList, new_reference, job.mask,
                                 job.peak_offset, job.sampleInformation,
                                 job.bw_range, new_freq, job.destination,
                                 job.max_iter - i, job.r_score, job.weighting)

                # distribute it
                self.distribute_job(new_job, verbose)

                # get the result back
                all_even_pre = None  # the 1st set
                all_even_wedge = None
                all_odd_pre = None  # the 2nd set
                all_odd_wedge = None
                pl = ParticleList()
                for j in range(self.num_workers):
                    result = self.get_result()
                    pl += result.pl
                    pre, wedge = self.retrieve_res_vols(result.name)

                    if self.assignment[result.worker_id] == 0:
                        if all_even_pre:
                            all_even_pre += pre
                            all_even_wedge += wedge
                        else:
                            all_even_pre = pre
                            all_even_wedge = wedge
                    else:
                        if all_odd_pre:
                            all_odd_pre += pre
                            all_odd_wedge += wedge
                        else:
                            all_odd_pre = pre
                            all_odd_wedge = wedge

                # write the new particle list to the disk
                pl.toXMLFile('aligned_pl_iter' + str(i) + '.xml')

                # create the averages separately
                if verbose:
                    print(self.node_name + ': determining the resolution ...')
                even = self.create_average(all_even_pre, all_even_wedge)
                odd = self.create_average(all_odd_pre, all_odd_wedge)

                # apply symmetries if any
                even = job.symmetries.applyToParticle(even)
                odd = job.symmetries.applyToParticle(odd)

                # determine the transformation between even and odd
                # here we assume the wedge from both sets are fully sampled
                from sh_alignment.frm import frm_align
                pos, angle, score = frm_align(odd, None, even, None,
                                              job.bw_range, new_freq,
                                              job.peak_offset)
                print(self.node_name +
                      'Transform of even set to match the odd set - shift: ' +
                      str(pos) + ' rotation: ' + str(angle))

                # transform the odd set accordingly
                from pytom_volume import vol, transformSpline
                from pytom.basic.fourier import ftshift
                from pytom_volume import reducedToFull
                from pytom_freqweight import weight
                transformed_odd_pre = vol(odd.sizeX(), odd.sizeY(),
                                          odd.sizeZ())
                full_all_odd_wedge = reducedToFull(all_odd_wedge)
                ftshift(full_all_odd_wedge)
                odd_weight = weight(
                    full_all_odd_wedge)  # the funny part of pytom
                transformed_odd = vol(odd.sizeX(), odd.sizeY(), odd.sizeZ())

                transformSpline(all_odd_pre, transformed_odd_pre, -angle[1],
                                -angle[0], -angle[2],
                                odd.sizeX() / 2,
                                odd.sizeY() / 2,
                                odd.sizeZ() / 2, -(pos[0] - odd.sizeX() / 2),
                                -(pos[1] - odd.sizeY() / 2),
                                -(pos[2] - odd.sizeZ() / 2), 0, 0, 0)
                odd_weight.rotate(-angle[1], -angle[0], -angle[2])
                transformed_odd_wedge = odd_weight.getWeightVolume(True)
                transformSpline(odd, transformed_odd, -angle[1], -angle[0],
                                -angle[2],
                                odd.sizeX() / 2,
                                odd.sizeY() / 2,
                                odd.sizeZ() / 2, -(pos[0] - odd.sizeX() / 2),
                                -(pos[1] - odd.sizeY() / 2),
                                -(pos[2] - odd.sizeZ() / 2), 0, 0, 0)

                all_odd_pre = transformed_odd_pre
                all_odd_wedge = transformed_odd_wedge
                odd = transformed_odd

                # determine resolution
                resNyquist, resolutionBand, numberBands = self.determine_resolution(
                    even, odd, job.fsc_criterion, None, job.mask, verbose)

                # write the half set to the disk
                even.write(
                    os.path.join(self.destination,
                                 'fsc_' + str(i) + '_even.em'))
                odd.write(
                    os.path.join(self.destination,
                                 'fsc_' + str(i) + '_odd.em'))

                current_resolution = bandToAngstrom(
                    resolutionBand, job.sampleInformation.getPixelSize(),
                    numberBands, 1)
                if verbose:
                    print(
                        self.node_name + ': current resolution ' +
                        str(current_resolution), resNyquist)

                # create new average
                all_even_pre += all_odd_pre
                all_even_wedge += all_odd_wedge
                average = self.create_average(all_even_pre, all_even_wedge)

                # apply symmetries
                average = job.symmetries.applyToParticle(average)

                # filter average to resolution
                average_name = os.path.join(self.destination,
                                            'average_iter' + str(i) + '.em')
                average.write(average_name)

                # update the references
                new_reference = [
                    Reference(
                        os.path.join(self.destination,
                                     'fsc_' + str(i) + '_even.em')),
                    Reference(
                        os.path.join(self.destination,
                                     'fsc_' + str(i) + '_odd.em'))
                ]

                # low pass filter the reference and write it to the disk
                filtered = lowpassFilter(average, ceil(resolutionBand),
                                         ceil(resolutionBand) / 10)
                filtered_ref_name = os.path.join(
                    self.destination, 'average_iter' + str(i) + '_res' +
                    str(current_resolution) + '.em')
                filtered[0].write(filtered_ref_name)

                # if the position/orientation is not improved, break it

                # change the frequency to a higher value
                new_freq = int(ceil(resolutionBand)) + 1
                if new_freq <= old_freq:
                    if job.adaptive_res is not False:  # two different strategies
                        print(
                            self.node_name +
                            ': Determined resolution gets worse. Include additional %f percent frequency to be aligned!'
                            % job.adaptive_res)
                        new_freq = int((1 + job.adaptive_res) * old_freq)
                    else:  # always increase by 1
                        print(
                            self.node_name +
                            ': Determined resolution gets worse. Increase the frequency to be aligned by 1!'
                        )
                        new_freq = old_freq + 1
                        old_freq = new_freq
                else:
                    old_freq = new_freq
                if new_freq >= numberBands:
                    print(self.node_name +
                          ': New frequency too high. Terminate!')
                    break

                if verbose:
                    print(self.node_name + ': change the frequency to ' +
                          str(new_freq))

            # send end signal to other nodes and terminate itself
            self.end(verbose)
        else:
            # other nodes
            self.run(verbose)
Ejemplo n.º 7
0
    import numpy as np

    num_angles, size = map(int, sys.argv[1:3])
    size2 = int(sys.argv[3])
    csize = int(sys.argv[4])

    try:
        start, end = map(int, sys.argv[5:7])
    except:
        start, end = 0, csize

    wedgeAngle = 30
    wedgeFilter = weight(wedgeAngle, 0, end - start, size, size)
    wedgeVolume = wedgeFilter.getWeightVolume(True)

    filterVolume = pytom_volume.reducedToFull(wedgeVolume)
    wedgeV = vol2npy(filterVolume).copy()

    from pytom.tompy.io import read
    import mrcfile

    # NDARRAYS
    voluNDA = mrcfile.open('tomo.mrc', permissive=True).data.copy()
    tempNDA = read('template.em')
    maskNDA = read("mask.em")

    sox, soy, soz = tempNDA.shape
    spx, spy, spz = voluNDA.shape

    #GPUARRAYS
    voluGPU = gu.to_gpu(voluNDA.astype(np.float32))
Ejemplo n.º 8
0
def frm_correlate(vf,
                  wf,
                  vg,
                  wg,
                  b,
                  max_freq,
                  weights=None,
                  ps=False,
                  denominator1=None,
                  denominator2=None,
                  return_score=True):
    """Calculate the correlation of two volumes as a function of Euler angle.

    Parameters
    ----------
    vf: Volume Nr. 1
        pytom_volume.vol

    wf: Mask of vf in Fourier space.
        pytom.basic.structures.Wedge

    vg: Volume Nr. 2 / Reference
        pytom_volume.vol

    wg: Mask of vg in Fourier space.
        pytom.basic.structures.Wedge

    b: Bandwidth range of spherical harmonics.
       None -> [4, 64]
       List -> [b_min, b_max]
       Integer -> [b, b]

    max_freq: Maximal frequency involved in calculation.
              Integer.

    weights: Obsolete.

    ps: Calculation based on only the power spectrum of two volumes or not.
        Boolean. Default is False.

    denominator1: If the denominator1 is provided or not. If yes, do not have to re-calculate it again.
                  This field is used out of computation effeciency consideration.
                  Default is None, not provided.

    denominator2: If the denominator2 is provided or not. If yes, do not have to re-calculate it again.
                  This field is used out of computation effeciency consideration.
                  Default is None, not provided.

    return_score: Return the correlation score or return the intermediate result (numerator, denominator1, denominator2).
                  Boolean, default is True.

    Returns
    -------
    If return_score is set to True, return the correlation function; otherwise return the intermediate result.
    """
    if not weights:  # weights, not used yet
        weights = [1 for i in xrange(max_freq)]

    from pytom.basic.fourier import fft, ifft, ftshift, iftshift
    from pytom_volume import vol, reducedToFull, abs, real, imag, rescale
    from vol2sf import vol2sf
    from math import log, ceil, pow

    # IMPORTANT!!! Should firstly do the IFFTSHIFT on the volume data (NOT FFTSHIFT since for odd-sized data it matters!),
    # and then followed by the FFT.
    vf = ftshift(reducedToFull(fft(iftshift(vf, inplace=False))),
                 inplace=False)
    vg = ftshift(reducedToFull(fft(iftshift(vg, inplace=False))),
                 inplace=False)

    if ps:  # power spectrum only
        ff = abs(vf)
        ff = real(ff)
        gg = abs(vg)
        gg = real(gg)
    else:  # use spline intepolation on the real/imaginary parts. Can be done better, but now it suffices.
        vfr = real(vf)
        vfi = imag(vf)
        vgr = real(vg)
        vgi = imag(vg)

    numerator = None
    if denominator1 is not None and denominator2 is not None:
        to_calculate = 1
    elif denominator1 is None and denominator2 is not None:
        to_calculate = 2
    else:
        to_calculate = 0

    _last_bw = 0
    # might be a better idea to start from 2 due to the bad interpolation around 0 frequency!
    # this can be better solved by NFFT!
    for r in xrange(1, max_freq + 1):
        # calculate the appropriate bw
        bw = get_adaptive_bw(r, b)

        # construct the wedge masks accordingly
        # now this part has been shifted to Pytom
        # if _last_bw != bw:
        #     # mf = create_wedge_sf(wf[0], wf[1], bw)
        #     # mg = create_wedge_sf(wg[0], wg[1], bw)
        #     mf = wf.toSphericalFunc(bw)
        #     mg = wg.toSphericalFunc(bw)
        mf = wf.toSphericalFunc(bw, r)
        mg = wg.toSphericalFunc(bw, r)

        if ps:
            corr1, corr2, corr3 = sph_correlate_ps(vol2sf(ff, r, bw), mf,
                                                   vol2sf(gg, r, bw), mg,
                                                   to_calculate)
        else:
            corr1, corr2, corr3 = sph_correlate_fourier(
                vol2sf(vfr, r, bw), vol2sf(vfi, r, bw), mf, vol2sf(vgr, r, bw),
                vol2sf(vgi, r, bw), mg, to_calculate)

        if _last_bw != bw:  # size is different, have to do enlarge
            if numerator is None:
                numerator = np.zeros((2 * bw, 2 * bw, 2 * bw), dtype='double')
                if to_calculate == 1:
                    pass
                elif to_calculate == 2:
                    denominator1 = np.zeros((2 * bw, 2 * bw, 2 * bw),
                                            dtype='double')
                else:
                    denominator1 = np.zeros((2 * bw, 2 * bw, 2 * bw),
                                            dtype='double')
                    denominator2 = np.zeros((2 * bw, 2 * bw, 2 * bw),
                                            dtype='double')
            else:
                numerator = enlarge2(numerator)
                if to_calculate == 1:
                    pass
                elif to_calculate == 2:
                    denominator1 = enlarge2(denominator1)
                else:
                    denominator1 = enlarge2(denominator1)
                    denominator2 = enlarge2(denominator2)

        numerator += corr1 * (r**2) * weights[r - 1]
        if to_calculate == 1:
            pass
        elif to_calculate == 2:
            denominator1 += corr2 * (r**2) * weights[r - 1]
        else:
            denominator1 += corr2 * (r**2) * weights[r - 1]
            denominator2 += corr3 * (r**2) * weights[r - 1]

        _last_bw = bw

    if return_score:
        res = numerator / (denominator1 * denominator2)**0.5
        return res
    else:
        return (numerator, denominator1, denominator2)
Ejemplo n.º 9
0
def average(particleList,
            averageName,
            showProgressBar=False,
            verbose=False,
            createInfoVolumes=False,
            weighting=False,
            norm=False,
            gpuId=None):
    """
    average : Creates new average from a particleList
    @param particleList: The particles
    @param averageName: Filename of new average
    @param verbose: Prints particle information. Disabled by default.
    @param createInfoVolumes: Create info data (wedge sum, inverted density) too? False by default.
    @param weighting: apply weighting to each average according to its correlation score
    @param norm: apply normalization for each particle
    @return: A new Reference object
    @rtype: L{pytom.basic.structures.Reference}
    @author: Thomas Hrabe
    @change: limit for wedgeSum set to 1% or particles to avoid division by small numbers - FF
    """
    from pytom_volume import read, vol, reducedToFull, limit, complexRealMult
    from pytom.basic.filter import lowpassFilter, rotateWeighting
    from pytom_volume import transformSpline as transform
    from pytom.basic.fourier import convolute
    from pytom.basic.structures import Reference, Rotation
    from pytom.basic.normalise import mean0std1
    from pytom.tools.ProgressBar import FixedProgBar
    from math import exp
    import os
    from pytom.basic.functions import initSphere
    from pytom.basic.filter import filter

    if len(particleList) == 0:
        raise RuntimeError('The particle list is empty. Aborting!')

    if showProgressBar:
        progressBar = FixedProgBar(0, len(particleList), 'Particles averaged ')
        progressBar.update(0)
        numberAlignedParticles = 0

    result = []
    wedgeSum = []

    newParticle = None
    # pre-check that scores != 0
    if weighting:
        wsum = 0.
        for particleObject in particleList:
            wsum += particleObject.getScore().getValue()
        if wsum < 0.00001:
            weighting = False
            print("Warning: all scores have been zero - weighting not applied")

    n = 0

    for particleObject in particleList:
        if 0 and verbose:
            print(particleObject)

        if not os.path.exists(particleObject.getFilename()):
            continue
        particle = read(particleObject.getFilename())
        if norm:  # normalize the particle
            mean0std1(particle)  # happen inplace

        wedgeInfo = particleObject.getWedge()
        # apply its wedge to itself

        rotation = particleObject.getRotation()
        rotinvert = rotation.invert()

        if not result:
            sizeX = particle.sizeX()
            sizeY = particle.sizeY()
            sizeZ = particle.sizeZ()

            newParticle = vol(sizeX, sizeY, sizeZ)

            centerX = sizeX // 2
            centerY = sizeY // 2
            centerZ = sizeZ // 2

            result = vol(sizeX, sizeY, sizeZ)
            result.setAll(0.0)

            if analytWedge:
                wedgeSum = wedgeInfo.returnWedgeVolume(wedgeSizeX=sizeX,
                                                       wedgeSizeY=sizeY,
                                                       wedgeSizeZ=sizeZ)
            else:
                # > FF bugfix
                wedgeSum = wedgeInfo.returnWedgeVolume(sizeX, sizeY, sizeZ)
                # < FF
                # > TH bugfix
                # wedgeSum = vol(sizeX,sizeY,sizeZ)
                # < TH
                # wedgeSum.setAll(0)
            assert wedgeSum.sizeX() == sizeX and wedgeSum.sizeY() == sizeY and wedgeSum.sizeZ() == sizeZ / 2 + 1, \
                "wedge initialization result in wrong dims :("
            wedgeSum.setAll(0)
            wedgeFilter = wedgeInfo.returnWedgeFilter(particle.sizeX(),
                                                      particle.sizeY(),
                                                      particle.sizeZ())

        particle = particle

        particle = list(filter(particle, wedgeFilter))[0]

        ### create spectral wedge weighting
        if analytWedge:
            # > analytical buggy version
            wedge = wedgeInfo.returnWedgeVolume(sizeX, sizeY, sizeZ, False,
                                                rotinvert)
        else:
            # > FF: interpol bugfix
            wedge = rotateWeighting(weighting=wedgeInfo.returnWedgeVolume(
                sizeX, sizeY, sizeZ, False),
                                    z1=rotinvert[0],
                                    z2=rotinvert[1],
                                    x=rotinvert[2],
                                    mask=None,
                                    isReducedComplex=True,
                                    returnReducedComplex=True)
            # wedge = wedgeInfo.returnWedgeVolume(sizeX, sizeY, sizeZ, False, rotation=rotinvert)

            # < FF
            # > TH bugfix
            # wedgeVolume = wedgeInfo.returnWedgeVolume(wedgeSizeX=sizeX, wedgeSizeY=sizeY, wedgeSizeZ=sizeZ,
            #                                    humanUnderstandable=True, rotation=rotinvert)
            # wedge = rotate(volume=wedgeVolume, rotation=rotinvert, imethod='linear')
            # < TH

        ### shift and rotate particle
        shiftV = particleObject.getShift()
        newParticle.setAll(0)

        transform(particle, newParticle, -rotation[1], -rotation[0],
                  -rotation[2], centerX, centerY, centerZ, -shiftV[0],
                  -shiftV[1], -shiftV[2], 0, 0, 0)

        if weighting:
            weight = 1. - particleObject.getScore().getValue()
            # weight = weight**2
            weight = exp(-1. * weight)
            result = result + newParticle * weight
            wedgeSum = wedgeSum + wedge * weight
        else:
            result = result + newParticle
            wedgeSum = wedgeSum + wedge

        if showProgressBar:
            numberAlignedParticles = numberAlignedParticles + 1
            progressBar.update(numberAlignedParticles)

        n += 1
    ###apply spectral weighting to sum
    result = lowpassFilter(result, sizeX / 2 - 1, 0.)[0]

    # if createInfoVolumes:
    result.write(averageName[:len(averageName) - 3] + '-PreWedge.em')

    # wedgeSum = wedgeSum*0+len(particleList)
    wedgeSum.write(averageName[:len(averageName) - 3] + '-WedgeSumUnscaled.em')
    invert_WedgeSum(invol=wedgeSum,
                    r_max=sizeX / 2 - 2.,
                    lowlimit=.05 * len(particleList),
                    lowval=.05 * len(particleList))

    if createInfoVolumes:
        w1 = reducedToFull(wedgeSum)
        w1.write(averageName[:len(averageName) - 3] + '-WedgeSumInverted.em')

    result = convolute(v=result, k=wedgeSum, kernel_in_fourier=True)

    # do a low pass filter
    # result = lowpassFilter(result, sizeX/2-2, (sizeX/2-1)/10.)[0]
    result.write(averageName)

    if createInfoVolumes:
        resultINV = result * -1
        # write sign inverted result to disk (good for chimera viewing ... )
        resultINV.write(averageName[:len(averageName) - 3] + '-INV.em')
    newReference = Reference(averageName, particleList)

    return newReference
Ejemplo n.º 10
0
b = 16
r = 4

m = np.ones(4*b**2)
# w = m
w = create_wedge_sf(-60, 60, b)

dist = []
dist2 = []
dist3 = []
for i in xrange(100):
	phi = np.random.randint(360)
	psi = np.random.randint(360)
	the = np.random.randint(180)

	fv1 = ftshift(reducedToFull(fft(iftshift(v, inplace=False))), inplace=False)

	# 1. rotate in real space and use the frm_fourier_corr to find the angle
	# This is the least accurate way, since the interpolation happens in real space.
	# rotateSpline(v, v2, phi, psi, the)
	# fv2 = ftshift(reducedToFull(fft(iftshift(v2))))
	# res = frm_fourier_corr(vol2sf(real(fv2), r, b), vol2sf(imag(fv2), r, b), vol2sf(real(fv1), r, b), vol2sf(imag(fv1), r, b))

	# 2. rotate real and imag parts seperately and feed into the frm_fourier_corr
	fr = real(fv1)
	fi = imag(fv1)

	# rotateSpline(v, v2, phi, psi, the)
	# fv2 = ftshift(reducedToFull(fft(iftshift(v2))))
	# fr2 = real(fv2)
	# fi2 = imag(fv2)
Ejemplo n.º 11
0
def weightedXCF(volume,reference,numberOfBands,wedgeAngle=-1):
    """
    weightedXCF: Determines the weighted correlation function for volume and reference
    @param volume: A volume 
    @param reference: A reference 
    @param numberOfBands:Number of bands
    @param wedgeAngle: A optional wedge angle
    @return: The weighted correlation function 
    @rtype: L{pytom_volume.vol} 
    @author: Thomas Hrabe 
    @todo: does not work yet -> test is disabled
    """
    from pytom.basic.correlation import bandCF
    import pytom_volume
    from math import sqrt
    import pytom_freqweight
    
    result = pytom_volume.vol(volume.sizeX(),volume.sizeY(),volume.sizeZ())
    result.setAll(0)
    cc2 = pytom_volume.vol(volume.sizeX(),volume.sizeY(),volume.sizeZ())
    cc2.setAll(0)
    q = 0
    
    if wedgeAngle >=0:
        wedgeFilter = pytom_freqweight.weight(wedgeAngle,0,volume.sizeX(),volume.sizeY(),volume.sizeZ())
        wedgeVolume = wedgeFilter.getWeightVolume(True)
    else:
        wedgeVolume = pytom_volume.vol(volume.sizeX(), volume.sizeY(), int(volume.sizeZ()/2+1))
        wedgeVolume.setAll(1.0)
        
    w = sqrt(1/float(volume.sizeX()*volume.sizeY()*volume.sizeZ()))
    
    numberVoxels = 0
    
    for i in range(numberOfBands):
        """
        notation according Steward/Grigorieff paper
        """
        band = [0,0]
        band[0] = i*volume.sizeX()/numberOfBands 
        band[1] = (i+1)*volume.sizeX()/numberOfBands
        
        r = bandCF(volume,reference,band)
        
        cc = r[0]
                
        filter = r[1]
        #get bandVolume
        bandVolume = filter.getWeightVolume(True)
            
        filterVolumeReduced = bandVolume * wedgeVolume
        filterVolume = pytom_volume.reducedToFull(filterVolumeReduced)
        #determine number of voxels != 0    
        N = pytom_volume.numberSetVoxels(filterVolume)
            
        #add to number of total voxels
        numberVoxels = numberVoxels + N
                 
        cc2.copyVolume(r[0])
                
        pytom_volume.power(cc2,2)
        
        cc.shiftscale(w,1)
        ccdiv = cc2/(cc)
                
        pytom_volume.power(ccdiv,3)
        
        #abs(ccdiv); as suggested by grigorief
        ccdiv.shiftscale(0,N)
        
        result = result + ccdiv
    
    result.shiftscale(0,1/float(numberVoxels))
    
    return result
Ejemplo n.º 12
0
def weightedXCC(volume,reference,numberOfBands,wedgeAngle=-1):
        """
        weightedXCC: Determines the band weighted correlation coefficient for a volume and reference. Notation according Steward/Grigorieff paper
        @param volume: A volume
        @type volume: L{pytom_volume.vol}
        @param reference: A reference of same size as volume
        @type reference: L{pytom_volume.vol}
        @param numberOfBands: Number of bands
        @param wedgeAngle: A optional wedge angle
        @return: The weighted correlation coefficient
        @rtype: float  
        @author: Thomas Hrabe   
        """    
        import pytom_volume
        from pytom.basic.fourier import fft
        from math import sqrt
        import pytom_freqweight
        result = 0
        numberVoxels = 0
        
        #volume.write('vol.em');
        #reference.write('ref.em');
        fvolume = fft(volume)
        freference = fft(reference)
        numelem = volume.numelem()
        
        fvolume.shiftscale(0,1/float(numelem))
        freference.shiftscale(0,1/float(numelem))
        from pytom.basic.structures import WedgeInfo
        wedge = WedgeInfo(wedgeAngle)
        wedgeVolume = wedge.returnWedgeVolume(volume.sizeX(),volume.sizeY(),volume.sizeZ())
        
        increment = int(volume.sizeX()/2 * 1/numberOfBands)
        band = [0,100]
        for i in range(0,volume.sizeX()/2, increment):
        
            band[0] = i
            band[1] = i + increment
    
            r = bandCC(volume,reference,band)
            cc = r[0]
            
            #print cc;
            filter = r[1]
            #get bandVolume
            bandVolume = filter.getWeightVolume(True)
            
            filterVolumeReduced = bandVolume * wedgeVolume
            filterVolume = pytom_volume.reducedToFull(filterVolumeReduced)
            
            #determine number of voxels != 0    
            N = pytom_volume.numberSetVoxels(filterVolume)
            
            w = sqrt(1/float(N))
            
            #add to number of total voxels
            numberVoxels=numberVoxels + N
            #print 'w',w;
            #print 'cc',cc;
            #print 'N',N;
            
            cc2 = cc*cc
            #print 'cc2',cc2;
            if cc <= 0.0:
                cc = cc2
            else:
                cc = cc2/(cc+w)
            
            #print 'cc',cc;
            cc = cc *cc *cc; #no abs
            #print 'cc',cc;
            
            #add up result
            result = result + cc*N
        
        return result*(1/float(numberVoxels))
Ejemplo n.º 13
0
    def start(self, job, verbose=False):
        if self.mpi_id == 0:
            from pytom.basic.structures import ParticleList, Reference
            from pytom.basic.resolution import bandToAngstrom
            from pytom.basic.filter import lowpassFilter
            from math import ceil
            from pytom.basic.fourier import convolute
            from pytom_volume import vol, power, read

            # randomly split the particle list into 2 half sets
            import numpy as np
            num_pairs = len(job.particleList.pairs)
            for i in range(num_pairs):
                # randomize the class labels to indicate the two half sets
                pl = job.particleList.pairs[i].get_phase_flip_pl()
                n = len(pl)
                labels = np.random.randint(2, size=(n, ))
                print(self.node_name + ': Number of 1st half set:',
                      n - np.sum(labels), 'Number of 2nd half set:',
                      np.sum(labels))
                for j in range(n):
                    p = pl[j]
                    p.setClass(labels[j])

            new_reference = job.reference
            old_freq = job.freq
            new_freq = job.freq
            # main node
            for i in range(job.max_iter):
                if verbose:
                    print(self.node_name + ': starting iteration %d ...' % i)

                # construct a new job by updating the reference and the frequency
                # here the job.particleList is actually ParticleListSet
                new_job = MultiDefocusJob(job.particleList, new_reference,
                                          job.mask, job.peak_offset,
                                          job.sampleInformation, job.bw_range,
                                          new_freq, job.destination,
                                          job.max_iter - i, job.r_score,
                                          job.weighting, job.bfactor)

                # distribute it
                num_all_particles = self.distribute_job(new_job, verbose)

                # calculate the denominator
                sum_ctf_squared = None
                for pair in job.particleList.pairs:
                    if sum_ctf_squared is None:
                        sum_ctf_squared = pair.get_ctf_sqr_vol() * pair.snr
                    else:
                        sum_ctf_squared += pair.get_ctf_sqr_vol() * pair.snr

                # get the result back
                all_even_pre = None
                all_even_wedge = None
                all_odd_pre = None
                all_odd_wedge = None
                pls = []
                for j in range(len(job.particleList.pairs)):
                    pls.append(ParticleList())

                for j in range(self.num_workers):
                    result = self.get_result()

                    pair_id = self.assignment[result.worker_id]
                    pair = job.particleList.pairs[pair_id]

                    pl = pls[pair_id]
                    pl += result.pl
                    even_pre, even_wedge, odd_pre, odd_wedge = self.retrieve_res_vols(
                        result.name)

                    if all_even_pre:
                        all_even_pre += even_pre * pair.snr
                        all_even_wedge += even_wedge
                        all_odd_pre += odd_pre * pair.snr
                        all_odd_wedge += odd_wedge
                    else:
                        all_even_pre = even_pre * pair.snr
                        all_even_wedge = even_wedge
                        all_odd_pre = odd_pre * pair.snr
                        all_odd_wedge = odd_wedge

                # write the new particle list to the disk
                for j in range(len(job.particleList.pairs)):
                    pls[j].toXMLFile('aligned_pl' + str(j) + '_iter' + str(i) +
                                     '.xml')

                # correct for the number of particles in wiener filter
                sum_ctf_squared = sum_ctf_squared / num_all_particles
                #                all_even_pre = all_even_pre/(num_all_particles/2)
                #                all_odd_pre = all_odd_pre/(num_all_particles/2)

                # bfactor
                if job.bfactor and job.bfactor != 'None':
                    #                    bfactor_kernel = create_bfactor_vol(sum_ctf_squared.sizeX(), job.sampleInformation.getPixelSize(), job.bfactor)
                    bfactor_kernel = read(job.bfactor)
                    bfactor_kernel_sqr = vol(bfactor_kernel)
                    power(bfactor_kernel_sqr, 2)
                    all_even_pre = convolute(all_even_pre, bfactor_kernel,
                                             True)
                    all_odd_pre = convolute(all_odd_pre, bfactor_kernel, True)
                    sum_ctf_squared = sum_ctf_squared * bfactor_kernel_sqr

                # create averages of two sets
                if verbose:
                    print(self.node_name + ': determining the resolution ...')
                even = self.create_average(
                    all_even_pre, sum_ctf_squared, all_even_wedge
                )  # assume that the CTF sum is the same for the even and odd
                odd = self.create_average(all_odd_pre, sum_ctf_squared,
                                          all_odd_wedge)

                # determine the transformation between even and odd
                # here we assume the wedge from both sets are fully sampled
                from sh_alignment.frm import frm_align
                pos, angle, score = frm_align(odd, None, even, None,
                                              job.bw_range, new_freq,
                                              job.peak_offset)
                print(
                    self.node_name +
                    ': transform of even set to match the odd set - shift: ' +
                    str(pos) + ' rotation: ' + str(angle))

                # transform the odd set accordingly
                from pytom_volume import vol, transformSpline
                from pytom.basic.fourier import ftshift
                from pytom_volume import reducedToFull
                from pytom_freqweight import weight
                transformed_odd_pre = vol(odd.sizeX(), odd.sizeY(),
                                          odd.sizeZ())
                full_all_odd_wedge = reducedToFull(all_odd_wedge)
                ftshift(full_all_odd_wedge)
                odd_weight = weight(
                    full_all_odd_wedge)  # the funny part of pytom
                transformed_odd = vol(odd.sizeX(), odd.sizeY(), odd.sizeZ())

                transformSpline(all_odd_pre, transformed_odd_pre, -angle[1],
                                -angle[0], -angle[2], int(odd.sizeX() / 2),
                                int(odd.sizeY() / 2), int(odd.sizeZ() / 2),
                                -(pos[0] - odd.sizeX() / 2),
                                -(pos[1] - odd.sizeY() / 2),
                                -(pos[2] - odd.sizeZ() / 2), 0, 0, 0)
                odd_weight.rotate(-angle[1], -angle[0], -angle[2])
                transformed_odd_wedge = odd_weight.getWeightVolume(True)
                transformSpline(odd, transformed_odd, -angle[1], -angle[0],
                                -angle[2], int(odd.sizeX() / 2),
                                int(odd.sizeY() / 2), int(odd.sizeZ() / 2),
                                -(pos[0] - odd.sizeX() / 2),
                                -(pos[1] - odd.sizeY() / 2),
                                -(pos[2] - odd.sizeZ() / 2), 0, 0, 0)

                all_odd_pre = transformed_odd_pre
                all_odd_wedge = transformed_odd_wedge
                odd = transformed_odd

                # apply symmetries before determine resolution
                # with gold standard you should be careful about applying the symmetry!
                even = job.symmetries.applyToParticle(even)
                odd = job.symmetries.applyToParticle(odd)
                resNyquist, resolutionBand, numberBands = self.determine_resolution(
                    even, odd, job.fsc_criterion, None, job.mask, verbose)

                # write the half set to the disk
                even.write('fsc_' + str(i) + '_even.em')
                odd.write('fsc_' + str(i) + '_odd.em')

                current_resolution = bandToAngstrom(
                    resolutionBand, job.sampleInformation.getPixelSize(),
                    numberBands, 1)
                if verbose:
                    print(
                        self.node_name + ': current resolution ' +
                        str(current_resolution), resNyquist)

                # create new average
                all_even_pre += all_odd_pre
                all_even_wedge += all_odd_wedge
                #                all_even_pre = all_even_pre/2 # correct for the number of particles in wiener filter
                average = self.create_average(all_even_pre, sum_ctf_squared,
                                              all_even_wedge)

                # apply symmetries
                average = job.symmetries.applyToParticle(average)

                # filter average to resolution and update the new reference
                average_name = 'average_iter' + str(i) + '.em'
                average.write(average_name)

                # update the references
                new_reference = [
                    Reference('fsc_' + str(i) + '_even.em'),
                    Reference('fsc_' + str(i) + '_odd.em')
                ]

                # low pass filter the reference and write it to the disk
                filtered = lowpassFilter(average, ceil(resolutionBand),
                                         ceil(resolutionBand) / 10)
                filtered_ref_name = 'average_iter' + str(i) + '_res' + str(
                    current_resolution) + '.em'
                filtered[0].write(filtered_ref_name)

                # change the frequency to a higher value
                new_freq = int(ceil(resolutionBand)) + 1
                if new_freq <= old_freq:
                    if job.adaptive_res is not False:  # two different strategies
                        print(
                            self.node_name +
                            ': Determined resolution gets worse. Include additional %f percent frequency to be aligned!'
                            % job.adaptive_res)
                        new_freq = int((1 + job.adaptive_res) * old_freq)
                    else:  # always increase by 1
                        print(
                            self.node_name +
                            ': Determined resolution gets worse. Increase the frequency to be aligned by 1!'
                        )
                        new_freq = old_freq + 1
                        old_freq = new_freq
                else:
                    old_freq = new_freq
                if new_freq >= numberBands:
                    print(self.node_name +
                          ': Determined frequency too high. Terminate!')
                    break

                if verbose:
                    print(self.node_name + ': change the frequency to ' +
                          str(new_freq))

            # send end signal to other nodes and terminate itself
            self.end(verbose)
        else:
            # other nodes
            self.run(verbose)
Ejemplo n.º 14
0
Archivo: misc.py Proyecto: xmzzaa/PyTom
def frm_fourier_adaptive_wedge_vol_rscore(vf, wf, vg, wg, b, radius=None, weights=None):
    """Obsolete.
    """
    if not radius: # set the radius
        radius = vf.sizeX()/2
    if not weights: # set the weights
        weights = [1 for i in range(radius)]

    if not b: # set the bandwidth adaptively
        b_min = 4
        b_max = 128
    elif b.__class__ == tuple or b.__class__ == list:
        b_min = b[0]
        b_max = b[1]
    elif isinstance(b, int): # fixed bandwidth
        b_min = b
        b_max = b
    else:
        raise RuntimeError("Argument b is not valid!")

    from pytom.basic.fourier import fft, ifft, ftshift, iftshift
    from pytom_volume import vol, reducedToFull, real, imag, rescale
    from .vol2sf import vol2sf
    from pytom_numpy import vol2npy
    from math import log, ceil, pow

    # IMPORTANT!!! Should firstly do the IFFTSHIFT on the volume data (NOT FFTSHIFT since for odd-sized data it matters!),
    # and then followed by the FFT.
    vf = ftshift(reducedToFull(fft(iftshift(vf, inplace=False))), inplace=False)
    vg = ftshift(reducedToFull(fft(iftshift(vg, inplace=False))), inplace=False)

    vfr = real(vf)
    vfi = imag(vf)
    vgr = real(vg)
    vgi = imag(vg)

    get_bw = lambda x: int(pow(2, int(ceil(log(2*x, 2)))))

    res = None
    _last_bw = 0
    for r in range(1, radius+1):
        # calculate the appropriate bw
        bw = get_bw(r)
        if bw < b_min:
            bw = b_min
        if bw > b_max:
            bw = b_max

        # construct the wedge masks accordingly
        if _last_bw != bw:
            mf = create_wedge_sf(wf[0], wf[1], bw)
            mg = create_wedge_sf(wg[0], wg[1], bw)

        corr = frm_fourier_constrained_corr(vol2sf(vfr, r, bw), vol2sf(vfi, r, bw), mf, vol2sf(vgr, r, bw), vol2sf(vgi, r, bw), mg, True, False, True)
        
        if _last_bw != bw:
            if res is None:
                res = np.zeros((2*bw, 2*bw, 2*bw), dtype='double')
            else:
                res = enlarge2(res)

        res += corr*(r**2)*weights[r-1]

        _last_bw = bw

    return res
Ejemplo n.º 15
0
Archivo: misc.py Proyecto: xmzzaa/PyTom
def frm_determine_orientation(vf, wf, vg, wg, b, radius=None, weights=None, r_score=False, norm=False):
    """Auxiliary function for xu_align_vol. Find the angle to rotate vg to match vf, using only their power spectrums.

    Parameters
    ----------
    vf: The volume you want to match.
        pytom_volume.vol

    wf: The single tilt wedge information of volume vf.
        [missing_wedge_angle1, missing_wedge_angle2]. Note this is defined different with frm_align im frm.py!

    vg: The reference volume.
        pytom_volume.vol

    wg: The single tilt wedge information of volume vg.
        [missing_wedge_angle1, missing_wedge_angle2]. Note this is defined different with frm_align im frm.py!

    b: The adaptive bandwidth of spherical harmonics.
       List [min_bandwidth, max_bandwidth], min_bandwidth, max_bandwidth in the range [4, 64].
       Or integer, which would then mean to use fixed bandwidth: min_bandwidth = max_bandwidth = integer.

    radius: The maximal radius in the Fourier space, which is equal to say the maximal frequency involved in calculation.
            Integer. By default is half of the volume size.

    weights: Obsolete.

    r_score: Obsolete.

    norm: Obsolete.

    Returns
    -------
    The angle (Euler angle, ZXZ convention [Phi, Psi, Theta]) to rotate vg to match vf.
    """
    if not radius: # set the radius
        radius = vf.sizeX()/2
    if not weights: # set the weights
        weights = [1 for i in range(radius)]
    
    if not b: # set the bandwidth adaptively
        b_min = 4
        b_max = 128
    elif b.__class__ == tuple or b.__class__ == list:
        b_min = b[0]
        b_max = b[1]
    elif isinstance(b, int): # fixed bandwidth
        b_min = b
        b_max = b
    else:
        raise RuntimeError("Argument b is not valid!")
    
    from pytom.basic.fourier import fft, ifft, ftshift, iftshift
    from pytom_volume import vol, reducedToFull, rescale, abs, real
    from .vol2sf import vol2sf
    from pytom_numpy import vol2npy
    from math import log, ceil, pow

    # IMPORTANT!!! Should firstly do the IFFTSHIFT on the volume data (NOT FFTSHIFT since for odd-sized data it matters!),
    # and then followed by the FFT.
    vf = ftshift(reducedToFull(fft(iftshift(vf, inplace=False))), inplace=False)
    vg = ftshift(reducedToFull(fft(iftshift(vg, inplace=False))), inplace=False)

    ff = abs(vf)
    ff = real(ff)
    gg = abs(vg)
    gg = real(gg)
    
    get_bw = lambda x: int(pow(2, int(ceil(log(2*x, 2)))))
    
    numerator = None
    denominator1 = None
    denominator2 = None
    _last_bw = 0
    for r in range(1, radius+1):
        # calculate the appropriate bw
        bw = get_bw(r)
        if bw < b_min:
            bw = b_min
        if bw > b_max:
            bw = b_max
            
        # construct the wedge masks accordingly
        if _last_bw != bw:
            mf = create_wedge_sf(wf[0], wf[1], bw)
            mg = create_wedge_sf(wg[0], wg[1], bw)
        
        corr1, corr2, corr3 = frm_constrained_corr(vol2sf(ff, r, bw), mf, vol2sf(gg, r, bw), mg, norm, return_score=False)
        
        if _last_bw != bw:
            if numerator is None:
                numerator = np.zeros((2*bw, 2*bw, 2*bw), dtype='double')
                denominator1 = np.zeros((2*bw, 2*bw, 2*bw), dtype='double')
                denominator2 = np.zeros((2*bw, 2*bw, 2*bw), dtype='double')
            else:
                numerator = enlarge2(numerator)
                denominator1 = enlarge2(denominator1)
                denominator2 = enlarge2(denominator2)
        
        numerator += corr1*(r**2)*weights[r-1]
        denominator1 += corr2*(r**2)*weights[r-1]
        denominator2 += corr3*(r**2)*weights[r-1]
        
        _last_bw = bw
    
    res = numerator/(denominator1 * denominator2)**0.5

    return frm_find_topn_angles_interp2(res)
Ejemplo n.º 16
0
Archivo: misc.py Proyecto: xmzzaa/PyTom
def bart_align_vol(vf, wf, vg, wg, b, radius=None, peak_offset=None):
    """Implementation of Bartesaghi's approach for alignment. For detail, please check the paper.

    Parameters
    ----------
    vf: The volume you want to match.
        pytom_volume.vol

    wf: The single tilt wedge information of volume vf.
        [missing_wedge_angle1, missing_wedge_angle2]. Note this is defined different with frm_align im frm.py!

    vg: The reference volume.
        pytom_volume.vol

    wg: The single tilt wedge information of volume vg.
        [missing_wedge_angle1, missing_wedge_angle2]. Note this is defined different with frm_align im frm.py!

    b: The bandwidth of spherical harmonics.
       Integer in the range [4, 64]

    radius: The maximal radius in the Fourier space, which is equal to say the maximal frequency involved in calculation.
            Integer. By default is half of the volume size.

    peak_offset: The maximal offset which allows the peak of the score to be.
                 Or simply speaking, the maximal distance allowed to shift vg to match vf.
                 This parameter is needed to prevent shifting the reference volume out of the frame.
                 Integer. By default is half of the volume size.

    Returns
    -------
    The best translation and rotation (Euler angle, ZXZ convention [Phi, Psi, Theta]) to transform vg to match vf.
    (best_translation, best_rotation, correlation_score)
    """
    from pytom_volume import vol, rotateSpline, max, peak
    from pytom.basic.correlation import nXcf
    from pytom.basic.filter import lowpassFilter
    from pytom.basic.structures import WedgeInfo
    from pytom_volume import initSphere

    if not radius: # set the radius
        radius = vf.sizeX()/2

    if peak_offset is None:
        peak_offset = vol(vf.sizeX(), vf.sizeY(), vf.sizeZ())
        initSphere(peak_offset, vf.sizeX()/2, 0,0, vf.sizeX()/2,vf.sizeY()/2,vf.sizeZ()/2)
    elif isinstance(peak_offset, int):
        peak_radius = peak_offset
        peak_offset = vol(vf.sizeX(), vf.sizeY(), vf.sizeZ())
        initSphere(peak_offset, peak_radius, 0,0, vf.sizeX()/2,vf.sizeY()/2,vf.sizeZ()/2)
    elif peak_offset.__class__ == vol:
        pass
    else:
        raise RuntimeError('Peak offset is given wrong!')
    
    from pytom.basic.fourier import fft, ifft, ftshift, iftshift
    from pytom_volume import vol, reducedToFull, rescale, abs, real
    from .vol2sf import vol2sf
    from pytom_numpy import vol2npy
    from math import log, ceil, pow

    # IMPORTANT!!! Should firstly do the IFFTSHIFT on the volume data (NOT FFTSHIFT since for odd-sized data it matters!),
    # and then followed by the FFT.
    ff = abs(ftshift(reducedToFull(fft(iftshift(vf, inplace=False))), inplace=False))
    ff = real(ff)
    gg = abs(ftshift(reducedToFull(fft(iftshift(vg, inplace=False))), inplace=False))
    gg = real(gg)
    
    sf = None
    sg = None
    mf = create_wedge_sf(wf[0], wf[1], b)
    mg = create_wedge_sf(wg[0], wg[1], b)

    for r in range(3, radius+1): # Should start from 3 since the interpolation in the first 2 bands is not accurate.
        if sf is None:
            sf = vol2sf(ff, r, b)
            sg = vol2sf(gg, r, b)
        else:
            sf += vol2sf(ff, r, b)
            sg += vol2sf(gg, r, b)
    
    corr = frm_constrained_corr(sf, mf, sg, mg)
    ang, val = frm_find_best_angle_interp(corr)

    tmp = vol(vg.sizeX(),vg.sizeY(),vg.sizeZ())
    rotateSpline(vg, tmp, ang[0], ang[1], ang[2])
    wedge_f = WedgeInfo(90+wf[0], 90-wf[1])
    wedge_g = WedgeInfo(90+wg[0], 90-wg[1])
    cc = nXcf(lowpassFilter(wedge_g.apply(vf), radius, 0)[0], lowpassFilter(wedge_f.apply(tmp), radius, 0)[0])
    pos = peak(cc, peak_offset)
    pos, score = find_subpixel_peak_position(vol2npy(cc), pos)

    return (pos, ang, score)