Exemple #1
0
    def _createCluster(self):
        """ Create the cluster with the selected particles
        from the cluster. This method will be called when
        the button 'Create Cluster' is pressed.
        """
        # Write the particles
        prot = self.protocol
        project = prot.getProject()
        inputSet = prot.getInputParticles()
        fnSqlite = prot._getTmpPath('cluster_particles.sqlite')
        cleanPath(fnSqlite)
        partSet = SetOfParticles(filename=fnSqlite)
        partSet.copyInfo(inputSet)
        for point in self.getData():
            if point.getState() == Point.SELECTED:
                particle = inputSet[point.getId()]
                partSet.append(particle)
        partSet.write()
        partSet.close()

        from continuousflex.protocols.protocol_batch_cluster import FlexBatchProtNMACluster
        #from xmipp3.protocols.nma.protocol_batch_cluster import BatchProtNMACluster
        newProt = project.newProtocol(FlexBatchProtNMACluster)
        clusterName = self.clusterWindow.getClusterName()
        if clusterName:
            newProt.setObjLabel(clusterName)
        newProt.inputNmaDimred.set(prot)
        newProt.sqliteFile.set(fnSqlite)

        project.launchProtocol(newProt)
Exemple #2
0
    def _getRandomSubset(self, imgSet, numOfParts):
        if isinstance(imgSet, SetOfClasses2D):
            partSet = self._createSetOfParticles("_averages")
            for i, cls in enumerate(imgSet):
                img = cls.getRepresentative()
                img.setSamplingRate(cls.getSamplingRate())
                img.setObjId(i + 1)
                partSet.append(img)
        else:
            partSet = imgSet

        if partSet.getSize() > numOfParts:
            newPartSet = SetOfParticles(
                filename=self._getTmpPath("particles_tmp.sqlite"))
            counter = 0
            for part in partSet.iterItems(orderBy='RANDOM()', direction='ASC'):
                if counter < numOfParts:
                    newPartSet.append(part)
                    counter += 1
                else:
                    break
        else:
            newPartSet = partSet

        return newPartSet
    def createSetOfParticles(self, setPartSqliteName, partFn, doCtf=False):
        # create a set of particles

        self.partSet = SetOfParticles(filename=setPartSqliteName)
        self.partSet.setAlignment(ALIGN_PROJ)
        self.partSet.setAcquisition(
            Acquisition(voltage=300,
                        sphericalAberration=2,
                        amplitudeContrast=0.1,
                        magnification=60000))
        self.partSet.setSamplingRate(samplingRate)
        self.partSet.setHasCTF(True)
        aList = [np.array(m) for m in mList]
        #defocus=15000 + 5000* random.random()
        for i, a in enumerate(aList):
            p = Particle()
            if doCtf:
                defocusU = defocusList[i]  #+500.
                defocusV = defocusList[i]
                ctf = CTFModel(defocusU=defocusU,
                               defocusV=defocusV,
                               defocusAngle=defocusAngle[i])
                ctf.standardize()
                p.setCTF(ctf)

            p.setLocation(i + 1, partFn)
            p.setTransform(Transform(a))
            self.partSet.append(p)

        self.partSet.write()
Exemple #4
0
 def _loadInputParticleSet(self):
     partSetFn = self.inputParticles.get().getFileName()
     updatedSet = SetOfParticles(filename=partSetFn)
     copyPartSet = SetOfParticles()
     updatedSet.loadAllProperties()
     copyPartSet.copy(updatedSet)
     updatedSet.close()
     return copyPartSet
 def test_particlesImportToStar(self):
     sqliteFn = self.ds.getFile("import/case2/particles.sqlite")
     partsSet = SetOfParticles(filename=sqliteFn)
     partsSet.loadAllProperties()
     outputStar = self.getOutputPath("particles.star")
     print(">>> Writing to particles star: %s" % outputStar)
     starWriter = convert.createWriter()
     starWriter.writeSetOfParticles(partsSet, outputStar)
