示例#1
0
文件: e2ctfsim.py 项目: phonchi/eman2
class GUIctfsim(QtGui.QWidget):
    module_closed = QtCore.pyqtSignal()

    def __init__(self,
                 application,
                 apix=1.0,
                 voltage=300.0,
                 cs=4.1,
                 ac=10.0,
                 samples=256,
                 apply=None):
        """CTF simulation dialog
		"""
        try:
            from eman2_gui.emimage2d import EMImage2DWidget
        except:
            print("Cannot import EMAN image GUI objects (EMImage2DWidget)")
            sys.exit(1)
        try:
            from eman2_gui.emplot2d import EMPlot2DWidget
        except:
            print(
                "Cannot import EMAN plot GUI objects (is matplotlib installed?)"
            )
            sys.exit(1)

        self.app = weakref.ref(application)

        self.df_voltage = voltage
        self.df_apix = apix
        self.df_cs = cs
        self.df_ac = ac
        self.df_samples = samples
        self.img = None

        if apply == None:
            self.apply = None
            self.applyim = None
        else:
            self.apply = EMData(apply, 0)
            self.df_apix = self.apply["apix_x"]
            print("A/pix reset to ", self.df_apix)
            self.applyim = EMImage2DWidget(application=self.app())

        QtGui.QWidget.__init__(self, None)
        self.setWindowIcon(QtGui.QIcon(get_image_directory() + "ctf.png"))

        self.data = []
        self.curset = 0
        self.plotmode = 0

        self.guiim = EMImage2DWidget(application=self.app())
        self.guiiminit = True  # a flag that's used to auto resize the first time the gui's set_data function is called
        self.guiplot = EMPlot2DWidget(application=self.app())
        #		self.guirealim=EMImage2DWidget(application=self.app())	# This will show the original particle images

        self.guiim.mousedown.connect(self.imgmousedown)
        self.guiim.mousedrag.connect(self.imgmousedrag)
        self.guiim.mouseup.connect(self.imgmouseup)
        self.guiplot.mousedown.connect(self.plotmousedown)

        self.guiim.mmode = "app"

        # This object is itself a widget we need to set up
        self.hbl = QtGui.QHBoxLayout(self)
        self.hbl.setMargin(0)
        self.hbl.setSpacing(6)
        self.hbl.setObjectName("hbl")

        # plot list and plot mode combobox
        self.vbl2 = QtGui.QVBoxLayout()
        self.setlist = MyListWidget(self)
        self.setlist.setSizePolicy(QtGui.QSizePolicy.Preferred,
                                   QtGui.QSizePolicy.Expanding)
        self.vbl2.addWidget(self.setlist)

        self.splotmode = QtGui.QComboBox(self)
        self.splotmode.addItem("Amplitude")
        self.splotmode.addItem("Intensity")
        self.splotmode.addItem("Int w sum")
        self.splotmode.addItem("Amp w sum")
        self.vbl2.addWidget(self.splotmode)
        self.hbl.addLayout(self.vbl2)

        # ValSliders for CTF parameters
        self.vbl = QtGui.QVBoxLayout()
        self.vbl.setMargin(0)
        self.vbl.setSpacing(6)
        self.vbl.setObjectName("vbl")
        self.hbl.addLayout(self.vbl)

        #self.samp = ValSlider(self,(0,5.0),"Amp:",0)
        #self.vbl.addWidget(self.samp)

        self.imginfo = QtGui.QLabel("Info", self)
        self.vbl.addWidget(self.imginfo)

        self.sdefocus = ValSlider(self, (0, 5), "Defocus:", 0, 90)
        self.vbl.addWidget(self.sdefocus)

        self.sbfactor = ValSlider(self, (0, 1600), "B factor:", 100, 90)
        self.vbl.addWidget(self.sbfactor)

        self.sdfdiff = ValSlider(self, (0, 1), "DF Diff:", 0, 90)
        self.vbl.addWidget(self.sdfdiff)

        self.sdfang = ValSlider(self, (0, 180), "Df Angle:", 0, 90)
        self.vbl.addWidget(self.sdfang)

        self.sampcont = ValSlider(self, (0, 100), "% AC", 0, 90)
        self.vbl.addWidget(self.sampcont)

        self.sphase = ValSlider(self, (0, 1), "Phase/pi", 0, 90)
        self.vbl.addWidget(self.sphase)

        self.sapix = ValSlider(self, (.2, 10), "A/Pix:", 2, 90)
        self.vbl.addWidget(self.sapix)

        self.svoltage = ValSlider(self, (0, 1000), "Voltage (kV):", 0, 90)
        self.vbl.addWidget(self.svoltage)

        self.scs = ValSlider(self, (0, 5), "Cs (mm):", 0, 90)
        self.vbl.addWidget(self.scs)

        self.ssamples = ValSlider(self, (32, 1024), "# Samples:", 0, 90)
        self.ssamples.setIntonly(True)
        self.vbl.addWidget(self.ssamples)

        self.hbl_buttons = QtGui.QHBoxLayout()
        self.newbut = QtGui.QPushButton("New")
        self.hbl_buttons.addWidget(self.newbut)
        self.vbl.addLayout(self.hbl_buttons)

        self.on_new_but()

        self.sdefocus.valueChanged.connect(self.newCTF)
        self.sbfactor.valueChanged.connect(self.newCTF)
        self.sdfdiff.valueChanged.connect(self.newCTF)
        self.sdfang.valueChanged.connect(self.newCTF)
        self.sapix.valueChanged.connect(self.newCTF)
        self.sampcont.valueChanged.connect(self.newCTFac)
        self.sphase.valueChanged.connect(self.newCTFpha)
        self.svoltage.valueChanged.connect(self.newCTF)
        self.scs.valueChanged.connect(self.newCTF)
        self.ssamples.valueChanged.connect(self.newCTF)
        self.setlist.currentRowChanged[int].connect(self.newSet)
        self.setlist.keypress.connect(self.listkey)
        self.splotmode.currentIndexChanged[int].connect(self.newPlotMode)

        self.newbut.clicked[bool].connect(self.on_new_but)

        self.resize(
            720, 380
        )  # figured these values out by printing the width and height in resize event

        E2loadappwin("e2ctfsim", "main", self)
        E2loadappwin("e2ctfsim", "image", self.guiim.qt_parent)
        #		E2loadappwin("e2ctf","realimage",self.guirealim.qt_parent)
        E2loadappwin("e2ctfsim", "plot", self.guiplot.qt_parent)

        self.setWindowTitle("CTF")

    def listkey(self, event):

        if event.key() >= Qt.Key_0 and event.key() <= Qt.Key_9:
            q = int(event.key()) - Qt.Key_0
            self.squality.setValue(q)
        elif event.key() == Qt.Key_Left:
            self.sdefocus.setValue(self.sdefocus.getValue() - 0.01)
        elif event.key() == Qt.Key_Right:
            self.sdefocus.setValue(self.sdefocus.getValue() + 0.01)
        elif event.key() == Qt.Key_R:
            self.on_recall_params()

    def on_new_but(self):
        ctf = EMAN2Ctf()
        ctf.defocus = 1.0
        ctf.bfactor = 100.0
        ctf.voltage = self.df_voltage
        ctf.apix = self.df_apix
        ctf.cs = self.df_cs
        ctf.ac = self.df_ac
        ctf.samples = self.df_samples
        self.data.append((str(len(self.setlist) + 1), ctf))
        self.curset = len(self.data)
        self.update_data()

    def show_guis(self):
        if self.guiim != None:
            self.app().show_specific(self.guiim)
        if self.applyim != None:
            self.app().show_specific(self.applyim)
        if self.guiplot != None:
            self.app().show_specific(self.guiplot)
        #if self.guirealim != None:
        #self.app().show_specific(self.guirealim)

        self.show()

    def closeEvent(self, event):
        #		QtGui.QWidget.closeEvent(self,event)
        #		self.app.app.closeAllWindows()
        E2saveappwin("e2ctf", "main", self)

        if self.guiim != None:
            E2saveappwin("e2ctf", "image", self.guiim.qt_parent)
            self.app().close_specific(self.guiim)
            self.guiim = None
        if self.applyim != None:
            self.app().close_specific(self.applyim)
            self.applyim = None
        if self.guiplot != None:
            E2saveappwin("e2ctf", "plot", self.guiplot.qt_parent)
            self.app().close_specific(self.guiplot)
        #if self.guirealim != None:
        #E2saveappwin("e2ctf","realimage",self.guirealim.qt_parent)
        #self.app().close_specific(self.guirealim)

        event.accept()
        self.app().close_specific(self)
        self.module_closed.emit(
        )  # this signal is important when e2ctf is being used by a program running its own event loop

    def update_data(self):
        """This will make sure the various widgets properly show the current data sets"""
        self.setlist.clear()
        for i, j in enumerate(self.data):
            self.setlist.addItem(j[0])
        self.setlist.setCurrentRow(self.curset)

    def update_plot(self):
        if self.guiplot == None: return  # it's closed/not visible

        for d in range(len(self.data)):
            ctf = self.data[d][1]
            ds = old_div(1.0, (ctf.apix * 2.0 * ctf.samples))
            s = arange(0, ds * ctf.samples, ds)

            curve = array(ctf.compute_1d(len(s) * 2, ds, Ctf.CtfType.CTF_AMP))
            if self.plotmode == 1 or self.plotmode == 2:
                curve = curve**2

            if self.plotmode == 2 or self.plotmode == 3:
                if d == 0: avg = curve[:]
                else:
                    if len(curve) != len(avg):
                        print(
                            "Number of samples must be fixed to compute an average ({})"
                            .format(d + 1))
                    else:
                        avg += curve

            self.guiplot.set_data((s, curve),
                                  self.data[d][0],
                                  d == 0,
                                  True,
                                  color=d + 1)

        if self.plotmode in (2, 3):
            self.guiplot.set_data((s, avg), "Sum", False, True, color=0)

        self.guiplot.setAxisParms("s (1/$\AA$)", "CTF")

        ctf.compute_2d_complex(self.img, Ctf.CtfType.CTF_AMP, None)
        self.guiim.set_data(self.img)

        if self.applyim != None:
            applyf = self.apply.do_fft()
            ctfmul = applyf.copy()
            ctf.compute_2d_complex(ctfmul, Ctf.CtfType.CTF_AMP)
            ctfsgn = applyf.copy()
            ctf.compute_2d_complex(ctfsgn, Ctf.CtfType.CTF_SIGN)
            applyf.mult(ctfmul)
            apply2 = applyf.do_ift()
            apply2.mult(
                5.0
            )  # roughly compensate for contrast reduction so apply comparable
            applyf.mult(ctfsgn)
            apply3 = applyf.do_ift()
            apply3.mult(5.0)
            self.applyim.set_data([apply2, apply3, self.apply])

    def newSet(self, val=0):
        "called when a new data set is selected from the list"
        self.curset = val

        self.sdefocus.setValue(self.data[val][1].defocus, True)
        self.sbfactor.setValue(self.data[val][1].bfactor, True)
        self.sapix.setValue(self.data[val][1].apix, True)
        self.sampcont.setValue(self.data[val][1].ampcont, True)
        self.sphase.setValue(old_div(self.data[val][1].get_phase(), pi), True)
        self.svoltage.setValue(self.data[val][1].voltage, True)
        self.scs.setValue(self.data[val][1].cs, True)
        self.sdfdiff.setValue(self.data[val][1].dfdiff, True)
        self.sdfang.setValue(self.data[val][1].dfang, True)
        self.ssamples.setValue(self.data[val][1].samples, True)

        # make new image if necessary
        if self.img == None or self.img["ny"] != self.data[val][1].samples:
            self.img = EMData(self.data[val][1].samples + 2,
                              self.data[val][1].samples)
            self.img.to_zero()
            self.img.set_complex(1)
        self.guiim.set_data(self.img)
        #		self.imginfo.setText("%s particles     SNR = %s"%(ptcl,ssnr))

        #if self.guiim != None:
        ##			print self.data
        #self.guiim.set_data(self.data[val][4])
        #if self.guiiminit:
        #self.guiim.optimally_resize()
        #self.guiiminit = False
        #self.guiim.updateGL()
        #self.update_plot()

        #		print "self.data[val]=",self.data[val][0].split('#')[-1]

        self.guiim.qt_parent.setWindowTitle("e2ctfsim - 2D FFT - " +
                                            self.data[val][0])
        #		self.guirealim.qt_parent.setWindowTitle("e2ctf - "+self.data[val][0].split('#')[-1])
        self.guiplot.qt_parent.setWindowTitle("e2ctfsim - Plot ")

        #n=EMUtil.get_image_count(self.data[val][0])
        #if n>1:
        #self.ptcldata=EMData.read_images(self.data[val][0],range(0,min(20,n)))
        #im=sum(self.ptcldata)
        #im.mult(1.0/len(self.ptcldata))
        #self.ptcldata.insert(0,im)
        #self.guirealim.set_data(self.ptcldata)
        #else : self.guirealim.set_data([EMData()])
        self.update_plot()

    def newPlotMode(self, mode):
        self.plotmode = mode
        self.update_plot()

    def newCTF(self):
        #		print traceback.print_stack()
        self.data[self.curset][1].defocus = self.sdefocus.value
        self.data[self.curset][1].bfactor = self.sbfactor.value
        self.data[self.curset][1].dfdiff = self.sdfdiff.value
        self.data[self.curset][1].dfang = self.sdfang.value
        self.data[self.curset][1].apix = self.sapix.value
        self.data[self.curset][1].ampcont = self.sampcont.value
        #		self.data[self.curset][1].set_phase(self.sphase.value)*pi
        self.data[self.curset][1].voltage = self.svoltage.value
        self.data[self.curset][1].cs = self.scs.value
        self.data[self.curset][1].samples = self.ssamples.value

        if self.img == None or self.img["ny"] != self.ssamples.value:
            self.img = EMData(self.ssamples.value + 2, self.ssamples.value)
            self.img.to_zero()
            self.img.set_complex(1)
            self.guiim.set_data(self.img)
        self.update_plot()

    def newCTFac(self):
        #		print traceback.print_stack()
        self.data[self.curset][1].ampcont = self.sampcont.value
        self.sphase.setValue(
            old_div(self.data[self.curset][1].get_phase(), pi), True)

        if self.img == None or self.img["ny"] != self.ssamples.value:
            self.img = EMData(self.ssamples.value + 2, self.ssamples.value)
            self.img.to_zero()
            self.img.set_complex(1)
            self.guiim.set_data(self.img)
        self.update_plot()

    def newCTFpha(self):
        #		print traceback.print_stack()
        self.data[self.curset][1].set_phase(self.sphase.value * pi)
        self.sampcont.setValue(self.data[self.curset][1].ampcont, True)

        if self.img == None or self.img["ny"] != self.ssamples.value:
            self.img = EMData(self.ssamples.value + 2, self.ssamples.value)
            self.img.to_zero()
            self.img.set_complex(1)
            self.guiim.set_data(self.img)
        self.update_plot()

    def imgmousedown(self, event):
        m = self.guiim.scr_to_img((event.x(), event.y()))
        #self.guiim.add_shape("cen",["rect",.9,.9,.4,x0,y0,x0+2,y0+2,1.0])

    def imgmousedrag(self, event):
        m = self.guiim.scr_to_img((event.x(), event.y()))

        # box deletion when shift held down
        #if event.modifiers()&Qt.ShiftModifier:
        #for i,j in enumerate(self.boxes):

    def imgmouseup(self, event):
        m = self.guiim.scr_to_img((event.x(), event.y()))

    def plotmousedown(self, event):
        m = self.guiim.scr_to_img((event.x(), event.y()))

    def run(self):
        """If you make your own application outside of this object, you are free to use
		your own local app.exec_(). This is a convenience for ctf-only programs."""
        self.app.exec_()

        #		E2saveappwin("boxer","imagegeom",self.guiim)
        #		try:
        #			E2setappval("boxer","imcontrol",self.guiim.inspector.isVisible())
        #			if self.guiim.inspector.isVisible() : E2saveappwin("boxer","imcontrolgeom",self.guiim.inspector)
        #		except : E2setappval("boxer","imcontrol",False)

        return
