def createOutputStep(self):
        imgSet = self.inputParticles.get()
        partSet = self._createSetOfParticles()
        partSet.copyInfo(imgSet)
        outImagesMd = self._getExtraPath('expanded_particles.star')

        # remove repeating rlnImageId column
        tableName = ''
        if Plugin.IS_GT30():
            tableName = 'particles'
            mdOptics = Table(fileName=outImagesMd, tableName='optics')

        mdOut = Table(fileName=outImagesMd, tableName=tableName)
        mdOut.removeColumns("rlnImageId")
        with open(outImagesMd, "w") as f:
            mdOut.writeStar(f, tableName=tableName)
            if Plugin.IS_GT30():
                mdOptics.writeStar(f, tableName='optics')

        reader = convert.createReader()
        reader.readSetOfParticles(
            outImagesMd, partSet,
            alignType=ALIGN_PROJ,
            postprocessImageRow=self._postprocessImageRow)

        self._defineOutputs(outputParticles=partSet)
        self._defineSourceRelation(imgSet, partSet)
Exemplo n.º 2
0
    def _defineParamDict(self):
        """ Define all parameters to run relion_postprocess"""
        # It seems that in Relion3 now the input should be the map
        # filename and not the prefix as before
        inputFn = self._getFileName('half1')

        self.paramDict = {'--i': inputFn,
                          '--o': self._getExtraPath('postprocess'),
                          '--angpix': self._getOutputPixelSize(),
                          # Expert params
                          '--filter_edge_width': self.filterEdgeWidth.get(),
                          '--randomize_at_fsc': self.randomizeAtFsc.get(),
                          '--mask': self._getFileName('mask')
                          }

        mtfFile = self.mtf.get()
        if mtfFile:
            self.paramDict['--mtf'] = mtfFile

        if self.doAutoBfactor:
            self.paramDict['--auto_bfac'] = ''
            self.paramDict['--autob_lowres'] = self.bfactorLowRes.get()
            self.paramDict['--autob_highres'] = self.bfactorHighRes.get()
        else:
            self.paramDict['--adhoc_bfac'] = self.bfactor.get()

        if self.skipFscWeighting:
            self.paramDict['--skip_fsc_weighting'] = ''
            self.paramDict['--low_pass'] = self.lowRes.get()

        if Plugin.IS_GT30() and self.origPixelSize.get() != -1.0:
            self.paramDict['--mtf_angpix'] = self.origPixelSize.get()
Exemplo n.º 3
0
    def _insertReconstructStep(self):
        imgSet = self.inputParticles.get()

        params = ' --i %s' % self._getFileName('input_particles')
        params += ' --o %s' % self._getFileName('output_volume')
        params += ' --sym %s' % self.symmetryGroup.get()
        params += ' --angpix %0.5f' % imgSet.getSamplingRate()
        params += ' --maxres %0.3f' % self.maxRes.get()
        params += ' --pad %0.3f' % self.pad.get()

        subset = -1 if self.subset.get() == 0 else self.subset
        params += ' --subset %d' % subset

        if Plugin.IS_GT30():
            params += ' --class %d' % self.classNum.get()

        if self.doCTF:
            params += ' --ctf'
            if self.ctfIntactFirstPeak:
                params += ' --ctf_intact_first_peak'

            if imgSet.isPhaseFlipped():
                params += ' --ctf_phase_flipped'

        if self.extraParams.hasValue():
            params += " " + self.extraParams.get()

        self._insertFunctionStep('reconstructStep', params)
