Example #1
0
    def test_FRM(self):
        import swig_frm
        from sh_alignment.frm import frm_align
        from pytom_volume import vol, rotate, shift
        from pytom.basic.structures import Rotation, Shift
        from pytom.tools.maths import rotation_distance

        v = vol(32,32,32)
        v.setAll(0)
        vMod = vol(32,32,32)
        vRot = vol(32,32,32)
        v.setV(1,10,10,10)
        v.setV(1,20,20,20)
        v.setV(1,15,15,15)
        v.setV(1,7,21,7)
        
        rotation = Rotation(10,20,30)
        shiftO = Shift(1,-3,5)
        
        rotate(v,vRot,rotation.getPhi(),rotation.getPsi(),rotation.getTheta())
        shift(vRot,vMod,shiftO.getX(),shiftO.getY(),shiftO.getZ())
        
        pos, ang, score = frm_align(vMod, None, v, None, [4, 64], 10)
        rotdist = rotation_distance(ang1=rotation, ang2=ang)
        diffx = shiftO[0] - (pos[0] - 16)
        diffy = shiftO[1] - (pos[1] - 16)
        diffz = shiftO[2] - (pos[2] - 16)

        self.assertTrue( rotdist < 5., msg='Rotations are different')
        self.assertTrue( diffx < .5, msg='x-difference > .5')
        self.assertTrue( diffy < .5, msg='y-difference > .5')
        self.assertTrue( diffz < .5, msg='z-difference > .5')
Example #2
0
def frm_proxy(p, ref, freq, offset, binning, mask):
    from pytom_volume import read, pasteCenter, vol
    from pytom.basic.transformations import resize
    from pytom.basic.structures import Shift, Rotation
    from sh_alignment.frm import frm_align
    import time

    v = p.getVolume(binning)

    if mask.__class__ == str:
        maskBin = read(mask, 0, 0, 0, 0, 0, 0, 0, 0, 0, binning, binning,
                       binning)
        if v.sizeX() != maskBin.sizeX() or v.sizeY() != maskBin.sizeY(
        ) or v.sizeZ() != maskBin.sizeZ():
            mask = vol(v.sizeX(), v.sizeY(), v.sizeZ())
            mask.setAll(0)
            pasteCenter(maskBin, mask)
        else:
            mask = maskBin

    pos, angle, score = frm_align(v, p.getWedge(), ref.getVolume(), None,
                                  [4, 64], freq, offset, mask)

    return (Shift([
        pos[0] - v.sizeX() // 2, pos[1] - v.sizeY() // 2,
        pos[2] - v.sizeZ() // 2
    ]), Rotation(angle), score, p.getFilename())
Example #3
0
    def run(self, verbose=False):
        from sh_alignment.frm import frm_align
        from pytom.basic.structures import Shift, Rotation
        from pytom.tools.ProgressBar import FixedProgBar

        while True:
            # get the job
            try:
                job = self.get_job()
            except:
                if verbose:
                    print(self.node_name + ': end')
                break  # get some non-job message, break it

            if verbose:
                prog = FixedProgBar(0,
                                    len(job.particleList) - 1,
                                    self.node_name + ':')
                i = 0

            ref = job.reference[0].getVolume()
            # run the job
            for p in job.particleList:
                if verbose:
                    prog.update(i)
                    i += 1
                v = p.getVolume()

                pos, angle, score = frm_align(v, p.getWedge(), ref, None,
                                              job.bw_range, job.freq,
                                              job.peak_offset,
                                              job.mask.getVolume())

                p.setShift(
                    Shift([
                        pos[0] - v.sizeX() / 2, pos[1] - v.sizeY() / 2,
                        pos[2] - v.sizeZ() / 2
                    ]))
                p.setRotation(Rotation(angle))
                p.setScore(FRMScore(score))

            # average the particle list
            name_prefix = os.path.join(
                self.destination, self.node_name + '_' + str(job.max_iter))
            self.average_sub_pl(job.particleList, name_prefix, job.weighting)

            # send back the result
            self.send_result(
                FRMResult(name_prefix, job.particleList, self.mpi_id))

        pytom_mpi.finalise()