Exemple #6
0
    def converParticlesStep(self, run):
        self._setrLev(run)
        self._imgFnList = []
        imgSet = self._getInputParticles()
        imgStar = self._getFileName('input_star', run=run)
        os.makedirs(self._getExtraPath('run_%02d' % run))
        subset = SetOfParticles(filename=":memory:")
        subset.copyInfo(imgSet)
        subset.setSamplingRate(self._getPixeSize())

        subsetSize = self.subsetSize.get() * self.numOfVols.get()
        minSize = min(subsetSize, imgSet.getSize())
        print("Values MinParticles: ", minSize, subsetSize)
        newIndex = 1
        for img in imgSet.iterItems(orderBy='RANDOM()', direction='ASC'):
            self._scaleImages(newIndex, img)
            newIndex += 1
            subset.append(img)
            if subsetSize > 0 and subset.getSize() == minSize:
                break

        writeSetOfParticles(subset,
                            imgStar,
                            outputDir=self._getExtraPath(),
                            alignType=ALIGN_NONE,
                            postprocessImageRow=self._postprocessParticleRow)
        if self.doCtfManualGroups:
            self._splitInCTFGroups(imgStar)
    def _loadInput(self):
        self.lastCheck = datetime.now()
        partsFile = self.inputParticles.get().getFileName()
        inPartsSet = SetOfParticles(filename=partsFile)
        inPartsSet.loadAllProperties()

        check = None
        for p in inPartsSet.iterItems(orderBy='creation', direction='DESC'):
            check = p.getObjCreation()
            break
        if self.check is None:
            writeSetOfParticles(inPartsSet, self.fnInputMd,
                                alignType=ALIGN_NONE, orderBy='creation')
        else:
            writeSetOfParticles(inPartsSet, self.fnInputMd,
                                alignType=ALIGN_NONE, orderBy='creation',
                                where='creation>"' + str(self.check) + '"')
            writeSetOfParticles(inPartsSet, self.fnInputOldMd,
                                alignType=ALIGN_NONE, orderBy='creation',
                                where='creation<"' + str(self.check) + '"')
        self.check = check

        streamClosed = inPartsSet.isStreamClosed()
        inputSize = inPartsSet.getSize()

        inPartsSet.close()

        return inputSize, streamClosed
    def test_mrcsLink(self):
        """ In this case just a link with .mrcs extension 
        should be created """
        print(magentaStr("\n==> Testing relion - link mrc stack to mrcs:"))
        stackFile = self.dsEmx.getFile('particles/particles.mrc')
        partSet = SetOfParticles(filename=':memory:')

        for i in range(1, 10):
            particle = Particle()
            particle.setLocation(i, stackFile)
            partSet.append(particle)

        outputDir = self.getOutputPath()
        filesDict = convert.convertBinaryFiles(partSet, outputDir)
        print(filesDict)
    def retrieveTrainSets(self):
        """ Retrieve, link and return a setOfParticles
            corresponding to the NegativeTrain DeepConsensus trainning set
            with certain extraction conditions (phaseFlip/invContrast)
        """
        prefixYES = ''
        prefixNO = 'no'
        modelType = "negativeTrain_%sPhaseFlip_%sInvert.mrcs" % (
            prefixYES if self.doInvert.get() else prefixNO,
            prefixYES if self.ignoreCTF.get() else prefixNO)
        modelPath = xmipp3.Plugin.getModel("deepConsensus", modelType)
        print("Precompiled negative particles found at %s" % (modelPath))
        modelFn = self._getTmpPath(modelType)
        pwutils.createLink(modelPath, modelFn)

        tmpSqliteSuff = "AddTrain"
        partSet = self._createSetOfParticles(tmpSqliteSuff)
        img = SetOfParticles.ITEM_TYPE()

        imgh = ImageHandler()
        _, _, _, n = imgh.getDimensions(modelFn)
        if n > 1:
            for index in range(1, n + 1):
                img.cleanObjId()
                img.setMicId(9999)
                img.setFileName(modelFn)
                img.setIndex(index)
                partSet.append(img)
        partSet.setAlignment(ALIGN_NONE)

        cleanPath(self._getPath("particles%s.sqlite" % tmpSqliteSuff))
        return partSet
 def __readParticles(self, partsStar, outputSqlite=None, **kwargs):
     outputSqlite = outputSqlite or self.getOutputPath('particles.sqlite')
     print("<<< Reading star file: \n   %s\n" % partsStar)
     cleanPath(outputSqlite)
     print(">>> Writing to particles db: \n   %s\n" % outputSqlite)
     partsSet = SetOfParticles(filename=outputSqlite)
     convert.readSetOfParticles(partsStar, partsSet, **kwargs)
     return partsSet
    def importParticles(self):
        partSet = self.protocol._createSetOfParticles()
        partSet.setObjComment('Particles imported from Frealign parfile:\n%s' %
                              self.parFile)

        # Create a local link to the input stack file
        localStack = self.protocol._getExtraPath(
            os.path.basename(self.stackFile))
        pwutils.createLink(self.stackFile, localStack)
        # Create a temporary set only with location
        tmpSet = SetOfParticles(filename=':memory:')
        tmpSet.readStack(localStack)
        self._setupSet(tmpSet)

        # Update both samplingRate and acquisition with parameters
        # selected in the protocol form
        self._setupSet(partSet)
        # Now read the alignment parameters from par file
        readSetOfParticles(tmpSet, partSet, self.parFile)
        partSet.setHasCTF(True)
        # Register the output set of particles
        self.protocol._defineOutputs(outputParticles=partSet)