Exemplo n.º 4
0
    def _getMdOut(self, it, prefix, ref3d):
        randomSet = self._getRandomSet(prefix)
        dataStar = self._getDataStar(prefix, it)
        tableName = 'particles' if Plugin.IS_GT30() else None
        mdOut = []

        table = Table(fileName=dataStar, tableName=tableName)
        for row in table:
            if 0 < randomSet < 3:
                if int(row.rlnRandomSubset) == randomSet and int(
                        row.rlnClassNumber) == ref3d:
                    mdOut.append(row)
            else:
                if int(row.rlnClassNumber) == ref3d:
                    mdOut.append(row)

        return mdOut
    def test_fromStar(self):
        if not Plugin.IS_GT30():
            print("Skipping test (required Relion > 3.1)")
            return

        partsStar = self.ds.getFile("Extract/job018/particles.star")

        print("<<< Reading optics groups from file: \n   %s\n" % partsStar)
        og = OpticsGroups.fromStar(partsStar)
        fog = og.first()

        # test hasColumn method
        for colName in ['rlnMtfFileName', 'rlnOpticsGroupName']:
            self.assertTrue(og.hasColumn(colName))

        # acq = first.getAcquisition()
        self.assertEqual(fog.rlnMtfFileName, 'mtf_k2_200kV.star')
        self.assertEqual(fog.rlnOpticsGroupName, 'opticsGroup1')
        self.assertEqual(og['opticsGroup1'], fog)
Exemplo n.º 6
0
    def _defineParams(self, form):
        form.addSection(label='Input')

        form.addParam('inputParticles', PointerParam,
                      pointerClass='SetOfParticles',
                      pointerCondition='hasAlignmentProj',
                      label="Input particles",
                      help='Select the input images from the project.')
        form.addParam('symmetryGroup', StringParam, default='c1',
                      label="Symmetry group",
                      help='See [[Relion Symmetry][http://www2.mrc-lmb.cam.ac.uk/'
                           'relion/index.php/Conventions_%26_File_formats#Symmetry]] '
                           'page for a description of the symmetry format '
                           'accepted by Relion')
        form.addParam('maxRes', FloatParam, default=-1,
                      label="Maximum resolution (A)",  
                      help='Maximum resolution (in Angstrom) to consider \n'
                           'in Fourier space (default Nyquist).')
        form.addParam('pad', FloatParam, default=2,
                      label="Padding factor")
        form.addParam('subset', EnumParam, default=0,
                      choices=['all', 'half1', 'half2'],
                      display=EnumParam.DISPLAY_HLIST,
                      label='Subset to reconstruct',
                      help='Subset of images to consider.')
        if Plugin.IS_GT30():
            form.addParam('classNum', IntParam, default=-1,
                          label='Use only this class',
                          help='Consider only this class (-1: use all classes)')
        
        form.addParam('extraParams', StringParam, default='',
                      expertLevel=LEVEL_ADVANCED,
                      label='Extra parameters: ', 
                      help='Extra parameters to *relion_reconstruct* program. '
                           'Address to Relion to see full list of options.')
        form.addSection('CTF')
        form.addParam('doCTF', BooleanParam, default=False,
                      label='Apply CTF correction?')
        form.addParam('ctfIntactFirstPeak', BooleanParam, default=False,
                      condition='doCTF',
                      label='Leave CTFs intact until first peak?')
        
        form.addParallelSection(threads=0, mpi=1)
    def test_readSetOfParticles(self):
        if not Plugin.IS_GT30():
            print("Skipping test (required Relion > 3.1)")
            return

        partsSet = self.__readParticles(
            self.ds.getFile("Extract/job018/particles.star"),
            extraLabels=['rlnNrOfSignificantSamples'])
        partsSet.write()

        first = partsSet.getFirstItem()
        first.printAll()

        self.assertAlmostEqual(first.getSamplingRate(), 1.244531)
        self.assertEqual(first.getClassId(), 4)
        self.assertTrue(hasattr(first, '_rlnNrOfSignificantSamples'))

        fog = OpticsGroups.fromImages(partsSet).first()
        self.assertEqual(fog.rlnMtfFileName, 'mtf_k2_200kV.star')
        self.assertEqual(fog.rlnOpticsGroupName, 'opticsGroup1')
    def _importMovies(self):
        print(magentaStr("\n==> Importing data - movies:"))
        protImport = self.newProtocol(
            ProtImportMovies,
            filesPath=self.ds.getFile('Movies/'),
            filesPattern='*.tiff',
            samplingRateMode=0,
            samplingRate=0.885,
            magnification=50000,
            scannedPixelSize=7.0,
            voltage=200,
            sphericalAberration=1.4,
            doseInitial=0.0,
            dosePerFrame=1.277,
            gainFile=self.ds.getFile("Movies/gain.mrc")
        )
        protImport.setObjLabel('import 24 movies')
        protImport.setObjComment('Relion 3 tutorial movies:\n\n'
                                 'Microscope Jeol Cryo-ARM 200\n'
                                 'Data courtesy of Takyuki Kato in the Namba '
                                 'group\n(Osaka University, Japan)')
        protImport = self.launchProtocol(protImport)

        # Validate output movies
        movies = getattr(protImport, 'outputMovies', None)
        self.assertIsNotNone(movies, "No movies were generated from the import")
        dims = movies.getDim()
        self.assertEqual((3710, 3838, 24), dims)
        self.assertEqual(24, movies.getSize())

        if Plugin.IS_30():
            return protImport
        else:
            print(magentaStr("\n==> Testing relion - assign optic groups:"))
            protAssign = self.newProtocol(ProtRelionAssignOpticsGroup,
                                          objLabel='assign optics',
                                          opticsGroupName='OpticsGroup1')
            protAssign.inputSet.set(protImport.outputMovies)
            return self.launchProtocol(protAssign)
