Пример #1
0
	def set_shrink(self,shrink):
		"""This actually loads the data ..."""
		
		self.shrink=shrink
		# Deal with particles
		n=min(EMUtil.get_image_count(self.particle_file),800)
		self.ptcl_data=[i for i in EMData.read_images(self.particle_file,list(range(n))) if i!=None]
		if self.shrink>1 :
			for i in self.ptcl_data : i.process_inplace("math.meanshrink",{"n":self.shrink})
		for i in self.ptcl_data : i.process_inplace("normalize.edgemean",{})

		if self.ptcl_display==None : 
			self.ptcl_display = EMImageMXWidget()
			self.ptcl_display.set_mouse_mode("App")
			self.ptcl_display.mx_image_selected.connect(self.ptcl_selected)
			self.ptcl_display.module_closed.connect(self.on_mx_display_closed)
		self.ptcl_display.set_data(self.ptcl_data)

		# deal with projections
		self.proj_data=EMData.read_images(self.projection_file)
		if self.shrink>1 :
			for i in self.proj_data : i.process_inplace("math.meanshrink",{"n":self.shrink})
		for i in self.proj_data : i.process_inplace("normalize.edgemean",{})

		eulers = [i["xform.projection"] for i in self.proj_data]
		self.specify_eulers(eulers)
		
		for i in self.proj_data : i["cmp"]=0.0
		self.set_emdata_list_as_data(self.proj_data,"cmp")
Пример #2
0
	def object_picked(self,object_number):
		if object_number == self.current_projection: return
		self.current_projection = object_number
		resize_necessary = False
		if self.mx_display == None:
			self.mx_display = EMImageMXWidget()
			self.mx_display.module_closed.connect(self.on_mx_display_closed)
			resize_necessary = True

		#if self.frc_display == None:
			#self.frc_display = EMPlot2DWidget()
#			QtCore.QObject.connect(self.frc_display,QtCore.SIGNAL("module_closed"),self.on_frc_display_closed)

		self.update_display(False)

		if resize_necessary:
			get_application().show_specific(self.mx_display)
			self.mx_display.optimally_resize()
#			get_application().show_specific(self.frc_display)
#			self.frc_display.optimally_resize()
		else:
			self.mx_display.updateGL()
#			self.frc_display.updateGL()
			
		if object_number != self.special_euler:
			self.special_euler = object_number
			self.regen_dl()
Пример #3
0
    def __init__(self, rctwidget):
        self.rctwidget = rctwidget
        self.window = EMImageMXWidget(application=self.rctwidget.parent_window)
        self.window.set_display_values(["tilt", "PImg#"])
        self.window.set_mouse_mode("App")
        self.window.setWindowTitle("Particles")
        self.window.optimally_resize()

        self.connect_signals()
        self.listsofparts = []
        self.numlists = 0
        self.closed = False