Example #4
0
def FRMAlignmentWrapper(particle,wedgeParticle, reference, wedgeReference,bandwidth, highestFrequency, mask = None , peakPrior = None):
    """
    FRMAlignmentWrapper: Wrapper for frm_align to handle PyTom objects.
    @param particle: The particle 
    @type particle: L{pytom.basic.structures.Particle}
    @param wedgeParticle: Wedge object of particle
    @type wedgeParticle: L{pytom.basic.structures.Wedge} 
    @param reference: Reference used for alignment
    @type reference: L{pytom.basic.structures.Reference}
    @param wedgeReference: Information about reference wedge 
    @type wedgeReference: L{pytom.basic.structures.Wedge}
    @param bandwidth: The bandwidth of the spherical harmonics - lowestBand used, highestBand used 
    @type bandwidth: [lowestBand,highestBand]
    @param highestFrequency: Highest frequency for lowpass filter in fourierspace
    @type highestFrequency: int
    @param mask: Mask that is applied to the particle
    @type mask: L{pytom.basic.structures.Mask}
    @param peakPrior: Maximum distance of peak from origin
    @type peakPrior: L{pytom.score.score.PeakPrior} or an integer
    @return: Returns a list of [L{pytom.basic.structures.Shift}, L{pytom.basic.structures.Rotation}, scoreValue]
    """
    from pytom.basic.structures import Particle,Reference,Mask,Wedge,Shift,Rotation
    from pytom.score.score import PeakPrior
    from sh_alignment.frm import frm_align
    
    if particle.__class__ == Particle:
        particle = particle.getVolume()
    
    if reference.__class__ == Reference:
        reference = reference.getVolume()
    
    if mask.__class__ == Mask:
        mask = mask.getVolume()
    else:
        mask = None
        
    if bandwidth.__class__ == list and len(bandwidth) != 2:
        raise RuntimeError('Bandwidth parameter must be a list of two integers!')

    if peakPrior and peakPrior.__class__ == PeakPrior:
        peakPrior = int(peakPrior.getRadius())
    
    if not peakPrior or peakPrior < 0.0001:
        peakPrior = None
        
    pos, angle, score = frm_align(particle, wedgeParticle.getWedgeObject(), reference*mask, wedgeReference.getWedgeObject(), bandwidth, int(highestFrequency), peakPrior)
            
    return [Shift([pos[0]-particle.sizeX()/2, pos[1]-particle.sizeY()/2, pos[2]-particle.sizeZ()/2]), Rotation(angle), score]