Exemplo n.º 9
0
def convertMask(img, outputPath, newPix=None, newDim=None, threshold=True):
    """ Convert mask to mrc format read by Relion.
    Params:
        img: input image to be converted.
        outputPath: it can be either a directory or a file path.
            If it is a directory, the output name will be inferred from input
            and put into that directory. If it is not a directory,
            it is assumed is the output filename.
        newPix: output pixel size (equals input if None)
        newDim: output box size
    Return:
        new file name of the mask.
    """
    index, filename = img.getLocation()
    imgFn = locationToRelion(index, filename)
    inPix = img.getSamplingRate()

    if os.path.isdir(outputPath):
        outFn = os.path.join(outputPath, pwutils.replaceBaseExt(imgFn, 'mrc'))
    else:
        outFn = outputPath

    params = '--i %s --o %s --angpix %0.5f' % (imgFn, outFn, inPix)

    if newPix is not None:
        params += ' --rescale_angpix %0.5f' % newPix

    if newDim is not None:
        params += ' --new_box %d' % newDim

    if threshold:
        params += ' --threshold_above 1 --threshold_below 0'

    pwutils.runJob(None,
                   'relion_image_handler',
                   params,
                   env=Plugin.getEnviron())

    return outFn
    def test_readSetOfParticlesAfterCtf(self):
        if not Plugin.IS_GT30():
            print("Skipping test (required Relion > 3.1)")
            return

        starFile = self.ds.getFile(
            "CtfRefine/job023/particles_ctf_refine.star")
        partsReader = Table.Reader(starFile, tableName='particles')
        firstRow = partsReader.getRow()

        partsSet = self.__readParticles(starFile)
        first = partsSet.getFirstItem()

        ogLabels = ['rlnBeamTiltX', 'rlnBeamTiltY']
        extraLabels = ['rlnCtfBfactor', 'rlnCtfScalefactor', 'rlnPhaseShift']
        for l in extraLabels:
            value = getattr(first, '_%s' % l)
            self.assertIsNotNone(value, "Missing label: %s" % l)
            self.assertAlmostEqual(getattr(firstRow, l), value)

        fog = OpticsGroups.fromImages(partsSet).first()
        self.assertTrue(all(hasattr(fog, l) for l in ogLabels))

        # Also test writing and preserving extra labels
        outputStar = self.getOutputPath('particles.star')
        print(">>> Writing to particles star: %s" % outputStar)
        starWriter = convert.createWriter()
        starWriter.writeSetOfParticles(partsSet, outputStar)

        fog = OpticsGroups.fromStar(outputStar).first()
        self.assertTrue(all(hasattr(fog, l) for l in ogLabels))

        partsReader = Table.Reader(outputStar, tableName='particles')
        firstRow = partsReader.getRow()
        for l in extraLabels:
            value = getattr(first, '_%s' % l)
            self.assertIsNotNone(value, "Missing label: %s" % l)
            self.assertAlmostEqual(getattr(firstRow, l), value)
    def test_string(self):
        if not Plugin.IS_GT30():
            print("Skipping test (required Relion > 3.1)")
            return

        og = OpticsGroups.create(rlnMtfFileName='mtf_k2_200kV.star')
        fog = og.first()

        # acq = first.getAcquisition()
        self.assertEqual(fog.rlnMtfFileName, 'mtf_k2_200kV.star')
        self.assertEqual(fog.rlnOpticsGroupName, 'opticsGroup1')
        self.assertEqual(og['opticsGroup1'], fog)

        # try update by id
        og.update(1, rlnMtfFileName="new_mtf_k2.star")
        # try update by name
        og.update('opticsGroup1', rlnImageSize=512)

        fog = og.first()
        # acq = first.getAcquisition()
        self.assertEqual(fog.rlnMtfFileName, 'new_mtf_k2.star')
        self.assertEqual(fog.rlnImageSize, 512)
        self.assertEqual(fog.rlnOpticsGroupName, 'opticsGroup1')
        self.assertEqual(og['opticsGroup1'], fog)