示例#2
0
class EMTomoBoxer(QtWidgets.QMainWindow):
	"""This class represents the EMTomoBoxer application instance.  """
	keypress = QtCore.pyqtSignal(QtGui.QKeyEvent)
	module_closed = QtCore.pyqtSignal()

	def __init__(self,application,options,datafile):
		QtWidgets.QWidget.__init__(self)
		self.initialized=False
		self.app=weakref.ref(application)
		self.options=options
		self.apix=options.apix
		self.currentset=0
		self.shrink=1#options.shrink
		self.setWindowTitle("Main Window (e2spt_boxer.py)")
		if options.mode=="3D":
			self.boxshape="circle"
		else:
			self.boxshape="rect"

		self.globalxf=Transform()
		
		# Menu Bar
		self.mfile=self.menuBar().addMenu("File")
		#self.mfile_open=self.mfile.addAction("Open")
		self.mfile_read_boxloc=self.mfile.addAction("Read Box Coord")
		self.mfile_save_boxloc=self.mfile.addAction("Save Box Coord")
		self.mfile_save_boxpdb=self.mfile.addAction("Save Coord as PDB")
		self.mfile_save_boxes_stack=self.mfile.addAction("Save Boxes as Stack")
		#self.mfile_quit=self.mfile.addAction("Quit")


		self.setCentralWidget(QtWidgets.QWidget())
		self.gbl = QtWidgets.QGridLayout(self.centralWidget())

		# relative stretch factors
		self.gbl.setColumnMinimumWidth(0,200)
		self.gbl.setRowMinimumHeight(0,200)
		self.gbl.setColumnStretch(0,0)
		self.gbl.setColumnStretch(1,100)
		self.gbl.setColumnStretch(2,0)
		self.gbl.setRowStretch(1,0)
		self.gbl.setRowStretch(0,100)
		

		# 3 orthogonal restricted projection views
		self.xyview = EMImage2DWidget(sizehint=(1024,1024))
		self.gbl.addWidget(self.xyview,0,1)

		self.xzview = EMImage2DWidget(sizehint=(1024,256))
		self.gbl.addWidget(self.xzview,1,1)

		self.zyview = EMImage2DWidget(sizehint=(256,1024))
		self.gbl.addWidget(self.zyview,0,0)

		# Select Z for xy view
		self.wdepth = QtWidgets.QSlider()
		self.gbl.addWidget(self.wdepth,1,2)

		### Control panel area in upper left corner
		self.gbl2 = QtWidgets.QGridLayout()
		self.gbl.addLayout(self.gbl2,1,0)

		#self.wxpos = QtWidgets.QSlider(Qt.Horizontal)
		#self.gbl2.addWidget(self.wxpos,0,0)
		
		#self.wypos = QtWidgets.QSlider(Qt.Vertical)
		#self.gbl2.addWidget(self.wypos,0,3,6,1)
		
		# box size
		self.wboxsize=ValBox(label="Box Size:",value=0)
		self.gbl2.addWidget(self.wboxsize,2,0)

		# max or mean
		#self.wmaxmean=QtWidgets.QPushButton("MaxProj")
		#self.wmaxmean.setCheckable(True)
		#self.gbl2.addWidget(self.wmaxmean,3,0)

		# number slices
		label0=QtWidgets.QLabel("Thickness")
		self.gbl2.addWidget(label0,3,0)

		self.wnlayers=QtWidgets.QSpinBox()
		self.wnlayers.setMinimum(1)
		self.wnlayers.setMaximum(256)
		self.wnlayers.setValue(1)
		self.gbl2.addWidget(self.wnlayers,3,1)

		# Local boxes in side view
		self.wlocalbox=QtWidgets.QCheckBox("Limit Side Boxes")
		self.gbl2.addWidget(self.wlocalbox,4,0)
		self.wlocalbox.setChecked(True)
		
		self.button_flat = QtWidgets.QPushButton("Flatten")
		self.gbl2.addWidget(self.button_flat,5,0)
		self.button_reset = QtWidgets.QPushButton("Reset")
		self.gbl2.addWidget(self.button_reset,5,1)
		## scale factor
		#self.wscale=ValSlider(rng=(.1,2),label="Sca:",value=1.0)
		#self.gbl2.addWidget(self.wscale,4,0,1,2)

		# 2-D filters
		self.wfilt = ValSlider(rng=(0,150),label="Filt",value=0.0)
		self.gbl2.addWidget(self.wfilt,6,0,1,2)
		
		self.curbox=-1
		
		self.boxes=[]						# array of box info, each is (x,y,z,...)
		self.boxesimgs=[]					# z projection of each box
		self.dragging=-1

		##coordinate display
		self.wcoords=QtWidgets.QLabel("")
		self.gbl2.addWidget(self.wcoords, 1, 0, 1, 2)
		
		self.button_flat.clicked[bool].connect(self.flatten_tomo)
		self.button_reset.clicked[bool].connect(self.reset_flatten_tomo)

		# file menu
		#self.mfile_open.triggered[bool].connect(self.menu_file_open)
		self.mfile_read_boxloc.triggered[bool].connect(self.menu_file_read_boxloc)
		self.mfile_save_boxloc.triggered[bool].connect(self.menu_file_save_boxloc)
		self.mfile_save_boxpdb.triggered[bool].connect(self.menu_file_save_boxpdb)
		
		self.mfile_save_boxes_stack.triggered[bool].connect(self.save_boxes)
		#self.mfile_quit.triggered[bool].connect(self.menu_file_quit)

		# all other widgets
		self.wdepth.valueChanged[int].connect(self.event_depth)
		self.wnlayers.valueChanged[int].connect(self.event_nlayers)
		self.wboxsize.valueChanged.connect(self.event_boxsize)
		#self.wmaxmean.clicked[bool].connect(self.event_projmode)
		#self.wscale.valueChanged.connect(self.event_scale)
		self.wfilt.valueChanged.connect(self.event_filter)
		self.wlocalbox.stateChanged[int].connect(self.event_localbox)

		self.xyview.mousemove.connect(self.xy_move)
		self.xyview.mousedown.connect(self.xy_down)
		self.xyview.mousedrag.connect(self.xy_drag)
		self.xyview.mouseup.connect(self.mouse_up)
		self.xyview.mousewheel.connect(self.xy_wheel)
		self.xyview.signal_set_scale.connect(self.event_scale)
		self.xyview.origin_update.connect(self.xy_origin)

		self.xzview.mousedown.connect(self.xz_down)
		self.xzview.mousedrag.connect(self.xz_drag)
		self.xzview.mouseup.connect(self.mouse_up)
		self.xzview.mousewheel.connect(self.xz_wheel)
		self.xzview.signal_set_scale.connect(self.event_scale)
		self.xzview.origin_update.connect(self.xz_origin)
		self.xzview.mousemove.connect(self.xz_move)

		self.zyview.mousedown.connect(self.zy_down)
		self.zyview.mousedrag.connect(self.zy_drag)
		self.zyview.mouseup.connect(self.mouse_up)
		self.zyview.mousewheel.connect(self.zy_wheel)
		self.zyview.signal_set_scale.connect(self.event_scale)
		self.zyview.origin_update.connect(self.zy_origin)
		self.zyview.mousemove.connect(self.zy_move)
		
		self.xyview.keypress.connect(self.key_press)
		self.datafilename=datafile
		self.basename=base_name(datafile)
		p0=datafile.find('__')
		if p0>0:
			p1=datafile.rfind('.')
			self.filetag=datafile[p0:p1]
			if self.filetag[-1]!='_':
				self.filetag+='_'
		else:
			self.filetag="__"
			
		data=EMData(datafile)
		self.set_data(data)

		# Boxviewer subwidget (details of a single box)
		#self.boxviewer=EMBoxViewer()
		#self.app().attach_child(self.boxviewer)

		# Boxes Viewer (z projections of all boxes)
		self.boxesviewer=EMImageMXWidget()
		
		#self.app().attach_child(self.boxesviewer)
		self.boxesviewer.show()
		self.boxesviewer.set_mouse_mode("App")
		self.boxesviewer.setWindowTitle("Particle List")
		self.boxesviewer.rzonce=True
		
		self.setspanel=EMTomoSetsPanel(self)

		self.optionviewer=EMTomoBoxerOptions(self)
		self.optionviewer.add_panel(self.setspanel,"Sets")
		
		
		self.optionviewer.show()
		
		self.boxesviewer.mx_image_selected.connect(self.img_selected)
		
		##################
		#### deal with metadata in the _info.json file...
		
		self.jsonfile=info_name(datafile)
		info=js_open_dict(self.jsonfile)
		
		#### read particle classes
		self.sets={}
		self.boxsize={}
		if "class_list" in info:
			clslst=info["class_list"]
			for k in sorted(clslst.keys()):
				if type(clslst[k])==dict:
					self.sets[int(k)]=str(clslst[k]["name"])
					self.boxsize[int(k)]=int(clslst[k]["boxsize"])
				else:
					self.sets[int(k)]=str(clslst[k])
					self.boxsize[int(k)]=64
					
		clr=QtGui.QColor
		self.setcolors=[QtGui.QBrush(clr("blue")),QtGui.QBrush(clr("green")),QtGui.QBrush(clr("red")),QtGui.QBrush(clr("cyan")),QtGui.QBrush(clr("purple")),QtGui.QBrush(clr("orange")), QtGui.QBrush(clr("yellow")),QtGui.QBrush(clr("hotpink")),QtGui.QBrush(clr("gold"))]
		self.sets_visible={}
				
		#### read boxes
		if "boxes_3d" in info:
			box=info["boxes_3d"]
			for i,b in enumerate(box):
				#### X-center,Y-center,Z-center,method,[score,[class #]]
				bdf=[0,0,0,"manual",0.0, 0]
				for j,bi in enumerate(b):  bdf[j]=bi
				
				
				if bdf[5] not in list(self.sets.keys()):
					clsi=int(bdf[5])
					self.sets[clsi]="particles_{:02d}".format(clsi)
					self.boxsize[clsi]=64
				
				self.boxes.append(bdf)
		
		###### this is the new (2018-09) metadata standard..
		### now we use coordinates at full size from center of tomogram so it works for different binning and clipping
		### have to make it compatible with older versions though..
		if "apix_unbin" in info:
			self.apix_unbin=info["apix_unbin"]
			self.apix_cur=apix=data["apix_x"]
			for b in self.boxes:
				b[0]=b[0]/apix*self.apix_unbin+data["nx"]//2
				b[1]=b[1]/apix*self.apix_unbin+data["ny"]//2
				b[2]=b[2]/apix*self.apix_unbin+data["nz"]//2
				
			for k in self.boxsize.keys():
				self.boxsize[k]=int(np.round(self.boxsize[k]*self.apix_unbin/apix))
		else:
			self.apix_unbin=-1
			
		info.close()
		
		E2loadappwin("e2sptboxer","main",self)
		E2loadappwin("e2sptboxer","boxes",self.boxesviewer.qt_parent)
		E2loadappwin("e2sptboxer","option",self.optionviewer)
		
		#### particle classes
		if len(self.sets)==0:
			self.new_set("particles_00")
		self.sets_visible[list(self.sets.keys())[0]]=0
		self.currentset=sorted(self.sets.keys())[0]
		self.setspanel.update_sets()
		self.wboxsize.setValue(self.get_boxsize())

		#print(self.sets)
		for i in range(len(self.boxes)):
			self.update_box(i)
		
		self.update_all()
		self.initialized=True
		

	def set_data(self,data):

		self.data=data
		self.apix=data["apix_x"]

		self.datasize=(data["nx"],data["ny"],data["nz"])
		self.x_loc, self.y_loc, self.z_loc=data["nx"]//2,data["ny"]//2,data["nz"]//2

		self.gbl.setRowMinimumHeight(1,max(250,data["nz"]))
		self.gbl.setColumnMinimumWidth(0,max(250,data["nz"]))
		print(data["nx"],data["ny"],data["nz"])

		self.wdepth.setRange(0,data["nz"]-1)
		self.wdepth.setValue(data["nz"]//2)
		self.boxes=[]
		self.curbox=-1

		if self.initialized:
			self.update_all()

	def eraser_width(self):
		return int(self.optionviewer.eraser_radius.getValue())
		
	def get_cube(self,x,y,z, centerslice=False, boxsz=-1):
		"""Returns a box-sized cube at the given center location"""
		if boxsz<0:
			bs=self.get_boxsize()
		else:
			bs=boxsz
			
		if centerslice:
			bz=1
		else:
			bz=bs
		
		if ((x<-bs//2) or (y<-bs//2) or (z<-bz//2)
			or (x>self.data["nx"]+bs//2) or (y>self.data["ny"]+bs//2) or (z>self.data["nz"]+bz//2) ):
			r=EMData(bs,bs,bz)
		else:
			r=self.data.get_clip(Region(x-bs//2,y-bs//2,z-bz//2,bs,bs,bz))

		if self.apix!=0 :
			r["apix_x"]=r["apix_y"]=r["apix_z"]=self.apix

		return r

	def get_slice(self,idx,thk=1,axis="z"):
		if self.globalxf.is_identity():
			data=self.data
		else:
			data=self.dataxf
			
		t=int(thk-1)
		idx=int(idx)
		r=data.process("misc.directional_sum",{"axis":axis,"first":idx-t,"last":idx+t})
		r.div(t*2+1)
		
		if self.apix!=0 :
			r["apix_x"]=r["apix_y"]=r["apix_z"]=self.apix
		return r

	def event_boxsize(self):
		if self.get_boxsize()==int(self.wboxsize.getValue()):
			return
		
		self.boxsize[self.currentset]=int(self.wboxsize.getValue())
		
		#cb=self.curbox
		self.initialized=False
		for i in range(len(self.boxes)):
			if self.boxes[i][5]==self.currentset:
				self.update_box(i)
		#self.update_box(cb)
		self.initialized=True
		self.update_all()

	#def event_projmode(self,state):
		#"""Projection mode can be simple average (state=False) or maximum projection (state=True)"""
		#self.update_all()

	def event_scale(self,newscale):
		self.xyview.set_scale(newscale)
		self.xzview.set_scale(newscale)
		self.zyview.set_scale(newscale)

	def event_depth(self):
		if self.z_loc!=self.wdepth.value():
			self.z_loc=self.wdepth.value()
		if self.initialized:
			self.update_sliceview()

	def event_nlayers(self):
		self.update_all()

	def event_filter(self):
		self.update_all()

	def event_localbox(self,tog):
		self.update_sliceview(['x','y'])

	def get_boxsize(self, clsid=-1):
		if clsid<0:
			return int(self.boxsize[self.currentset])
		else:
			try:
				ret= int(self.boxsize[clsid])
			except:
				print("No box size saved for {}..".format(clsid))
				ret=32
			return ret

	def nlayers(self):
		return int(self.wnlayers.value())

	def menu_file_read_boxloc(self):
		fsp=str(QtWidgets.QFileDialog.getOpenFileName(self, "Select output text file")[0])
		
		if not os.path.isfile(fsp):
			print("file does not exist")
			return

		f=open(fsp,"r")
		for b in f:
			b2=[old_div(int(float(i)),self.shrink) for i in b.split()[:3]]
			bdf=[0,0,0,"manual",0.0, self.currentset]
			for j in range(len(b2)):
				bdf[j]=b2[j]
			self.boxes.append(bdf)
			self.update_box(len(self.boxes)-1)
		f.close()

	def menu_file_save_boxloc(self):
		shrinkf=self.shrink 								#jesus

		fsp=str(QtWidgets.QFileDialog.getSaveFileName(self, "Select output text file")[0])
		if len(fsp)==0:
			return

		out=open(fsp,"w")
		for b in self.boxes:
			out.write("%d\t%d\t%d\n"%(b[0]*shrinkf,b[1]*shrinkf,b[2]*shrinkf))
		out.close()
		
	def menu_file_save_boxpdb(self):
		fsp=str(QtWidgets.QFileDialog.getSaveFileName(self, "Select output PDB file", filter="PDB (*.pdb)")[0])
		if len(fsp)==0:
			return
		if fsp[-4:].lower()!=".pdb" :
			fsp+=".pdb"
		clsid=list(self.sets_visible.keys())
		if len(clsid)==0:
			print("No visible particles to save")
			return
		
		bxs=np.array([[b[0], b[1], b[2]] for b in self.boxes if int(b[5]) in clsid])/10
		
		numpy2pdb(bxs, fsp)
		print("PDB saved to {}. Use voxel size 0.1".format(fsp))
		
	def save_boxes(self, clsid=[]):
		if len(clsid)==0:
			defaultname="ptcls.hdf"
		else:
			defaultname="_".join([self.sets[i] for i in clsid])+".hdf"
		
		name,ok=QtWidgets.QInputDialog.getText( self, "Save particles", "Filename suffix:", text=defaultname)
		if not ok:
			return
		name=self.filetag+str(name)
		if name[-4:].lower()!=".hdf" :
			name+=".hdf"
			
			
		if self.options.mode=="3D":
			dr="particles3d"
			is2d=False
		else:
			dr="particles"
			is2d=True
		
		
		if not os.path.isdir(dr):
			os.mkdir(dr)
		
		fsp=os.path.join(dr,self.basename)+name

		print("Saving {} particles to {}".format(self.options.mode, fsp))
		
		if os.path.isfile(fsp):
			print("{} exist. Overwritting...".format(fsp))
			os.remove(fsp)
		
		progress = QtWidgets.QProgressDialog("Saving", "Abort", 0, len(self.boxes),None)
		
		
		boxsz=-1
		for i,b in enumerate(self.boxes):
			if len(clsid)>0:
				if int(b[5]) not in clsid:
					continue
			
			#img=self.get_cube(b[0],b[1],b[2])
			bs=self.get_boxsize(b[5])
			if boxsz<0:
				boxsz=bs
			else:
				if boxsz!=bs:
					print("Inconsistant box size in the particles to save.. Using {:d}..".format(boxsz))
					bs=boxsz
			
			sz=[s//2 for s in self.datasize]
			
			img=self.get_cube(b[0], b[1], b[2], centerslice=is2d, boxsz=bs)
			if is2d==False:
				img.process_inplace('normalize')
			
			img["ptcl_source_image"]=self.datafilename
			img["ptcl_source_coord"]=(b[0]-sz[0], b[1]-sz[1], b[2]-sz[2])
			
			if is2d==False: #### do not invert contrast for 2D images
				img.mult(-1)
			
			img.write_image(fsp,-1)

			progress.setValue(i+1)
			if progress.wasCanceled():
				break

	def update_sliceview(self, axis=['x','y','z']):
		boxes=self.get_rotated_boxes()
		
		allside=(not self.wlocalbox.isChecked())
		
		pms={'z':[2, self.xyview, self.z_loc],
		     'y':[1, self.xzview, self.y_loc],
		     'x':[0, self.zyview, self.x_loc]} 
		
		if self.boxshape=="circle": lwi=7
		else: lwi=8
		
		for ax in axis:
			ia, view, loc=pms[ax]
			shp=view.get_shapes()
			if len(shp)!=len(boxes):
				### something changes the box shapes...
				for i,b in enumerate(boxes):
					self.update_box_shape(i,b)
			
		for ax in axis:
			ia, view, loc=pms[ax]
			
			## update the box shapes
			shp=view.get_shapes()
			for i,b in enumerate(boxes):
				bs=self.get_boxsize(b[5])
				dst=abs(b[ia] - loc)
				
				inplane=dst<bs//2
				rad=bs//2-dst
				
				if ax!='z' and allside:
					## display all side view boxes in this mode
					inplane=True
					rad=bs//2
					
				if ax=='z' and self.options.mode=="2D":
					## boxes are 1 slice thick in 2d mode
					inplane=dst<1
				
				
				if inplane and (b[5] in self.sets_visible):
					shp[i][0]=self.boxshape
					## selected box is slightly thicker
					if self.curbox==i:
						shp[i][lwi]=3
					else:
						shp[i][lwi]=2
					if self.options.mode=="3D":
						shp[i][6]=rad
				else:
					shp[i][0]="hidden"
				
			view.shapechange=1
			img=self.get_slice(loc, self.nlayers(), ax)
			if self.wfilt.getValue()!=0.0:
				img.process_inplace("filter.lowpass.gauss",{"cutoff_freq":1.0/self.wfilt.getValue(),"apix":self.apix})

			view.set_data(img)
			
		self.update_coords()

	def get_rotated_boxes(self):
		if len(self.boxes)==0:
			return []
		if self.globalxf.is_identity():
			boxes=self.boxes
		else:
			cnt=[self.data["nx"]//2,self.data["ny"]//2,self.data["nz"]//2]
			pts=np.array([b[:3] for b in self.boxes])-cnt
			pts=np.array([self.globalxf.transform(p.tolist()) for p in pts])+cnt
			boxes=[]
			for i,b in enumerate(self.boxes):
				p=pts[i]
				boxes.append([p[0],p[1],p[2],b[3],b[4],b[5]])
		return boxes
		
	def update_all(self):
		"""redisplay of all widgets"""
		if self.data==None:
			return

		self.update_sliceview()
		self.update_boximgs()

	def update_coords(self):
		self.wcoords.setText("X: {:d}\tY: {:d}\tZ: {:d}".format(int(self.x_loc), int(self.y_loc), int(self.z_loc)))

	def inside_box(self,n,x=-1,y=-1,z=-1):
		"""Checks to see if a point in image coordinates is inside box number n. If any value is negative, it will not be checked."""
		box=self.boxes[n]
		if box[5] not in self.sets_visible:
			return False
		bs=self.get_boxsize(box[5])/2
		if self.options.mode=="3D":
			rr=(x>=0)*((box[0]-x)**2) + (y>=0)*((box[1]-y) **2) + (z>=0)*((box[2]-z)**2)
		else:
			rr=(x>=0)*((box[0]-x)**2) + (y>=0)*((box[1]-y) **2) + (z>=0)*(box[2]!=z)*(1e3*bs**2)
		return rr<=bs**2

	def del_box(self, delids):
		
		if type(delids)!=list:
			delids=[delids]
		
		kpids=[i for i,b in enumerate(self.boxes) if i not in delids]
		self.boxes=[self.boxes[i] for i in kpids]
		self.boxesimgs=[self.boxesimgs[i] for i in kpids]
		self.xyview.shapes={i:self.xyview.shapes[k] for i,k in enumerate(kpids)}
		self.xzview.shapes={i:self.xzview.shapes[k] for i,k in enumerate(kpids)}
		self.zyview.shapes={i:self.zyview.shapes[k] for i,k in enumerate(kpids)}
		#print self.boxes, self.xyview.get_shapes()
		self.curbox=-1
		self.update_all()
	
	def update_box_shape(self,n, box):
		bs2=self.get_boxsize(box[5])//2
		if n==self.curbox:
			lw=3
		else:
			lw=2
		color=self.setcolors[box[5]%len(self.setcolors)].color().getRgbF()
		if self.options.mode=="3D":
			self.xyview.add_shape(n,EMShape(["circle",color[0],color[1],color[2],box[0],box[1],bs2,lw]))
			self.xzview.add_shape(n,EMShape(["circle",color[0],color[1],color[2],box[0],box[2],bs2,lw]))
			self.zyview.add_shape(n,EMShape(("circle",color[0],color[1],color[2],box[2],box[1],bs2,lw)))
		else:
			self.xyview.add_shape(n,EMShape(["rect",color[0],color[1],color[2],
				    box[0]-bs2,box[1]-bs2,box[0]+bs2,box[1]+bs2,2]))
			self.xzview.add_shape(n,EMShape(["rect",color[0],color[1],color[2], 
				    box[0]-bs2,box[2]-1,box[0]+bs2,box[2]+1,2]))
			self.zyview.add_shape(n,EMShape(["rect",color[0],color[1],color[2],
				    box[2]-1,box[1]-bs2,box[2]+1,box[1]+bs2,2]))
			
	
	def update_box(self,n,quiet=False):
		"""After adjusting a box, call this"""
#		print "upd ",n,quiet
		if n<0 or n>=len(self.boxes):
			return
		
		box=self.boxes[n]
		
		boxes=self.get_rotated_boxes()
		self.update_box_shape(n,boxes[n])

		if self.initialized: 
			self.update_sliceview()

		# For speed, we turn off updates while dragging a box around. Quiet is set until the mouse-up
		if not quiet:
			# Get the cube from the original data (normalized)
			proj=self.get_cube(box[0], box[1], box[2], centerslice=True, boxsz=self.get_boxsize(box[5]))
			proj.process_inplace("normalize")
			
			for i in range(len(self.boxesimgs),n+1): 
				self.boxesimgs.append(None)
			
			self.boxesimgs[n]=proj

			mm=[m for im,m in enumerate(self.boxesimgs) if self.boxes[im][5] in self.sets_visible]
			
			if self.initialized: self.SaveJson()
			
		if self.initialized:
			self.update_boximgs()
			#if n!=self.curbox:
				#self.boxesviewer.set_selected((n,),True)

		self.curbox=n

	def update_boximgs(self):
		self.boxids=[im for im,m in enumerate(self.boxesimgs) if self.boxes[im][5] in self.sets_visible]
		self.boxesviewer.set_data([self.boxesimgs[i] for i in self.boxids])
		self.boxesviewer.update()
		return

	def img_selected(self,event,lc):
		#print("sel",lc[0])
		lci=self.boxids[lc[0]]
		if event.modifiers()&Qt.ShiftModifier:
			if event.modifiers()&Qt.ControlModifier:
				self.del_box(list(range(lci, len(self.boxes))))
			else:
				self.del_box(lci)
		else:
			#self.update_box(lci)
			self.curbox=lci
			box=self.boxes[lci]
			self.x_loc,self.y_loc,self.z_loc=self.rotate_coord([box[0], box[1], box[2]], inv=False)
			self.scroll_to(self.x_loc,self.y_loc,self.z_loc)
			
			self.update_sliceview()
			
			
	def rotate_coord(self, p, inv=True):
		if not self.globalxf.is_identity():
			cnt=[self.data["nx"]//2,self.data["ny"]//2,self.data["nz"]//2]
			p=[p[i]-cnt[i] for i in range(3)]
			xf=Transform(self.globalxf)
			if inv:
				xf.invert()
			p=xf.transform(p)+cnt
		return p

	def del_region_xy(self, x=-1, y=-1, z=-1, rad=-1):
		if rad<0:
			rad=self.eraser_width()
		
		delids=[]
		boxes=self.get_rotated_boxes()
		for i,b in enumerate(boxes):
			if b[5] not in self.sets_visible:
				continue
			
			if (x>=0)*(b[0]-x)**2 + (y>=0)*(b[1]-y)**2 +(z>=0)*(b[2]-z)**2 < rad**2:
				delids.append(i)
		self.del_box(delids)
	
	def scroll_to(self, x,y,z, axis=""):
		if axis!="z": self.xyview.scroll_to(x,y,True)
		if axis!="y": self.xzview.scroll_to(x,self.data["nz"]/2,True)
		if axis!="x": self.zyview.scroll_to(self.data["nz"]/2,y,True)
	
	#### mouse click
	def xy_down(self,event):
		x,y=self.xyview.scr_to_img((event.x(),event.y()))
		self.mouse_down(event, x,y,self.z_loc, "z")
		
	def xz_down(self,event):
		x,z=self.xzview.scr_to_img((event.x(),event.y()))
		self.mouse_down(event,x,self.y_loc,z, "y")
			
	def zy_down(self,event):
		z,y=self.zyview.scr_to_img((event.x(),event.y()))
		self.mouse_down(event,self.x_loc,y,z, "x")
		
	def mouse_down(self,event, x, y, z, axis):
		if min(x,y,z)<0: return
		
		xr,yr,zr=self.rotate_coord((x,y,z))
		#print(x,y,z,xr,yr,zr)
	
		if self.optionviewer.erasercheckbox.isChecked():
			self.del_region_xy(x,y,z)
			return
			
		for i in range(len(self.boxes)):
			if self.inside_box(i,xr,yr,zr):
				
				if event.modifiers()&Qt.ShiftModifier:  ## delete box
					self.del_box(i)

				else:  ## start dragging
					self.dragging=i
					self.curbox=i
					self.scroll_to(x,y,z,axis)
					
				break
		else:
			if not event.modifiers()&Qt.ShiftModifier: ## add box

				self.x_loc, self.y_loc, self.z_loc=x,y,z
				self.scroll_to(x,y,z,axis)
				self.curbox=len(self.boxes)
				self.boxes.append(([xr,yr,zr, 'manual', 0.0, self.currentset]))
				self.update_box(len(self.boxes)-1)
				self.dragging=len(self.boxes)-1
				
				

	#### eraser mode
	def xy_move(self,event):
		self.mouse_move(event, self.xyview)
			
	def xz_move(self,event):
		self.mouse_move(event, self.xzview)
		
	def zy_move(self,event):
		self.mouse_move(event, self.zyview)
			
	
	def mouse_move(self,event,view):
		
		if self.optionviewer.erasercheckbox.isChecked(): 
			self.xyview.eraser_shape=self.xzview.eraser_shape=self.zyview.eraser_shape=None
			x,y=view.scr_to_img((event.x(),event.y()))
			view.eraser_shape=EMShape(["circle",1,1,1,x,y,self.eraser_width(),2])
			view.shapechange=1
			view.update()
		else:
			view.eraser_shape=None
			
	
	#### dragging...
	def mouse_drag(self,x, y, z):
		if self.dragging<0:
			return
		if min(x,y,z)<0:
			return
		
		self.x_loc, self.y_loc, self.z_loc=x,y,z
		x,y,z=self.rotate_coord((x,y,z))
		self.boxes[self.dragging][:3]= x,y,z
		self.update_box(self.dragging,True)

	def xy_drag(self,event):
		if self.dragging>=0:
			x,y=self.xyview.scr_to_img((event.x(),event.y()))
			self.mouse_drag(x,y,self.z_loc)

	def xz_drag(self,event):
		if self.dragging>=0:
			x,z=self.xzview.scr_to_img((event.x(),event.y()))
			self.mouse_drag(x,self.y_loc,z)
	
	def zy_drag(self,event):
		if self.dragging>=0:
			z,y=self.zyview.scr_to_img((event.x(),event.y()))
			self.mouse_drag(self.x_loc,y,z)
		
	def mouse_up(self,event):
		if self.dragging>=0:
			self.update_box(self.dragging)
		self.dragging=-1

	
	#### keep the same origin for the 3 views
	def xy_origin(self,newor):
		xzo=self.xzview.get_origin()
		self.xzview.set_origin(newor[0],xzo[1],True)

		zyo=self.zyview.get_origin()
		self.zyview.set_origin(zyo[0],newor[1],True)
	
	def xz_origin(self,newor):
		xyo=self.xyview.get_origin()
		self.xyview.set_origin(newor[0],xyo[1],True)

	def zy_origin(self,newor):
		xyo=self.xyview.get_origin()
		self.xyview.set_origin(xyo[0],newor[1],True)


	##### go up/down with shift+wheel
	def xy_wheel(self, event):
		z=int(self.z_loc+ np.sign(event.angleDelta().y()))
		if z>0 and z<self.data["nz"]:
			self.wdepth.setValue(z)
	
	def xz_wheel(self, event):
		y=int(self.y_loc+np.sign(event.angleDelta().y()))
		if y>0 and y<self.data["ny"]:
			self.y_loc=y
			self.update_sliceview(['y'])
		
	def zy_wheel(self, event):
		x=int(self.x_loc+np.sign(event.angleDelta().y()))
		if x>0 and x<self.data["nx"]:
			self.x_loc=x
			self.update_sliceview(['x'])
			

	########
	def set_current_set(self, name):
		
		#print "set current", name
		name=parse_setname(name)
		self.currentset=name
		self.wboxsize.setValue(self.get_boxsize())
		self.update_all()
		return
	
	
	def hide_set(self, name):
		name=parse_setname(name)
		
		if name in self.sets_visible: self.sets_visible.pop(name)
		
		
		if self.initialized: 
			self.update_all()
			self.update_boximgs()
		return
	
	
	def show_set(self, name):
		name=parse_setname(name)
		self.sets_visible[name]=0
		#self.currentset=name
		#self.wboxsize.setValue(self.get_boxsize())
		if self.initialized: 
			self.update_all()
			self.update_boximgs()
		return
	
	
	def delete_set(self, name):
		name=parse_setname(name)
		## idx to keep
		delids=[i for i,b in enumerate(self.boxes) if b[5]==int(name)]
		self.del_box(delids)
		
		if name in self.sets_visible: self.sets_visible.pop(name)
		if name in self.sets: self.sets.pop(name)
		if name in self.boxsize: self.boxsize.pop(name)
		
		self.update_all()
		
		return
	
	def rename_set(self, oldname,  newname):
		name=parse_setname(oldname)
		if name in self.sets: 
			self.sets[name]=newname
		return
	
	
	def new_set(self, name):
		for i in range(len(self.sets)+1):
			if i not in self.sets:
				break
			
		self.sets[i]=name
		self.sets_visible[i]=0
		if self.options.mode=="3D":
			self.boxsize[i]=32
		else:
			self.boxsize[i]=64
		
		return
	
	def save_set(self):
		
		self.save_boxes(list(self.sets_visible.keys()))
		return
	
	
	def key_press(self,event):
		if event.key() == 96: ## "`" to move up a slice since arrow keys are occupied...
			self.wdepth.setValue(self.z_loc+1)

		elif event.key() == 49: ## "1" to move down a slice
			self.wdepth.setValue(self.z_loc-1)
		else:
			self.keypress.emit(event)

	def flatten_tomo(self):
		print("Flatten tomogram by particles coordinates")
		vis=list(self.sets_visible.keys())
		pts=[b[:3] for b in self.boxes if b[5] in vis]
		if len(pts)<3:
			print("Too few visible particles. Cannot flatten tomogram.")
			return
		pts=np.array(pts)
		pca=PCA(3)
		pca.fit(pts);
		c=pca.components_
		t=Transform()
		cc=c[2]
		if cc[2]!=0:
			cc*=np.sign(cc[2])
		
		t.set_rotation(c[2].tolist())
		t.invert()
		xyz=t.get_params("xyz")
		xyz["ztilt"]=0
		print("xtilt {:.02f}, ytilt {:.02f}".format(xyz["xtilt"], xyz["ytilt"]))
		t=Transform(xyz)
		self.globalxf=t
		self.dataxf=self.data.process("xform",{"transform":t})
		
		self.xyview.shapes={}
		self.zyview.shapes={}
		self.xzview.shapes={}
		
		boxes=self.get_rotated_boxes()
		for i,b in enumerate(boxes):
			self.update_box_shape(i,b)
		
		self.update_sliceview()
		print("Done")
	
	def reset_flatten_tomo(self, event):
		self.globalxf=Transform()
		self.xyview.shapes={}
		self.zyview.shapes={}
		self.xzview.shapes={}
		
		boxes=self.get_rotated_boxes()
		for i,b in enumerate(boxes):
			self.update_box_shape(i,b)
		
		self.update_sliceview()
		

	def SaveJson(self):
		
		info=js_open_dict(self.jsonfile)
		sx,sy,sz=(self.data["nx"]//2,self.data["ny"]//2,self.data["nz"]//2)
		if "apix_unbin" in info:
			bxs=[]
			for b0 in self.boxes:
				b=[	(b0[0]-sx)*self.apix_cur/self.apix_unbin,
					(b0[1]-sy)*self.apix_cur/self.apix_unbin,
					(b0[2]-sz)*self.apix_cur/self.apix_unbin,
					b0[3], b0[4], b0[5]	]
				bxs.append(b)
				
			bxsz={}
			for k in self.boxsize.keys():
				bxsz[k]=np.round(self.boxsize[k]*self.apix_cur/self.apix_unbin)

				
		else:
			bxs=self.boxes
			bxsz=self.boxsize
				
		info["boxes_3d"]=bxs
		clslst={}
		for key in list(self.sets.keys()):
			clslst[int(key)]={
				"name":self.sets[key],
				"boxsize":int(bxsz[key]),
				}
		info["class_list"]=clslst
		info.close()
	
	def closeEvent(self,event):
		print("Exiting")
		self.SaveJson()
		
		E2saveappwin("e2sptboxer","main",self)
		E2saveappwin("e2sptboxer","boxes",self.boxesviewer.qt_parent)
		E2saveappwin("e2sptboxer","option",self.optionviewer)
		
		#self.boxviewer.close()
		self.boxesviewer.close()
		self.optionviewer.close()
		#self.optionviewer.close()
		#self.xyview.close()
		#self.xzview.close()
		#self.zyview.close()
		
		self.module_closed.emit() # this signal is important when e2ctf is being used by a program running its own event loop
示例#3
0
class EMBoxViewer(QtGui.QWidget):
	"""This is a multi-paned view showing a single boxed out particle from a larger tomogram"""

	def __init__(self):
		QtGui.QWidget.__init__(self)
		self.setWindowTitle("Single Particle View")

		self.resize(300,300)

		self.gbl = QtGui.QGridLayout(self)
		self.xyview = EMImage2DWidget()
		self.gbl.addWidget(self.xyview,0,1)

		self.xzview = EMImage2DWidget()
		self.gbl.addWidget(self.xzview,1,1)

		self.zyview = EMImage2DWidget()
		self.gbl.addWidget(self.zyview,0,0)
		self.data = None


		# This puts an isosurface view in the lower left corner, but was causing a lot of segfaults, so switching to 2-D slices for now
		#self.d3view = EMScene3D()
		#self.d3viewdata = EMDataItem3D(test_image_3d(3), transform=Transform())
		#isosurface = EMIsosurface(self.d3viewdata, transform=Transform())
		#self.d3view.insertNewNode('', self.d3viewdata, parentnode=self.d3view)
		#self.d3view.insertNewNode("Iso", isosurface, parentnode=self.d3viewdata )

		self.d3view = EMImage2DWidget()
		self.gbl.addWidget(self.d3view,1,0)

		self.wfilt = ValSlider(rng=(0,50),label="Filter:",value=0.0)
		self.gbl.addWidget(self.wfilt,2,0,1,2)

		self.wfilt.valueChanged.connect(self.event_filter)

		self.gbl.setRowStretch(2,1)
		self.gbl.setRowStretch(0,5)
		self.gbl.setRowStretch(1,5)
		
	def set_data(self,data):
		"""Sets the current volume to display"""

		self.data=data
		self.fdata=data

		self.update()
		self.show()

	def get_data(self):
		return self.data

	def update(self):
		if self.data==None:
			self.xyview.set_data(None)
			self.xzview.set_data(None)
			self.zyview.set_data(None)

			#self.d3viewdata.setData(test_image_3d(3))
			#self.d3view.updateSG()
			self.d3view.set_data(test_image_3d(3))

			return

		if self.wfilt.getValue()>4 :
			self.fdata=self.data.process("filter.lowpass.gauss",{"cutoff_freq":old_div(1.0,self.wfilt.getValue()),"apix":self.data['apix_x']}) #JESUS

		xyd=self.fdata.process("misc.directional_sum",{"axis":"z"})
		xzd=self.fdata.process("misc.directional_sum",{"axis":"y"})
		zyd=self.fdata.process("misc.directional_sum",{"axis":"x"})

		self.xyview.set_data(xyd)
		self.xzview.set_data(xzd)
		self.zyview.set_data(zyd)

		#self.d3viewdata.setData(self.fdata)
		#self.d3view.updateSG()
		self.d3view.set_data(self.fdata)


	def event_filter(self,value):
		self.update()

	def closeEvent(self, event):
		self.d3view.close()
		self.xyview.close()
		self.xzview.close()
		self.zyview.close()
示例#4
0
class EMTomoBoxer(QtGui.QMainWindow):
	"""This class represents the EMTomoBoxer application instance.  """
	keypress = QtCore.pyqtSignal(QtGui.QKeyEvent)
	module_closed = QtCore.pyqtSignal()

	def __init__(self,application,options,datafile):
		QtGui.QWidget.__init__(self)
		self.initialized=False
		self.app=weakref.ref(application)
		self.options=options
		self.yshort=False
		self.apix=options.apix
		self.currentset=0
		self.shrink=1#options.shrink
		self.setWindowTitle("Main Window (e2spt_boxer.py)")
		if options.mode=="3D":
			self.boxshape="circle"
		else:
			self.boxshape="rect"


		# Menu Bar
		self.mfile=self.menuBar().addMenu("File")
		self.mfile_open=self.mfile.addAction("Open")
		self.mfile_read_boxloc=self.mfile.addAction("Read Box Coord")
		self.mfile_save_boxloc=self.mfile.addAction("Save Box Coord")
		self.mfile_save_boxes_stack=self.mfile.addAction("Save Boxes as Stack")
		self.mfile_quit=self.mfile.addAction("Quit")


		self.setCentralWidget(QtGui.QWidget())
		self.gbl = QtGui.QGridLayout(self.centralWidget())

		# relative stretch factors
		self.gbl.setColumnStretch(0,1)
		self.gbl.setColumnStretch(1,4)
		self.gbl.setColumnStretch(2,0)
		self.gbl.setRowStretch(1,1)
		self.gbl.setRowStretch(0,4)

		# 3 orthogonal restricted projection views
		self.xyview = EMImage2DWidget()
		self.gbl.addWidget(self.xyview,0,1)

		self.xzview = EMImage2DWidget()
		self.gbl.addWidget(self.xzview,1,1)

		self.zyview = EMImage2DWidget()
		self.gbl.addWidget(self.zyview,0,0)

		# Select Z for xy view
		self.wdepth = QtGui.QSlider()
		self.gbl.addWidget(self.wdepth,1,2)

		### Control panel area in upper left corner
		self.gbl2 = QtGui.QGridLayout()
		self.gbl.addLayout(self.gbl2,1,0)

		#self.wxpos = QtGui.QSlider(Qt.Horizontal)
		#self.gbl2.addWidget(self.wxpos,0,0)
		
		#self.wypos = QtGui.QSlider(Qt.Vertical)
		#self.gbl2.addWidget(self.wypos,0,3,6,1)
		
		# box size
		self.wboxsize=ValBox(label="Box Size:",value=0)
		self.gbl2.addWidget(self.wboxsize,2,0,1,2)

		# max or mean
		#self.wmaxmean=QtGui.QPushButton("MaxProj")
		#self.wmaxmean.setCheckable(True)
		#self.gbl2.addWidget(self.wmaxmean,3,0)

		# number slices
		self.wnlayers=QtGui.QSpinBox()
		self.wnlayers.setMinimum(1)
		self.wnlayers.setMaximum(256)
		self.wnlayers.setValue(1)
		self.gbl2.addWidget(self.wnlayers,3,1)

		# Local boxes in side view
		self.wlocalbox=QtGui.QCheckBox("Limit Side Boxes")
		self.gbl2.addWidget(self.wlocalbox,3,0)
		self.wlocalbox.setChecked(True)

		# scale factor
		self.wscale=ValSlider(rng=(.1,2),label="Sca:",value=1.0)
		self.gbl2.addWidget(self.wscale,4,0,1,2)

		# 2-D filters
		self.wfilt = ValSlider(rng=(0,150),label="Filt:",value=0.0)
		self.gbl2.addWidget(self.wfilt,5,0,1,2)
		
		self.curbox=-1
		
		self.boxes=[]						# array of box info, each is (x,y,z,...)
		self.boxesimgs=[]					# z projection of each box
		self.xydown=self.xzdown=self.zydown=None
		self.firsthbclick = None

		# coordinate display
		self.wcoords=QtGui.QLabel("X: " + str(self.get_x()) + "\t\t" + "Y: " + str(self.get_y()) + "\t\t" + "Z: " + str(self.get_z()))
		self.gbl2.addWidget(self.wcoords, 1, 0, 1, 2)

		# file menu
		self.mfile_open.triggered[bool].connect(self.menu_file_open)
		self.mfile_read_boxloc.triggered[bool].connect(self.menu_file_read_boxloc)
		self.mfile_save_boxloc.triggered[bool].connect(self.menu_file_save_boxloc)
		self.mfile_save_boxes_stack.triggered[bool].connect(self.save_boxes)
		self.mfile_quit.triggered[bool].connect(self.menu_file_quit)

		# all other widgets
		self.wdepth.valueChanged[int].connect(self.event_depth)
		self.wnlayers.valueChanged[int].connect(self.event_nlayers)
		self.wboxsize.valueChanged.connect(self.event_boxsize)
		#self.wmaxmean.clicked[bool].connect(self.event_projmode)
		self.wscale.valueChanged.connect(self.event_scale)
		self.wfilt.valueChanged.connect(self.event_filter)
		self.wlocalbox.stateChanged[int].connect(self.event_localbox)

		self.xyview.mousemove.connect(self.xy_move)
		self.xyview.mousedown.connect(self.xy_down)
		self.xyview.mousedrag.connect(self.xy_drag)
		self.xyview.mouseup.connect(self.xy_up)
		self.xyview.mousewheel.connect(self.xy_wheel)
		self.xyview.signal_set_scale.connect(self.xy_scale)
		self.xyview.origin_update.connect(self.xy_origin)

		self.xzview.mousedown.connect(self.xz_down)
		self.xzview.mousedrag.connect(self.xz_drag)
		self.xzview.mouseup.connect(self.xz_up)
		self.xzview.signal_set_scale.connect(self.xz_scale)
		self.xzview.origin_update.connect(self.xz_origin)

		self.zyview.mousedown.connect(self.zy_down)
		self.zyview.mousedrag.connect(self.zy_drag)
		self.zyview.mouseup.connect(self.zy_up)
		self.zyview.signal_set_scale.connect(self.zy_scale)
		self.zyview.origin_update.connect(self.zy_origin)
		
		self.xyview.keypress.connect(self.key_press)
		self.datafilename=datafile
		self.basename=base_name(datafile)
		p0=datafile.find('__')
		if p0>0:
			p1=datafile.rfind('.')
			self.filetag=datafile[p0:p1]
			if self.filetag[-1]!='_':
				self.filetag+='_'
		else:
			self.filetag="__"
			
		data=EMData(datafile)
		self.set_data(data)

		# Boxviewer subwidget (details of a single box)
		self.boxviewer=EMBoxViewer()
		#self.app().attach_child(self.boxviewer)

		# Boxes Viewer (z projections of all boxes)
		self.boxesviewer=EMImageMXWidget()
		
		#self.app().attach_child(self.boxesviewer)
		self.boxesviewer.show()
		self.boxesviewer.set_mouse_mode("App")
		self.boxesviewer.setWindowTitle("Particle List")
		self.boxesviewer.rzonce=True
		
		self.setspanel=EMTomoSetsPanel(self)

		self.optionviewer=EMTomoBoxerOptions(self)
		self.optionviewer.add_panel(self.setspanel,"Sets")
		
		
		self.optionviewer.show()
		
		# Average viewer shows results of background tomographic processing
#		self.averageviewer=EMAverageViewer(self)
		#self.averageviewer.show()

		self.boxesviewer.mx_image_selected.connect(self.img_selected)
		
		self.jsonfile=info_name(datafile)
		
		info=js_open_dict(self.jsonfile)
		self.sets={}
		self.boxsize={}
		if "class_list" in info:
			clslst=info["class_list"]
			for k in sorted(clslst.keys()):
				if type(clslst[k])==dict:
					self.sets[int(k)]=str(clslst[k]["name"])
					self.boxsize[int(k)]=int(clslst[k]["boxsize"])
				else:
					self.sets[int(k)]=str(clslst[k])
					self.boxsize[int(k)]=boxsize
				
					
		
			
			
		clr=QtGui.QColor
		self.setcolors=[clr("blue"),clr("green"),clr("red"),clr("cyan"),clr("purple"),clr("orange"), clr("yellow"),clr("hotpink"),clr("gold")]
		self.sets_visible={}
				
		if "boxes_3d" in info:
			box=info["boxes_3d"]
			for i,b in enumerate(box):
				#### X-center,Y-center,Z-center,method,[score,[class #]]
				bdf=[0,0,0,"manual",0.0, 0]
				for j,bi in enumerate(b):  bdf[j]=bi
				
				
				if bdf[5] not in list(self.sets.keys()):
					clsi=int(bdf[5])
					self.sets[clsi]="particles_{:02d}".format(clsi)
					self.boxsize[clsi]=boxsize
				
				self.boxes.append(bdf)
		
		
		###### this is the new (2018-09) metadata standard..
		### now we use coordinates at full size from center of tomogram so it works for different binning and clipping
		### have to make it compatible with older versions though..
		if "apix_unbin" in info:
			self.apix_unbin=info["apix_unbin"]
			self.apix_cur=apix=data["apix_x"]
			for b in self.boxes:
				b[0]=b[0]/apix*self.apix_unbin+data["nx"]//2
				b[1]=b[1]/apix*self.apix_unbin+data["ny"]//2
				b[2]=b[2]/apix*self.apix_unbin+data["nz"]//2
				
			for k in self.boxsize.keys():
				self.boxsize[k]=self.boxsize[k]/apix*self.apix_unbin
		else:
			self.apix_unbin=-1
		
		
		
		info.close()
		if len(self.sets)==0:
			self.new_set("particles_00")
		self.sets_visible[list(self.sets.keys())[0]]=0
		self.currentset=sorted(self.sets.keys())[0]
		self.setspanel.update_sets()
		self.wboxsize.setValue(self.get_boxsize())

		print(self.sets)
		for i in range(len(self.boxes)):
			self.update_box(i)
		
		self.update_all()
		self.initialized=True

	def set_data(self,data):

		self.data=data
		self.apix=data["apix_x"]

		self.datasize=(data["nx"],data["ny"],data["nz"])

		self.wdepth.setRange(0,self.datasize[2]-1)
		self.boxes=[]
		self.curbox=-1

		self.wdepth.setValue(old_div(self.datasize[2],2))
		if self.initialized:
			self.update_all()

	def eraser_width(self):
		return int(self.optionviewer.eraser_radius.getValue())
		
	def get_cube(self,x,y,z, centerslice=False, boxsz=-1):
		"""Returns a box-sized cube at the given center location"""
		if boxsz<0:
			bs=self.get_boxsize()
		else:
			bs=boxsz
			
		if centerslice:
			bz=1
		else:
			bz=bs
		
		if ((x<-bs//2) or (y<-bs//2) or (z<-bz//2)
			or (x>self.data["nx"]+bs//2) or (y>self.data["ny"]+bs//2) or (z>self.data["nz"]+bz//2) ):
			r=EMData(bs,bs,bz)
		else:
			r=self.data.get_clip(Region(x-bs//2,y-bs//2,z-bz//2,bs,bs,bz))

		if self.apix!=0 :
			r["apix_x"]=self.apix
			r["apix_y"]=self.apix
			r["apix_z"]=self.apix

		#if options.normproc:
			#r.process_inplace(options.normproc)
		return r

	def get_slice(self,n,xyz):
		"""Reads a slice either from a file or the preloaded memory array.
		xyz is the axis along which 'n' runs, 0=x (yz), 1=y (xz), 2=z (xy)"""

		if xyz==0:
			r=self.data.get_clip(Region(n,0,0,1,self.datasize[1],self.datasize[2]))
			r.set_size(self.datasize[1],self.datasize[2],1)
		elif xyz==1:
			r=self.data.get_clip(Region(0,n,0,self.datasize[0],1,self.datasize[2]))
			r.set_size(self.datasize[0],self.datasize[2],1)
		else:
			r=self.data.get_clip(Region(0,0,n,self.datasize[0],self.datasize[1],1))

		if self.apix!=0 :
			r["apix_x"]=self.apix
			r["apix_y"]=self.apix
			r["apix_z"]=self.apix
		return r

	def event_boxsize(self):
		if self.get_boxsize()==int(self.wboxsize.getValue()):
			return
		
		self.boxsize[self.currentset]=int(self.wboxsize.getValue())
		
		cb=self.curbox
		self.initialized=False
		for i in range(len(self.boxes)):
			if self.boxes[i][5]==self.currentset:
				self.update_box(i)
		self.update_box(cb)
		self.initialized=True
		self.update_all()

	def event_projmode(self,state):
		"""Projection mode can be simple average (state=False) or maximum projection (state=True)"""
		self.update_all()

	def event_scale(self,newscale):
		self.xyview.set_scale(newscale)
		self.xzview.set_scale(newscale)
		self.zyview.set_scale(newscale)

	def event_depth(self):
		if self.initialized:
			self.update_xy()

	def event_nlayers(self):
		self.update_all()

	def event_filter(self):
		self.update_all()

	def event_localbox(self,tog):
		self.update_sides()

	def get_boxsize(self, clsid=-1):
		if clsid<0:
			return int(self.boxsize[self.currentset])
		else:
			try:
				ret= int(self.boxsize[clsid])
			except:
				print("No box size saved for {}..".format(clsid))
				ret=32
			return ret

	def nlayers(self):
		return int(self.wnlayers.value())

	def depth(self):
		return int(self.wdepth.value())

	def scale(self):
		return self.wscale.getValue()

	def get_x(self):
		return self.get_coord(0)

	def get_y(self):
		return self.get_coord(1)

	def get_z(self):
		return self.depth()

	def get_coord(self, coord_index):
		if len(self.boxes) > 1:
			if self.curbox:
				return self.boxes[self.curbox][coord_index]
			else:
				return self.boxes[-1][coord_index]
		else:
			return 0


	def menu_file_open(self,tog):
		QtGui.QMessageBox.warning(None,"Error","Sorry, in the current version, you must provide a file to open on the command-line.")

	def load_box_yshort(self, boxcoords):
		if options.yshort:
			return [boxcoords[0], boxcoords[2], boxcoords[1]]
		else:
			return boxcoords

	def menu_file_read_boxloc(self):
		fsp=str(QtGui.QFileDialog.getOpenFileName(self, "Select output text file"))

		f=file(fsp,"r")
		for b in f:
			b2=[old_div(int(float(i)),self.shrink) for i in b.split()[:3]]
			bdf=[0,0,0,"manual",0.0, self.currentset]
			for j in range(len(b2)):
				bdf[j]=b2[j]
			self.boxes.append(bdf)
			self.update_box(len(self.boxes)-1)
		f.close()

	def menu_file_save_boxloc(self):
		shrinkf=self.shrink 								#jesus

		fsp=str(QtGui.QFileDialog.getSaveFileName(self, "Select output text file"))

		out=file(fsp,"w")
		for b in self.boxes:
			out.write("%d\t%d\t%d\n"%(b[0]*shrinkf,b[1]*shrinkf,b[2]*shrinkf))
		out.close()


	def save_boxes(self, clsid=[]):
		if len(clsid)==0:
			defaultname="ptcls.hdf"
		else:
			defaultname="_".join([self.sets[i] for i in clsid])+".hdf"
		
		name,ok=QtGui.QInputDialog.getText( self, "Save particles", "Filename suffix:", text=defaultname)
		if not ok:
			return
		name=self.filetag+str(name)
		if name[-4:].lower()!=".hdf" :
			name+=".hdf"
			
			
		if self.options.mode=="3D":
			dr="particles3d"
			is2d=False
		else:
			dr="particles"
			is2d=True
		
		
		if not os.path.isdir(dr):
			os.mkdir(dr)
		
		fsp=os.path.join(dr,self.basename)+name

		print("Saving {} particles to {}".format(self.options.mode, fsp))
		
		if os.path.isfile(fsp):
			print("{} exist. Overwritting...".format(fsp))
			os.remove(fsp)
		
		progress = QtGui.QProgressDialog("Saving", "Abort", 0, len(self.boxes),None)
		
		
		boxsz=-1
		for i,b in enumerate(self.boxes):
			if len(clsid)>0:
				if int(b[5]) not in clsid:
					continue
			
			#img=self.get_cube(b[0],b[1],b[2])
			bs=self.get_boxsize(b[5])
			if boxsz<0:
				boxsz=bs
			else:
				if boxsz!=bs:
					print("Inconsistant box size in the particles to save.. Using {:d}..".format(boxsz))
					bs=boxsz
			
			sz=[s//2 for s in self.datasize]
			
			img=self.get_cube(b[0], b[1], b[2], centerslice=is2d, boxsz=bs)
			if is2d==False:
				img.process_inplace('normalize')
			
			img["ptcl_source_image"]=self.datafilename
			img["ptcl_source_coord"]=(b[0]-sz[0], b[1]-sz[1], b[2]-sz[2])
			
			if is2d==False: #### do not invert contrast for 2D images
				img.mult(-1)
			
			img.write_image(fsp,-1)

			progress.setValue(i+1)
			if progress.wasCanceled():
				break


	def menu_file_quit(self):
		self.close()

	def transform_coords(self, point, xform):
		xvec = xform.get_matrix()
		return [xvec[0]*point[0] + xvec[4]*point[1] + xvec[8]*point[2] + xvec[3], xvec[1]*point[0] + xvec[5]*point[1] + xvec[9]*point[2] + xvec[7], xvec[2]*point[0] + xvec[6]*point[1] + xvec[10]*point[2] + xvec[11]]

	def get_averager(self):
		"""returns an averager of the appropriate type for generating projection views"""
		#if self.wmaxmean.isChecked() : return Averagers.get("minmax",{"max":1})

		return Averagers.get("mean")

	def update_sides(self):
		"""updates xz and yz views due to a new center location"""

		#print "\n\n\n\n\nIn update sides, self.datafile is", self.datafile
		#print "\n\n\n\n"

		if self.data==None:
			return

		if self.curbox==-1 :
			x=self.datasize[0]//2
			y=self.datasize[1]//2
			z=0
		else:
			x,y,z=self.boxes[self.curbox][:3]

		self.cury=y
		self.curx=x

		# update shape display
		if self.wlocalbox.isChecked():
			xzs=self.xzview.get_shapes()
			for i in range(len(self.boxes)):
				bs=self.get_boxsize(self.boxes[i][5])
				if self.boxes[i][1]<self.cury+old_div(bs,2) and self.boxes[i][1]>self.cury-old_div(bs,2) and  self.boxes[i][5] in self.sets_visible:
					xzs[i][0]=self.boxshape
				else:
					xzs[i][0]="hidden"

			zys=self.zyview.get_shapes()
			
			for i in range(len(self.boxes)):
				bs=self.get_boxsize(self.boxes[i][5])
				if self.boxes[i][0]<self.curx+old_div(bs,2) and self.boxes[i][0]>self.curx-old_div(bs,2) and  self.boxes[i][5] in self.sets_visible:
					zys[i][0]=self.boxshape
				else:
					zys[i][0]="hidden"
		else :
			xzs=self.xzview.get_shapes()
			zys=self.zyview.get_shapes()
		
			for i in range(len(self.boxes)):
				bs=self.get_boxsize(self.boxes[i][5])
				if  self.boxes[i][5] in self.sets_visible:
					xzs[i][0]=self.boxshape
					zys[i][0]=self.boxshape
				else:
					xzs[i][0]="hidden"
					zys[i][0]="hidden"

		self.xzview.shapechange=1
		self.zyview.shapechange=1

		# yz
		avgr=self.get_averager()

		for x in range(x-(self.nlayers()//2),x+((self.nlayers()+1)//2)):
			slc=self.get_slice(x,0)
			avgr.add_image(slc)

		av=avgr.finish()
		if not self.yshort:
			av.process_inplace("xform.transpose")

		if self.wfilt.getValue()!=0.0:
			av.process_inplace("filter.lowpass.gauss",{"cutoff_freq":old_div(1.0,self.wfilt.getValue()),"apix":self.apix})

		self.zyview.set_data(av)

		# xz
		avgr=self.get_averager()

		for y in range(y-old_div(self.nlayers(),2),y+old_div((self.nlayers()+1),2)):
			slc=self.get_slice(y,1)
			avgr.add_image(slc)

		av=avgr.finish()
		if self.wfilt.getValue()!=0.0:
			av.process_inplace("filter.lowpass.gauss",{"cutoff_freq":old_div(1.0,self.wfilt.getValue()),"apix":self.apix})

		self.xzview.set_data(av)


	def update_xy(self):
		"""updates xy view due to a new slice range"""

		#print "\n\n\n\n\nIn update_xy, self.datafile is", self.datafile
		#print "\n\n\n\n"

		if self.data==None:
			return

		# Boxes should also be limited by default in the XY view
		if len(self.boxes) > 0:
			zc=self.wdepth.value()
			#print "The current depth is", self.wdepth.value()
			xys=self.xyview.get_shapes()
			for i in range(len(self.boxes)):

				bs=self.get_boxsize(self.boxes[i][5])
				zdist=abs(self.boxes[i][2] - zc)

				if self.options.mode=="3D":
					zthr=bs/2
					xys[i][6]=bs//2-zdist
				else:
					zthr=1
					
				if zdist < zthr and self.boxes[i][5] in self.sets_visible:
					xys[i][0]=self.boxshape
					
				else :
					xys[i][0]="hidden"
			self.xyview.shapechange=1

		#if self.wmaxmean.isChecked():
			#avgr=Averagers.get("minmax",{"max":1})

		#else:
		avgr=Averagers.get("mean")

		slc=EMData()
		for z in range(self.wdepth.value()-self.nlayers()//2,self.wdepth.value()+(self.nlayers()+1)//2):
			slc=self.get_slice(z,2)
			avgr.add_image(slc)

		av=avgr.finish()

		#print "\n\nIn update xy, av and type are", av, type(av)

		if self.wfilt.getValue()!=0.0:

			av.process_inplace("filter.lowpass.gauss",{"cutoff_freq":old_div(1.0,self.wfilt.getValue()),"apix":self.apix})
		if self.initialized:
			self.xyview.set_data(av, keepcontrast=True)
		else:
			self.xyview.set_data(av)

	def update_all(self):
		"""redisplay of all widgets"""

		#print "\n\n\n\n\nIn update all, self.datafile is", self.datafile
		#print "\n\n\n\n"
		if self.data==None:
			return

		self.update_xy()
		self.update_sides()
		self.update_boximgs()

	def update_coords(self):
		self.wcoords.setText("X: " + str(self.get_x()) + "\t\t" + "Y: " + str(self.get_y()) + "\t\t" + "Z: " + str(self.get_z()))

	def inside_box(self,n,x=-1,y=-1,z=-1):
		"""Checks to see if a point in image coordinates is inside box number n. If any value is negative, it will not be checked."""
		box=self.boxes[n]
		if box[5] not in self.sets_visible:
			return False
		bs=self.get_boxsize(box[5])/2
		if self.options.mode=="3D":
			rr=(x>=0)*((box[0]-x)**2) + (y>=0)*((box[1]-y) **2) + (z>=0)*((box[2]-z)**2)
		else:
			rr=(x>=0)*((box[0]-x)**2) + (y>=0)*((box[1]-y) **2) + (z>=0)*(box[2]!=z)*(1e3*bs**2)
		return rr<=bs**2

	def do_deletion(self, delids):
		
		kpids=[i for i,b in enumerate(self.boxes) if i not in delids]
		self.boxes=[self.boxes[i] for i in kpids]
		self.boxesimgs=[self.boxesimgs[i] for i in kpids]
		self.xyview.shapes={i:self.xyview.shapes[k] for i,k in enumerate(kpids)}
		self.xzview.shapes={i:self.xzview.shapes[k] for i,k in enumerate(kpids)}
		self.zyview.shapes={i:self.zyview.shapes[k] for i,k in enumerate(kpids)}
		#print self.boxes, self.xyview.get_shapes()
		self.curbox=-1
		self.update_all()

	def del_box(self,n):
		"""Delete an existing box by replacing the deleted box with the last box. A bit funny, but otherwise
		update after deletion is REALLY slow."""
#		print "del ",n
		if n<0 or n>=len(self.boxes): return

		if self.boxviewer.get_data(): self.boxviewer.set_data(None)
		self.curbox=-1
		self.do_deletion([n])




	def update_box(self,n,quiet=False):
		"""After adjusting a box, call this"""
#		print "upd ",n,quiet

		try:
			box=self.boxes[n]
		except IndexError:
			return
		bs2=self.get_boxsize(box[5])//2

		
		color=self.setcolors[box[5]%len(self.setcolors)].getRgbF()
		if self.options.mode=="3D":
			self.xyview.add_shape(n,EMShape(["circle",color[0],color[1],color[2],box[0],box[1],bs2,2]))
			self.xzview.add_shape(n,EMShape(["circle",color[0],color[1],color[2],box[0],box[2],bs2,2]))
			self.zyview.add_shape(n,EMShape(("circle",color[0],color[1],color[2],box[2],box[1],bs2,2)))
		else:
			self.xyview.add_shape(n,EMShape(["rect",color[0],color[1],color[2],
				    box[0]-bs2,box[1]-bs2,box[0]+bs2,box[1]+bs2,2]))
			self.xzview.add_shape(n,EMShape(["rect",color[0],color[1],color[2], 
				    box[0]-bs2,box[2]-1,box[0]+bs2,box[2]+1,2]))
			self.zyview.add_shape(n,EMShape(["rect",color[0],color[1],color[2],
				    box[2]-1,box[1]-bs2,box[2]+1,box[1]+bs2,2]))
			
			

		if self.depth()!=box[2]:
			self.wdepth.setValue(box[2])
		else:
			self.xyview.update()
		if self.initialized: self.update_sides()

		# For speed, we turn off updates while dragging a box around. Quiet is set until the mouse-up
		if not quiet:
			# Get the cube from the original data (normalized)
			proj=self.get_cube(box[0], box[1], box[2], centerslice=True, boxsz=self.get_boxsize(box[5]))
			proj.process_inplace("normalize")
			
			for i in range(len(self.boxesimgs),n+1): 
				self.boxesimgs.append(None)
			
			self.boxesimgs[n]=proj

			mm=[m for im,m in enumerate(self.boxesimgs) if self.boxes[im][5] in self.sets_visible]
			
			if self.initialized: self.SaveJson()
			
		if self.initialized:
			self.update_boximgs()
			

			if n!=self.curbox:
				self.boxesviewer.set_selected((n,),True)

		self.curbox=n
		self.update_coords()

	def update_boximgs(self):
		self.boxids=[im for im,m in enumerate(self.boxesimgs) if self.boxes[im][5] in self.sets_visible]
		self.boxesviewer.set_data([self.boxesimgs[i] for i in self.boxids])
		self.boxesviewer.update()
		return

	def img_selected(self,event,lc):
		#print "sel",lc[0]
		lci=self.boxids[lc[0]]
		if event.modifiers()&Qt.ShiftModifier:
			self.del_box(lci)
		else:
			self.update_box(lci)
		if self.curbox>=0 :
			box=self.boxes[self.curbox]
			self.xyview.scroll_to(box[0],box[1])
			self.xzview.scroll_to(None,box[2])
			self.zyview.scroll_to(box[2],None)
			self.currentset=box[5]
			self.setspanel.initialized=False
			self.setspanel.update_sets()

	def del_region_xy(self, x=-1, y=-1, z=-1, rad=-1):
		if rad<0:
			rad=self.eraser_width()
		
		delids=[]
		for i,b in enumerate(self.boxes):
			if b[5] not in self.sets_visible:
				continue
			
			if (x>=0)*(b[0]-x)**2 + (y>=0)*(b[1]-y)**2 +(z>=0)*(b[2]-z)**2 < rad**2:
				delids.append(i)
		self.do_deletion(delids)

	def xy_down(self,event):
		x,y=self.xyview.scr_to_img((event.x(),event.y()))
		x,y=int(x),int(y)
		z=int(self.get_z())
		self.xydown=None
		if x<0 or y<0 : return		# no clicking outside the image (on 2 sides)
		if self.optionviewer.erasercheckbox.isChecked():
			self.del_region_xy(x,y)
			return
			
		for i in range(len(self.boxes)):
			if self.inside_box(i,x,y,z):
				if event.modifiers()&Qt.ShiftModifier:
					self.del_box(i)
					self.firsthbclick = None
				else:
					self.xydown=(i,x,y,self.boxes[i][0],self.boxes[i][1])
					self.update_box(i)
				break
		else:
#			if x>self.get_boxsize()/2 and x<self.datasize[0]-self.get_boxsize()/2 and y>self.get_boxsize()/2 and y<self.datasize[1]-self.get_boxsize()/2 and self.depth()>self.get_boxsize()/2 and self.depth()<self.datasize[2]-self.get_boxsize()/2 :
			if not event.modifiers()&Qt.ShiftModifier:
				self.boxes.append(([x,y,self.depth(), 'manual', 0.0, self.currentset]))
				self.xydown=(len(self.boxes)-1,x,y,x,y)		# box #, x down, y down, x box at down, y box at down
				self.update_box(self.xydown[0])

		if self.curbox>=0:
			box=self.boxes[self.curbox]
			self.xzview.scroll_to(None,box[2])
			self.zyview.scroll_to(box[2],None)

	def xy_drag(self,event):
		
		x,y=self.xyview.scr_to_img((event.x(),event.y()))
		x,y=int(x),int(y)
		if self.optionviewer.erasercheckbox.isChecked():
			self.del_region_xy(x,y)
			self.xyview.eraser_shape=EMShape(["circle",1,1,1,x,y,self.eraser_width(),2])
			self.xyview.shapechange=1
			self.xyview.update()
			return
		
		if self.xydown==None : return


		dx=x-self.xydown[1]
		dy=y-self.xydown[2]

		self.boxes[self.xydown[0]][0]=dx+self.xydown[3]
		self.boxes[self.xydown[0]][1]=dy+self.xydown[4]
		self.update_box(self.curbox,True)

	def xy_up  (self,event):
		if self.xydown!=None: self.update_box(self.curbox)
		self.xydown=None

	def xy_wheel (self,event):
		if event.delta() > 0:
			#self.wdepth.setValue(self.wdepth.value()+4)
			self.wdepth.setValue(self.wdepth.value()+1) #jesus

		elif event.delta() < 0:
			#self.wdepth.setValue(self.wdepth.value()-4)
			self.wdepth.setValue(self.wdepth.value()-1) #jesus


	def xy_scale(self,news):
		"xy image view has been rescaled"
		self.wscale.setValue(news)
		#self.xzview.set_scale(news,True)
		#self.zyview.set_scale(news,True)

	def xy_origin(self,newor):
		"xy origin change"
		xzo=self.xzview.get_origin()
		self.xzview.set_origin(newor[0],xzo[1],True)

		zyo=self.zyview.get_origin()
		self.zyview.set_origin(zyo[0],newor[1],True)
	
	def xy_move(self,event):
		if self.optionviewer.erasercheckbox.isChecked():
			x,y=self.xyview.scr_to_img((event.x(),event.y()))
			#print x,y
			self.xyview.eraser_shape=EMShape(["circle",1,1,1,x,y,self.eraser_width(),2])
			self.xyview.shapechange=1
			self.xyview.update()
		else:
			self.xyview.eraser_shape=None

	def xz_down(self,event):
		x,z=self.xzview.scr_to_img((event.x(),event.y()))
		x,z=int(x),int(z)
		y=int(self.get_y())
		self.xzdown=None
		if x<0 or z<0 : return		# no clicking outside the image (on 2 sides)
		if self.optionviewer.erasercheckbox.isChecked():
			return
		for i in range(len(self.boxes)):
			if (not self.wlocalbox.isChecked() and self.inside_box(i,x,y,z)) or self.inside_box(i,x,self.cury,z) :
				if event.modifiers()&Qt.ShiftModifier:
					self.del_box(i)
					self.firsthbclick = None
				else :
					self.xzdown=(i,x,z,self.boxes[i][0],self.boxes[i][2])
					self.update_box(i)
				break
		else:
			if not event.modifiers()&Qt.ShiftModifier:
				self.boxes.append(([x,self.cury,z, 'manual', 0.0, self.currentset]))
				self.xzdown=(len(self.boxes)-1,x,z,x,z)		# box #, x down, y down, x box at down, y box at down
				self.update_box(self.xzdown[0])

		if self.curbox>=0 :
			box=self.boxes[self.curbox]
			self.xyview.scroll_to(None,box[1])
			self.zyview.scroll_to(box[2],None)

	def xz_drag(self,event):
		if self.xzdown==None : return

		x,z=self.xzview.scr_to_img((event.x(),event.y()))
		x,z=int(x),int(z)

		dx=x-self.xzdown[1]
		dz=z-self.xzdown[2]

		self.boxes[self.xzdown[0]][0]=dx+self.xzdown[3]
		self.boxes[self.xzdown[0]][2]=dz+self.xzdown[4]
		self.update_box(self.curbox,True)

	def xz_up  (self,event):
		if self.xzdown!=None: self.update_box(self.curbox)
		self.xzdown=None

	def xz_scale(self,news):
		"xy image view has been rescaled"
		self.wscale.setValue(news)
		#self.xyview.set_scale(news,True)
		#self.zyview.set_scale(news,True)

	def xz_origin(self,newor):
		"xy origin change"
		xyo=self.xyview.get_origin()
		self.xyview.set_origin(newor[0],xyo[1],True)

		#zyo=self.zyview.get_origin()
		#self.zyview.set_origin(zyo[0],newor[1],True)


	def zy_down(self,event):
		z,y=self.zyview.scr_to_img((event.x(),event.y()))
		z,y=int(z),int(y)
		x=int(self.get_x())
		self.xydown=None
		if z<0 or y<0 : return		# no clicking outside the image (on 2 sides)

		for i in range(len(self.boxes)):
			if (not self.wlocalbox.isChecked() and self.inside_box(i,x,y,z)) or  self.inside_box(i,self.curx,y,z):
				if event.modifiers()&Qt.ShiftModifier:
					self.del_box(i) 
					self.firsthbclick = None
				else :
					self.zydown=(i,z,y,self.boxes[i][2],self.boxes[i][1])
					self.update_box(i)
				break
		else:
			if not event.modifiers()&Qt.ShiftModifier:
				###########
				self.boxes.append(([self.curx,y,z, 'manual', 0.0, self.currentset]))
				self.zydown=(len(self.boxes)-1,z,y,z,y)		# box #, x down, y down, x box at down, y box at down
				self.update_box(self.zydown[0])

		if self.curbox>=0 :
			box=self.boxes[self.curbox]
			self.xyview.scroll_to(box[0],None)
			self.xzview.scroll_to(None,box[2])

	def zy_drag(self,event):
		if self.zydown==None : return

		z,y=self.zyview.scr_to_img((event.x(),event.y()))
		z,y=int(z),int(y)

		dz=z-self.zydown[1]
		dy=y-self.zydown[2]

		self.boxes[self.zydown[0]][2]=dz+self.zydown[3]
		self.boxes[self.zydown[0]][1]=dy+self.zydown[4]
		self.update_box(self.curbox,True)

	def zy_up  (self,event):
		if self.zydown!=None:
			self.update_box(self.curbox)
		self.zydown=None

	def zy_scale(self,news):
		"xy image view has been rescaled"
		self.wscale.setValue(news)
		#self.xyview.set_scale(news,True)
		#self.xzview.set_scale(news,True)

	def zy_origin(self,newor):
		"xy origin change"
		xyo=self.xyview.get_origin()
		self.xyview.set_origin(xyo[0],newor[1],True)

		#xzo=self.xzview.get_origin()
		#self.xzview.set_origin(xzo[0],newor[1],True)
	
	
	def set_current_set(self, name):
		
		#print "set current", name
		name=parse_setname(name)
		self.currentset=name
		self.wboxsize.setValue(self.get_boxsize())
		self.update_all()
		return
	
	
	def hide_set(self, name):
		name=parse_setname(name)
		
		if name in self.sets_visible: self.sets_visible.pop(name)
		
		
		if self.initialized: 
			self.update_all()
			self.update_boximgs()
		return
	
	
	def show_set(self, name):
		name=parse_setname(name)
		self.sets_visible[name]=0
		#self.currentset=name
		self.wboxsize.setValue(self.get_boxsize())
		if self.initialized: 
			self.update_all()
			self.update_boximgs()
		return
	
	
	def delete_set(self, name):
		name=parse_setname(name)
		## idx to keep
		delids=[i for i,b in enumerate(self.boxes) if b[5]==int(name)]
		self.do_deletion(delids)
		
		if name in self.sets_visible: self.sets_visible.pop(name)
		if name in self.sets: self.sets.pop(name)
		if name in self.boxsize: self.boxsize.pop(name)
		
		self.curbox=-1
		self.update_all()
		
		return
	
	def rename_set(self, oldname,  newname):
		name=parse_setname(oldname)
		if name in self.sets: 
			self.sets[name]=newname
		return
	
	
	def new_set(self, name):
		for i in range(len(self.sets)+1):
			if i not in self.sets:
				break
			
		self.sets[i]=name
		self.sets_visible[i]=0
		if self.options.mode=="3D":
			self.boxsize[i]=32
		else:
			self.boxsize[i]=64
		
		return
	
	def save_set(self):
		
		self.save_boxes(list(self.sets_visible.keys()))
		return
	
	
	def key_press(self,event):
		if event.key() == 96:
			self.wdepth.setValue(self.wdepth.value()+1)

		elif event.key() == 49:
			self.wdepth.setValue(self.wdepth.value()-1)
		else:
			self.keypress.emit(event)

	def SaveJson(self):
		
		info=js_open_dict(self.jsonfile)
		sx,sy,sz=(self.data["nx"]//2,self.data["ny"]//2,self.data["nz"]//2)
		if "apix_unbin" in info:
			bxs=[]
			for b0 in self.boxes:
				b=[	(b0[0]-sx)*self.apix_cur/self.apix_unbin,
					(b0[1]-sy)*self.apix_cur/self.apix_unbin,
					(b0[2]-sz)*self.apix_cur/self.apix_unbin,
					b0[3], b0[4], b0[5]	]
				bxs.append(b)
				
			bxsz={}
			for k in self.boxsize.keys():
				bxsz[k]=self.boxsize[k]*self.apix_cur/self.apix_unbin

				
		else:
			bxs=self.boxes
			bxsz=self.boxsize
				
		info["boxes_3d"]=bxs
		clslst={}
		for key in list(self.sets.keys()):
			clslst[int(key)]={
				"name":self.sets[key],
				"boxsize":int(bxsz[key]),
				}
		info["class_list"]=clslst
		info.close()
	
	def closeEvent(self,event):
		print("Exiting")
		self.SaveJson()
		
		self.boxviewer.close()
		self.boxesviewer.close()
		self.optionviewer.close()
		self.xyview.close()
		self.xzview.close()
		self.zyview.close()
		
		self.module_closed.emit() # this signal is important when e2ctf is being used by a program running its own event loop
示例#5
0
class EMHelloWorldInspector(QtWidgets.QWidget):
    def __init__(self, target):
        QtWidgets.QWidget.__init__(self, None)
        self.target = target

        self.vbl = QtWidgets.QVBoxLayout(self)
        self.vbl.setContentsMargins(0, 0, 0, 0)
        self.vbl.setSpacing(6)
        self.vbl.setObjectName("vbl")

        self.hbl = QtWidgets.QHBoxLayout()
        self.hbl.setContentsMargins(0, 0, 0, 0)
        self.hbl.setSpacing(6)
        self.hbl.setObjectName("hbl")
        self.vbl.addLayout(self.hbl)

        self.vbl2 = QtWidgets.QVBoxLayout()
        self.vbl2.setContentsMargins(0, 0, 0, 0)
        self.vbl2.setSpacing(6)
        self.vbl2.setObjectName("vbl2")
        self.hbl.addLayout(self.vbl2)

        self.wiretog = QtWidgets.QPushButton("Wire")
        self.wiretog.setCheckable(1)
        self.vbl2.addWidget(self.wiretog)

        self.lighttog = QtWidgets.QPushButton("Light")
        self.lighttog.setCheckable(1)
        self.vbl2.addWidget(self.lighttog)

        self.tabwidget = QtWidgets.QTabWidget()
        self.maintab = None
        self.tabwidget.addTab(self.get_main_tab(), "Main")
        self.tabwidget.addTab(self.get_GL_tab(), "GL")
        self.vbl.addWidget(self.tabwidget)
        self.n3_showing = False

        self.scale.valueChanged.connect(target.set_scale)
        self.az.valueChanged.connect(self.slider_rotate)
        self.alt.valueChanged.connect(self.slider_rotate)
        self.phi.valueChanged.connect(self.slider_rotate)
        self.cbb.currentIndexChanged[str].connect(target.setColor)
        self.src.currentIndexChanged[str].connect(self.set_src)
        self.x_trans.valueChanged[double].connect(target.set_cam_x)
        self.y_trans.valueChanged[double].connect(target.set_cam_y)
        self.z_trans.valueChanged[double].connect(target.set_cam_z)
        self.wiretog.toggled[bool].connect(target.toggle_wire)
        self.lighttog.toggled[bool].connect(target.toggle_light)
        self.glcontrast.valueChanged.connect(target.set_GL_contrast)
        self.glbrightness.valueChanged.connect(target.set_GL_brightness)

    def get_GL_tab(self):
        self.gltab = QtWidgets.QWidget()
        gltab = self.gltab

        gltab.vbl = QtWidgets.QVBoxLayout(self.gltab)
        gltab.vbl.setContentsMargins(0, 0, 0, 0)
        gltab.vbl.setSpacing(6)
        gltab.vbl.setObjectName("Main")

        self.glcontrast = ValSlider(gltab, (1.0, 5.0), "GLShd:")
        self.glcontrast.setObjectName("GLShade")
        self.glcontrast.setValue(1.0)
        gltab.vbl.addWidget(self.glcontrast)

        self.glbrightness = ValSlider(gltab, (-1.0, 0.0), "GLBst:")
        self.glbrightness.setObjectName("GLBoost")
        self.glbrightness.setValue(0.1)
        self.glbrightness.setValue(0.0)
        gltab.vbl.addWidget(self.glbrightness)

        return gltab

    def get_main_tab(self):
        if (self.maintab == None):
            self.maintab = QtWidgets.QWidget()
            maintab = self.maintab
            maintab.vbl = QtWidgets.QVBoxLayout(self.maintab)
            maintab.vbl.setContentsMargins(0, 0, 0, 0)
            maintab.vbl.setSpacing(6)
            maintab.vbl.setObjectName("Main")

            self.scale = ValSlider(maintab, (0.01, 30.0), "Zoom:")
            self.scale.setObjectName("scale")
            self.scale.setValue(1.0)
            maintab.vbl.addWidget(self.scale)

            self.hbl_color = QtWidgets.QHBoxLayout()
            self.hbl_color.setContentsMargins(0, 0, 0, 0)
            self.hbl_color.setSpacing(6)
            self.hbl_color.setObjectName("Material")
            maintab.vbl.addLayout(self.hbl_color)

            self.color_label = QtWidgets.QLabel()
            self.color_label.setText('Material')
            self.hbl_color.addWidget(self.color_label)

            self.cbb = QtWidgets.QComboBox(maintab)
            self.hbl_color.addWidget(self.cbb)

            self.hbl_trans = QtWidgets.QHBoxLayout()
            self.hbl_trans.setContentsMargins(0, 0, 0, 0)
            self.hbl_trans.setSpacing(6)
            self.hbl_trans.setObjectName("Trans")
            maintab.vbl.addLayout(self.hbl_trans)

            self.x_label = QtWidgets.QLabel()
            self.x_label.setText('x')
            self.hbl_trans.addWidget(self.x_label)

            self.x_trans = QtWidgets.QDoubleSpinBox(self)
            self.x_trans.setMinimum(-10000)
            self.x_trans.setMaximum(10000)
            self.x_trans.setValue(0.0)
            self.hbl_trans.addWidget(self.x_trans)

            self.y_label = QtWidgets.QLabel()
            self.y_label.setText('y')
            self.hbl_trans.addWidget(self.y_label)

            self.y_trans = QtWidgets.QDoubleSpinBox(maintab)
            self.y_trans.setMinimum(-10000)
            self.y_trans.setMaximum(10000)
            self.y_trans.setValue(0.0)
            self.hbl_trans.addWidget(self.y_trans)

            self.z_label = QtWidgets.QLabel()
            self.z_label.setText('z')
            self.hbl_trans.addWidget(self.z_label)

            self.z_trans = QtWidgets.QDoubleSpinBox(maintab)
            self.z_trans.setMinimum(-10000)
            self.z_trans.setMaximum(10000)
            self.z_trans.setValue(0.0)
            self.hbl_trans.addWidget(self.z_trans)

            self.hbl_src = QtWidgets.QHBoxLayout()
            self.hbl_src.setContentsMargins(0, 0, 0, 0)
            self.hbl_src.setSpacing(6)
            self.hbl_src.setObjectName("hbl")
            maintab.vbl.addLayout(self.hbl_src)

            self.label_src = QtWidgets.QLabel()
            self.label_src.setText('Rotation Convention')
            self.hbl_src.addWidget(self.label_src)

            self.src = QtWidgets.QComboBox(maintab)
            self.load_src_options(self.src)
            self.hbl_src.addWidget(self.src)

            # set default value -1 ensures that the val slider is updated the first time it is created
            self.az = ValSlider(self, (-360.0, 360.0), "az", -1)
            self.az.setObjectName("az")
            maintab.vbl.addWidget(self.az)

            self.alt = ValSlider(self, (-180.0, 180.0), "alt", -1)
            self.alt.setObjectName("alt")
            maintab.vbl.addWidget(self.alt)

            self.phi = ValSlider(self, (-360.0, 360.0), "phi", -1)
            self.phi.setObjectName("phi")
            maintab.vbl.addWidget(self.phi)

            self.current_src = EULER_EMAN

        return self.maintab

    def set_xy_trans(self, x, y):
        self.x_trans.setValue(x)
        self.y_trans.setValue(y)

    def set_xyz_trans(self, x, y, z):
        self.x_trans.setValue(x)
        self.y_trans.setValue(y)
        self.z_trans.setValue(z)

    def set_translate_scale(self, xscale, yscale, zscale):
        self.x_trans.setSingleStep(xscale)
        self.y_trans.setSingleStep(yscale)
        self.z_trans.setSingleStep(zscale)

    def update_rotations(self, t3d):
        convention = str(self.src.currentText())
        #FIXME: Transform.get_rotation() wants a string sometimes and a EulerType other times
        try:
            rot = t3d.get_rotation(str(self.src_map[convention]))
        except Exception as e:  #doing a quick fix
            print(e)
            print(
                "Developers: This catches a large range of exceptions... a better way surely exists"
            )
            rot = t3d.get_rotation(self.src_map[convention])

        if (self.src_map[convention] == EULER_SPIN):
            self.n3.setValue(rot[self.n3.getLabel()], True)

        self.az.setValue(rot[self.az.getLabel()], True)
        self.alt.setValue(rot[self.alt.getLabel()], True)
        self.phi.setValue(rot[self.phi.getLabel()], True)

    def slider_rotate(self):
        self.target.load_rotation(self.get_current_rotation())

    def get_current_rotation(self):
        convention = self.src.currentText()
        rot = {}
        if (self.current_src == EULER_SPIN):
            rot[self.az.getLabel()] = self.az.getValue()

            n1 = self.alt.getValue()
            n2 = self.phi.getValue()
            n3 = self.n3.getValue()

            norm = sqrt(n1 * n1 + n2 * n2 + n3 * n3)

            n1 /= norm
            n2 /= norm
            n3 /= norm

            rot[self.alt.getLabel()] = n1
            rot[self.phi.getLabel()] = n2
            rot[self.n3.getLabel()] = n3

        else:
            rot[self.az.getLabel()] = self.az.getValue()
            rot[self.alt.getLabel()] = self.alt.getValue()
            rot[self.phi.getLabel()] = self.phi.getValue()

        return Transform(self.current_src, rot)

    def set_src(self, val):
        t3d = self.get_current_rotation()

        if (self.n3_showing):
            self.vbl.removeWidget(self.n3)
            self.n3.deleteLater()
            self.n3_showing = False
            self.az.setRange(-360, 360)
            self.alt.setRange(-180, 180)
            self.phi.setRange(-360, 660)

        if (self.src_map[str(val)] == EULER_SPIDER):
            self.az.setLabel('phi')
            self.alt.setLabel('theta')
            self.phi.setLabel('psi')
        elif (self.src_map[str(val)] == EULER_EMAN):
            self.az.setLabel('az')
            self.alt.setLabel('alt')
            self.phi.setLabel('phi')
        elif (self.src_map[str(val)] == EULER_IMAGIC):
            self.az.setLabel('alpha')
            self.alt.setLabel('beta')
            self.phi.setLabel('gamma')
        elif (self.src_map[str(val)] == EULER_XYZ):
            self.az.setLabel('xtilt')
            self.alt.setLabel('ytilt')
            self.phi.setLabel('ztilt')
        elif (self.src_map[str(val)] == EULER_MRC):
            self.az.setLabel('phi')
            self.alt.setLabel('theta')
            self.phi.setLabel('omega')
        elif (self.src_map[str(val)] == EULER_SPIN):
            self.az.setLabel('omega')
            self.alt.setRange(-1, 1)
            self.phi.setRange(-1, 1)

            self.alt.setLabel('n1')
            self.phi.setLabel('n2')

            self.n3 = ValSlider(self, (-360.0, 360.0), "n3", -1)
            self.n3.setRange(-1, 1)
            self.n3.setObjectName("n3")
            self.vbl.addWidget(self.n3)
            self.n3.valueChanged.connect(self.slider_rotate)
            self.n3_showing = True

        self.current_src = self.src_map[str(val)]
        self.update_rotations(t3d)

    def load_src_options(self, widgit):
        self.load_src()
        for i in self.src_strings:
            widgit.addItem(i)

    # read src as 'supported rotation conventions'
    def load_src(self):
        # supported_rot_conventions
        src_flags = []
        src_flags.append(EULER_EMAN)
        src_flags.append(EULER_SPIDER)
        src_flags.append(EULER_IMAGIC)
        src_flags.append(EULER_MRC)
        src_flags.append(EULER_SPIN)
        src_flags.append(EULER_XYZ)

        self.src_strings = []
        self.src_map = {}
        for i in src_flags:
            self.src_strings.append(str(i))
            self.src_map[str(i)] = i

    def setColors(self, colors, current_color):
        a = 0
        for i in colors:
            self.cbb.addItem(i)
            if (i == current_color):
                self.cbb.setCurrentIndex(a)
            a += 1

    def set_scale(self, newscale):
        self.scale.setValue(newscale)
示例#6
0
class TrackerControl(QtWidgets.QWidget):
    def __init__(self,
                 app,
                 maxshift,
                 invert=False,
                 seqali=False,
                 tiltstep=2.0):
        self.app = app
        self.maxshift = maxshift
        self.seqali = seqali
        self.invert = invert
        self.tiltstep = tiltstep

        # the control panel
        QtWidgets.QWidget.__init__(self, None)

        self.gbl = QtWidgets.QGridLayout(self)
        self.gbl.setContentsMargins(0, 0, 0, 0)
        self.gbl.setSpacing(6)
        self.gbl.setObjectName("hbl")

        # action buttons
        self.bcenalign = QtWidgets.QPushButton("Center Align")
        self.bprojalign = QtWidgets.QPushButton("Proj. Realign")
        self.btiltaxis = QtWidgets.QPushButton("Tilt Axis")
        self.btiltaxisval = QtWidgets.QLineEdit("90.0")
        self.bsavedata = QtWidgets.QPushButton("Save Data")
        self.breconst = QtWidgets.QPushButton("3D Normal")
        self.sbmode = QtWidgets.QSpinBox(self)
        self.sbmode.setRange(0, 2)
        self.sbmode.setValue(0)
        self.bmagict = QtWidgets.QPushButton("3D Tomofill")
        self.bmagics = QtWidgets.QPushButton("3D Sph")
        self.bmagicc = QtWidgets.QPushButton("3D Cyl")
        self.vslpfilt = ValSlider(self, (0, .5), "Filter", 0.5, 50)

        self.gbl.addWidget(self.bcenalign, 0, 0)
        self.gbl.addWidget(self.bprojalign, 0, 1)
        self.gbl.addWidget(self.btiltaxis, 0, 2)
        self.gbl.addWidget(self.btiltaxisval, 0, 3)
        #		self.gbl.addWidget(self.bsavedata,0,3)
        self.gbl.addWidget(self.breconst, 1, 0)
        self.gbl.addWidget(self.sbmode, 2, 0, 1, 1)
        self.gbl.addWidget(self.vslpfilt, 3, 0, 1, 4)
        self.gbl.addWidget(self.bmagict, 1, 1)
        self.gbl.addWidget(self.bmagics, 1, 2)
        self.gbl.addWidget(self.bmagicc, 1, 3)

        self.bcenalign.clicked[bool].connect(self.do_cenalign)
        self.bprojalign.clicked[bool].connect(self.do_projalign)
        self.btiltaxis.clicked[bool].connect(self.do_tiltaxis)
        self.bsavedata.clicked[bool].connect(self.do_savedata)
        self.breconst.clicked[bool].connect(self.do_reconst)
        self.bmagict.clicked[bool].connect(self.do_magict)
        self.bmagics.clicked[bool].connect(self.do_magics)
        self.bmagicc.clicked[bool].connect(self.do_magicc)
        self.vslpfilt.valueChanged.connect(self.do_filter)

        # the single image display widget
        self.im2d = EMImage2DWidget(application=app, winid="tomotrackbox.big")
        self.imboxed = EMImage2DWidget(application=app,
                                       winid="tomotrackbox.small")
        self.improj = EMImage2DWidget(application=app,
                                      winid="tomotrackbox.proj")
        self.imslice = EMImage2DWidget(application=app,
                                       winid="tomotrackbox.3dslice")
        self.imvol = EMImage3DWidget(application=app, winid="tomotrackbox.3d")

        # get some signals from the window.
        self.im2d.mousedown.connect(self.down)
        self.im2d.mousedrag.connect(self.drag)
        self.im2d.mouseup.connect(self.up)
        self.im2d.signal_increment_list_data.connect(self.change_tilt)

        self.imagefile = None
        self.imageparm = None
        self.tiltshapes = None
        self.curtilt = 0
        self.oldtilt = self.curtilt
        self.map3d = None
        self.downloc = None
        self.downadjloc = None

        self.show()
        self.im2d.show()

    def closeEvent(self, event):
        self.im2d.closeEvent(QtGui.QCloseEvent())
        self.imboxed.closeEvent(QtGui.QCloseEvent())
        self.improj.closeEvent(QtGui.QCloseEvent())
        self.imslice.closeEvent(QtGui.QCloseEvent())
        self.imvol.closeEvent(QtGui.QCloseEvent())
        event.accept()

    def do_cenalign(self, x=0):
        """In response to the center align button. Just a wrapper"""
        self.cenalign_stack()
        self.update_stack()

    def do_projalign(self, x=0):
        """In response to the projection align button. Just a wrapper"""
        self.projection_align(self.tiltstep)
        self.update_stack()
#		self.do_reconst()

    def do_tiltaxis(self):
        """In response to the tilt axis button. Just a wrapper"""
        self.tilt_axis()

    def do_reconst(self, x=0):
        """In response to the normal reconstruction button. Just a wrapper"""
        stack = self.get_boxed_stack()
        mode = self.sbmode.value()
        self.map3d = self.reconstruct(stack, self.tiltstep, mode)
        self.update_3d()

    def do_magict(self, x):
        """In response to tomographic filling reconstruction button. Just a wrapper"""
        stack = self.get_boxed_stack()
        #		self.map3d=self.reconstruct_ca(stack[5:-4],0.5)
        #		init=self.reconstruct_ca(stack[5:-4],0.5)
        mode = self.sbmode.value()
        self.map3d = self.reconstruct_wedgefill(stack, self.tiltstep, mode)
        self.update_3d()

    def do_magics(self, x):
        """In response to the 3D Sph button. Just a wrapper"""
        return

    def do_magicc(self, x):
        """In response to the 3D cyl button. Just a wrapper"""
        return

    def do_filter(self, v):
        """In response to the filter ValSlider"""
        if self.map3d == None: return
        self.lpfilt = v
        self.update_3d()

    def do_savedata(self):
        ""

    def update_3d(self):
        if self.map3d == None: return

        self.filt3d = self.map3d.process(
            "filter.lowpass.gauss", {"cutoff_abs": self.vslpfilt.getValue()})

        self.imvol.set_data(self.filt3d)
        self.imvol.show()
        self.imvol.updateGL()

        sz = self.map3d["nx"]
        xsum = self.filt3d.process("misc.directional_sum", {"axis": "x"})
        xsum.set_size(sz, sz, 1)
        ysum = self.filt3d.process("misc.directional_sum", {"axis": "y"})
        ysum.set_size(sz, sz, 1)
        zsum = self.filt3d.process("misc.directional_sum", {"axis": "z"})
        zsum.set_size(sz, sz, 1)

        self.improj.set_data([zsum, ysum, xsum])
        self.improj.show()
        self.improj.updateGL()

        self.imslice.set_data(self.filt3d)
        self.imslice.show()
        self.imslice.updateGL()

    def update_stack(self):
        stack = self.get_boxed_stack()
        self.imboxed.set_data(stack)
        self.imboxed.show()
        self.imboxed.updateGL()

    def set_image(self, fsp):
        """Takes an ali file to process"""
        self.imageparm = EMData(fsp, 0, True).get_attr_dict()
        print(
            "%d slices at %d x %d" %
            (self.imageparm["nz"], self.imageparm["nx"], self.imageparm["ny"]))

        self.imagefile = fsp

        self.curtilt = old_div(self.imageparm["nz"], 2)
        self.tiltshapes = [None for i in range(self.imageparm["nz"])]
        self.update_tilt()

    def update_tilt(self):
        if self.imagefile == None: return

        self.curimg = EMData(
            self.imagefile, 0, False,
            Region(0, 0, self.curtilt, self.imageparm["nx"],
                   self.imageparm["ny"], 1))
        if self.invert: self.curimg.mult(-1.0)
        self.im2d.set_data(self.curimg)

        s = EMShape(
            ["scrlabel", .7, .3, 0, 20.0, 20.0,
             "%d" % self.curtilt, 200.0, 1])
        self.im2d.add_shape("tilt", s)

        if self.tiltshapes[self.curtilt] != None:
            self.im2d.add_shape("finalbox", self.tiltshapes[self.curtilt])

            s0 = self.tiltshapes[self.oldtilt].getShape()
            s1 = self.tiltshapes[self.curtilt].getShape()
            dx = s0[4] - s1[4]
            dy = s0[5] - s1[5]

            self.im2d.set_origin(self.im2d.origin[0] - dx,
                                 self.im2d.origin[1] - dy)
            self.oldtilt = self.curtilt

        self.im2d.updateGL()

    def change_tilt(self, direc):
        """When the user presses the up or down arrow"""
        self.oldtilt = self.curtilt
        self.curtilt += direc
        if self.curtilt < 0: self.curtilt = 0
        if self.curtilt >= self.imageparm["nz"]:
            self.curtilt = self.imageparm["nz"] - 1

        self.update_tilt()

    def down(self, event, lc):
        """The event contains the x,y coordinates in window space, lc are the coordinates in image space"""

        if event.buttons() & Qt.LeftButton:
            if event.modifiers() & Qt.ShiftModifier:
                self.downadjloc = (
                    lc, self.tiltshapes[self.curtilt].getShape()[4:8])
            else:
                self.downloc = lc

    def drag(self, event, lc):
        if self.downloc != None:
            dx = abs(lc[0] - self.downloc[0])
            dy = abs(lc[1] - self.downloc[1])
            dx = max(dx, dy)  # Make box square
            dx = old_div(good_size(dx * 2), 2)  # use only good sizes
            dy = dx
            s = EMShape([
                "rectpoint", 0, .7, 0, self.downloc[0] - dx,
                self.downloc[1] - dy, self.downloc[0] + dx,
                self.downloc[1] + dy, 1
            ])
            self.im2d.add_shape("box", s)
            s = EMShape([
                "scrlabel", .7, .7, 0, 20.0, 20.0,
                "%d (%d x %d)" % (self.curtilt, dx * 2, dy * 2), 200.0, 1
            ])
            self.im2d.add_shape("tilt", s)
        elif self.downadjloc != None:
            dx = (lc[0] - self.downadjloc[0][0])
            dy = (lc[1] - self.downadjloc[0][1])
            s = self.tiltshapes[self.curtilt].getShape()[:]
            s[4] = self.downadjloc[1][0] + dx
            s[5] = self.downadjloc[1][1] + dy
            s[6] = self.downadjloc[1][2] + dx
            s[7] = self.downadjloc[1][3] + dy
            self.im2d.add_shape("box", EMShape(s))

        self.im2d.updateGL()

    def up(self, event, lc):
        if self.downloc != None:
            dx = abs(lc[0] - self.downloc[0])
            dy = abs(lc[1] - self.downloc[1])
            dx = max(dx, dy)  # Make box square
            dx = old_div(good_size(dx * 2), 2)  # use only good sizes
            dy = dx
            s = EMShape([
                "rectpoint", .7, .2, 0, self.downloc[0] - dx,
                self.downloc[1] - dy, self.downloc[0] + dx,
                self.downloc[1] + dy, 1
            ])
            self.im2d.del_shape("box")
            if hypot(lc[0] - self.downloc[0], lc[1] - self.downloc[1]) > 5:
                self.tiltshapes = [None for i in range(self.imageparm["nz"])]
                self.find_boxes(s)

            self.update_tilt()
            self.downloc = None
        elif self.downadjloc != None:
            dx = (lc[0] - self.downadjloc[0][0])
            dy = (lc[1] - self.downadjloc[0][1])
            s = self.tiltshapes[self.curtilt].getShape()[:]
            s[4] = self.downadjloc[1][0] + dx
            s[5] = self.downadjloc[1][1] + dy
            s[6] = self.downadjloc[1][2] + dx
            s[7] = self.downadjloc[1][3] + dy
            self.tiltshapes[self.curtilt] = EMShape(s)
            self.im2d.add_shape("finalbox", self.tiltshapes[self.curtilt])
            self.im2d.del_shape("box")

            self.update_tilt()
            self.update_stack()
            self.downadjloc = None

    def get_boxed_stack(self):
        stack = []
        for i in range(self.imageparm["nz"]):
            refshape = self.tiltshapes[i].getShape()
            img = EMData(
                self.imagefile, 0, False,
                Region(refshape[4], refshape[5], i, refshape[6] - refshape[4],
                       refshape[7] - refshape[5], 1))
            img["ptcl_source_coord"] = (int(
                old_div((refshape[6] + refshape[4]),
                        2.0)), int(old_div((refshape[7] + refshape[5]),
                                           2.0)), i)
            img["ptcl_source_image"] = str(self.imagefile)
            if self.invert: img.mult(-1.0)
            img.process_inplace("normalize.edgemean")
            stack.append(img)

        return stack

    def cenalign_stack(self):
        """This will perform an iterative centering process on a stack of particle images, centering each on the average.
	It will modify the current stack of boxing parameters in-place"""

        for it in range(5):
            stack = self.get_boxed_stack()

            # Average the stack, and center it
            av = stack[0].copy()
            for im in stack[1:]:
                av.add(im)
            av.process_inplace("filter.lowpass.gauss", {"cutoff_abs": .1})
            av.process_inplace("filter.highpass.gauss", {"cutoff_abs": .02})
            av.process_inplace("xform.centeracf")
            #display((av,av2))

            # align to the average
            for i, im in enumerate(stack):
                im.process_inplace("filter.lowpass.gauss", {"cutoff_abs": .1})
                im.process_inplace("filter.highpass.gauss",
                                   {"cutoff_abs": .02})
                ali = im.align("translational", av)
                trans = ali["xform.align2d"].get_trans()
                shape = self.tiltshapes[i]
                shape.translate(-trans[0], -trans[1])

        # Update the stack display
        stack = self.get_boxed_stack()
        self.imboxed.set_data(stack)

    def reconstruct_wedgefill(self, stack, angstep, mode=2):
        """Fills the missing wedge with the average of the slices"""
        print("Making 3D tomofill")

        taxis = float(self.btiltaxisval.text())
        boxsize = stack[0]["nx"]
        pad = Util.calc_best_fft_size(int(boxsize * 1.5))

        # average all of the slices together
        av = stack[0].copy()
        for p in stack[1:]:
            av += p
        av.del_attr("xform.projection")
        av.mult(old_div(1.0, (len(stack))))
        av = av.get_clip(
            Region(old_div(-(pad - boxsize), 2), old_div(-(pad - boxsize), 2),
                   pad, pad))

        for i, p in enumerate(stack):
            p["alt"] = (i - old_div(len(stack), 2)) * angstep

        # Determine a good angular step for filling Fourier space
        fullsamp = old_div(360.0, (boxsize * pi))
        if old_div(angstep, fullsamp) > 2.0:
            samp = old_div(1.0, (floor(old_div(angstep, fullsamp))))
        else:
            samp = angstep

        print("Subsampling = %1.2f" % samp)

        # Now the reconstruction
        recon = Reconstructors.get(
            "fourier", {
                "sym": "c1",
                "size": (pad, pad, pad),
                "mode": reconmodes[mode],
                "verbose": True
            })
        recon.setup()

        for ri in range(5):
            print("Iteration ", ri)
            for a in [
                    i * samp for i in range(-int(old_div(90.0, samp)),
                                            int(old_div(90.0, samp)) + 1)
            ]:
                for ii in range(len(stack) - 1):
                    if stack[ii]["alt"] <= a and stack[ii + 1]["alt"] > a:
                        break
                else:
                    ii = -1

                if a < stack[0]["alt"]:
                    p = av
                    #frac=0.5*(a-stack[0]["alt"])/(-90.0-stack[0]["alt"])
                    ## a bit wierd. At the ends (missing wedge) we use the average over all tilts. This could be improved
                    #p=stack[0].get_clip(Region(-(pad-boxsize)/2,-(pad-boxsize)/2,pad,pad))*(1.0-frac)+stack[-1].get_clip(Region(-(pad-boxsize)/2,-(pad-boxsize)/2,pad,pad))*frac
#					print a," avg ",frac,stack[0]["alt"]
                elif ii == -1:
                    p = av
                    #frac=0.5*(a-stack[-1]["alt"])/(90.0-stack[-1]["alt"])+.5
                    ## a bit wierd. At the ends (missing wedge) we use the average over all tilts. This could be improved
                    #p=stack[-1].get_clip(Region(-(pad-boxsize)/2,-(pad-boxsize)/2,pad,pad))*(1.0-frac)+stack[0].get_clip(Region(-(pad-boxsize)/2,-(pad-boxsize)/2,pad,pad))*frac
#					print a," avg ",frac
                else:
                    # We average slices in real space, producing a rotational 'smearing' effect
                    frac = old_div((a - stack[ii]["alt"]), angstep)
                    p = stack[ii].get_clip(
                        Region(old_div(-(pad - boxsize), 2),
                               old_div(-(pad - boxsize), 2), pad, pad)
                    ) * (1.0 - frac) + stack[ii + 1].get_clip(
                        Region(old_div(-(pad - boxsize), 2),
                               old_div(-(pad - boxsize), 2), pad, pad)) * frac
#					print a,ii,ii+1,stack[ii]["alt"],frac

                xf = Transform({
                    "type": "eman",
                    "alt": a,
                    "az": -taxis,
                    "phi": taxis
                })
                p["xform.projection"] = xf

                if ri % 2 == 1:
                    recon.determine_slice_agreement(p, xf, 1)
                else:
                    recon.insert_slice(p, xf)

        ret = recon.finish()
        print("Done")
        ret = ret.get_clip(
            Region(old_div((pad - boxsize), 2), old_div((pad - boxsize), 2),
                   old_div((pad - boxsize), 2), boxsize, boxsize, boxsize))
        ret.process_inplace("normalize.edgemean")
        #		ret=ret.get_clip(Region((pad-boxsize)/2,(pad-boxsize)/2,(pad-boxsize)/2,boxsize,boxsize,boxsize))

        return ret

    def reconstruct_ca(self, stack, angstep, mode=2):
        """Cylindrically averaged tomographic model, generally used for filling empty spaces. Returned volume is padded."""
        print("Making CA")

        taxis = float(self.btiltaxisval.text())
        boxsize = stack[0]["nx"]
        pad = Util.calc_best_fft_size(int(boxsize * 1.5))

        # average all of the slices together
        av = stack[0].copy()
        for p in stack[1:]:
            av += p
        av.del_attr("xform.projection")
        p.mult(old_div(1.0, len(stack)))
        av = av.get_clip(
            Region(old_div(-(pad - boxsize), 2), old_div(-(pad - boxsize), 2),
                   pad, pad))

        recon = Reconstructors.get("fourier", {
            "quiet": True,
            "sym": "c1",
            "x_in": pad,
            "y_in": pad
        })
        recon.setup()

        for ri in range(3):
            if ri > 0:
                alt = -180.0
                while (alt < 180.0):
                    recon.determine_slice_agreement(
                        av,
                        Transform({
                            "type": "eman",
                            "alt": alt,
                            "az": -taxis,
                            "phi": taxis
                        }), 1)
                    alt += angstep
            alt = -180.0
            while (alt < 180.0):
                recon.insert_slice(
                    av,
                    Transform({
                        "type": "eman",
                        "alt": alt,
                        "az": -taxis,
                        "phi": taxis
                    }))
                alt += angstep

        ret = recon.finish()
        ret.process_inplace("normalize.edgemean")
        #		ret=ret.get_clip(Region((pad-boxsize)/2,(pad-boxsize)/2,(pad-boxsize)/2,boxsize,boxsize,boxsize))

        return ret

    def reconstruct(self, stack, angstep, mode=0, initmodel=None):
        """ Tomographic reconstruction of the current stack """
        if initmodel != None: print("Using initial model")

        taxis = float(self.btiltaxisval.text())

        boxsize = stack[0]["nx"]
        pad = good_size(int(boxsize * 1.5))

        for i, p in enumerate(stack):
            p["xform.projection"] = Transform({
                "type":
                "eman",
                "alt": (i - old_div(len(stack), 2)) * angstep,
                "az":
                -taxis,
                "phi":
                taxis
            })

        recon = Reconstructors.get(
            "fourier", {
                "sym": "c1",
                "size": (pad, pad, pad),
                "mode": reconmodes[mode],
                "verbose": True
            })
        if initmodel != None: recon.setup(initmodel, .01)
        else: recon.setup()
        scores = []

        # First pass to assess qualities and normalizations
        for i, p in enumerate(stack):
            p2 = p.get_clip(
                Region(old_div(-(pad - boxsize), 2),
                       old_div(-(pad - boxsize), 2), pad, pad))
            p2 = recon.preprocess_slice(p2, p["xform.projection"])
            recon.insert_slice(p2, p["xform.projection"], 1.0)
            print(" %d    \r" % i)
        print("")

        # after building the model once we can assess how well everything agrees
        for p in stack:
            p2 = p.get_clip(
                Region(old_div(-(pad - boxsize), 2),
                       old_div(-(pad - boxsize), 2), pad, pad))
            p2 = recon.preprocess_slice(p2, p["xform.projection"])
            recon.determine_slice_agreement(p2, p["xform.projection"], 1.0,
                                            True)
            scores.append((p2["reconstruct_absqual"], p2["reconstruct_norm"]))
            print(" %d\t%1.3f    \r" % (i, scores[-1][0]))
        print("")

        # clear out the first reconstruction (probably no longer necessary)
        #		ret=recon.finish(True)
        #		ret=None

        # setup for the second run
        if initmodel != None: recon.setup(initmodel, .01)
        else: recon.setup()

        thr = old_div(
            0.7 * (scores[old_div(len(scores), 2)][0] +
                   scores[old_div(len(scores), 2) - 1][0] +
                   scores[old_div(len(scores), 2) + 1][0]), 3)
        # this is rather arbitrary
        # First pass to assess qualities and normalizations
        for i, p in enumerate(stack):
            if scores[i][0] < thr:
                print("%d. %1.3f *" % (i, scores[i][0]))
                continue

            print("%d. %1.2f \t%1.3f\t%1.3f" %
                  (i, p["xform.projection"].get_rotation("eman")["alt"],
                   scores[i][0], scores[i][1]))
            p2 = p.get_clip(
                Region(old_div(-(pad - boxsize), 2),
                       old_div(-(pad - boxsize), 2), pad, pad))
            p2 = recon.preprocess_slice(p2, p["xform.projection"])
            p2.mult(scores[i][1])
            recon.insert_slice(p2, p["xform.projection"], 1.0)

#		plot(scores)

        recon.set_param("savenorm", "norm.mrc")
        ret = recon.finish(True)
        ret = ret.get_clip(
            Region(old_div((pad - boxsize), 2), old_div((pad - boxsize), 2),
                   old_div((pad - boxsize), 2), boxsize, boxsize, boxsize))
        #		print "Quality: ",qual

        return ret

    def tilt_axis(self):
        ntilt = self.imageparm["nz"]
        sz = good_size(old_div(self.imageparm["nx"], 2))
        while 1:
            av = None
            n = 0
            for i in range(ntilt):
                refshape = self.tiltshapes[i].getShape()
                if refshape[4] <= old_div(sz, 2) or refshape[5] <= old_div(
                        sz,
                        2) or self.imageparm["nx"] - refshape[4] <= old_div(
                            sz, 2
                        ) or self.imageparm["ny"] - refshape[5] <= old_div(
                            sz, 2):
                    break
                img = EMData(
                    self.imagefile, 0, False,
                    Region(refshape[4] - old_div(sz, 2),
                           refshape[5] - old_div(sz, 2), i, sz, sz, 1))
                if self.invert: img.mult(-1.0)
                img.process_inplace("normalize.edgemean")

                if av == None: av = img
                else: av.add(img)
                n += 1

            if n == ntilt: break
            sz /= 2
            if sz < 32: return
            print(
                "You may wish to center on a feature closer to the center of the image next time -> ",
                sz)

        sz2 = good_size(sz + 128)
        av2 = av.get_clip(
            Region(old_div((sz - sz2), 2), old_div((sz - sz2), 2), sz2, sz2))
        av2.process_inplace("mask.zeroedgefill")
        av2.process_inplace("filter.flattenbackground", {"radius": 64})
        av = av2.get_clip(
            Region(old_div((sz2 - sz), 2), old_div((sz2 - sz), 2), sz, sz))
        av.process_inplace("normalize.edgemean")
        av.process_inplace("mask.sharp", {"outer_radius": old_div(sz, 2) - 1})

        #		display(av)
        f = av.do_fft()
        d = f.calc_az_dist(360, -90.25, 0.5, 10.0, old_div(sz, 2) - 1)
        d = [(i, j * 0.5 - 90) for j, i in enumerate(d)]
        self.btiltaxisval.setText(str(max(d)[1]))


#		print max(d)
#		print min(d)
#		plot(d)

    def projection_align(self, angstep=2.0):
        """realign the current set of boxes using iterative projection matching"""

        taxis = float(self.btiltaxisval.text())

        stack = self.get_boxed_stack()
        for i, p in enumerate(stack):
            ort = Transform({
                "type": "eman",
                "alt": (i - old_div(len(stack), 2)) * angstep,
                "az": -taxis,
                "phi": taxis
            })  # is this right ?
            curshape = self.tiltshapes[i].getShape()

            # Read the reference at the user specified size, then pad it a bit
            ref = self.map3d.project("standard", ort)
            ref.process_inplace("filter.lowpass.gauss", {"cutoff_abs": .1})
            ref.process_inplace("normalize.edgemean")
            ref = ref.get_clip(
                Region(-self.maxshift, -self.maxshift,
                       ref["nx"] + self.maxshift * 2,
                       ref["ny"] + self.maxshift * 2))

            # when we read the alignment target, we pad with actual image data since the object will have moved
            trg = EMData(
                self.imagefile, 0, False,
                Region(curshape[4] - self.maxshift,
                       curshape[5] - self.maxshift, i,
                       curshape[6] - curshape[4] + self.maxshift * 2,
                       curshape[7] - curshape[5] + self.maxshift * 2, 1))
            trg.process_inplace("filter.lowpass.gauss", {"cutoff_abs": .1})
            trg.process_inplace("normalize.edgemean")
            if self.invert: trg.mult(-1.0)

            aln = ref.align("translational", trg, {
                "intonly": 1,
                "maxshift": old_div(self.maxshift * 4, 5)
            })
            trans = aln["xform.align2d"].get_trans()
            print(i, trans[0], trans[1])
            if i > len(stack) - 4: display([ref, trg, aln])
            #			if i==self.curtilt+3 : display((ref,trg,aln,ref.calc_ccf(trg)))

            self.tiltshapes[i].translate(trans[0], trans[1])

    def find_boxes(self, mainshape):
        """Starting with a user selected box at the current tilt, search for the same shape in the entire
	tilt series"""

        if self.imagefile == None: return

        self.tiltshapes[self.curtilt] = mainshape

        lref = None
        for i in range(self.curtilt + 1, self.imageparm["nz"]):
            refshape = self.tiltshapes[i - 1].getShape()

            # Read the reference at the user specified size, then pad it a bit
            ref = EMData(
                self.imagefile, 0, False,
                Region(refshape[4], refshape[5], i - 1,
                       refshape[6] - refshape[4], refshape[7] - refshape[5],
                       1))
            ref.process_inplace("threshold.clampminmax.nsigma",
                                {"nsigma": 4.0})
            ref.process_inplace("filter.lowpass.gauss", {"cutoff_abs": .1})
            ref.process_inplace("normalize.edgemean")
            ref = ref.get_clip(
                Region(-self.maxshift, -self.maxshift,
                       ref["nx"] + self.maxshift * 2,
                       ref["ny"] + self.maxshift * 2))
            if lref != None and self.seqali: ref.add(lref)
            ref.process_inplace(
                "normalize.edgemean")  # older images contribute less
            lref = ref

            # when we read the alignment target, we pad with actual image data since the object will have moved
            trg = EMData(
                self.imagefile, 0, False,
                Region(refshape[4] - self.maxshift,
                       refshape[5] - self.maxshift, i,
                       refshape[6] - refshape[4] + self.maxshift * 2,
                       refshape[7] - refshape[5] + self.maxshift * 2, 1))
            trg.process_inplace("threshold.clampminmax.nsigma",
                                {"nsigma": 4.0})
            trg.process_inplace("filter.lowpass.gauss", {"cutoff_abs": .1})
            trg.process_inplace("normalize.edgemean")

            aln = ref.align(
                "translational", trg, {
                    "intonly": 1,
                    "maxshift": old_div(self.maxshift * 4, 5),
                    "masked": 1
                })
            ref.write_image("dbug.hdf", -1)
            trg.write_image("dbug.hdf", -1)
            aln.write_image("dbug.hdf", -1)
            trans = aln["xform.align2d"].get_trans()
            #			if i==self.curtilt+3 : display((ref,trg,aln,ref.calc_ccf(trg)))

            self.tiltshapes[i] = EMShape([
                "rectpoint", .7, .2, 0, refshape[4] + trans[0],
                refshape[5] + trans[1], refshape[6] + trans[0],
                refshape[7] + trans[1], 1
            ])
            print(i, trans[0], trans[1])

        lref = None
        for i in range(self.curtilt - 1, -1, -1):
            refshape = self.tiltshapes[i + 1].getShape()

            # Read the reference at the user specified size, then pad it a bit
            ref = EMData(
                self.imagefile, 0, False,
                Region(refshape[4], refshape[5], i + 1,
                       refshape[6] - refshape[4], refshape[7] - refshape[5],
                       1))
            ref.process_inplace("filter.lowpass.gauss", {"cutoff_abs": .1})
            ref.process_inplace("normalize.edgemean")
            ref = ref.get_clip(
                Region(-self.maxshift, -self.maxshift,
                       ref["nx"] + self.maxshift * 2,
                       ref["ny"] + self.maxshift * 2))
            if lref != None and self.seqali: ref.add(lref)
            ref.process_inplace("normalize.edgemean")
            lref = ref

            # when we read the alignment target, we pad with actual image data since the object will have moved
            trg = EMData(
                self.imagefile, 0, False,
                Region(refshape[4] - self.maxshift,
                       refshape[5] - self.maxshift, i,
                       refshape[6] - refshape[4] + self.maxshift * 2,
                       refshape[7] - refshape[5] + self.maxshift * 2, 1))
            trg.process_inplace("filter.lowpass.gauss", {"cutoff_abs": .1})
            trg.process_inplace("normalize.edgemean")

            aln = ref.align(
                "translational", trg, {
                    "intonly": 1,
                    "maxshift": old_div(self.maxshift * 4, 5),
                    "masked": 1
                })
            trans = aln["xform.align2d"].get_trans()
            if i == self.curtilt + 3:
                display((ref, trg, aln, ref.calc_ccf(trg)))

            self.tiltshapes[i] = EMShape([
                "rectpoint", .7, .2, 0, refshape[4] + trans[0],
                refshape[5] + trans[1], refshape[6] + trans[0],
                refshape[7] + trans[1], 1
            ])
            print(i, trans[0], trans[1])

        self.update_stack()