Exemple #12
0
 def _checkNewInput(self):
     # Check if there are new particles to process from the input set
     partsFile = self.inputParticles.get().getFileName()
     partsSet = SetOfParticles(filename=partsFile)
     partsSet.loadAllProperties()
     self.SetOfParticles = [m.clone() for m in partsSet]
     self.streamClosed = partsSet.isStreamClosed()
     partsSet.close()
     partsSet = self._createSetOfParticles()
     readSetOfParticles(self._getExtraPath("allDone.xmd"), partsSet)
     newParts = any(m.getObjId() not in partsSet
                    for m in self.SetOfParticles)
     outputStep = self._getFirstJoinStep()
     if newParts:
         fDeps = self._insertNewPartsSteps(self.insertedDict,
                                           self.SetOfParticles)
         if outputStep is not None:
             outputStep.addPrerequisites(*fDeps)
         self.updateSteps()
    def _allParticles(self, iterate=False):
        # A handler function to iterate over the particles
        inputSet = self.inputSet.get()

        if self.isInputClasses():
            iterParticles = inputSet.iterClassItems()
            if iterate:
                return iterParticles
            else:
                particles = SetOfParticles(filename=":memory:")
                particles.copyInfo(inputSet.getFirstItem())
                particles.copyItems(iterParticles)
                return particles
        else:
            if iterate:
                return inputSet.iterItems()
            else:
                return inputSet
    def test_hdfToStk(self):
        """ In this case the hdf stack files should be converted
        to .stk spider files for Relion.
        """
        print(magentaStr("\n==> Testing relion - convert hdf files to mrcs:"))
        stackFiles = [
            'BPV_1386_ptcls.hdf', 'BPV_1387_ptcls.hdf', 'BPV_1388_ptcls.hdf'
        ]

        partSet = SetOfParticles(filename=':memory:')

        for fn in stackFiles:
            particle = Particle()
            particle.setLocation(1, self.ds.getFile('particles/%s' % fn))
            partSet.append(particle)

        outputDir = self.getOutputPath()
        filesDict = convert.convertBinaryFiles(partSet, outputDir)
        partSet.close()
        print(filesDict)
    def createOutputStep(self):
        fnTilted = self._getExtraPath("images_tilted.xmd")
        fnUntilted = self._getExtraPath("images_untilted.xmd")

        # Create outputs SetOfParticles both for tilted and untilted
        imgSetU = self._createSetOfParticles(suffix="Untilted")
        imgSetU.copyInfo(self.uMics)
        imgSetT = self._createSetOfParticles(suffix="Tilted")
        imgSetT.copyInfo(self.tMics)

        sampling = self.getMicSampling() if self._micsOther(
        ) else self.getCoordSampling()
        if self._doDownsample():
            sampling *= self.downFactor.get()
        imgSetU.setSamplingRate(sampling)
        imgSetT.setSamplingRate(sampling)

        # set coords from the input, will update later if needed
        imgSetU.setCoordinates(
            self.inputCoordinatesTiltedPairs.get().getUntilted())
        imgSetT.setCoordinates(
            self.inputCoordinatesTiltedPairs.get().getTilted())

        # Read untilted and tilted particles on a temporary object (also disabled particles)
        imgSetAuxU = SetOfParticles(filename=':memory:')
        imgSetAuxU.copyInfo(imgSetU)
        readSetOfParticles(fnUntilted, imgSetAuxU, removeDisabled=False)

        imgSetAuxT = SetOfParticles(filename=':memory:')
        imgSetAuxT.copyInfo(imgSetT)
        readSetOfParticles(fnTilted, imgSetAuxT, removeDisabled=False)

        # calculate factor for coords scaling
        factor = 1 / self.samplingFactor
        if self._doDownsample():
            factor /= self.downFactor.get()

        coordsT = self.getCoords().getTilted()
        # For each untilted particle retrieve micId from SetOfCoordinates untilted
        for imgU, coordU in izip(imgSetAuxU, self.getCoords().getUntilted()):
            # FIXME: Remove this check when sure that objIds are equal
            id = imgU.getObjId()
            if id != coordU.getObjId():
                raise Exception(
                    'ObjIds in untilted img and coord are not the same!!!!')
            imgT = imgSetAuxT[id]
            coordT = coordsT[id]

            # If both particles are enabled append them
            if imgU.isEnabled() and imgT.isEnabled():
                if self._micsOther() or self._doDownsample():
                    coordU.scale(factor)
                    coordT.scale(factor)
                imgU.setCoordinate(coordU)
                imgSetU.append(imgU)
                imgT.setCoordinate(coordT)
                imgSetT.append(imgT)

        if self.doFlip:
            imgSetU.setIsPhaseFlipped(self.ctfUntilt.hasValue())
            imgSetU.setHasCTF(self.ctfUntilt.hasValue())
            imgSetT.setIsPhaseFlipped(self.ctfTilt.hasValue())
            imgSetT.setHasCTF(self.ctfTilt.hasValue())
        imgSetU.write()
        imgSetT.write()

        # Define output ParticlesTiltPair
        outputset = ParticlesTiltPair(
            filename=self._getPath('particles_pairs.sqlite'))
        outputset.setTilted(imgSetT)
        outputset.setUntilted(imgSetU)
        for imgU, imgT in izip(imgSetU, imgSetT):
            outputset.append(TiltPair(imgU, imgT))

        outputset.setCoordsPair(self.inputCoordinatesTiltedPairs.get())
        self._defineOutputs(outputParticlesTiltPair=outputset)
        self._defineSourceRelation(self.inputCoordinatesTiltedPairs, outputset)