def runRelionProgram(cmd):
    print(">>>", cmd)
    cmd = cmd.split()
    p = subprocess.Popen(cmd, env=Plugin.getEnviron())
    return p.wait()
Exemplo n.º 13
0
class ProtRelionRefine3D(ProtRefine3D, ProtRelionBase):
    """ Protocol to refine a 3D map using Relion.

Relion employs an empirical Bayesian approach to refinement
of (multiple) 3D reconstructions
or 2D class averages in electron cryo-microscopy (cryo-EM). Many
parameters of a statistical model are learned from the data,which
leads to objective and high-quality results.
    """    
    _label = '3D auto-refine'
    IS_CLASSIFY = False
    CHANGE_LABELS = ['rlnChangesOptimalOrientations',
                     'rlnChangesOptimalOffsets',
                     'rlnOverallAccuracyRotations',
                     'rlnOverallAccuracyTranslationsAngst' if Plugin.IS_GT30() else 'rlnOverallAccuracyTranslations']

    PREFIXES = ['half1_', 'half2_']
    
    def __init__(self, **args):        
        ProtRelionBase.__init__(self, **args)
        
    def _initialize(self):
        """ This function is mean to be called after the 
        working dir for the protocol have been set.
        (maybe after recovery from mapper)
        """
        ProtRelionBase._initialize(self)
        self.ClassFnTemplate = '%(ref)03d@%(rootDir)s/relion_it%(iter)03d_classes.mrcs'

    # -------------------------- INSERT steps functions -----------------------
    def _setSamplingArgs(self, args):
        """ Set sampling related params"""
        args['--auto_local_healpix_order'] = self.localSearchAutoSamplingDeg.get()
        
        if not self.doContinue:
            args['--healpix_order'] = self.angularSamplingDeg.get()
            args['--offset_range'] = self.offsetSearchRangePix.get()
            f = self._getSamplingFactor()
            args['--offset_step'] = self.offsetSearchStepPix.get() * f
            args['--auto_refine'] = ''
            args['--split_random_halves'] = ''
            
            joinHalves = "--low_resol_join_halves"
            if joinHalves not in self.extraParams.get():
                args['--low_resol_join_halves'] = 40

            if self.IS_GT30() and self.useFinerSamplingFaster:
                args['--auto_ignore_angles'] = ''
                args['--auto_resol_angles'] = ''

    # -------------------------- STEPS functions ------------------------------
    def createOutputStep(self):
        imgSet = self._getInputParticles()
        vol = Volume()
        vol.setFileName(self._getExtraPath('relion_class001.mrc'))
        vol.setSamplingRate(imgSet.getSamplingRate())
        half1 = self._getFileName("final_half1_volume", ref3d=1)
        half2 = self._getFileName("final_half2_volume", ref3d=1)
        vol.setHalfMaps([half1, half2])

        outImgSet = self._createSetOfParticles()
        outImgSet.copyInfo(imgSet)
        self._fillDataFromIter(outImgSet, self._lastIter())

        self._defineOutputs(outputVolume=vol)
        self._defineSourceRelation(self.inputParticles, vol)
        self._defineOutputs(outputParticles=outImgSet)
        self._defineTransformRelation(self.inputParticles, outImgSet)

        fsc = FSC(objLabel=self.getRunName())
        fn = self._getExtraPath("relion_model.star")
        table = Table(fileName=fn, tableName='model_class_1')
        resolution_inv = table.getColumnValues('rlnResolution')
        frc = table.getColumnValues('rlnGoldStandardFsc')
        fsc.setData(resolution_inv, frc)

        self._defineOutputs(outputFSC=fsc)
        self._defineSourceRelation(vol, fsc)

    # -------------------------- INFO functions -------------------------------
    def _validateNormal(self):
        errors = []

        if self.IS_3D and self.solventFscMask and not self.referenceMask.get():
            errors.append('When using solvent-corrected FSCs, '
                          'please provide a reference mask.')

        return errors
    
    def _validateContinue(self):
        errors = []
        continueRun = self.continueRun.get()
        continueRun._initialize()
        lastIter = continueRun._lastIter()
        
        if self.continueIter.get() == 'last':
            continueIter = lastIter
        else:
            continueIter = int(self.continueIter.get())
        
        if continueIter > lastIter:
            errors += ["You can continue only from the iteration %01d or less" % lastIter]
        
        return errors
    
    def _summaryNormal(self):
        summary = []
        if not hasattr(self, 'outputVolume'):
            summary.append("Output volume not ready yet.")
            it = self._lastIter() or -1
            if it >= 1 and it > self._getContinueIter():
                table = Table(fileName=self._getFileName('half1_model', iter=it),
                              tableName='model_general')
                row = table[0]
                resol = float(row.rlnCurrentResolution)
                summary.append("Current resolution: *%0.2f A*" % resol)
        else:
            table = Table(fileName=self._getFileName('modelFinal'),
                          tableName='model_general')
            row = table[0]
            resol = float(row.rlnCurrentResolution)
            summary.append("Final resolution: *%0.2f A*" % resol)

        return summary
    
    def _summaryContinue(self):
        return ["Continue from iteration %01d" % self._getContinueIter()]

    # -------------------------- UTILS functions ------------------------------
    def _fillDataFromIter(self, imgSet, iteration):
        tableName = 'particles@' if self.IS_GT30() else ''
        outImgsFn = self._getFileName('data', iter=iteration)
        imgSet.setAlignmentProj()
        self.reader = convert.createReader(alignType=ALIGN_PROJ,
                                           pixelSize=imgSet.getSamplingRate())

        mdIter = Table.iterRows(tableName + outImgsFn, key='rlnImageId')
        imgSet.copyItems(self._getInputParticles(), doClone=False,
                         updateItemCallback=self._updateParticle,
                         itemDataIterator=mdIter)

    def _updateParticle(self, particle, row):
        self.reader.setParticleTransform(particle, row)

        if getattr(self, '__updatingFirst', True):
            self.reader.createExtraLabels(particle, row, PARTICLE_EXTRA_LABELS)
            self.__updatingFirst = False
        else:
            self.reader.setExtraLabels(particle, row)