Пример #4
0
        def __init__(self, app, classes):
            """Effectively a modal dialog for selecting masking parameters interactively
			"""
            self.app = app
            QtGui.QWidget.__init__(self, None)
            nx = classes[0]["nx"]

            self.classes = classes
            self.classview = EMImageMXWidget(self, classes)

            self.vbl = QtGui.QVBoxLayout(self)
            self.vbl.addWidget(self.classview)

            self.hbl = QtGui.QHBoxLayout()

            self.cmode = CheckBox(self, "orig", value=1)
            self.hbl.addWidget(self.cmode)

            self.slpres = ValSlider(self, (0.001, 0.2), "Low-pass Filter:",
                                    0.03, 90)
            self.hbl.addWidget(self.slpres)

            self.snmax = ValSlider(self, (0, 20), "NMax:", 5, 90)
            self.snmax.intonly = 1
            self.hbl.addWidget(self.snmax)

            self.sshells = ValSlider(self, (0, 40), "NShells:", nx // 8, 90)
            self.sshells.intonly = 1
            self.hbl.addWidget(self.sshells)

            self.ssigma = ValSlider(self, (0, 2), "Sigma:", 0.333, 90)
            self.hbl.addWidget(self.ssigma)

            self.bok = QtGui.QPushButton("OK")
            self.hbl.addWidget(self.bok)

            self.vbl.addLayout(self.hbl)

            QtCore.QObject.connect(self.cmode, QtCore.SIGNAL("valueChanged"),
                                   self.newParm)
            QtCore.QObject.connect(self.slpres, QtCore.SIGNAL("valueChanged"),
                                   self.newParm)
            QtCore.QObject.connect(self.snmax, QtCore.SIGNAL("valueChanged"),
                                   self.newParm)
            QtCore.QObject.connect(self.sshells, QtCore.SIGNAL("valueChanged"),
                                   self.newParm)
            QtCore.QObject.connect(self.ssigma, QtCore.SIGNAL("valueChanged"),
                                   self.newParm)
            QtCore.QObject.connect(self.bok, QtCore.SIGNAL("clicked(bool)"),
                                   self.close)

            self.newParm()
Пример #5
0
	def __init__(self,application,options):
		
		self.options=options
		self.check_path(options.path)
		self.get_data(0)

		QtWidgets.QWidget.__init__(self)
		self.imgview = EMImage2DWidget()
		self.setCentralWidget(QtWidgets.QWidget())
		self.gbl = QtWidgets.QGridLayout(self.centralWidget())
		
		self.lb_name=QtWidgets.QLabel(self.tomoname)
		self.lb_name.setWordWrap(True)
		self.gbl.addWidget(self.lb_name, 0,0,1,2)
		
		self.iterlst=QtWidgets.QListWidget()
		self.iterlst.itemflags=Qt.ItemFlags(Qt.ItemIsSelectable)
		
		for i in sorted(self.losses.keys()):
			txt="{:d}  :  loss = {:.1f}".format(i, self.losses[i])
			item=QtWidgets.QListWidgetItem(txt)
			self.iterlst.addItem(item)
			
		
		self.iterlst.currentRowChanged[int].connect(self.update_list)
		self.gbl.addWidget(self.iterlst,1,0,1,2)
		

		
		self.app=weakref.ref(application)


		self.imgview = EMImage2DWidget()
		self.boxes=Boxes(self.imgview, self.pks2d, self.dirs)
		self.shape_index = 0
		
		self.imgview.set_data(self.datafile)
		self.imgview.shapes = {0:self.boxes}
		self.imgview.show()
		self.imgview.mouseup.connect(self.on_mouseup)
		
		self.boxesviewer=EMImageMXWidget()
		self.boxesviewer.show()
		self.boxesviewer.set_mouse_mode("App")
		self.boxesviewer.setWindowTitle("Landmarks")
		self.boxesviewer.rzonce=True
Пример #6
0
    def fileUpdate(self):
        "Called when the user selects a file from the list or need to completely refresh display"

        QtGui.qApp.setOverrideCursor(Qt.BusyCursor)

        if self.vclasses == None:
            self.vclasses = EMImageMXWidget()
            self.vclasses.set_mouse_mode("App")
            QtCore.QObject.connect(self.vclasses,
                                   QtCore.SIGNAL("mx_image_selected"),
                                   self.classSelect)
            QtCore.QObject.connect(self.vclasses,
                                   QtCore.SIGNAL("mx_image_double"),
                                   self.classDouble)

        self.vclasses.set_title("Classes")

        #		self.classes=EMData.read_images(self.curFile())
        self.vclasses.set_data(self.curFile(), self.curFile())
        #		self.vclasses.set_single_active_set("selected")		# This makes the 'set' representing the selected class-averages current
        self.vclasses.set_mouse_mode("App")
        self.vclasses.enable_set("evalptcl", [])

        # This makes sure the particle file is in the list of choices and is selected
        try:
            ptclfile = EMData(self.curFile(), 0, True)["class_ptcl_src"]
            i = self.wptclfile.findText(ptclfile)
            if i == -1:
                self.wptclfile.insertItem(0, ptclfile)
                self.wptclfile.setCurrentIndex(0)
            else:
                self.wptclfile.setCurrentIndex(i)
        except:
            QtGui.QMessageBox.warning(
                self, "Error !",
                "This image does not appear to be a class average. (No class_ptcl_src, etc.)"
            )
            ptclfile = "None"

        # Make sure our display widgets exist
        if self.vgoodptcl == None:
            self.vgoodptcl = EMImageMXWidget()
        self.vgoodptcl.set_title("Included Particles")

        if self.vbadptcl == None:
            self.vbadptcl = EMImageMXWidget()
        self.vbadptcl.set_title("Excluded Particles")

        self.vclasses.show()
        self.vgoodptcl.show()
        self.vbadptcl.show()

        QtGui.qApp.setOverrideCursor(Qt.ArrowCursor)
Пример #7
0
	class GUImask(QtWidgets.QWidget):
		def __init__(self,app,classes):
			"""Effectively a modal dialog for selecting masking parameters interactively
			"""
			self.app=app
			QtWidgets.QWidget.__init__(self,None)
			nx=classes[0]["nx"]
			
			self.classes=classes
			self.classview=EMImageMXWidget(self,classes)
			
			self.vbl = QtWidgets.QVBoxLayout(self)
			self.vbl.addWidget(self.classview)
			
			self.hbl = QtWidgets.QHBoxLayout()
			
			self.cmode=CheckBox(self,"orig",value=1)
			self.hbl.addWidget(self.cmode)

			self.slpres=ValSlider(self,(0.001,0.2),"Low-pass Filter:",0.03,90)
			self.hbl.addWidget(self.slpres)
			
			self.snmax=ValSlider(self,(0,20),"NMax:",5,90)
			self.snmax.intonly=1
			self.hbl.addWidget(self.snmax)
			
			self.sshells=ValSlider(self,(0,40),"NShells:",nx//8,90)
			self.sshells.intonly=1
			self.hbl.addWidget(self.sshells)
			
			self.ssigma=ValSlider(self,(0,2),"Sigma:",0.333,90)
			self.hbl.addWidget(self.ssigma)
			
			self.bok=QtWidgets.QPushButton("OK")
			self.hbl.addWidget(self.bok)

			self.vbl.addLayout(self.hbl)

			self.cmode.valueChanged.connect(self.newParm)
			self.slpres.valueChanged.connect(self.newParm)
			self.snmax.valueChanged.connect(self.newParm)
			self.sshells.valueChanged.connect(self.newParm)
			self.ssigma.valueChanged.connect(self.newParm)
			self.bok.clicked[bool].connect(self.close)
	
			self.newParm()

		def quit(self):
			self.app.close_specific(self)
			
		def newParm(self):
			if self.cmode.getValue(): 
				self.classview.set_data(self.classes)
				return
			self.masked=[i.process("filter.lowpass.gauss",{"cutoff_freq":self.slpres.value}) for i in self.classes]
			nx=self.masked[0]["nx"]
			for i,im in enumerate(self.masked):
				im.process_inplace("mask.auto2d",{"nmaxseed":int(self.snmax.value),"nshells":int(self.sshells.value),"radius":old_div(nx,10),"return_mask":1,"sigma":self.ssigma.value})
				im.process_inplace("filter.lowpass.gauss",{"cutoff_freq":0.03})
				im.mult(self.classes[i])
				
			self.classview.set_data(self.masked)
Пример #8
0
class EMClassPtclTool(QtWidgets.QWidget):
    """This class is a tab widget for inspecting particles within class-averages"""
    def __init__(self, extrafiles=None):
        QtWidgets.QWidget.__init__(self)
        self.vbl = QtWidgets.QVBoxLayout(self)

        self.extrafiles = extrafiles

        # A listwidget for selecting which class-average file we're looking at
        self.wclassfilel = QtWidgets.QLabel("Class-average File:")
        self.vbl.addWidget(self.wclassfilel)

        self.wfilesel = QtWidgets.QListWidget()
        self.vbl.addWidget(self.wfilesel)
        self.vbl.addSpacing(5)

        # A widget containing the current particle filename, editable by the user
        # If edited it will also impact set generation !
        self.wptclfilel = QtWidgets.QLabel("Particle Data File:")
        self.vbl.addWidget(self.wptclfilel)

        self.wptclfile = QtWidgets.QComboBox(self)
        self.vbl.addWidget(self.wptclfile)
        self.vbl.addSpacing(5)

        # Selection tools
        self.wselectg = QtWidgets.QGroupBox("Class Selection", self)
        self.wselectg.setFlat(False)
        self.vbl.addWidget(self.wselectg)
        self.vbl.addSpacing(5)

        self.gbl0 = QtWidgets.QGridLayout(self.wselectg)

        self.wselallb = QtWidgets.QPushButton("All")
        self.gbl0.addWidget(self.wselallb, 0, 0)

        self.wselnoneb = QtWidgets.QPushButton("Clear")
        self.gbl0.addWidget(self.wselnoneb, 0, 1)

        self.wselrangeb = QtWidgets.QPushButton("Range")
        self.gbl0.addWidget(self.wselrangeb, 1, 0)

        self.wselinvertb = QtWidgets.QPushButton("Invert")
        self.gbl0.addWidget(self.wselinvertb, 0, 2)

        self.wsel3db = QtWidgets.QPushButton("From 3D")
        self.gbl0.addWidget(self.wsel3db, 1, 2)

        self.wprocessg = QtWidgets.QGroupBox("Process results", self)
        self.wprocessg.setFlat(False)
        self.vbl.addWidget(self.wprocessg)

        self.vbl2 = QtWidgets.QVBoxLayout(self.wprocessg)

        self.wselused = CheckBox(None, "Included Ptcls", 1, 100)
        self.vbl2.addWidget(self.wselused)

        self.wselunused = CheckBox(None, "Excluded Ptcls", 1, 100)
        self.vbl2.addWidget(self.wselunused)

        # Mark particles in selected classes as bad
        self.wmarkbut = QtWidgets.QPushButton("Mark as Bad")
        self.vbl2.addWidget(self.wmarkbut)

        # Mark particles in selected classes as good
        self.wmarkgoodbut = QtWidgets.QPushButton("Mark as Good")
        self.vbl2.addWidget(self.wmarkgoodbut)

        # Make a new set from selected classes
        self.wmakebut = QtWidgets.QPushButton("Make New Set")
        self.vbl2.addWidget(self.wmakebut)
        #		self.wmakebut.setEnabled(False)

        # Save list
        self.wsavebut = QtWidgets.QPushButton("Save Particle List")
        self.vbl2.addWidget(self.wsavebut)

        # Save micrograph dereferenced lists
        self.wsaveorigbut = QtWidgets.QPushButton("Save CCD-based List")
        self.vbl2.addWidget(self.wsaveorigbut)

        self.wfilesel.itemSelectionChanged.connect(self.fileUpdate)
        self.wptclfile.currentIndexChanged[int].connect(self.ptclChange)
        self.wselallb.clicked[bool].connect(self.selAllClasses)
        self.wselnoneb.clicked[bool].connect(self.selNoClasses)
        self.wselrangeb.clicked[bool].connect(self.selRangeClasses)
        self.wselinvertb.clicked[bool].connect(self.selInvertClasses)
        self.wsel3db.clicked[bool].connect(self.sel3DClasses)
        self.wmakebut.clicked[bool].connect(self.makeNewSet)
        self.wmarkbut.clicked[bool].connect(self.markBadPtcl)
        self.wmarkgoodbut.clicked[bool].connect(self.markGoodPtcl)
        self.wsavebut.clicked[bool].connect(self.savePtclNum)
        self.wsaveorigbut.clicked[bool].connect(self.saveOrigPtclNum)

        # View windows, one for class-averages, one for good particles and one for bad particles
        self.vclasses = None
        self.vgoodptcl = None
        self.vbadptcl = None

        self.updateFiles()

    def makeNewSet(self, x):
        "Makes a new particle set based on the selected class-averages"
        setname = QtWidgets.QInputDialog.getText(
            None, "Set Name",
            "Please specify the name for the set. If you specify an existing set, new particles will be added to the end"
        )
        if setname[1] == False: return
        else: setname = setname[0]
        if setname[-4:] != ".lst": setname = setname + ".lst"
        if not "/" in setname: setname = "sets/" + setname

        lst = LSXFile(self.curPtclFile())  # lst file for dereferenceing
        lstout = LSXFile(setname)
        include = []
        # iterate over each particle from each marked class-average
        for n in self.curPtclIter(self.wselused.getValue(),
                                  self.wselunused.getValue()):
            try:
                orign, origfile, comment = lst.read(
                    n
                )  # the original file/number dereferenced from the LST file
            except:
                QtWidgets.QMessageBox.warning(
                    self, "Error !",
                    "The data_source '%s' does not follow EMAN2.1 project conventions. Cannot find raw particles for set."
                    % srcfile)
                return

            include.append((origfile, orign,
                            comment))  # build a list so we can sort by frame

        # write the new set
        for i in sorted(include):
            lstout.write(-1, i[1], i[0], i[2])

    def markBadPtcl(self, x):
        "Mark particles from the selected class-averages as bad in the set interface"

        r = QtWidgets.QMessageBox.question(
            None, "Are you sure ?",
            "WARNING: There is no undo for this operation. It will  mark all particles associated with the selected class-averages as bad. Are you sure you want to proceed ?",
            QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.Cancel)
        if r == QtWidgets.QMessageBox.Cancel: return

        lst = LSXFile(self.curPtclFile())  # lst file for dereferenceing
        ptcls = {
        }  # dictionary keyed by original frame filename with list of selected particle #s
        # iterate over each particle from each marked class-average
        for n in self.curPtclIter(self.wselused.getValue(),
                                  self.wselunused.getValue()):
            try:
                orign, origfile, comment = lst.read(n)
            except:
                QtWidgets.QMessageBox.warning(
                    self, "Error !",
                    "The data_source '%s' does not follow EMAN2.1 project conventions. Cannot find raw particles for set."
                    % srcfile)
                return

            try:
                ptcls[origfile].append(
                    orign)  # try to add to a list for an existing filename
            except:
                ptcls[origfile] = [
                    orign
                ]  # creates the list for this filename if it's new

        #now mark the particles as bad
        newbad = 0
        totbad = 0
        for origfile in ptcls:
            js = js_open_dict(
                info_name(origfile))  # get the info dict for this file

            try:
                sets = js["sets"]
            except:
                sets = {"bad_particles": []}
            try:
                badset = set(sets["bad_particles"])
            except:
                badset = set()

            try:
                newset = list(set(ptcls[origfile]) | badset)
                sets[
                    "bad_particles"] = newset  # update the set of bad particles for this image file
                js["sets"] = sets
                totbad += len(badset)
                newbad += len(newset) - len(badset)
            except:
                print("Error setting bad particles in ", origfile)

            js_close_dict(info_name(origfile))
        print(newbad, " new particles marked as bad. Total of ", totbad,
              " in affected micrographs")

    def markGoodPtcl(self, x):
        "Mark particles from the selected class-averages as good in the set interface"

        r = QtWidgets.QMessageBox.question(
            None, "Are you sure ?",
            "WARNING: There is no undo for this operation. It will un-mark all particles associated with the selected class-averages as bad. Are you sure you want to proceed ?",
            QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.Cancel)
        if r == QtWidgets.QMessageBox.Cancel: return

        lst = LSXFile(self.curPtclFile())  # lst file for dereferenceing
        ptcls = {
        }  # dictionary keyed by original frame filename with list of selected particle #s
        # iterate over each particle from each marked class-average
        for n in self.curPtclIter(self.wselused.getValue(),
                                  self.wselunused.getValue()):
            try:
                orign, origfile, comment = lst.read(n)
            except:
                QtWidgets.QMessageBox.warning(
                    self, "Error !",
                    "The data_source '%s' does not follow EMAN2.1 project conventions. Cannot find raw particles for set."
                    % srcfile)
                return

            try:
                ptcls[origfile].append(
                    orign)  # try to add to a list for an existing filename
            except:
                ptcls[origfile] = [
                    orign
                ]  # creates the list for this filename if it's new

        #now mark the particles as good
        badafter = 0
        badbefore = 0
        for origfile in ptcls:
            js = js_open_dict(
                info_name(origfile))  # get the info dict for this file
            try:
                badset = set(js["sets"]["bad_particles"])
                js["sets"]["bad_particles"] = list(
                    badset - set(ptcls[origfile])
                )  # update the set of bad particles for this image file
            except:
                pass  # since marking as good is the same as removing from the bad list, if there is no bad list, there is nothing to do

            try:
                sets = js["sets"]
            except:
                continue  # if no current bad particles, nothing to mark good
            try:
                badset = sets["bad_particles"]
            except:
                continue

            try:
                newset = list(badset - set(ptcls[origfile]))
                sets[
                    "bad_particles"] = newset  # update the set of bad particles for this image file
                js["sets"] = sets
                badbefore += len(badset)
                badafter += len(newset)
            except:
                continue

        print(badbefore, " bad particles before processing, now ", badafter)

    def savePtclNum(self, x):
        "Saves a list of particles from marked classes into a text file"

        filename = QtWidgets.QInputDialog.getText(
            None, "Filename",
            "Please enter a filename for the particle list. The file will contain the particle number (within the particle file) for each particle associated with a selected class-average."
        )
        if filename[1] == False or filename[0] == "": return

        out = open(filename[0], "w")
        for i in self.curPtclIter(self.wselused.getValue(),
                                  self.wselunused.getValue()):
            out.write("%d\n" % i)
        out.close()

    def saveOrigPtclNum(self, x):
        "Saves a file containing micrograph-dereferenced particles"
        filename = QtWidgets.QInputDialog.getText(
            None, "Filename",
            "Please enter a filename for the particle list. The file will contain particle number and image file, one per line. Image files will be referenced back to the original per-CCD frame stacks."
        )
        if filename[1] == False or filename[0] == "": return

        lst = LSXFile(self.curPtclFile())  # lst file for dereferenceing
        include = []
        # iterate over each particle from each marked class-average
        for n in self.curPtclIter(self.wselused.getValue(),
                                  self.wselunused.getValue()):
            try:
                orign, origfile, comment = lst.read(
                    n
                )  # the original file/number dereferenced from the LST file
            except:
                QtWidgets.QMessageBox.warning(
                    self, "Error !",
                    "The data_source '%s' does not follow EMAN2.1 project conventions. Cannot find raw particles for set."
                    % srcfile)
                return

            include.append((origfile, orign,
                            comment))  # build a list so we can sort by frame

        # write the output file
        out = open(filename, "w")
        for i in sorted(include):
            out.write("{}\t{}\n".format(i[1], i[0]))
        out = None

    def selAllClasses(self, x):
        "Mark all classes as selected"
        self.vclasses.all_set()

    def selNoClasses(self, x):
        "Clear selection"
        self.vclasses.clear_set()

    def selRangeClasses(self, x):
        "Select a range of images (ask the user for the range)"
        rng = QtWidgets.QInputDialog.getText(
            None, "Select Range",
            "Enter the range of particle values as first-last (inclusive). Merges with existing selection."
        )
        if rng[1] == False: return

        try:
            x0, x1 = rng[0].split("-")
            x0 = int(x0)
            x1 = int(x1) + 1
        except:
            QtWidgets.QMessageBox.warning(
                self, "Error !", "Invalid range specified. Use: min-max")
            return

        self.vclasses.subset_set(list(range(x0, x1)))

    def selInvertClasses(self, x):
        "Inverts the current selection set"
        self.vclasses.invert_set()

    def sel3DClasses(self, x):
        "Select a range of images based on those used in a 3-D reconstruction associated with this classes file. Removes current selection first."

        f = self.curFile()
        if not '#classes_' in f:
            QtWidgets.QMessageBox.warning(
                self, "Error !",
                "A classes_xx file from a refine_xx directory is not currently selected"
            )
            return

        # construct the path to the threed_xx file
        num = f.split("_")[-1]
        pre = f.split("#")[0]
        d3path = "%s#threed_%s" % (pre, num)
        try:
            a = EMData(d3path, 0, True)
            goodptcl = a["threed_ptcl_idxs"]
        except:
            QtWidgets.QMessageBox.warning(self, "Error !",
                                          "Cannot read classes from " + d3path)
            return

        self.vclasses.clear_set()
        self.vclasses.subset_set(goodptcl)

    def ptclChange(self, value):
        "Called when a change of particle data file occurs to zero out the display"
        try:
            self.vgoodptcl.set_data(None)
            self.vbadptcl.set_data(None)
        except:
            pass

    def updateFiles(self):
        "Updates the list of classes files"
        subdir = sorted([
            i for i in os.listdir(e2getcwd())
            if "r2d_" in i or "r2db_" in i or "relion2d_" in i
            or "refine_" in i or "multi_" in i or "multinoali_" in i
        ])
        for d in subdir:
            try:
                dbs = os.listdir(d)
            except:
                continue
            dbs.sort()
            for db in dbs:
                if "classes_" in db or "allrefs_" in db:
                    self.wfilesel.addItem("%s/%s" % (d, db))

        for f in self.extrafiles:
            self.wfilesel.addItem(f)

        dbs = os.listdir("sets")
        dbs.sort()
        for db in dbs:
            self.wptclfile.addItem("sets/" + db)

    def curPtclIter(self, included=True, excluded=True):
        "This is a generator function which yields n (in self.curPtclFile()) for all particles associated with selected classes"
        for ci in self.curSet():
            try:
                c = EMData(self.curFile(), ci,
                           True)  # read header for current class average
                if included:
                    incl = c["class_ptcl_idxs"]
                    if isinstance(incl, int):
                        incl = [
                            incl
                        ]  # This should not happen, but seems to sometimes for some reason
                    for i in incl:
                        yield (i)
                if excluded and c.has_attr("exc_class_ptcl_idxs"):
                    excl = c["exc_class_ptcl_idxs"]
                    if isinstance(excl, int):
                        excl = [
                            excl
                        ]  # This should not happen, but seems to sometimes for some reason
                    for i in excl:
                        yield (i)
            except:
                print("Problem with class %d (%s). Skipping" %
                      (ci, self.curFile()))
                traceback.print_exc()
                continue

    def curFile(self):
        "return the currently selected file as a readable path"
        return str(self.wfilesel.item(self.wfilesel.currentRow()).text()
                   )  # text of the currently selected item

    def curSet(self):
        "return a set (integers) of the currently selected class-averages"

        return self.vclasses.get_set("evalptcl")

    def curPtclFile(self):
        "return the particle file associated with the currently selected classes file"
        return str(self.wptclfile.currentText()
                   )  # text of the currently selected item

    def fileUpdate(self):
        "Called when the user selects a file from the list or need to completely refresh display"

        QtWidgets.qApp.setOverrideCursor(Qt.BusyCursor)

        if self.vclasses == None:
            self.vclasses = EMImageMXWidget()
            self.vclasses.set_mouse_mode("App")
            self.vclasses.mx_image_selected.connect(self.classSelect)
            self.vclasses.mx_image_double.connect(self.classDouble)

        self.vclasses.set_title("Classes")

        #		self.classes=EMData.read_images(self.curFile())
        self.vclasses.set_data(self.curFile(), self.curFile())
        #		self.vclasses.set_single_active_set("selected")		# This makes the 'set' representing the selected class-averages current
        self.vclasses.set_mouse_mode("App")
        self.vclasses.enable_set("evalptcl", [])

        # This makes sure the particle file is in the list of choices and is selected
        try:
            ptclfile = EMData(self.curFile(), 0, True)["class_ptcl_src"]
            i = self.wptclfile.findText(ptclfile)
            if i == -1:
                self.wptclfile.insertItem(0, ptclfile)
                self.wptclfile.setCurrentIndex(0)
            else:
                self.wptclfile.setCurrentIndex(i)
        except:
            QtWidgets.QMessageBox.warning(
                self, "Error !",
                "This image does not appear to be a class average. (No class_ptcl_src, etc.)"
            )
            ptclfile = "None"

        # Make sure our display widgets exist
        if self.vgoodptcl == None:
            self.vgoodptcl = EMImageMXWidget()
        self.vgoodptcl.set_title("Included Particles")

        if self.vbadptcl == None:
            self.vbadptcl = EMImageMXWidget()
        self.vbadptcl.set_title("Excluded Particles")

        self.vclasses.show()
        self.vgoodptcl.show()
        self.vbadptcl.show()

        QtWidgets.qApp.setOverrideCursor(Qt.ArrowCursor)

    def classSelect(self, event, lc):
        "Single clicked class particle. lc=(img#,x,y,image_dict)"

        QtWidgets.qApp.setOverrideCursor(Qt.BusyCursor)
        ptclfile = self.curPtclFile()
        try:
            ptclgood = lc[3]["class_ptcl_idxs"]
            self.vgoodptcl.set_data(EMData.read_images(ptclfile, ptclgood))
        except:
            QtWidgets.QMessageBox.warning(
                self, "Error !",
                "This image does not appear to be a class average. (No class_ptcl_src, etc.)"
            )
            QtWidgets.qApp.setOverrideCursor(Qt.ArrowCursor)
            return
        try:
            ptclbad = lc[3]["exc_class_ptcl_idxs"]
            self.vbadptcl.set_data(EMData.read_images(ptclfile, ptclbad))
        except:
            ptclbad = []
            self.vbadptcl.set_data(None)

        self.vgoodptcl.show()
        self.vbadptcl.show()
        QtWidgets.qApp.setOverrideCursor(Qt.ArrowCursor)

    def classDouble(self, event, lc):
        self.vclasses.image_set_associate(lc[0], update_gl=True)

    def closeEvent(self, event):
        try:
            self.vclasses.commit_sets()
            self.vclasses.close()
        except:
            pass
        try:
            self.vgoodptcl.close()
        except:
            pass
        try:
            self.vbadptcl.close()
        except:
            pass

        QtWidgets.QWidget.closeEvent(self, event)
Пример #9
0
class EMTomoBoxer(QtWidgets.QMainWindow):
	"""This class represents the EMTomoBoxer application instance.  """
	keypress = QtCore.pyqtSignal(QtGui.QKeyEvent)
	module_closed = QtCore.pyqtSignal()

	def __init__(self,application,options,datafile):
		QtWidgets.QWidget.__init__(self)
		self.initialized=False
		self.app=weakref.ref(application)
		self.options=options
		self.apix=options.apix
		self.currentset=0
		self.shrink=1#options.shrink
		self.setWindowTitle("Main Window (e2spt_boxer.py)")
		if options.mode=="3D":
			self.boxshape="circle"
		else:
			self.boxshape="rect"

		self.globalxf=Transform()
		
		# Menu Bar
		self.mfile=self.menuBar().addMenu("File")
		#self.mfile_open=self.mfile.addAction("Open")
		self.mfile_read_boxloc=self.mfile.addAction("Read Box Coord")
		self.mfile_save_boxloc=self.mfile.addAction("Save Box Coord")
		self.mfile_save_boxpdb=self.mfile.addAction("Save Coord as PDB")
		self.mfile_save_boxes_stack=self.mfile.addAction("Save Boxes as Stack")
		#self.mfile_quit=self.mfile.addAction("Quit")


		self.setCentralWidget(QtWidgets.QWidget())
		self.gbl = QtWidgets.QGridLayout(self.centralWidget())

		# relative stretch factors
		self.gbl.setColumnMinimumWidth(0,200)
		self.gbl.setRowMinimumHeight(0,200)
		self.gbl.setColumnStretch(0,0)
		self.gbl.setColumnStretch(1,100)
		self.gbl.setColumnStretch(2,0)
		self.gbl.setRowStretch(1,0)
		self.gbl.setRowStretch(0,100)
		

		# 3 orthogonal restricted projection views
		self.xyview = EMImage2DWidget(sizehint=(1024,1024))
		self.gbl.addWidget(self.xyview,0,1)

		self.xzview = EMImage2DWidget(sizehint=(1024,256))
		self.gbl.addWidget(self.xzview,1,1)

		self.zyview = EMImage2DWidget(sizehint=(256,1024))
		self.gbl.addWidget(self.zyview,0,0)

		# Select Z for xy view
		self.wdepth = QtWidgets.QSlider()
		self.gbl.addWidget(self.wdepth,1,2)

		### Control panel area in upper left corner
		self.gbl2 = QtWidgets.QGridLayout()
		self.gbl.addLayout(self.gbl2,1,0)

		#self.wxpos = QtWidgets.QSlider(Qt.Horizontal)
		#self.gbl2.addWidget(self.wxpos,0,0)
		
		#self.wypos = QtWidgets.QSlider(Qt.Vertical)
		#self.gbl2.addWidget(self.wypos,0,3,6,1)
		
		# box size
		self.wboxsize=ValBox(label="Box Size:",value=0)
		self.gbl2.addWidget(self.wboxsize,2,0)

		# max or mean
		#self.wmaxmean=QtWidgets.QPushButton("MaxProj")
		#self.wmaxmean.setCheckable(True)
		#self.gbl2.addWidget(self.wmaxmean,3,0)

		# number slices
		label0=QtWidgets.QLabel("Thickness")
		self.gbl2.addWidget(label0,3,0)

		self.wnlayers=QtWidgets.QSpinBox()
		self.wnlayers.setMinimum(1)
		self.wnlayers.setMaximum(256)
		self.wnlayers.setValue(1)
		self.gbl2.addWidget(self.wnlayers,3,1)

		# Local boxes in side view
		self.wlocalbox=QtWidgets.QCheckBox("Limit Side Boxes")
		self.gbl2.addWidget(self.wlocalbox,4,0)
		self.wlocalbox.setChecked(True)
		
		self.button_flat = QtWidgets.QPushButton("Flatten")
		self.gbl2.addWidget(self.button_flat,5,0)
		self.button_reset = QtWidgets.QPushButton("Reset")
		self.gbl2.addWidget(self.button_reset,5,1)
		## scale factor
		#self.wscale=ValSlider(rng=(.1,2),label="Sca:",value=1.0)
		#self.gbl2.addWidget(self.wscale,4,0,1,2)

		# 2-D filters
		self.wfilt = ValSlider(rng=(0,150),label="Filt",value=0.0)
		self.gbl2.addWidget(self.wfilt,6,0,1,2)
		
		self.curbox=-1
		
		self.boxes=[]						# array of box info, each is (x,y,z,...)
		self.boxesimgs=[]					# z projection of each box
		self.dragging=-1

		##coordinate display
		self.wcoords=QtWidgets.QLabel("")
		self.gbl2.addWidget(self.wcoords, 1, 0, 1, 2)
		
		self.button_flat.clicked[bool].connect(self.flatten_tomo)
		self.button_reset.clicked[bool].connect(self.reset_flatten_tomo)

		# file menu
		#self.mfile_open.triggered[bool].connect(self.menu_file_open)
		self.mfile_read_boxloc.triggered[bool].connect(self.menu_file_read_boxloc)
		self.mfile_save_boxloc.triggered[bool].connect(self.menu_file_save_boxloc)
		self.mfile_save_boxpdb.triggered[bool].connect(self.menu_file_save_boxpdb)
		
		self.mfile_save_boxes_stack.triggered[bool].connect(self.save_boxes)
		#self.mfile_quit.triggered[bool].connect(self.menu_file_quit)

		# all other widgets
		self.wdepth.valueChanged[int].connect(self.event_depth)
		self.wnlayers.valueChanged[int].connect(self.event_nlayers)
		self.wboxsize.valueChanged.connect(self.event_boxsize)
		#self.wmaxmean.clicked[bool].connect(self.event_projmode)
		#self.wscale.valueChanged.connect(self.event_scale)
		self.wfilt.valueChanged.connect(self.event_filter)
		self.wlocalbox.stateChanged[int].connect(self.event_localbox)

		self.xyview.mousemove.connect(self.xy_move)
		self.xyview.mousedown.connect(self.xy_down)
		self.xyview.mousedrag.connect(self.xy_drag)
		self.xyview.mouseup.connect(self.mouse_up)
		self.xyview.mousewheel.connect(self.xy_wheel)
		self.xyview.signal_set_scale.connect(self.event_scale)
		self.xyview.origin_update.connect(self.xy_origin)

		self.xzview.mousedown.connect(self.xz_down)
		self.xzview.mousedrag.connect(self.xz_drag)
		self.xzview.mouseup.connect(self.mouse_up)
		self.xzview.mousewheel.connect(self.xz_wheel)
		self.xzview.signal_set_scale.connect(self.event_scale)
		self.xzview.origin_update.connect(self.xz_origin)
		self.xzview.mousemove.connect(self.xz_move)

		self.zyview.mousedown.connect(self.zy_down)
		self.zyview.mousedrag.connect(self.zy_drag)
		self.zyview.mouseup.connect(self.mouse_up)
		self.zyview.mousewheel.connect(self.zy_wheel)
		self.zyview.signal_set_scale.connect(self.event_scale)
		self.zyview.origin_update.connect(self.zy_origin)
		self.zyview.mousemove.connect(self.zy_move)
		
		self.xyview.keypress.connect(self.key_press)
		self.datafilename=datafile
		self.basename=base_name(datafile)
		p0=datafile.find('__')
		if p0>0:
			p1=datafile.rfind('.')
			self.filetag=datafile[p0:p1]
			if self.filetag[-1]!='_':
				self.filetag+='_'
		else:
			self.filetag="__"
			
		data=EMData(datafile)
		self.set_data(data)

		# Boxviewer subwidget (details of a single box)
		#self.boxviewer=EMBoxViewer()
		#self.app().attach_child(self.boxviewer)

		# Boxes Viewer (z projections of all boxes)
		self.boxesviewer=EMImageMXWidget()
		
		#self.app().attach_child(self.boxesviewer)
		self.boxesviewer.show()
		self.boxesviewer.set_mouse_mode("App")
		self.boxesviewer.setWindowTitle("Particle List")
		self.boxesviewer.rzonce=True
		
		self.setspanel=EMTomoSetsPanel(self)

		self.optionviewer=EMTomoBoxerOptions(self)
		self.optionviewer.add_panel(self.setspanel,"Sets")
		
		
		self.optionviewer.show()
		
		self.boxesviewer.mx_image_selected.connect(self.img_selected)
		
		##################
		#### deal with metadata in the _info.json file...
		
		self.jsonfile=info_name(datafile)
		info=js_open_dict(self.jsonfile)
		
		#### read particle classes
		self.sets={}
		self.boxsize={}
		if "class_list" in info:
			clslst=info["class_list"]
			for k in sorted(clslst.keys()):
				if type(clslst[k])==dict:
					self.sets[int(k)]=str(clslst[k]["name"])
					self.boxsize[int(k)]=int(clslst[k]["boxsize"])
				else:
					self.sets[int(k)]=str(clslst[k])
					self.boxsize[int(k)]=64
					
		clr=QtGui.QColor
		self.setcolors=[QtGui.QBrush(clr("blue")),QtGui.QBrush(clr("green")),QtGui.QBrush(clr("red")),QtGui.QBrush(clr("cyan")),QtGui.QBrush(clr("purple")),QtGui.QBrush(clr("orange")), QtGui.QBrush(clr("yellow")),QtGui.QBrush(clr("hotpink")),QtGui.QBrush(clr("gold"))]
		self.sets_visible={}
				
		#### read boxes
		if "boxes_3d" in info:
			box=info["boxes_3d"]
			for i,b in enumerate(box):
				#### X-center,Y-center,Z-center,method,[score,[class #]]
				bdf=[0,0,0,"manual",0.0, 0]
				for j,bi in enumerate(b):  bdf[j]=bi
				
				
				if bdf[5] not in list(self.sets.keys()):
					clsi=int(bdf[5])
					self.sets[clsi]="particles_{:02d}".format(clsi)
					self.boxsize[clsi]=64
				
				self.boxes.append(bdf)
		
		###### this is the new (2018-09) metadata standard..
		### now we use coordinates at full size from center of tomogram so it works for different binning and clipping
		### have to make it compatible with older versions though..
		if "apix_unbin" in info:
			self.apix_unbin=info["apix_unbin"]
			self.apix_cur=apix=data["apix_x"]
			for b in self.boxes:
				b[0]=b[0]/apix*self.apix_unbin+data["nx"]//2
				b[1]=b[1]/apix*self.apix_unbin+data["ny"]//2
				b[2]=b[2]/apix*self.apix_unbin+data["nz"]//2
				
			for k in self.boxsize.keys():
				self.boxsize[k]=int(np.round(self.boxsize[k]*self.apix_unbin/apix))
		else:
			self.apix_unbin=-1
			
		info.close()
		
		E2loadappwin("e2sptboxer","main",self)
		E2loadappwin("e2sptboxer","boxes",self.boxesviewer.qt_parent)
		E2loadappwin("e2sptboxer","option",self.optionviewer)
		
		#### particle classes
		if len(self.sets)==0:
			self.new_set("particles_00")
		self.sets_visible[list(self.sets.keys())[0]]=0
		self.currentset=sorted(self.sets.keys())[0]
		self.setspanel.update_sets()
		self.wboxsize.setValue(self.get_boxsize())

		#print(self.sets)
		for i in range(len(self.boxes)):
			self.update_box(i)
		
		self.update_all()
		self.initialized=True
		

	def set_data(self,data):

		self.data=data
		self.apix=data["apix_x"]

		self.datasize=(data["nx"],data["ny"],data["nz"])
		self.x_loc, self.y_loc, self.z_loc=data["nx"]//2,data["ny"]//2,data["nz"]//2

		self.gbl.setRowMinimumHeight(1,max(250,data["nz"]))
		self.gbl.setColumnMinimumWidth(0,max(250,data["nz"]))
		print(data["nx"],data["ny"],data["nz"])

		self.wdepth.setRange(0,data["nz"]-1)
		self.wdepth.setValue(data["nz"]//2)
		self.boxes=[]
		self.curbox=-1

		if self.initialized:
			self.update_all()

	def eraser_width(self):
		return int(self.optionviewer.eraser_radius.getValue())
		
	def get_cube(self,x,y,z, centerslice=False, boxsz=-1):
		"""Returns a box-sized cube at the given center location"""
		if boxsz<0:
			bs=self.get_boxsize()
		else:
			bs=boxsz
			
		if centerslice:
			bz=1
		else:
			bz=bs
		
		if ((x<-bs//2) or (y<-bs//2) or (z<-bz//2)
			or (x>self.data["nx"]+bs//2) or (y>self.data["ny"]+bs//2) or (z>self.data["nz"]+bz//2) ):
			r=EMData(bs,bs,bz)
		else:
			r=self.data.get_clip(Region(x-bs//2,y-bs//2,z-bz//2,bs,bs,bz))

		if self.apix!=0 :
			r["apix_x"]=r["apix_y"]=r["apix_z"]=self.apix

		return r

	def get_slice(self,idx,thk=1,axis="z"):
		if self.globalxf.is_identity():
			data=self.data
		else:
			data=self.dataxf
			
		t=int(thk-1)
		idx=int(idx)
		r=data.process("misc.directional_sum",{"axis":axis,"first":idx-t,"last":idx+t})
		r.div(t*2+1)
		
		if self.apix!=0 :
			r["apix_x"]=r["apix_y"]=r["apix_z"]=self.apix
		return r

	def event_boxsize(self):
		if self.get_boxsize()==int(self.wboxsize.getValue()):
			return
		
		self.boxsize[self.currentset]=int(self.wboxsize.getValue())
		
		#cb=self.curbox
		self.initialized=False
		for i in range(len(self.boxes)):
			if self.boxes[i][5]==self.currentset:
				self.update_box(i)
		#self.update_box(cb)
		self.initialized=True
		self.update_all()

	#def event_projmode(self,state):
		#"""Projection mode can be simple average (state=False) or maximum projection (state=True)"""
		#self.update_all()

	def event_scale(self,newscale):
		self.xyview.set_scale(newscale)
		self.xzview.set_scale(newscale)
		self.zyview.set_scale(newscale)

	def event_depth(self):
		if self.z_loc!=self.wdepth.value():
			self.z_loc=self.wdepth.value()
		if self.initialized:
			self.update_sliceview()

	def event_nlayers(self):
		self.update_all()

	def event_filter(self):
		self.update_all()

	def event_localbox(self,tog):
		self.update_sliceview(['x','y'])

	def get_boxsize(self, clsid=-1):
		if clsid<0:
			return int(self.boxsize[self.currentset])
		else:
			try:
				ret= int(self.boxsize[clsid])
			except:
				print("No box size saved for {}..".format(clsid))
				ret=32
			return ret

	def nlayers(self):
		return int(self.wnlayers.value())

	def menu_file_read_boxloc(self):
		fsp=str(QtWidgets.QFileDialog.getOpenFileName(self, "Select output text file")[0])
		
		if not os.path.isfile(fsp):
			print("file does not exist")
			return

		f=open(fsp,"r")
		for b in f:
			b2=[old_div(int(float(i)),self.shrink) for i in b.split()[:3]]
			bdf=[0,0,0,"manual",0.0, self.currentset]
			for j in range(len(b2)):
				bdf[j]=b2[j]
			self.boxes.append(bdf)
			self.update_box(len(self.boxes)-1)
		f.close()

	def menu_file_save_boxloc(self):
		shrinkf=self.shrink 								#jesus

		fsp=str(QtWidgets.QFileDialog.getSaveFileName(self, "Select output text file")[0])
		if len(fsp)==0:
			return

		out=open(fsp,"w")
		for b in self.boxes:
			out.write("%d\t%d\t%d\n"%(b[0]*shrinkf,b[1]*shrinkf,b[2]*shrinkf))
		out.close()
		
	def menu_file_save_boxpdb(self):
		fsp=str(QtWidgets.QFileDialog.getSaveFileName(self, "Select output PDB file", filter="PDB (*.pdb)")[0])
		if len(fsp)==0:
			return
		if fsp[-4:].lower()!=".pdb" :
			fsp+=".pdb"
		clsid=list(self.sets_visible.keys())
		if len(clsid)==0:
			print("No visible particles to save")
			return
		
		bxs=np.array([[b[0], b[1], b[2]] for b in self.boxes if int(b[5]) in clsid])/10
		
		numpy2pdb(bxs, fsp)
		print("PDB saved to {}. Use voxel size 0.1".format(fsp))
		
	def save_boxes(self, clsid=[]):
		if len(clsid)==0:
			defaultname="ptcls.hdf"
		else:
			defaultname="_".join([self.sets[i] for i in clsid])+".hdf"
		
		name,ok=QtWidgets.QInputDialog.getText( self, "Save particles", "Filename suffix:", text=defaultname)
		if not ok:
			return
		name=self.filetag+str(name)
		if name[-4:].lower()!=".hdf" :
			name+=".hdf"
			
			
		if self.options.mode=="3D":
			dr="particles3d"
			is2d=False
		else:
			dr="particles"
			is2d=True
		
		
		if not os.path.isdir(dr):
			os.mkdir(dr)
		
		fsp=os.path.join(dr,self.basename)+name

		print("Saving {} particles to {}".format(self.options.mode, fsp))
		
		if os.path.isfile(fsp):
			print("{} exist. Overwritting...".format(fsp))
			os.remove(fsp)
		
		progress = QtWidgets.QProgressDialog("Saving", "Abort", 0, len(self.boxes),None)
		
		
		boxsz=-1
		for i,b in enumerate(self.boxes):
			if len(clsid)>0:
				if int(b[5]) not in clsid:
					continue
			
			#img=self.get_cube(b[0],b[1],b[2])
			bs=self.get_boxsize(b[5])
			if boxsz<0:
				boxsz=bs
			else:
				if boxsz!=bs:
					print("Inconsistant box size in the particles to save.. Using {:d}..".format(boxsz))
					bs=boxsz
			
			sz=[s//2 for s in self.datasize]
			
			img=self.get_cube(b[0], b[1], b[2], centerslice=is2d, boxsz=bs)
			if is2d==False:
				img.process_inplace('normalize')
			
			img["ptcl_source_image"]=self.datafilename
			img["ptcl_source_coord"]=(b[0]-sz[0], b[1]-sz[1], b[2]-sz[2])
			
			if is2d==False: #### do not invert contrast for 2D images
				img.mult(-1)
			
			img.write_image(fsp,-1)

			progress.setValue(i+1)
			if progress.wasCanceled():
				break

	def update_sliceview(self, axis=['x','y','z']):
		boxes=self.get_rotated_boxes()
		
		allside=(not self.wlocalbox.isChecked())
		
		pms={'z':[2, self.xyview, self.z_loc],
		     'y':[1, self.xzview, self.y_loc],
		     'x':[0, self.zyview, self.x_loc]} 
		
		if self.boxshape=="circle": lwi=7
		else: lwi=8
		
		for ax in axis:
			ia, view, loc=pms[ax]
			shp=view.get_shapes()
			if len(shp)!=len(boxes):
				### something changes the box shapes...
				for i,b in enumerate(boxes):
					self.update_box_shape(i,b)
			
		for ax in axis:
			ia, view, loc=pms[ax]
			
			## update the box shapes
			shp=view.get_shapes()
			for i,b in enumerate(boxes):
				bs=self.get_boxsize(b[5])
				dst=abs(b[ia] - loc)
				
				inplane=dst<bs//2
				rad=bs//2-dst
				
				if ax!='z' and allside:
					## display all side view boxes in this mode
					inplane=True
					rad=bs//2
					
				if ax=='z' and self.options.mode=="2D":
					## boxes are 1 slice thick in 2d mode
					inplane=dst<1
				
				
				if inplane and (b[5] in self.sets_visible):
					shp[i][0]=self.boxshape
					## selected box is slightly thicker
					if self.curbox==i:
						shp[i][lwi]=3
					else:
						shp[i][lwi]=2
					if self.options.mode=="3D":
						shp[i][6]=rad
				else:
					shp[i][0]="hidden"
				
			view.shapechange=1
			img=self.get_slice(loc, self.nlayers(), ax)
			if self.wfilt.getValue()!=0.0:
				img.process_inplace("filter.lowpass.gauss",{"cutoff_freq":1.0/self.wfilt.getValue(),"apix":self.apix})

			view.set_data(img)
			
		self.update_coords()

	def get_rotated_boxes(self):
		if len(self.boxes)==0:
			return []
		if self.globalxf.is_identity():
			boxes=self.boxes
		else:
			cnt=[self.data["nx"]//2,self.data["ny"]//2,self.data["nz"]//2]
			pts=np.array([b[:3] for b in self.boxes])-cnt
			pts=np.array([self.globalxf.transform(p.tolist()) for p in pts])+cnt
			boxes=[]
			for i,b in enumerate(self.boxes):
				p=pts[i]
				boxes.append([p[0],p[1],p[2],b[3],b[4],b[5]])
		return boxes
		
	def update_all(self):
		"""redisplay of all widgets"""
		if self.data==None:
			return

		self.update_sliceview()
		self.update_boximgs()

	def update_coords(self):
		self.wcoords.setText("X: {:d}\tY: {:d}\tZ: {:d}".format(int(self.x_loc), int(self.y_loc), int(self.z_loc)))

	def inside_box(self,n,x=-1,y=-1,z=-1):
		"""Checks to see if a point in image coordinates is inside box number n. If any value is negative, it will not be checked."""
		box=self.boxes[n]
		if box[5] not in self.sets_visible:
			return False
		bs=self.get_boxsize(box[5])/2
		if self.options.mode=="3D":
			rr=(x>=0)*((box[0]-x)**2) + (y>=0)*((box[1]-y) **2) + (z>=0)*((box[2]-z)**2)
		else:
			rr=(x>=0)*((box[0]-x)**2) + (y>=0)*((box[1]-y) **2) + (z>=0)*(box[2]!=z)*(1e3*bs**2)
		return rr<=bs**2

	def del_box(self, delids):
		
		if type(delids)!=list:
			delids=[delids]
		
		kpids=[i for i,b in enumerate(self.boxes) if i not in delids]
		self.boxes=[self.boxes[i] for i in kpids]
		self.boxesimgs=[self.boxesimgs[i] for i in kpids]
		self.xyview.shapes={i:self.xyview.shapes[k] for i,k in enumerate(kpids)}
		self.xzview.shapes={i:self.xzview.shapes[k] for i,k in enumerate(kpids)}
		self.zyview.shapes={i:self.zyview.shapes[k] for i,k in enumerate(kpids)}
		#print self.boxes, self.xyview.get_shapes()
		self.curbox=-1
		self.update_all()
	
	def update_box_shape(self,n, box):
		bs2=self.get_boxsize(box[5])//2
		if n==self.curbox:
			lw=3
		else:
			lw=2
		color=self.setcolors[box[5]%len(self.setcolors)].color().getRgbF()
		if self.options.mode=="3D":
			self.xyview.add_shape(n,EMShape(["circle",color[0],color[1],color[2],box[0],box[1],bs2,lw]))
			self.xzview.add_shape(n,EMShape(["circle",color[0],color[1],color[2],box[0],box[2],bs2,lw]))
			self.zyview.add_shape(n,EMShape(("circle",color[0],color[1],color[2],box[2],box[1],bs2,lw)))
		else:
			self.xyview.add_shape(n,EMShape(["rect",color[0],color[1],color[2],
				    box[0]-bs2,box[1]-bs2,box[0]+bs2,box[1]+bs2,2]))
			self.xzview.add_shape(n,EMShape(["rect",color[0],color[1],color[2], 
				    box[0]-bs2,box[2]-1,box[0]+bs2,box[2]+1,2]))
			self.zyview.add_shape(n,EMShape(["rect",color[0],color[1],color[2],
				    box[2]-1,box[1]-bs2,box[2]+1,box[1]+bs2,2]))
			
	
	def update_box(self,n,quiet=False):
		"""After adjusting a box, call this"""
#		print "upd ",n,quiet
		if n<0 or n>=len(self.boxes):
			return
		
		box=self.boxes[n]
		
		boxes=self.get_rotated_boxes()
		self.update_box_shape(n,boxes[n])

		if self.initialized: 
			self.update_sliceview()

		# For speed, we turn off updates while dragging a box around. Quiet is set until the mouse-up
		if not quiet:
			# Get the cube from the original data (normalized)
			proj=self.get_cube(box[0], box[1], box[2], centerslice=True, boxsz=self.get_boxsize(box[5]))
			proj.process_inplace("normalize")
			
			for i in range(len(self.boxesimgs),n+1): 
				self.boxesimgs.append(None)
			
			self.boxesimgs[n]=proj

			mm=[m for im,m in enumerate(self.boxesimgs) if self.boxes[im][5] in self.sets_visible]
			
			if self.initialized: self.SaveJson()
			
		if self.initialized:
			self.update_boximgs()
			#if n!=self.curbox:
				#self.boxesviewer.set_selected((n,),True)

		self.curbox=n

	def update_boximgs(self):
		self.boxids=[im for im,m in enumerate(self.boxesimgs) if self.boxes[im][5] in self.sets_visible]
		self.boxesviewer.set_data([self.boxesimgs[i] for i in self.boxids])
		self.boxesviewer.update()
		return

	def img_selected(self,event,lc):
		#print("sel",lc[0])
		lci=self.boxids[lc[0]]
		if event.modifiers()&Qt.ShiftModifier:
			if event.modifiers()&Qt.ControlModifier:
				self.del_box(list(range(lci, len(self.boxes))))
			else:
				self.del_box(lci)
		else:
			#self.update_box(lci)
			self.curbox=lci
			box=self.boxes[lci]
			self.x_loc,self.y_loc,self.z_loc=self.rotate_coord([box[0], box[1], box[2]], inv=False)
			self.scroll_to(self.x_loc,self.y_loc,self.z_loc)
			
			self.update_sliceview()
			
			
	def rotate_coord(self, p, inv=True):
		if not self.globalxf.is_identity():
			cnt=[self.data["nx"]//2,self.data["ny"]//2,self.data["nz"]//2]
			p=[p[i]-cnt[i] for i in range(3)]
			xf=Transform(self.globalxf)
			if inv:
				xf.invert()
			p=xf.transform(p)+cnt
		return p

	def del_region_xy(self, x=-1, y=-1, z=-1, rad=-1):
		if rad<0:
			rad=self.eraser_width()
		
		delids=[]
		boxes=self.get_rotated_boxes()
		for i,b in enumerate(boxes):
			if b[5] not in self.sets_visible:
				continue
			
			if (x>=0)*(b[0]-x)**2 + (y>=0)*(b[1]-y)**2 +(z>=0)*(b[2]-z)**2 < rad**2:
				delids.append(i)
		self.del_box(delids)
	
	def scroll_to(self, x,y,z, axis=""):
		if axis!="z": self.xyview.scroll_to(x,y,True)
		if axis!="y": self.xzview.scroll_to(x,self.data["nz"]/2,True)
		if axis!="x": self.zyview.scroll_to(self.data["nz"]/2,y,True)
	
	#### mouse click
	def xy_down(self,event):
		x,y=self.xyview.scr_to_img((event.x(),event.y()))
		self.mouse_down(event, x,y,self.z_loc, "z")
		
	def xz_down(self,event):
		x,z=self.xzview.scr_to_img((event.x(),event.y()))
		self.mouse_down(event,x,self.y_loc,z, "y")
			
	def zy_down(self,event):
		z,y=self.zyview.scr_to_img((event.x(),event.y()))
		self.mouse_down(event,self.x_loc,y,z, "x")
		
	def mouse_down(self,event, x, y, z, axis):
		if min(x,y,z)<0: return
		
		xr,yr,zr=self.rotate_coord((x,y,z))
		#print(x,y,z,xr,yr,zr)
	
		if self.optionviewer.erasercheckbox.isChecked():
			self.del_region_xy(x,y,z)
			return
			
		for i in range(len(self.boxes)):
			if self.inside_box(i,xr,yr,zr):
				
				if event.modifiers()&Qt.ShiftModifier:  ## delete box
					self.del_box(i)

				else:  ## start dragging
					self.dragging=i
					self.curbox=i
					self.scroll_to(x,y,z,axis)
					
				break
		else:
			if not event.modifiers()&Qt.ShiftModifier: ## add box

				self.x_loc, self.y_loc, self.z_loc=x,y,z
				self.scroll_to(x,y,z,axis)
				self.curbox=len(self.boxes)
				self.boxes.append(([xr,yr,zr, 'manual', 0.0, self.currentset]))
				self.update_box(len(self.boxes)-1)
				self.dragging=len(self.boxes)-1
				
				

	#### eraser mode
	def xy_move(self,event):
		self.mouse_move(event, self.xyview)
			
	def xz_move(self,event):
		self.mouse_move(event, self.xzview)
		
	def zy_move(self,event):
		self.mouse_move(event, self.zyview)
			
	
	def mouse_move(self,event,view):
		
		if self.optionviewer.erasercheckbox.isChecked(): 
			self.xyview.eraser_shape=self.xzview.eraser_shape=self.zyview.eraser_shape=None
			x,y=view.scr_to_img((event.x(),event.y()))
			view.eraser_shape=EMShape(["circle",1,1,1,x,y,self.eraser_width(),2])
			view.shapechange=1
			view.update()
		else:
			view.eraser_shape=None
			
	
	#### dragging...
	def mouse_drag(self,x, y, z):
		if self.dragging<0:
			return
		if min(x,y,z)<0:
			return
		
		self.x_loc, self.y_loc, self.z_loc=x,y,z
		x,y,z=self.rotate_coord((x,y,z))
		self.boxes[self.dragging][:3]= x,y,z
		self.update_box(self.dragging,True)

	def xy_drag(self,event):
		if self.dragging>=0:
			x,y=self.xyview.scr_to_img((event.x(),event.y()))
			self.mouse_drag(x,y,self.z_loc)

	def xz_drag(self,event):
		if self.dragging>=0:
			x,z=self.xzview.scr_to_img((event.x(),event.y()))
			self.mouse_drag(x,self.y_loc,z)
	
	def zy_drag(self,event):
		if self.dragging>=0:
			z,y=self.zyview.scr_to_img((event.x(),event.y()))
			self.mouse_drag(self.x_loc,y,z)
		
	def mouse_up(self,event):
		if self.dragging>=0:
			self.update_box(self.dragging)
		self.dragging=-1

	
	#### keep the same origin for the 3 views
	def xy_origin(self,newor):
		xzo=self.xzview.get_origin()
		self.xzview.set_origin(newor[0],xzo[1],True)

		zyo=self.zyview.get_origin()
		self.zyview.set_origin(zyo[0],newor[1],True)
	
	def xz_origin(self,newor):
		xyo=self.xyview.get_origin()
		self.xyview.set_origin(newor[0],xyo[1],True)

	def zy_origin(self,newor):
		xyo=self.xyview.get_origin()
		self.xyview.set_origin(xyo[0],newor[1],True)


	##### go up/down with shift+wheel
	def xy_wheel(self, event):
		z=int(self.z_loc+ np.sign(event.angleDelta().y()))
		if z>0 and z<self.data["nz"]:
			self.wdepth.setValue(z)
	
	def xz_wheel(self, event):
		y=int(self.y_loc+np.sign(event.angleDelta().y()))
		if y>0 and y<self.data["ny"]:
			self.y_loc=y
			self.update_sliceview(['y'])
		
	def zy_wheel(self, event):
		x=int(self.x_loc+np.sign(event.angleDelta().y()))
		if x>0 and x<self.data["nx"]:
			self.x_loc=x
			self.update_sliceview(['x'])
			

	########
	def set_current_set(self, name):
		
		#print "set current", name
		name=parse_setname(name)
		self.currentset=name
		self.wboxsize.setValue(self.get_boxsize())
		self.update_all()
		return
	
	
	def hide_set(self, name):
		name=parse_setname(name)
		
		if name in self.sets_visible: self.sets_visible.pop(name)
		
		
		if self.initialized: 
			self.update_all()
			self.update_boximgs()
		return
	
	
	def show_set(self, name):
		name=parse_setname(name)
		self.sets_visible[name]=0
		#self.currentset=name
		#self.wboxsize.setValue(self.get_boxsize())
		if self.initialized: 
			self.update_all()
			self.update_boximgs()
		return
	
	
	def delete_set(self, name):
		name=parse_setname(name)
		## idx to keep
		delids=[i for i,b in enumerate(self.boxes) if b[5]==int(name)]
		self.del_box(delids)
		
		if name in self.sets_visible: self.sets_visible.pop(name)
		if name in self.sets: self.sets.pop(name)
		if name in self.boxsize: self.boxsize.pop(name)
		
		self.update_all()
		
		return
	
	def rename_set(self, oldname,  newname):
		name=parse_setname(oldname)
		if name in self.sets: 
			self.sets[name]=newname
		return
	
	
	def new_set(self, name):
		for i in range(len(self.sets)+1):
			if i not in self.sets:
				break
			
		self.sets[i]=name
		self.sets_visible[i]=0
		if self.options.mode=="3D":
			self.boxsize[i]=32
		else:
			self.boxsize[i]=64
		
		return
	
	def save_set(self):
		
		self.save_boxes(list(self.sets_visible.keys()))
		return
	
	
	def key_press(self,event):
		if event.key() == 96: ## "`" to move up a slice since arrow keys are occupied...
			self.wdepth.setValue(self.z_loc+1)

		elif event.key() == 49: ## "1" to move down a slice
			self.wdepth.setValue(self.z_loc-1)
		else:
			self.keypress.emit(event)

	def flatten_tomo(self):
		print("Flatten tomogram by particles coordinates")
		vis=list(self.sets_visible.keys())
		pts=[b[:3] for b in self.boxes if b[5] in vis]
		if len(pts)<3:
			print("Too few visible particles. Cannot flatten tomogram.")
			return
		pts=np.array(pts)
		pca=PCA(3)
		pca.fit(pts);
		c=pca.components_
		t=Transform()
		cc=c[2]
		if cc[2]!=0:
			cc*=np.sign(cc[2])
		
		t.set_rotation(c[2].tolist())
		t.invert()
		xyz=t.get_params("xyz")
		xyz["ztilt"]=0
		print("xtilt {:.02f}, ytilt {:.02f}".format(xyz["xtilt"], xyz["ytilt"]))
		t=Transform(xyz)
		self.globalxf=t
		self.dataxf=self.data.process("xform",{"transform":t})
		
		self.xyview.shapes={}
		self.zyview.shapes={}
		self.xzview.shapes={}
		
		boxes=self.get_rotated_boxes()
		for i,b in enumerate(boxes):
			self.update_box_shape(i,b)
		
		self.update_sliceview()
		print("Done")
	
	def reset_flatten_tomo(self, event):
		self.globalxf=Transform()
		self.xyview.shapes={}
		self.zyview.shapes={}
		self.xzview.shapes={}
		
		boxes=self.get_rotated_boxes()
		for i,b in enumerate(boxes):
			self.update_box_shape(i,b)
		
		self.update_sliceview()
		

	def SaveJson(self):
		
		info=js_open_dict(self.jsonfile)
		sx,sy,sz=(self.data["nx"]//2,self.data["ny"]//2,self.data["nz"]//2)
		if "apix_unbin" in info:
			bxs=[]
			for b0 in self.boxes:
				b=[	(b0[0]-sx)*self.apix_cur/self.apix_unbin,
					(b0[1]-sy)*self.apix_cur/self.apix_unbin,
					(b0[2]-sz)*self.apix_cur/self.apix_unbin,
					b0[3], b0[4], b0[5]	]
				bxs.append(b)
				
			bxsz={}
			for k in self.boxsize.keys():
				bxsz[k]=np.round(self.boxsize[k]*self.apix_cur/self.apix_unbin)

				
		else:
			bxs=self.boxes
			bxsz=self.boxsize
				
		info["boxes_3d"]=bxs
		clslst={}
		for key in list(self.sets.keys()):
			clslst[int(key)]={
				"name":self.sets[key],
				"boxsize":int(bxsz[key]),
				}
		info["class_list"]=clslst
		info.close()
	
	def closeEvent(self,event):
		print("Exiting")
		self.SaveJson()
		
		E2saveappwin("e2sptboxer","main",self)
		E2saveappwin("e2sptboxer","boxes",self.boxesviewer.qt_parent)
		E2saveappwin("e2sptboxer","option",self.optionviewer)
		
		#self.boxviewer.close()
		self.boxesviewer.close()
		self.optionviewer.close()
		#self.optionviewer.close()
		#self.xyview.close()
		#self.xzview.close()
		#self.zyview.close()
		
		self.module_closed.emit() # this signal is important when e2ctf is being used by a program running its own event loop
Пример #10
0
	def __init__(self,application,options,datafile):
		QtWidgets.QWidget.__init__(self)
		self.initialized=False
		self.app=weakref.ref(application)
		self.options=options
		self.apix=options.apix
		self.currentset=0
		self.shrink=1#options.shrink
		self.setWindowTitle("Main Window (e2spt_boxer.py)")
		if options.mode=="3D":
			self.boxshape="circle"
		else:
			self.boxshape="rect"

		self.globalxf=Transform()
		
		# Menu Bar
		self.mfile=self.menuBar().addMenu("File")
		#self.mfile_open=self.mfile.addAction("Open")
		self.mfile_read_boxloc=self.mfile.addAction("Read Box Coord")
		self.mfile_save_boxloc=self.mfile.addAction("Save Box Coord")
		self.mfile_save_boxpdb=self.mfile.addAction("Save Coord as PDB")
		self.mfile_save_boxes_stack=self.mfile.addAction("Save Boxes as Stack")
		#self.mfile_quit=self.mfile.addAction("Quit")


		self.setCentralWidget(QtWidgets.QWidget())
		self.gbl = QtWidgets.QGridLayout(self.centralWidget())

		# relative stretch factors
		self.gbl.setColumnMinimumWidth(0,200)
		self.gbl.setRowMinimumHeight(0,200)
		self.gbl.setColumnStretch(0,0)
		self.gbl.setColumnStretch(1,100)
		self.gbl.setColumnStretch(2,0)
		self.gbl.setRowStretch(1,0)
		self.gbl.setRowStretch(0,100)
		

		# 3 orthogonal restricted projection views
		self.xyview = EMImage2DWidget(sizehint=(1024,1024))
		self.gbl.addWidget(self.xyview,0,1)

		self.xzview = EMImage2DWidget(sizehint=(1024,256))
		self.gbl.addWidget(self.xzview,1,1)

		self.zyview = EMImage2DWidget(sizehint=(256,1024))
		self.gbl.addWidget(self.zyview,0,0)

		# Select Z for xy view
		self.wdepth = QtWidgets.QSlider()
		self.gbl.addWidget(self.wdepth,1,2)

		### Control panel area in upper left corner
		self.gbl2 = QtWidgets.QGridLayout()
		self.gbl.addLayout(self.gbl2,1,0)

		#self.wxpos = QtWidgets.QSlider(Qt.Horizontal)
		#self.gbl2.addWidget(self.wxpos,0,0)
		
		#self.wypos = QtWidgets.QSlider(Qt.Vertical)
		#self.gbl2.addWidget(self.wypos,0,3,6,1)
		
		# box size
		self.wboxsize=ValBox(label="Box Size:",value=0)
		self.gbl2.addWidget(self.wboxsize,2,0)

		# max or mean
		#self.wmaxmean=QtWidgets.QPushButton("MaxProj")
		#self.wmaxmean.setCheckable(True)
		#self.gbl2.addWidget(self.wmaxmean,3,0)

		# number slices
		label0=QtWidgets.QLabel("Thickness")
		self.gbl2.addWidget(label0,3,0)

		self.wnlayers=QtWidgets.QSpinBox()
		self.wnlayers.setMinimum(1)
		self.wnlayers.setMaximum(256)
		self.wnlayers.setValue(1)
		self.gbl2.addWidget(self.wnlayers,3,1)

		# Local boxes in side view
		self.wlocalbox=QtWidgets.QCheckBox("Limit Side Boxes")
		self.gbl2.addWidget(self.wlocalbox,4,0)
		self.wlocalbox.setChecked(True)
		
		self.button_flat = QtWidgets.QPushButton("Flatten")
		self.gbl2.addWidget(self.button_flat,5,0)
		self.button_reset = QtWidgets.QPushButton("Reset")
		self.gbl2.addWidget(self.button_reset,5,1)
		## scale factor
		#self.wscale=ValSlider(rng=(.1,2),label="Sca:",value=1.0)
		#self.gbl2.addWidget(self.wscale,4,0,1,2)

		# 2-D filters
		self.wfilt = ValSlider(rng=(0,150),label="Filt",value=0.0)
		self.gbl2.addWidget(self.wfilt,6,0,1,2)
		
		self.curbox=-1
		
		self.boxes=[]						# array of box info, each is (x,y,z,...)
		self.boxesimgs=[]					# z projection of each box
		self.dragging=-1

		##coordinate display
		self.wcoords=QtWidgets.QLabel("")
		self.gbl2.addWidget(self.wcoords, 1, 0, 1, 2)
		
		self.button_flat.clicked[bool].connect(self.flatten_tomo)
		self.button_reset.clicked[bool].connect(self.reset_flatten_tomo)

		# file menu
		#self.mfile_open.triggered[bool].connect(self.menu_file_open)
		self.mfile_read_boxloc.triggered[bool].connect(self.menu_file_read_boxloc)
		self.mfile_save_boxloc.triggered[bool].connect(self.menu_file_save_boxloc)
		self.mfile_save_boxpdb.triggered[bool].connect(self.menu_file_save_boxpdb)
		
		self.mfile_save_boxes_stack.triggered[bool].connect(self.save_boxes)
		#self.mfile_quit.triggered[bool].connect(self.menu_file_quit)

		# all other widgets
		self.wdepth.valueChanged[int].connect(self.event_depth)
		self.wnlayers.valueChanged[int].connect(self.event_nlayers)
		self.wboxsize.valueChanged.connect(self.event_boxsize)
		#self.wmaxmean.clicked[bool].connect(self.event_projmode)
		#self.wscale.valueChanged.connect(self.event_scale)
		self.wfilt.valueChanged.connect(self.event_filter)
		self.wlocalbox.stateChanged[int].connect(self.event_localbox)

		self.xyview.mousemove.connect(self.xy_move)
		self.xyview.mousedown.connect(self.xy_down)
		self.xyview.mousedrag.connect(self.xy_drag)
		self.xyview.mouseup.connect(self.mouse_up)
		self.xyview.mousewheel.connect(self.xy_wheel)
		self.xyview.signal_set_scale.connect(self.event_scale)
		self.xyview.origin_update.connect(self.xy_origin)

		self.xzview.mousedown.connect(self.xz_down)
		self.xzview.mousedrag.connect(self.xz_drag)
		self.xzview.mouseup.connect(self.mouse_up)
		self.xzview.mousewheel.connect(self.xz_wheel)
		self.xzview.signal_set_scale.connect(self.event_scale)
		self.xzview.origin_update.connect(self.xz_origin)
		self.xzview.mousemove.connect(self.xz_move)

		self.zyview.mousedown.connect(self.zy_down)
		self.zyview.mousedrag.connect(self.zy_drag)
		self.zyview.mouseup.connect(self.mouse_up)
		self.zyview.mousewheel.connect(self.zy_wheel)
		self.zyview.signal_set_scale.connect(self.event_scale)
		self.zyview.origin_update.connect(self.zy_origin)
		self.zyview.mousemove.connect(self.zy_move)
		
		self.xyview.keypress.connect(self.key_press)
		self.datafilename=datafile
		self.basename=base_name(datafile)
		p0=datafile.find('__')
		if p0>0:
			p1=datafile.rfind('.')
			self.filetag=datafile[p0:p1]
			if self.filetag[-1]!='_':
				self.filetag+='_'
		else:
			self.filetag="__"
			
		data=EMData(datafile)
		self.set_data(data)

		# Boxviewer subwidget (details of a single box)
		#self.boxviewer=EMBoxViewer()
		#self.app().attach_child(self.boxviewer)

		# Boxes Viewer (z projections of all boxes)
		self.boxesviewer=EMImageMXWidget()
		
		#self.app().attach_child(self.boxesviewer)
		self.boxesviewer.show()
		self.boxesviewer.set_mouse_mode("App")
		self.boxesviewer.setWindowTitle("Particle List")
		self.boxesviewer.rzonce=True
		
		self.setspanel=EMTomoSetsPanel(self)

		self.optionviewer=EMTomoBoxerOptions(self)
		self.optionviewer.add_panel(self.setspanel,"Sets")
		
		
		self.optionviewer.show()
		
		self.boxesviewer.mx_image_selected.connect(self.img_selected)
		
		##################
		#### deal with metadata in the _info.json file...
		
		self.jsonfile=info_name(datafile)
		info=js_open_dict(self.jsonfile)
		
		#### read particle classes
		self.sets={}
		self.boxsize={}
		if "class_list" in info:
			clslst=info["class_list"]
			for k in sorted(clslst.keys()):
				if type(clslst[k])==dict:
					self.sets[int(k)]=str(clslst[k]["name"])
					self.boxsize[int(k)]=int(clslst[k]["boxsize"])
				else:
					self.sets[int(k)]=str(clslst[k])
					self.boxsize[int(k)]=64
					
		clr=QtGui.QColor
		self.setcolors=[QtGui.QBrush(clr("blue")),QtGui.QBrush(clr("green")),QtGui.QBrush(clr("red")),QtGui.QBrush(clr("cyan")),QtGui.QBrush(clr("purple")),QtGui.QBrush(clr("orange")), QtGui.QBrush(clr("yellow")),QtGui.QBrush(clr("hotpink")),QtGui.QBrush(clr("gold"))]
		self.sets_visible={}
				
		#### read boxes
		if "boxes_3d" in info:
			box=info["boxes_3d"]
			for i,b in enumerate(box):
				#### X-center,Y-center,Z-center,method,[score,[class #]]
				bdf=[0,0,0,"manual",0.0, 0]
				for j,bi in enumerate(b):  bdf[j]=bi
				
				
				if bdf[5] not in list(self.sets.keys()):
					clsi=int(bdf[5])
					self.sets[clsi]="particles_{:02d}".format(clsi)
					self.boxsize[clsi]=64
				
				self.boxes.append(bdf)
		
		###### this is the new (2018-09) metadata standard..
		### now we use coordinates at full size from center of tomogram so it works for different binning and clipping
		### have to make it compatible with older versions though..
		if "apix_unbin" in info:
			self.apix_unbin=info["apix_unbin"]
			self.apix_cur=apix=data["apix_x"]
			for b in self.boxes:
				b[0]=b[0]/apix*self.apix_unbin+data["nx"]//2
				b[1]=b[1]/apix*self.apix_unbin+data["ny"]//2
				b[2]=b[2]/apix*self.apix_unbin+data["nz"]//2
				
			for k in self.boxsize.keys():
				self.boxsize[k]=int(np.round(self.boxsize[k]*self.apix_unbin/apix))
		else:
			self.apix_unbin=-1
			
		info.close()
		
		E2loadappwin("e2sptboxer","main",self)
		E2loadappwin("e2sptboxer","boxes",self.boxesviewer.qt_parent)
		E2loadappwin("e2sptboxer","option",self.optionviewer)
		
		#### particle classes
		if len(self.sets)==0:
			self.new_set("particles_00")
		self.sets_visible[list(self.sets.keys())[0]]=0
		self.currentset=sorted(self.sets.keys())[0]
		self.setspanel.update_sets()
		self.wboxsize.setValue(self.get_boxsize())

		#print(self.sets)
		for i in range(len(self.boxes)):
			self.update_box(i)
		
		self.update_all()
		self.initialized=True
Пример #11
0
class ParticlesWindow(object):
    def __init__(self, rctwidget):
        self.rctwidget = rctwidget
        self.window = EMImageMXWidget(application=self.rctwidget.parent_window)
        self.window.set_display_values(["tilt", "PImg#"])
        self.window.set_mouse_mode("App")
        self.window.setWindowTitle("Particles")
        self.window.optimally_resize()

        self.connect_signals()
        self.listsofparts = []
        self.numlists = 0
        self.closed = False

    def addlist(self, name):
        data = []
        data.append(name)
        data.append(0)
        data.append([])
        self.listsofparts.append(data)
        self.numlists = len(self.listsofparts)

    def update_particles(self, particles, idx):
        #print self.listsofparts[idx][0]
        # reset the relevent list of particles
        self.listsofparts[idx][1] = len(particles)
        self.listsofparts[idx][2] = particles

        # get the number of lists and the minimum number of particles in a given list..
        listlength = 1e308
        for lst in self.listsofparts:
            listlength = min(listlength, lst[1])

        i = 0
        self.totparts = []
        for part in range(listlength):
            for lst in range(self.numlists):
                self.listsofparts[lst][2][part].set_attr(
                    "tilt", self.listsofparts[lst][0])
                self.listsofparts[lst][2][part].set_attr("PImg#", part)
                self.totparts.append(self.listsofparts[lst][2][part])
                i += 1

        if self.totparts != []:
            self.window.set_data(self.totparts)
            self.window.updateGL()

    def connect_signals(self):
        self.window.mx_image_selected.connect(self.box_selected)
        self.window.mx_mousedrag.connect(self.box_moved)
        self.window.mx_mouseup.connect(self.box_released)
        self.window.mx_boxdeleted.connect(self.box_image_deleted)
        self.window.module_closed.connect(self.on_module_closed)

    def box_selected(self, event, lc):
        if lc == None or lc[0] == None: return
        self.moving_box_data = [event.x(), event.y(), lc[0]]

    def box_moved(self, event, scale):
        winidx = self.moving_box_data[2] % self.numlists
        ppidx = int(old_div(self.moving_box_data[2], self.numlists))
        if self.moving_box_data:
            dx = 0.2 * (event.x() - self.moving_box_data[0])
            dy = 0.2 * (self.moving_box_data[1] - event.y())
            self.rctwidget.windowlist[winidx].boxes.move_box(ppidx, dx, dy)
            self.rctwidget.windowlist[winidx].update_mainwin()
            self.rctwidget.windowlist[winidx].update_particles()

    def box_released(self, event, lc):
        pass

    def box_image_deleted(self, event, lc):
        if lc == None or lc[0] == None: return

        #delete all particle pairs
        ppidx = int(old_div(lc[0], self.numlists))
        for i, window in enumerate(self.rctwidget.windowlist):
            window.boxes.remove_box(ppidx, self.rctwidget.boxsize)
            window.update_mainwin()
            window.update_particles()

    def on_module_closed(self):
        E2saveappwin("e2rctboxer", "particles", self.window.qt_parent)
        pass
Пример #12
0
def main():

    progname = os.path.basename(sys.argv[0])
    usage = """e2findlines sets/img.lst
	
	** EXPERIMENTAL **
	this program looks for ~ straight line segments in images, such as wrinkles in graphene oxide films or possible C-film edges

	"""

    parser = EMArgumentParser(usage=usage, version=EMANVERSION)
    parser.add_argument("--threshold",
                        type=float,
                        help="Threshold for separating particles, default=3",
                        default=3.0)
    parser.add_argument("--newsets",
                        default=False,
                        action="store_true",
                        help="Split lines/nolines into 2 new sets")
    #parser.add_argument("--output",type=str,help="Output filename (text file)", default="ptclplot.txt")
    parser.add_argument("--gui",
                        default=False,
                        action="store_true",
                        help="show histogram of values")
    parser.add_argument(
        "--threads",
        default=4,
        type=int,
        help="Number of threads to run in parallel on the local computer")
    parser.add_argument(
        "--verbose",
        "-v",
        dest="verbose",
        action="store",
        metavar="n",
        type=int,
        default=0,
        help=
        "verbose level [0-9], higher number means higher level of verboseness")
    parser.add_argument(
        "--ppid",
        type=int,
        help="Set the PID of the parent process, used for cross platform PPID",
        default=-1)

    (options, args) = parser.parse_args()

    if (len(args) < 1):
        parser.error("Please specify an input stack/set to operate on")

    E2n = E2init(sys.argv, options.ppid)

    options.threads += 1  # one extra thread for storing results

    im0 = EMData(args[0], 0)  # first image
    r2 = im0["ny"] / 4  # outer radius

    # we build up a list of 'Z scores' which should be larger for images containing one or more parallel lines.
    # if 2 lines aren't parallel the number may be lower, even if the lines are strong, but should still be higher
    # than images without lines in most cases
    n = EMUtil.get_image_count(args[0])
    step = max(n // 500, 1)
    Z = []
    im2d = []
    for i in range(n):
        im = EMData(args[0], i)
        a = im.do_fft().calc_az_dist(60, -88.5, 3, 4, r2)
        d = np.array(a)
        Z.append((d.max() - d.mean()) / d.std())
        if i % step == 0:
            im["zscore"] = (d.max() - d.mean()) / d.std()
            im2d.append(im)

    if options.gui:
        # GUI display of a histogram of the Z scores
        from eman2_gui.emhist import EMHistogramWidget
        from eman2_gui.emimagemx import EMImageMXWidget
        from eman2_gui.emapplication import EMApp
        app = EMApp()
        histw = EMHistogramWidget(application=app)
        histw.set_data(Z)
        app.show_specific(histw)
        imd = EMImageMXWidget(application=app)
        im2d.sort(key=lambda x: x["zscore"])
        imd.set_data(im2d)
        app.show_specific(imd)
        app.exec_()

    if options.newsets:
        lstin = LSXFile(args[0])

        # output containing images with lines
        linesfsp = args[0].rsplit(".", 1)[0] + "_lines.lst"
        try:
            os.unlink(linesfsp)
        except:
            pass
        lstlines = LSXFile(linesfsp)

        # output containin images without lines
        nolinesfsp = args[0].rsplit(".", 1)[0] + "_nolines.lst"
        try:
            os.unlink(nolinesfsp)
        except:
            pass
        lstnolines = LSXFile(nolinesfsp)

        for i, z in enumerate(Z):
            if z > options.threshold: lstlines[-1] = lstin[i]
            else: lstnolines[-1] = lstin[i]

    E2end(E2n)
Пример #13
0
class EMDrawWindow(QtWidgets.QMainWindow):
    def __init__(self, application, options):

        self.options = options
        self.check_path(options.path)
        self.get_data(0)

        QtWidgets.QWidget.__init__(self)
        self.imgview = EMImage2DWidget()
        self.setCentralWidget(QtWidgets.QWidget())
        self.gbl = QtWidgets.QGridLayout(self.centralWidget())

        self.lb_name = QtWidgets.QLabel(self.tomoname)
        self.lb_name.setWordWrap(True)
        self.gbl.addWidget(self.lb_name, 0, 0, 1, 2)

        self.iterlst = QtWidgets.QListWidget()
        self.iterlst.itemflags = Qt.ItemFlags(Qt.ItemIsSelectable)

        for i in sorted(self.losses.keys()):
            txt = "{:d}  :  loss = {:.1f}".format(i, self.losses[i])
            item = QtWidgets.QListWidgetItem(txt)
            self.iterlst.addItem(item)

        self.iterlst.currentRowChanged[int].connect(self.update_list)
        self.gbl.addWidget(self.iterlst, 1, 0, 1, 2)

        self.app = weakref.ref(application)

        self.imgview = EMImage2DWidget()
        self.boxes = Boxes(self.imgview, self.pks2d, self.dirs)
        self.shape_index = 0

        self.imgview.set_data(self.datafile)
        self.imgview.shapes = {0: self.boxes}
        self.imgview.show()
        self.imgview.mouseup.connect(self.on_mouseup)

        self.boxesviewer = EMImageMXWidget()
        self.boxesviewer.show()
        self.boxesviewer.set_mouse_mode("App")
        self.boxesviewer.setWindowTitle("Landmarks")
        self.boxesviewer.rzonce = True

        #glEnable(GL_POINT_SMOOTH)
        #glEnable( GL_LINE_SMOOTH );
        #glEnable( GL_POLYGON_SMOOTH );
        #glEnable(GL_BLEND);
        #glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);

    def check_path(self, path):
        js = js_open_dict(os.path.join(path, "0_tomorecon_params.json"))
        self.tomoname = base_name(js["inputname"])
        js.close()

        self.losses = {}
        for itr in range(5):
            fname = os.path.join(self.options.path,
                                 "loss_{:02d}.txt".format(itr))
            #print(fname)
            if os.path.isfile(fname):
                l = np.loadtxt(fname)
                self.losses[itr] = np.mean(l[:, 1])

        #print(self.losses)

    def update_list(self, row):

        idx = sorted(self.losses.keys())[row]
        if idx == self.cur_iter:
            return
        print("Showing iteration {:d}...".format(
            sorted(self.losses.keys())[row]))

        self.get_data(idx)
        self.boxes.pks2d = self.pks2d
        self.boxes.dirs = self.dirs
        self.boxes.selid = -1
        self.shape_index = 0

        self.imgview.set_data(self.datafile)
        self.imgview.shapes = {0: self.boxes}
        #self.imgview.show()
        self.imgview.shapechange = 1
        self.imgview.updateGL()
        self.boxesviewer.set_data([])
        self.boxesviewer.update()

    def select_landmark(self, x, y):
        nid = self.imgview.list_idx
        pks = self.pks2d[nid]
        d = np.sqrt(np.mean((pks - [x, y])**2, axis=1))
        if np.min(d) > 64:
            return
        pid = np.argmin(d)
        print("Selecting landmark #{}...".format(pid))
        self.boxes.selid = pid
        self.imgview.shapechange = 1
        self.imgview.updateGL()

        aname = os.path.join(self.options.path,
                             "ptclali_{:02d}.hdf".format(self.cur_iter))
        #self.boxesviewer.set_data(aname)
        self.ptcls = EMData.read_images(aname)
        #self.ptcls=[p for p in self.ptcls if p["pid"]==pid]
        self.boxesviewer.set_data([p for p in self.ptcls if p["pid"] == pid])
        self.boxesviewer.update()

    def on_mouseup(self, event):
        x, y = self.imgview.scr_to_img((event.x(), event.y()))
        self.select_landmark(x, y)

        #if event.button()&Qt.LeftButton:

        #if event.modifiers()&Qt.ControlModifier:
        ##### interpolate curve from previous slices
        #print('new contour')
        #self.contour.add_point([x, y], True)

        #elif event.modifiers()&Qt.ShiftModifier:
        ##### remove point
        #pts=np.array(self.contour.points)

    def get_data(self, itr):

        fname = os.path.join(self.options.path, "ali_{:02d}.hdf".format(itr))
        pname = os.path.join(self.options.path,
                             "landmarks_{:02d}.txt".format(itr))
        pks = np.loadtxt(pname)
        imgs = EMData.read_images(fname)
        img = EMData(imgs[0]["nx"], imgs[0]["ny"], len(imgs))
        xfs = []
        for i, m in enumerate(imgs):
            img.insert_clip(m, (0, 0, i))
            xfs.append(m["xform.projection"])

        aname = os.path.join(self.options.path,
                             "ptclali_{:02d}.hdf".format(itr))
        n = EMUtil.get_image_count(aname)
        dirs = np.zeros((len(imgs), len(pks), 2))
        for i in range(n):
            e = EMData(aname, i, True)
            s = e["score"]
            dirs[e["nid"], e["pid"]] = [s[0], s[1]]
        #print(dirs)
        self.dirs = dirs = dirs * e["apix_x"] / imgs[0]["apix_x"]
        pks2d = []

        dx = img["nx"] // 2
        dy = img["ny"] // 2
        for nid, xf in enumerate(xfs):
            p2d = []
            for pid, p in enumerate(pks):
                pt = [p[0], p[1], p[2]]
                ptx = xf.transform(pt)
                ptx = [i / 2 for i in ptx]
                ptx = [ptx[0] + dx, ptx[1] + dy]
                p2d.append(ptx)

            pks2d.append(p2d)

        self.pks2d = np.array(pks2d)
        #print(self.pks2d.shape)

        self.xfs = xfs
        self.pks = pks
        self.datafile = img
        self.cur_iter = itr

    def closeEvent(self, event):
        self.imgview.close()
        self.boxesviewer.close()
Пример #14
0
    def setData(self, data):
        if data == None:
            self.data = None
            return

        elif isinstance(data, str):
            self.datafile = data
            self.nimg = EMUtil.get_image_count(data)

            self.origdata = EMData(data, 0)

            if self.origdata["nz"] == 1:
                if self.nimg > 20:
                    self.origdata = EMData.read_images(
                        data, list(range(0, self.nimg, old_div(self.nimg, 20)))
                    )  # read regularly separated images from the file totalling ~20
                elif self.nimg > 1:
                    self.origdata = EMData.read_images(data,
                                                       list(range(self.nimg)))
                else:
                    self.origdata = [self.origdata]
            else:
                self.origdata = [self.origdata]

        else:
            self.datafile = None
            if isinstance(data, EMData): self.origdata = [data]
            else: self.origdata = data

        self.nx = self.origdata[0]["nx"]
        self.ny = self.origdata[0]["ny"]
        self.nz = self.origdata[0]["nz"]
        if self.apix <= 0.0: self.apix = self.origdata[0]["apix_x"]
        EMProcessorWidget.parmdefault["apix"] = (0, (0.2, 10.0), self.apix,
                                                 None)

        origfft = self.origdata[0].do_fft()
        self.pspecorig = origfft.calc_radial_dist(old_div(self.ny, 2), 0.0,
                                                  1.0, 1)
        ds = old_div(1.0, (self.apix * self.ny))
        self.pspecs = [ds * i for i in range(len(self.pspecorig))]

        if self.viewer != None:
            for v in self.viewer:
                v.close()

        if self.nz == 1 or self.force2d:
            if len(self.origdata) > 1:
                self.viewer = [EMImageMXWidget()]
                self.mfile_save_stack.setEnabled(True)
                self.mfile_save_map.setEnabled(False)
            else:
                self.viewer = [EMImage2DWidget()]
                self.mfile_save_stack.setEnabled(False)
                self.mfile_save_map.setEnabled(True)
        else:
            self.mfile_save_stack.setEnabled(False)
            self.mfile_save_map.setEnabled(True)
            self.viewer = [EMScene3D()]
            self.sgdata = EMDataItem3D(test_image_3d(3), transform=Transform())
            self.viewer[0].insertNewNode('Data',
                                         self.sgdata,
                                         parentnode=self.viewer[0])
            isosurface = EMIsosurface(self.sgdata, transform=Transform())
            self.viewer[0].insertNewNode("Iso",
                                         isosurface,
                                         parentnode=self.sgdata)
            volslice = EMSliceItem3D(self.sgdata, transform=Transform())
            self.viewer[0].insertNewNode("Slice",
                                         volslice,
                                         parentnode=self.sgdata)

        E2loadappwin("e2filtertool", "image", self.viewer[0].qt_parent)

        self.procChange(-1)
Пример #15
0
class EMCmpExplorer(EM3DSymModel):
	def __init__(self, gl_widget, projection_file=None,simmx_file=None,particle_file=None):
		self.init_lock = True # a lock indicated that we are still in the __init__ function
		self.au_data = None # This will be a dictionary, keys will be refinement directories, values will be something like available iterations for visual study	
		EM3DSymModel.__init__(self,gl_widget)
		self.window_title = "SimmxXplor"
		#InputEventsManager.__init__(self)
	
		self.projection_file = projection_file 	# a projection file produced by e2project3d
		self.particle_file = particle_file 		# A file containing particles to be examined

		self.current_particle = -1 				# keep track of the current particle
		self.current_projection = None 			# keep track of the current projection

		self.ptcl_display = None			# display all particles
		self.mx_display = None 					# mx display module for displaying projection and aligned particle
		self.lay = None 						# 2d plot for displaying comparison between particle and projection
		
		self.simcmp = "dot:normalize=1"
		self.align = "rotate_translate_flip"
		self.aligncmp = "dot"
		self.refine = "refine"
		self.refinecmp = "dot:normalize=1"
		self.shrink=1
		
		
	def set_data(self,projections,particles):
		'''
		Initialize data
		'''
		if not file_exists(projections): raise RuntimeError("%s does not exist" %self.projection_file)
		if not file_exists(particles): raise RuntimeError("%s does not exist" %self.particle_file)
		
		self.projection_file = projections 	# a projection file produced by e2project3d
		self.particle_file = particles 		# A file containing particles to be examined
		self.set_shrink(self.shrink)
		
	def set_shrink(self,shrink):
		"""This actually loads the data ..."""
		
		self.shrink=shrink
		# Deal with particles
		n=min(EMUtil.get_image_count(self.particle_file),800)
		self.ptcl_data=[i for i in EMData.read_images(self.particle_file,list(range(n))) if i!=None]
		if self.shrink>1 :
			for i in self.ptcl_data : i.process_inplace("math.meanshrink",{"n":self.shrink})
		for i in self.ptcl_data : i.process_inplace("normalize.edgemean",{})

		if self.ptcl_display==None : 
			self.ptcl_display = EMImageMXWidget()
			self.ptcl_display.set_mouse_mode("App")
			self.ptcl_display.mx_image_selected.connect(self.ptcl_selected)
			self.ptcl_display.module_closed.connect(self.on_mx_display_closed)
		self.ptcl_display.set_data(self.ptcl_data)

		# deal with projections
		self.proj_data=EMData.read_images(self.projection_file)
		if self.shrink>1 :
			for i in self.proj_data : i.process_inplace("math.meanshrink",{"n":self.shrink})
		for i in self.proj_data : i.process_inplace("normalize.edgemean",{})

		eulers = [i["xform.projection"] for i in self.proj_data]
		self.specify_eulers(eulers)
		
		for i in self.proj_data : i["cmp"]=0.0
		self.set_emdata_list_as_data(self.proj_data,"cmp")

	def get_num_particles(self):
		if self.ptcl_data==None : return 0
		return len(self.ptcl_data)
		
	def render(self):
		if self.inspector == None: self.get_inspector()
			
		EM3DSymModel.render(self)
	
	def object_picked(self,object_number):
		if object_number == self.current_projection: return
		self.current_projection = object_number
		resize_necessary = False
		if self.mx_display == None:
			self.mx_display = EMImageMXWidget()
			self.mx_display.module_closed.connect(self.on_mx_display_closed)
			resize_necessary = True

		#if self.frc_display == None:
			#self.frc_display = EMPlot2DWidget()
#			QtCore.QObject.connect(self.frc_display,QtCore.SIGNAL("module_closed"),self.on_frc_display_closed)

		self.update_display(False)

		if resize_necessary:
			get_application().show_specific(self.mx_display)
			self.mx_display.optimally_resize()
#			get_application().show_specific(self.frc_display)
#			self.frc_display.optimally_resize()
		else:
			self.mx_display.updateGL()
#			self.frc_display.updateGL()
			
		if object_number != self.special_euler:
			self.special_euler = object_number
			self.regen_dl()
			
	def update_display(self,update=True):
		'''
		Uses self.current_particle and self.current_projection to udpate the self.mx_display
		'''
		if self.mx_display == None : return
		
		if self.current_particle<0 or self.current_projection==None : return
		
		dlist=[]
		dlist.append(self.proj_data[self.current_projection].copy())	# aligned projection
		dlist[0].transform(dlist[0]["ptcl.align2d"])					
		tmp=dlist[0].process("threshold.notzero")
		dlist.append(self.ptcl_data[self.current_particle].copy())		# original particle
		dlist[1].process_inplace("normalize.toimage",{"to":dlist[0]})
		dlist.append(self.proj_data[self.current_projection].copy())	# filtered projection
		dlist[2].process_inplace("filter.matchto",{"to":dlist[1]})
		dlist[2].mult(tmp)
		dlist[2].process_inplace("normalize.toimage",{"to":dlist[1]})
		dlist.append(dlist[2].copy())									# particle with projection subtracted
		dlist[3].sub(dlist[1])
		
		#dlist.append(self.ptcl_data[self.current_particle].copy())		# same as 1 and 2 above, but with a mask
		#tmp=dlist[0].process("threshold.notzero")
		#dlist[4].mult(tmp)
		#dlist[4].process_inplace("filter.matchto",{"to":dlist[0]})
		#dlist[4].mult(tmp)
		#dlist[4].process_inplace("normalize.toimage",{"to":dlist[0]})
		#dlist.append(dlist[3].copy())
		#dlist[5].sub(dlist[0])
		
		self.mx_display.set_data(dlist)
		

		if update: self.mx_display.updateGL()
			
	def on_mx_display_closed(self):
		self.mx_display = None
		
	def get_inspector(self):
		if not self.inspector : 
			self.inspector=EMSimmxXplorInspector(self)
		return self.inspector
	
	def ptcl_selected(self,event,lc):
		"""slot for image selection events from image mx"""
		self.set_ptcl_idx(lc[0])
	
	def set_alignment(self,align,aligncmp,refine,refinecmp):
		"""sets alignment algorithms and recomputes"""
		self.align = str(align)
		self.aligncmp = str(aligncmp)
		self.refine = str(refine)
		self.refinecmp = str(refinecmp)

		self.update_align()



	def set_ptcl_idx(self,idx):
		"""Select the index of the current particle to use for comparisons"""
		if self.current_particle != idx:
			self.current_particle = idx
			self.update_align()

	def update_align(self):
		if self.current_particle<0 : return
		ptcl=self.ptcl_data[self.current_particle]
		
		progress = QtWidgets.QProgressDialog("Computing alignments", "Abort", 0, len(self.proj_data),None)
		progress.show()
		# redetermines particle alignments
		# then we can quickly compute a series of different similarity values
		aopt=parsemodopt(self.align)
		acmp=parsemodopt(self.aligncmp)
		ropt=parsemodopt(self.refine)
		rcmp=parsemodopt(self.refinecmp)
		for i,p in enumerate(self.proj_data):
			try:
				ali=p.align(aopt[0],ptcl,aopt[1],acmp[0],acmp[1])
				if self.refine!="" :
					ropt[1]["xform.align2d"]=ali["xform.align2d"]
					ali=p.align(ropt[0],ptcl,ropt[1],rcmp[0],rcmp[1])
			except:
				print(traceback.print_exc())
				QtWidgets.QMessageBox.warning(None,"Error","Problem with alignment parameters")
				progress.close()
				return
			p["ptcl.align2d"]=ali["xform.align2d"]
			progress.setValue(i)
			QtCore.QCoreApplication.instance().processEvents()
		
		progress.close()
		self.update_cmp()
#		self.update_display(True)
	
	def set_cmp(self,cmpstring):
		"""Select the comparator. Passed as a standard name:attr=value:attr=value string"""
		self.simcmp=str(cmpstring)
		self.update_cmp()
		
	def update_cmp(self):
		cmpopt=parsemodopt(self.simcmp)
		
		progress = QtWidgets.QProgressDialog("Computing similarities", "Abort", 0, len(self.proj_data),None)
		progress.show()
		ptcl=self.ptcl_data[self.current_particle]
		for i,p in enumerate(self.proj_data):
			ali=p.copy()
			ali.transform(p["ptcl.align2d"])
			try : p["cmp"]=-ptcl.cmp(cmpopt[0],ali,cmpopt[1])
			except:
				print(traceback.print_exc())
				QtWidgets.QMessageBox.warning(None,"Error","Invalid similarity metric string, or other comparison error")
				progress.close()
				return
			progress.setValue(i)
			QtWidgets.qApp.processEvents()
			
		progress.close()
		self.set_emdata_list_as_data(self.proj_data,"cmp")
#		self.regen_dl(True)
		EM3DSymModel.render(self)
Пример #16
0
class EMPtclClassify(QtWidgets.QMainWindow):
    def __init__(self, application, options, datafile=None):
        QtWidgets.QWidget.__init__(self)
        self.setMinimumSize(150, 100)
        self.setCentralWidget(QtWidgets.QWidget())
        self.gbl = QtWidgets.QGridLayout(self.centralWidget())

        #self.bt_new=QtWidgets.QPushButton("New")
        #self.bt_new.setToolTip("Build new neural network")
        #self.gbl.addWidget(self.bt_new, 0,0,1,2)

        self.bt_train = QtWidgets.QPushButton("Train")
        self.bt_train.setToolTip("Train neural network")
        self.gbl.addWidget(self.bt_train, 0, 0, 1, 2)

        self.bt_save = QtWidgets.QPushButton("Save")
        self.bt_save.setToolTip("Save particle set")
        self.gbl.addWidget(self.bt_save, 1, 0, 1, 2)

        self.bt_load = QtWidgets.QPushButton("Load")
        self.bt_load.setToolTip("Load neural network")
        self.gbl.addWidget(self.bt_load, 2, 0, 1, 2)

        #self.bt_apply=QtWidgets.QPushButton("Apply")
        #self.bt_apply.setToolTip("Apply neural network")
        #self.gbl.addWidget(self.bt_apply, 4,0,1,2)

        #self.bt_new.clicked[bool].connect(self.new_nnet)
        self.bt_load.clicked[bool].connect(self.load_nnet)
        self.bt_train.clicked[bool].connect(self.train_nnet)
        self.bt_save.clicked[bool].connect(self.save_set)
        #self.bt_apply.clicked[bool].connect(self.apply_nnet)
        #self.bt_chgbx.clicked[bool].connect(self.change_boxsize)
        #self.box_display.currentIndexChanged.connect(self.do_update)

        self.val_learnrate = TextBox("LearnRate", 1e-4)
        self.gbl.addWidget(self.val_learnrate, 0, 2, 1, 1)

        nptcl = EMUtil.get_image_count(options.setname)
        self.val_ptclthr = TextBox("PtclThresh",
                                   int(nptcl * (1 - options.keep)))
        self.gbl.addWidget(self.val_ptclthr, 1, 2, 1, 1)

        self.val_niter = TextBox("Niter", 10)
        self.gbl.addWidget(self.val_niter, 2, 2, 1, 1)

        self.options = options
        self.app = weakref.ref(application)
        self.nnet = None
        self.trainset = []
        #self.nnetsize=96

        self.particles = EMData.read_images(options.setname)
        self.boxsz = self.particles[0]["nx"]

        self.ptclviewer = EMImageMXWidget()
        self.ptclviewer.setWindowTitle("Particles")
        self.sortidx = list(range(len(self.particles)))

        for boxview in [self.ptclviewer]:
            boxview.usetexture = False
            boxview.show()
            boxview.set_mouse_mode("Sets")
            boxview.rzonce = True

        self.ptclviewer.set_data(self.particles)
        self.ptclviewer.sets["bad_particles"] = set()
        self.ptclviewer.update()
        self.ptclimg = get_image(self.particles, len(self.particles))
        global tf
        tf = import_tensorflow(options.gpuid)

    def train_nnet(self):

        if self.nnet == None:
            self.nnet = NNet(self.boxsz)

        if int(self.val_niter.getval()) > 0:
            print("Preparing training set...")
            sets = self.ptclviewer.sets
            #print(sets)
            if (not "bad_particles" in sets) or len(
                    sets["bad_particles"]) == 0:
                print("No references.")
                return

            bids = sets["bad_particles"]
            ptcldata = self.ptclviewer.data
            #print(bids)
            badrefs = [ptcldata[i] for i in bids]
            thr = int(self.val_ptclthr.getval())
            gids = self.sortidx[thr:]
            gids = [i for i in gids if i not in bids]
            #gids=[i for i in range(len(ptcldata)) if i not in bids]
            np.random.shuffle(gids)
            nsample = 512
            gids = gids[:nsample]
            #print(gids)
            goodrefs = [ptcldata[i] for i in gids]
            gimgs = get_image(goodrefs, nsample)
            bimgs = get_image(badrefs, nsample)
            #print(gimgs.shape, bimgs.shape)

            imgs = np.concatenate([gimgs, bimgs], axis=0)

            print(imgs.shape)
            labs = np.zeros(nsample * 2, dtype=np.float32)
            labs[:nsample] = 1.0

            self.trainset = (imgs, labs)

            dataset = tf.data.Dataset.from_tensor_slices(self.trainset)
            dataset = dataset.shuffle(500).batch(64)
            self.nnet.do_training(
                dataset,
                learnrate=self.val_learnrate.getval(),
                niter=int(self.val_niter.getval()),
            )

        bids = self.ptclviewer.sets["bad_particles"]
        #bids=[self.sortidx.index(i) for i in bids]
        bids = [self.sortidx[i] for i in bids]
        #print(bids)
        score = self.nnet.apply_network(self.ptclimg).flatten()
        #print(score)
        sid = np.argsort(score).tolist()
        self.sortidx = sid
        ptcls = [self.particles[i] for i in sid]
        self.ptclviewer.set_data(ptcls)
        idx = [sid.index(i) for i in bids]
        self.ptclviewer.enable_set("bad_particles", idx, update=True)
        self.ptclviewer.commit_sets()
        #print(idx)
        #print('------------------')
        #self.ptclviewer.sets={"bad_particles":s}
        self.ptclviewer.set_mouse_mode("Sets")
        self.ptclviewer.update()

    def save_set(self):
        fname = self.options.setname
        oname = fname[:fname.rfind('.')] + "_good.lst"
        thr = int(self.val_ptclthr.getval())
        #print(oname, thr)
        badi = self.sortidx[:thr]
        #print(badi)
        if os.path.isfile(oname):
            os.remove(oname)
        lst = LSXFile(fname, True)
        lout = LSXFile(oname, False)
        nn = lst.n
        for i in range(nn):
            if i in badi:
                continue
            l = lst.read(i)
            lout.write(-1, l[0], l[1], l[2])

        lst = lout = None
        print("{} particles written to {}".format(nn - thr, oname))

    def load_nnet(self):
        if self.nnet == None:
            self.nnet = NNet(self.boxsz)

        self.nnet.model = tf.keras.models.load_model("nnet_classifycnn.h5",
                                                     compile=False)

    def closeEvent(self, event):
        for b in [self.ptclviewer]:
            b.close()
Пример #17
0
    def on_mx_image_selected(self, event, lc):
        #		self.arc_anim_points = None
        get_application().setOverrideCursor(Qt.BusyCursor)
        if lc != None: self.sel = lc[0]

        if self.average != None:
            included = []
            if self.average.has_attr("class_ptcl_idxs"):
                included = self.average["class_ptcl_idxs"]
            excluded = []
            if self.average.has_attr("exc_class_ptcl_idxs"):
                excluded = self.average["exc_class_ptcl_idxs"]

            all = included + excluded
            #all.sort()

            bdata = []
            data = []
            idx_included = []
            running_idx = 0
            from eman2_gui.emimagemx import ApplyAttribute
            for val in included:
                bdata.append(
                    [self.particle_file, val, [ApplyAttribute("Img #", val)]])
                idx_included.append(running_idx)
                running_idx += 1

            idx_excluded = []
            for val in excluded:
                bdata.append(
                    [self.particle_file, val, [ApplyAttribute("Img #", val)]])
                idx_excluded.append(running_idx)
                running_idx += 1

            data = EMLightWeightParticleCache(bdata)

            first = False
            if self.particle_viewer == None:
                first = True
                self.particle_viewer = EMImageMXWidget(
                    data=None, application=get_application())
                self.particle_viewer.set_mouse_mode("App")
                self.particle_viewer.module_closed.connect(
                    self.on_particle_mx_view_closed)
                self.particle_viewer.mx_image_selected.connect(
                    self.particle_selected)
                get_application().show_specific(self.particle_viewer)

            self.check_images_in_memory()

            if self.sel == 0 or self.alignment_file == None:
                self.particle_viewer.set_data(data)
            else:

                for i, [name, idx, f] in enumerate(bdata):
                    index = -1
                    if self.classes.get_xsize() == 1:
                        index = 0  # just assume it's the first one - this is potentially fatal assumption, but in obscure situations only
                    else:
                        for j in range(self.classes.get_xsize()):
                            if int(self.classes.get(j, idx)) == self.class_idx:
                                index = j
                                break
                    if index == -1:
                        print("couldn't find")
                        get_application().setOverrideCursor(Qt.ArrowCursor)
                        return

                    x = self.dx.get(index, idx)
                    y = self.dy.get(index, idx)
                    a = self.da.get(index, idx)
                    m = self.dflip.get(index, idx)

                    t = Transform({"type": "2d", "alpha": a, "mirror": int(m)})
                    t.set_trans(x, y)
                    from eman2_gui.emimagemx import ApplyTransform
                    f.append(ApplyTransform(t))
                    #data[i].transform(t)
                self.particle_viewer.set_data(data)

            if first:
                self.particle_viewer.updateGL()
                self.particle_viewer.optimally_resize()

            self.particle_viewer.clear_sets(False)
            self.particle_viewer.enable_set("Excluded", idx_excluded, True,
                                            False)
            self.particle_viewer.enable_set("Included", idx_included, False,
                                            False)
            self.particle_viewer.updateGL()

            get_application().setOverrideCursor(Qt.ArrowCursor)

            self.updateGL()
Пример #18
0
    def au_point_selected(self, i, event=None):
        if self.readfrom:
            return
        if i == None:
            if event != None and event.modifiers() & Qt.ShiftModifier:
                if self.special_euler != None:
                    self.special_euler = None
                    if not self.init_lock: self.regen_dl()
            return


#		self.arc_anim_points = None
        self.projection = None
        if self.euler_data:
            #			db = db_open_dict(self.average_file)
            #			a = db.get(i)
            #			print a["nx"]
            #			print self.average_file,i
            #			self.average = EMData(self.average_file,i)
            #			self.average["nx"]
            self.average = self.euler_data[i]  #
            self.projection = EMData(
                self.projection_file,
                self.average.get_attr("projection_image_idx"))
            self.average.process_inplace("normalize.toimage",
                                         {"to": self.projection})
            try:
                self.class_idx = self.average.get_attr("projection_image_idx")
                print("%d (%d)" % (self.class_idx, self.average["ptcl_repr"]))
            except:
                self.class_idx = -1
        else:
            return

        #if self.projection  == None and self.average == None: return
        first = False
        if self.proj_class_viewer == None:
            first = True
            self.proj_class_viewer = EMImageMXWidget(
                data=None, application=get_application())
            #			self.proj_class_viewer = EMImage2DWidget(image=None,application=get_application())
            self.proj_class_viewer.module_closed.connect(
                self.on_mx_view_closed)
            #			self.proj_class_viewer.set_mouse_mode("App" )
            self.proj_class_viewer.mx_image_selected.connect(
                self.on_mx_image_selected)
            get_application().show_specific(self.proj_class_viewer)

            self.proj_class_single = EMImage2DWidget(
                image=None, application=get_application())
            self.proj_class_single.module_closed.connect(
                self.on_mx_view_closed)
            #			QtCore.QObject.connect(self.proj_class_single,QtCore.SIGNAL("mx_image_selected"), self.mx_image_selected)
            get_application().show_specific(self.proj_class_single)

        disp = []
        if self.projection != None: disp.append(self.projection)
        if self.average != None and self.projection != None:
            # ok, this really should be put into its own processor
            #dataf = self.projection.do_fft()
            #apix=self.projection["apix_x"]
            #curve = dataf.calc_radial_dist(dataf["ny"], 0, 0.5,True)
            #curve=[i/(dataf["nx"]*dataf["ny"])**2 for i in curve]
            #xcurve=[i/(apix*2.0*dataf["ny"]) for i in range(len(curve))]
            #xyd=XYData()
            #xyd.set_xy_list(xcurve,curve)
            #filt=self.average.process("filter.setstrucfac",{"apix":apix,"strucfac":xyd})
            #filt.process_inplace("normalize.toimage",{"to":self.average})
            self.projection["apix_x"] = self.average["apix_x"]
            self.projection["apix_y"] = self.average["apix_y"]
            self.projection["apix_z"] = self.average["apix_z"]
            filt = self.projection.process("threshold.notzero")
            filt.mult(self.average)
            filt.process_inplace("filter.matchto", {"to": self.projection})

            disp.append(filt)

        if self.average != None:
            disp.append(self.average)

        self.proj_class_viewer.set_data(disp)
        self.proj_class_single.set_data(disp)

        self.proj_class_viewer.updateGL()
        self.proj_class_single.updateGL()
        if self.particle_viewer != None:
            self.on_mx_image_selected(None, None)
        if first: self.proj_class_viewer.optimally_resize()

        if i != self.special_euler:
            self.special_euler = i
            self.force_update = True

        if not self.init_lock: self.updateGL()
Пример #19
0
class EMEulerExplorer(EM3DSymModel, Animator):
    point_selected = QtCore.pyqtSignal(int, QEvent)

    def mousePressEvent(self, event):
        if self.events_mode == "inspect":
            self.current_hit = self.get_hit(event)
            if self.current_hit == None:
                EM3DSymModel.mousePressEvent(self, event)
        else:
            EM3DSymModel.mousePressEvent(self, event)

    def mouseReleaseEvent(self, event):
        if self.events_mode == "inspect":
            if self.current_hit != None:
                self.updateGL(
                )  # there needs to be a clear or something  in order for the picking to work. This is  bit of hack but our rendering function doesn't take long anyhow
                hit = self.get_hit(event)
                if hit == self.current_hit:
                    self.point_selected.emit(self.current_hit, event)
            else:
                #EM3DSymModel.mouseReleaseEvent(self,event)
                EM3DModel.mouseReleaseEvent(
                    self, event
                )  #behavior in EM3DSymModel is not what we want (needed in sibling classes?)

            self.current_hit = None
        else:
            #EM3DSymModel.mouseReleaseEvent(self,event)
            EM3DModel.mouseReleaseEvent(
                self, event
            )  #behavior in EM3DSymModel is not what we want (needed in sibling classes?)

    def mouseMoveEvent(self, event):
        if self.events_mode == "inspect" and self.current_hit:
            pass
        else:
            EM3DSymModel.mouseMoveEvent(self, event)

    def get_hit(self, event):
        v = self.vdtools.wview.tolist()
        self.get_gl_widget().makeCurrent()  # prevents a stack underflow
        #		x = event.x()
        #		y = v[-1]-event.y()
        #		glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT | GL_STENCIL_BUFFER_BIT )
        #		vals = self.render(color_picking=True)
        #		glFlush()
        #		vv = glReadPixels(x,y,1,1,GL_RGB,GL_FLOAT)
        #		reslt = Vec3f(float(vv[0][0][0]),float(vv[0][0][1]),float(vv[0][0][2]))
        #		for i,val in enumerate(vals):
        ##			print val,reslt,(reslt-val).length(),vv[0][0]
        #			if (reslt-val).length() < 0.01:
        #				print i
        ##				print (reslt-val).length()
        #				return i
        #		print vv
        #
        # the problem with this approach is that depth testing is not part of picking
        sb = [0 for i in range(0, 512)]
        glSelectBuffer(512)
        glRenderMode(GL_SELECT)
        glInitNames()
        glMatrixMode(GL_PROJECTION)
        glPushMatrix()
        glLoadIdentity()
        gluPickMatrix(event.x(), v[-1] - event.y(), 5, 5, v)
        self.get_gl_widget().load_perspective()
        glMatrixMode(GL_MODELVIEW)
        glInitNames()
        self.render()
        glMatrixMode(GL_PROJECTION)
        glPopMatrix()
        glMatrixMode(GL_MODELVIEW)
        glFlush()

        intersection = None
        hits = list(glRenderMode(GL_RENDER))
        for hit in hits:
            a, b, c = hit
            if len(c) > 0:
                intersection = c[0] - 1
                break

        return intersection

    def keyPressEvent(self, event):

        if event.key() == Qt.Key_F1:
            self.display_web_help(
                "http://blake.bcm.edu/emanwiki/EMAN2/Programs/e2eulerxplor")
        elif event.key() == Qt.Key_F:
            if self.flatten > 0: self.flatten = 0.0
            else: self.flatten = 1.0
            self.generate_current_display_list(True)
            self.updateGL()
        else:
            EM3DSymModel.keyPressEvent(self, event)

    def __init__(self,
                 gl_widget=None,
                 auto=True,
                 sparse_mode=False,
                 file_name="",
                 read_from=None):
        self.current_hit = None
        self.events_mode_list = ["navigate", "inspect"]
        self.events_mode = self.events_mode_list[1]

        self.init_lock = True  # a lock indicated that we are still in the __init__ function
        self.au_data = None  # This will be a dictionary, keys will be refinement directories, values will be something like available iterations for visual study
        if len(read_from) == 0:
            read_from = None
        self.readfrom = read_from
        if self.readfrom:

            file_name = ""
            auto = False
            datalst = []
            for rr in self.readfrom:
                datalst.append([rr, None, None, None, rr])
            self.au_data = {"data": datalst}

        if auto:  # this a flag that tells the eulerxplorer to search for refinement data and automatically add elements to the inspector, if so
            self.gen_refinement_data()

        EM3DSymModel.__init__(self, gl_widget, eulerfilename=file_name)
        Animator.__init__(self)
        self.height_scale = 8.0  # This is a value used in EM3DSymModel which scales the height of the displayed cylinders - I made it 8 because it seemed fine. The user can change it anyhow
        self.projection_file = None  # This is a string - the name of the projection images file
        self.average_file = None  # This is a string - the name of the class averages file
        self.proj_class_viewer = None  # This will be an EMImageMXWidget that shows the class and/or projection
        self.particle_viewer = None  # This will be an EMImageMXWidget that shows the particles in a class
        self.clsdb = None  # I think this will become redundant - it used to be the old database that stores which particles are in a class, but now that's stored in the header
        self.particle_file = None  # This will be a string - the name of the file that has the particle files in it. This might be made redundant with the new approach
        self.alignment_file = None  # This will be a string - the name of the file containing the alignment parameters - this is essential if you we want to show the aligned particles
        self.refine_dir = None  # This will be a string - the name of the current refinement directory that is being studied
        self.dx = None  # This is an EMData object storing the x shifts of the alignments for all particles. Generated by e2classaverage
        self.dy = None  # This is an EMData object storing the y shifts of the alignments for all particles. Generated by e2classaverage
        self.da = None  # This is an EMData object storing the angle of the alignments for all particles. Generated by e2classaverage
        self.dflip = None  # This is an EMData object storing whether or not tthe alignment involved a flip, for all particles. Generated by e2classaverage
        self.classes = None  # This is an EMData object storing which class(es) a particle belongs to. Generated by e2classaverage
        self.inclusions = None  # This is and EMDAta storing a boolean that indicates the particle was actually included in the final average. Generated by e2classaverage

        self.average = None  # This the class average itself, an EMData object
        self.projection = None  # This is the projection itelse, an EMData object
        self.class_idx = None  # This is the idx of the current class being studied in the interface

        self.previous_len = -1  # To keep track of the number of class averages that were previously viewable. This helps to make sure we can switch to the same class average in the context of a different refinement iteration
        module_closed = QtCore.pyqtSignal()
        self.mirror_eulers = False
        if sparse_mode:
            self.mirror_eulers = True  # If True the drawn Eulers are are also rendered on the opposite side of the sphere - see EM3DSymModel.make_sym_dl_lis

        # Grab the symmetry from the workflow database if possible
        sym = "c1"
        if js_check_dict("refine_01/0_refine_parms.json"):
            try:
                sym = str(js_open_dict("refine_01/0_refine_parms.json")["sym"])
            except:
                pass

        # Have to tell the EM3DSymModel that there is a new sym
        self.set_symmetry(sym)

        # this object will have
        if self.au_data != None:
            combo_entries = list(self.au_data.keys())
            combo_entries.sort()
            combo_entries.reverse()

            if len(combo_entries) > 0:
                au = combo_entries[0]
                cls = self.au_data[au][0][0]
                self.on_au_selected(au, cls)
                self.mirror_eulers = True

        self.init_lock = False
        self.force_update = True  # Force a display udpdate in EMImage3DSymModule

        self.point_selected.connect(self.au_point_selected)

    def __del__(self):
        EM3DSymModel.__del__(
            self
        )  # this is here for documentation purposes - beware that the del function is important

    def initializeGL(self):
        glEnable(GL_NORMALIZE)

    def generate_current_display_list(self, force=False):
        '''
		Redefinition of EMImage3DSymModule.generate_current_display_list

		'''
        if self.init_lock: return 0
        if self.au_data == None or len(self.au_data) == 0:
            EM3DSymModel.generate_current_display_list(self, force)

        self.init_basic_shapes()
        if self.nomirror == True: val = 0
        else: val = 1
        self.trace_great_arcs(self.sym_object.get_asym_unit_points(val))
        self.trace_great_triangles(val)

        self.eulers = self.specified_eulers
        if self.eulers == None: return 0

        #		if not self.colors_specified: self.point_colors = []
        #		else: self.point_colors = self.specified_colors
        #		self.points = []
        #		for i in self.eulers:
        #			p = i.transpose()*Vec3f(0,0,self.radius)
        #			self.points.append(p)
        #			if not self.colors_specified: self.point_colors.append((0.34615, 0.3143, 0.0903,1))

        self.make_sym_dl_list(self.eulers)
        return 1

    def get_data_dims(self):
        return (2 * self.radius, 2 * self.radius, 2 * self.radius)

    def width(self):
        return 2 * self.radius

    def height(self):
        return 2 * self.radius

    def depth(self):
        return 2 * self.radius

    def gen_refinement_data(self):
        dirs, files = get_files_and_directories()

        dirs.sort()
        for i in range(len(dirs) - 1, -1, -1):
            if dirs[i][:7] != "refine_" and dirs[i][:6] != "multi_" and dirs[
                    i][:11] != "multinoali_":
                dirs.pop(i)
            else:
                try:
                    int(dirs[i][7:])
                except:
                    dirs.pop(i)

        self.dirs = dirs
        print(dirs)

        self.au_data = {}
        for dir in self.dirs:
            d = self.check_refine_db_dir(dir)
            if len(d) != 0 and len(d[dir]) != 0: self.au_data.update(d)

    def check_refine_db_dir(self,
                            dir,
                            s1="classes",
                            s2=None,
                            s3="cls_result",
                            s4="threed",
                            s5="projections"):
        # s2 used to be class_indices
        names = [s1, s2, s3, s4, s5]
        data = {}
        data[dir] = []
        register_js_name = "{}/0_refine_parms.json".format(dir)

        files = os.listdir(dir)

        try:
            nums = [
                int(i[7:9]) for i in files
                if "threed" in i and "even" not in i and "odd" not in i
            ]
            maxnum = max(nums)
        except:
            print("Nothing in ", dir)
            return {}

        for i in range(1, maxnum + 1):
            exte = "_{:02d}_even.hdf".format(i)
            exto = "_{:02d}_odd.hdf".format(i)
            data[dir].append([
                sadd(dir, s1, exte),
                sadd(dir, s2, exte),
                sadd(dir, s3, exte),
                sadd(dir, s4, exte),
                sadd(dir, s5, exte)
            ])
            data[dir].append([
                sadd(dir, s1, exto),
                sadd(dir, s2, exto),
                sadd(dir, s3, exto),
                sadd(dir, s4, exto),
                sadd(dir, s5, exto)
            ])

        return data

    def set_projection_file(self, projection_file):
        self.projection_file = projection_file

    def get_inspector(self):
        if not self.inspector:
            if (
                    self.au_data == None or len(self.au_data) == 0
            ) and self.mirror_eulers == False:  #self.mirror_eulers thing is a little bit of a hack, it's tied to the sparse_mode flag in the init function, which is used by euler_display in EMAN2.py
                self.inspector = EMAsymmetricUnitInspector(self, True, True)
            else:
                self.inspector = EMAsymmetricUnitInspector(self)
            self.inspector.au_selected.connect(self.on_au_selected)
        return self.inspector

    def on_au_selected(self, refine_dir, cls):
        refine_dir = str(refine_dir)
        cls = str(cls)
        self.refine_dir = refine_dir
        get_application().setOverrideCursor(Qt.BusyCursor)
        data = []
        for d in self.au_data[refine_dir]:
            if d[0] == cls:
                data = d
                break

        if len(data) == 0:
            error("error, no data for %s %s, returning" % (refine_dir, cls))
            #			print "error, no data for",au,cls,"returning"
            self.events_handlers["inspect"].reset()
            get_application().setOverrideCursor(Qt.ArrowCursor)
            return

        if not self.readfrom:
            try:
                self.particle_file = js_open_dict(
                    refine_dir + "/0_refine_parms.json")["input"]
            except:
                error("No data in " + refine_dir)
                self.events_handlers["inspect"].reset()
                get_application().setOverrideCursor(Qt.ArrowCursor)
                return

        self.average_file = cls
        self.projection_file = data[4]
        self.alignment_file = data[2]
        self.clsdb = data[1]

        self.dx = None
        self.dy = None
        self.da = None
        self.dflip = None
        self.classes = None

        eulers = get_eulers_from(self.average_file)
        #s = Symmetries.get("d7")
        #eulers = s.gen_orientations("rand",{"n":EMUtil.get_image_count(self.average_file)})

        self.specify_eulers(eulers)
        #from eman2_gui.emimagemx import EMDataListCache
        #a = EMData.read_images(self.average_file)
        #a = [test_image() for i in range(EMUtil.get_image_count(self.average_file))]
        #print len(a),len(eulers)
        #b = [a[i].set_attr("xform.projection",eulers[i]) for i in range(len(eulers))]
        #b = [a[i].set_attr("ptcl_repr",1) for i in range(len(eulers))]

        self.set_emdata_list_as_data(
            EMLightWeightParticleCache.from_file(self.average_file),
            "ptcl_repr")
        #self.set_emdata_list_as_data(EMDataListCache(self.average_file),"ptcl_repr")
        #		self.set_emdata_list_as_data(a,"ptcl_repr")
        self.force_update = True
        self.au_point_selected(self.class_idx, None)
        # if we have the same number of Eulers we can update everything
        #		if self.previous_len == len(eulers) : self.events_handlers["inspect"].repeat_event()
        #		else:self.events_handlers["inspect"].reset()
        self.previous_len = len(eulers)
        if not self.init_lock: self.updateGL()
        get_application().setOverrideCursor(Qt.ArrowCursor)

    def __get_file_headers(self, filename):
        headers = []
        n = EMUtil.get_image_count(filename)
        for i in range(n):
            e = EMData()
            e.read_image(filename, i, True)
            headers.append(e)
        return headers

    def au_point_selected(self, i, event=None):
        if self.readfrom:
            return
        if i == None:
            if event != None and event.modifiers() & Qt.ShiftModifier:
                if self.special_euler != None:
                    self.special_euler = None
                    if not self.init_lock: self.regen_dl()
            return


#		self.arc_anim_points = None
        self.projection = None
        if self.euler_data:
            #			db = db_open_dict(self.average_file)
            #			a = db.get(i)
            #			print a["nx"]
            #			print self.average_file,i
            #			self.average = EMData(self.average_file,i)
            #			self.average["nx"]
            self.average = self.euler_data[i]  #
            self.projection = EMData(
                self.projection_file,
                self.average.get_attr("projection_image_idx"))
            self.average.process_inplace("normalize.toimage",
                                         {"to": self.projection})
            try:
                self.class_idx = self.average.get_attr("projection_image_idx")
                print("%d (%d)" % (self.class_idx, self.average["ptcl_repr"]))
            except:
                self.class_idx = -1
        else:
            return

        #if self.projection  == None and self.average == None: return
        first = False
        if self.proj_class_viewer == None:
            first = True
            self.proj_class_viewer = EMImageMXWidget(
                data=None, application=get_application())
            #			self.proj_class_viewer = EMImage2DWidget(image=None,application=get_application())
            self.proj_class_viewer.module_closed.connect(
                self.on_mx_view_closed)
            #			self.proj_class_viewer.set_mouse_mode("App" )
            self.proj_class_viewer.mx_image_selected.connect(
                self.on_mx_image_selected)
            get_application().show_specific(self.proj_class_viewer)

            self.proj_class_single = EMImage2DWidget(
                image=None, application=get_application())
            self.proj_class_single.module_closed.connect(
                self.on_mx_view_closed)
            #			QtCore.QObject.connect(self.proj_class_single,QtCore.SIGNAL("mx_image_selected"), self.mx_image_selected)
            get_application().show_specific(self.proj_class_single)

        disp = []
        if self.projection != None: disp.append(self.projection)
        if self.average != None and self.projection != None:
            # ok, this really should be put into its own processor
            #dataf = self.projection.do_fft()
            #apix=self.projection["apix_x"]
            #curve = dataf.calc_radial_dist(dataf["ny"], 0, 0.5,True)
            #curve=[i/(dataf["nx"]*dataf["ny"])**2 for i in curve]
            #xcurve=[i/(apix*2.0*dataf["ny"]) for i in range(len(curve))]
            #xyd=XYData()
            #xyd.set_xy_list(xcurve,curve)
            #filt=self.average.process("filter.setstrucfac",{"apix":apix,"strucfac":xyd})
            #filt.process_inplace("normalize.toimage",{"to":self.average})
            self.projection["apix_x"] = self.average["apix_x"]
            self.projection["apix_y"] = self.average["apix_y"]
            self.projection["apix_z"] = self.average["apix_z"]
            filt = self.projection.process("threshold.notzero")
            filt.mult(self.average)
            filt.process_inplace("filter.matchto", {"to": self.projection})

            disp.append(filt)

        if self.average != None:
            disp.append(self.average)

        self.proj_class_viewer.set_data(disp)
        self.proj_class_single.set_data(disp)

        self.proj_class_viewer.updateGL()
        self.proj_class_single.updateGL()
        if self.particle_viewer != None:
            self.on_mx_image_selected(None, None)
        if first: self.proj_class_viewer.optimally_resize()

        if i != self.special_euler:
            self.special_euler = i
            self.force_update = True

        if not self.init_lock: self.updateGL()

    def on_mx_view_closed(self):
        self.proj_class_viewer = None
        self.proj_class_single = None

    def on_particle_mx_view_closed(self):
        self.particle_viewer = None

    def animation_done_event(self, animation):
        pass

    def alignment_time_animation(self, transforms):
        if len(transforms) < 2: return
        animation = OrientationListAnimation(self, transforms, self.radius)
        self.register_animatable(animation)

    def particle_selected(self, event, lc):
        if lc != None:
            d = lc[3]
            ptcl_idx = d["Img #"]
            data = self.au_data[self.refine_dir]
            prj = []
            cls_result = []
            for l in data:
                for s in l:
                    stag = base_name(s)

                    if len(stag) > 11 and stag[:11] == "projections":
                        prj.append(s)
                    elif len(stag) > 10 and stag[:10] == "cls_result":
                        cls_result.append(s)

            transforms = []
            if len(prj) != len(cls_result):
                RunTimeError(
                    "The number of cls_result files does not match the number of projection files?"
                )

            e = EMData()
            for i, cr in enumerate(cls_result):
                r = Region(0, ptcl_idx, 1, 1)
                e.read_image(cr, 0, False, r)
                p = int(e.get(0))
                e.read_image(prj[i], p, True)
                transforms.append(e["xform.projection"])

            self.alignment_time_animation(transforms)

    def on_mx_image_selected(self, event, lc):
        #		self.arc_anim_points = None
        get_application().setOverrideCursor(Qt.BusyCursor)
        if lc != None: self.sel = lc[0]

        if self.average != None:
            included = []
            if self.average.has_attr("class_ptcl_idxs"):
                included = self.average["class_ptcl_idxs"]
            excluded = []
            if self.average.has_attr("exc_class_ptcl_idxs"):
                excluded = self.average["exc_class_ptcl_idxs"]

            all = included + excluded
            #all.sort()

            bdata = []
            data = []
            idx_included = []
            running_idx = 0
            from eman2_gui.emimagemx import ApplyAttribute
            for val in included:
                bdata.append(
                    [self.particle_file, val, [ApplyAttribute("Img #", val)]])
                idx_included.append(running_idx)
                running_idx += 1

            idx_excluded = []
            for val in excluded:
                bdata.append(
                    [self.particle_file, val, [ApplyAttribute("Img #", val)]])
                idx_excluded.append(running_idx)
                running_idx += 1

            data = EMLightWeightParticleCache(bdata)

            first = False
            if self.particle_viewer == None:
                first = True
                self.particle_viewer = EMImageMXWidget(
                    data=None, application=get_application())
                self.particle_viewer.set_mouse_mode("App")
                self.particle_viewer.module_closed.connect(
                    self.on_particle_mx_view_closed)
                self.particle_viewer.mx_image_selected.connect(
                    self.particle_selected)
                get_application().show_specific(self.particle_viewer)

            self.check_images_in_memory()

            if self.sel == 0 or self.alignment_file == None:
                self.particle_viewer.set_data(data)
            else:

                for i, [name, idx, f] in enumerate(bdata):
                    index = -1
                    if self.classes.get_xsize() == 1:
                        index = 0  # just assume it's the first one - this is potentially fatal assumption, but in obscure situations only
                    else:
                        for j in range(self.classes.get_xsize()):
                            if int(self.classes.get(j, idx)) == self.class_idx:
                                index = j
                                break
                    if index == -1:
                        print("couldn't find")
                        get_application().setOverrideCursor(Qt.ArrowCursor)
                        return

                    x = self.dx.get(index, idx)
                    y = self.dy.get(index, idx)
                    a = self.da.get(index, idx)
                    m = self.dflip.get(index, idx)

                    t = Transform({"type": "2d", "alpha": a, "mirror": int(m)})
                    t.set_trans(x, y)
                    from eman2_gui.emimagemx import ApplyTransform
                    f.append(ApplyTransform(t))
                    #data[i].transform(t)
                self.particle_viewer.set_data(data)

            if first:
                self.particle_viewer.updateGL()
                self.particle_viewer.optimally_resize()

            self.particle_viewer.clear_sets(False)
            self.particle_viewer.enable_set("Excluded", idx_excluded, True,
                                            False)
            self.particle_viewer.enable_set("Included", idx_included, False,
                                            False)
            self.particle_viewer.updateGL()

            get_application().setOverrideCursor(Qt.ArrowCursor)

            self.updateGL()

    def check_images_in_memory(self):
        if self.alignment_file != None:
            if self.dx == None:
                self.dx = EMData(self.alignment_file, 2)
            if self.dy == None:
                self.dy = EMData(self.alignment_file, 3)
            if self.da == None:
                self.da = EMData(self.alignment_file, 4)
            if self.dflip == None:
                self.dflip = EMData(self.alignment_file, 5)
            if self.classes == None:
                self.classes = EMData(self.alignment_file, 0)
            if self.inclusions == None:
                self.inclusions = EMData(self.alignment_file, 1)

    def set_events_mode(self, mode):
        if not mode in self.events_mode_list:
            print("error, unknown events mode", mode)
            return

        else:
            self.events_mode = mode

    def closeEvent(self, event):
        if self.inspector != None: self.inspector.close()
        if self.proj_class_viewer != None: self.proj_class_viewer.close()
        if self.proj_class_single != None: self.proj_class_single.close()
        if self.particle_viewer != None: self.particle_viewer.close()
        get_application().close_specific(self)
        self.module_closed.emit()  # this signal is
Пример #20
0
    def __init__(self, application, options, datafile=None):
        QtWidgets.QWidget.__init__(self)
        self.setMinimumSize(150, 100)
        self.setCentralWidget(QtWidgets.QWidget())
        self.gbl = QtWidgets.QGridLayout(self.centralWidget())

        #self.bt_new=QtWidgets.QPushButton("New")
        #self.bt_new.setToolTip("Build new neural network")
        #self.gbl.addWidget(self.bt_new, 0,0,1,2)

        self.bt_train = QtWidgets.QPushButton("Train")
        self.bt_train.setToolTip("Train neural network")
        self.gbl.addWidget(self.bt_train, 0, 0, 1, 2)

        self.bt_save = QtWidgets.QPushButton("Save")
        self.bt_save.setToolTip("Save particle set")
        self.gbl.addWidget(self.bt_save, 1, 0, 1, 2)

        self.bt_load = QtWidgets.QPushButton("Load")
        self.bt_load.setToolTip("Load neural network")
        self.gbl.addWidget(self.bt_load, 2, 0, 1, 2)

        #self.bt_apply=QtWidgets.QPushButton("Apply")
        #self.bt_apply.setToolTip("Apply neural network")
        #self.gbl.addWidget(self.bt_apply, 4,0,1,2)

        #self.bt_new.clicked[bool].connect(self.new_nnet)
        self.bt_load.clicked[bool].connect(self.load_nnet)
        self.bt_train.clicked[bool].connect(self.train_nnet)
        self.bt_save.clicked[bool].connect(self.save_set)
        #self.bt_apply.clicked[bool].connect(self.apply_nnet)
        #self.bt_chgbx.clicked[bool].connect(self.change_boxsize)
        #self.box_display.currentIndexChanged.connect(self.do_update)

        self.val_learnrate = TextBox("LearnRate", 1e-4)
        self.gbl.addWidget(self.val_learnrate, 0, 2, 1, 1)

        nptcl = EMUtil.get_image_count(options.setname)
        self.val_ptclthr = TextBox("PtclThresh",
                                   int(nptcl * (1 - options.keep)))
        self.gbl.addWidget(self.val_ptclthr, 1, 2, 1, 1)

        self.val_niter = TextBox("Niter", 10)
        self.gbl.addWidget(self.val_niter, 2, 2, 1, 1)

        self.options = options
        self.app = weakref.ref(application)
        self.nnet = None
        self.trainset = []
        #self.nnetsize=96

        self.particles = EMData.read_images(options.setname)
        self.boxsz = self.particles[0]["nx"]

        self.ptclviewer = EMImageMXWidget()
        self.ptclviewer.setWindowTitle("Particles")
        self.sortidx = list(range(len(self.particles)))

        for boxview in [self.ptclviewer]:
            boxview.usetexture = False
            boxview.show()
            boxview.set_mouse_mode("Sets")
            boxview.rzonce = True

        self.ptclviewer.set_data(self.particles)
        self.ptclviewer.sets["bad_particles"] = set()
        self.ptclviewer.update()
        self.ptclimg = get_image(self.particles, len(self.particles))
        global tf
        tf = import_tensorflow(options.gpuid)
Пример #21
0
class EMTomoBoxer(QtGui.QMainWindow):
	"""This class represents the EMTomoBoxer application instance.  """
	keypress = QtCore.pyqtSignal(QtGui.QKeyEvent)
	module_closed = QtCore.pyqtSignal()

	def __init__(self,application,options,datafile):
		QtGui.QWidget.__init__(self)
		self.initialized=False
		self.app=weakref.ref(application)
		self.options=options
		self.yshort=False
		self.apix=options.apix
		self.currentset=0
		self.shrink=1#options.shrink
		self.setWindowTitle("Main Window (e2spt_boxer.py)")
		if options.mode=="3D":
			self.boxshape="circle"
		else:
			self.boxshape="rect"


		# Menu Bar
		self.mfile=self.menuBar().addMenu("File")
		self.mfile_open=self.mfile.addAction("Open")
		self.mfile_read_boxloc=self.mfile.addAction("Read Box Coord")
		self.mfile_save_boxloc=self.mfile.addAction("Save Box Coord")
		self.mfile_save_boxes_stack=self.mfile.addAction("Save Boxes as Stack")
		self.mfile_quit=self.mfile.addAction("Quit")


		self.setCentralWidget(QtGui.QWidget())
		self.gbl = QtGui.QGridLayout(self.centralWidget())

		# relative stretch factors
		self.gbl.setColumnStretch(0,1)
		self.gbl.setColumnStretch(1,4)
		self.gbl.setColumnStretch(2,0)
		self.gbl.setRowStretch(1,1)
		self.gbl.setRowStretch(0,4)

		# 3 orthogonal restricted projection views
		self.xyview = EMImage2DWidget()
		self.gbl.addWidget(self.xyview,0,1)

		self.xzview = EMImage2DWidget()
		self.gbl.addWidget(self.xzview,1,1)

		self.zyview = EMImage2DWidget()
		self.gbl.addWidget(self.zyview,0,0)

		# Select Z for xy view
		self.wdepth = QtGui.QSlider()
		self.gbl.addWidget(self.wdepth,1,2)

		### Control panel area in upper left corner
		self.gbl2 = QtGui.QGridLayout()
		self.gbl.addLayout(self.gbl2,1,0)

		#self.wxpos = QtGui.QSlider(Qt.Horizontal)
		#self.gbl2.addWidget(self.wxpos,0,0)
		
		#self.wypos = QtGui.QSlider(Qt.Vertical)
		#self.gbl2.addWidget(self.wypos,0,3,6,1)
		
		# box size
		self.wboxsize=ValBox(label="Box Size:",value=0)
		self.gbl2.addWidget(self.wboxsize,2,0,1,2)

		# max or mean
		#self.wmaxmean=QtGui.QPushButton("MaxProj")
		#self.wmaxmean.setCheckable(True)
		#self.gbl2.addWidget(self.wmaxmean,3,0)

		# number slices
		self.wnlayers=QtGui.QSpinBox()
		self.wnlayers.setMinimum(1)
		self.wnlayers.setMaximum(256)
		self.wnlayers.setValue(1)
		self.gbl2.addWidget(self.wnlayers,3,1)

		# Local boxes in side view
		self.wlocalbox=QtGui.QCheckBox("Limit Side Boxes")
		self.gbl2.addWidget(self.wlocalbox,3,0)
		self.wlocalbox.setChecked(True)

		# scale factor
		self.wscale=ValSlider(rng=(.1,2),label="Sca:",value=1.0)
		self.gbl2.addWidget(self.wscale,4,0,1,2)

		# 2-D filters
		self.wfilt = ValSlider(rng=(0,150),label="Filt:",value=0.0)
		self.gbl2.addWidget(self.wfilt,5,0,1,2)
		
		self.curbox=-1
		
		self.boxes=[]						# array of box info, each is (x,y,z,...)
		self.boxesimgs=[]					# z projection of each box
		self.xydown=self.xzdown=self.zydown=None
		self.firsthbclick = None

		# coordinate display
		self.wcoords=QtGui.QLabel("X: " + str(self.get_x()) + "\t\t" + "Y: " + str(self.get_y()) + "\t\t" + "Z: " + str(self.get_z()))
		self.gbl2.addWidget(self.wcoords, 1, 0, 1, 2)

		# file menu
		self.mfile_open.triggered[bool].connect(self.menu_file_open)
		self.mfile_read_boxloc.triggered[bool].connect(self.menu_file_read_boxloc)
		self.mfile_save_boxloc.triggered[bool].connect(self.menu_file_save_boxloc)
		self.mfile_save_boxes_stack.triggered[bool].connect(self.save_boxes)
		self.mfile_quit.triggered[bool].connect(self.menu_file_quit)

		# all other widgets
		self.wdepth.valueChanged[int].connect(self.event_depth)
		self.wnlayers.valueChanged[int].connect(self.event_nlayers)
		self.wboxsize.valueChanged.connect(self.event_boxsize)
		#self.wmaxmean.clicked[bool].connect(self.event_projmode)
		self.wscale.valueChanged.connect(self.event_scale)
		self.wfilt.valueChanged.connect(self.event_filter)
		self.wlocalbox.stateChanged[int].connect(self.event_localbox)

		self.xyview.mousemove.connect(self.xy_move)
		self.xyview.mousedown.connect(self.xy_down)
		self.xyview.mousedrag.connect(self.xy_drag)
		self.xyview.mouseup.connect(self.xy_up)
		self.xyview.mousewheel.connect(self.xy_wheel)
		self.xyview.signal_set_scale.connect(self.xy_scale)
		self.xyview.origin_update.connect(self.xy_origin)

		self.xzview.mousedown.connect(self.xz_down)
		self.xzview.mousedrag.connect(self.xz_drag)
		self.xzview.mouseup.connect(self.xz_up)
		self.xzview.signal_set_scale.connect(self.xz_scale)
		self.xzview.origin_update.connect(self.xz_origin)

		self.zyview.mousedown.connect(self.zy_down)
		self.zyview.mousedrag.connect(self.zy_drag)
		self.zyview.mouseup.connect(self.zy_up)
		self.zyview.signal_set_scale.connect(self.zy_scale)
		self.zyview.origin_update.connect(self.zy_origin)
		
		self.xyview.keypress.connect(self.key_press)
		self.datafilename=datafile
		self.basename=base_name(datafile)
		p0=datafile.find('__')
		if p0>0:
			p1=datafile.rfind('.')
			self.filetag=datafile[p0:p1]
			if self.filetag[-1]!='_':
				self.filetag+='_'
		else:
			self.filetag="__"
			
		data=EMData(datafile)
		self.set_data(data)

		# Boxviewer subwidget (details of a single box)
		self.boxviewer=EMBoxViewer()
		#self.app().attach_child(self.boxviewer)

		# Boxes Viewer (z projections of all boxes)
		self.boxesviewer=EMImageMXWidget()
		
		#self.app().attach_child(self.boxesviewer)
		self.boxesviewer.show()
		self.boxesviewer.set_mouse_mode("App")
		self.boxesviewer.setWindowTitle("Particle List")
		self.boxesviewer.rzonce=True
		
		self.setspanel=EMTomoSetsPanel(self)

		self.optionviewer=EMTomoBoxerOptions(self)
		self.optionviewer.add_panel(self.setspanel,"Sets")
		
		
		self.optionviewer.show()
		
		# Average viewer shows results of background tomographic processing
#		self.averageviewer=EMAverageViewer(self)
		#self.averageviewer.show()

		self.boxesviewer.mx_image_selected.connect(self.img_selected)
		
		self.jsonfile=info_name(datafile)
		
		info=js_open_dict(self.jsonfile)
		self.sets={}
		self.boxsize={}
		if "class_list" in info:
			clslst=info["class_list"]
			for k in sorted(clslst.keys()):
				if type(clslst[k])==dict:
					self.sets[int(k)]=str(clslst[k]["name"])
					self.boxsize[int(k)]=int(clslst[k]["boxsize"])
				else:
					self.sets[int(k)]=str(clslst[k])
					self.boxsize[int(k)]=boxsize
				
					
		
			
			
		clr=QtGui.QColor
		self.setcolors=[clr("blue"),clr("green"),clr("red"),clr("cyan"),clr("purple"),clr("orange"), clr("yellow"),clr("hotpink"),clr("gold")]
		self.sets_visible={}
				
		if "boxes_3d" in info:
			box=info["boxes_3d"]
			for i,b in enumerate(box):
				#### X-center,Y-center,Z-center,method,[score,[class #]]
				bdf=[0,0,0,"manual",0.0, 0]
				for j,bi in enumerate(b):  bdf[j]=bi
				
				
				if bdf[5] not in list(self.sets.keys()):
					clsi=int(bdf[5])
					self.sets[clsi]="particles_{:02d}".format(clsi)
					self.boxsize[clsi]=boxsize
				
				self.boxes.append(bdf)
		
		
		###### this is the new (2018-09) metadata standard..
		### now we use coordinates at full size from center of tomogram so it works for different binning and clipping
		### have to make it compatible with older versions though..
		if "apix_unbin" in info:
			self.apix_unbin=info["apix_unbin"]
			self.apix_cur=apix=data["apix_x"]
			for b in self.boxes:
				b[0]=b[0]/apix*self.apix_unbin+data["nx"]//2
				b[1]=b[1]/apix*self.apix_unbin+data["ny"]//2
				b[2]=b[2]/apix*self.apix_unbin+data["nz"]//2
				
			for k in self.boxsize.keys():
				self.boxsize[k]=self.boxsize[k]/apix*self.apix_unbin
		else:
			self.apix_unbin=-1
		
		
		
		info.close()
		if len(self.sets)==0:
			self.new_set("particles_00")
		self.sets_visible[list(self.sets.keys())[0]]=0
		self.currentset=sorted(self.sets.keys())[0]
		self.setspanel.update_sets()
		self.wboxsize.setValue(self.get_boxsize())

		print(self.sets)
		for i in range(len(self.boxes)):
			self.update_box(i)
		
		self.update_all()
		self.initialized=True

	def set_data(self,data):

		self.data=data
		self.apix=data["apix_x"]

		self.datasize=(data["nx"],data["ny"],data["nz"])

		self.wdepth.setRange(0,self.datasize[2]-1)
		self.boxes=[]
		self.curbox=-1

		self.wdepth.setValue(old_div(self.datasize[2],2))
		if self.initialized:
			self.update_all()

	def eraser_width(self):
		return int(self.optionviewer.eraser_radius.getValue())
		
	def get_cube(self,x,y,z, centerslice=False, boxsz=-1):
		"""Returns a box-sized cube at the given center location"""
		if boxsz<0:
			bs=self.get_boxsize()
		else:
			bs=boxsz
			
		if centerslice:
			bz=1
		else:
			bz=bs
		
		if ((x<-bs//2) or (y<-bs//2) or (z<-bz//2)
			or (x>self.data["nx"]+bs//2) or (y>self.data["ny"]+bs//2) or (z>self.data["nz"]+bz//2) ):
			r=EMData(bs,bs,bz)
		else:
			r=self.data.get_clip(Region(x-bs//2,y-bs//2,z-bz//2,bs,bs,bz))

		if self.apix!=0 :
			r["apix_x"]=self.apix
			r["apix_y"]=self.apix
			r["apix_z"]=self.apix

		#if options.normproc:
			#r.process_inplace(options.normproc)
		return r

	def get_slice(self,n,xyz):
		"""Reads a slice either from a file or the preloaded memory array.
		xyz is the axis along which 'n' runs, 0=x (yz), 1=y (xz), 2=z (xy)"""

		if xyz==0:
			r=self.data.get_clip(Region(n,0,0,1,self.datasize[1],self.datasize[2]))
			r.set_size(self.datasize[1],self.datasize[2],1)
		elif xyz==1:
			r=self.data.get_clip(Region(0,n,0,self.datasize[0],1,self.datasize[2]))
			r.set_size(self.datasize[0],self.datasize[2],1)
		else:
			r=self.data.get_clip(Region(0,0,n,self.datasize[0],self.datasize[1],1))

		if self.apix!=0 :
			r["apix_x"]=self.apix
			r["apix_y"]=self.apix
			r["apix_z"]=self.apix
		return r

	def event_boxsize(self):
		if self.get_boxsize()==int(self.wboxsize.getValue()):
			return
		
		self.boxsize[self.currentset]=int(self.wboxsize.getValue())
		
		cb=self.curbox
		self.initialized=False
		for i in range(len(self.boxes)):
			if self.boxes[i][5]==self.currentset:
				self.update_box(i)
		self.update_box(cb)
		self.initialized=True
		self.update_all()

	def event_projmode(self,state):
		"""Projection mode can be simple average (state=False) or maximum projection (state=True)"""
		self.update_all()

	def event_scale(self,newscale):
		self.xyview.set_scale(newscale)
		self.xzview.set_scale(newscale)
		self.zyview.set_scale(newscale)

	def event_depth(self):
		if self.initialized:
			self.update_xy()

	def event_nlayers(self):
		self.update_all()

	def event_filter(self):
		self.update_all()

	def event_localbox(self,tog):
		self.update_sides()

	def get_boxsize(self, clsid=-1):
		if clsid<0:
			return int(self.boxsize[self.currentset])
		else:
			try:
				ret= int(self.boxsize[clsid])
			except:
				print("No box size saved for {}..".format(clsid))
				ret=32
			return ret

	def nlayers(self):
		return int(self.wnlayers.value())

	def depth(self):
		return int(self.wdepth.value())

	def scale(self):
		return self.wscale.getValue()

	def get_x(self):
		return self.get_coord(0)

	def get_y(self):
		return self.get_coord(1)

	def get_z(self):
		return self.depth()

	def get_coord(self, coord_index):
		if len(self.boxes) > 1:
			if self.curbox:
				return self.boxes[self.curbox][coord_index]
			else:
				return self.boxes[-1][coord_index]
		else:
			return 0


	def menu_file_open(self,tog):
		QtGui.QMessageBox.warning(None,"Error","Sorry, in the current version, you must provide a file to open on the command-line.")

	def load_box_yshort(self, boxcoords):
		if options.yshort:
			return [boxcoords[0], boxcoords[2], boxcoords[1]]
		else:
			return boxcoords

	def menu_file_read_boxloc(self):
		fsp=str(QtGui.QFileDialog.getOpenFileName(self, "Select output text file"))

		f=file(fsp,"r")
		for b in f:
			b2=[old_div(int(float(i)),self.shrink) for i in b.split()[:3]]
			bdf=[0,0,0,"manual",0.0, self.currentset]
			for j in range(len(b2)):
				bdf[j]=b2[j]
			self.boxes.append(bdf)
			self.update_box(len(self.boxes)-1)
		f.close()

	def menu_file_save_boxloc(self):
		shrinkf=self.shrink 								#jesus

		fsp=str(QtGui.QFileDialog.getSaveFileName(self, "Select output text file"))

		out=file(fsp,"w")
		for b in self.boxes:
			out.write("%d\t%d\t%d\n"%(b[0]*shrinkf,b[1]*shrinkf,b[2]*shrinkf))
		out.close()


	def save_boxes(self, clsid=[]):
		if len(clsid)==0:
			defaultname="ptcls.hdf"
		else:
			defaultname="_".join([self.sets[i] for i in clsid])+".hdf"
		
		name,ok=QtGui.QInputDialog.getText( self, "Save particles", "Filename suffix:", text=defaultname)
		if not ok:
			return
		name=self.filetag+str(name)
		if name[-4:].lower()!=".hdf" :
			name+=".hdf"
			
			
		if self.options.mode=="3D":
			dr="particles3d"
			is2d=False
		else:
			dr="particles"
			is2d=True
		
		
		if not os.path.isdir(dr):
			os.mkdir(dr)
		
		fsp=os.path.join(dr,self.basename)+name

		print("Saving {} particles to {}".format(self.options.mode, fsp))
		
		if os.path.isfile(fsp):
			print("{} exist. Overwritting...".format(fsp))
			os.remove(fsp)
		
		progress = QtGui.QProgressDialog("Saving", "Abort", 0, len(self.boxes),None)
		
		
		boxsz=-1
		for i,b in enumerate(self.boxes):
			if len(clsid)>0:
				if int(b[5]) not in clsid:
					continue
			
			#img=self.get_cube(b[0],b[1],b[2])
			bs=self.get_boxsize(b[5])
			if boxsz<0:
				boxsz=bs
			else:
				if boxsz!=bs:
					print("Inconsistant box size in the particles to save.. Using {:d}..".format(boxsz))
					bs=boxsz
			
			sz=[s//2 for s in self.datasize]
			
			img=self.get_cube(b[0], b[1], b[2], centerslice=is2d, boxsz=bs)
			if is2d==False:
				img.process_inplace('normalize')
			
			img["ptcl_source_image"]=self.datafilename
			img["ptcl_source_coord"]=(b[0]-sz[0], b[1]-sz[1], b[2]-sz[2])
			
			if is2d==False: #### do not invert contrast for 2D images
				img.mult(-1)
			
			img.write_image(fsp,-1)

			progress.setValue(i+1)
			if progress.wasCanceled():
				break


	def menu_file_quit(self):
		self.close()

	def transform_coords(self, point, xform):
		xvec = xform.get_matrix()
		return [xvec[0]*point[0] + xvec[4]*point[1] + xvec[8]*point[2] + xvec[3], xvec[1]*point[0] + xvec[5]*point[1] + xvec[9]*point[2] + xvec[7], xvec[2]*point[0] + xvec[6]*point[1] + xvec[10]*point[2] + xvec[11]]

	def get_averager(self):
		"""returns an averager of the appropriate type for generating projection views"""
		#if self.wmaxmean.isChecked() : return Averagers.get("minmax",{"max":1})

		return Averagers.get("mean")

	def update_sides(self):
		"""updates xz and yz views due to a new center location"""

		#print "\n\n\n\n\nIn update sides, self.datafile is", self.datafile
		#print "\n\n\n\n"

		if self.data==None:
			return

		if self.curbox==-1 :
			x=self.datasize[0]//2
			y=self.datasize[1]//2
			z=0
		else:
			x,y,z=self.boxes[self.curbox][:3]

		self.cury=y
		self.curx=x

		# update shape display
		if self.wlocalbox.isChecked():
			xzs=self.xzview.get_shapes()
			for i in range(len(self.boxes)):
				bs=self.get_boxsize(self.boxes[i][5])
				if self.boxes[i][1]<self.cury+old_div(bs,2) and self.boxes[i][1]>self.cury-old_div(bs,2) and  self.boxes[i][5] in self.sets_visible:
					xzs[i][0]=self.boxshape
				else:
					xzs[i][0]="hidden"

			zys=self.zyview.get_shapes()
			
			for i in range(len(self.boxes)):
				bs=self.get_boxsize(self.boxes[i][5])
				if self.boxes[i][0]<self.curx+old_div(bs,2) and self.boxes[i][0]>self.curx-old_div(bs,2) and  self.boxes[i][5] in self.sets_visible:
					zys[i][0]=self.boxshape
				else:
					zys[i][0]="hidden"
		else :
			xzs=self.xzview.get_shapes()
			zys=self.zyview.get_shapes()
		
			for i in range(len(self.boxes)):
				bs=self.get_boxsize(self.boxes[i][5])
				if  self.boxes[i][5] in self.sets_visible:
					xzs[i][0]=self.boxshape
					zys[i][0]=self.boxshape
				else:
					xzs[i][0]="hidden"
					zys[i][0]="hidden"

		self.xzview.shapechange=1
		self.zyview.shapechange=1

		# yz
		avgr=self.get_averager()

		for x in range(x-(self.nlayers()//2),x+((self.nlayers()+1)//2)):
			slc=self.get_slice(x,0)
			avgr.add_image(slc)

		av=avgr.finish()
		if not self.yshort:
			av.process_inplace("xform.transpose")

		if self.wfilt.getValue()!=0.0:
			av.process_inplace("filter.lowpass.gauss",{"cutoff_freq":old_div(1.0,self.wfilt.getValue()),"apix":self.apix})

		self.zyview.set_data(av)

		# xz
		avgr=self.get_averager()

		for y in range(y-old_div(self.nlayers(),2),y+old_div((self.nlayers()+1),2)):
			slc=self.get_slice(y,1)
			avgr.add_image(slc)

		av=avgr.finish()
		if self.wfilt.getValue()!=0.0:
			av.process_inplace("filter.lowpass.gauss",{"cutoff_freq":old_div(1.0,self.wfilt.getValue()),"apix":self.apix})

		self.xzview.set_data(av)


	def update_xy(self):
		"""updates xy view due to a new slice range"""

		#print "\n\n\n\n\nIn update_xy, self.datafile is", self.datafile
		#print "\n\n\n\n"

		if self.data==None:
			return

		# Boxes should also be limited by default in the XY view
		if len(self.boxes) > 0:
			zc=self.wdepth.value()
			#print "The current depth is", self.wdepth.value()
			xys=self.xyview.get_shapes()
			for i in range(len(self.boxes)):

				bs=self.get_boxsize(self.boxes[i][5])
				zdist=abs(self.boxes[i][2] - zc)

				if self.options.mode=="3D":
					zthr=bs/2
					xys[i][6]=bs//2-zdist
				else:
					zthr=1
					
				if zdist < zthr and self.boxes[i][5] in self.sets_visible:
					xys[i][0]=self.boxshape
					
				else :
					xys[i][0]="hidden"
			self.xyview.shapechange=1

		#if self.wmaxmean.isChecked():
			#avgr=Averagers.get("minmax",{"max":1})

		#else:
		avgr=Averagers.get("mean")

		slc=EMData()
		for z in range(self.wdepth.value()-self.nlayers()//2,self.wdepth.value()+(self.nlayers()+1)//2):
			slc=self.get_slice(z,2)
			avgr.add_image(slc)

		av=avgr.finish()

		#print "\n\nIn update xy, av and type are", av, type(av)

		if self.wfilt.getValue()!=0.0:

			av.process_inplace("filter.lowpass.gauss",{"cutoff_freq":old_div(1.0,self.wfilt.getValue()),"apix":self.apix})
		if self.initialized:
			self.xyview.set_data(av, keepcontrast=True)
		else:
			self.xyview.set_data(av)

	def update_all(self):
		"""redisplay of all widgets"""

		#print "\n\n\n\n\nIn update all, self.datafile is", self.datafile
		#print "\n\n\n\n"
		if self.data==None:
			return

		self.update_xy()
		self.update_sides()
		self.update_boximgs()

	def update_coords(self):
		self.wcoords.setText("X: " + str(self.get_x()) + "\t\t" + "Y: " + str(self.get_y()) + "\t\t" + "Z: " + str(self.get_z()))

	def inside_box(self,n,x=-1,y=-1,z=-1):
		"""Checks to see if a point in image coordinates is inside box number n. If any value is negative, it will not be checked."""
		box=self.boxes[n]
		if box[5] not in self.sets_visible:
			return False
		bs=self.get_boxsize(box[5])/2
		if self.options.mode=="3D":
			rr=(x>=0)*((box[0]-x)**2) + (y>=0)*((box[1]-y) **2) + (z>=0)*((box[2]-z)**2)
		else:
			rr=(x>=0)*((box[0]-x)**2) + (y>=0)*((box[1]-y) **2) + (z>=0)*(box[2]!=z)*(1e3*bs**2)
		return rr<=bs**2

	def do_deletion(self, delids):
		
		kpids=[i for i,b in enumerate(self.boxes) if i not in delids]
		self.boxes=[self.boxes[i] for i in kpids]
		self.boxesimgs=[self.boxesimgs[i] for i in kpids]
		self.xyview.shapes={i:self.xyview.shapes[k] for i,k in enumerate(kpids)}
		self.xzview.shapes={i:self.xzview.shapes[k] for i,k in enumerate(kpids)}
		self.zyview.shapes={i:self.zyview.shapes[k] for i,k in enumerate(kpids)}
		#print self.boxes, self.xyview.get_shapes()
		self.curbox=-1
		self.update_all()

	def del_box(self,n):
		"""Delete an existing box by replacing the deleted box with the last box. A bit funny, but otherwise
		update after deletion is REALLY slow."""
#		print "del ",n
		if n<0 or n>=len(self.boxes): return

		if self.boxviewer.get_data(): self.boxviewer.set_data(None)
		self.curbox=-1
		self.do_deletion([n])




	def update_box(self,n,quiet=False):
		"""After adjusting a box, call this"""
#		print "upd ",n,quiet

		try:
			box=self.boxes[n]
		except IndexError:
			return
		bs2=self.get_boxsize(box[5])//2

		
		color=self.setcolors[box[5]%len(self.setcolors)].getRgbF()
		if self.options.mode=="3D":
			self.xyview.add_shape(n,EMShape(["circle",color[0],color[1],color[2],box[0],box[1],bs2,2]))
			self.xzview.add_shape(n,EMShape(["circle",color[0],color[1],color[2],box[0],box[2],bs2,2]))
			self.zyview.add_shape(n,EMShape(("circle",color[0],color[1],color[2],box[2],box[1],bs2,2)))
		else:
			self.xyview.add_shape(n,EMShape(["rect",color[0],color[1],color[2],
				    box[0]-bs2,box[1]-bs2,box[0]+bs2,box[1]+bs2,2]))
			self.xzview.add_shape(n,EMShape(["rect",color[0],color[1],color[2], 
				    box[0]-bs2,box[2]-1,box[0]+bs2,box[2]+1,2]))
			self.zyview.add_shape(n,EMShape(["rect",color[0],color[1],color[2],
				    box[2]-1,box[1]-bs2,box[2]+1,box[1]+bs2,2]))
			
			

		if self.depth()!=box[2]:
			self.wdepth.setValue(box[2])
		else:
			self.xyview.update()
		if self.initialized: self.update_sides()

		# For speed, we turn off updates while dragging a box around. Quiet is set until the mouse-up
		if not quiet:
			# Get the cube from the original data (normalized)
			proj=self.get_cube(box[0], box[1], box[2], centerslice=True, boxsz=self.get_boxsize(box[5]))
			proj.process_inplace("normalize")
			
			for i in range(len(self.boxesimgs),n+1): 
				self.boxesimgs.append(None)
			
			self.boxesimgs[n]=proj

			mm=[m for im,m in enumerate(self.boxesimgs) if self.boxes[im][5] in self.sets_visible]
			
			if self.initialized: self.SaveJson()
			
		if self.initialized:
			self.update_boximgs()
			

			if n!=self.curbox:
				self.boxesviewer.set_selected((n,),True)

		self.curbox=n
		self.update_coords()

	def update_boximgs(self):
		self.boxids=[im for im,m in enumerate(self.boxesimgs) if self.boxes[im][5] in self.sets_visible]
		self.boxesviewer.set_data([self.boxesimgs[i] for i in self.boxids])
		self.boxesviewer.update()
		return

	def img_selected(self,event,lc):
		#print "sel",lc[0]
		lci=self.boxids[lc[0]]
		if event.modifiers()&Qt.ShiftModifier:
			self.del_box(lci)
		else:
			self.update_box(lci)
		if self.curbox>=0 :
			box=self.boxes[self.curbox]
			self.xyview.scroll_to(box[0],box[1])
			self.xzview.scroll_to(None,box[2])
			self.zyview.scroll_to(box[2],None)
			self.currentset=box[5]
			self.setspanel.initialized=False
			self.setspanel.update_sets()

	def del_region_xy(self, x=-1, y=-1, z=-1, rad=-1):
		if rad<0:
			rad=self.eraser_width()
		
		delids=[]
		for i,b in enumerate(self.boxes):
			if b[5] not in self.sets_visible:
				continue
			
			if (x>=0)*(b[0]-x)**2 + (y>=0)*(b[1]-y)**2 +(z>=0)*(b[2]-z)**2 < rad**2:
				delids.append(i)
		self.do_deletion(delids)

	def xy_down(self,event):
		x,y=self.xyview.scr_to_img((event.x(),event.y()))
		x,y=int(x),int(y)
		z=int(self.get_z())
		self.xydown=None
		if x<0 or y<0 : return		# no clicking outside the image (on 2 sides)
		if self.optionviewer.erasercheckbox.isChecked():
			self.del_region_xy(x,y)
			return
			
		for i in range(len(self.boxes)):
			if self.inside_box(i,x,y,z):
				if event.modifiers()&Qt.ShiftModifier:
					self.del_box(i)
					self.firsthbclick = None
				else:
					self.xydown=(i,x,y,self.boxes[i][0],self.boxes[i][1])
					self.update_box(i)
				break
		else:
#			if x>self.get_boxsize()/2 and x<self.datasize[0]-self.get_boxsize()/2 and y>self.get_boxsize()/2 and y<self.datasize[1]-self.get_boxsize()/2 and self.depth()>self.get_boxsize()/2 and self.depth()<self.datasize[2]-self.get_boxsize()/2 :
			if not event.modifiers()&Qt.ShiftModifier:
				self.boxes.append(([x,y,self.depth(), 'manual', 0.0, self.currentset]))
				self.xydown=(len(self.boxes)-1,x,y,x,y)		# box #, x down, y down, x box at down, y box at down
				self.update_box(self.xydown[0])

		if self.curbox>=0:
			box=self.boxes[self.curbox]
			self.xzview.scroll_to(None,box[2])
			self.zyview.scroll_to(box[2],None)

	def xy_drag(self,event):
		
		x,y=self.xyview.scr_to_img((event.x(),event.y()))
		x,y=int(x),int(y)
		if self.optionviewer.erasercheckbox.isChecked():
			self.del_region_xy(x,y)
			self.xyview.eraser_shape=EMShape(["circle",1,1,1,x,y,self.eraser_width(),2])
			self.xyview.shapechange=1
			self.xyview.update()
			return
		
		if self.xydown==None : return


		dx=x-self.xydown[1]
		dy=y-self.xydown[2]

		self.boxes[self.xydown[0]][0]=dx+self.xydown[3]
		self.boxes[self.xydown[0]][1]=dy+self.xydown[4]
		self.update_box(self.curbox,True)

	def xy_up  (self,event):
		if self.xydown!=None: self.update_box(self.curbox)
		self.xydown=None

	def xy_wheel (self,event):
		if event.delta() > 0:
			#self.wdepth.setValue(self.wdepth.value()+4)
			self.wdepth.setValue(self.wdepth.value()+1) #jesus

		elif event.delta() < 0:
			#self.wdepth.setValue(self.wdepth.value()-4)
			self.wdepth.setValue(self.wdepth.value()-1) #jesus


	def xy_scale(self,news):
		"xy image view has been rescaled"
		self.wscale.setValue(news)
		#self.xzview.set_scale(news,True)
		#self.zyview.set_scale(news,True)

	def xy_origin(self,newor):
		"xy origin change"
		xzo=self.xzview.get_origin()
		self.xzview.set_origin(newor[0],xzo[1],True)

		zyo=self.zyview.get_origin()
		self.zyview.set_origin(zyo[0],newor[1],True)
	
	def xy_move(self,event):
		if self.optionviewer.erasercheckbox.isChecked():
			x,y=self.xyview.scr_to_img((event.x(),event.y()))
			#print x,y
			self.xyview.eraser_shape=EMShape(["circle",1,1,1,x,y,self.eraser_width(),2])
			self.xyview.shapechange=1
			self.xyview.update()
		else:
			self.xyview.eraser_shape=None

	def xz_down(self,event):
		x,z=self.xzview.scr_to_img((event.x(),event.y()))
		x,z=int(x),int(z)
		y=int(self.get_y())
		self.xzdown=None
		if x<0 or z<0 : return		# no clicking outside the image (on 2 sides)
		if self.optionviewer.erasercheckbox.isChecked():
			return
		for i in range(len(self.boxes)):
			if (not self.wlocalbox.isChecked() and self.inside_box(i,x,y,z)) or self.inside_box(i,x,self.cury,z) :
				if event.modifiers()&Qt.ShiftModifier:
					self.del_box(i)
					self.firsthbclick = None
				else :
					self.xzdown=(i,x,z,self.boxes[i][0],self.boxes[i][2])
					self.update_box(i)
				break
		else:
			if not event.modifiers()&Qt.ShiftModifier:
				self.boxes.append(([x,self.cury,z, 'manual', 0.0, self.currentset]))
				self.xzdown=(len(self.boxes)-1,x,z,x,z)		# box #, x down, y down, x box at down, y box at down
				self.update_box(self.xzdown[0])

		if self.curbox>=0 :
			box=self.boxes[self.curbox]
			self.xyview.scroll_to(None,box[1])
			self.zyview.scroll_to(box[2],None)

	def xz_drag(self,event):
		if self.xzdown==None : return

		x,z=self.xzview.scr_to_img((event.x(),event.y()))
		x,z=int(x),int(z)

		dx=x-self.xzdown[1]
		dz=z-self.xzdown[2]

		self.boxes[self.xzdown[0]][0]=dx+self.xzdown[3]
		self.boxes[self.xzdown[0]][2]=dz+self.xzdown[4]
		self.update_box(self.curbox,True)

	def xz_up  (self,event):
		if self.xzdown!=None: self.update_box(self.curbox)
		self.xzdown=None

	def xz_scale(self,news):
		"xy image view has been rescaled"
		self.wscale.setValue(news)
		#self.xyview.set_scale(news,True)
		#self.zyview.set_scale(news,True)

	def xz_origin(self,newor):
		"xy origin change"
		xyo=self.xyview.get_origin()
		self.xyview.set_origin(newor[0],xyo[1],True)

		#zyo=self.zyview.get_origin()
		#self.zyview.set_origin(zyo[0],newor[1],True)


	def zy_down(self,event):
		z,y=self.zyview.scr_to_img((event.x(),event.y()))
		z,y=int(z),int(y)
		x=int(self.get_x())
		self.xydown=None
		if z<0 or y<0 : return		# no clicking outside the image (on 2 sides)

		for i in range(len(self.boxes)):
			if (not self.wlocalbox.isChecked() and self.inside_box(i,x,y,z)) or  self.inside_box(i,self.curx,y,z):
				if event.modifiers()&Qt.ShiftModifier:
					self.del_box(i) 
					self.firsthbclick = None
				else :
					self.zydown=(i,z,y,self.boxes[i][2],self.boxes[i][1])
					self.update_box(i)
				break
		else:
			if not event.modifiers()&Qt.ShiftModifier:
				###########
				self.boxes.append(([self.curx,y,z, 'manual', 0.0, self.currentset]))
				self.zydown=(len(self.boxes)-1,z,y,z,y)		# box #, x down, y down, x box at down, y box at down
				self.update_box(self.zydown[0])

		if self.curbox>=0 :
			box=self.boxes[self.curbox]
			self.xyview.scroll_to(box[0],None)
			self.xzview.scroll_to(None,box[2])

	def zy_drag(self,event):
		if self.zydown==None : return

		z,y=self.zyview.scr_to_img((event.x(),event.y()))
		z,y=int(z),int(y)

		dz=z-self.zydown[1]
		dy=y-self.zydown[2]

		self.boxes[self.zydown[0]][2]=dz+self.zydown[3]
		self.boxes[self.zydown[0]][1]=dy+self.zydown[4]
		self.update_box(self.curbox,True)

	def zy_up  (self,event):
		if self.zydown!=None:
			self.update_box(self.curbox)
		self.zydown=None

	def zy_scale(self,news):
		"xy image view has been rescaled"
		self.wscale.setValue(news)
		#self.xyview.set_scale(news,True)
		#self.xzview.set_scale(news,True)

	def zy_origin(self,newor):
		"xy origin change"
		xyo=self.xyview.get_origin()
		self.xyview.set_origin(xyo[0],newor[1],True)

		#xzo=self.xzview.get_origin()
		#self.xzview.set_origin(xzo[0],newor[1],True)
	
	
	def set_current_set(self, name):
		
		#print "set current", name
		name=parse_setname(name)
		self.currentset=name
		self.wboxsize.setValue(self.get_boxsize())
		self.update_all()
		return
	
	
	def hide_set(self, name):
		name=parse_setname(name)
		
		if name in self.sets_visible: self.sets_visible.pop(name)
		
		
		if self.initialized: 
			self.update_all()
			self.update_boximgs()
		return
	
	
	def show_set(self, name):
		name=parse_setname(name)
		self.sets_visible[name]=0
		#self.currentset=name
		self.wboxsize.setValue(self.get_boxsize())
		if self.initialized: 
			self.update_all()
			self.update_boximgs()
		return
	
	
	def delete_set(self, name):
		name=parse_setname(name)
		## idx to keep
		delids=[i for i,b in enumerate(self.boxes) if b[5]==int(name)]
		self.do_deletion(delids)
		
		if name in self.sets_visible: self.sets_visible.pop(name)
		if name in self.sets: self.sets.pop(name)
		if name in self.boxsize: self.boxsize.pop(name)
		
		self.curbox=-1
		self.update_all()
		
		return
	
	def rename_set(self, oldname,  newname):
		name=parse_setname(oldname)
		if name in self.sets: 
			self.sets[name]=newname
		return
	
	
	def new_set(self, name):
		for i in range(len(self.sets)+1):
			if i not in self.sets:
				break
			
		self.sets[i]=name
		self.sets_visible[i]=0
		if self.options.mode=="3D":
			self.boxsize[i]=32
		else:
			self.boxsize[i]=64
		
		return
	
	def save_set(self):
		
		self.save_boxes(list(self.sets_visible.keys()))
		return
	
	
	def key_press(self,event):
		if event.key() == 96:
			self.wdepth.setValue(self.wdepth.value()+1)

		elif event.key() == 49:
			self.wdepth.setValue(self.wdepth.value()-1)
		else:
			self.keypress.emit(event)

	def SaveJson(self):
		
		info=js_open_dict(self.jsonfile)
		sx,sy,sz=(self.data["nx"]//2,self.data["ny"]//2,self.data["nz"]//2)
		if "apix_unbin" in info:
			bxs=[]
			for b0 in self.boxes:
				b=[	(b0[0]-sx)*self.apix_cur/self.apix_unbin,
					(b0[1]-sy)*self.apix_cur/self.apix_unbin,
					(b0[2]-sz)*self.apix_cur/self.apix_unbin,
					b0[3], b0[4], b0[5]	]
				bxs.append(b)
				
			bxsz={}
			for k in self.boxsize.keys():
				bxsz[k]=self.boxsize[k]*self.apix_cur/self.apix_unbin

				
		else:
			bxs=self.boxes
			bxsz=self.boxsize
				
		info["boxes_3d"]=bxs
		clslst={}
		for key in list(self.sets.keys()):
			clslst[int(key)]={
				"name":self.sets[key],
				"boxsize":int(bxsz[key]),
				}
		info["class_list"]=clslst
		info.close()
	
	def closeEvent(self,event):
		print("Exiting")
		self.SaveJson()
		
		self.boxviewer.close()
		self.boxesviewer.close()
		self.optionviewer.close()
		self.xyview.close()
		self.xzview.close()
		self.zyview.close()
		
		self.module_closed.emit() # this signal is important when e2ctf is being used by a program running its own event loop
Пример #22
0
class EMTomobox(QtWidgets.QMainWindow):

	def __init__(self,application,options,datafile=None):
		QtWidgets.QWidget.__init__(self)
		self.setMinimumSize(700,200)
		
		#### load references first
		self.reffile="info/boxrefs3d.hdf"
		if os.path.isfile(self.reffile):
			self.references=EMData.read_images(self.reffile)
		else:
			self.references=[]
			
		self.path="tomograms/"
		self.setCentralWidget(QtWidgets.QWidget())
		self.gbl = QtWidgets.QGridLayout(self.centralWidget())
		
		self.imglst=QtWidgets.QTableWidget(1, 5, self)
		#self.imglst.verticalHeader().hide()
		for i,w in enumerate([50,200,70,70,70]):
			self.imglst.setColumnWidth(i,w)
		self.imglst.setMinimumSize(450, 100)
		self.gbl.addWidget(self.imglst, 0,4,10,10)
		
		self.bt_new=QtWidgets.QPushButton("New")
		self.bt_new.setToolTip("Build new neural network")
		self.gbl.addWidget(self.bt_new, 0,0,1,2)
		
		self.bt_load=QtWidgets.QPushButton("Load")
		self.bt_load.setToolTip("Load neural network")
		self.gbl.addWidget(self.bt_load, 1,0,1,2)
		
		self.bt_train=QtWidgets.QPushButton("Train")
		self.bt_train.setToolTip("Train neural network")
		self.gbl.addWidget(self.bt_train, 2,0,1,2)
		
		self.bt_save=QtWidgets.QPushButton("Save")
		self.bt_save.setToolTip("Save neural network")
		self.gbl.addWidget(self.bt_save, 3,0,1,2)
		
		self.bt_apply=QtWidgets.QPushButton("Apply")
		self.bt_apply.setToolTip("Apply neural network")
		self.gbl.addWidget(self.bt_apply, 4,0,1,2)
		
		self.bt_chgbx=QtWidgets.QPushButton("ChangeBx")
		self.bt_chgbx.setToolTip("Change box size")
		self.gbl.addWidget(self.bt_chgbx, 5,0,1,2)
		
		self.bt_applyall=QtWidgets.QPushButton("ApplyAll")
		self.bt_applyall.setToolTip("Apply to all tomograms")
		self.gbl.addWidget(self.bt_applyall, 6,0,1,2)
		
		self.box_display = QtWidgets.QComboBox()
		self.box_display.addItem("References")
		self.box_display.addItem("Particles")
		self.gbl.addWidget(self.box_display, 0,2,1,1)
		
		
		self.bt_new.clicked[bool].connect(self.new_nnet)
		self.bt_load.clicked[bool].connect(self.load_nnet)
		self.bt_train.clicked[bool].connect(self.train_nnet)
		self.bt_save.clicked[bool].connect(self.save_nnet)
		self.bt_apply.clicked[bool].connect(self.apply_nnet)
		self.bt_chgbx.clicked[bool].connect(self.change_boxsize)
		self.bt_applyall.clicked[bool].connect(self.apply_nnet_all)
		self.box_display.currentIndexChanged.connect(self.do_update)

		self.val_targetsize=TextBox("TargetSize", 1)
		self.gbl.addWidget(self.val_targetsize, 1,2,1,1)
		
		self.val_learnrate=TextBox("LearnRate", 1e-4)
		self.gbl.addWidget(self.val_learnrate, 2,2,1,1)
		
		self.val_ptclthr=TextBox("PtclThresh", 0.8)
		self.gbl.addWidget(self.val_ptclthr, 3,2,1,1)
		
		self.val_circlesize=TextBox("CircleSize", 24)
		self.gbl.addWidget(self.val_circlesize, 4,2,1,1)
		
		self.val_niter=TextBox("Niter", 20)
		self.gbl.addWidget(self.val_niter, 5,2,1,1)
		
		self.val_posmult=TextBox("PosMult", 1)
		self.gbl.addWidget(self.val_posmult, 6,2,1,1)
		
		self.val_lossfun = QtWidgets.QComboBox()
		self.val_lossfun.addItem("Sum")
		self.val_lossfun.addItem("Max")
		self.gbl.addWidget(self.val_lossfun, 7,2,1,1)
		
		self.options=options
		self.app=weakref.ref(application)
		
		self.nnet=None
		global tf
		tf=import_tensorflow(options.gpuid)
		
		
		self.nnetsize=96
		if len(self.references)==0:
			self.boxsize=self.nnetsize
		else:
			self.boxsize=self.references[0]["boxsz"]
		self.thick=9
		self.datafile=""
		self.data=None
		
		if not os.path.isdir("neuralnets"):
			os.mkdir("neuralnets")
		
		self.imgview = EMImage2DWidget()
		self.boxesviewer=[EMImageMXWidget(), EMImageMXWidget()]
		self.boxesviewer[0].setWindowTitle("Negative")
		self.boxesviewer[1].setWindowTitle("Positive")
		self.ptclviewer=EMImageMXWidget()
		self.ptclviewer.setWindowTitle("Particles")
		
		for boxview in (self.boxesviewer+[self.ptclviewer]):
			boxview.usetexture=False
			boxview.show()
			boxview.set_mouse_mode("App")
			boxview.rzonce=True
		
		self.boximages=[]
		self.ptclimages=[]
		for img in self.references:
			self.add_boximage(img)
		
		self.imgview.mouseup.connect(self.on_tomo_mouseup)
		self.imgview.keypress.connect(self.key_press)
		self.boxesviewer[0].mx_image_selected.connect(self.on_boxpos_selected)
		self.boxesviewer[1].mx_image_selected.connect(self.on_boxneg_selected)
		self.ptclviewer.mx_image_selected.connect(self.on_ptcl_selected)
		self.imglst.cellClicked[int, int].connect(self.on_list_selected)
		
		self.boxshapes=BoxShapes(img=self.imgview, points=[] )
		self.imgview.shapes = {0:self.boxshapes}

		glEnable(GL_POINT_SMOOTH)
		glEnable(GL_LINE_SMOOTH );
		glEnable(GL_POLYGON_SMOOTH );
		glEnable(GL_BLEND);
		glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);
		self.trainset=[]
		self.tomoinp=[]
		self.segout=None
		self.curinfo=None
		
		self.update_list()
		self.do_update()
			
	
	def update_list(self):
		#### update file list
		files=natural_sort([os.path.join(self.path,f) for f in os.listdir(self.path)])
		self.imginfo=[]
		for i,name in enumerate(files):
			basename=base_name(name)
			jsfile=info_name(basename)
			if not os.path.isfile(jsfile):
				continue
			info={"id":len(self.imginfo), "name":name, "basename":basename, "nptcl":0, "gref":0, "bref":0, "clsid":-1}
			js=js_open_dict(jsfile)
			clsid=-1
			if ("class_list" in js) and ("boxes_3d" in js):
				cls=js["class_list"]
				
				for k in list(cls.keys()):
					vname=str(cls[k]["name"])
					if vname==self.options.label:
						clsid=int(k)
						break
					
				ptcls=[p for p in js["boxes_3d"] if p[5]==clsid]
				info["nptcl"]=len(ptcls)
			
			refs=[r for r in self.references if base_name(r["fromtomo"])==basename]
			info["gref"]=len([r for r in refs if r["label"]==1])
			info["bref"]=len([r for r in refs if r["label"]==0])
			info["clsid"]=clsid
				
			self.imginfo.append(info)
			js.close()
			
		#print(self.imginfo)
		self.imglst.clear()
		self.imglst.setRowCount(len(self.imginfo))
		self.imglst.setColumnCount(5)
		self.imglst.setHorizontalHeaderLabels(["ID", "FileName", "Ptcls", "GoodRef", "BadRef"])
		self.imglst.setColumnHidden(0, True)
		for i,info in enumerate(self.imginfo):
			#### use Qt.EditRole so we can sort them as numbers instead of strings
			it=QtWidgets.QTableWidgetItem()
			it.setData(Qt.EditRole, int(info["id"]))
			self.imglst.setItem(i,0,it)
			
			self.imglst.setItem(i,1,QtWidgets.QTableWidgetItem(str(info["basename"])))
			
			it=QtWidgets.QTableWidgetItem()
			it.setData(Qt.EditRole, info["nptcl"])
			self.imglst.setItem(i,2, it)
			
			
			it=QtWidgets.QTableWidgetItem()
			it.setData(Qt.EditRole, info["gref"])
			self.imglst.setItem(i,3, it)
			
			it=QtWidgets.QTableWidgetItem()
			it.setData(Qt.EditRole, info["bref"])
			self.imglst.setItem(i,4, it)
			
			
		self.imglst.setVerticalHeaderLabels([str(i) for i in range(len(self.imginfo))])
	
	
	def on_list_selected(self, row, col):
		if self.curinfo:
			self.save_points()
			self.update_list()
		
		idx=self.imglst.item(row, 0).text()
		self.curinfo=self.imginfo[int(idx)]
		self.set_data(self.curinfo["name"])
		
		#hdr=EMData(info["filename"], 0,True)
		#iz=hdr["nz"]//2
		#e=EMData(info["filename"], 0, False, Region(0,0,iz, hdr["nx"], hdr["ny"],1))
		#fac=float(hdr["nx"])/self.bt_show2d.width()*1.01
		#e.process_inplace('math.fft.resample',{"n":fac})
		#self.wg_thumbnail.set_data(e)
		
		
	def set_data(self, datafile):
		if self.datafile==datafile:
			return
		print("Reading {}...".format(datafile))
		self.datafile=datafile
		self.data=EMData(datafile)
		self.data.mult(self.options.mult)
		self.imgview.setWindowTitle(base_name(datafile))
		self.imgview.list_idx=self.data["nz"]//2
		self.imgview.set_data(self.data)
		self.imgview.show()
		self.infofile=info_name(datafile)
		
		js=js_open_dict(self.infofile)
		apix_cur=apix=self.data["apix_x"]
		apix_unbin=js["apix_unbin"]
		self.apix_scale=apix_cur/apix_unbin
		self.tomocenter= np.array([self.data["nx"],self.data["ny"],self.data["nz"]])/2
		self.ptclimages=[]
		if "boxes_3d" in js:
			ptcls=np.array([p[:3] for p in js["boxes_3d"] if p[5]==self.curinfo["clsid"]])
			if len(ptcls)>0:
				ptcls= ptcls  / self.apix_scale + self.tomocenter
				for p in ptcls.tolist():
					self.add_ptcls(p[0], p[1], p[2])
		
		js.close()
		self.tomoinp=[]
		self.segout=None
		self.do_update()
	
	def change_boxsize(self):
		size,ok=QtWidgets.QInputDialog.getText( self, "Box size", "Enter a new box size:")
		if not ok : return
		size=int(size)
		if size<self.nnetsize:
			print("Cannot set box size ({}) smaller than Network input size ({}). Stop.".format(size, self.nnetsize))
			return
		self.boxsize=size
		print("Updating references...")
		oldref=[r for r in self.references]
		self.references=[]
		self.boximages=[]
		curdatafile=self.datafile
		
		tomolist={r["fromtomo"] for r in oldref}
		print("  {} references from {} tomograms...".format(len(oldref), len(tomolist)))
		for tomo in tomolist:
			imgs=[r for r in oldref if r["fromtomo"]==tomo]
			self.datafile=tomo
			self.data=EMData(tomo)
			self.data.mult(self.options.mult)
			for m in imgs:
				p=m["pos"]
				self.add_reference(m["label"], p[0],p[1],p[2])
				
				
		oldref=[]
		self.datafile=curdatafile
		self.do_update()
		self.trainset= []
		self.tomoinp=[]
		if len(self.datafile)>0:
			self.curdata=EMData(self.datafile)
		else:
			self.curdata=None
		
	def new_nnet(self):
		print("New network..")
		self.nnet=NNet(boxsize=self.nnetsize, thick=self.thick)
		
	def train_nnet(self):
		if self.nnet==None:
			self.new_nnet()
			
		if len(self.trainset)==0:
			if len(self.references)==0:
				print("No references.")
				return
			
			print("Preparing training set...")
			labels=[b["label"] for b in self.references]
			ngood=np.mean(labels)
			nc=500/len(labels)
			ncopy=[int(round(nc/(1-ngood)+1)), int(round(nc/ngood+1))]
			#print(ncopy)
			imgs=[]
			labs=[]
			for i,p in enumerate(self.references):
				lb=labels[i]
				m=get_image(p, thick=self.thick, ncopy=ncopy[lb])
				imgs.extend(m)
				labs.extend([lb]*ncopy[lb])
			imgs=np.array(imgs, dtype=np.float32)
			labs=np.array(labs, dtype=np.float32)
			#print(len(labs),np.sum(labs==0), np.sum(labs==1), np.mean(labels))
			self.trainset=(imgs, labs)
			
		dataset = tf.data.Dataset.from_tensor_slices(self.trainset)
		dataset=dataset.shuffle(500).batch(32)
		usemax=(self.val_lossfun.currentText()=="Max")
		self.nnet.do_training(
			dataset, 
			learnrate=self.val_learnrate.getval(), 
			niter=int(self.val_niter.getval()),
			tarsz=self.val_targetsize.getval(),
			usemax=usemax, 
			posmult=self.val_posmult.getval()
			)
		
		self.segout=None
		print("Generating output...")
		
		imgs,labs=self.trainset
		idx=np.arange(len(imgs))
		np.random.shuffle(idx)
		idx=idx[:100]
		out=self.nnet.apply_network(imgs[idx])[:,:,:,0]
		
		outval=self.nnet.predict_class(np.array(imgs[idx], dtype=np.float32), usemax=usemax)
		outval=np.array(outval)[:,:,0]
		
		fname="neuralnets/trainouts.hdf"
		if os.path.isfile(fname):
			os.remove(fname)
			
		bx=self.nnetsize
		sz=out.shape[1]
		for i,o in enumerate(out):
			m=np.mean(imgs[idx[i]], axis=-1)/3.0
			m=from_numpy(m)
			m["score"]=[int(labs[idx[i]]),int(labs[idx[i]])]
			m.write_image(fname, -1)
			m=from_numpy(o)
			m=m.get_clip(Region(sz//2-bx//2, sz//2-bx//2, bx, bx))
			m.scale(4)
			m["score"]=[float(outval[0,i]),float(outval[1,i])]
			m.write_image(fname, -1)
		print("Output written to {}...".format(fname))
		
	
	def apply_nnet_all(self):
		if self.nnet==None:
			print("Neural network not initialized...")
			return
		
		modifiers = QtWidgets.QApplication.keyboardModifiers()
		skipexist=(modifiers == QtCore.Qt.ShiftModifier)
		if skipexist:
			print("Skipping tomograms with particles")
			
		for i,info in enumerate(self.imginfo):
			self.curinfo=info
			if skipexist and info["nptcl"]>0:
				print("Skipping {}..".format(info["name"]))
				continue
			self.set_data(info["name"])
			self.apply_nnet()
		
	def apply_nnet(self):
		bx=self.nnetsize//8
		thk=self.thick//8
		if self.data==None:
			print("No data loaded")
			return
		if self.nnet==None:
			print("Neural network not initialized...")
			return
		
		ts=[self.data["nx"], self.data["ny"], self.data["nz"]]
		if self.nnetsize==self.boxsize:
			data=self.data
		else:
			data=self.data.copy()
			scale=self.nnetsize/self.boxsize
			data.scale(scale)
			ts2=[int(t*scale) for t in ts]
			#print(ts, ts2)
			data=data.get_clip(Region(ts[0]//2-ts2[0]//2, ts[1]//2-ts2[1]//2, ts[2]//2-ts2[2]//2,  ts2[0], ts2[1], ts2[2]))
			ts=ts2
			
		if len(self.tomoinp)==0:
			print("Preparing input...")
			
			m=[]
			for i in range(0, ts[2],4):
				m.extend(get_image(data, pos=[ts[0]//2, ts[1]//2, i], thick=self.thick, ncopy=1))
				
			self.tomoinp=np.array(m)

		if self.segout==None:
			print("Applying...")
			out=self.nnet.apply_network(self.tomoinp)[:,:,:,0]
			#o=out.reshape((-1, 4, out.shape[1], out.shape[2]))
			#o=np.mean(o, axis=1)
			o=out.transpose(0,2,1).copy()
			o=from_numpy(o)
			o=o.get_clip(Region(o["nx"]//2-ts[0]//8, o["ny"]//2-ts[1]//8, o["nz"]//2-ts[2]//8, ts[0]//4, ts[1]//4, ts[2]//4))
			#o.scale(self.boxsize/self.nnetsize)
			o.process_inplace("mask.zeroedge3d",{"x0":bx,"x1":bx,"y0":bx,"y1":bx,"z0":thk,"z1":thk})
			o.process_inplace("filter.lowpass.gauss",{"cutoff_abs":.5})
			self.segout=o.copy()
			o.write_image("neuralnets/segout.hdf")
		
		print("Finding boxes...")
		o=self.segout.copy()
		self.ptclimages=[]
		o.process_inplace("mask.onlypeaks",{"npeaks":5})
		img=o.numpy().copy()
		img[img<self.val_ptclthr.getval()]=0
		
		pnew=np.array(np.where(img>self.val_ptclthr.getval())).T
		val=img[pnew[:,0], pnew[:,1], pnew[:,2]]
		pnew=pnew[np.argsort(-val)]

		pnew=pnew[:, ::-1]*4#+4
		pnew=pnew*self.boxsize/self.nnetsize

		dthr=self.val_circlesize.getval()
		
		
		tree=KDTree(pnew)

		tokeep=np.ones(len(pnew), dtype=bool)
		for i in range(len(pnew)):
			if tokeep[i]:
				k=tree.query_ball_point(pnew[i], dthr)
				tokeep[k]=False
				tokeep[i]=True
			
		#print(np.sum(tokeep))
		pts=pnew[tokeep].tolist()
		#scr=pkscore[tokeep]
		
		
		#dst=scipydist.cdist(pnew, pnew)+(np.eye(len(pnew))*dthr*100)
		#tokeep=np.ones(len(dst), dtype=bool)
		#for i in range(len(dst)):
			#if tokeep[i]:
				#tokeep[dst[i]<dthr]=False

		#pts=pnew[tokeep].tolist()

		for p in pts:
			self.add_ptcls(p[0], p[1], p[2])
			
		print("Found {} particles...".format(len(pts)))
		self.do_update()
		self.save_points()
		self.update_list()
		#print(self.boxshapes.points)
		
	def load_nnet(self):
		self.nnet=NNet.load_network("neuralnets/nnet_save.hdf", boxsize=self.nnetsize, thick=self.thick)
		
	def save_nnet(self):
		self.nnet.save_network("neuralnets/nnet_save.hdf")
		
	def key_press(self, event):
		#print(event.key())
		if event.key()==96:
			self.imgview.increment_list_data(1)
		elif event.key()==49:	
			self.imgview.increment_list_data(-1)
		self.imgview.shapechange=1
		self.imgview.updateGL()
		return
	
	def on_tomo_mouseup(self, event):
		x,y=self.imgview.scr_to_img((event.x(),event.y()))		
		x,y =np.round(x), np.round(y)
		
		if not event.button()&Qt.LeftButton:
			return
		
		if  event.modifiers()&Qt.ShiftModifier:
			### delete point
			z=self.imgview.list_idx
			pos=np.array([x,y,z])
			mode=self.box_display.currentText()
			if mode=="Particles":
				pts=np.array([b["pos"] for b in self.ptclimages])
			else:
				pts=np.array([img["pos"] for img in self.references])
			if len(pts)==0:
				return
			dst=np.linalg.norm(pts[:,:3]-pos, axis=1)
			
			if np.sum(dst<self.val_circlesize.getval())>0:
				idx=np.argsort(dst)[0]
				if mode=="Particles":
					self.ptclimages=[p for i,p in enumerate(self.ptclimages) if i!=idx]
					self.save_points()
				else:
					self.boxshapes.points.pop(idx)
					self.references=[b for ib,b in enumerate(self.references) if ib!=idx]
					self.boximages=[b for ib,b in enumerate(self.boximages) if ib!=idx]
					
			self.do_update()
			
		else:
			#### hold control to add negative references
			label=int(event.modifiers()&Qt.ControlModifier)
			label=int(label==0)
			self.add_reference(label, x, y)
			self.do_update()
		
	def on_boxpos_selected(self,event,lc):
		self.on_box_selected(event, lc, ib=0)
		
	def on_boxneg_selected(self,event,lc):
		self.on_box_selected(event, lc, ib=1)
		
	def on_box_selected(self,event, lc, ib):
		ic=int(1-ib)
		if event.modifiers()&Qt.ShiftModifier:
			if event.modifiers()&Qt.ControlModifier:
				torm=len(self.boxesviewer[ib].data)-lc[0]
			else:
				torm=1
			
			for i in range(torm):
				img=self.boxesviewer[ib].data[lc[0]+i]
				pos=np.array(img["pos"])
				self.rm_reference(pos)
				self.boximages=[b for b in self.boximages if b!=img]
			self.do_update()
			
	def on_ptcl_selected(self,event,lc):
		
		
		if  event.modifiers()&Qt.ShiftModifier:
			## delete ptcl
			if event.modifiers()&Qt.ControlModifier:
				self.ptclimages=[p for i,p in enumerate(self.ptclimages) if i<lc[0]]
			else:
				self.ptclimages=[p for i,p in enumerate(self.ptclimages) if i!=lc[0]]
			self.do_update()
			
			
		else:
			## add to good/bad refs
			if event.modifiers()&Qt.ControlModifier:
				ic=0
			else:# event.modifiers()&Qt.ShiftModifier:
				ic=1
			
			img=self.ptclimages[lc[0]]
			pos=img["pos"]
			self.add_reference(ic, pos[0], pos[1], pos[2])
			self.do_update()
			
			
	def rm_reference(self, pos):
		pts=np.array([b["pos"] for b in self.references])
		dst=np.linalg.norm(pts-pos, axis=1)
		if np.sum(dst<1)>0:
			idx=int(np.where(dst<1)[0][0])
			self.references=[b for i,b in enumerate(self.references) if i!=idx]
			self.trainset=[]
	
	def add_reference(self, label, x, y, z=None):
		if z==None:
			z=self.imgview.list_idx
		
		
		
		if self.nnetsize==self.boxsize:
			b=self.boxsize
			t=self.thick
			img=self.data.get_clip(Region(x-b//2, y-b//2, z-t//2, b, b, t))
			
		else:
			b=self.boxsize
			img=self.data.get_clip(Region(x-b//2, y-b//2, z-b//2, b, b, b))
			img.scale(self.nnetsize/self.boxsize)
			b2=self.nnetsize
			t=self.thick
			img=img.get_clip(Region(b//2-b2//2, b//2-b2//2, b//2-t//2,  b2, b2, t))
			
		img["pos"]=[x,y,z]
		img["fromtomo"]=self.datafile
		img["label"]=label
		img["boxsz"]=self.boxsize
		self.references.append(img)
		self.add_boximage(img)
		self.trainset=[]
		
	def add_boximage(self, img):
		img2d=img.get_clip(Region(0, 0, img["nz"]//2, img["nx"], img["nx"], 1))
		self.boximages.append(img2d)
			
	def add_ptcls(self, x, y, z):
		b=self.boxsize
		img=self.data.get_clip(Region(x-b//2, y-b//2, z, b, b, 1))
		img["pos"]=[x,y,z]
		self.ptclimages.append(img)

	def save_points(self):
		if self.data==None:
			return
		#print("Saving particles...")
		js=js_open_dict(self.infofile)
		label=self.options.label
		pts=np.array([b["pos"] for b in self.ptclimages])
		if len(pts)>0:
			pts=(pts - self.tomocenter) * self.apix_scale
		pts=np.round(pts).astype(int)
		clsid=self.curinfo["clsid"]
		print('save particles to {} : {}'.format(clsid, label))
		if not "class_list" in js:
			js["class_list"]={}
		if not "boxes_3d" in js:
			js["boxes_3d"]=[]
			
		clslst=js["class_list"]
		if clsid==-1:
			if len(clslst.keys())==0:
				clsid=0
			else:
				clsid=max([int(k) for k in clslst.keys()])+1
			clslst[str(clsid)]={"name":label, "boxsize": int(self.boxshapes.circlesize*8)}
			js["class_list"]=clslst
			self.curinfo["clsid"]=clsid
			
		boxes=[b for b in js["boxes_3d"] if b[5]!=clsid]
		boxes=boxes+[[p[0], p[1], p[2], "tomobox", 0.0, clsid] for p in pts.tolist()]
		js["boxes_3d"]=boxes
		js.close()
		#self.update_list()
		
	def clear_points(self):
		return
	
	#def box_display_changed(self):
		#print(self.box_display.currentText())
		
	
	def do_update(self):
		if self.box_display.currentText()=="Particles":
			pts=[b["pos"]+[1] for b in self.ptclimages]
		else:
			pts=[b["pos"]+[b["label"]] for b in self.references if b["fromtomo"]==self.datafile]
			
		
		#print(pts)
		self.boxshapes.points=pts
		self.boxshapes.circlesize=self.val_circlesize.getval()
		self.imgview.shapechange=1
		self.imgview.updateGL()
		
		self.ptclviewer.set_data(self.ptclimages)
		self.ptclviewer.update()
		for i in [0,1]:
			self.boxesviewer[i].set_data([b for b in self.boximages if b["label"]==i])
			self.boxesviewer[i].update()
		
	
	def save_references(self):
		if os.path.isfile(self.reffile):
			os.remove(self.reffile)
		
		for ref in self.references:
			ref.write_image(self.reffile, -1)
		
		
	def closeEvent(self, event):
		self.save_references()
		self.imgview.close()
		self.save_points()
		#self.imagelist.close()
		for b in  (self.boxesviewer+[self.ptclviewer]):
			b.close()
Пример #23
0
def main():

    progname = os.path.basename(sys.argv[0])
    usage = """e2findlines sets/img.lst
	
	** EXPERIMENTAL **
	this program looks for ~ straight line segments in images, such as wrinkles in graphene oxide films or possible C-film edges

	"""

    parser = EMArgumentParser(usage=usage, version=EMANVERSION)
    parser.add_argument("--threshold",
                        type=float,
                        help="Threshold for separating particles, default=3",
                        default=3.0)
    parser.add_argument("--newsets",
                        default=False,
                        action="store_true",
                        help="Split lines/nolines into 2 new sets")
    #parser.add_argument("--output",type=str,help="Output filename (text file)", default="ptclplot.txt")
    parser.add_argument("--gui",
                        default=False,
                        action="store_true",
                        help="show histogram of values")
    parser.add_argument(
        "--threads",
        default=4,
        type=int,
        help="Number of threads to run in parallel on the local computer")
    parser.add_argument(
        "--verbose",
        "-v",
        dest="verbose",
        action="store",
        metavar="n",
        type=int,
        default=0,
        help=
        "verbose level [0-9], higher number means higher level of verboseness")
    parser.add_argument(
        "--ppid",
        type=int,
        help="Set the PID of the parent process, used for cross platform PPID",
        default=-1)
    parser.add_argument(
        "--invar",
        default=False,
        action="store_true",
        help=
        "create the invar file for the newsets. The newsets option must be used."
    )
    parser.add_argument("--zscore",
                        default=True,
                        action="store_true",
                        help="run Z-score-based line finding.")
    parser.add_argument("--rdnxform",
                        default=False,
                        action="store_true",
                        help="detect lines via radon transform")
    parser.add_argument(
        "--rthreshold",
        default=25,
        help="see scikit-image.transform.radon() parameter documentation.")
    parser.add_argument(
        "--rsigma",
        default=3,
        help="see scikit-image.transform.radon() parameter documentation.")

    (options, args) = parser.parse_args()

    if (len(args) < 1):
        parser.error("Please specify an input stack/set to operate on")

    E2n = E2init(sys.argv, options.ppid)

    options.threads += 1  # one extra thread for storing results

    if options.rdnxform:
        options.zscore = False
        print("running e2findlines.py with Radon transform method.")

        n = EMUtil.get_image_count(args[0])
        lines = []

        if options.threads - 1 == 0:
            t1 = time.time()
            for i in range(n):
                im = EMData(args[0], i)
                radon_im = radon(im.numpy(), preserve_range=True)
                laplacian = blob_log(radon_im,
                                     threshold=options.rthreshold,
                                     min_sigma=options.rsigma)
                if len(laplacian) == 0:
                    lines.append(0)
                else:
                    lines.append(1)
                print(f"{i} out of {n} images analyzed" + ' ' * 20,
                      end='\b' *
                      (len(str(f"{i} out of {n} images analyzed")) + 20),
                      flush=True)
            t2 = time.time()
            print(f"Total time for rdnxform (nonthreaded): {t2-t1}s")

        if options.threads - 1 > 0 and options.threads <= n:
            t1 = time.time()
            print("running threaded rdnxform")
            threaded_indices = [[] for x in range(options.threads - 1)]
            for i in range(n):
                threaded_indices[i % (options.threads - 1)].append(i)

            #completed = 0

            class ImageBatch(threading.Thread):
                def __init__(self, threadId, name, data_indices):
                    threading.Thread.__init__(self)
                    self.threadId = threadId
                    self.name = name
                    self.data_indices = data_indices

                def run(self):
                    for i in self.data_indices:
                        im = EMData(args[0], i)
                        radon_im = radon(im.numpy(), preserve_range=True)
                        laplacian = blob_log(radon_im,
                                             threshold=options.rthreshold,
                                             min_sigma=options.rsigma)
                        if len(laplacian) == 0:
                            lines.append(0)
                        else:
                            lines.append(1)
                        print(
                            f"{i} out of {n} images analyzed" + ' ' * 20,
                            end='\b' *
                            (len(str(f"{i} out of {n} images analyzed")) + 20),
                            flush=True)

            threads = [
                ImageBatch(x, "thread_%s" % x, threaded_indices[x])
                for x in range(options.threads - 1)
            ]
            for thread in threads:
                thread.start()
            for thread in threads:
                thread.join()
            t2 = time.time()
            print(f"Total time rdnxform (threaded): {t2-t1}s")
            """for i in range(n):
				im=EMData(args[0],i)
				radon_im=radon(im.numpy(), preserve_range=True)
				laplacian=blob_log(radon_im, threshold=options.rthreshold, min_sigma=options.rsigma)
				if len(laplacian)==0:
					lines.append(0)
				else:
					lines.append(1)
				print(f"{i} out of {n} images analyzed"+' '*20, end ='\b'*(len(str(f"{i} out of {n} images analyzed"))+20), flush=True)"""

    if options.zscore:
        print("running e2findlines.py with Z-score method.")
        im0 = EMData(args[0], 0)  # first image
        r2 = im0["ny"] / 4  # outer radius

        # we build up a list of 'Z scores' which should be larger for images containing one or more parallel lines.
        # if 2 lines aren't parallel the number may be lower, even if the lines are strong, but should still be higher
        # than images without lines in most cases
        n = EMUtil.get_image_count(args[0])
        step = max(n // 500, 1)
        Z = []
        im2d = []
        for i in range(n):
            im = EMData(args[0], i)
            a = im.do_fft().calc_az_dist(60, -88.5, 3, 4, r2)
            d = np.array(a)
            Z.append((d.max() - d.mean()) / d.std())
            if i % step == 0:
                im["zscore"] = (d.max() - d.mean()) / d.std()
                im2d.append(im)

        if options.gui:
            # GUI display of a histogram of the Z scores
            from eman2_gui.emhist import EMHistogramWidget
            from eman2_gui.emimagemx import EMImageMXWidget
            from eman2_gui.emapplication import EMApp
            app = EMApp()
            histw = EMHistogramWidget(application=app)
            histw.set_data(Z)
            app.show_specific(histw)
            imd = EMImageMXWidget(application=app)
            im2d.sort(key=lambda x: x["zscore"])
            imd.set_data(im2d)
            app.show_specific(imd)
            app.exec_()
    """
	if options.newsets:
		lstin=LSXFile(args[0])

		# output containing images with lines
		linesfsp=args[0].rsplit(".",1)[0]+"_lines.lst"
		try: os.unlink(linesfsp)
		except: pass
		lstlines=LSXFile(linesfsp)	

		# output containin images without lines
		nolinesfsp=args[0].rsplit(".",1)[0]+"_nolines.lst"
		try: os.unlink(nolinesfsp)
		except: pass
		lstnolines=LSXFile(nolinesfsp)	

		for i,z in enumerate(Z):
			if z>options.threshold: lstlines[-1]=lstin[i]
			else: lstnolines[-1]=lstin[i]
	"""
    if options.newsets and not options.invar:
        lstin = LSXFile(args[0])

        # output containing images with lines
        linesfsp = args[0].split("__", 1)[0] + "_lines__" + args[0].split(
            "__", 1)[1]
        try:
            os.unlink(linesfsp)
        except:
            pass
        lstlines = LSXFile(linesfsp)

        # output containin images without lines
        nolinesfsp = args[0].split("__", 1)[0] + "_nolines__" + args[0].split(
            "__", 1)[1]
        try:
            os.unlink(nolinesfsp)
        except:
            pass
        lstnolines = LSXFile(nolinesfsp)

        if options.zscore:
            for i, z in enumerate(Z):
                if z > options.threshold: lstlines[-1] = lstin[i]
                else: lstnolines[-1] = lstin[i]
        if options.rdnxform:
            for i, r in enumerate(lines):
                if r != 0: lstlines[-1] = lstin[i]
                else: lstnolines[-1] = lstin[i]

    if options.newsets and options.invar:
        lstin = LSXFile(args[0])

        # output containing images with lines
        fnamemod = input("Type the filename modifier:   ")
        linesfsp = args[0].split(
            "__", 1)[0] + f"_lines_{fnamemod}__" + args[0].split("__", 1)[1]
        try:
            os.unlink(linesfsp)
        except:
            pass
        lstlines = LSXFile(linesfsp)

        # output containing images without lines
        nolinesfsp = args[0].split(
            "__", 1)[0] + f"_nolines_{fnamemod}__" + args[0].split("__", 1)[1]
        try:
            os.unlink(nolinesfsp)
        except:
            pass
        lstnolines = LSXFile(nolinesfsp)

        # output copy of nolines folder
        invarfsp = nolinesfsp.rsplit("_", 1)[0] + "_invar.lst"
        try:
            os.unlink(invarfsp)
        except:
            pass
        lstinvar = LSXFile(invarfsp)

        if options.zscore:
            for i, z in enumerate(Z):
                if z > options.threshold: lstlines[-1] = lstin[i]
                else:
                    lstnolines[-1] = lstin[i]
                    lstinvar[-1] = lstin[i]
        if options.rdnxform:
            for i, r in enumerate(lines):
                if r != 0: lstlines[-1] = lstin[i]
                else:
                    lstnolines[-1] = lstin[i]
                    lstinvar[-1] = lstin[i]

        print(f"running: e2proclst.py {invarfsp} --retype ctf_flip_invar")
        os.system(f"e2proclst.py {invarfsp} --retype ctf_flip_invar")
        print("invar file created.")

    E2end(E2n)
Пример #24
0
    def setData(self, data):
        if data == None:
            self.data = None
            return

        elif isinstance(data, str):
            self.datafile = data
            self.nimg = EMUtil.get_image_count(data)

            if self.dataidx >= 0 and self.dataidx < self.nimg:
                ii = self.dataidx
                self.nimg = 1
            else:
                ii = 0

            hdr = EMData(data, 0, 1)

            self.origdata = EMData(data, ii)

            if self.origdata["nz"] == 1:
                if self.nimg > 20 and hdr["ny"] > 512:
                    self.origdata = EMData.read_images(
                        data, list(range(0, self.nimg, self.nimg // 20))
                    )  # read regularly separated images from the file totalling ~20
                elif self.nimg > 100:
                    self.origdata = EMData.read_images(
                        data,
                        list(range(0, 72)) +
                        list(range(72, self.nimg, self.nimg // 100))
                    )  # read the first 36 then regularly separated images from the file
                elif self.nimg > 1:
                    self.origdata = EMData.read_images(data,
                                                       list(range(self.nimg)))
                else:
                    self.origdata = [self.origdata]
            else:
                self.origdata = [self.origdata]

        else:
            self.datafile = None
            if isinstance(data, EMData): self.origdata = [data]
            else: self.origdata = data

        self.nx = self.origdata[0]["nx"]
        self.ny = self.origdata[0]["ny"]
        self.nz = self.origdata[0]["nz"]
        if self.apix <= 0.0: self.apix = self.origdata[0]["apix_x"]
        EMProcessorWidget.parmdefault["apix"] = (0, (0.2, 10.0), self.apix,
                                                 None)

        origfft = self.origdata[0].do_fft()
        self.pspecorig = origfft.calc_radial_dist(old_div(self.ny, 2), 0.0,
                                                  1.0, 1)
        ds = old_div(1.0, (self.apix * self.ny))
        self.pspecs = [ds * i for i in range(len(self.pspecorig))]

        if self.viewer != None:
            for v in self.viewer:
                v.close()

        if self.nz == 1 or self.force2d or (self.nx > 320
                                            and self.safemode == False):
            if len(self.origdata) > 1:
                self.viewer = [EMImageMXWidget()]
                self.mfile_save_stack.setEnabled(True)
            else:
                self.viewer = [EMImage2DWidget()]
                self.mfile_save_stack.setEnabled(False)
        else:
            self.mfile_save_stack.setEnabled(False)
            self.viewer = [EMScene3D()]
            self.sgdata = EMDataItem3D(test_image_3d(3), transform=Transform())
            self.viewer[0].insertNewNode('Data',
                                         self.sgdata,
                                         parentnode=self.viewer[0])
            isosurface = EMIsosurface(self.sgdata, transform=Transform())
            self.viewer[0].insertNewNode("Iso",
                                         isosurface,
                                         parentnode=self.sgdata)
            volslice = EMSliceItem3D(self.sgdata, transform=Transform())
            self.viewer[0].insertNewNode("Slice",
                                         volslice,
                                         parentnode=self.sgdata)

        if self.nz > 1: self.mfile_save_map.setEnabled(True)
        else: self.mfile_save_map.setEnabled(False)

        E2loadappwin("e2filtertool", "image", self.viewer[0].qt_parent)
        if self.origdata[0].has_attr("source_path"):
            winname = str(self.origdata[0]["source_path"])
        else:
            winname = "FilterTool"
        self.viewer[0].setWindowTitle(winname)

        self.procChange(-1)
Пример #25
0
	def __init__(self,application,options,datafile=None):
		QtWidgets.QWidget.__init__(self)
		self.setMinimumSize(700,200)
		
		#### load references first
		self.reffile="info/boxrefs3d.hdf"
		if os.path.isfile(self.reffile):
			self.references=EMData.read_images(self.reffile)
		else:
			self.references=[]
			
		self.path="tomograms/"
		self.setCentralWidget(QtWidgets.QWidget())
		self.gbl = QtWidgets.QGridLayout(self.centralWidget())
		
		self.imglst=QtWidgets.QTableWidget(1, 5, self)
		#self.imglst.verticalHeader().hide()
		for i,w in enumerate([50,200,70,70,70]):
			self.imglst.setColumnWidth(i,w)
		self.imglst.setMinimumSize(450, 100)
		self.gbl.addWidget(self.imglst, 0,4,10,10)
		
		self.bt_new=QtWidgets.QPushButton("New")
		self.bt_new.setToolTip("Build new neural network")
		self.gbl.addWidget(self.bt_new, 0,0,1,2)
		
		self.bt_load=QtWidgets.QPushButton("Load")
		self.bt_load.setToolTip("Load neural network")
		self.gbl.addWidget(self.bt_load, 1,0,1,2)
		
		self.bt_train=QtWidgets.QPushButton("Train")
		self.bt_train.setToolTip("Train neural network")
		self.gbl.addWidget(self.bt_train, 2,0,1,2)
		
		self.bt_save=QtWidgets.QPushButton("Save")
		self.bt_save.setToolTip("Save neural network")
		self.gbl.addWidget(self.bt_save, 3,0,1,2)
		
		self.bt_apply=QtWidgets.QPushButton("Apply")
		self.bt_apply.setToolTip("Apply neural network")
		self.gbl.addWidget(self.bt_apply, 4,0,1,2)
		
		self.bt_chgbx=QtWidgets.QPushButton("ChangeBx")
		self.bt_chgbx.setToolTip("Change box size")
		self.gbl.addWidget(self.bt_chgbx, 5,0,1,2)
		
		self.bt_applyall=QtWidgets.QPushButton("ApplyAll")
		self.bt_applyall.setToolTip("Apply to all tomograms")
		self.gbl.addWidget(self.bt_applyall, 6,0,1,2)
		
		self.box_display = QtWidgets.QComboBox()
		self.box_display.addItem("References")
		self.box_display.addItem("Particles")
		self.gbl.addWidget(self.box_display, 0,2,1,1)
		
		
		self.bt_new.clicked[bool].connect(self.new_nnet)
		self.bt_load.clicked[bool].connect(self.load_nnet)
		self.bt_train.clicked[bool].connect(self.train_nnet)
		self.bt_save.clicked[bool].connect(self.save_nnet)
		self.bt_apply.clicked[bool].connect(self.apply_nnet)
		self.bt_chgbx.clicked[bool].connect(self.change_boxsize)
		self.bt_applyall.clicked[bool].connect(self.apply_nnet_all)
		self.box_display.currentIndexChanged.connect(self.do_update)

		self.val_targetsize=TextBox("TargetSize", 1)
		self.gbl.addWidget(self.val_targetsize, 1,2,1,1)
		
		self.val_learnrate=TextBox("LearnRate", 1e-4)
		self.gbl.addWidget(self.val_learnrate, 2,2,1,1)
		
		self.val_ptclthr=TextBox("PtclThresh", 0.8)
		self.gbl.addWidget(self.val_ptclthr, 3,2,1,1)
		
		self.val_circlesize=TextBox("CircleSize", 24)
		self.gbl.addWidget(self.val_circlesize, 4,2,1,1)
		
		self.val_niter=TextBox("Niter", 20)
		self.gbl.addWidget(self.val_niter, 5,2,1,1)
		
		self.val_posmult=TextBox("PosMult", 1)
		self.gbl.addWidget(self.val_posmult, 6,2,1,1)
		
		self.val_lossfun = QtWidgets.QComboBox()
		self.val_lossfun.addItem("Sum")
		self.val_lossfun.addItem("Max")
		self.gbl.addWidget(self.val_lossfun, 7,2,1,1)
		
		self.options=options
		self.app=weakref.ref(application)
		
		self.nnet=None
		global tf
		tf=import_tensorflow(options.gpuid)
		
		
		self.nnetsize=96
		if len(self.references)==0:
			self.boxsize=self.nnetsize
		else:
			self.boxsize=self.references[0]["boxsz"]
		self.thick=9
		self.datafile=""
		self.data=None
		
		if not os.path.isdir("neuralnets"):
			os.mkdir("neuralnets")
		
		self.imgview = EMImage2DWidget()
		self.boxesviewer=[EMImageMXWidget(), EMImageMXWidget()]
		self.boxesviewer[0].setWindowTitle("Negative")
		self.boxesviewer[1].setWindowTitle("Positive")
		self.ptclviewer=EMImageMXWidget()
		self.ptclviewer.setWindowTitle("Particles")
		
		for boxview in (self.boxesviewer+[self.ptclviewer]):
			boxview.usetexture=False
			boxview.show()
			boxview.set_mouse_mode("App")
			boxview.rzonce=True
		
		self.boximages=[]
		self.ptclimages=[]
		for img in self.references:
			self.add_boximage(img)
		
		self.imgview.mouseup.connect(self.on_tomo_mouseup)
		self.imgview.keypress.connect(self.key_press)
		self.boxesviewer[0].mx_image_selected.connect(self.on_boxpos_selected)
		self.boxesviewer[1].mx_image_selected.connect(self.on_boxneg_selected)
		self.ptclviewer.mx_image_selected.connect(self.on_ptcl_selected)
		self.imglst.cellClicked[int, int].connect(self.on_list_selected)
		
		self.boxshapes=BoxShapes(img=self.imgview, points=[] )
		self.imgview.shapes = {0:self.boxshapes}

		glEnable(GL_POINT_SMOOTH)
		glEnable(GL_LINE_SMOOTH );
		glEnable(GL_POLYGON_SMOOTH );
		glEnable(GL_BLEND);
		glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);
		self.trainset=[]
		self.tomoinp=[]
		self.segout=None
		self.curinfo=None
		
		self.update_list()
		self.do_update()