class TestSubProj(BaseTest):
    @classmethod
    def setUpClass(cls):
        setupTestProject(cls)

    def createSetOfParticles(self, setPartSqliteName, partFn, doCtf=False):
        # create a set of particles

        self.partSet = SetOfParticles(filename=setPartSqliteName)
        self.partSet.setAlignment(ALIGN_PROJ)
        self.partSet.setAcquisition(
            Acquisition(voltage=300,
                        sphericalAberration=2,
                        amplitudeContrast=0.1,
                        magnification=60000))
        self.partSet.setSamplingRate(samplingRate)
        self.partSet.setHasCTF(True)
        aList = [np.array(m) for m in mList]
        #defocus=15000 + 5000* random.random()
        for i, a in enumerate(aList):
            p = Particle()
            if doCtf:
                defocusU = defocusList[i]  #+500.
                defocusV = defocusList[i]
                ctf = CTFModel(defocusU=defocusU,
                               defocusV=defocusV,
                               defocusAngle=defocusAngle[i])
                ctf.standardize()
                p.setCTF(ctf)

            p.setLocation(i + 1, partFn)
            p.setTransform(Transform(a))
            self.partSet.append(p)

        self.partSet.write()

    def createProjection(self, proj, num, baseName):
        img = emlib.Image()
        img.setDataType(emlib.DT_FLOAT)
        img.resize(projSize, projSize)

        #img.initRandom(0., 1., emlib.XMIPP_RND_GAUSSIAN)
        img.initConstant(0.)
        for coor in proj:
            value = img.getPixel(coor[0], coor[1], coor[2], coor[3])
            img.setPixel(coor[0], coor[1], coor[2], coor[3],
                         coor[4] + value)  # coor4 is the pixel value
        img.write("%d@" % num + baseName)

    def createVol(self, volume):
        vol = emlib.Image()
        vol.setDataType(emlib.DT_FLOAT)
        vol.resize(projSize, projSize, projSize)

        #vol.initRandom(0., .5, emlib.XMIPP_RND_UNIFORM)
        vol.initConstant(0.)
        for coor in volume:
            vol.setPixel(coor[0], coor[1], coor[2], coor[3],
                         coor[4])  # coor4 is the pixel value
        vol.write(self.volBaseFn)

    def createMask(self, _maskName):
        vol = emlib.Image()
        vol.setDataType(emlib.DT_FLOAT)
        vol.resize(projSize, projSize, projSize)

        vol.initConstant(0.0)  #ROB: not sure this is needed
        halfDim = int(projSize / 2)
        maskRadius2 = maskRadius * maskRadius
        for i in range(-halfDim, halfDim):
            for j in range(-halfDim, halfDim):
                for k in range(-halfDim, halfDim):
                    if (i * i + j * j + k * k) < maskRadius2:
                        vol.setPixel(0, k + halfDim, i + halfDim, j + halfDim,
                                     1.)  # coor4 is the pixel value
        vol.write(_maskName)

    def applyCTF(self, setPartMd):

        writeSetOfParticles(self.partSet, setPartMd)
        md1 = emlib.MetaData()
        md1.setColumnFormat(False)
        idctf = md1.addObject()
        _acquisition = self.partSet.getAcquisition()
        for part in self.partSet:
            baseFnCtf = self.proj.getTmpPath(
                "kk")  #self._getTmpPath("ctf_%d.param"%mic)

            md1.setValue(emlib.MDL_CTF_SAMPLING_RATE, samplingRate, idctf)
            md1.setValue(emlib.MDL_CTF_VOLTAGE, 200., idctf)
            ctf = part.getCTF()
            udefocus = ctf.getDefocusU()
            vdefocus = ctf.getDefocusV()
            angle = ctf.getDefocusAngle()
            md1.setValue(emlib.MDL_CTF_DEFOCUSU, udefocus, idctf)
            md1.setValue(emlib.MDL_CTF_DEFOCUSV, vdefocus, idctf)
            md1.setValue(emlib.MDL_CTF_DEFOCUS_ANGLE, 180.0 * random.random(),
                         idctf)
            md1.setValue(emlib.MDL_CTF_CS, 2., idctf)
            md1.setValue(emlib.MDL_CTF_Q0, 0.07, idctf)
            md1.setValue(emlib.MDL_CTF_K, 1., idctf)

            md1.write(baseFnCtf)
        ##writeSetOfParticles(self.partSet, setPartMd)
        #apply ctf
        args = " -i %s" % setPartMd
        args += " -o %s" % self.proj.getTmpPath(setPartCtfName)
        args += " -f ctf %s" % baseFnCtf
        args += " --sampling %f" % samplingRate
        Plugin.runXmippProgram("xmipp_transform_filter", args)

        args = " -i %s" % setPartMd
        args += " -o %s" % self.proj.getTmpPath(setPartCtfPosName)
        args += " -f ctfpos %s" % baseFnCtf
        args += " --sampling %f" % samplingRate
        Plugin.runXmippProgram("xmipp_transform_filter", args)

    def importData(self, baseFn, objLabel, protType, importFrom):
        prot = self.newProtocol(protType,
                                objLabel=objLabel,
                                filesPath=baseFn,
                                maskPath=baseFn,
                                sqliteFile=baseFn,
                                haveDataBeenPhaseFlipped=False,
                                magnification=10000,
                                samplingRate=samplingRate,
                                importFrom=importFrom)
        self.launchProtocol(prot)
        return prot

    def test_pattern(self):
        #1) create fake protocol so I have a place to save data
        #prot = self.launchFakeProtocol()
        #output stack
        self.setPartName = self.proj.getTmpPath(setPartName)
        self.setPartSqliteName = self.proj.getTmpPath(setPartSqliteName)
        self.setPartSqliteCtfName = self.proj.getTmpPath(setPartSqliteCtfName)
        self.setPartSqliteCTfPosName = self.proj.getTmpPath(
            setPartSqliteCTfPosName)
        self.kksqlite = self.proj.getTmpPath("kk.sqlite")
        self.setPartMd = self.proj.getTmpPath(setPartNameMd)
        self.volBaseFn = self.proj.getTmpPath(volName)
        self.maskName = self.proj.getTmpPath(maskName)

        #2) create projections and sets of particles
        self.createProjection(proj1, 1, self.setPartName)
        self.createProjection(proj2, 2, self.setPartName)
        self.createProjection(proj3, 3, self.setPartName)
        self.createSetOfParticles(self.setPartSqliteCTfPosName,
                                  self.proj.getTmpPath(setPartCtfPosName),
                                  True)
        self.createSetOfParticles(self.setPartSqliteCtfName,
                                  self.proj.getTmpPath(setPartCtfName), True)
        self.createSetOfParticles(self.setPartSqliteName, self.setPartName,
                                  False)
        #create auxiliary setofparticles
        self.createSetOfParticles(self.kksqlite, self.setPartName, True)
        #4) apply CTF
        self.applyCTF(self.setPartMd)
        #5) create volume
        self.createVol(vol1)
        #6) create mask
        self.createMask(self.maskName)

        #import three projection datasets, volume and mask
        protPlainProj = self.importData(
            self.setPartSqliteName, "plain projection", ProtImportParticles,
            ProtImportParticles.IMPORT_FROM_SCIPION)
        protCTFProj = self.importData(self.setPartSqliteCtfName,
                                      "ctf projection", ProtImportParticles,
                                      ProtImportParticles.IMPORT_FROM_SCIPION)
        protCTFposProj = self.importData(
            self.setPartSqliteCTfPosName, "pos ctf projection",
            ProtImportParticles, ProtImportParticles.IMPORT_FROM_SCIPION)
        _protImportVol = self.importData(
            os.path.abspath(self.proj.getTmpPath(volName)), "3D reference",
            ProtImportVolumes, ProtImportParticles.IMPORT_FROM_FILES)
        _protImportMask = self.importData(
            self.proj.getTmpPath(maskName), "3D mask", ProtImportMask,
            ProtImportParticles.IMPORT_FROM_FILES)
        mask = VolumeMask()
        mask.setFileName(self.proj.getTmpPath(maskName))
        mask.setSamplingRate(samplingRate)

        #launch substract protocol <<<<<<<<<<<<<<<<<<<<<<<<<<
        protSubtract = self.newProtocol(XmippProtSubtractProjection)
        protSubtract.inputParticles.set(protPlainProj.outputParticles)
        protSubtract.inputVolume.set(_protImportVol.outputVolume)
        protSubtract.refMask.set(_protImportMask.outputMask)
        protSubtract.projType.set(XmippProtSubtractProjection.CORRECT_NONE)
        self.launchProtocol(protSubtract)

        protSubtractCTF = self.newProtocol(XmippProtSubtractProjection)
        protSubtractCTF.inputParticles.set(protCTFProj.outputParticles)
        protSubtractCTF.inputVolume.set(_protImportVol.outputVolume)
        protSubtractCTF.refMask.set(_protImportMask.outputMask)
        protSubtractCTF.projType.set(
            XmippProtSubtractProjection.CORRECT_FULL_CTF)
        self.launchProtocol(protSubtractCTF)

        protSubtractCTFpos = self.newProtocol(XmippProtSubtractProjection)
        protSubtractCTFpos.inputParticles.set(protCTFposProj.outputParticles)
        protSubtractCTFpos.inputVolume.set(_protImportVol.outputVolume)
        protSubtractCTFpos.refMask.set(_protImportMask.outputMask)
        protSubtractCTFpos.projType.set(
            XmippProtSubtractProjection.CORRECT_PHASE_FLIP)
        self.launchProtocol(protSubtractCTFpos)

        protRelionRefine3D = self.newProtocol(ProtRelionRefine3D,
                                              doCTF=False,
                                              runMode=1,
                                              maskDiameterA=340,
                                              symmetryGroup="c1",
                                              numberOfMpi=3,
                                              numberOfThreads=2)
        protRelionRefine3D.inputParticles.set(protCTFProj.outputParticles)
        protRelionRefine3D.referenceVolume.set(_protImportVol.outputVolume)
        protRelionRefine3D.doGpu.set(False)
        self.launchProtocol(protRelionRefine3D)

        protMask = self.newProtocol(ProtRelionCreateMask3D, threshold=0.045)
        protMask.inputVolume.set(protRelionRefine3D.outputVolume)
        self.launchProtocol(protMask)

        protSubtract = self.newProtocol(ProtRelionSubtract,
                                        refMask=protMask.outputMask,
                                        numberOfMpi=2)
        protSubtract.inputProtocol.set(protRelionRefine3D)
        self.launchProtocol(protSubtract)
        self.assertIsNotNone(protSubtract.outputParticles,
                             "There was a problem with subtract projection")

        self.assertTrue(True)
    def _createSetOfParts(self, nMics=10, nOptics=2, partsPerMic=10):
        micSet = self._createSetOfMics(nMics, nOptics)
        outputSqlite = self.getOutputPath('particles.sqlite')
        cleanPath(outputSqlite)
        print(">>> Writing to particles db: %s" % outputSqlite)
        outputParts = SetOfParticles(filename=outputSqlite)
        outputParts.setSamplingRate(1.234)
        outputParts.setAcquisition(micSet.getAcquisition())

        part = SetOfParticles.ITEM_TYPE()
        coord = Coordinate()

        for mic in micSet:
            for i in range(1, partsPerMic + 1):
                part.setLocation(i, mic.getFileName().replace('mrc', 'mrcs'))
                coord.setPosition(x=np.random.randint(0, 1000),
                                  y=np.random.randint(0, 1000))
                coord.setMicrograph(mic)
                part.setObjId(None)
                part.setCoordinate(coord)
                part.setAcquisition(mic.getAcquisition())
                outputParts.append(part)

        outputParts.write()

        return outputParts
    def launchTest(self, fileKey, mList, alignType=None, **kwargs):
        """ Helper function to launch similar alignment tests
        given the EMX transformation matrix.
        Params:
            fileKey: the file where to grab the input stack images.
            mList: the matrix list of transformations
                (should be the same length of the stack of images)
        """
        print("\n")
        print("*" * 80)
        print("* Launching test: ", fileKey)
        print("*" * 80)

        is2D = alignType == ALIGN_2D

        if fileKey == 'alignShiftRotExp':
            # relion requires mrcs stacks
            origFn = self.dataset.getFile(fileKey)
            stackFn = replaceExt(origFn, ".mrcs")
            createLink(origFn, stackFn)
        else:
            stackFn = self.dataset.getFile(fileKey)

        partFn1 = self.getOutputPath(fileKey + "_particles1.sqlite")
        mdFn = self.getOutputPath(fileKey + "_particles.star")
        partFn2 = self.getOutputPath(fileKey + "_particles2.sqlite")

        if self.IS_ALIGNMENT:
            outputFn = self.getOutputPath(fileKey + "_output.mrcs")
            outputFnRelion = self.getOutputPath(fileKey + "_output")
            goldFn = self.dataset.getFile(fileKey + '_Gold_output_relion.mrcs')
        else:
            outputFn = self.getOutputPath(fileKey + "_output.vol")
            goldFn = self.dataset.getFile("reconstruction/gold/" + fileKey +
                                          '_Gold_rln_output.vol')

        if PRINT_FILES:
            print("BINARY DATA: ", stackFn)
            print("SET1:        ", partFn1)
            print("  MD:        ", mdFn)
            print("SET2:        ", partFn2)
            print("OUTPUT:      ", outputFn)
            print("GOLD:        ", goldFn)

        if alignType == ALIGN_2D or alignType == ALIGN_PROJ:
            partSet = SetOfParticles(filename=partFn1)
        else:
            partSet = SetOfVolumes(filename=partFn1)
        partSet.setAlignment(alignType)

        acq = Acquisition(voltage=300,
                          sphericalAberration=2,
                          amplitudeContrast=0.1,
                          magnification=60000)
        og = OpticsGroups.create(rlnMtfFileName="mtfFile1.star",
                                 rlnImageSize=128)
        partSet.setSamplingRate(1.0)
        partSet.setAcquisition(acq)
        og.toImages(partSet)
        # Populate the SetOfParticles with images
        # taken from images.mrc file
        # and setting the previous alignment parameters
        aList = [np.array(m) for m in mList]
        for i, a in enumerate(aList):
            p = Particle()
            p.setLocation(i + 1, stackFn)
            p.setTransform(Transform(a))
            partSet.append(p)
        # Write out the .sqlite file and check that are correctly aligned
        print("Partset", partFn1)
        partSet.printAll()
        partSet.write()
        # Convert to a Xmipp metadata and also check that the images are
        # aligned correctly
        if alignType == ALIGN_2D or alignType == ALIGN_PROJ:
            starWriter = convert.createWriter()
            starWriter.writeSetOfParticles(partSet, mdFn, alignType=alignType)
            partSet2 = SetOfParticles(filename=partFn2)
        else:
            convert.writeSetOfVolumes(partSet, mdFn, alignType=alignType)
            partSet2 = SetOfVolumes(filename=partFn2)
        # Let's create now another SetOfImages reading back the written
        # Xmipp metadata and check one more time.
        partSet2.copyInfo(partSet)
        if alignType == ALIGN_2D or alignType == ALIGN_PROJ:
            convert.readSetOfParticles(mdFn, partSet2, alignType=alignType)
        else:
            convert.readSetOfParticles(mdFn,
                                       partSet2,
                                       rowToFunc=convert.rowToVolume,
                                       alignType=alignType)

        partSet2.write()

        if PRINT_MATRIX:
            for i, img in enumerate(partSet2):
                m1 = aList[i]
                m2 = img.getTransform().getMatrix()
                print("-" * 5)
                print(img.getFileName(), img.getIndex())
                print('m1:\n', m1, convert.geometryFromMatrix(m1, False))

                print('m2:\n', m2, convert.geometryFromMatrix(m2, False))
                self.assertTrue(np.allclose(m1, m2, rtol=1e-2))

        # Launch apply transformation and check result images
        runRelionProgram(self.CMD % locals())

        if SHOW_IMAGES:
            runRelionProgram('scipion show %(outputFn)s' % locals())

        if os.path.exists(goldFn):
            self.assertTrue(
                ImageHandler().compareData(goldFn, outputFn, tolerance=0.001),
                "Different data files:\n>%s\n<%s" % (goldFn, outputFn))

        if CLEAN_IMAGES:
            cleanPath(outputFn)
    def _runRelionStep(self, fnOut, fnBlock, fnDir, imgNo, fnGallery):
        relPart = SetOfParticles(filename=":memory:")
        convXmp.readSetOfParticles(fnBlock, relPart)

        if self.copyAlignment.get():
            alignType = relPart.getAlignment()
            alignType != ALIGN_NONE
        else:
            alignType = ALIGN_NONE

        alignToPrior = getattr(self, 'alignmentAsPriors', False)
        fillRandomSubset = getattr(self, 'fillRandomSubset', False)
        fnRelion = self._getExtraPath('relion_%s.star' % imgNo)

        writeSetOfParticles(relPart, fnRelion, self._getExtraPath(),
                            alignType=alignType,
                            postprocessImageRow=self._postprocessParticleRow,
                            fillRandomSubset=fillRandomSubset)
        if alignToPrior:
            mdParts = md.MetaData(fnRelion)
            self._copyAlignAsPriors(mdParts, alignType)
            mdParts.write(fnRelion)
        if self.doCtfManualGroups:
            self._splitInCTFGroups(fnRelion)

        args = {}
        self._setNormalArgs(args)
        args['--i'] = fnRelion
        args['--o'] = fnOut
        if self.referenceClassification.get():
            fnRef = "%s@%s" % (imgNo, fnGallery)
            args['--ref'] = fnRef
        self._setComputeArgs(args)

        params = ' '.join(
            ['%s %s' % (k, str(v)) for k, v in args.items()])
        print('Vamos a correr relion', params)
        self.runJob(self._getRelionProgram(), params)

        clsDistList = []
        it = self.numberOfIterations.get()
        model = '_it%03d_' % it
        fnModel = (fnOut + model + 'model.star')

        block = md.getBlocksInMetaDataFile(fnModel)[1]
        fnBlock = "%s@%s" % (block, fnModel)

        mdBlocks = md.MetaData(fnBlock)
        fnData = (fnOut + model + 'data.star')

        for objId in mdBlocks:
            clsDist = mdBlocks.getValue(md.RLN_MLMODEL_PDF_CLASS, objId)
            clsDistList.append(clsDist)

        fnOutput = self._getPath('output_particles.xmd')
        mdOutput = self._getMetadata(fnOutput)

        mdParticles = md.MetaData(fnData)
        for row in md.iterRows(mdParticles):
            objId = mdOutput.addObject()
            clsNum = row.getValue('rlnClassNumber')
            clsDist = clsDistList[clsNum-1]

            if clsDist >= self.thresholdValue.get():
                row.setValue(md.MDL_ENABLED, 1)
            else:
                row.setValue(md.MDL_ENABLED, -1)
            row.writeToMd(mdOutput, objId)
        mdOutput.write(fnOutput)