Exemplo n.º 14
0
 def IS_GT30(self):
     return Plugin.IS_GT30()
Exemplo n.º 15
0
    def _defineParams(self, form):
        form.addSection(label='Input')
        form.addParam('protRefine', params.PointerParam,
                      pointerClass="ProtRefine3D",
                      label='Select a previous refinement protocol',
                      help='Select any previous refinement protocol to get the '
                           '3D half maps. Note that it is recommended that the '
                           'refinement protocol uses a gold-standard method.')
        form.addParam('solventMask', params.PointerParam,
                      pointerClass="VolumeMask",
                      label='Solvent mask',
                      help="Provide a soft mask where the protein is white "
                           "(1) and the solvent is black (0). Often, the "
                           "softer the mask the higher resolution estimates "
                           "you will get. A soft edge of 5-10 pixels is often "
                           "a good edge width.")
        form.addParam('calibratedPixelSize', params.FloatParam, default=0,
                      label='Calibrated pixel size (A)',
                      help="Provide the final, calibrated pixel size in "
                           "Angstroms. If 0, the input pixel size will be used. "
                           "This value may be different from the pixel-size "
                           "used thus far, e.g. when you have recalibrated "
                           "the pixel size using the fit to a PDB model. "
                           "The X-axis of the output FSC plot will use this "
                           "calibrated value.")

        form.addSection(label='Sharpening')
        group = form.addGroup('MTF')
        group.addParam('mtf', params.FileParam,
                       label='MTF of the detector',
                       help='User-provided STAR-file with the MTF-curve '
                            'of the detector. Use the wizard to load one '
                            'of the predefined ones provided at:\n'
                            '- [[https://www3.mrc-lmb.cam.ac.uk/relion/index.php/'
                            'FAQs#Where_can_I_find_MTF_curves_for_typical_detectors.3F]'
                            '[Relion\'s Wiki FAQs]]\n'
                            ' - [[https://www.gatan.com/techniques/cryo-em#MTF][Gatan\'s website]]\n\n'
                            'Relion param: *--mtf*')
        if Plugin.IS_GT30():
            group.addParam('origPixelSize', params.FloatParam,
                           default=-1.0,
                           label='Original detector pixel size (A)',
                           help='This is the original pixel size (in Angstroms)'
                                ' in the raw (non-super-resolution!) micrographs')

        form.addParam('doAutoBfactor', params.BooleanParam, default=True,
                      label='Estimate B-factor automatically?',
                      help='If set to Yes, then the program will use the '
                           'automated procedure described by Rosenthal and '
                           'Henderson (2003, JMB) to estimate an overall '
                           'B-factor for your map, and sharpen it accordingly.')
        line = form.addLine('B-factor resolution (A): ',
                            condition='doAutoBfactor',
                            help='There are the frequency (in Angstroms), '
                                 'lowest and highest, that will be included in '
                                 'the linear fit of the Guinier plot as '
                                 'described in Rosenthal and Henderson '
                                 '(2003, JMB).')
        line.addParam('bfactorLowRes', params.FloatParam,
                      default=10.0, label='low')
        line.addParam('bfactorHighRes', params.FloatParam,
                      default=0.0, label='high')
        form.addParam('bfactor', params.FloatParam, default=-350,
                      condition='not doAutoBfactor',
                      label='Provide B-factor:',
                      help='User-provided B-factor (in A^2) for map '
                           'sharpening, e.g. -400. Use negative values for '
                           'sharpening. Be careful: if you over-sharpen\n'
                           'your map, you may end up interpreting noise for '
                           'signal!\n'
                           'Relion param: *--adhoc_bfac*')

        form.addSection(label='Filtering')
        form.addParam('skipFscWeighting', params.BooleanParam, default=False,
                      label='Skip FSC-weighting for sharpening?',
                      help='If set to No (the default), then the output map '
                           'will be low-pass filtered according to the '
                           'mask-corrected, gold-standard FSC-curve. '
                           'Sometimes, it is also useful to provide an ad-hoc '
                           'low-pass filter (option below), as due to local '
                           'resolution variations some parts of the map may '
                           'be better and other parts may be worse than the '
                           'overall resolution as measured by the FSC. In '
                           'such  cases, set this option to Yes and provide '
                           'an ad-hoc filter as described below.')
        form.addParam('lowRes', params.FloatParam, default=5,
                      condition='skipFscWeighting',
                      label='Ad-hoc low-pass filter (A):',
                      help='This option allows one to low-pass filter the map '
                           'at a user-provided frequency (in Angstroms). When '
                           'using a resolution that is higher than the '
                           'gold-standard FSC-reported resolution, take care '
                           'not to interpret noise in the map for signal...')
        form.addParam('filterEdgeWidth', params.IntParam, default=2,
                      expertLevel=params.LEVEL_ADVANCED,
                      label='Low-pass filter edge width:',
                      help='Width of the raised cosine on the low-pass filter '
                           'edge (in resolution shells)\n'
                           'Relion param: *--filter_edge_width*')
        form.addParam('randomizeAtFsc', params.FloatParam, default=0.8,
                      expertLevel=params.LEVEL_ADVANCED,
                      label='Randomize phases threshold',
                      help='Randomize phases from the resolution where FSC '
                           'drops below this value\n'
                           'Relion param: *--randomize_at_fsc*')

        form.addParallelSection(threads=0, mpi=1)