Example #5
0
def calculate_difference_map(v1,
                             band1,
                             v2,
                             band2,
                             mask=None,
                             focus_mask=None,
                             align=True,
                             sigma=None,
                             threshold=0.4):
    """mask if for alignment, while focus_mask is for difference map.
    """
    from pytom_volume import vol, power, abs, limit, transformSpline, variance, mean, max, min
    from pytom.basic.normalise import mean0std1
    from pytom.basic.filter import lowpassFilter

    # do lowpass filtering first
    lv1 = lowpassFilter(v1, band1, band1 / 10.)[0]
    lv2 = lowpassFilter(v2, band2, band2 / 10.)[0]

    # do alignment of two volumes, if required. v1 is used as reference.
    if align:
        from sh_alignment.frm import frm_align
        band = int(band1 if band1 < band2 else band2)
        pos, angle, score = frm_align(lv2, None, lv1, None, [4, 64], band,
                                      lv1.sizeX() // 4, mask)
        shift = [
            pos[0] - v1.sizeX() // 2, pos[1] - v1.sizeY() // 2,
            pos[2] - v1.sizeZ() // 2
        ]

        # transform v2
        lvv2 = vol(lv2)
        transformSpline(lv2, lvv2, -angle[1], -angle[0], -angle[2],
                        lv2.sizeX() // 2,
                        lv2.sizeY() // 2,
                        lv2.sizeZ() // 2, -shift[0], -shift[1], -shift[2], 0,
                        0, 0)
    else:
        lvv2 = lv2

    # do normalization
    mean0std1(lv1)
    mean0std1(lvv2)

    # only consider the density beyond certain sigma
    if sigma is None or sigma == 0:
        pass
    elif sigma < 0:  # negative density counts
        assert min(lv1) < sigma
        assert min(lvv2) < sigma
        limit(lv1, 0, 0, sigma, 0, False, True)
        limit(lvv2, 0, 0, sigma, 0, False, True)
    else:  # positive density counts
        assert max(lv1) > sigma
        assert max(lvv2) > sigma
        limit(lv1, sigma, 0, 0, 0, True, False)
        limit(lvv2, sigma, 0, 0, 0, True, False)

    # if we want to focus on specific area only
    if focus_mask:
        lv1 *= focus_mask
        lvv2 *= focus_mask

    # calculate the STD map
    avg = (lv1 + lvv2) / 2
    var1 = avg - lv1
    power(var1, 2)
    var2 = avg - lvv2
    power(var2, 2)

    std_map = var1 + var2
    power(std_map, 0.5)

    # calculate the coefficient of variance map
    # std_map = std_map/abs(avg)

    if focus_mask:
        std_map *= focus_mask

    # threshold the STD map
    mv = mean(std_map)
    threshold = mv + (max(std_map) - mv) * threshold
    limit(std_map, threshold, 0, threshold, 1, True, True)

    # do a lowpass filtering
    std_map1 = lowpassFilter(std_map, v1.sizeX() // 4, v1.sizeX() / 40.)[0]

    if align:
        std_map2 = vol(std_map)
        transformSpline(std_map1, std_map2, angle[0], angle[1], angle[2],
                        v1.sizeX() // 2,
                        v1.sizeY() // 2,
                        v1.sizeZ() // 2, 0, 0, 0, shift[0], shift[1], shift[2])
    else:
        std_map2 = std_map1

    limit(std_map1, 0.5, 0, 1, 1, True, True)
    limit(std_map2, 0.5, 0, 1, 1, True, True)

    # return the respective difference maps
    return (std_map1, std_map2)
Example #6
0
    def run(self, verbose=False):
        from sh_alignment.frm import frm_align
        from sh_alignment.constrained_frm import frm_constrained_align, AngularConstraint
        from pytom.basic.structures import Shift, Rotation
        from pytom.tools.ProgressBar import FixedProgBar
        from pytom.basic.transformations import resize, resizeFourier
        binningType = 'Fourier'

        while True:
            # get the job
            try:
                job = self.get_job()
            except:
                if verbose:
                    print(self.node_name + ': end')
                break  # get some non-job message, break it

            if verbose:
                prog = FixedProgBar(0,
                                    len(job.particleList) - 1,
                                    self.node_name + ':')
                i = 0
            ref = job.reference.getVolume()
            if job.binning > 1:
                ref = resize(volume=ref,
                             factor=1. / job.binning,
                             interpolation=binningType)
                if type(ref) == tuple:
                    ref = ref[0]
            # re-set max frequency in case it exceeds Nyquist - a bit brute force
            job.freq = min(job.freq, ref.sizeX() // 2 - 1)
            # run the job
            for p in job.particleList:
                if verbose:
                    prog.update(i)
                    i += 1
                v = p.getVolume()
                if job.binning > 1:
                    v = resize(volume=v,
                               factor=1. / job.binning,
                               interpolation=binningType)
                    if type(v) == tuple:
                        v = v[0]
                mask = job.mask.getVolume()
                if job.binning > 1:
                    mask = resize(volume=mask,
                                  factor=1. / job.binning,
                                  interpolation='Spline')
                    if type(mask) == tuple:
                        mask = mask[0]
                if job.constraint:
                    constraint = job.constraint
                    if job.constraint.type == AngularConstraint.ADP_ANGLE:  # set the constraint around certain angle
                        rot = p.getRotation()
                        constraint.setAngle(rot.getPhi(), rot.getPsi(),
                                            rot.getTheta())
                    #pos, angle, score = frm_constrained_align(v, p.getWedge(), ref, None, job.bw_range, job.freq, job.peak_offset, job.mask.getVolume(), constraint)
                    if job.binning > 1:
                        pos, angle, score = frm_constrained_align(
                            v, p.getWedge(), ref, None, job.bw_range, job.freq,
                            job.peak_offset / job.binning, mask, constraint)
                    else:
                        pos, angle, score = frm_constrained_align(
                            v, p.getWedge(), ref, None, job.bw_range, job.freq,
                            job.peak_offset, mask, constraint)
                else:
                    #pos, angle, score = frm_align(v, p.getWedge(), ref, None, job.bw_range, job.freq, job.peak_offset, job.mask.getVolume())
                    #if job.binning >1:
                    #    print(job.peak_offset)
                    #    print(type(job.peak_offset))
                    #    print(job.peak_offset/job.binning)
                    #    print(type(job.binning))
                    #    pos, angle, score = frm_align(v, p.getWedge(), ref, None, job.bw_range, job.freq,
                    #                            job.peak_offset/job.binning, mask)
                    #else:
                    pos, angle, score = frm_align(v, p.getWedge(), ref, None,
                                                  job.bw_range, job.freq,
                                                  job.peak_offset, mask)

                if job.binning > 1:
                    pos[0] = job.binning * (pos[0] - v.sizeX() / 2)
                    pos[1] = job.binning * (pos[1] - v.sizeY() / 2)
                    pos[2] = job.binning * (pos[2] - v.sizeZ() / 2)
                    p.setShift(Shift([pos[0], pos[1], pos[2]]))
                else:
                    p.setShift(
                        Shift([
                            pos[0] - v.sizeX() / 2, pos[1] - v.sizeY() / 2,
                            pos[2] - v.sizeZ() / 2
                        ]))
                p.setRotation(Rotation(angle))
                p.setScore(FRMScore(score))

            # average the particle list
            name_prefix = os.path.join(
                job.destination, self.node_name + '_' + str(job.max_iter))
            self.average_sub_pl(job.particleList, name_prefix, job.weighting)

            # send back the result
            self.send_result(
                FRMResult(name_prefix, job.particleList, self.mpi_id))

        pytom_mpi.finalise()
Example #7
0
    def start(self, job, verbose=False):
        if self.mpi_id == 0:
            from pytom.basic.structures import ParticleList, Reference
            from pytom.basic.resolution import bandToAngstrom
            from pytom.basic.filter import lowpassFilter
            from math import ceil

            # randomly split the particle list into 2 half sets
            if len(job.particleList.splitByClass()) != 2:
                import numpy as np
                n = len(job.particleList)
                labels = np.random.randint(2, size=(n, ))
                print(self.node_name + ': Number of 1st half set:',
                      n - np.sum(labels), 'Number of 2nd half set:',
                      np.sum(labels))
                for i in range(n):
                    p = job.particleList[i]
                    p.setClass(labels[i])

            self.destination = job.destination
            new_reference = job.reference
            old_freq = job.freq
            new_freq = job.freq
            # main node
            for i in range(job.max_iter):
                if verbose:
                    print(self.node_name + ': starting iteration %d ...' % i)

                # construct a new job by updating the reference and the frequency
                new_job = FRMJob(job.particleList, new_reference, job.mask,
                                 job.peak_offset, job.sampleInformation,
                                 job.bw_range, new_freq, job.destination,
                                 job.max_iter - i, job.r_score, job.weighting)

                # distribute it
                self.distribute_job(new_job, verbose)

                # get the result back
                all_even_pre = None  # the 1st set
                all_even_wedge = None
                all_odd_pre = None  # the 2nd set
                all_odd_wedge = None
                pl = ParticleList()
                for j in range(self.num_workers):
                    result = self.get_result()
                    pl += result.pl
                    pre, wedge = self.retrieve_res_vols(result.name)

                    if self.assignment[result.worker_id] == 0:
                        if all_even_pre:
                            all_even_pre += pre
                            all_even_wedge += wedge
                        else:
                            all_even_pre = pre
                            all_even_wedge = wedge
                    else:
                        if all_odd_pre:
                            all_odd_pre += pre
                            all_odd_wedge += wedge
                        else:
                            all_odd_pre = pre
                            all_odd_wedge = wedge

                # write the new particle list to the disk
                pl.toXMLFile('aligned_pl_iter' + str(i) + '.xml')

                # create the averages separately
                if verbose:
                    print(self.node_name + ': determining the resolution ...')
                even = self.create_average(all_even_pre, all_even_wedge)
                odd = self.create_average(all_odd_pre, all_odd_wedge)

                # apply symmetries if any
                even = job.symmetries.applyToParticle(even)
                odd = job.symmetries.applyToParticle(odd)

                # determine the transformation between even and odd
                # here we assume the wedge from both sets are fully sampled
                from sh_alignment.frm import frm_align
                pos, angle, score = frm_align(odd, None, even, None,
                                              job.bw_range, new_freq,
                                              job.peak_offset)
                print(self.node_name +
                      'Transform of even set to match the odd set - shift: ' +
                      str(pos) + ' rotation: ' + str(angle))

                # transform the odd set accordingly
                from pytom_volume import vol, transformSpline
                from pytom.basic.fourier import ftshift
                from pytom_volume import reducedToFull
                from pytom_freqweight import weight
                transformed_odd_pre = vol(odd.sizeX(), odd.sizeY(),
                                          odd.sizeZ())
                full_all_odd_wedge = reducedToFull(all_odd_wedge)
                ftshift(full_all_odd_wedge)
                odd_weight = weight(
                    full_all_odd_wedge)  # the funny part of pytom
                transformed_odd = vol(odd.sizeX(), odd.sizeY(), odd.sizeZ())

                transformSpline(all_odd_pre, transformed_odd_pre, -angle[1],
                                -angle[0], -angle[2],
                                odd.sizeX() / 2,
                                odd.sizeY() / 2,
                                odd.sizeZ() / 2, -(pos[0] - odd.sizeX() / 2),
                                -(pos[1] - odd.sizeY() / 2),
                                -(pos[2] - odd.sizeZ() / 2), 0, 0, 0)
                odd_weight.rotate(-angle[1], -angle[0], -angle[2])
                transformed_odd_wedge = odd_weight.getWeightVolume(True)
                transformSpline(odd, transformed_odd, -angle[1], -angle[0],
                                -angle[2],
                                odd.sizeX() / 2,
                                odd.sizeY() / 2,
                                odd.sizeZ() / 2, -(pos[0] - odd.sizeX() / 2),
                                -(pos[1] - odd.sizeY() / 2),
                                -(pos[2] - odd.sizeZ() / 2), 0, 0, 0)

                all_odd_pre = transformed_odd_pre
                all_odd_wedge = transformed_odd_wedge
                odd = transformed_odd

                # determine resolution
                resNyquist, resolutionBand, numberBands = self.determine_resolution(
                    even, odd, job.fsc_criterion, None, job.mask, verbose)

                # write the half set to the disk
                even.write(
                    os.path.join(self.destination,
                                 'fsc_' + str(i) + '_even.em'))
                odd.write(
                    os.path.join(self.destination,
                                 'fsc_' + str(i) + '_odd.em'))

                current_resolution = bandToAngstrom(
                    resolutionBand, job.sampleInformation.getPixelSize(),
                    numberBands, 1)
                if verbose:
                    print(
                        self.node_name + ': current resolution ' +
                        str(current_resolution), resNyquist)

                # create new average
                all_even_pre += all_odd_pre
                all_even_wedge += all_odd_wedge
                average = self.create_average(all_even_pre, all_even_wedge)

                # apply symmetries
                average = job.symmetries.applyToParticle(average)

                # filter average to resolution
                average_name = os.path.join(self.destination,
                                            'average_iter' + str(i) + '.em')
                average.write(average_name)

                # update the references
                new_reference = [
                    Reference(
                        os.path.join(self.destination,
                                     'fsc_' + str(i) + '_even.em')),
                    Reference(
                        os.path.join(self.destination,
                                     'fsc_' + str(i) + '_odd.em'))
                ]

                # low pass filter the reference and write it to the disk
                filtered = lowpassFilter(average, ceil(resolutionBand),
                                         ceil(resolutionBand) / 10)
                filtered_ref_name = os.path.join(
                    self.destination, 'average_iter' + str(i) + '_res' +
                    str(current_resolution) + '.em')
                filtered[0].write(filtered_ref_name)

                # if the position/orientation is not improved, break it

                # change the frequency to a higher value
                new_freq = int(ceil(resolutionBand)) + 1
                if new_freq <= old_freq:
                    if job.adaptive_res is not False:  # two different strategies
                        print(
                            self.node_name +
                            ': Determined resolution gets worse. Include additional %f percent frequency to be aligned!'
                            % job.adaptive_res)
                        new_freq = int((1 + job.adaptive_res) * old_freq)
                    else:  # always increase by 1
                        print(
                            self.node_name +
                            ': Determined resolution gets worse. Increase the frequency to be aligned by 1!'
                        )
                        new_freq = old_freq + 1
                        old_freq = new_freq
                else:
                    old_freq = new_freq
                if new_freq >= numberBands:
                    print(self.node_name +
                          ': New frequency too high. Terminate!')
                    break

                if verbose:
                    print(self.node_name + ': change the frequency to ' +
                          str(new_freq))

            # send end signal to other nodes and terminate itself
            self.end(verbose)
        else:
            # other nodes
            self.run(verbose)
Example #8
0
    def run(self, verbose=False):
        from sh_alignment.frm import frm_align
        from pytom.basic.structures import Shift, Rotation
        from pytom.tools.ProgressBar import FixedProgBar
        from pytom.basic.fourier import convolute
        from pytom_volume import read, power

        while True:
            # get the job
            try:
                job = self.get_job()
            except:
                if verbose:
                    print(self.node_name + ': end')
                break  # get some non-job message, break it

            if verbose:
                prog = FixedProgBar(0,
                                    len(job.particleList) - 1,
                                    self.node_name + ':')
                i = 0

            ref = []
            ref.append(job.reference[0].getVolume())
            ref.append(job.reference[1].getVolume())

            # convolute with the approximation of the CTF
            if job.sum_ctf_sqr:
                ctf = read(job.sum_ctf_sqr)
                power(ctf,
                      0.5)  # the number of CTFs should not matter, should it?
                ref0 = ref[0]
                ref1 = ref[1]
                ref0 = convolute(ref0, ctf, True)
                ref1 = convolute(ref1, ctf, True)
                ref = [ref0, ref1]

            if job.bfactor and job.bfactor != 'None':
                #                restore_kernel = create_bfactor_restore_vol(ref.sizeX(), job.sampleInformation.getPixelSize(), job.bfactor)
                from pytom_volume import vol, read
                bfactor_kernel = read(job.bfactor)
                unit = vol(bfactor_kernel)
                unit.setAll(1)
                restore_kernel = unit / bfactor_kernel

            # run the job
            for p in job.particleList:
                if verbose:
                    prog.update(i)
                    i += 1
                v = p.getVolume()

                #                if weights is None: # create the weights according to the bfactor
                #                    if job.bfactor == 0:
                #                        weights = [1 for k in xrange(job.freq)]
                #                    else:
                #                        restore_fnc = create_bfactor_restore_fnc(ref.sizeX(), job.sampleInformation.getPixelSize(), job.bfactor)
                #                        # cut out the corresponding part and square it to get the weights!
                #                        weights = restore_fnc[1:job.freq+1]**2

                if job.bfactor and job.bfactor != 'None':
                    v = convolute(v, restore_kernel,
                                  True)  # if bfactor is set, restore it

                pos, angle, score = frm_align(v, p.getWedge(),
                                              ref[int(p.getClass())], None,
                                              job.bw_range, job.freq,
                                              job.peak_offset,
                                              job.mask.getVolume())

                p.setShift(
                    Shift([
                        pos[0] - v.sizeX() / 2, pos[1] - v.sizeY() / 2,
                        pos[2] - v.sizeZ() / 2
                    ]))
                p.setRotation(Rotation(angle))
                p.setScore(FRMScore(score))

            # average the particle list
            name_prefix = self.node_name + '_' + str(job.max_iter)
            pair = ParticleListPair('', job.ctf_conv_pl, None, None)
            pair.set_phase_flip_pl(job.particleList)
            self.average_sub_pl(
                pair.get_ctf_conv_pl(),
                name_prefix)  # operate on the CTF convoluted projection!

            # send back the result
            self.send_result(
                FRMResult(name_prefix, job.particleList, self.mpi_id))

        pytom_mpi.finalise()
Example #9
0
    def start(self, job, verbose=False):
        if self.mpi_id == 0:
            from pytom.basic.structures import ParticleList, Reference
            from pytom.basic.resolution import bandToAngstrom
            from pytom.basic.filter import lowpassFilter
            from math import ceil
            from pytom.basic.fourier import convolute
            from pytom_volume import vol, power, read

            # randomly split the particle list into 2 half sets
            import numpy as np
            num_pairs = len(job.particleList.pairs)
            for i in range(num_pairs):
                # randomize the class labels to indicate the two half sets
                pl = job.particleList.pairs[i].get_phase_flip_pl()
                n = len(pl)
                labels = np.random.randint(2, size=(n, ))
                print(self.node_name + ': Number of 1st half set:',
                      n - np.sum(labels), 'Number of 2nd half set:',
                      np.sum(labels))
                for j in range(n):
                    p = pl[j]
                    p.setClass(labels[j])

            new_reference = job.reference
            old_freq = job.freq
            new_freq = job.freq
            # main node
            for i in range(job.max_iter):
                if verbose:
                    print(self.node_name + ': starting iteration %d ...' % i)

                # construct a new job by updating the reference and the frequency
                # here the job.particleList is actually ParticleListSet
                new_job = MultiDefocusJob(job.particleList, new_reference,
                                          job.mask, job.peak_offset,
                                          job.sampleInformation, job.bw_range,
                                          new_freq, job.destination,
                                          job.max_iter - i, job.r_score,
                                          job.weighting, job.bfactor)

                # distribute it
                num_all_particles = self.distribute_job(new_job, verbose)

                # calculate the denominator
                sum_ctf_squared = None
                for pair in job.particleList.pairs:
                    if sum_ctf_squared is None:
                        sum_ctf_squared = pair.get_ctf_sqr_vol() * pair.snr
                    else:
                        sum_ctf_squared += pair.get_ctf_sqr_vol() * pair.snr

                # get the result back
                all_even_pre = None
                all_even_wedge = None
                all_odd_pre = None
                all_odd_wedge = None
                pls = []
                for j in range(len(job.particleList.pairs)):
                    pls.append(ParticleList())

                for j in range(self.num_workers):
                    result = self.get_result()

                    pair_id = self.assignment[result.worker_id]
                    pair = job.particleList.pairs[pair_id]

                    pl = pls[pair_id]
                    pl += result.pl
                    even_pre, even_wedge, odd_pre, odd_wedge = self.retrieve_res_vols(
                        result.name)

                    if all_even_pre:
                        all_even_pre += even_pre * pair.snr
                        all_even_wedge += even_wedge
                        all_odd_pre += odd_pre * pair.snr
                        all_odd_wedge += odd_wedge
                    else:
                        all_even_pre = even_pre * pair.snr
                        all_even_wedge = even_wedge
                        all_odd_pre = odd_pre * pair.snr
                        all_odd_wedge = odd_wedge

                # write the new particle list to the disk
                for j in range(len(job.particleList.pairs)):
                    pls[j].toXMLFile('aligned_pl' + str(j) + '_iter' + str(i) +
                                     '.xml')

                # correct for the number of particles in wiener filter
                sum_ctf_squared = sum_ctf_squared / num_all_particles
                #                all_even_pre = all_even_pre/(num_all_particles/2)
                #                all_odd_pre = all_odd_pre/(num_all_particles/2)

                # bfactor
                if job.bfactor and job.bfactor != 'None':
                    #                    bfactor_kernel = create_bfactor_vol(sum_ctf_squared.sizeX(), job.sampleInformation.getPixelSize(), job.bfactor)
                    bfactor_kernel = read(job.bfactor)
                    bfactor_kernel_sqr = vol(bfactor_kernel)
                    power(bfactor_kernel_sqr, 2)
                    all_even_pre = convolute(all_even_pre, bfactor_kernel,
                                             True)
                    all_odd_pre = convolute(all_odd_pre, bfactor_kernel, True)
                    sum_ctf_squared = sum_ctf_squared * bfactor_kernel_sqr

                # create averages of two sets
                if verbose:
                    print(self.node_name + ': determining the resolution ...')
                even = self.create_average(
                    all_even_pre, sum_ctf_squared, all_even_wedge
                )  # assume that the CTF sum is the same for the even and odd
                odd = self.create_average(all_odd_pre, sum_ctf_squared,
                                          all_odd_wedge)

                # determine the transformation between even and odd
                # here we assume the wedge from both sets are fully sampled
                from sh_alignment.frm import frm_align
                pos, angle, score = frm_align(odd, None, even, None,
                                              job.bw_range, new_freq,
                                              job.peak_offset)
                print(
                    self.node_name +
                    ': transform of even set to match the odd set - shift: ' +
                    str(pos) + ' rotation: ' + str(angle))

                # transform the odd set accordingly
                from pytom_volume import vol, transformSpline
                from pytom.basic.fourier import ftshift
                from pytom_volume import reducedToFull
                from pytom_freqweight import weight
                transformed_odd_pre = vol(odd.sizeX(), odd.sizeY(),
                                          odd.sizeZ())
                full_all_odd_wedge = reducedToFull(all_odd_wedge)
                ftshift(full_all_odd_wedge)
                odd_weight = weight(
                    full_all_odd_wedge)  # the funny part of pytom
                transformed_odd = vol(odd.sizeX(), odd.sizeY(), odd.sizeZ())

                transformSpline(all_odd_pre, transformed_odd_pre, -angle[1],
                                -angle[0], -angle[2], int(odd.sizeX() / 2),
                                int(odd.sizeY() / 2), int(odd.sizeZ() / 2),
                                -(pos[0] - odd.sizeX() / 2),
                                -(pos[1] - odd.sizeY() / 2),
                                -(pos[2] - odd.sizeZ() / 2), 0, 0, 0)
                odd_weight.rotate(-angle[1], -angle[0], -angle[2])
                transformed_odd_wedge = odd_weight.getWeightVolume(True)
                transformSpline(odd, transformed_odd, -angle[1], -angle[0],
                                -angle[2], int(odd.sizeX() / 2),
                                int(odd.sizeY() / 2), int(odd.sizeZ() / 2),
                                -(pos[0] - odd.sizeX() / 2),
                                -(pos[1] - odd.sizeY() / 2),
                                -(pos[2] - odd.sizeZ() / 2), 0, 0, 0)

                all_odd_pre = transformed_odd_pre
                all_odd_wedge = transformed_odd_wedge
                odd = transformed_odd

                # apply symmetries before determine resolution
                # with gold standard you should be careful about applying the symmetry!
                even = job.symmetries.applyToParticle(even)
                odd = job.symmetries.applyToParticle(odd)
                resNyquist, resolutionBand, numberBands = self.determine_resolution(
                    even, odd, job.fsc_criterion, None, job.mask, verbose)

                # write the half set to the disk
                even.write('fsc_' + str(i) + '_even.em')
                odd.write('fsc_' + str(i) + '_odd.em')

                current_resolution = bandToAngstrom(
                    resolutionBand, job.sampleInformation.getPixelSize(),
                    numberBands, 1)
                if verbose:
                    print(
                        self.node_name + ': current resolution ' +
                        str(current_resolution), resNyquist)

                # create new average
                all_even_pre += all_odd_pre
                all_even_wedge += all_odd_wedge
                #                all_even_pre = all_even_pre/2 # correct for the number of particles in wiener filter
                average = self.create_average(all_even_pre, sum_ctf_squared,
                                              all_even_wedge)

                # apply symmetries
                average = job.symmetries.applyToParticle(average)

                # filter average to resolution and update the new reference
                average_name = 'average_iter' + str(i) + '.em'
                average.write(average_name)

                # update the references
                new_reference = [
                    Reference('fsc_' + str(i) + '_even.em'),
                    Reference('fsc_' + str(i) + '_odd.em')
                ]

                # low pass filter the reference and write it to the disk
                filtered = lowpassFilter(average, ceil(resolutionBand),
                                         ceil(resolutionBand) / 10)
                filtered_ref_name = 'average_iter' + str(i) + '_res' + str(
                    current_resolution) + '.em'
                filtered[0].write(filtered_ref_name)

                # change the frequency to a higher value
                new_freq = int(ceil(resolutionBand)) + 1
                if new_freq <= old_freq:
                    if job.adaptive_res is not False:  # two different strategies
                        print(
                            self.node_name +
                            ': Determined resolution gets worse. Include additional %f percent frequency to be aligned!'
                            % job.adaptive_res)
                        new_freq = int((1 + job.adaptive_res) * old_freq)
                    else:  # always increase by 1
                        print(
                            self.node_name +
                            ': Determined resolution gets worse. Increase the frequency to be aligned by 1!'
                        )
                        new_freq = old_freq + 1
                        old_freq = new_freq
                else:
                    old_freq = new_freq
                if new_freq >= numberBands:
                    print(self.node_name +
                          ': Determined frequency too high. Terminate!')
                    break

                if verbose:
                    print(self.node_name + ': change the frequency to ' +
                          str(new_freq))

            # send end signal to other nodes and terminate itself
            self.end(verbose)
        else:
            # other nodes
            self.run(verbose)