class ProtRelionClassify3D(ProtClassify3D, ProtRelionBase):
    """
    Protocol to classify 3D using Relion Bayesian approach.
    Relion employs an empirical Bayesian approach to refinement of (multiple)
    3D reconstructions or 2D class averages in electron cryo-EM. Many
    parameters of a statistical model are learned from the data, which
    leads to objective and high-quality results.
    """

    _label = '3D classification'
    CHANGE_LABELS = ['rlnChangesOptimalOrientations',
                     'rlnChangesOptimalOffsets',
                     'rlnOverallAccuracyRotations',
                     'rlnOverallAccuracyTranslationsAngst' if Plugin.IS_GT30() else 'rlnOverallAccuracyTranslations',
                     'rlnChangesOptimalClasses']
    
    def __init__(self, **args):        
        ProtRelionBase.__init__(self, **args)
        
    def _initialize(self):
        """ This function is mean to be called after the 
        working dir for the protocol have been set.
        (maybe after recovery from mapper)
        """
        ProtRelionBase._initialize(self)
    
    # -------------------------- INSERT steps functions -----------------------
    def _setSamplingArgs(self, args):
        """ Set sampling related params. """
        if self.doImageAlignment:
            args['--healpix_order'] = self.angularSamplingDeg.get()
            args['--offset_range'] = self.offsetSearchRangePix.get()
            args['--offset_step'] = self.offsetSearchStepPix.get() * self._getSamplingFactor()

            if self.localAngularSearch:
                args['--sigma_ang'] = self.localAngularSearchRange.get() / 3.

            if relion.Plugin.IS_GT30() and self.allowCoarserSampling:
                args['--allow_coarser_sampling'] = ''

        else:
            args['--skip_align'] = ''
    
    # -------------------------- STEPS functions ------------------------------
    def createOutputStep(self):
        partSet = self.inputParticles.get()
        classes3D = self._createSetOfClasses3D(partSet)
        self._fillClassesFromIter(classes3D, self._lastIter())
        
        self._defineOutputs(outputClasses=classes3D)
        self._defineSourceRelation(self.inputParticles, classes3D)

        # create a SetOfVolumes and define its relations
        volumes = self._createSetOfVolumes()
        volumes.setSamplingRate(partSet.getSamplingRate())
        
        for class3D in classes3D:
            vol = class3D.getRepresentative()
            vol.setObjId(class3D.getObjId())
            volumes.append(vol)
        
        self._defineOutputs(outputVolumes=volumes)
        self._defineSourceRelation(self.inputParticles, volumes)
        
        if not self.doContinue:
            self._defineSourceRelation(self.referenceVolume, classes3D)
            self._defineSourceRelation(self.referenceVolume, volumes)
    
    # -------------------------- INFO functions -------------------------------
    def _validateNormal(self):
        errors = []
        return errors
    
    def _validateContinue(self):
        errors = []
        continueRun = self.continueRun.get()
        continueRun._initialize()
        lastIter = continueRun._lastIter()
        
        if self.continueIter.get() == 'last':
            continueIter = lastIter
        else:
            continueIter = int(self.continueIter.get())
        
        if continueIter > lastIter:
            errors += ["You can continue only from the iteration %01d or less" % lastIter]
        
        return errors
    
    def _summaryNormal(self):
        summary = []
        it = self._lastIter() or -1
        if it >= 1:
            table = Table(fileName=self._getFileName('model', iter=it),
                          tableName='model_general')
            row = table[0]
            resol = float(row.rlnCurrentResolution)
            summary.append("Current resolution: *%0.2f A*" % resol)

        inputParts = self.inputParticles.get()
        sizeStr = 'None' if inputParts is None else inputParts.getSize()
        summary.append("Input Particles: *%s*\n"
                       "Classified into *%d* 3D classes\n"
                       % (sizeStr, self.numberOfClasses))
        
        return summary
    
    def _summaryContinue(self):
        summary = list()
        summary.append("Continue from iteration %01d" % self._getContinueIter())
        return summary
    
    def _methods(self):
        strline = ''
        if hasattr(self, 'outputClasses'):
            strline += 'We classified %d particles into %d 3D classes using Relion Classify3d. ' %\
                           (self.inputParticles.get().getSize(), self.numberOfClasses.get())
        return [strline]
    
    # -------------------------- UTILS functions ------------------------------
    def _fillClassesFromIter(self, clsSet, iteration):
        """ Create the SetOfClasses3D from a given iteration. """
        classLoader = convert.ClassesLoader(self, ALIGN_PROJ)
        classLoader.fillClassesFromIter(clsSet, iteration)
Exemplo n.º 17
0
    def _plotClassDistribution(self, paramName=None):
        labels = ["rlnClassDistribution", "rlnAccuracyRotations"]
        if Plugin.IS_GT30():
            labels.append("rlnAccuracyTranslationsAngst")
        else:
            labels.append("rlnAccuracyTranslations")

        iterations = range(self.firstIter, self.lastIter + 1)
        classInfo = {}

        for it in iterations:
            modelStar = self.protocol._getFileName('model', iter=it)
            table = Table(fileName=modelStar, tableName='model_classes')
            for row in table:
                i, fn = relionToLocation(row.rlnReferenceImage)
                if i == NO_INDEX:  # the case for 3D classes
                    # NOTE: Since there is not an proper ID value in
                    # the classes metadata, we are assuming that class X
                    # has a filename *_classXXX.mrc (as it is in Relion)
                    # and we take the ID from there
                    index = int(fn[-7:-4])
                else:
                    index = i

                if index not in classInfo:
                    classInfo[index] = {}
                    for l in labels:
                        classInfo[index][l] = []

                for l in labels:
                    classInfo[index][l].append(float(getattr(row, l)))

        xplotter = RelionPlotter()
        xplotter.createSubPlot("Classes distribution over iterations",
                               "Iterations", "Classes Distribution")

        # Empty list for each iteration
        iters = [[]] * len(iterations)

        l = labels[0]
        for index in sorted(classInfo.keys()):
            for it, value in enumerate(classInfo[index][l]):
                iters[it].append(value)

        ax = xplotter.getLastSubPlot()

        n = len(iterations)
        ind = range(n)
        bottomValues = [0] * n
        width = 0.45  # the width of the bars: can also be len(x) sequence

        def get_cmap(N):
            import matplotlib.cm as cmx
            import matplotlib.colors as colors
            """Returns a function that maps each index in 0, 1, ... N-1 to a distinct
            RGB color."""
            color_norm = colors.Normalize(vmin=0, vmax=N)  # -1)
            scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')

            def map_index_to_rgb_color(ind):
                return scalar_map.to_rgba(ind)

            return map_index_to_rgb_color

        cmap = get_cmap(len(classInfo))

        for classId in sorted(classInfo.keys()):
            values = classInfo[classId][l]
            ax.bar(ind,
                   values,
                   width,
                   label='class %s' % classId,
                   bottom=bottomValues,
                   color=cmap(classId))
            bottomValues = [a + b for a, b in zip(bottomValues, values)]

        ax.get_xaxis().set_ticks([i + 0.25 for i in ind])
        ax.get_xaxis().set_ticklabels([str(i) for i in ind])
        ax.legend(loc='upper left', fontsize='xx-small')

        return [xplotter]