Example #1
0
class TelescopeEventView(tk.Frame, object):
    """ A frame showing the camera view of a single telescope """

    def __init__(self, root, telescope, data=None, *args, **kwargs):
        self.telescope = telescope
        super(TelescopeEventView, self).__init__(root)
        self.figure = Figure(figsize=(5, 5), facecolor='none')
        self.ax = Axes(self.figure, [0, 0, 1, 1], aspect=1)
        self.ax.set_axis_off()
        self.figure.add_axes(self.ax)
        self.camera_plot = CameraPlot(telescope, self.ax, data, *args, **kwargs)
        self.canvas = FigureCanvasTkAgg(self.figure, master=self)
        self.canvas.show()
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        self.canvas._tkcanvas.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True)
        self.canvas._tkcanvas.config(highlightthickness=0)

    @property
    def data(self):
        return self.camera_plot.data

    @data.setter
    def data(self, value):
        self.camera_plot.data = value
        self.canvas.draw()
Example #2
0
class MyMplCanvas(FigureCanvas):
    """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.)."""
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        self.fig1 = Figure(figsize=(width, height), dpi=dpi)
        self.fig2 = Figure(figsize=(width, height), dpi=dpi)
        fig3 = Figure(figsize=(width, height), dpi=dpi)
        self.axes1 = self.fig1.add_subplot(223)

        print self.axes1.__class__.__name__
        self.axes2 = self.fig2.add_subplot(221)
        # We want the axes cleared every time plot() is called
        #self.axes.hold(False)
        #self.axes2.hold(False)

        self.compute_initial_figure()

        #
        FigureCanvas.__init__(self, self.fig1)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self,
                                   QtGui.QSizePolicy.Expanding,
                                   QtGui.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        t = arange(0.0, 1.0, 0.01)
        s = sin(2*pi*t)
        axes3 = fig3.add_subplot(1, 2, 2)
        axes3.plot(t,s)
        axes3.set_figure(self.fig1)
        self.fig1.add_axes(axes3)


    def compute_initial_figure(self):
        pass
Example #3
0
 def _figure_default(self):
     """
     figure属性的缺省值,直接创建一个Figure对象
     """
     figure = Figure()
     figure.add_axes([0.05, 0.1, 0.9, 0.85]) #添加绘图区域,四周留有边距
     return figure
    def getImage(self):
        ddict=self.fitresult
        try:
            fig = Figure(figsize=(6,3)) # in inches
            canvas = FigureCanvas(fig)
            ax = fig.add_axes([.15, .15, .8, .8])
            ax.set_axisbelow(True)
            logplot = self.plotDict.get('logy', True)
            if logplot:
                axplot = ax.semilogy
            else:
                axplot = ax.plot
            axplot(ddict['result']['energy'], ddict['result']['ydata'], 'k', lw=1.5)
            axplot(ddict['result']['energy'], ddict['result']['continuum'], 'g', lw=1.5)
            legendlist = ['spectrum', 'continuum', 'fit']
            axplot(ddict['result']['energy'], ddict['result']['yfit'], 'r', lw=1.5)
            fontproperties = FontProperties(size=8)
            if ddict['result']['config']['fit']['sumflag']:
                axplot(ddict['result']['energy'],
                       ddict['result']['pileup'] + ddict['result']['continuum'], 'y', lw=1.5)
                legendlist.append('pileup')
            if matplotlib_version < '0.99.0':
                legend = ax.legend(legendlist,0,
                                   prop = fontproperties, labelsep=0.02)
            else:
                legend = ax.legend(legendlist,0,
                                   prop = fontproperties, labelspacing=0.02)
        except ValueError:
            fig = Figure(figsize=(6,3)) # in inches
            canvas = FigureCanvas(fig)
            ax = fig.add_axes([.15, .15, .8, .8])
            ax.set_axisbelow(True)
            ax.plot(ddict['result']['energy'], ddict['result']['ydata'], 'k', lw=1.5)
            ax.plot(ddict['result']['energy'], ddict['result']['continuum'], 'g', lw=1.5)
            legendlist = ['spectrum', 'continuum', 'fit']
            ax.plot(ddict['result']['energy'], ddict['result']['yfit'], 'r', lw=1.5)
            fontproperties = FontProperties(size=8)
            if ddict['result']['config']['fit']['sumflag']:
                ax.plot(ddict['result']['energy'],
                            ddict['result']['pileup'] + ddict['result']['continuum'], 'y', lw=1.5)
                legendlist.append('pileup')
            if matplotlib_version < '0.99.0':
                legend = ax.legend(legendlist,0,
                               prop = fontproperties, labelsep=0.02)
            else:
                legend = ax.legend(legendlist,0,
                               prop = fontproperties, labelspacing=0.02)

        ax.set_xlabel('Energy')
        ax.set_ylabel('Counts')
        legend.draw_frame(False)

        outfile = self.outdir+"/"+self.outfile+".png"
        try:
            os.remove(outfile)
        except:
            pass

        canvas.print_figure(outfile)
        return self.__getFitImage(self.outfile+".png")
Example #5
0
class Window():
    def __init__(self, master):
        self.frame = Tk.Frame(master)
        self.f = Figure( figsize=(10, 9), dpi=80 )
        self.ax0 = self.f.add_axes( (0.05, .05, .50, .50), axisbg=(.75,.75,.75), frameon=False)
        self.ax1 = self.f.add_axes( (0.05, .55, .90, .45), axisbg=(.75,.75,.75), frameon=False)
        self.ax2 = self.f.add_axes( (0.55, .05, .50, .50), axisbg=(.75,.75,.75), frameon=False)


        self.ax0.set_xlabel( 'Time (s)' )
        self.ax0.set_ylabel( 'Frequency (Hz)' )
        self.ax0.plot(np.max(np.random.rand(100,10)*10,axis=1),"r-")
        self.ax1.plot(np.max(np.random.rand(100,10)*10,axis=1),"g-")
        self.ax2.pie(np.random.randn(10)*100)


        self.frame = Tk.Frame( root )
        self.frame.pack(side=Tk.LEFT, fill=Tk.BOTH, expand=1)

        self.canvas = FigureCanvasTkAgg(self.f, master=self.frame)
        self.canvas.get_tk_widget().pack(side=Tk.TOP, fill=Tk.BOTH, expand=1)
        self.canvas.show()

        self.toolbar = NavigationToolbar2TkAgg(self.canvas, self.frame )
        self.toolbar.pack()
        self.toolbar.update()
Example #6
0
def plot_LO_horiz_stripes():
    '''
    This uses data that has been processed through pik1 but w/ the hanning filter
    disabled s.t. the stripes are more readily apparent.
    '''
    fig = Figure((10, 4))
    canvas = FigureCanvas(fig)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis('off')
    plot_radar(ax, 'TOT_stacked_nofilter', 3200, None, [135000, 230000])
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    TOT_bounds, VCD_bounds, THW_bounds = find_quiet_regions()
    ax.vlines(TOT_bounds[0:2], 0, 3200, colors='red', linewidth=3, linestyles='dashed')    
    plot_bounding_box(ax, TOT_bounds, '', linewidth=4)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

    canvas.print_figure('../FinalReport/figures/TOT_LO_stripes_d.jpg')

    zoom_bounds = [5000, 9000, 3000, 3200]
    zoom_fig = Figure((2.5, 9))
    zoom_canvas = FigureCanvas(zoom_fig)
    zoom_ax = zoom_fig.add_axes([0, 0, 1, 1])
    zoom_ax.axis('off')
    plot_radar(zoom_ax, 'TOT_stacked_nofilter', 3200, zoom_bounds, [135000, 230000])
    zoom_canvas.print_figure('../FinalReport/figures/TOT_LO_stripes_zoom.jpg')
Example #7
0
class Canvas(FigureCanvas):
    def __init__(self,parent,dpi=300.0):
        size = parent.size()
        self.dpi = dpi
        self.width = size.width() / dpi
        self.height = size.height() / dpi
        self.figure = Figure(figsize=(self.width, self.height), dpi=self.dpi, facecolor='white', edgecolor='k', frameon=True)
        self.figure.subplots_adjust(left=0.00, bottom=0.00, right=1, top=1, wspace=None, hspace=None)
        # left, bottom, width, height
        self.axes = self.figure.add_axes([0.10,0.15,0.80,0.80])
        self.axes.set_xticks([])
        self.axes.set_yticks([])
        self.axes.set_axis_off()
        self.axcb = self.figure.add_axes([0.1,    0.05, 0.8, 0.05])
        self.axcb.set_xticks([])
        self.axcb.set_yticks([])
        FigureCanvas.__init__(self, self.figure)
        self.updateGeometry()
        self.draw()
        self.setParent(parent)

    def on_pre_draw(self):
        pass

    def on_draw(self):
        raise NotImplementedError

    def on_post_draw(self):
        sleep(0.005)

    def redraw(self):
        self.on_pre_draw()
        self.on_draw()
        self.on_post_draw()
class PlotFigure(Frame):

    def __init__(self):
        Frame.__init__(self, None, -1, "Test embedded wxFigure")

        self.fig = Figure((5,4), 75)
        self.canvas = FigureCanvasWxAgg(self, -1, self.fig)
        self.toolbar = NavigationToolbar2Wx(self.canvas)
        self.toolbar.Realize()

        # On Windows, default frame size behaviour is incorrect
        # you don't need this under Linux
        tw, th = self.toolbar.GetSizeTuple()
        fw, fh = self.canvas.GetSizeTuple()
        self.toolbar.SetSize(Size(fw, th))

        # Create a figure manager to manage things

        # Now put all into a sizer
        sizer = BoxSizer(VERTICAL)
        # This way of adding to sizer allows resizing
        sizer.Add(self.canvas, 1, LEFT|TOP|GROW)
        # Best to allow the toolbar to resize!
        sizer.Add(self.toolbar, 0, GROW)
        self.SetSizer(sizer)
        self.Fit()
        EVT_TIMER(self, TIMER_ID, self.onTimer)

    def init_plot_data(self):
        # jdh you can add a subplot directly from the fig rather than
        # the fig manager
        a = self.fig.add_axes([0.075,0.1,0.75,0.85])
        cax = self.fig.add_axes([0.85,0.1,0.075,0.85])
        self.x = npy.empty((120,120))
        self.x.flat = npy.arange(120.0)*2*npy.pi/120.0
        self.y = npy.empty((120,120))
        self.y.flat = npy.arange(120.0)*2*npy.pi/100.0
        self.y = npy.transpose(self.y)
        z = npy.sin(self.x) + npy.cos(self.y)
        self.im = a.imshow( z, cmap=cm.jet)#, interpolation='nearest')
        self.fig.colorbar(self.im,cax=cax,orientation='vertical')

    def GetToolBar(self):
        # You will need to override GetToolBar if you are using an
        # unmanaged toolbar in your frame
        return self.toolbar

    def onTimer(self, evt):
        self.x += npy.pi/15
        self.y += npy.pi/20
        z = npy.sin(self.x) + npy.cos(self.y)
        self.im.set_array(z)
        self.canvas.draw()
        #self.canvas.gui_repaint()  # jdh wxagg_draw calls this already

    def onEraseBackground(self, evt):
        # this is supposed to prevent redraw flicker on some X servers...
        pass
Example #9
0
	def _figure_default(self):
		'''
		set the defaults for the figure
		'''
		figure = Figure()
		figure.add_axes([0.05, 0.04, 0.9, 0.92])
		figure.axes[0].get_xaxis().set_ticks([])
		figure.axes[0].get_yaxis().set_ticks([])
		return figure
Example #10
0
def time_basic_plot():

    fig = Figure()
    canvas = FigureCanvas(fig)

    ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS)
    fig.add_axes(ax)

    ax.set_xlim(-0.5, 148.5)
    ax.set_ylim(-0.5, 148.5)

    canvas.draw()
Example #11
0
    def __init__(self, parent=None):
        fig = Figure()
        super(ImageCanvas, self).__init__(fig)
        matplotlib.rcParams.update({'font.size': 10})

        # Generate axes on the figure
        self.axes_err = fig.add_axes([0.075, 0.05, 0.775, 0.1], facecolor='k') # error plot
        self.axes_ftt = fig.add_axes([0.075, 0.175, 0.775, 0.675], frame_on=True)  # FTT image frame
        self.axes_vsum = fig.add_axes([0.075, 0.875, 0.775, 0.1], facecolor='k')  # vertical summaiton
        self.axes_hsum = fig.add_axes([0.875, 0.175, 0.1, 0.675], facecolor='k')  # horizontal summation
        # The image, plot random data for initialization
        self.ftt_image = self.axes_ftt.imshow(np.random.rand(512,512), 
        						cmap='gray',
        						interpolation='none',
        						extent=(1,512, 512,1),
                                norm=LogNorm(vmin=0.001, vmax=1)
        					)
        self.axes_ftt.set_axis_off()
        self.axes_ftt.set_aspect('auto')

        # Summation plots, generate axes and plot random data for initialization
        self.xlims = np.linspace(0, 511, 512)
        self.ylims = np.linspace(0, 511, 512)
        self.errlims = np.linspace(0, 99, 100)

        self.axes_vsum.relim()
        self.axes_vsum.autoscale_view()
        self.axes_hsum.relim()
        self.axes_hsum.autoscale_view()
        self.axes_err.relim()
        self.axes_err.autoscale_view()

        # error buffer
        self.errs = np.zeros((100,2))

        self.vsum_lines, = self.axes_vsum.plot(self.ylims, np.random.rand(512,1), color='w', linewidth=0.5)
        self.hsum_lines, = self.axes_hsum.plot(np.random.rand(512,1), self.xlims, color='w', linewidth=0.5)
        self.xerr_lines, = self.axes_err.plot(self.errlims, self.errs[:,0], color='r') 
        self.yerr_lines, = self.axes_err.plot(self.errlims, self.errs[:,1], color='y')

        # Initialize fiber marker list and add to axes
        self.marker_lines = []
        ratio = (512/12.0, 512/12.0)
        x = np.array([[-ratio[0]*1.5, -ratio[0]*0.5], [ratio[0]*0.5, ratio[0]*1.5], [0.0, 0.0], [0.0, 0.0]]) + fiber_loc[0]
        y = np.array([[0.0, 0.0], [0.0, 0.0], [-ratio[1]*1.5, -ratio[1]*0.5], [ratio[1]*0.5, ratio[1]*1.5]]) + fiber_loc[1]

        for ind in range(len(x)):
            self.marker_lines.append(mlines.Line2D(x[ind], y[ind], color='g', linewidth=2.0)) #.set_data(x[ind], y[ind])

        for line in self.marker_lines:
            self.axes_ftt.add_line(line)

        cid = self.mpl_connect('button_press_event', self.on_press)
    def time_basic_plot(self):

        fig = Figure()
        canvas = FigureCanvas(fig)

        ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7],
                     wcs=WCS(self.msx_header))
        fig.add_axes(ax)

        ax.set_xlim(-0.5, 148.5)
        ax.set_ylim(-0.5, 148.5)

        canvas.draw()
Example #13
0
def time_basic_plot_with_grid():

    fig = Figure()
    canvas = FigureCanvas(fig)

    ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS)
    fig.add_axes(ax)

    ax.grid(color='red', alpha=0.5, linestyle='solid')

    ax.set_xlim(-0.5, 148.5)
    ax.set_ylim(-0.5, 148.5)

    canvas.draw()
Example #14
0
def plot_quiet_regions():
    # Plot the region that the noise was calculated from...
    # For TOT...
    TOT_bounds, VCD_bounds, THW_bounds = find_quiet_regions()

    # TOT/JKB2d/X16a gives: 
    # mag = 32809.224658469, phs = -0.90421798501485484
    # VCD/JKB2g/DVD01a gives:
    # mag = 15720.217174332585, phs = -0.98350090576267946
    # THW/SJB2/DRP02a gives:
    # 26158.900202734963, phs = 1.6808311318828895
    
    TOT_fig = Figure((10, 8))
    TOT_canvas = FigureCanvas(TOT_fig)
    TOT_ax = TOT_fig.add_axes([0, 0, 1, 1])
    TOT_ax.axis('off')
    plot_radar(TOT_ax, 'TOT_LO', 3200, None, [135000, 234000])
    xlim = TOT_ax.get_xlim()
    ylim = TOT_ax.get_ylim()
    TOT_ax.vlines(TOT_bounds[0:2], 0, 3200, colors='red', linewidth=3, linestyles='dashed')
    plot_bounding_box(TOT_ax, TOT_bounds, '', linewidth=4)
    TOT_ax.set_xlim(xlim)
    TOT_ax.set_ylim(ylim)
    TOT_canvas.print_figure('../FinalReport/figures/TOT_quiet_region.jpg')

    VCD_fig = Figure((10, 8))
    VCD_canvas = FigureCanvas(VCD_fig)
    VCD_ax = VCD_fig.add_axes([0, 0, 1, 1])
    VCD_ax.axis('off')
    plot_radar(VCD_ax, 'VCD_LO', 3200, None, [135000, 234000])
    xlim = VCD_ax.get_xlim()
    ylim = VCD_ax.get_ylim()
    VCD_ax.vlines(VCD_bounds[0:2], 0, 3200, colors='red', linewidth=3, linestyles='dashed')
    plot_bounding_box(VCD_ax, VCD_bounds, '', linewidth=4)
    VCD_ax.set_xlim(xlim)
    VCD_ax.set_ylim(ylim)
    VCD_canvas.print_figure('../FinalReport/figures/VCD_quiet_region.jpg')

    THW_fig = Figure((10, 8))
    THW_canvas = FigureCanvas(THW_fig)
    THW_ax = THW_fig.add_axes([0, 0, 1, 1])
    THW_ax.axis('off')
    plot_radar(THW_ax, 'THW_LO', 3200, None, [135000, 234000])
    xlim = THW_ax.get_xlim()
    ylim = THW_ax.get_ylim()
    THW_ax.vlines(THW_bounds[0:2], 0, 3200, colors='red', linewidth=3, linestyles='dashed')
    plot_bounding_box(THW_ax, THW_bounds, '', linewidth=4)
    THW_ax.set_xlim(xlim)
    THW_ax.set_ylim(ylim)
    THW_canvas.print_figure('../FinalReport/figures/THW_quiet_region.jpg')
Example #15
0
class CheckMeansPanel(wx.Panel):
	def __init__(self,parent,ID=-1,label="",pos=wx.DefaultPosition,size=(100,25)):
		#(0) Initialize panel:
		wx.Panel.__init__(self,parent,ID,pos,size,wx.STATIC_BORDER,label)
		self.SetMinSize(size)
		#(1) Create Matplotlib figure:
		self.figure = Figure(facecolor=(0.8,)*3)
		self.canvas = FigureCanvasWxAgg(self, -1, self.figure)
		self._resize()
		self._create_axes()
		# self.cidAxisEnter   = self.canvas.mpl_connect('axes_enter_event', self.callback_enter_axes)
		# self.cidAxisLeave   = self.canvas.mpl_connect('axes_leave_event', self.callback_leave_axes)
		
	def _create_axes(self):
		self.ax  = self.figure.add_axes((0,0,1,1), axisbg=[0.5]*3)
		self.cax = self.figure.add_axes((0.1,0.05,0.8,0.02), axisbg=[0.5]*3)
		pyplot.setp(self.ax, xticks=[], yticks=[])

	def _resize(self):
		szPixels = tuple( self.GetClientSize() )
		self.canvas.SetSize(szPixels)
		szInches = float(szPixels[0])/self.figure.get_dpi() ,  float(szPixels[1])/self.figure.get_dpi()
		self.figure.set_size_inches( szInches[0] , szInches[1] )
		
	
	# def callback_enter_axes(self, event):
	# 	print 'buta-san in'
	# def callback_leave_axes(self, event):
	# 	print 'buta-san out'
	
	def cla(self):
		self.ax.cla()
		self.cax.cla()
		# self.ax.set_position([0,0,1,1])
		self.ax.set_axis_bgcolor([0.5]*3)
		pyplot.setp(self.ax, xticks=[], yticks=[], xlim=(0,1), ylim=(0,1))
		self.ax.axis('tight')
		self.canvas.draw()
	
	def plot(self, I):
		I = np.asarray(I, dtype=float)
		I[I==0] = np.nan
		self.ax.imshow(I, interpolation='nearest', origin='lower')
		pyplot.setp(self.ax, xticks=[], yticks=[])
		self.ax.set_axis_bgcolor([0.05]*3)
		self.ax.axis('image')
		cb = pyplot.colorbar(cax=self.cax, mappable=self.ax.images[0], orientation='horizontal')
		pyplot.setp(cb.ax.get_xticklabels(), color='0.5')
		self.canvas.draw()
Example #16
0
def time_contourf_with_transform():

    fig = Figure()
    canvas = FigureCanvas(fig)

    ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS)
    fig.add_axes(ax)

    ax.contourf(DATA, transform=ax.get_transform(TWOMASS_WCS))

    # The limits are to make sure the contours are in the middle of the result
    ax.set_xlim(32.5, 150.5)
    ax.set_ylim(-64.5, 64.5)

    canvas.draw()
Example #17
0
class ViewCanvas(FigureCanvas):
    """
    Viewer class for matplotlib 2D plotting widget
    """

    def __init__(self, parent=None, width=6, height=4, dpi=110):
        """
        Init canvas.
        """

        self.fig = Figure(figsize=(width, height), dpi=dpi)

        # Here one can adjust the position of the CTX plot area.
        self.axes = self.fig.add_axes([0.1, 0, 1, 1])
        # self.axes = self.fig.add_subplot(111)

        FigureCanvas.__init__(self, self.fig)

        layout = QVBoxLayout(parent)
        layout.addWidget(self)
        parent.setLayout(layout)

        FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        # next too lines are needed in order to catch keypress events in plot canvas by mpl_connect()
        FigureCanvas.setFocusPolicy(self, QtCore.Qt.ClickFocus)
        FigureCanvas.setFocus(self)
Example #18
0
class CanvasPanel(wx.Panel):
    def __init__(self, parent):
        wx.Panel.__init__(self, parent)
        self.figure = Figure(facecolor="black")
        self.axes = self.figure.add_axes((0, 0, 1, 1))
        self.canvas = FigureCanvas(self, -1, self.figure)
        self.sizer = wx.BoxSizer(wx.VERTICAL)
        self.sizer.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW)
        self.SetSizer(self.sizer)
        self.Fit()

        self.axes.set_axis_bgcolor('black')
        self.delta = 0
        
        t = arange(0.0, 3.0, 0.01)
        s = sin(2 * pi * (t))
        self.datatoplot, = self.axes.plot(t, s, 'w', linewidth=3.0)
        self.axes.set_ylim((-10,10))
        self.axes.grid(True, color="w")
        
    def draw(self):
        self.delta = (self.delta + 0.05) % 1
        t = arange(0.0, 3.0, 0.01)
        s = sin(2 * pi * (t-self.delta))
        self.datatoplot.set_ydata(s)
        
        wx.CallLater(10, self.draw)
        self.canvas.draw()
        self.canvas.Refresh()
Example #19
0
class GraphPanel(wx.Panel):
    def __init__(self, parent, streamValuesHistory):
        wx.Panel.__init__(self, parent)
        self._streamValuesHistory = streamValuesHistory

        self.figure = Figure()
        self.figure.patch.set_facecolor('black')
        self.axes = self.figure.add_axes([0.1, 0.025, 0.9, 0.95])

        self.canvas = FigureCanvas(self, -1, self.figure)
        self.sizer = wx.BoxSizer(wx.VERTICAL)
        self.sizer.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW)
        self.SetSizer(self.sizer)
        self.Fit()
        
        self._colors = [(0.0, 1.0, 0.0), (1.0, 0.0, 0.0), (1.0, 1.0, 0.0), (0.0, 1.0, 1.0), (0.0, 0.0, 1.0), (1.0, 0.0, 1.0)]

    def UpdateGraphs(self):
        self.axes.clear()
        self.axes.patch.set_facecolor((0, 0, 0))
        self.axes.grid(b=True, color=(0, 0.1, 0), which='major', linestyle='-', linewidth=1)
        self.axes.yaxis.set_tick_params(labelcolor=(0.6, 0.6, 0.6))
        self.axes.set_axisbelow(True) 

        """Draw data."""
        iColor = 0
        for streamValues in self._streamValuesHistory.itervalues():
            valuesNumber = int(self.axes.get_window_extent().width)
            X = range(0, valuesNumber)
            Y = [streamValues[-min(valuesNumber, len(streamValues))]] * (valuesNumber - len(streamValues)) + streamValues[-valuesNumber:]
            self.axes.plot( X, Y, color=self._colors[iColor%len(self._colors)], linewidth=1)
            iColor+=1
            
        self.canvas.draw()
    def buildMetadataImage(self, layerInfoList, width):
        """
        Creates the metadata caption for figures in the style used by WMSViz.
        """
        self.metadataItems = self._buildMetadataItems(layerInfoList)
        self.width = width
        
        width=self.width;height=1600;dpi=100;transparent=False
        figsize=(width / float(dpi), height / float(dpi))
        fig = Figure(figsize=figsize, dpi=dpi, facecolor='w', frameon=(not transparent))
        axes = fig.add_axes([0.04, 0.04, 0.92, 0.92],  frameon=True,xticks=[], yticks=[])
        renderer = Renderer(fig.dpi)


        title, titleHeight = self._drawTitleToAxes(axes, renderer)
        
        txt, textHeight = self._drawMetadataTextToAxes(axes, renderer, self.metadataItems)

        # fit the axis round the text

        pos = axes.get_position()
        newpos = Bbox( [[pos.x0,  pos.y1 - (titleHeight + textHeight) / height], [pos.x1, pos.y1]] )
        axes.set_position(newpos )

        # position the text below the title

        newAxisHeight = (newpos.y1 - newpos.y0) * height
        txt.set_position( (0.02, 0.98 - (titleHeight/newAxisHeight) ))

        for loc, spine in axes.spines.iteritems():
            spine.set_edgecolor(borderColor)
        
        # Draw heading box
        
        headingBoxHeight = titleHeight - 1
        
        axes.add_patch(Rectangle((0, 1.0 - (headingBoxHeight/newAxisHeight)), 1, (headingBoxHeight/newAxisHeight),
                       facecolor=borderColor,
                      fill = True,
                      linewidth=0))

        # reduce the figure height
        
        originalHeight = fig.get_figheight()
        pos = axes.get_position()
        topBound = 20 / float(dpi)
        
        textHeight = (pos.y1 - pos.y0) * originalHeight
        
        newHeight = topBound * 2 + textHeight
        
        # work out the new proportions for the figure
        
        border = topBound / float(newHeight)
        newpos = Bbox( [[pos.x0,  border], [pos.x1, 1 - border]] )
        axes.set_position(newpos )
        
        fig.set_figheight(newHeight)
        
        return image_util.figureToImage(fig)
Example #21
0
def daily_timseries( ts ):
  fig = Figure( ( 2.56, 2.56 ), 300 )
  canvas = FigureCanvas(fig)
  ax = fig.add_axes((0,0,1,1))

  ax.set_ylim( [ 0 , 500 ] )

  preferspan = ax.axhspan( SAFE[0], SAFE[1],
                           facecolor='g', alpha=0.2,
                           edgecolor = '#003333',
                           linewidth=1
                         )
  # XXX: gets a list of days.
  timestamps = glucose.get_days( ts.time )
  halfday = dates.relativedelta( hours=12 )
  soleday = dates.relativedelta( days=1 )
  xmin, xmax = ( timestamps[ 0 ], timestamps[ 1 ] + soleday )
  ax.set_xlim( [ xmin, xmax ] )
  #fig.autofmt_xdate( )
  #plot_glucose_stems( ax, ts )
  plt.setp(ax.get_xminorticklabels(), visible=False )
  plt.setp(ax.get_xmajorticklabels(), visible=False )
  plt.setp(ax.get_ymajorticklabels(), visible=False )
  plt.setp(ax.get_yminorticklabels(), visible=False )

  ax.grid(True)

  xmin, xmax = ax.get_xlim( )
  log.info( pformat( {
    'xlim': [ dates.num2date( xmin ), dates.num2date( xmax ) ],
  } ) )

  return canvas
Example #22
0
def pro2cap(name='C2'):
	fig=Figure()
	fig = plt.figure(figsize=(5, 3.75))
	ax=fig.add_axes([0.16, 0.13, 0.74, 0.77])
	ax.grid(True)
	cv=FigureCanvas(fig)
	yield_disp=6
	data=np.loadtxt(r'TestRES\\'+name+'.out',skiprows =1)
	disp=data[:,1]/yield_disp
	force=data[:,2]
	bb=bacbone(name)
	x1=bb[:,1]/yield_disp
	f1=bb[:,2]
	(Vp,Vs,Va)=SheerEQNS(name,fc=58.0)
	bbplot=ax.plot(x1,f1,linewidth=1.5,color='k')
	priplot=ax.plot(Vp[:,0],Vp[:,1]/1000,-Vp[:,0],-Vp[:,1]/1000,linewidth=1,color='r')
	senplot=ax.plot(Vs[:,0],Vs[:,1]/1000,-Vs[:,0],-Vs[:,1]/1000,linewidth=1,color='b')
	aciplot=ax.plot(Va[:,0],Va[:,1]/1000,-Va[:,0],-Va[:,1]/1000,linewidth=1,color='g')
	#df=ax.plot(disp,force,linewidth=1)
	ax.set_xlabel(u'延性系数')
	ax.set_ylabel(u'侧向力[kN]')	
	le=ax.legend([bbplot[0],aciplot[0],priplot[0],senplot[0]],
		[u'骨架曲线',u'ACI318-规范',u'Priestly抗剪能力',u'Sezen抗剪能力'],
		loc='upper left', fancybox=True, shadow=True,numpoints=1)
	cv.print_figure(name+'.png',dpi=300)
	return
Example #23
0
        def serialize(dataset):
            fix_map_attributes(dataset)
            fig = Figure(figsize=figsize, dpi=dpi)
            fig.figurePatch.set_alpha(0.0)
            ax = fig.add_axes([0.05, 0.05, 0.45, 0.85])
            ax.axesPatch.set_alpha(0.5)

            # Plot requested grids.
            layers = [layer for layer in query.get('LAYERS', '').split(',')
                    if layer] or [var.id for var in walk(dataset, GridType)]
            layer = layers[0]
            names = [dataset] + layer.split('.')
            grid = reduce(operator.getitem, names)

            actual_range = self._get_actual_range(grid)
            norm = Normalize(vmin=actual_range[0], vmax=actual_range[1])
            cb = ColorbarBase(ax, cmap=get_cmap(cmap), norm=norm,
                    orientation='vertical')
            for tick in cb.ax.get_yticklabels():
                tick.set_fontsize(14)
                tick.set_color('white')
                #tick.set_fontweight('bold')

            # Save to buffer.
            canvas = FigureCanvas(fig)
            output = StringIO() 
            canvas.print_png(output)
            if hasattr(dataset, 'close'): dataset.close()
            return [ output.getvalue() ]
Example #24
0
    def _figure(self,instruments, instrument_color, instrument_names, flowcells):
        fig = Figure(figsize=[12, 8])
        ax = fig.add_axes([0.1, 0.2, 0.8, 0.7])

        locs, labels = [], []
        for instrument in instruments:
            color = instrument_color[instrument]
            X = [(k, i["q30"]) for k, i in flowcells.items() if "q30" in i and i["instrument"] == instrument]
            y = [x[1] for x in X]
            x = [parser.parse(x[0][0].split("_")[0]) for x in X]

            label = instrument_names.get(instrument,instrument)
            ax.scatter(x, y, c=color, s=100, marker='o', label=label)

            locs += x
            labels += [a[0][0].split("_")[0] for a in X]

        L = list(set(zip(locs, labels)))

        ax.set_xticks([l[0] for l in L])
        ax.set_xticklabels([l[1] for l in L], rotation=90)
        ax.set_xlabel("Run")
        ax.set_ylabel("%")
        ax.set_ylim([0, 100])

        ax.legend(loc="lower right", bbox_to_anchor=(1, 1), ncol=5)

        FigureCanvasAgg(fig)

        return fig
Example #25
0
    def depthReport(self):
        """
            Method to draw a histogram of the number of new links
            discovered at each depth.
            (i.e. show how many links are required to reach a link)
        """
        try:
            from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
            self.FigureCanvas = FigureCanvas
            from matplotlib.figure import Figure
            self.Figure = Figure
            from numpy import arange
        except ImportError:
            return
        self.reporter("Analysis of link depth")
        fig = Figure(figsize=(4, 2.5))
        # Draw a histogram
        width = 0.9
        rect = [0.12, 0.08, 0.9, 0.85]
        ax = fig.add_axes(rect)
        left = arange(len(self.linkDepth))
        plot = ax.bar(left, self.linkDepth, width=width)
        # Add the x axis labels
        ax.set_xticks(left+(width*0.5))
        ax.set_xticklabels(left)

        chart = StringIO()
        canvas = self.FigureCanvas(fig)
        canvas.print_figure(chart)
        image = chart.getvalue()
        import base64
        base64Img = base64.b64encode(image)
        image = "<img src=\"data:image/png;base64,%s\">" % base64Img
        self.reporter(image)
class RewardWidget(QtGui.QWidget):
    def __init__(self, parent=None):
        super(RewardWidget, self).__init__(parent)
        self.samples = 0
        self.resize(1500, 100)

        self.figure = Figure()

        self.canvas = FigureCanvasQTAgg(self.figure)

        self.axes = self.figure.add_axes([0, 0, 1, 1])

        self.layoutVertical = QtGui.QVBoxLayout(self)
        self.layoutVertical.addWidget(self.canvas)

    def set_time_range(self, time_range):
    	self.time_range = time_range
    	range_length = int((time_range[1] - time_range[0]))
    	self.rewards = [0] * (range_length * 100)

    def add_data(self, data_range, data):
    	begin = int(round((data_range[0] - self.time_range[0]) * 100))
    	end = int(round((data_range[1] - self.time_range[0]) * 100)) + 1
    	self.rewards[begin:end] = [data for x in range(end-begin)]
    	

    	range_length = int((self.time_range[1] - self.time_range[0]))
    	self.axes.clear()
    	self.axes.set_xlim([0,range_length*100])
    	self.axes.set_ylim([-10.0,10.0])
        self.axes.plot(self.rewards)

        self.canvas.draw()
        #print(self.rewards)
Example #27
0
    def __init__(self, masterWindow, style=None, scheme=None):
        self.masterWindow = masterWindow
        self.legend = False

        plotArea = masterWindow.ui.plotArea

        #create the plotting canvas and its toolbar and add them
        tfig = Figure()
        tfig.set_facecolor('white')
        self.canvas = FigureCanvasQTAgg(tfig)
        self.navbar = AstonNavBar(self.canvas, masterWindow)
        plotArea.addWidget(self.navbar)
        plotArea.addWidget(self.canvas)

        #TODO this next line is the slowest in this module
        #self.plt = tfig.add_subplot(111, frameon=False)
        self.plt = tfig.add_axes((0.05, 0.1, 0.9, 0.85), frame_on=False)
        self.plt.xaxis.set_ticks_position('none')
        self.plt.yaxis.set_ticks_position('none')
        self.patches = []
        self.cb = None

        #TODO: find a way to make the axes fill the figure properly
        #tfig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95)
        #tfig.tight_layout(pad=2)

        self.canvas.setFocusPolicy(Qt.ClickFocus)
        self.canvas.mpl_connect('button_press_event', self.mousedown)
        self.canvas.mpl_connect('button_release_event', self.mouseup)
        self.canvas.mpl_connect('scroll_event', self.mousescroll)

        self.highlight = None
Example #28
0
def plot(request):
    """
    http://stackoverflow.com/a/5515994/185820
    """

    import cStringIO
    from matplotlib.figure import Figure
    from matplotlib.backends.backend_agg import FigureCanvasAgg

    x, y = 4, 4
    qs = parse_qs(request.query_string)
    if 'x' in qs:
        x = int(qs['x'][0])
    if 'y' in qs:
        y = int(qs['y'][0])
    fig = Figure(figsize=[x, y])
    ax = fig.add_axes([.1, .1, .8, .8])
    ax.scatter([1, 2], [3, 4])
    canvas = FigureCanvasAgg(fig)

    # write image data to a string buffer and get the PNG image bytes
    buf = cStringIO.StringIO()
    canvas.print_png(buf)
    data = buf.getvalue()

    # write image bytes back to the browser
    response = Response(data)
    response.content_type = 'image/png'
    response.content_length = len(data)
    return response
    def check_and_plot(self, A_nn, A0_nn, digits, keywords=''):
        # Construct fingerprint of input matrices for comparison
        fingerprint = np.array([md5_array(A_nn, numeric=True),
                                md5_array(A0_nn, numeric=True)])

        # Compare fingerprints across all processors
        fingerprints = np.empty((world.size, 2), np.int64)
        world.all_gather(fingerprint, fingerprints)
        if fingerprints.ptp(0).any():
            raise RuntimeError('Distributed matrices are not identical!')

        # If assertion fails, catch temporarily while plotting, then re-raise
        try:
            self.assertAlmostEqual(np.abs(A_nn-A0_nn).max(), 0, digits)
        except AssertionError:
            if world.rank == 0 and mpl is not None:
                from matplotlib.figure import Figure
                fig = Figure()
                ax = fig.add_axes([0.0, 0.1, 1.0, 0.83])
                ax.set_title(self.__class__.__name__)
                im = ax.imshow(np.abs(A_nn-A0_nn), interpolation='nearest')
                fig.colorbar(im)
                fig.text(0.5, 0.05, 'Keywords: ' + keywords, \
                    horizontalalignment='center', verticalalignment='top')

                from matplotlib.backends.backend_agg import FigureCanvasAgg
                img = 'ut_hsops_%s_%s.png' % (self.__class__.__name__, \
                    '_'.join(keywords.split(',')))
                FigureCanvasAgg(fig).print_figure(img.lower(), dpi=90)
            raise
Example #30
0
def small_plot(data):
    x, y, area, colors = data

    fig = Figure(figsize=(3, 3), dpi=80)
    axes = fig.add_axes([0.0, 0.0, 1.0, 1.0], alpha=1.0)
    axes.set_frame_on(False)
    axes.set_xticks([])
    axes.set_yticks([])
    # axes.set_xlim(5, 6)
    # axes.set_ylim(5, 6)

    axes.scatter(x, y, s=area, c=colors, alpha=0.5)

    axes.set_xlim(4, 7)
    axes.set_ylim(4, 7)

    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    # canvas = FigureCanvasQTAgg(fig)
    # buf = canvas.tostring_rgb()
    buf = fig.canvas.tostring_rgb()

    ncols, nrows = fig.canvas.get_width_height()
    img = np.fromstring(buf, dtype=np.uint8).reshape(nrows, ncols, 3)

    return img
Example #31
0
    "year": 1941,
}

# Figure is 768x1024
fig = Figure(
    figsize=(7.68, 10.24),
    dpi=100,
    facecolor="white",
    edgecolor="black",
    linewidth=0.0,
    frameon=False,
    subplotpars=None,
    tight_layout=None,
)
canvas = FigureCanvas(fig)
ax_full = fig.add_axes([0, 0, 1, 1])
ax_full.set_xlim([0, 1])
ax_full.set_ylim([0, 1])
ax_full.set_axis_off()

# Paint the background white - why is this needed?
ax_full.add_patch(
    matplotlib.patches.Rectangle((0, 0), 1, 1, fill=True, facecolor="white"))

# Box with the data in
topLeft = (0.07 + imp["xshift"] / 768, 0.725 + imp["yshift"] / 1024)
topRight = (
    0.93 + imp["xshift"] / 768 + (imp["xscale"] - 1) * 0.86,
    0.725 + imp["yshift"] / 1024,
)
bottomLeft = (0.07 + imp["xshift"] / 768, 0.325 + imp["yshift"] / 1024)
             dpi=100,
             facecolor=(0.88, 0.88, 0.88, 1),
             edgecolor=None,
             linewidth=0.0,
             frameon=False,
             subplotpars=None,
             tight_layout=None)
canvas = FigureCanvas(fig)

# Pressure map
matplotlib.rc('image', aspect='auto')
projection = ccrs.RotatedPole(pole_longitude=60.0,
                              pole_latitude=0.0,
                              central_rotated_longitude=270.0)
extent = [-180, 180, -90, 90]
ax_prmsl = fig.add_axes([0.01, 0.505, 0.485, 0.485], projection=projection)
ax_prmsl.set_extent(extent, crs=projection)
contour_plot(ax_prmsl,
             prmsl,
             prmsl_r,
             scale=0.01,
             levels=numpy.arange(870, 1050, 7))
ax_t2m = fig.add_axes([0.505, 0.51, 0.485, 0.485], projection=projection)
ax_t2m.set_extent(extent, crs=projection)
contour_plot(ax_t2m, t2m, t2m_r, scale=1.0, levels=numpy.arange(230, 310, 10))
ax_z500 = fig.add_axes([0.01, 0.01, 0.485, 0.485], projection=projection)
ax_z500.set_extent(extent, crs=projection)
contour_plot(ax_z500,
             z500,
             z500_r,
             scale=1.0,
canvas = FigureCanvas(fig)
font = {
    'family': 'sans-serif',
    'sans-serif': 'Arial',
    'weight': 'normal',
    'size': 16
}
matplotlib.rc('font', **font)

# UK-centred projection
projection = ccrs.RotatedPole(pole_longitude=177.5, pole_latitude=35.5)
scale = 12
extent = [scale * -1, scale, scale * -1 * math.sqrt(2), scale * math.sqrt(2)]

# On the left - spaghetti-contour plot of original 20CRv3
ax_left = fig.add_axes([0.005, 0.01, 0.495, 0.98], projection=projection)
ax_left.set_axis_off()
ax_left.set_extent(extent, crs=projection)
ax_left.background_patch.set_facecolor((0.88, 0.88, 0.88, 1))
mg.background.add_grid(ax_left)
land_img_left = ax_left.background_img(name='GreyT', resolution='low')

# 20CRv3 data
prmsl = twcr.load('prmsl', dte, version='4.5.1')

# 20CRv3 data
prmsl = twcr.load('prmsl', dte, version='4.5.1')
obs_t = twcr.load_observations_fortime(dte, version='4.5.1')
# Filter to those assimilated and near the UK
obs_s = obs_t.loc[((obs_t['Latitude'] > 0) & (obs_t['Latitude'] < 90)) & (
    (obs_t['Longitude'] > 240) | (obs_t['Longitude'] < 100))].copy()
Example #34
0
    def __init__(self,
                 parent,
                 imgpanel,
                 color=0,
                 colormap_list=None,
                 default=None,
                 cmap_callback=None,
                 title='Color Table',
                 **kws):
        wx.Panel.__init__(self, parent, -1, **kws)

        self.imgpanel = imgpanel
        self.icol = color
        self.cmap_callback = cmap_callback

        labstyle = wx.ALIGN_LEFT | wx.LEFT | wx.TOP | wx.EXPAND
        sizer = wx.GridBagSizer(2, 2)

        self.title = wx.StaticText(self, label=title, size=(120, -1))
        sizer.Add(self.title, (0, 0), (1, 4), labstyle, 2)

        self.cmap_choice = None
        reverse = False
        cmapname = default
        if colormap_list is not None:
            cmap_choice = wx.Choice(self, size=(90, -1), choices=colormap_list)
            cmap_choice.Bind(wx.EVT_CHOICE, self.onCMap)
            self.cmap_choice = cmap_choice

            if cmapname is None:
                cmapname = colormap_list[0]

            if cmapname.endswith('_r'):
                reverse = True
                cmapname = cmap[:-2]
            cmap_choice.SetStringSelection(cmapname)

            cmap_reverse = wx.CheckBox(self, label='Reverse', size=(60, -1))
            cmap_reverse.Bind(wx.EVT_CHECKBOX, self.onCMapReverse)
            cmap_reverse.SetValue(reverse)
            self.cmap_reverse = cmap_reverse

        if cmapname is None:
            cmapname = 'gray'
        self.imgpanel.conf.cmap[color] = cmap.get_cmap(cmapname)

        maxval = self.imgpanel.conf.cmap_range
        wd, ht = 1.00, 0.125

        self.cmap_dat = np.outer(np.ones(int(maxval * ht)),
                                 np.linspace(0, 1, maxval))

        fig = Figure((wd, ht), dpi=150)

        ax = fig.add_axes([0, 0, 1, 1])
        ax.set_axis_off()
        self.cmap_canvas = FigureCanvas(self, -1, figure=fig)

        self.cmap_img = ax.imshow(self.cmap_dat,
                                  cmap=cmapname,
                                  interpolation='bilinear')
        self.cmap_lo = wx.Slider(self,
                                 -1,
                                 0,
                                 0,
                                 maxval,
                                 style=wx.SL_HORIZONTAL)

        self.cmap_hi = wx.Slider(self,
                                 -1,
                                 maxval,
                                 0,
                                 maxval,
                                 style=wx.SL_HORIZONTAL)

        self.cmap_lo.Bind(wx.EVT_SCROLL, self.onStretchLow)
        self.cmap_hi.Bind(wx.EVT_SCROLL, self.onStretchHigh)

        irow = 0
        if self.cmap_choice is not None:
            irow += 1
            sizer.Add(self.cmap_choice, (irow, 0), (1, 2), labstyle, 2)
            sizer.Add(self.cmap_reverse, (irow, 2), (1, 2), labstyle, 2)

        irow += 1
        sizer.Add(self.cmap_hi, (irow, 0), (1, 4), labstyle, 2)
        irow += 1
        sizer.Add(self.cmap_canvas, (irow, 0), (1, 4), wx.ALIGN_CENTER, 0)
        irow += 1
        sizer.Add(self.cmap_lo, (irow, 0), (1, 4), labstyle, 2)

        self.imin_val = LabeledTextCtrl(self,
                                        0,
                                        size=(80, -1),
                                        labeltext='Range:',
                                        action=partial(self.onThreshold,
                                                       argu='lo'))
        self.imax_val = LabeledTextCtrl(self,
                                        maxval,
                                        size=(80, -1),
                                        labeltext=':',
                                        action=partial(self.onThreshold,
                                                       argu='hi'))
        self.islider_range = wx.StaticText(self,
                                           label='Shown: ',
                                           size=(90, -1))
        irow += 1
        sizer.Add(self.imin_val.label, (irow, 0), (1, 1), labstyle, 1)
        sizer.Add(self.imin_val, (irow, 1), (1, 1), labstyle, 0)
        sizer.Add(self.imax_val.label, (irow, 2), (1, 1), labstyle, 0)
        sizer.Add(self.imax_val, (irow, 3), (1, 1), labstyle, 0)

        irow += 1
        sizer.Add(self.islider_range, (irow, 0), (1, 4), labstyle, 0)

        pack(self, sizer)
        self.set_colormap(cmapname)
class LoadSaveTopoModule(ModuleTemplate):
    """
    Module to save the current topography in a subset of the sandbox
    and recreate it at a later time
    two different representations are saved to the numpy file:

    absolute Topography:
    deviation from the mean height inside the bounding box in millimeter

    relative Height:
    height of each pixel relative to the vmin and vmax of the currently used calibration.
    use relative height with the gempy module to get the same geologic map with different calibration settings.
    """
    def __init__(self, extent: list = None, **kwargs):
        # call parents' class init, use greyscale colormap as standard and extreme color labeling
        pn.extension()
        if extent is not None:
            self.vmin = extent[4]
            self.vmax = extent[5]
            self.extent = extent
        else:
            self.extent = None
        # location of bottom left corner of the box in the sandbox. values refer to pixels of the kinect sensor
        self.box_origin = [40, 40]
        self.box_width = 200
        self.box_height = 150
        self.absolute_topo = None
        self.relative_topo = None

        self.is_loaded = False  # Flag to know if a file have been loaded or not

        self.current_show = 'None'
        self.difference_types = [
            'None', 'Show topography', 'Show difference',
            'Show gradient difference'
        ]

        self.cmap_difference = self._cmap_difference()

        self.difference = None
        self.loaded = None

        self.transparency_difference = 1

        self.npz_filename = None

        self.release_width = 10
        self.release_height = 10
        self.release_area = None
        self.release_area_origin = None
        self.aruco_release_area_origin = None

        self.data_filenames = ['None']
        self.file_id = None
        self.data_path = _test_data['topo']

        self.figure = Figure()
        self.ax = plt.Axes(self.figure, [0., 0., 1., 1.])
        self.figure.add_axes(self.ax)

        self.snapshot_frame = pn.pane.Matplotlib(self.figure,
                                                 tight=False,
                                                 height=500)
        plt.close(self.figure)  # close figure to prevent inline display

        # Stores the axes
        self._lod = None
        # self._dif = None
        self.frame = None
        # contour lines
        self.deactivate_main_contour = False
        self.contours_on = False
        # create the widgets if used from another module
        _ = self.widgets_box()
        logger.info("LoadSaveTopoModule loaded successfully")

    def update(self, sb_params: dict):
        frame = sb_params.get('frame')
        ax = sb_params.get('ax')
        marker = sb_params.get('marker')
        self.extent = sb_params.get('extent')
        self.frame = frame
        if len(marker) > 0:
            self.aruco_release_area_origin = marker.loc[marker.is_inside_box,
                                                        ('box_x', 'box_y')]
            self.add_release_area_origin()
        self.plot(frame, ax)
        sb_params['active_contours'] = not self.deactivate_main_contour

        return sb_params

    def delete_rectangles_ax(self, ax):
        [
            rec.remove() for rec in reversed(ax.patches)
            if isinstance(rec, matplotlib.patches.Rectangle)
        ]
        # ax.patches = []

    def delete_im_ax(self, ax):
        # [quad.remove() for quad in reversed(ax.collections) if isinstance(quad, matplotlib.collections.QuadMesh)]
        # if self._dif is not None:
        #    self._dif.remove()
        #    self._dif = None
        if self._lod is not None:
            self._lod.remove()
            self._lod = None

    def set_show(self, i: str):
        self.current_show = i

    def plot(self, frame, ax):
        self.delete_rectangles_ax(ax)
        self.delete_im_ax(ax)

        if self.current_show == self.difference_types[0]:
            self.delete_im_ax(ax)
        elif self.current_show == self.difference_types[1]:
            self.showLoadedTopo(ax)
        elif self.current_show == self.difference_types[2]:
            self.showDifference(ax)
        elif self.current_show == self.difference_types[3]:
            self.showGradDifference(ax)

        # Show contour lines of the plot
        self.delete_contourns(ax)
        if self.contours_on:
            self.showContourTopo(ax)
        else:
            if self.deactivate_main_contour:
                self.delete_contourns(ax)
                self.deactivate_main_contour = False

        self.showBox(ax, self.box_origin, self.box_width, self.box_height)
        self.plot_release_area(ax, self.release_area_origin,
                               self.release_width, self.release_height)

    def moveBox_possible(self, x, y, width, height):
        """
        Dinamicaly modify the size of the box when the margins extend more than frame
        Args:
            x: x coordinte of the box origin
            y: y coordinate of the box origin
            width: of the box
            height: of the box
        Returns:
        """

        if (x + width) >= self.extent[1]:
            self.box_width = self.extent[1] - x
        else:
            self.box_width = width

        if (y + height) >= self.extent[3]:
            self.box_height = self.extent[3] - y
        else:
            self.box_height = height

        self.box_origin = [x, y]

    def add_release_area_origin(self, x=None, y=None):
        """
        Add a box origin [x,y] to highlight a zone on the image.
        This method also manages the aruco release areas if detected
        Args:
            x: x coordinte of origin
            y: y coordinte of origin
        Returns:

        """
        if self.release_area_origin is None:
            self.release_area_origin = pd.DataFrame(columns=('box_x', 'box_y'))
        if self.aruco_release_area_origin is None:
            self.aruco_release_area_origin = pd.DataFrame(columns=('box_x',
                                                                   'box_y'))
        self.release_area_origin = pd.concat(
            (self.release_area_origin,
             self.aruco_release_area_origin)).drop_duplicates()
        if x is not None and y is not None:
            self.release_area_origin = self.release_area_origin.append(
                {
                    'box_x': x,
                    'box_y': y
                }, ignore_index=True)

    def plot_release_area(self, ax, origin: pd.DataFrame, width: int,
                          height: int):
        """
        Plot small boxes in the frame according to the dataframe origin and width, height specifiend.
        Args:
            ax: matplotlib axes to plot
            origin: pandas dataframe indicating the x and y coordintes of the boxes to plot
            width: width of the box to plot
            height: height of the box to plot
        Returns:
        """
        if origin is not None:
            x_pos = origin.box_x
            y_pos = origin.box_y
            x_origin = x_pos.values - width / 2
            y_origin = y_pos.values - height / 2
            self.release_area = numpy.array([
                [x_origin - self.box_origin[0], y_origin - self.box_origin[1]],
                [
                    x_origin - self.box_origin[0],
                    y_origin + height - self.box_origin[1]
                ],
                [
                    x_origin + width - self.box_origin[0],
                    y_origin + height - self.box_origin[1]
                ],
                [
                    x_origin + width - self.box_origin[0],
                    y_origin - self.box_origin[1]
                ]
            ])
            for i in range(len(x_pos)):
                self.showBox(ax, [x_origin[i], y_origin[i]], width, height)

    @staticmethod
    def showBox(ax, origin: tuple, width: int, height: int):
        """
        Draws a wide rectangle outline in the live view
        Args:
            ax: the axes where the patch will be drawed on
            origin: relative position from bottom left in sensor pixel space
            width: width of box in sensor pixels
            height: height of box in sensor pixels
        Returns:
        """
        box = matplotlib.patches.Rectangle(origin,
                                           width,
                                           height,
                                           fill=False,
                                           edgecolor='white')
        ax.add_patch(box)
        return True

    def getBoxFrame(self, frame: numpy.ndarray):
        """
        Get the absolute and relative topography of the current.
        Crop frame image to dimensions of box
        Args:
            frame: frame of the actual topography
        Returns:
            absolute_topo, the cropped frame minus the mean value and relative_topo,
            the absolute topo normalized to the extent of the sandbox
        """
        cropped_frame = frame[self.box_origin[1]:self.box_origin[1] +
                              self.box_height,
                              self.box_origin[0]:self.box_origin[0] +
                              self.box_width]

        mean_height = cropped_frame.mean()
        absolute_topo = cropped_frame - mean_height
        relative_topo = absolute_topo / (self.vmax - self.vmin)
        return absolute_topo, relative_topo

    def extractTopo(self):
        """
        Extract the topography of the current frame and stores the value internally
        Returns:
            absolute topography and relative topography
        """
        self.is_loaded = True
        self.absolute_topo, self.relative_topo = self.getBoxFrame(self.frame)
        return self.absolute_topo, self.relative_topo

    def saveTopo(self, filename="00_savedTopography.npz"):
        """Save the absolute topography and relative topography in a .npz file"""
        numpy.savez(filename, self.absolute_topo, self.relative_topo)
        logger.info('Save topo successful')

    def save_release_area(self, filename="00_releaseArea.npy"):
        """Save the release areas as a .npy file """
        numpy.save(filename, self.release_area)
        logger.info('Save area successful')

    def loadTopo(self, filename="00_savedTopography.npz"):
        """Load the absolute topography and relative topography from a .npz file.
        If usinng a single .npy is assumed to be an outside DEM """
        self.is_loaded = True
        files = numpy.load(filename, allow_pickle=True)
        if filename.split(".")[-1] == "npz":
            self.absolute_topo = files['arr_0']
            self.relative_topo = files['arr_1']
            logger.info('Load sandbox topography successfully')
        elif filename.split(".")[-1] == "npy":
            target = [
                0, self.box_width, 0, self.box_height, self.extent[-2],
                self.extent[-1]
            ]
            self.absolute_topo, self.relative_topo = self.normalize_topography(
                files, target)

        self._get_id(filename)

    def showLoadedTopo(self, ax):
        """
        If a topography is saved internally, display the saved topograhy on the sandbox
        Args:
            ax: axes to plot the saved topography
        Returns:
        """
        if self.is_loaded:
            shape_frame = self.getBoxShape()
            # self.loaded = self.modify_to_box_coordinates(self.absolute_topo[:shape_frame[0],
            #                                             :shape_frame[1]])
            self.loaded = self.absolute_topo[:shape_frame[0], :shape_frame[1]]
            # if self._lod is None:

            self._lod = ax.imshow(
                self.loaded,
                cmap='gist_earth',
                origin="lower",
                # TODO: data is inverted, need to fix this for all the landsladides topography data
                zorder=2,
                extent=self.to_box_extent,
                aspect="auto")
            # else:
            # self._lod.set_array(self.loaded[:-1,:-1].ravel())
        else:
            # if self._lod is not None:
            # self._lod.remove()
            # self._lod = None
            logger.warning("No Topography loaded, please load a Topography")

    def showContourTopo(self, ax):
        """
         If a topography is saved internally, display the saved topograhy on the sandbox
        Args:
            ax: axes to plot the saved topography
        Returns:
        """
        if self.is_loaded:
            self.deactivate_main_contour = True
            shape_frame = self.getBoxShape()

            self.loaded = self.absolute_topo[:shape_frame[0], :shape_frame[1]]
            self._cont = ax.contour(self.loaded,
                                    zorder=1000,
                                    extent=self.to_box_extent,
                                    colors="k")
            self._label = ax.clabel(self._cont,
                                    inline=True,
                                    fontsize=15,
                                    fmt='%3.0f')

        else:
            self.delete_contourns(ax)
            self.deactivate_main_contour = False
            logger.warning(
                "No Topography loaded, please load a Topography to display contour lines"
            )

    def delete_contourns(self, ax):
        if self.deactivate_main_contour:
            [
                coll.remove() for coll in reversed(ax.collections)
                if isinstance(coll, matplotlib.collections.LineCollection)
            ]
            [
                text.remove() for text in reversed(ax.artists)
                if isinstance(text, matplotlib.text.Text)
            ]

    @staticmethod
    def normalize_topography(dem, target_extent):
        """
        Normalize any size of numpy array to fit the sandbox frame.
        Useful when passing DEM with resolution bigger than sandbox sensor.
        Args:
            dem:
            target_extent: [minx, maxx, miny, maxy, vmin, vmax] ->
            [0, frame_width, 0, frame_height, vmin_sensor, vmax_sensor]
        Returns:
             normalized frame
        """
        # Change shape of numpy array to desired shape
        topo_changed = skimage.transform.resize(
            dem, (target_extent[3], target_extent[1]),
            order=3,
            mode='edge',
            anti_aliasing=True,
            preserve_range=False)

        topo_min = topo_changed.min()
        topo_max = topo_changed.max()
        # when the min value is not 0
        if topo_min != 0:
            displ = 0 - topo_min
            topo_changed = topo_changed - displ

        topo_changed = topo_changed * (
            target_extent[-1] - target_extent[-2]) / (topo_max - topo_min)
        mean_height = topo_changed.mean()
        absolute_topo = topo_changed - mean_height
        relative_topo = topo_changed / (target_extent[-1] - target_extent[-2])

        return absolute_topo, relative_topo

    def modify_to_box_coordinates(self, frame):
        """
        Since the box is not in the origin of the frame,
        this will correctly display the loaded topography inside the box
        Args:
            frame: the frame that need to be modified to box coordintes
        Returns:
            The modified frame
        """
        width = frame.shape[0]
        left = numpy.ones((self.box_origin[0], width))
        left[left == 1] = numpy.nan
        frame = numpy.insert(frame, 0, left, axis=1)

        height = frame.shape[1]
        bot = numpy.ones((self.box_origin[1], height))
        bot[bot == 1] = numpy.nan
        frame = numpy.insert(frame, 0, bot, axis=0)
        # frame = numpy.ma.array(frame, mask=numpy.nan)
        return frame

    def saveTopoVector(self):  # TODO:
        """
        saves a vector graphic of the contour map to disk
        """
        pass

    def _cmap_difference(self):
        """Creates a custom made color map"""
        blues = plt.cm.RdBu(numpy.linspace(0, 0.5, 256))
        reds = plt.cm.RdBu(numpy.linspace(0.5, 1, 256))
        blues_reds = numpy.vstack((blues, reds))
        cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
            'difference_map', blues_reds)
        return cmap

    @property
    def norm_difference(self):
        """Creates a custom made norm"""
        norm = matplotlib.colors.TwoSlopeNorm(vmin=self.absolute_topo.min(),
                                              vcenter=0,
                                              vmax=self.absolute_topo.max())
        return norm

    def getBoxShape(self):
        """This will return the shape of the current saved topography"""
        current_absolute_topo, current_relative_topo = self.getBoxFrame(
            self.frame)
        x_dimension, y_dimension = current_absolute_topo.shape
        x_saved, y_saved = self.absolute_topo.shape
        shape_frame = [
            numpy.min((x_dimension, x_saved)),
            numpy.min((y_dimension, y_saved))
        ]
        return shape_frame

    def extractDifference(self):
        """This will return a numpy array comparing the difference between the current frame and the saved frame """
        current_absolute_topo, _ = self.getBoxFrame(self.frame)
        shape_frame = self.getBoxShape()
        diff = self.absolute_topo[:shape_frame[0], :shape_frame[
            1]] - current_absolute_topo[:shape_frame[0], :shape_frame[1]]

        # paste diff array at right location according to box coordinates
        # difference = self.modify_to_box_coordinates(diff)
        return diff

    @property
    def to_box_extent(self):
        """When using imshow to plot data over the image. pass this as extent argumment to display the
        image in the correct area of the sandbox box-area"""
        return (self.box_origin[0], self.box_width + self.box_origin[0],
                self.box_origin[1], self.box_height + self.box_origin[1])

    def showDifference(self, ax):
        """
        Displays the calculated difference of the previous frame with the actual frame
        Args:
            ax: Axes to plot the difference
        Returns:
        """
        if self.is_loaded:
            difference = self.extractDifference()
            # plot
            # if self._dif is None:
            self._lod = ax.imshow(difference,
                                  cmap=self.cmap_difference,
                                  alpha=self.transparency_difference,
                                  norm=self.norm_difference,
                                  origin="lower",
                                  zorder=1,
                                  extent=self.to_box_extent,
                                  aspect="auto")
            # else:
            #   self._dif.set_array(difference[:-1, :-1].ravel())
        else:
            # if self._dif is not None:
            #    self._dif.remove()
            #   self._dif = None
            logger.warning('No topography to show difference')

    def showGradDifference(self, ax):
        """
        Displays the calculated gradient difference of the previous frame with the actual frame
        Args:
            ax: Axes to plot the difference
        Returns:
        """
        if self.is_loaded:
            grad = self.extractGradDifference()
            # plot
            # if self._dif is None:
            self._lod = ax.imshow(grad,
                                  vmin=-5,
                                  vmax=5,
                                  cmap=self.cmap_difference,
                                  alpha=self.transparency_difference,
                                  norm=self.norm_difference,
                                  origin="lower",
                                  zorder=1,
                                  extent=self.to_box_extent,
                                  aspect="auto")
            # else:
            #   self._dif.set_array(difference[:-1, :-1].ravel())
        else:
            # if self._dif is not None:
            #    self._dif.remove()
            #   self._dif = None
            logger.warning('No topography to show gradient difference')

    def extractGradDifference(self):
        """This will return a numpy array comparing the difference of second degree (gradients)
        between the current frame and the saved frame """
        current_absolute_topo, _ = self.getBoxFrame(self.frame)
        dx_current, dy_current = numpy.gradient(current_absolute_topo)
        dxdy_current = numpy.sqrt(dx_current**2 + dy_current**2)
        dxdy_current = numpy.clip(dxdy_current, -5, 5)

        dx_lod, dy_lod = numpy.gradient(self.absolute_topo)
        dxdy_lod = numpy.sqrt(dx_lod**2 + dy_lod**2)
        dxdy_lod = numpy.clip(dxdy_lod, -5, 5)

        shape_frame = self.getBoxShape()
        gradDiff = dxdy_current[:shape_frame[0], :shape_frame[
            1]] - dxdy_lod[:shape_frame[0], :shape_frame[1]]

        # paste diff array at right location according to box coordinates
        # difference = self.modify_to_box_coordinates(diff)
        return gradDiff * -1

    def snapshotFrame(self):
        """This will display the saved topography and display it in the panel bokeh"""
        self.ax.cla()
        self.ax.imshow(self.absolute_topo,
                       cmap='gist_earth',
                       origin="lower",
                       aspect='auto')
        self.ax.axis('equal')
        self.ax.set_axis_off()
        self.ax.set_title('Loaded Topography')
        self.snapshot_frame.param.trigger('object')

    def _search_all_data(self, data_path):
        self.data_filenames = os.listdir(data_path)

    def _get_id(self, filename):
        ids = [str(s) for s in filename if s.isdigit()]
        if len(ids) > 0:
            self.file_id = ids[-1]
        else:
            logger.warning("Unknown file id")

    def show_widgets(self):
        tabs = pn.Tabs(('Box widgets', self.widgets_box()),
                       ('Release area widgets', self.widgets_release_area()),
                       ('Load Topography', self.widgets_load()),
                       ('Save Topography', self.widgets_save()))
        return tabs

    def widgets_release_area(self):
        # Release area widgets
        self._widget_release_width = pn.widgets.IntSlider(
            name='Release area width',
            value=self.release_width,
            start=1,
            end=50)
        self._widget_release_width.param.watch(self._callback_release_width,
                                               'value',
                                               onlychanged=False)

        self._widget_release_height = pn.widgets.IntSlider(
            name='Release area height',
            value=self.release_height,
            start=1,
            end=50)
        self._widget_release_height.param.watch(self._callback_release_height,
                                                'value',
                                                onlychanged=False)

        self._widget_show_release = pn.widgets.RadioButtonGroup(
            name='Show or erase the areas',
            options=['Show', 'Erase'],
            value=['Erase'],
            button_type='success')
        self._widget_show_release.param.watch(self._callback_show_release,
                                              'value',
                                              onlychanged=False)

        widgets = pn.WidgetBox(
            '<b>Modify the size and shape of the release area </b>',
            self._widget_release_width, self._widget_release_height,
            self._widget_show_release)
        panel = pn.Column("### Shape release area", widgets)

        return panel

    def widgets_box(self):
        # Box widgets
        self._widget_show_type = pn.widgets.RadioBoxGroup(
            name='Show in sandbox',
            options=self.difference_types,
            value=self.difference_types[0],
            inline=False)
        self._widget_show_type.param.watch(self._callback_show,
                                           'value',
                                           onlychanged=False)

        self._widget_move_box_horizontal = pn.widgets.IntSlider(
            name='x box origin',
            value=self.box_origin[0],
            start=0,
            end=self.extent[1])
        self._widget_move_box_horizontal.param.watch(
            self._callback_move_box_horizontal, 'value', onlychanged=False)

        self._widget_move_box_vertical = pn.widgets.IntSlider(
            name='y box origin',
            value=self.box_origin[1],
            start=0,
            end=self.extent[3])
        self._widget_move_box_vertical.param.watch(
            self._callback_move_box_vertical, 'value', onlychanged=False)

        self._widget_box_width = pn.widgets.IntSlider(name='box width',
                                                      value=self.box_width,
                                                      start=0,
                                                      end=self.extent[1])
        self._widget_box_width.param.watch(self._callback_box_width,
                                           'value',
                                           onlychanged=False)

        self._widget_box_height = pn.widgets.IntSlider(name='box height',
                                                       value=self.box_height,
                                                       start=0,
                                                       end=self.extent[3])
        self._widget_box_height.param.watch(self._callback_box_height,
                                            'value',
                                            onlychanged=False)

        # Snapshots
        self._widget_snapshot = pn.widgets.Button(name="Snapshot",
                                                  button_type="success")
        self._widget_snapshot.param.watch(self._callback_snapshot,
                                          'clicks',
                                          onlychanged=False)

        self._widget_plot_contours = pn.widgets.Checkbox(
            name='Show contours', value=self.contours_on)
        self._widget_plot_contours.param.watch(self._callback_plot_contours,
                                               'value',
                                               onlychanged=False)

        widgets = pn.Column('<b>Modify box size </b>',
                            self._widget_move_box_horizontal,
                            self._widget_move_box_vertical,
                            self._widget_box_width, self._widget_box_height,
                            '<b>Take snapshot</b>', self._widget_snapshot,
                            '<b>Show in sandbox</b>', self._widget_show_type,
                            '<b>Show contour lines of target topography</b>',
                            self._widget_plot_contours)

        rows = pn.Row(widgets, self.snapshot_frame)
        panel = pn.Column("### Interaction widgets", rows)

        return panel

    def widgets_save(self):
        self._widget_npz_filename = pn.widgets.TextInput(
            name='Choose a filename to save the current topography snapshot:')
        self._widget_npz_filename.param.watch(self._callback_filename_npz,
                                              'value',
                                              onlychanged=False)
        self._widget_npz_filename.value = _test_data[
            'topo'] + '/savedTopography.npz'

        self._widget_save = pn.widgets.Button(name='Save')
        self._widget_save.param.watch(self._callback_save,
                                      'clicks',
                                      onlychanged=False)

        panel = pn.Column("### Save widget", '<b>Filename</b>',
                          self._widget_npz_filename, '<b>Safe Topography</b>',
                          self._widget_save)
        return panel

    def widgets_load(self):
        self._widget_data_path = pn.widgets.TextInput(
            name='Choose a folder to load the available topography snapshots:')
        self._widget_data_path.value = self.data_path
        self._widget_data_path.param.watch(self._callback_filename,
                                           'value',
                                           onlychanged=False)

        self._widget_load = pn.widgets.Button(name='Load Files in folder')
        self._widget_load.param.watch(self._callback_load,
                                      'clicks',
                                      onlychanged=False)

        self._widget_available_topography = pn.widgets.RadioBoxGroup(
            name='Available Topographies',
            options=self.data_filenames,
            inline=False)
        self._widget_available_topography.param.watch(
            self._callback_available_topography, 'value', onlychanged=False)

        # self._widget_other_topography = pn.widgets.FileInput(name="Load calibration (Note yet working)")
        self._widget_other_topography = pn.widgets.FileSelector('~')
        # self._widget_other_topography.param.watch(self._callback_other_topography, 'value')
        self._widget_load_other = pn.widgets.Button(name='Load other',
                                                    button_type='success')
        self._widget_load_other.param.watch(self._callback_load_other,
                                            'clicks',
                                            onlychanged=False)

        panel = pn.Column("### Load widget", '<b>Directory path</b>',
                          self._widget_data_path, '<b>Load Topography</b>',
                          self._widget_load,
                          '<b>Load available Topography</b>',
                          pn.WidgetBox(self._widget_available_topography),
                          '<b>Select another Topography file</b>',
                          self._widget_other_topography,
                          self._widget_load_other)

        return panel

    def _callback_plot_contours(self, event):
        self.contours_on = event.new

    def _callback_show(self, event):
        self.set_show(event.new)

    def _callback_show_release(self, event):
        if event.new == 'Show':
            self.add_release_area_origin()
        else:
            self.release_area_origin = None

    def _callback_release_width(self, event):
        self.release_width = event.new

    def _callback_release_height(self, event):
        self.release_height = event.new

    def _callback_filename_npz(self, event):
        self.npz_filename = event.new

    def _callback_filename(self, event):
        self.data_path = event.new

    def _callback_save(self, event):
        if self.npz_filename is not None:
            self.saveTopo(filename=self.npz_filename)

    def _callback_load(self, event):
        if self.data_path is not None:
            # self.loadTopo(filename=self.npz_filename)
            # self.snapshotFrame()
            self._search_all_data(data_path=self.data_path)
            self._widget_available_topography.options = self.data_filenames
            self._widget_available_topography.sizing_mode = "scale_both"

    def _callback_move_box_horizontal(self, event):
        self.moveBox_possible(x=event.new,
                              y=self.box_origin[1],
                              width=self.box_width,
                              height=self.box_height)

    def _callback_move_box_vertical(self, event):
        self.moveBox_possible(x=self.box_origin[0],
                              y=event.new,
                              width=self.box_width,
                              height=self.box_height)

    def _callback_box_width(self, event):
        self.moveBox_possible(x=self.box_origin[0],
                              y=self.box_origin[1],
                              width=event.new,
                              height=self.box_height)

    def _callback_box_height(self, event):
        self.moveBox_possible(x=self.box_origin[0],
                              y=self.box_origin[1],
                              width=self.box_width,
                              height=event.new)

    def _callback_snapshot(self, event):
        self.extractTopo()
        self.snapshotFrame()

    def _callback_available_topography(self, event):
        if event.new is not None:
            self.loadTopo(filename=self.data_path + event.new)
            self.snapshotFrame()

    def _callback_load_other(self, event):
        self.loadTopo(filename=self._widget_other_topography.value[0])
        self.snapshotFrame()
Example #36
0
class AbstractGroupPlotPlugin(FigureCanvas):
    '''
	Abstract base class specifying interface of a group plot plugin.
	'''
    def __init__(self, preferences, parent=None):
        self.preferences = preferences

        self.fig = Figure(facecolor='white', dpi=96)

        FigureCanvas.__init__(self, self.fig)

        #self.setParent(parent)
        #FigureCanvas.setSizePolicy(self,QtGui.QSizePolicy.Fixed,QtGui.QSizePolicy.Fixed)
        #FigureCanvas.updateGeometry(self)

        self.cid = None

        self.type = '<none>'
        self.name = '<none>'
        self.bSupportsHighlight = False
        self.bPlotFeaturesIndividually = True

    def mouseEventCallback(self, callback):
        if self.cid != None:
            FigureCanvas.mpl_disconnect(self, self.cid)

        self.cid = FigureCanvas.mpl_connect(self, 'button_press_event',
                                            callback)

    def plot(self, profile, statsResults):
        pass

    def configure(self, profile, statsResults):
        pass

    def savePlot(self, filename, dpi=300):
        format = filename[filename.rfind('.') + 1:len(filename)]
        if format in ['png', 'pdf', 'ps', 'eps', 'svg']:
            self.fig.savefig(filename,
                             format=format,
                             dpi=dpi,
                             facecolor='white',
                             edgecolor='white')
        else:
            pass

    def clear(self):
        self.fig.clear()

    def mirrorProperties(self, plotToCopy):
        self.type = plotToCopy.type
        self.name = plotToCopy.name
        self.bSupportsHighlight = plotToCopy.bSupportsHighlight

    def labelExtents(self, xLabels, xFontSize, xRotation, yLabels, yFontSize,
                     yRotation):
        self.fig.clear()

        tempAxes = self.fig.add_axes([0, 0, 1.0, 1.0])

        tempAxes.set_xticks(np.arange(len(xLabels)))
        tempAxes.set_yticks(np.arange(len(yLabels)))

        xText = tempAxes.set_xticklabels(xLabels,
                                         size=xFontSize,
                                         rotation=xRotation)
        yText = tempAxes.set_yticklabels(yLabels,
                                         size=yFontSize,
                                         rotation=yRotation)

        bboxes = []
        for label in xText:
            bbox = label.get_window_extent(self.get_renderer())
            bboxi = bbox.inverse_transformed(self.fig.transFigure)
            bboxes.append(bboxi)
        xLabelBounds = mtransforms.Bbox.union(bboxes)

        bboxes = []
        for label in yText:
            bbox = label.get_window_extent(self.get_renderer())
            bboxi = bbox.inverse_transformed(self.fig.transFigure)
            bboxes.append(bboxi)
        yLabelBounds = mtransforms.Bbox.union(bboxes)

        self.fig.clear()

        return xLabelBounds, yLabelBounds

    def xLabelExtents(self, labels, fontSize, rotation=0):
        self.fig.clear()

        tempAxes = self.fig.add_axes([0, 0, 1.0, 1.0])
        tempAxes.set_xticks(np.arange(len(labels)))
        xLabels = tempAxes.set_xticklabels(labels,
                                           size=fontSize,
                                           rotation=rotation)

        bboxes = []
        for label in xLabels:
            bbox = label.get_window_extent(self.get_renderer())
            bboxi = bbox.inverse_transformed(self.fig.transFigure)
            bboxes.append(bboxi)
        xLabelBounds = mtransforms.Bbox.union(bboxes)

        self.fig.clear()

        return xLabelBounds

    def yLabelExtents(self, labels, fontSize, rotation=0):
        self.fig.clear()

        tempAxes = self.fig.add_axes([0, 0, 1.0, 1.0])
        tempAxes.set_yticks(np.arange(len(labels)))
        yLabels = tempAxes.set_yticklabels(labels,
                                           size=fontSize,
                                           rotation=rotation)

        bboxes = []
        for label in yLabels:
            bbox = label.get_window_extent(self.get_renderer())
            bboxi = bbox.inverse_transformed(self.fig.transFigure)
            bboxes.append(bboxi)
        yLabelBounds = mtransforms.Bbox.union(bboxes)

        self.fig.clear()

        return yLabelBounds

    def emptyAxis(self, title=''):
        self.fig.clear()
        self.fig.set_size_inches(6, 4)
        emptyAxis = self.fig.add_axes([0.1, 0.1, 0.8, 0.8])

        emptyAxis.set_ylabel('No active features or degenerate plot',
                             fontsize=8)
        emptyAxis.set_xlabel('No active features or degenerate plot',
                             fontsize=8)
        emptyAxis.set_yticks([])
        emptyAxis.set_xticks([])
        emptyAxis.set_title(title)

        for loc, spine in emptyAxis.spines.iteritems():
            if loc in ['right', 'top']:
                spine.set_color('none')

        #self.updateGeometry()
        self.draw()

    def formatLabels(self, labels):
        formattedLabels = []
        for label in labels:
            value = float(label.get_text())
            if value < 0.01:
                valueStr = '%.2e' % value
                if 'e-00' in valueStr:
                    valueStr = valueStr.replace('e-00', 'e-')
                elif 'e-0' in valueStr:
                    valueStr = valueStr.replace('e-0', 'e-')
            else:
                valueStr = '%.3f' % value

            formattedLabels.append(valueStr)

        return formattedLabels
Example #37
0
    linewidth=0.0,
    frameon=False,
    subplotpars=None,
    tight_layout=None)
canvas = FigureCanvas(fig)
font = {
    'family': 'sans-serif',
    'sans-serif': 'Arial',
    'weight': 'normal',
    'size': 16
}
matplotlib.rc('font', **font)

# Plot the lines
ax = fig.add_axes([0.06, 0.06, 0.93, 0.67],
                  xlim=((start - datetime.timedelta(days=1)),
                        (end + datetime.timedelta(days=1))),
                  ylim=ylim)
ax.set_ylabel('TMP2m minus v3 mean')
ax2 = fig.add_axes([0.06, 0.75, 0.93, 0.22],
                   xlim=((start - datetime.timedelta(days=1)),
                         (end + datetime.timedelta(days=1))),
                   ylim=[.1, 1.1])
ax2.set_ylabel('Spread')
ax2.get_xaxis().set_visible(False)

(ndata, dts, spread) = fromversion('3')
v3m = numpy.mean(ndata, 1)
for m in range(80):
    ax.add_line(
        Line2D(dts,
               ndata[:, m] - v3m,
Example #38
0
def dotplot(df, cutoff=0.05, term_num=10, figsize=(3, 6), scale=50):
    """Visualize enrichr or gsea results.
    
    :param df: GSEApy DataFrame results. 
    :param cutoff: p-adjust cut-off. 
    :param term_num: number of enriched terms to show.
    :param scale: dotplot point size scale.
    :return:  a dotplot for enrichr terms. 

    """

    if 'fdr' in df.columns:
        #gsea results
        df.rename(columns={
            'fdr': 'Adjusted P-value',
        }, inplace=True)
        df['hits_ratio'] = df['matched_size'] / df['gene_set_size']
    else:
        #enrichr results
        df['Count'] = df['Overlap'].str.split("/").str[0].astype(int)
        df['Background'] = df['Overlap'].str.split("/").str[1].astype(int)
        df['hits_ratio'] = df['Count'] / df['Background']

    # pvalue cut off
    df = df[df['Adjusted P-value'] <= cutoff]

    if len(df) < 1:
        logging.warning("Warning: No enrich terms when cuttoff = %s" % cutoff)
        return None
    #sorting the dataframe for better visualization
    df = df.sort_values(by='Adjusted P-value', ascending=False)
    df = df.head(term_num)
    # x axis values
    padj = df['Adjusted P-value']
    x = -padj.apply(np.log10)
    # y axis index and values
    y = [i for i in range(0, len(df))]
    labels = df.Term.values

    area = np.pi * (df['hits_ratio'] * scale)**2

    #creat scatter plot
    if hasattr(sys, 'ps1'):
        #working inside python console, show figure
        fig, ax = plt.subplots(figsize=figsize)
    else:
        #If working on commandline, don't show figure
        fig = Figure(figsize=figsize)
        canvas = FigureCanvas(fig)
        ax = fig.add_subplot(111)
    vmin = np.percentile(padj.min(), 2)
    vmax = np.percentile(padj.max(), 98)
    sc = ax.scatter(x=x,
                    y=y,
                    s=area,
                    edgecolors='face',
                    c=padj,
                    cmap=plt.cm.RdBu,
                    vmin=vmin,
                    vmax=vmax)
    ax.set_xlabel("-log$_{10}$(Adjust P-value)")
    ax.yaxis.set_major_locator(plt.FixedLocator(y))
    ax.yaxis.set_major_formatter(plt.FixedFormatter(labels))
    ax.set_ylim([-1, len(df)])
    ax.grid()

    #colorbar
    cax = fig.add_axes([0.93, 0.20, 0.05, 0.20])
    cbar = fig.colorbar(
        sc,
        cax=cax,
    )
    cbar.ax.tick_params(right='off')
    cbar.ax.set_title('Padj', loc='left')

    #scale of dots
    ax2 = fig.add_axes([0.93, 0.55, 0.05, 0.12])

    #for terms less than 3
    if len(df) >= 3:
        x2 = [0] * 3
        # find the index of the closest value to the median
        idx = [
            area.argmax(),
            np.abs(area - area.median()).argmin(),
            area.argmin()
        ]
    else:
        x2 = [0] * len(df)
        idx = df.index

    y2 = [i for i in range(0, len(x2))]
    ax2.scatter(x=x2, y=y2, s=area[idx], c='black', edgecolors='face')

    for i, index in enumerate(idx):
        ax2.text(x=0.8,
                 y=y2[i],
                 s=df['hits_ratio'][index].round(2),
                 verticalalignment='center',
                 horizontalalignment='left')
    ax2.set_title("Gene\nRatio", loc='left')

    #turn off all spines and ticks
    ax2.axis('off')

    #plt.tight_layout()
    #canvas.print_figure('test', bbox_inches='tight')
    return fig
Example #39
0
from matplotlib.figure import Figure
from matplotlib.patches import Polygon
from matplotlib.backends.backend_agg import FigureCanvasAgg
import matplotlib.numerix as nx

figsize = (3, 8)
dpi = 80

from matplotlib import mpl
fig = Figure(figsize=figsize)

#ax = fig.add_subplot(111)
# Make a figure and axes with dimensions as desired.
#fig = pyplot.figure(figsize=(8,3))
#[left, bottom, width, height]
ax1 = fig.add_axes([0.05, 0.05, 0.15, 0.9])
ax2 = fig.add_axes([0.65, 0.05, 0.15, 0.9])

# Set the colormap and norm to correspond to the data for which
# the colorbar will be used.
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=5, vmax=10)

# ColorbarBase derives from ScalarMappable and puts a colorbar
# in a specified axes, so it has everything needed for a
# standalone colorbar.  There are many more kwargs, but the
# following gives a basic continuous colorbar with ticks
# and labels.
cb1 = mpl.colorbar.ColorbarBase(ax1,
                                cmap=cmap,
                                norm=norm,
Example #40
0
    figsize=(15, 15 * 1.06 / 1.04),  # Width, Height (inches)
    dpi=100,
    facecolor=(0.88, 0.88, 0.88, 1),
    edgecolor=None,
    linewidth=0.0,
    frameon=False,
    subplotpars=None,
    tight_layout=None)
canvas = FigureCanvas(fig)

# Global projection
projection = ccrs.RotatedPole(pole_longitude=180.0, pole_latitude=90.0)
extent = [-180, 180, -90, 90]

# Top half for the originals
ax_orig = fig.add_axes([0.02, 0.51, 0.96, 0.47], projection=projection)
ax_orig.set_axis_off()
ax_orig.set_extent(extent, crs=projection)
ax_post = fig.add_axes([0.02, 0.02, 0.96, 0.47], projection=projection)
ax_post.set_axis_off()
ax_post.set_extent(extent, crs=projection)

# Background, grid and land for both
ax_orig.background_patch.set_facecolor((0.88, 0.88, 0.88, 1))
ax_post.background_patch.set_facecolor((0.88, 0.88, 0.88, 1))
mg.background.add_grid(ax_orig)
mg.background.add_grid(ax_post)
land_img_orig = ax_orig.background_img(name='GreyT', resolution='low')
land_img_post = ax_post.background_img(name='GreyT', resolution='low')

# Plot the pressures as contours
Example #41
0
 def _fmap_default(self):
     figure = Figure()
     figure.add_axes([0.05, 0.04, 0.9, 0.92])
     return figure
Example #42
0
 def _figure_default(self):
     figure = Figure(facecolor='white')
     figure.add_axes([0.08, 0.13, 0.85, 0.74])
     return figure
Example #43
0
from matplotlib.figure import Figure
from numpy import meshgrid
import matplotlib.numerix as nx
import matplotlib.cm as cm
from matplotlib.mlab import load

# read in topo data (on a regular lat/lon grid)
# longitudes go from 20 to 380.
etopo = load('etopo20data.gz')
lons = load('etopo20lons.gz')
lats = load('etopo20lats.gz')
# create figure.
fig = Figure()
canvas = FigureCanvas(fig)
# create axes instance, leaving room for colorbar at bottom.
ax = fig.add_axes([0.125, 0.175, 0.75, 0.75])
# create Basemap instance for Robinson projection.
# set 'ax' keyword so pylab won't be imported.
m = Basemap(projection='robin', lon_0=0.5 * (lons[0] + lons[-1]), ax=ax)
# make filled contour plot.
x, y = m(*meshgrid(lons, lats))
cs = m.contourf(x, y, etopo, 30, cmap=cm.jet)
# draw coastlines.
m.drawcoastlines()
# draw a line around the map region.
m.drawmapboundary()
# draw parallels and meridians.
m.drawparallels(nx.arange(-60., 90., 30.), labels=[1, 0, 0, 0], fontsize=10)
m.drawmeridians(nx.arange(0., 420., 60.), labels=[0, 0, 0, 1], fontsize=10)
# add a title.
ax.set_title('Robinson Projection')
Example #44
0
order = numpy.argsort(numpy.abs(autoencoder.get_weights()[1]))[::-1]

fig = Figure(
    figsize=(19.2, 10.8),  # 1920x1080, HD
    dpi=100,
    facecolor=(0.88, 0.88, 0.88, 1),
    edgecolor=None,
    linewidth=0.0,
    frameon=False,
    subplotpars=None,
    tight_layout=None)
canvas = FigureCanvas(fig)

# Top right - map showing original and reconstructed fields
projection = ccrs.RotatedPole(pole_longitude=180.0, pole_latitude=90.0)
ax_map = fig.add_axes([0.505, 0.51, 0.475, 0.47], projection=projection)
ax_map.set_axis_off()
extent = [-180, 180, -90, 90]
ax_map.set_extent(extent, crs=projection)
matplotlib.rc('image', aspect='auto')

# Run the data through the autoencoder and convert back to iris cube
pm = ic.copy()
pm.data = normalise(pm.data)
ict = tf.convert_to_tensor(pm.data, numpy.float32)
ict = tf.reshape(ict, [1, 18048])  # ????
result = autoencoder.predict_on_batch(ict)
result = tf.reshape(result, ic.data.shape)
pm.data = unnormalise(result)

# Background, grid and land
Example #45
0
    def new_window(self, controller, pathGML, pathSIMOGenMovData,
                   numberOfInfected, percentageInfection, spreadD, IP):
        global top, spreadDistance, frameNew, ax, fig, fig2D, ax2D, currentDay, labelDay, IncubationVal, currentTime, labelTime, timeIncreaser, var2, label2, HumanCount, infectedHumanNumber, healthyHumanNumber
        top = tkinter.Toplevel()
        top.title("Virus propagation model")
        top.attributes('-fullscreen', True)
        fig = Figure(figsize=(7, 7), dpi=100, facecolor='#F0F0F0')
        ax = Axes3D(fig, auto_add_to_figure=False)
        fig.add_axes(ax)
        fig2D = Figure(figsize=(5.5, 5.5), dpi=100, facecolor='#F0F0F0')
        ax2D = Axes3D(fig2D, auto_add_to_figure=False)
        fig2D.add_axes(ax2D)
        fig2D.suptitle("Floor " + str(floorChanger) + ":", fontsize=12)
        button1 = tkinter.Button(top,
                                 text="Close",
                                 font=fontName,
                                 command=lambda self=self, controller=
                                 controller: self.closeFunction(controller))
        button1.pack(padx=2, pady=2)
        spreadDistance = float(spreadD.get())
        IncubationVal = int(IP.get())
        # scaling by 1/150
        spreadDistance = spreadDistance / 150
        # parsing indoor gml data
        myGML_3D(pathGML.get())
        id_arr, startinitT, endfinalT, increasetime = gettingData(
            pathSIMOGenMovData.get())
        canvas1 = tkinter.Canvas(top,
                                 highlightbackground="black",
                                 highlightcolor="black",
                                 highlightthickness=1)
        canvas1.pack(padx=5, pady=5, expand=True, fill="both", side="right")
        frame_top = tkinter.Frame(top,
                                  highlightbackground="black",
                                  highlightcolor="black",
                                  highlightthickness=1)
        frame_top.pack(side="top", padx=1, pady=1)
        varNew = tkinter.StringVar()
        labelNew = tkinter.Label(frame_top,
                                 textvariable=varNew,
                                 font=('Arial', 16))
        varNew.set("Virus propagation model in Indoor space")
        labelNew.pack(side="top", padx=5, pady=5)
        labelDay = tkinter.StringVar()
        labelDay.set("Day: " + str(currentDay))
        main_label = tkinter.Label(frame_top,
                                   textvariable=labelDay,
                                   font=('Arial 14 bold'))
        main_label.pack(side="top", padx=5, pady=1)
        labelTime = tkinter.StringVar()
        labelTime.set("Time: " + str(sometime))
        main_labelTime = tkinter.Label(frame_top,
                                       textvariable=labelTime,
                                       font=('Arial 14 bold'))
        main_labelTime.pack(side="top", padx=5, pady=1)
        frame_bottom = tkinter.Frame(frame_top)
        frame_bottom.pack(side="bottom", padx=5, pady=1)
        canvas = FigureCanvasTkAgg(fig, master=frame_top)
        canvas.get_tk_widget().pack(side="left", padx=5, pady=1)
        canvas.mpl_connect('button_press_event', ax.axes._button_press)
        canvas.mpl_connect('button_release_event', ax.axes._button_release)
        canvas.mpl_connect('motion_notify_event', ax.axes._on_move)
        HumanCount = len(id_arr)
        # creating objects by assigning their id, path, start and end time to each person
        for ival, i in enumerate(np.arange(0, HumanCount)):
            regularHuman = Person(i, float(percentageInfection.get()))
            regularHuman.humanID = id_arr[i]
            humans.append(regularHuman)
            for ival2, i in enumerate(range(len(idWithCoord))):
                if regularHuman.humanID == idWithCoord[i][0]:
                    temporary = [
                        idWithCoord[i][1], idWithCoord[i][2],
                        idWithCoord[i][3] + 2.5
                    ]
                    regularHuman.path.append(temporary)
                    regularHuman.startT.append(idWithCoord[i][4])
                    regularHuman.endT.append(idWithCoord[i][5])
            regularHuman.pathSize = int(len(regularHuman.path))

        timeIncreaser = 41400 / increasetime
        for i, h in enumerate(humans):
            for j in range(len(h.endT)):
                difference = h.endT[j] - startinitT
                h.timeline.append(difference.seconds)
        for i, h in enumerate(humans):
            h.timeline = list(dict.fromkeys(h.timeline))
            h.timeline.sort()
        secondFrame = tkinter.Frame(frame_top,
                                    highlightbackground="black",
                                    highlightcolor="black",
                                    highlightthickness=1)
        canvasScroll = tkinter.Canvas(secondFrame)
        scrollbar = tkinter.Scrollbar(secondFrame,
                                      orient="vertical",
                                      command=canvasScroll.yview)
        frameNew = tkinter.Frame(canvasScroll)
        frameNew.bind(
            "<Configure>", lambda e: canvasScroll.configure(
                scrollregion=canvasScroll.bbox("all")))
        canvasScroll.create_window((0, 0), window=frameNew)
        canvasScroll.configure(yscrollcommand=scrollbar.set)
        tkinter.Label(frameNew, font=scrollFontBig, text="Events log").pack()
        tkinter.Label(frameNew,
                      font=scrollFontBig,
                      text="--------------------------").pack()
        numberInfectedFromEntry = int(numberOfInfected.get())
        for ival, i in enumerate(range(numberInfectedFromEntry)):
            humans[i].makeInfected()
            initalCase = "Person " + str(humans[i].humanID) + "\nis infected"
            tkinter.Label(frameNew, font=scrollFontSmall,
                          text=initalCase).pack()
            infectedHumanNumber = infectedHumanNumber + 1
        healthyHumanNumber = HumanCount - infectedHumanNumber
        secondFrame.pack(padx=5, pady=5)
        buttonVisualize = tkinter.Button(
            secondFrame,
            text="Visualize",
            font=fontName,
            command=lambda pathGML=pathGML: visualize())
        buttonVisualize.pack(padx=10, pady=1, side="left")
        canvasScroll.pack(side="left", fill="both", expand=True)
        scrollbar.pack(side="right", fill="y")
        thirdFrame = tkinter.Frame(frame_top,
                                   highlightbackground="black",
                                   highlightcolor="black",
                                   highlightthickness=1)
        tkinter.Label(thirdFrame,
                      font=scrollFontBig,
                      text="Infection case coordinates in each floor").pack(
                          padx=1, pady=1)
        canvas2D = FigureCanvasTkAgg(fig2D, master=thirdFrame)
        thirdFrame.pack(padx=1, pady=1)
        fourthFrame = tkinter.Frame(thirdFrame,
                                    highlightbackground="black",
                                    highlightcolor="black",
                                    highlightthickness=1)
        fourthFrame.pack(side="left", padx=1, pady=1)
        canvas2D.get_tk_widget().pack(side="bottom",
                                      expand=True,
                                      padx=1,
                                      pady=1)
        canvas2D.mpl_connect('button_press_event', ax2D.axes._button_press)
        canvas2D.mpl_connect('button_release_event', ax2D.axes._button_release)
        canvas2D.mpl_connect('motion_notify_event', ax2D.axes._on_move)
        # create button for each floor
        for k, v in floorsAndValues.items():
            if k:
                floorNumber = k
                buttonF = tkinter.Button(
                    fourthFrame,
                    text="Floor" + str(k),
                    font=fontName,
                    command=lambda floorNumber=floorNumber: drawerByFloor(
                        floorNumber))
                buttonF.pack(side="top", fill="x", padx=10, pady=1)
        # setting the dimesions
        ax.set_xlim3d([0.0, max(highAndLowX)])
        ax.set_ylim3d([0.0, max(highAndLowY)])
        ax.set_zlim3d([0.0, max(highAndLowZ) * 3])
        ax.set_box_aspect(
            (max(highAndLowX), max(highAndLowY), max(highAndLowZ) * 3))
        ax2D.set_xlim3d([0.0, max(highAndLowX)])
        ax2D.set_ylim3d([0.0, max(highAndLowY)])
        ax2D.set_zlim3d([0.0, max(highAndLowZ) * 3])
        ax2D.set_box_aspect(
            (max(highAndLowX), max(highAndLowY), max(highAndLowZ) * 3))
        try:
            ax.set_aspect('equal')
        except NotImplementedError:
            pass
        allPoints = []
        allObjects = []
        allPointsDoors = []
        allObjectsDoors = []
        allPoints2D = []
        allObjects2D = []
        allPoints2D_Doors = []
        allObjects2D_Doors = []
        var1 = tkinter.StringVar()
        label1 = tkinter.Label(canvas1, textvariable=var1, font=('Arial', 14))
        var1.set("Total number of people: " + str(HumanCount))
        label1.pack(side="top", padx=5, pady=5)
        var2 = tkinter.StringVar()
        label2 = tkinter.Label(canvas1, textvariable=var2, font=('Arial', 14))
        var2.set("Number of infected people: " + str(infectedHumanNumber))
        label2.pack(side="top", padx=5, pady=5)
        labelNew = tkinter.Label(canvas1, justify='center')
        labelNew.pack()
        alphaVal = 0.3
        lineWidthVal = 0.05
        lineWidthVal2 = 0.05
        drawer(ax, allPoints, allPointsDoors, allObjects, allObjectsDoors,
               True, alphaVal, lineWidthVal, False, 0)
        alphaVal = 0.7
        lineWidthVal = 1
        drawer(ax2D, allPoints2D, allPoints2D_Doors, allObjects2D,
               allObjects2D_Doors, False, alphaVal, lineWidthVal, False, 0)
        ax.set_axis_off()
        ax2D.set_axis_off()
        ax2D.view_init(90)
        global ct, timeArray, f, c, axPie
        ct = [infectedHumanNumber, infectedHumanNumber]
        timeArray = [0, currentDay]
        f = plt.figure(figsize=(6, 4))
        c = f.add_subplot(1, 1, 1)
        c.axis([1, 20, 0, HumanCount])
        caja = plt.Rectangle((0, 0), 100, 100, fill=True)
        cvst, = c.plot(infectedHumanNumber,
                       color="red",
                       label="Infected people")
        c.legend(handles=[cvst])
        c.set_xlabel("Time (days)")
        c.set_ylabel("Infections")
        dataProportion = np.array([healthyHumanNumber, infectedHumanNumber])
        # Creating pie
        myPie, axPie = plt.subplots(figsize=(6, 6))
        axPie.pie(dataProportion,
                  autopct=lambda val: updaterOfValuesAndPercentage(
                      val, dataProportion),
                  explode=explode,
                  labels=labelCondition,
                  shadow=True,
                  colors=colors,
                  wedgeprops=wedgeProp)
        update(ct, timeArray, infectedHumanNumber, healthyHumanNumber)
        Graph(canvas1, f).pack(side="bottom", padx=10, pady=10)
        Graph(canvas1, myPie).pack(side="bottom", padx=10, pady=10)
        canvas.draw()
        canvas2D.draw()
        global anim, anim2D
        anim = FuncAnimation(fig,
                             updateALL,
                             frames=800,
                             interval=0,
                             blit=True,
                             repeat=True,
                             cache_frame_data=False)
        buttonPausingMov = tkinter.Button(
            frame_bottom,
            text="Pause simulation",
            bg='brown',
            fg='white',
            font=fontName,
            command=lambda anim=anim: pauseAnimation(anim))
        buttonPausingMov.pack(padx=2, pady=2, side="left")
        buttonStartingMov = tkinter.Button(
            frame_bottom,
            text="Continue simulation",
            bg='green',
            fg='white',
            font=fontName,
            command=lambda anim=anim: continueAnimation(anim))
        buttonStartingMov.pack(padx=2, pady=2, side="left")
ict = tf.reshape(ict, [1, 79, 159, 4])
result = encoder.predict_on_batch(ict)

# Plot the encoded
fig = Figure(figsize=(5, 5),
             dpi=100,
             facecolor=(0.88, 0.88, 0.88, 1),
             edgecolor=None,
             linewidth=0.0,
             frameon=False,
             subplotpars=None,
             tight_layout=None)
canvas = FigureCanvas(fig)

# Single axes - 10x10 aray plot
ax = fig.add_axes([0.05, 0.05, 0.9, 0.9])
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_axis_off()  # Don't want surrounding x and y axis
x = numpy.linspace(0, 10, 10)
latent_img = ax.pcolorfast(x,
                           x,
                           result[0].reshape(10, 10),
                           cmap='viridis',
                           alpha=1.0,
                           vmin=-3,
                           vmax=3,
                           zorder=20)

# Render the figure as a png
fig.savefig("latent_space.png")
Example #47
0
class MyMplCanvas(mpl_qt.FigureCanvasQTAgg):
    def __init__(self, parent=None):
        self.fig = Figure()
        self.fig.patch.set_facecolor('white')
        super(MyMplCanvas, self).__init__(self.fig)
        self.setParent(parent)
        self.updateGeometry()
        self.setupPlot()
        self.mpl_connect('button_press_event', self.onPress)
        self.img = None
        self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
        self.mouseClickPos = None
        self.beamPos = eval(config.get('beam', 'pos'))

        self.customContextMenuRequested.connect(self.viewMenu)
        self.menu = qt.QMenu()

        self.actionMove = self.menu.addAction('move this point to beam',
                                              self.moveToBeam)

        self.actionDefineBeam = self.menu.addAction(
            'define beam position here', self.setBeamPosition)

        self.actionShowBeam = self.menu.addAction('show beam position',
                                                  self.showBeam)
        self.actionShowBeam.setCheckable(True)
        self.isBeamPositionVisible = True
        self.actionShowBeam.setChecked(self.isBeamPositionVisible)

        self.actionShowRect = self.menu.addAction('show reference rectangle',
                                                  self.showRect)
        self.actionShowRect.setCheckable(True)
        self.isRectVisible = True
        self.actionShowRect.setChecked(self.isRectVisible)

    def setupPlot(self):
        rect = [0., 0., 1., 1.]
        self.axes = self.fig.add_axes(rect)
        self.axes.xaxis.set_visible(False)
        self.axes.yaxis.set_visible(False)
        for spine in ['left', 'right', 'bottom', 'top']:
            self.axes.spines[spine].set_visible(False)
        self.axes.set_zorder(20)

    def imshow(self, img):
        if self.img is None:
            self.img = self.axes.imshow(img)
        else:
            prev = self.img.get_array()
            self.img.set_data(img)
            if prev.shape != img.shape:
                self.img.set_extent(
                    [-0.5, img.shape[1] - 0.5, img.shape[0] - 0.5, -0.5])
                self.axes.set_xlim((0, img.shape[1]))
                self.axes.set_ylim((img.shape[0], 0))
                self.toolbar.update()
        self.draw()

    def onPress(self, event):
        if (event.xdata is None) or (event.ydata is None):
            self.mouseClickPos = None
            return
        self.mouseClickPos = int(round(event.xdata)), int(round(event.ydata))
        if not self.parent().buttonBaseRect.isChecked():
            return
        self.parent().buttonBaseRect.setCorner(*self.mouseClickPos)

    def viewMenu(self, position):
        if self.mouseClickPos is None:
            return
        self.actionDefineBeam.setEnabled(
            not self.parent().buttonStraightRect.isChecked())
        self.actionMove.setEnabled(self.parent().canTransform())
        self.menu.exec_(self.mapToGlobal(position))
        self.parent().updateFrame()

    def setBeamPosition(self):
        if (self.beamPos[0] > 0) or (self.beamPos[1] > 0):
            msgBox = qt.QMessageBox()
            reply = msgBox.question(
                self, 'Confirm',
                'Do you really want to re-define beam position?',
                qt.QMessageBox.Yes | qt.QMessageBox.No, qt.QMessageBox.Yes)
            if reply == qt.QMessageBox.No:
                return
        self.beamPos[:] = self.mouseClickPos
        config.set('beam', 'pos', str(self.beamPos))
        write_config()
        self.parent().buttonStraightRect.update()

    def showBeam(self):
        self.isBeamPositionVisible = not self.isBeamPositionVisible

    def showRect(self):
        self.isRectVisible = not self.isRectVisible

    def moveToBeam(self):
        x, y = self.mouseClickPos
        parent = self.parent()
        xC, yC = parent.beamPosRectified
        if not parent.buttonStraightRect.isChecked():
            xP, yP = parent.transformPoint((x, y))
            x0, y0 = (xP - xC) / parent.zoom, (yP - yC) / parent.zoom
        else:
            x0, y0 = (x - xC) / parent.zoom, (y - yC) / parent.zoom

        if isTest:
            print(-x0, y0)
        else:
            if motorX is not None:
                try:
                    curX = motorX.read_attribute('position').value
                    motorX.write_attribute('position', curX - x0)
                except Exception as e:
                    lines = str(e).splitlines()
                    for line in reversed(lines):
                        if 'desc =' in line:
                            msgBox = qt.QMessageBox()
                            msgBox.critical(self, 'Motion has failed',
                                            line.strip()[7:])
                            return
            if motorY is not None:
                curY = motorY.read_attribute('position').value
                motorY.write_attribute('position', curY + y0)
    dpi=100,
    facecolor="white",
    edgecolor="black",
    linewidth=0.0,
    frameon=False,
    subplotpars=None,
    tight_layout=None,
)
fig.subplots_adjust(left=0.0,
                    right=1.0,
                    bottom=0.0,
                    top=1.0,
                    hspace=0.0,
                    wspace=0.0)
canvas = FigureCanvas(fig)
ax_full = fig.add_axes([0, 0, 1, 1], facecolor="white")


def showImage(filen, spi=1):

    # Load the source image
    fName = "%s/Robot_Rainfall_Rescue/from_Ed/images/%s" % (
        os.getenv("SCRATCH"),
        filen,
    )
    sImage = cv2.imread(fName, cv2.IMREAD_UNCHANGED)
    if sImage is None:
        raise Exception("No such image file %s" % fName)

    # Standardise the size
    sImage = cv2.resize(sImage, (1024, 1632))
Example #49
0
class AnatomicalCanvas(FigureCanvas):
    """Base canvas for anatomical views

    Attributes
    ----------
    point_selected_signal : QtCore.Signal
        Create a event when user clicks on the canvas

    """
    point_selected_signal = QtCore.Signal(float, float, float)
    _horizontal_nav = None
    _vertical_nav = None
    _navigation_state = False
    annotations = []
    last_update = 0
    update_freq = 0.0667
    previous_point = (0, 0)

    def __init__(self,
                 parent,
                 width=8,
                 height=8,
                 dpi=100,
                 crosshair=False,
                 plot_points=False,
                 annotate=False,
                 vertical_nav=False,
                 horizontal_nav=False):
        self._parent = parent
        self._image = parent.image
        self._params = parent.params
        self._crosshair = crosshair
        self._plot_points = plot_points
        self._annotate_points = annotate
        self._vertical_nav = vertical_nav
        self._horizontal_nav = horizontal_nav
        self.position = None

        self._x, self._y, self._z = [
            int(i) for i in self._parent._controller.position
        ]

        self._fig = Figure(figsize=(width, height), dpi=dpi)
        super(AnatomicalCanvas, self).__init__(self._fig)
        FigureCanvas.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding,
                                   QtWidgets.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        self.vmin_updated = self._params.vmin
        self.vmax_updated = self._params.vmax

    def _init_ui(self, data, aspect):
        self._fig.canvas.mpl_connect('button_release_event',
                                     self.on_select_point)
        self._fig.canvas.mpl_connect('scroll_event', self.on_zoom)
        self._fig.canvas.mpl_connect('button_release_event',
                                     self.on_change_intensity)
        self._fig.canvas.mpl_connect('motion_notify_event',
                                     self.on_change_intensity)

        self._axes = self._fig.add_axes([0, 0, 1, 0.9], frameon=True)
        self._axes.axis('off')
        self.view = self._axes.imshow(data,
                                      cmap=self._params.cmap,
                                      interpolation=self._params.interp,
                                      vmin=self._params.vmin,
                                      vmax=self._params.vmax,
                                      alpha=self._params.alpha)
        self._axes.set_aspect(aspect)

        if self._crosshair:
            self.cursor = Cursor(self._axes,
                                 useblit=True,
                                 color='r',
                                 linewidth=1)

        self.points = self._axes.plot([], [], '.r', markersize=7)[0]

    def title(self, message):
        self._fig.suptitle(message, fontsize=10)

    def annotate(self, x, y, label):
        self.annotations.append(
            self._axes.annotate(label,
                                xy=(x, y),
                                xytext=(-3, 3),
                                textcoords='offset points',
                                ha='right',
                                va='bottom',
                                color='r'))

    def clear(self):
        for i in self.annotations:
            i.remove()
        self.annotations = []
        self.points.set_xdata([])
        self.points.set_ydata([])

    def refresh(self):
        self.view.set_clim(vmin=self.vmin_updated, vmax=self.vmax_updated)
        # self.view.set_clim(self._parent._controller.vmin_updated,
        #                    self._parent._controller.vmax_updated)
        logger.debug("vmin_updated=" + str(self.vmin_updated) +
                     ", vmax_updated=" + str(self.vmax_updated))
        self.plot_position()
        self.plot_points()
        self.view.figure.canvas.draw()

    def plot_data(self, xdata, ydata, labels):
        self.points.set_xdata(xdata)
        self.points.set_ydata(ydata)

        if self._annotate_points:
            for x, y, label in zip(xdata, ydata, labels):
                self.annotate(x, y, label)

    def on_zoom(self, event):
        if event.xdata is None or event.ydata is None:
            return

        if event.button == 'up':
            scale_factor = 1.3
        else:
            scale_factor = 1 / 1.3

        x = event.xdata
        y = event.ydata

        x_lim = self._axes.get_xlim()
        y_lim = self._axes.get_ylim()

        left = (x - x_lim[0]) * scale_factor
        right = (x_lim[1] - x) * scale_factor
        top = (y - y_lim[0]) * scale_factor
        bottom = (y_lim[1] - y) * scale_factor

        if x + right - left >= self.x_max or y + bottom - top >= self.y_max:
            return

        self._axes.set_xlim(x - left, x + right)
        self._axes.set_ylim(y - top, y + bottom)
        self.view.figure.canvas.draw()

    def on_select_point(self, event):
        pass

    def on_change_intensity(self, event):
        if event.xdata is None or event.ydata is None:
            return

        if event.button == 3:  # right click
            curr_time = time()

            if curr_time - self.last_update <= self.update_freq:
                # TODO: never enters that loop because last_update set to 0 and it is never updated
                return

            if (abs(event.xdata - self.previous_point[0]) < 1
                    and abs(event.ydata - self.previous_point) < 1):
                # TODO: never enters that loop because previous_point set to 0,0 and it is never updated
                self.previous_point = (event.xdata, event.ydata)
                return

            logger.debug("X=" + str(event.xdata) + ", Y=" + str(event.ydata))
            xlim, ylim = self._axes.get_xlim(), self._axes.get_ylim()
            x_factor = (event.xdata - xlim[0]) / float(
                xlim[1] - xlim[0])  # between 0 and 1. No change: 0.5
            y_factor = (event.ydata - ylim[1]) / float(ylim[0] - ylim[1])

            # get dynamic of the image
            vminvmax = self._params.vmax - self._params.vmin  # todo: get variable based on image quantization

            # adjust brightness by adding offset to image intensity
            # the "-" sign is there so that when moving the cursor to the right, brightness increases (more intuitive)
            # the 2.0 factor maximizes change.
            self.vmin_updated = self._params.vmin - (x_factor -
                                                     0.5) * vminvmax * 2.0
            self.vmax_updated = self._params.vmax - (x_factor -
                                                     0.5) * vminvmax * 2.0

            # adjust contrast by multiplying image dynamic by scaling factor
            # the factor 2.0 maximizes contrast change. For y_factor = 0.5, the scaling will be 1, which means no change
            # in contrast
            self.vmin_updated = self.vmin_updated * (y_factor * 2.0)
            self.vmax_updated = self.vmax_updated * (y_factor * 2.0)

            self.refresh()

    def horizontal_position(self, position):
        if self._horizontal_nav:
            try:
                self._horizontal_nav.remove()
            except AttributeError:
                pass
            self._horizontal_nav = self._axes.axhline(position, color='r')

    def vertical_position(self, position):
        if self._vertical_nav:
            try:
                self._vertical_nav.remove()
            except AttributeError:
                pass
            self._vertical_nav = self._axes.axvline(position, color='r')

    def __repr__(self):
        return '{}: {}, {}, {}'.format(self.__class__, self._x, self._y,
                                       self._z)

    def __str__(self):
        return '{}: {}, {}'.format(self._x, self._y, self._z)
"""
使用matploamlib进行数据绘图
"""
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
# matplotlib命令与格式包括:画布和图像  Figure and Axis
# 创建自定义图像
# figure(num=None,figsize=None,dpi=None,facecolor=None,edgecolor=None,edgecolor=None,frameON=tRUE)
# NUM:图像编号或名称,数字为编号,字符串为名称
# figsize:指定figure的宽和高,单位为英寸
# dpi参数指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80
# facecolor 背景颜色
# edgecolor 边框颜色
# frameon 是否显示边框
fig = Figure()
# 获得绘图对象
canvas = FigureCanvas(fig)
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
line, = ax.plot([0, 2], [0, 2])
# 图表标题
ax.set_title("a straight line ")
# x和 y 轴的标签
ax.set_xlabel("x label")
ax.set_ylabel("y label")
# 指定位置绘制图片
canvas.print_figure('./res/img/chatpic1.jpg')
Example #51
0
class InterpolationInteractive:
    def quit(self):
        """Exit the application."""

        self.window.quit()  # stops mainloop
        self.window.destroy()  # this is necessary on Windows to prevent
        # Fatal Python Error: PyEval_RestoreThread:
        # NULL tstate

    def get_interpolated_quantity(self, star_mass, star_metallicity):
        """Return a callable for plotting an interpolation quantity."""

        if (star_mass < self.track_mass[0] or star_mass > self.track_mass[-1]
                or star_metallicity < self.track_metallicity[0]
                or star_metallicity > self.track_metallicity[-1]):
            return None

        if self.plot_quantity == 'I':
            return InterpolatedQuantitySum(
                self.interpolator_manager.current_interpolator()(
                    'ICONV', star_mass, star_metallicity),
                self.interpolator_manager.current_interpolator()(
                    'IRAD', star_mass, star_metallicity))
        else:
            if self.plot_quantity == 'R': quantity_id = 'radius'
            else: quantity_id = self.plot_quantity
            return self.interpolator_manager.current_interpolator()(
                quantity_id, star_mass, star_metallicity)

    def plot_interpolation(self,
                           star_mass,
                           star_metallicity,
                           plot_func,
                           deriv_order,
                           single_plot=False):
        """Plot an interpolated stellar evolution track."""

        interpolated_quantity = self.get_interpolated_quantity(
            star_mass, star_metallicity)

        try:
            resolution = int(self.resolution_text.get())
        except:
            resolution = 100
        interpolation_ages = numpy.exp(
            numpy.linspace(numpy.log(max(interpolated_quantity.min_age, 1e-5)),
                           numpy.log(interpolated_quantity.max_age),
                           resolution)[1:-1])

        plot_x = age_transform(star_mass, star_metallicity, interpolation_ages)
        if deriv_order > 0:
            derivatives = interpolated_quantity.deriv(interpolation_ages)
            plot_y = derivatives[deriv_order if single_plot else 0]
        else:
            plot_y = interpolated_quantity(interpolation_ages)

        plot_func(plot_x, plot_y, self.interp_plot_style.get())

        if single_plot: return

        if deriv_order > 0:
            d1_y = numpy.copy(derivatives[1])
            if deriv_order > 1:
                d2_y = numpy.copy(derivatives[2])

            if self.logy:
                d1_y /= plot_y
                if deriv_order > 1:
                    d2_y = (d2_y / plot_y - d1_y**2)

            if self.logx:
                deriv_plot_funcname = 'semilogx'
                d1_y *= plot_x
                if deriv_order > 1:
                    d2_y = d1_y + plot_x**2 * d2_y
            else:
                deriv_plot_funcname = 'plot'

            getattr(self.first_deriv_axes,
                    deriv_plot_funcname)(plot_x, d1_y,
                                         self.interp_plot_style.get())
            if deriv_order > 1:
                getattr(self.second_deriv_axes,
                        deriv_plot_funcname)(plot_x, d2_y,
                                             self.interp_plot_style.get())

    def separate_window(self, deriv_order):
        """Spawn a new window plotting curves of given order derivative."""

        if self.logx and self.logy: plot = matplotlib.pyplot.loglog
        elif self.logx and not self.logy: plot = matplotlib.pyplot.semilogx
        elif not self.logx and self.logy: plot = matplotlib.pyplot.semilogy
        else: plot = matplotlib.pyplot.plot
        self.plot_interpolation(self.interp['mass'],
                                self.interp['metallicity'], plot, deriv_order,
                                True)
        matplotlib.pyplot.show()

    def display(self):
        """(Re-)draw the plot as currently configured by the user."""

        if self.do_not_display: return

        exec(
            'def age_transform(m, feh, t) : return ' +
            self.age_transform_entry.get(), globals())
        main_x_lim = self.main_axes.get_xlim()
        main_y_lim = self.main_axes.get_ylim()

        self.main_axes.cla()
        if self.first_deriv_axes is not None:
            first_deriv_xlim = self.first_deriv_axes.get_xlim()
            first_deriv_ylim = self.first_deriv_axes.get_ylim()
            self.first_deriv_axes.cla()
        if self.second_deriv_axes is not None:
            second_deriv_xlim = self.second_deriv_axes.get_xlim()
            second_deriv_ylim = self.second_deriv_axes.get_ylim()
            self.second_deriv_axes.cla()

        if self.logx and self.logy: plot = self.main_axes.loglog
        elif self.logx and not self.logy: plot = self.main_axes.semilogx
        elif not self.logx and self.logy: plot = self.main_axes.semilogy
        else: plot = self.main_axes.plot

        if self.interpolate:
            self.plot_interpolation(self.interp['mass'],
                                    self.interp['metallicity'], plot,
                                    self.deriv)

        for track_mass_index, track_mass in enumerate(self.track_mass):
            nearby_track_mass = (track_mass_index == self.track_below['mass']
                                 or track_mass_index
                                 == self.track_below['mass'] + 1)
            for track_metallicity_index, track_metallicity in enumerate(
                    self.track_metallicity):
                nearby_track_metallicity = (
                    track_metallicity_index
                    == (self.track_below['metallicity'])
                    or track_metallicity_index
                    == (self.track_below['metallicity'] + 1))
                if (self.track_state[track_mass][track_metallicity].get()
                        or (nearby_track_mass and nearby_track_metallicity)):
                    track_mass = self.track_mass[track_mass_index]
                    track_metallicity = self.track_metallicity[
                        track_metallicity_index]

                    if (self.track_state[track_mass][track_metallicity]
                            and self.interpolate):
                        self.plot_interpolation(track_mass, track_metallicity,
                                                plot, self.deriv)

                    plot(
                        age_transform(
                            track_mass, track_metallicity,
                            self.tracks[track_mass][track_metallicity]['t']),
                        self.tracks[track_mass][track_metallicity][
                            self.plot_quantity], self.track_plot_style.get())
        if not self.main_auto_axes:
            self.main_axes.set_xlim(main_x_lim)
            self.main_axes.set_ylim(main_y_lim)
        if self.first_deriv_axes is not None:
            if self.first_deriv_auto_axes:
                self.first_deriv_axes.set_xlim(self.main_axes.get_xlim())
            else:
                self.first_deriv_axes.set_xlim(first_deriv_xlim)
                self.first_deriv_axes.set_ylim(first_deriv_ylim)
            if self.second_deriv_axes is not None:
                if self.second_deriv_auto_axes:
                    self.second_deriv_axes.set_xlim(self.main_axes.get_xlim())
                else:
                    self.second_deriv_axes.set_xlim(second_deriv_xlim)
                    self.second_deriv_axes.set_ylim(second_deriv_ylim)
        self.main_auto_axes = False
        self.first_deriv_auto_axes = False
        self.second_deriv_auto_axes = False
        self.main_canvas.show()

    def on_key_event(self, event):
        """Handle standard matplotlib key presses."""

        key_press_handler(event, self.main_canvas, self.toolbar)

    def toggle_log_y(self):
        """Switch between log and linear scale for the y axis and refresh."""

        self.logy = not self.logy
        self.logy_button.config(relief=Tk.SUNKEN if self.logy else Tk.RAISED)
        self.main_auto_axes = True
        self.first_deriv_auto_axes = True
        self.second_deriv_auto_axes = True
        self.display()

    def toggle_log_x(self):
        """Switch between log and linear scale for the x axis and refresh."""

        self.logx = not self.logx
        self.logx_button.config(relief=Tk.SUNKEN if self.logx else Tk.RAISED)
        self.main_auto_axes = True
        self.first_deriv_auto_axes = True
        self.second_deriv_auto_axes = True
        self.display()

    def toggle_interpolation(self):
        """Switch between displaying and not an interpolated track."""

        self.interpolate = not self.interpolate
        self.interpolate_button.config(
            relief=Tk.SUNKEN if self.interpolate else Tk.RAISED)
        self.display()

    def toggle_deriv(self, order):
        """React appropriately toggling the derivative of the given order."""

        if self.deriv >= order: self.deriv = order - 1
        else: self.deriv = order

        if self.first_deriv_axes is not None:
            self.main_figure.delaxes(self.first_deriv_axes)
            self.first_deriv_axes = None
        if self.second_deriv_axes is not None:
            self.main_figure.delaxes(self.second_deriv_axes)
            self.second_deriv_axes = None

        if self.deriv == 0:
            self.main_axes.set_position([0.05, 0.05, 0.9, 0.9])
        else:
            self.main_axes.set_position([0.05, 0.5, 0.9, 0.45])
            if self.deriv == 1:
                self.first_deriv_axes = self.main_figure.add_axes(
                    [0.05, 0.05, 0.9, 0.45])
            else:
                assert (self.deriv == 2)
                self.first_deriv_axes = self.main_figure.add_axes(
                    [0.05, 0.05, 0.45, 0.45])
                self.second_deriv_axes = self.main_figure.add_axes(
                    [0.5, 0.05, 0.45, 0.45])

        self.first_deriv_button.config(
            relief=Tk.RAISED if self.deriv < 1 else Tk.SUNKEN)
        self.second_deriv_button.config(
            relief=Tk.RAISED if self.deriv < 2 else Tk.SUNKEN)
        self.first_deriv_auto_axes = True
        self.second_deriv_auto_axes = True
        self.display()

    def change_plot_quantity(self, plot_quantity=None):
        """Modify y axis controls for the new quantity and refresh."""

        if plot_quantity is None:
            plot_quantity = self.selected_plot_quantity.get()
        self.plot_quantity = plot_quantity

        self.first_deriv_button.config(state=(
            Tk.DISABLED if self.max_deriv[plot_quantity] == 0 else Tk.NORMAL))
        self.second_deriv_button.config(state=(
            Tk.DISABLED if self.max_deriv[plot_quantity] < 2 else Tk.NORMAL))
        self.main_auto_axes = True
        self.first_deriv_auto_axes = True
        self.second_deriv_auto_axes = True
        self.change_interpolated_quantity()
        self.display()

    def change_interpolated_quantity(self):
        """Re-creates the interpolated quantity per the current settings."""

    def change_interp(self, quantity, input_new_value):
        """
        Modify the stellar mass or [Fe/H] for the displayed interpolation.

        Args:
            - quantity: One of 'mass' or 'metallicity'.
            - input_new_value: The new value to set. Should be convertible
                               to float.

        Returns: 
            - None: if the quantity was not changed due to it still being
                    within the last change timeout.
            - True: if the quantity was actually changed.
        """

        if self.changing[quantity]: return

        new_value = float(input_new_value)
        self.interp[quantity] = new_value
        track_quantities = getattr(self, 'track_' + quantity)
        index_below = len(track_quantities) - 2
        while (index_below >= 0 and track_quantities[index_below] > new_value):
            index_below -= 1
        self.track_below[quantity] = index_below
        self.track_below[quantity] = -2

        self.fine_interp_scale[quantity].config(
            from_=(track_quantities[index_below] + 1e-5),
            to=(track_quantities[index_below + 1] - 1e-5))

        self.coarse_interp_scale[quantity].set(new_value)
        self.fine_interp_scale[quantity].set(new_value)

        self.change_interpolated_quantity()
        self.changing[quantity] = True

        if self.display_job: self.window.after_cancel(self.display_job)
        self.display_job = self.window.after(10, self.display)

        self.changing[quantity] = False
        return True

    def toggle_all_tracks(self):
        """Enable displaying all tracks."""

        selected = (self.all_tracks_button.cget('relief') == Tk.RAISED)
        self.all_tracks_button.config(
            relief=(Tk.SUNKEN if selected else Tk.RAISED),
            text=('None' if selected else 'All'))

        for button in self.track_mass_button.values():
            button.config(relief=(Tk.SUNKEN if selected else Tk.RAISED))
        for button in self.track_metallicity_button.values():
            button.config(relief=(Tk.SUNKEN if selected else Tk.RAISED))

        for mass_track_state in self.track_state.values():
            for track_state in mass_track_state.values():
                track_state.set(selected)

        self.display()

    def toggle_track_mass(self, mass):
        """Select/deselect displaying all tracks with the given mass."""

        selected = (self.track_mass_button[mass].cget('relief') == Tk.RAISED)
        self.track_mass_button[mass].config(
            relief=(Tk.SUNKEN if selected else Tk.RAISED))
        for track_state in self.track_state[mass].values():
            track_state.set(selected)

        self.display()

    def toggle_track_metallicity(self, metallicity):
        """Select/deselect displaying all tracks with the given [Fe/H]."""

        selected = (self.track_metallicity_button[metallicity].cget('relief')
                    == Tk.RAISED)
        self.track_metallicity_button[metallicity].config(
            relief=(Tk.SUNKEN if selected else Tk.RAISED))
        for mass_track_state in self.track_state.values():
            if metallicity in mass_track_state:
                mass_track_state[metallicity].set(selected)

        self.display()

    def auto_axes(self):
        """Re-plot letting matplotlib determine the axes limits."""

        self.main_auto_axes = True
        self.first_deriv_auto_axes = True
        self.second_deriv_auto_axes = True
        self.display()

    def __init__(self, window, tracks_dir):
        """Setup user controls and display frame."""
        def create_main_axes():
            """Create a figure and add an axes to it for drawing."""

            self.main_figure = Figure(figsize=(5, 4), dpi=100)
            self.main_axes = self.main_figure.add_axes((0.05, 0.05, 0.9, 0.9))
            self.first_deriv_axes = None
            self.second_deriv_axes = None

        def create_axes_controls():
            """Create controls for plot quantity and log axes."""
            def create_x_controls():
                """Create the controls for the x axis."""

                x_controls_frame = Tk.Frame(window)
                x_controls_frame.grid(row=3, column=2)

                self.logx_button = Tk.Button(
                    x_controls_frame,
                    text='log10',
                    command=self.toggle_log_x,
                    relief=Tk.SUNKEN if self.logx else Tk.RAISED)
                self.age_transform_entry = Tk.Entry(x_controls_frame, )
                self.age_transform_entry.insert(
                    0, 't * (1.0 + (t / 5.0) * m**5 * 10.0**(-0.2*feh))'
                    '* m**2.3 * 10.0**(-0.4*feh)')
                self.logx_button.grid(row=0, column=0)
                self.age_transform_entry.grid(row=0, column=1)

            def create_y_controls():
                """Create the controls for the y axis."""

                y_controls_frame = Tk.Frame(window)
                y_controls_frame.grid(row=1, column=1, rowspan=2)

                self.selected_plot_quantity = Tk.StringVar()
                self.selected_plot_quantity.set(self.plot_quantities[0])
                self.plot_quantity_menu = Tk.OptionMenu(
                    y_controls_frame,
                    self.selected_plot_quantity,
                    *self.plot_quantities,
                    command=self.change_plot_quantity)
                self.logy_button = Tk.Button(
                    y_controls_frame,
                    text='log10',
                    command=self.toggle_log_y,
                    relief=Tk.SUNKEN if self.logy else Tk.RAISED)
                self.first_deriv_button = Tk.Button(y_controls_frame,
                                                    text='d/dt',
                                                    command=functools.partial(
                                                        self.toggle_deriv, 1))
                self.second_deriv_button = Tk.Button(y_controls_frame,
                                                     text='d/dt',
                                                     command=functools.partial(
                                                         self.toggle_deriv, 2))
                self.logy_button.grid(row=0, column=0)
                self.first_deriv_button.grid(row=1, column=0)
                self.second_deriv_button.grid(row=2, column=0)
                self.plot_quantity_menu.grid(row=3, column=0)

            create_x_controls()
            create_y_controls()
            Tk.Button(window,
                      text='Auto Axes',
                      command=self.auto_axes,
                      relief=Tk.RAISED).grid(row=1,
                                             column=3,
                                             sticky=Tk.N + Tk.S + Tk.W + Tk.E)

        def create_interp_controls(interp_control_frame):
            """Create the controls for the interpolation mass & [Fe/H]."""

            self.coarse_interp_scale = dict()
            self.fine_interp_scale = dict()
            for index, quantity in enumerate(['mass', 'metallicity']):
                self.coarse_interp_scale[quantity] = Tk.Scale(
                    interp_control_frame,
                    from_=getattr(self, 'track_' + quantity)[0],
                    to=getattr(self, 'track_' + quantity)[-1],
                    resolution=-1,
                    length=1000,
                    orient=Tk.HORIZONTAL,
                    command=functools.partial(self.change_interp, quantity),
                    digits=6)

                self.fine_interp_scale[quantity] = Tk.Scale(
                    interp_control_frame,
                    resolution=-1,
                    length=1000,
                    orient=Tk.HORIZONTAL,
                    command=functools.partial(self.change_interp, quantity),
                    digits=6)

                self.coarse_interp_scale[quantity].grid(row=3 * index,
                                                        column=1)
                self.fine_interp_scale[quantity].grid(row=3 * index + 1,
                                                      column=1)
            Tk.Label(interp_control_frame, text='M*/Msun').grid(row=0,
                                                                column=0,
                                                                rowspan=2)
            Tk.Label(interp_control_frame, text='[Fe/H]').grid(row=3,
                                                               column=0,
                                                               rowspan=2)
            ttk.Separator(interp_control_frame,
                          orient=Tk.HORIZONTAL).grid(row=2,
                                                     column=0,
                                                     columnspan=2,
                                                     sticky="ew")

        def create_track_selectors(track_selectors_frame):
            """Create buttons to enable/disable tracks to display."""

            self.all_tracks_button = Tk.Button(track_selectors_frame,
                                               text='All',
                                               command=self.toggle_all_tracks,
                                               relief=Tk.RAISED)
            self.all_tracks_button.grid(row=0,
                                        column=0,
                                        rowspan=2,
                                        columnspan=2)

            Tk.Label(track_selectors_frame,
                     text='M*/Msun').grid(row=1,
                                          column=0,
                                          columnspan=len(self.track_mass))
            Tk.Label(
                track_selectors_frame,
                text='[Fe/H]',
            ).grid(row=0, column=1, rowspan=len(self.track_metallicity))

            self.track_state = dict()
            self.track_mass_button = dict()
            self.track_metallicity_button = dict()
            for row, mass in enumerate(self.track_mass):
                self.track_state[mass] = dict()
                self.track_mass_button[mass] = Tk.Button(
                    track_selectors_frame,
                    text='%.3f' % mass,
                    command=functools.partial(self.toggle_track_mass, mass),
                    relief=Tk.RAISED)
                self.track_mass_button[mass].grid(row=2 + row, column=1)
                for column, metallicity in enumerate(self.track_metallicity):
                    if row == 0:
                        self.track_metallicity_button[metallicity] = \
                            Tk.Button(
                                track_selectors_frame,
                                text = '%.3f' % metallicity,
                                command = functools.partial(
                                    self.toggle_track_metallicity,
                                    metallicity
                                ),
                                relief = Tk.RAISED
                            )
                        self.track_metallicity_button[metallicity].grid(
                            row=1, column=2 + column)
                    self.track_state[mass][metallicity] = Tk.IntVar()
                    Tk.Checkbutton(
                        track_selectors_frame,
                        text='',
                        variable=self.track_state[mass][metallicity],
                        command=self.display).grid(row=2 + row,
                                                   column=2 + column)

        def create_curve_controls(curve_control_frame):
            """Create controls for modifying plotting."""

            Tk.Button(curve_control_frame, text='Replot',
                      command=self.display).grid(row=0, column=0, columnspan=2)
            curve_setup_frame = Tk.Frame(curve_control_frame)
            curve_setup_frame.grid(row=1, column=0)
            Tk.Label(curve_setup_frame, text='Resolution:').grid(row=1,
                                                                 column=0)
            Tk.Entry(curve_setup_frame,
                     textvariable=self.resolution_text).grid(row=1, column=1)
            Tk.Label(curve_setup_frame, text='Track style:').grid(row=2,
                                                                  column=0)
            Tk.Entry(curve_setup_frame,
                     textvariable=self.track_plot_style).grid(row=2, column=1)

            Tk.Label(curve_setup_frame, text='Interp style:').grid(row=3,
                                                                   column=0)
            Tk.Entry(curve_setup_frame,
                     textvariable=self.interp_plot_style).grid(row=3, column=1)

            isolate_frame = Tk.Frame(curve_control_frame)
            isolate_frame.grid(row=1, column=1)
            Tk.Label(isolate_frame, text='Separate Window').grid(row=0,
                                                                 column=0)
            Tk.Button(isolate_frame,
                      text='Main curve',
                      command=functools.partial(self.separate_window,
                                                0)).grid(row=1, column=0)
            Tk.Button(isolate_frame,
                      text='First deriv',
                      command=functools.partial(self.separate_window,
                                                1)).grid(row=2, column=0)
            Tk.Button(isolate_frame,
                      text='Second deriv',
                      command=functools.partial(self.separate_window,
                                                2)).grid(row=3, column=0)

        def create_main_canvas(plot_frame):
            """Create the canvas for plotting undifferentiated quantities."""

            self.main_canvas = FigureCanvasTkAgg(self.main_figure,
                                                 master=plot_frame)
            self.main_canvas.show()
            self.main_canvas.get_tk_widget().pack(side=Tk.TOP,
                                                  fill=Tk.BOTH,
                                                  expand=1)
            self.toolbar = NavigationToolbar2TkAgg(self.main_canvas,
                                                   plot_frame)
            self.toolbar.update()
            self.main_canvas._tkcanvas.pack(side=Tk.TOP,
                                            fill=Tk.BOTH,
                                            expand=1)
            self.main_canvas.mpl_connect('key_press_event', self.on_key_event)

        def set_initial_state():
            """Set initial states for member variables."""

            self.main_auto_axes = True
            self.first_deriv_auto_axes = True
            self.second_deriv_auto_axes = True
            self.do_not_display = True
            self.display_job = None
            self.plot_quantities = [
                'Iconv', 'Irad', 'I', 'R', 'Lum', 'Rrad', 'Mrad'
            ]
            self.max_deriv = dict(Iconv=2,
                                  Irad=2,
                                  I=2,
                                  R=1,
                                  Lum=0,
                                  Rrad=1,
                                  Mrad=2)
            self.logx = True
            self.logy = True
            self.interpolate = False
            self.deriv = 0
            self.tracks = read_MESA(tracks_dir)
            self.track_mass = sorted(self.tracks.keys())
            self.track_metallicity = set()
            for mass_tracks in self.tracks.values():
                self.track_metallicity.update(mass_tracks.keys())
            self.track_metallicity = sorted(list(self.track_metallicity))
            self.interp = dict(mass=1.0, metallicity=0.0)
            self.enabled_tracks = [[False for feh in self.track_metallicity]
                                   for m in self.track_mass]

            self.resolution_text = Tk.StringVar()
            self.resolution_text.set('100')
            self.interp_plot_style = Tk.StringVar()
            self.interp_plot_style.set('.r')
            self.track_plot_style = Tk.StringVar()
            self.track_plot_style.set('xk')

            self.changing = dict(mass=False, metallicity=False)

        def configure_window():
            """Arrange the various application elements."""

            self.window = window

            plot_frame = Tk.Frame(window)
            plot_frame.grid(row=1,
                            column=2,
                            rowspan=2,
                            sticky=Tk.N + Tk.S + Tk.W + Tk.E)

            interp_control_frame = Tk.Frame(window)
            interp_control_frame.grid(row=0, column=1, columnspan=3)

            plot_control_frame = Tk.Frame(window)
            plot_control_frame.grid(row=2, column=3)
            track_selectors_frame = Tk.Frame(plot_control_frame)
            track_selectors_frame.grid(row=0, column=0)
            curve_control_frame = Tk.Frame(plot_control_frame)
            curve_control_frame.grid(row=1, column=0)

            Tk.Grid.columnconfigure(window, 2, weight=1)
            Tk.Grid.rowconfigure(window, 1, weight=1)

            create_main_axes()
            create_main_canvas(plot_frame)
            create_axes_controls()
            create_interp_controls(interp_control_frame)
            create_track_selectors(track_selectors_frame)
            create_curve_controls(curve_control_frame)

            self.interpolate_button = Tk.Button(
                window,
                text='Interpolate',
                command=self.toggle_interpolation,
                relief=Tk.SUNKEN if self.interpolate else Tk.RAISED)
            self.interpolate_button.grid(row=0,
                                         column=0,
                                         sticky=Tk.N + Tk.S + Tk.W + Tk.E)

            interpolator_manager_frame = Tk.Frame(window)
            interpolator_manager_frame.grid(row=1, column=0, rowspan=2)
            self.interpolator_manager = InterpolatorManagerGUI(
                interpolator_manager_frame, serialized_interpolator_dir)

            self.track_below = dict()

        set_initial_state()
        configure_window()

        self.change_plot_quantity(self.plot_quantities[0])
        self.change_interp('mass', self.interp['mass'])
        self.change_interp('metallicity', self.interp['metallicity'])
        self.do_not_display = False
        self.display()
Example #52
0
canvas = FigureCanvas(fig)
matplotlib.rc('image', aspect='auto')


def add_latline(ax, latitude):
    latl = (latitude + 90) / 180
    ax.add_line(
        Line2D([start.timestamp(), end.timestamp()], [latl, latl],
               linewidth=0.5,
               color=(0.8, 0.8, 0.8, 1),
               zorder=200))


# Add a textured grey background
s = (2000, 600)
ax2 = fig.add_axes([0, 0.05, 1, 0.95], facecolor='green')
ax2.set_axis_off()
nd2 = numpy.random.rand(s[1], s[0])
clrs = []
for shade in numpy.linspace(.42 + .01, .36 + .01):
    clrs.append((shade, shade, shade, 1))
y = numpy.linspace(0, 1, s[1])
x = numpy.linspace(0, 1, s[0])
img = ax2.pcolormesh(x,
                     y,
                     nd2,
                     cmap=matplotlib.colors.ListedColormap(clrs),
                     alpha=1.0,
                     shading='gouraud',
                     zorder=10)
    edgecolor=None,
    linewidth=0.0,
    frameon=False,
    subplotpars=None,
    tight_layout=None,
)
canvas = FigureCanvas(fig)
font = {
    "family": "sans-serif",
    "sans-serif": "Arial",
    "weight": "normal",
    "size": 12
}
matplotlib.rc("font", **font)

ax_full = fig.add_axes([0, 0, 1, 1])
ax_full.set_axis_off()
ax_full.add_patch(
    Rectangle((0, 0), 1, 1, facecolor=(1, 1, 1, 1), fill=True, zorder=1))

t2m = twcr.load("air.2m", dte,
                version="4.6.1").extract(iris.Constraint(member=0))
u10m = twcr.load("uwnd.10m", dte,
                 version="4.6.1").extract(iris.Constraint(member=0))
v10m = twcr.load("vwnd.10m", dte,
                 version="4.6.1").extract(iris.Constraint(member=0))
precip = twcr.load("prate", dte,
                   version="4.6.1").extract(iris.Constraint(member=0))
precip = normalise_precip(precip)
obs = twcr.load_observations_fortime(dte, version="4.6.1")
# prmsl all members for spread
Example #54
0
class FitPlot():

    plt_time = 0
    plt_radius = 0
    fsize3d = 16
    fsize = matplotlib.rcParams['font.size']

    #ts_revisions = []
    edge_discontinuties = []
    core_discontinuties = []

    tbeg = 0
    tend = 7
    picked = False
    grid = False
    logy = False
    m2g = None

    def __init__(self, parent, fit_frame):
        self.parent = parent

        self.fit_frame = fit_frame

        self.rstride = 1
        self.cstride = 3

        self.tstep = None
        self.xlab = ''
        self.ylab = ''
        self.ylab_diff = ''

    def isfloat(self, num):

        if num == '':
            return True
        try:
            float(num)
            return True
        except:
            return False

    def update_axis_range(self, tbeg, tend):
        if tbeg is None or tend is None:
            return
        self.main_slider.valmin = tbeg
        self.main_slider.valmax = tend
        self.sl_ax_main.set_xlim(tbeg, tend)
        self.ax_main.set_xlim(self.options['rho_min'], self.options['rho_max'])
        self.tbeg = tbeg
        self.tend = tend

    def change_set_prof_load(self):
        #update fit figure if the fitted quantity was changed

        self.sl_eta.set_val(self.options['eta'])
        self.sl_lam.set_val(self.options['lam'])

        if self.options['data_loaded']:
            self.init_plot()
            self.changed_fit_slice()
        else:
            self.ax_main.cla()
            self.ax_main.grid(self.grid)
            self.fig.canvas.draw_idle()

    def init_plot_data(self, prof, data_d, elms, mhd_modes):

        #merge all diagnostics together
        unit, labels = '', []
        data_rho, plot_rho, data, data_err,weights, data_tvec, plot_tvec,diags = [],[],[],[],[],[],[],[]
        #channel and point index for later identification
        ind_channels, ind_points = [], []
        n_ch, n_points = 0, 0
        for ch in data_d['data']:
            if prof not in ch: continue
            d = ch[prof].values
            data.append(d)
            err = np.copy(ch[prof + '_err'].values)
            #NOTE negative values are set to be masked
            mask = err <= 0
            #negative infinite ponts will not be shown
            err[np.isfinite(err) & mask] *= -1
            data_err.append(np.ma.array(err, mask=mask))
            data_rho.append(ch['rho'].values)
            data_tvec.append(
                np.tile(ch['time'].values,
                        data_rho[-1].shape[:0:-1] + (1, )).T)
            plot_tvec.append(np.tile(ch['time'].values, d.shape[1:] + (1, )).T)

            s = d.shape
            dch = 1 if len(s) == 1 else s[1]
            ind_channels.append(
                np.tile(np.arange(dch, dtype='uint32') + n_ch, (s[0], 1)))
            n_ch += dch
            ind_points.append(
                np.tile(
                    n_points +
                    np.arange(d.size, dtype='uint32').reshape(d.shape).T,
                    data_rho[-1].shape[d.ndim:] + ((1, ) * d.ndim)).T)
            n_points += d.size

            if 'weight' in ch:
                #non-local measurements
                weights.append(ch['weight'].values)
                plot_rho.append(ch['rho_tg'].values)
            else:
                weights.append(np.ones_like(d))
                plot_rho.append(ch['rho'].values)

            diags.append(ch['diags'].values)
            labels.append(ch[prof].attrs['label'])

        unit = ch[prof].attrs['units']

        if n_ch == 0 or n_points == 0:
            print('No data !! Try to extend time range')
            return

        diag_names = data_d['diag_names']
        label = ','.join(np.unique(labels))

        self.elms = elms
        self.mhd_modes = mhd_modes

        self.options['data_loaded'] = True
        self.options['fitted'] = False
        rho = np.hstack([r.ravel() for r in data_rho])
        y = np.hstack([d.ravel() for d in data])
        yerr = np.ma.hstack([de.ravel() for de in data_err])

        self.channel = np.hstack([ch.ravel() for ch in ind_channels])
        points = np.hstack([p.ravel() for p in ind_points])
        weights = np.hstack([w.ravel() for w in weights])
        tvec = np.hstack([t.ravel() for t in data_tvec])
        diags = np.hstack([d.ravel() for d in diags])
        self.plot_tvec = np.hstack([t.ravel() for t in plot_tvec])
        self.plot_rho = np.hstack([r.ravel() for r in plot_rho])
        self.options['rho_min'] = np.minimum(
            0, np.maximum(-1.1, self.plot_rho.min()))
        diag_dict = {d: i for i, d in enumerate(diag_names)}
        self.ind_diag = np.array([diag_dict[d] for d in diags])
        self.diags = diag_names

        if self.parent.elmsphase:
            #epl phase
            self.plot_tvec = np.interp(self.plot_tvec, self.elms['tvec'],
                                       self.elms['data'])
            tvec = np.interp(tvec, self.elms['tvec'], self.elms['data'])

        if self.parent.elmstime:
            #time from nearest ELM
            self.plot_tvec -= self.elms['elm_beg'][
                self.elms['elm_beg'].searchsorted(self.plot_tvec) - 1]
            tvec -= self.elms['elm_beg'][
                self.elms['elm_beg'].searchsorted(tvec) - 1]

        tstep = 'None'
        if self.tstep is None:
            tstep = float(data_d['tres'])

        self.parent.set_trange(np.amin(tvec), np.amax(tvec), tstep)

        self.tres = self.tstep

        #plot timeslice nearest to the original location where are some data
        self.plt_time = tvec[np.argmin(abs(self.plt_time - tvec))]

        self.ylab = r'$%s\ [\mathrm{%s}]$' % (label, unit)
        xlab = self.options['rho_coord'].split('_')
        self.xlab = xlab[0]
        if self.options['rho_coord'][:3] in ['Psi', 'rho']:
            self.xlab = '\\' + self.xlab
        if len(xlab) > 1:
            self.xlab += '_{' + xlab[1] + '}'
        self.xlab = '$' + self.xlab + '$'

        self.ylab_diff = r'$R/L_{%s}/\rho\ [-]$' % label

        #create object of the fitting routine
        self.m2g = map2grid(rho, tvec, y, yerr, points, weights,
                            self.options['nr_new'], self.tstep)
        self.options['fit_prepared'] = False
        self.options['zeroed_outer'] = False
        self.options['elmrem_ind'] = False

        self.init_plot()
        self.changed_fit_slice()

    def init_plot(self):
        #clear and inicialize the main plot with the fits

        self.ax_main.cla()
        self.ax_main.grid(self.grid)

        self.ax_main.ticklabel_format(style='sci', scilimits=(-2, 2), axis='y')

        self.ax_main.set_ylabel(self.ylab, fontsize=self.fsize + 2)
        self.ax_main.set_xlabel(self.xlab, fontsize=self.fsize + 2)

        #the plots inside
        colors = matplotlib.cm.brg(np.linspace(0, 1, len(self.diags)))
        self.plotline, self.caplines, self.barlinecols = [], [], []
        self.replot_plot = [
            self.ax_main.plot([], [], '+', c=c)[0] for c in colors
        ]

        for i, d in enumerate(self.diags):
            plotline, caplines, barlinecols = self.ax_main.errorbar(
                0,
                np.nan,
                0,
                fmt='.',
                capsize=4,
                label=d,
                c=colors[i],
                zorder=1)
            self.plotline.append(plotline)
            self.caplines.append(caplines)
            self.barlinecols.append(barlinecols)

        self.fit_plot, = self.ax_main.plot([], [],
                                           'k-',
                                           linewidth=.5,
                                           zorder=2)
        nr = self.options['nr_new']
        self.fit_confidence = self.ax_main.fill_between(np.arange(nr),
                                                        np.zeros(nr),
                                                        np.zeros(nr),
                                                        alpha=.2,
                                                        facecolor='k',
                                                        edgecolor='None',
                                                        zorder=0)

        self.lcfs_line = self.ax_main.axvline(1, ls='--', c='k', visible=False)
        self.zero_line = self.ax_main.axhline(0, ls='--', c='k', visible=False)

        self.core_discontinuties = [
            self.ax_main.axvline(t, ls='-', lw=.5, c='k', visible=False)
            for t in eval(self.fit_options['sawteeth_times'].get())
        ]
        self.edge_discontinuties = [
            self.ax_main.axvline(t, ls='-', lw=.5, c='k', visible=False)
            for t in self.elms['elm_beg']
        ]

        if self.mhd_modes is not None:
            self.mhd_locations = {
                mode: self.ax_main.axvline(np.nan,
                                           ls='-',
                                           lw=.5,
                                           c='k',
                                           visible=False)
                for mode in self.mhd_modes['modes']
            }
            self.mhd_labels = {
                mode: self.ax_main.text(np.nan, 0, mode, visible=False)
                for mode in self.mhd_modes['modes']
            }

        self.spline_mean, = self.ax_main.plot([], [],
                                              '--',
                                              c='.5',
                                              linewidth=1,
                                              visible=False)
        self.spline_min, = self.ax_main.plot([], [],
                                             ':',
                                             c='.5',
                                             linewidth=1,
                                             visible=False)
        self.spline_max, = self.ax_main.plot([], [],
                                             ':',
                                             c='.5',
                                             linewidth=1,
                                             visible=False)

        leg = self.ax_main.legend(fancybox=True, loc='upper right')
        leg.get_frame().set_alpha(0.9)
        try:
            leg.set_draggable(True)
        except:
            leg.draggable()

        #make the legend interactive
        self.leg_diag_ind = {}
        for idiag, legline in enumerate(leg.legendHandles):
            legline.set_picker(20)  # 20 pts tolerance
            self.leg_diag_ind[legline] = idiag

        description = self.parent.device + ' %d' % self.shot
        self.plot_description = self.ax_main.text(
            1.01,
            .05,
            description,
            rotation='vertical',
            transform=self.ax_main.transAxes,
            verticalalignment='bottom',
            size=10,
            backgroundcolor='none',
            zorder=100)

        title_template = '%s, time: %.3fs'  # prints running simulation time
        self.time_text = self.ax_main.text(.05,
                                           .95,
                                           '',
                                           transform=self.ax_main.transAxes)
        self.chi2_text = self.ax_main.text(.05,
                                           .90,
                                           '',
                                           transform=self.ax_main.transAxes)

        self.fit_options['dt'].set('%.2g' % (self.tres * 1e3))
        self.click_event = {1: None, 2: None, 3: None}

        def line_select_callback(eclick, erelease):
            #'eclick and erelease are the press and release events'

            click_event = self.click_event[eclick.button]
            #make sure that this event is "unique", due to some bug in matplolib
            if click_event is None or click_event.xdata != eclick.xdata:
                self.click_event[eclick.button] = eclick
                x1, y1 = eclick.xdata, eclick.ydata
                x2, y2 = erelease.xdata, erelease.ydata

                #delete/undelet selected points
                undelete = eclick.button == 3
                what = 'channel' if self.ctrl else 'point'
                self.delete_points(eclick, (x1, x2), (y1, y2), what, undelete)

            self.RS_delete.set_visible(False)
            self.RS_undelete.set_visible(False)
            self.RS_delete.visible = True
            self.RS_undelete.visible = True

        rectprops = dict(facecolor='red',
                         edgecolor='red',
                         alpha=0.5,
                         fill=True,
                         zorder=99)

        self.RS_delete = RectangleSelector(
            self.ax_main,
            line_select_callback,
            drawtype='box',
            useblit=True,
            button=[1],  # don't use middle button
            minspanx=5,
            minspany=5,
            rectprops=rectprops,
            spancoords='pixels',
            interactive=True)
        rectprops = dict(facecolor='blue',
                         edgecolor='blue',
                         alpha=0.5,
                         fill=True,
                         zorder=99)

        self.RS_undelete = RectangleSelector(
            self.ax_main,
            line_select_callback,
            drawtype='box',
            useblit=True,
            button=[3],  # don't use middle button
            minspanx=5,
            minspany=5,
            rectprops=rectprops,
            spancoords='pixels',
            interactive=True)

    def changed_fit_slice(self):
        #switch between time/radial slice or gradient

        if self.plot_type.get() in [1, 2]:
            #radial slice, radial gradient
            self.view_step.config(textvariable=self.fit_options['dt'])
            self.view_step_lbl.config(text='Plot step [ms]')
            self.main_slider.label = 'Time:'
            self.parent.set_trange(self.tbeg, self.tend)
            self.ax_main.set_xlim(self.options['rho_min'],
                                  self.options['rho_max'])
            self.ax_main.set_xlabel(self.xlab, fontsize=self.fsize + 2)
            if self.plot_type.get() == 1:
                self.ax_main.set_ylabel(self.ylab, fontsize=self.fsize + 2)
            if self.plot_type.get() == 2:
                self.ax_main.set_ylabel(self.ylab_diff,
                                        fontsize=self.fsize + 2)

        if self.plot_type.get() in [0]:
            #time slice
            self.view_step.config(
                textvariable=self.fit_options.get('dr', 0.02))
            self.view_step_lbl.config(text='Radial step [-]')
            self.main_slider.label = 'Radius:'
            self.main_slider.valmin = self.options['rho_min']
            self.main_slider.valmax = self.options['rho_max']
            self.sl_ax_main.set_xlim(self.options['rho_min'],
                                     self.options['rho_max'])
            self.ax_main.set_xlim(self.tbeg, self.tend)
            self.ax_main.set_xlabel('time [s]', fontsize=self.fsize + 2)
            self.ax_main.set_ylabel(self.ylab)

        self.lcfs_line.set_visible(self.plot_type.get() in [1, 2])
        self.zero_line.set_visible(self.plot_type.get() in [0, 2])

        self.updateMainSlider()
        self.PreparePloting()
        self.plot_step()
        self.plot3d(update=True)

    def init_fit_frame(self):

        #frame with the navigation bar for the main plot

        fit_frame_up = tk.Frame(self.fit_frame)
        fit_frame_down = tk.LabelFrame(self.fit_frame, relief='groove')
        fit_frame_up.pack(side=tk.TOP, fill=tk.BOTH)
        fit_frame_down.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=tk.Y)

        self.plot_type = tk.IntVar(master=self.fit_frame)
        self.plot_type.set(1)
        r_buttons = 'Radial slice', 'Time slice', 'Gradient'
        for nbutt, butt in enumerate(r_buttons):
            button = tk.Radiobutton(fit_frame_up,
                                    text=butt,
                                    variable=self.plot_type,
                                    command=self.changed_fit_slice,
                                    value=nbutt)
            button.pack(anchor='w', side=tk.LEFT, pady=2, padx=2)

        # canvas frame

        self.fig = Figure(figsize=(10, 10), dpi=75)
        self.fig.patch.set_facecolor((.93, .93, .93))
        self.ax_main = self.fig.add_subplot(111)

        self.canvasMPL = tkagg.FigureCanvasTkAgg(self.fig,
                                                 master=fit_frame_down)
        self.toolbar = NavigationToolbar2Tk(self.canvasMPL, fit_frame_down)

        def print_figure(filename, **kwargs):
            #cheat print_figure function to save only the plot without the sliders.
            if 'bbox_inches' not in kwargs:
                fig = self.ax_main.figure
                extent = self.ax_main.get_tightbbox(
                    fig.canvas.renderer).transformed(
                        fig.dpi_scale_trans.inverted())
                extent.y1 += .3
                extent.x1 += .3
                kwargs['bbox_inches'] = extent
            self.canvas_print_figure(filename, **kwargs)

        self.canvas_print_figure = self.toolbar.canvas.print_figure
        self.toolbar.canvas.print_figure = print_figure

        self.canvasMPL.get_tk_widget().pack(side=tk.TOP,
                                            fill=tk.BOTH,
                                            expand=1)
        self.canvasMPL._tkcanvas.pack(side=tk.TOP, fill=tk.BOTH, expand=1)

        self.lcfs_line = self.ax_main.axvline(1, ls='--', c='k', visible=False)
        self.zero_line = self.ax_main.axhline(0, ls='--', c='k', visible=False)

        hbox1 = tk.Frame(fit_frame_down)
        hbox1.pack(side=tk.BOTTOM, fill=tk.X)
        mouse_help = tk.Label(hbox1, text='Mouse: ')
        mouse_left = tk.Label(hbox1,
                              text='Left (+Ctrl): del point (Channel)  ',
                              fg="#900000")
        mouse_mid = tk.Label(hbox1, text='Mid: re-fit  ', fg="#009000")
        mouse_right = tk.Label(hbox1,
                               text='Right: undelete point  ',
                               fg="#000090")
        mouse_wheel = tk.Label(hbox1, text='Wheel: shift', fg="#905090")

        for w in (mouse_help, mouse_left, mouse_mid, mouse_right, mouse_wheel):
            w.pack(side=tk.LEFT)

        hbox2 = tk.Frame(fit_frame_down)
        hbox2.pack(side=tk.BOTTOM, fill=tk.X)

        helv36 = tkinter.font.Font(family='Helvetica', size=10, weight='bold')
        calc_button = tk.Button(hbox2,
                                text='Fit',
                                bg='red',
                                command=self.calculate,
                                font=helv36)
        calc_button.pack(side=tk.LEFT)

        self.playfig = tk.PhotoImage(file=icon_dir + 'play.gif',
                                     master=self.fit_frame)
        self.pausefig = tk.PhotoImage(file=icon_dir + 'pause.gif',
                                      master=self.fit_frame)
        self.forwardfig = tk.PhotoImage(file=icon_dir + 'forward.gif',
                                        master=self.fit_frame)
        self.backwardfig = tk.PhotoImage(file=icon_dir + 'backward.gif',
                                         master=self.fit_frame)

        self.backward_button = tk.Button(hbox2,
                                         command=self.Backward,
                                         image=self.backwardfig)
        self.backward_button.pack(side=tk.LEFT)
        self.play_button = tk.Button(hbox2,
                                     command=self.Play,
                                     image=self.playfig)
        self.play_button.pack(side=tk.LEFT)
        self.forward_button = tk.Button(hbox2,
                                        command=self.Forward,
                                        image=self.forwardfig)
        self.forward_button.pack(side=tk.LEFT)

        self.button_3d = tk.Button(hbox2,
                                   command=self.plot3d,
                                   text='3D',
                                   font=helv36)
        self.button_3d.pack(side=tk.LEFT)

        self.stop = True
        self.ctrl = False
        self.shift = False

        def stop_handler(event=None, self=self):
            self.stop = True

        self.forward_button.bind("<Button-1>", stop_handler)
        self.backward_button.bind("<Button-1>", stop_handler)

        vcmd = hbox2.register(self.isfloat)

        self.view_step = tk.Entry(hbox2,
                                  width=4,
                                  validate="key",
                                  validatecommand=(vcmd, '%P'),
                                  justify=tk.CENTER)
        self.view_step_lbl = tk.Label(hbox2, text='Plot step [ms]')
        self.view_step.pack(side=tk.RIGHT, padx=10)
        self.view_step_lbl.pack(side=tk.RIGHT)

        axcolor = 'lightgoldenrodyellow'

        self.fig.subplots_adjust(left=.10,
                                 bottom=.20,
                                 right=.95,
                                 top=.95,
                                 hspace=.1,
                                 wspace=0)
        from numpy.lib import NumpyVersion
        kargs = {
            'facecolor': axcolor
        } if NumpyVersion(
            matplotlib.__version__) > NumpyVersion('2.0.0') else {
                'axisbg': axcolor
            }
        self.sl_ax_main = self.fig.add_axes([.1, .10, .8, .03], **kargs)
        self.main_slider = Slider(self.sl_ax_main,
                                  '',
                                  self.tbeg,
                                  self.tend,
                                  valinit=self.tbeg)

        sl_ax = self.fig.add_axes([.1, .03, .35, .03], **kargs)
        self.sl_eta = Slider(sl_ax, '', 0, 1, valinit=self.options['eta'])

        sl_ax2 = self.fig.add_axes([.55, .03, .35, .03], **kargs)
        self.sl_lam = Slider(sl_ax2, '', 0, 1, valinit=self.options['lam'])

        self.fig.text(.1, .075, 'Time smoothing -->:')
        self.fig.text(.55, .075, 'Radial smoothing -->:')

        createToolTip(self.forward_button, 'Go forward by one step')
        createToolTip(self.backward_button, 'Go backward by one step')
        createToolTip(self.play_button,
                      'Go step by step forward, pause by second press')
        createToolTip(
            self.view_step,
            'Plotting time/radial step, this option influences only the plotting, not fitting!'
        )

        createToolTip(calc_button, 'Calculate the 2d fit of the data')

        def update_eta(eta):
            self.options['eta'] = eta
            stop_handler()

        def update_lam(lam):
            stop_handler()
            self.options['lam'] = lam

        def update_slider(val):
            try:
                if self.plot_type.get() in [1, 2]:
                    self.plt_time = val
                if self.plot_type.get() in [0]:
                    self.plt_radius = val

                self.updateMainSlider()
                self.plot_step()
            except:
                print(
                    '!!!!!!!!!!!!!!!!!!!!!main_slider error!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                )
                raise

        self.main_slider.on_changed(update_slider)

        self.cid1 = self.fig.canvas.mpl_connect('button_press_event',
                                                self.MouseInteraction)
        self.cid2 = self.fig.canvas.mpl_connect('scroll_event',
                                                self.WheelInteraction)
        self.cid3 = self.fig.canvas.mpl_connect('key_press_event', self.on_key)
        self.cid4 = self.fig.canvas.mpl_connect('key_release_event',
                                                self.off_key)
        self.cid5 = self.fig.canvas.mpl_connect(
            'button_press_event',
            lambda event: self.fig.canvas._tkcanvas.focus_set())
        self.cid6 = self.fig.canvas.mpl_connect('pick_event', self.legend_pick)

        self.sl_eta.on_changed(update_eta)
        self.sl_lam.on_changed(update_lam)

    def calculate(self):

        if self.shot is None:
            print('Set shot number first!')
            tkinter.messagebox.showerror('Missing shot number',
                                         'Shot number is not defined')

        if self.m2g is None:
            #print("No data to fit, let's try to load them first...")
            self.parent.init_data()

            if self.m2g is None:
                raise Exception('Loading of data failed!')

        if self.elms['signal'] != self.fit_options['elm_signal'].get():
            print('Load new ELM signal ' +
                  self.fit_options['elm_signal'].get())
            self.elms = self.parent.data_loader(
                'elms', {'elm_signal': self.fit_options['elm_signal']})
            self.edge_discontinuties = [
                self.ax_main.axvline(t, ls='-', lw=.2, c='k', visible=False)
                for t in self.elms['elm_beg']
            ]

        sys.stdout.write('  * Fitting  ... \t ')
        sys.stdout.flush()
        T = time.time()
        #make the fit of the data
        self.fit_frame.config(cursor="watch")
        self.fit_frame.update()

        self.saved_profiles = False

        sawteeth = eval(self.fit_options['sawteeth_times'].get()
                        ) if self.fit_options['sawteeth'].get() else []

        elms = self.elms['elm_beg'] if self.fit_options['elmsync'].get(
        ) else []
        elm_phase = (
            self.elms['tvec'],
            self.elms['data']) if self.fit_options['elmsync'].get() else None

        robust_fit = self.fit_options['robustfit'].get()
        zeroedge = self.fit_options['zeroedge'].get()

        pedestal_rho = float(self.fit_options['pedestal_rho'].get())

        #get functions for profile transformation
        transform = transformations[self.fit_options['transformation'].get()]

        even_fun = self.options['rho_coord'] not in ['Psi', 'Psi_N']

        if self.m2g is None:
            return

        if not self.options['fit_prepared']:
            #print 'not yet prepared! in calculate'
            self.m2g.PrepareCalculation(zero_edge=zeroedge,
                                        core_discontinuties=sawteeth,
                                        edge_discontinuties=elms,
                                        transformation=transform,
                                        pedestal_rho=pedestal_rho,
                                        robust_fit=robust_fit,
                                        elm_phase=elm_phase,
                                        even_fun=even_fun)
            self.m2g.corrected = False

            #remove points affected by elms
            if self.fit_options['elmrem'].get() and len(self.elms['tvec']) > 2:
                elm_phase = np.interp(self.plot_tvec, self.elms['tvec'],
                                      self.elms['data'])
                elm_ind = (elm_phase < 0.1) | (
                    (elm_phase > 0.95) &
                    (elm_phase < 1))  #remove also data shortly before an elm
                self.options['elmrem_ind'] = (self.plot_rho > .8) & elm_ind
                self.m2g.Yerr.mask |= self.options['elmrem_ind']
            elif np.any(self.options['elmrem_ind']):
                try:
                    self.m2g.Yerr.mask[self.options['elmrem_ind']] = False
                except:
                    print(
                        'np.shape(self.options[elmrem_inds]),self.m2g.Yerr.mask.shape ',
                        np.shape(self.options['elmrem_ind']),
                        self.m2g.Yerr.mask.shape, self.plot_rho.shape,
                        self.plot_tvec.shape)

                self.options['elmrem_ind'] = False

            #remove points in outer regions
            zeroed_outer = self.options['zeroed_outer']
            if self.fit_options['null_outer'].get():
                if np.any(zeroed_outer):
                    self.m2g.Yerr.mask[zeroed_outer] = False
                rho_lim = float(self.fit_options['outside_rho'].get())
                zeroed_outer = self.plot_rho > rho_lim
                self.m2g.Yerr.mask |= zeroed_outer
            elif np.any(zeroed_outer):
                self.m2g.Yerr.mask[zeroed_outer] = False
                zeroed_outer = False
            self.options['zeroed_outer'] = zeroed_outer
            self.options['fit_prepared'] = True

        lam = self.sl_lam.val
        eta = self.sl_eta.val
        self.m2g.Calculate(lam, eta)

        print('\t done in %.1fs' % (time.time() - T))

        self.options['fitted'] = True
        self.chi2_text.set_text('$\chi^2/doF$: %.2f' % self.m2g.chi2)

        self.plot_step()
        self.plot3d(update=True)

        self.fit_frame.config(cursor="")

    def Pause(self):
        self.stop = True
        self.play_button['image'] = self.playfig
        self.play_button['command'] = self.Play

    def Forward(self, mult=1):
        try:
            dt = self.fit_options['dt'].get()
        except:
            dt = ''

        if not self.isfloat(dt) or dt == '':
            return

        if self.ctrl: mult /= 5

        if self.plot_type.get() in [1, 2]:
            dt = float(dt) / 1e3
            self.plt_time += dt * mult
        if self.plot_type.get() in [0]:
            dr = float(self.fit_options['dr'].get())
            self.plt_radius += dr * mult

        self.updateMainSlider()
        self.plot_step()

    def Backward(self):
        self.Forward(-1)

    def Play(self):
        #animated plot
        self.stop = False
        self.play_button['image'] = self.pausefig
        self.play_button['command'] = self.Pause
        try:
            dt = float(self.fit_options['dt'].get()) / 1e3
        except:
            print('Invalid time step value!')
            dt = .01

        try:
            dr = float(self.fit_options['dr'].get())
        except:
            print('Invalid radial step value!')
            dr = .05

        while True:

            if self.plot_type.get() in [0]:
                self.plt_radius += dr

            if self.plot_type.get() in [1, 2]:
                self.plt_time += dt

            if not (self.m2g.t_min <= self.plt_time <= self.m2g.t_max):
                self.stop = True
            if not (self.options['rho_min'] <= self.plt_radius <=
                    self.options['rho_max']):
                self.stop = True

            self.fit_frame.after(1, self.plot_step)
            time.sleep(1e-3)
            try:
                self.canvasMPL.get_tk_widget().update()
            except:
                return

            self.updateMainSlider()

            if self.stop:
                break

        self.Pause()

    def PreparePloting(self):
        #set the limits and title of the plots
        if hasattr(self.parent, 'BRIEF'):
            self.ax_main.set_title(self.parent.BRIEF)

        minlim, maxlim = 0, 1
        if self.plot_type.get() in [0, 1] and self.options['data_loaded']:
            valid = ~self.m2g.Yerr.mask & (self.m2g.Y.data >
                                           self.m2g.Yerr.data)
            minlim, maxlim = mquantiles(self.m2g.Y[valid], [.001, .995])
        if self.plot_type.get() in [
                2
        ] and self.options['data_loaded'] and self.m2g.prepared:
            minlim, maxlim = mquantiles(self.m2g.K[self.m2g.g_r < .8],
                                        [.02, .98])
            maxlim *= 2
        elif self.plot_type.get() in [2]:
            minlim, maxlim = 0, 10  #just guess of the range

        minlim = min(0, minlim)
        if minlim != 0:
            minlim -= (maxlim - minlim) * .1
        maxlim += (maxlim - minlim) * .2
        self.ax_main.set_ylim(minlim, maxlim)

    def updateMainSlider(self):

        if self.plot_type.get() in [0]:
            self.plt_radius = min(
                max(self.plt_radius, self.options['rho_min']),
                self.options['rho_max'])
            val = self.plt_radius
        if self.plot_type.get() in [1, 2]:
            self.plt_time = min(max(self.plt_time, self.tbeg), self.tend)
            val = self.plt_time

        self.main_slider.val = val
        poly = self.main_slider.poly.get_xy()
        poly[2:4, 0] = val
        self.main_slider.poly.set_xy(poly)
        self.main_slider.valtext.set_text('%.3f' % val)

    def plot_step(self):
        #single step plotting routine
        if not self.options['data_loaded']:
            return

        t = self.plt_time
        r = self.plt_radius
        try:
            dt = float(self.fit_options['dt'].get()) / 1e3
        except:
            print('Invalid time step value!')
            dt = .01
        #dt = float(self.fit_options['dt'].get())/1e3
        dr = float(self.fit_options['dr'].get())
        plot_type = self.plot_type.get()
        kinprof = self.parent.kin_prof

        if plot_type in [0]:
            self.time_text.set_text('rho: %.3f' % r)
            self.select = abs(self.plot_rho - r) <= abs(dr) / 2
            X = self.plot_tvec

        if plot_type in [1, 2]:
            self.time_text.set_text('time: %.4fs' % t)
            self.select = abs(self.plot_tvec - t) <= abs(dt) / 2
            X = self.plot_rho

        for idiag, diag in enumerate(self.diags):
            dind = self.select & (self.ind_diag
                                  == idiag) & (self.m2g.Yerr.data > 0)

            if any(dind) and plot_type in [0, 1]:
                self.replot_plot[idiag].set_visible(True)
                self.plotline[idiag].set_visible(True)
                for c in self.caplines[idiag]:
                    c.set_visible(True)
                self.barlinecols[idiag][0].set_visible(True)
                x = X[dind]
                y = self.m2g.Y[dind]

                yerr = self.m2g.Yerr.data[dind]
                yerr[self.m2g.Yerr.mask[dind]] = np.infty

                ry = self.m2g.retro_f[dind]
                self.replot_plot[idiag].set_data(x, ry)

                # Replot the data first
                self.plotline[idiag].set_data(x, y)

                # Find the ending points of the errorbars
                error_positions = (x, y - yerr), (x, y + yerr)

                # Update the caplines
                for j, pos in enumerate(error_positions):
                    self.caplines[idiag][j].set_data(pos)

                #Update the error bars
                self.barlinecols[idiag][0].set_segments(
                    list(zip(list(zip(x, y - yerr)), list(zip(x, y + yerr)))))
            else:
                self.replot_plot[idiag].set_visible(False)
                self.plotline[idiag].set_visible(False)
                for c in self.caplines[idiag]:
                    c.set_visible(False)
                self.barlinecols[idiag][0].set_visible(False)

            #plot fit of teh data with uncertainty
            if self.options['fitted'] and hasattr(self.m2g, 'g_t'):
                if plot_type == 0:  #time slice
                    y, x = self.m2g.g_t[:, 0], self.m2g.g_r[0]
                    p = r
                    profiles = self.m2g.g.T, self.m2g.g_d.T, self.m2g.g_u.T

                if plot_type == 1:  #radial slice
                    profiles = self.m2g.g, self.m2g.g_d, self.m2g.g_u
                    x, y = self.m2g.g_t[:, 0], self.m2g.g_r[0]
                    p = t

                if plot_type == 2:  #radial slice of the gradient/rho
                    profiles = self.m2g.K, self.m2g.Kerr_d, self.m2g.Kerr_u
                    x, y = self.m2g.g_t[:, 0], self.m2g.g_r[0]
                    p = t

                prof = []
                for d in profiles:
                    if self.m2g.g_t.shape[0] == 1:
                        prof.append(d[0])
                    else:
                        prof.append(
                            interp1d(x,
                                     d,
                                     axis=0,
                                     copy=False,
                                     assume_sorted=True)(np.clip(
                                         p, x[0], x[-1])))

                self.fit_plot.set_data(y, prof[0])
                self.fit_confidence = update_fill_between(
                    self.fit_confidence, y, prof[1], prof[2], -np.infty,
                    np.infty)

        #show discontinuties in time
        for d in self.core_discontinuties:
            d.set_visible(r < .3 and plot_type == 0)
        for d in self.edge_discontinuties:
            d.set_visible(r > .7 and plot_type == 0)

        #MHD modes
        if self.mhd_modes is not None:
            for name, rho_loc in self.mhd_modes['modes'].items():
                loc = np.nan
                if plot_type in [1, 2] and kinprof in ['omega', 'Ti']:
                    it = np.argmin(abs(self.mhd_modes['tvec'] - t))
                    loc = rho_loc[it]

                if np.isfinite(loc):
                    self.mhd_locations[name].set_data([loc, loc], [0, 1])
                    self.mhd_labels[name].set_x(loc)
                    self.mhd_locations[name].set_visible(True)
                    self.mhd_labels[name].set_visible(True)
                else:
                    self.mhd_locations[name].set_visible(False)
                    self.mhd_labels[name].set_visible(False)

            #self.mhd_locations = [self.ax_main.axvline(np.nan, ls='-',lw=.5,c='k',visible=False) for mode in self.mhd_modes]
            #self.mhd_labels = [self.ax_main.text(0, np.nan,  mode) for mode in self.mhd_modes]

            #for txt in

            #set_x
            #axvline1.set_data([event.xdata, event.xdata], [0, 1])
            #axvline2.set_data([event.xdata, event.xdata], [0, 1])

        #show also zipfit profiles
        #BUG how to avoid access of parent class?
        show_splines = self.parent.show_splines.get() == 1

        if show_splines and kinprof in self.parent.splines:
            splines = self.parent.splines[kinprof]

            y = splines['time'].values
            x = splines['rho'].values
            z = splines[kinprof].values
            ze = z * 0

            if kinprof + '_err' in splines:
                ze = splines[kinprof + '_err'].values

            y0 = t

            if plot_type == 0:  #temporal profiles
                z, ze, x = z.T, ze.T, x.T
                y, x = x, y
                y0 = r

            if np.ndim(y) == 2:
                y = y.mean(1)

            i = np.argmin(np.abs(y - y0))

            if np.ndim(x) == 2:
                x = x[i]

            z = z[i]
            ze = ze[i]

            if plot_type in [2]:
                #BUG!!!!
                a0 = 0.6
                R0 = 1.7
                z_ = (z[1:] + z[:-1]) / 2
                x_ = (x[1:] + x[:-1]) / 2
                z = -(np.diff(z) / np.diff(x * a0) * R0 /
                      x_)[z_ != 0] / z_[z_ != 0]
                ze = 0
                x = x_[z_ != 0]

            self.spline_mean.set_data(x, z)
            self.spline_min.set_data(x, z - ze)
            self.spline_max.set_data(x, z + ze)

        self.spline_mean.set_visible(show_splines)
        self.spline_min.set_visible(show_splines)
        self.spline_max.set_visible(show_splines)

        self.fig.canvas.draw_idle()

    def plot3d(self, update=False):

        if not self.options['fitted']:
            return

        if plt.fignum_exists('3D plot'):
            try:
                ax = self.fig_3d.gca()
                ax.collections.remove(self.wframe)
            except:
                return
        elif not update:
            self.fig_3d = plt.figure('3D plot')
            ax = p3.Axes3D(self.fig_3d)
            ax.set_xlabel(self.xlab, fontsize=self.fsize3d)
            ax.set_ylabel('Time [s]', fontsize=self.fsize3d)
        else:
            return

        ax.set_zlabel(self.ylab, fontsize=self.fsize3d)
        self.wframe = ax.plot_wireframe(self.m2g.g_r,
                                        self.m2g.g_t,
                                        self.m2g.g,
                                        linewidth=.3,
                                        rstride=self.rstride,
                                        cstride=self.cstride)

        self.fig_3d.show()

    def MouseInteraction(self, event):
        if self.picked:  #legend_pick was called first
            self.picked = False
            return

        if event.button == 1 and self.ctrl:
            self.delete_channel(event)
        elif event.button == 1:
            self.delete_point(event)
        elif event.button == 2:
            self.calculate()
        elif event.button == 3 and self.ctrl:
            self.undelete_channel(event)
        elif event.button == 3:
            self.undelete_point(event)

    def legend_pick(self, event):

        if not event.mouseevent.dblclick:
            return

        if event.mouseevent.button == 1:
            undelete = False

        elif event.mouseevent.button == 3:
            undelete = True
        else:
            return

        if not event.artist in self.leg_diag_ind:
            return

        i_diag = self.leg_diag_ind[event.artist]

        ind = np.in1d(self.ind_diag, i_diag)

        self.m2g.Yerr.mask[ind] = not undelete

        self.plot_step()
        self.m2g.corrected = False  #force the upgrade
        self.picked = True

    def WheelInteraction(self, event):
        self.Forward(int(event.step))

    def delete_channel(self, event):
        self.delete_point(event, 'channel')

    def undelete_channel(self, event):
        self.delete_point(event, 'channel', True)

    def undelete_point(self, event):
        self.delete_point(event, 'point', True)

    def delete_point(self, event, what='point', undelete=False):
        if not event.dblclick:
            return

        self.delete_points(event,
                           event.xdata,
                           event.ydata,
                           what=what,
                           undelete=undelete)

    def delete_points(self, event, xc, yc, what='point', undelete=False):
        #delete point closest to xc,yc or in the rectangle decribed by xc,yc
        # what - point, channel, diagnostic

        if self.ax_main != event.inaxes or not self.options['data_loaded']:
            return
        if undelete:
            affected = self.select & self.m2g.Yerr.mask
        else:
            affected = self.select & ~self.m2g.Yerr.mask

        if not any(affected):
            return

        if self.plot_type.get() == 1:
            x = self.plot_rho[affected]
        elif self.plot_type.get() == 0:
            x = self.plot_tvec[affected]
        else:
            return

        y = self.m2g.Y[affected]
        if np.size(xc) == 1:
            #get range within the plot
            sx = np.ptp(self.ax_main.get_xlim())
            sy = np.ptp(self.ax_main.get_ylim())

            dist = np.hypot((x - xc) / sx, (y - yc) / sy)
            selected = np.argmin(dist)
        else:
            selected = (x >= min(xc)) & (x <= max(xc)) & (y >= min(yc)) & (
                y <= max(yc))
            if not any(selected):
                return

        i_ind = np.where(affected)[0]
        ind = i_ind[selected]

        action = 'recovered ' if undelete else 'deleted'

        if what == 'channel':
            ch = np.unique(self.channel[ind])
            ind = np.in1d(self.channel, ch)
            print('Channel %s was ' % ch + action)

        elif what == 'diagnostic':
            i_diag = self.ind_diag[ind]
            ind = np.in1d(self.ind_diag, i_diag)
            print('diagnostic %s was ' % i_diag + action)

        elif what == 'point':
            pass
        else:
            print('Removing of "%s" is not supported' % (str(what)))

        self.m2g.Yerr.mask[ind] = not undelete

        self.plot_step()
        self.m2g.corrected = False  #force the upgrade

    def on_key(self, event):
        if 'control' == event.key and hasattr(self, 'RS_delete'):
            self.ctrl = True
            if self.RS_delete.eventpress is not None:
                self.RS_delete.eventpress.key = None
            if self.RS_undelete.eventpress is not None:
                self.RS_undelete.eventpress.key = None

        if 'shift' == event.key:
            self.shift = True

        if 'left' == event.key:
            self.Backward()

        if 'right' == event.key:
            self.Forward()

        if 'g' == event.key:
            self.grid = not self.grid
            self.ax_main.grid(self.grid)
            self.fig.canvas.draw_idle()

        if 'l' == event.key:

            self.logy = not self.logy
            if self.logy:
                if self.ax_main.get_ylim()[0] <= 0:
                    self.ax_main.set_ylim(1, None)
                self.ax_main.set_yscale('log')
            else:
                self.ax_main.set_yscale('linear')
            self.fig.canvas.draw_idle()

        if ' ' == event.key:
            if self.stop:
                self.Play()
            else:
                self.Pause()

    def off_key(self, event):
        if event.key in ('ctrl+control', 'control'):
            self.ctrl = False
        if 'shift' == event.key:
            self.shift = False
Example #55
0
File: scene.py Project: tukss/yt
    def save(self, fname=None, sigma_clip=None, render=True):
        r"""Saves a rendered image of the Scene to disk.

        Once you have created a scene, this saves an image array to disk with
        an optional filename. This function calls render() to generate an
        image array, unless the render parameter is set to False, in which case
        the most recently rendered scene is used if it exists.

        Parameters
        ----------
        fname: string, optional
            If specified, save the rendering as to the file "fname".
            If unspecified, it creates a default based on the dataset filename.
            The file format is inferred from the filename's suffix. Supported
            fomats are png, pdf, eps, and ps.
            Default: None
        sigma_clip: float, optional
            Image values greater than this number times the standard deviation
            plus the mean of the image will be clipped before saving. Useful
            for enhancing images as it gets rid of rare high pixel values.
            Default: None

            floor(vals > std_dev*sigma_clip + mean)
        render: boolean, optional
            If True, will always render the scene before saving.
            If False, will use results of previous render if it exists.
            Default: True

        Returns
        -------
            Nothing

        Examples
        --------

        >>> import yt
        >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
        >>>
        >>> sc = yt.create_scene(ds)
        >>> # Modify camera, sources, etc...
        >>> sc.save('test.png', sigma_clip=4)

        When saving multiple images without modifying the scene (camera,
        sources,etc.), render=False can be used to avoid re-rendering.
        This is useful for generating images at a range of sigma_clip values:

        >>> import yt
        >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
        >>>
        >>> sc = yt.create_scene(ds)
        >>> # save with different sigma clipping values
        >>> sc.save('raw.png')  # The initial render call happens here
        >>> sc.save('clipped_2.png', sigma_clip=2, render=False)
        >>> sc.save('clipped_4.png', sigma_clip=4, render=False)

        """
        if fname is None:
            sources = list(self.sources.values())
            rensources = [s for s in sources if isinstance(s, RenderSource)]
            # if a volume source present, use its affiliated ds for fname
            if len(rensources) > 0:
                rs = rensources[0]
                basename = rs.data_source.ds.basename
                if isinstance(rs.field, str):
                    field = rs.field
                else:
                    field = rs.field[-1]
                fname = f"{basename}_Render_{field}.png"
            # if no volume source present, use a default filename
            else:
                fname = "Render_opaque.png"
        suffix = get_image_suffix(fname)
        if suffix == "":
            suffix = ".png"
            fname = f"{fname}{suffix}"

        render = self._sanitize_render(render)
        if render:
            self.render()
        mylog.info("Saving rendered image to %s", fname)

        # We can render pngs natively but for other formats we defer to
        # matplotlib.
        if suffix == ".png":
            self._last_render.write_png(fname, sigma_clip=sigma_clip)
        else:
            from matplotlib.backends.backend_pdf import FigureCanvasPdf
            from matplotlib.backends.backend_ps import FigureCanvasPS
            from matplotlib.figure import Figure

            shape = self._last_render.shape
            fig = Figure((shape[0] / 100.0, shape[1] / 100.0))
            if suffix == ".pdf":
                canvas = FigureCanvasPdf(fig)
            elif suffix in (".eps", ".ps"):
                canvas = FigureCanvasPS(fig)
            else:
                raise NotImplementedError(f"Unknown file suffix '{suffix}'")
            ax = fig.add_axes([0, 0, 1, 1])
            ax.set_axis_off()
            out = self._last_render
            nz = out[:, :, :3][out[:, :, :3].nonzero()]
            max_val = nz.mean() + sigma_clip * nz.std()
            alpha = 255 * out[:, :, 3].astype("uint8")
            out = np.clip(out[:, :, :3] / max_val, 0.0, 1.0) * 255
            out = np.concatenate([out.astype("uint8"), alpha[..., None]],
                                 axis=-1)
            # not sure why we need rot90, but this makes the orientation
            # match the png writer
            ax.imshow(np.rot90(out), origin="lower")
            canvas.print_figure(fname, dpi=100)
Example #56
0
    def changePlotWidget(self, library, frame_for_plot):

        if library == "PyQtGraph":
            plotWdg = pg.PlotWidget()
            plotWdg.showGrid(True, True, 0.5)
            datavline = pg.InfiniteLine(0,
                                        angle=90,
                                        pen=pg.mkPen("r", width=1),
                                        name="cross_vertical")
            datahline = pg.InfiniteLine(0,
                                        angle=0,
                                        pen=pg.mkPen("r", width=1),
                                        name="cross_horizontal")
            plotWdg.addItem(datavline)
            plotWdg.addItem(datahline)
            # cursor
            xtextitem = pg.TextItem("X : /",
                                    color=(0, 0, 0),
                                    border=pg.mkPen(color=(0, 0, 0), width=1),
                                    fill=pg.mkBrush("w"),
                                    anchor=(0, 1))
            ytextitem = pg.TextItem(
                "Y : / ",
                color=(0, 0, 0),
                border=pg.mkPen(color=(0, 0, 0), width=1),
                fill=pg.mkBrush("w"),
                anchor=(0, 0),
            )
            plotWdg.addItem(xtextitem)
            plotWdg.addItem(ytextitem)

            plotWdg.getViewBox().autoRange(items=[])
            plotWdg.getViewBox().disableAutoRange()
            plotWdg.getViewBox().border = pg.mkPen(color=(0, 0, 0), width=1)

            return plotWdg

        elif library == "Qwt5" and has_qwt:
            plotWdg = QwtPlot(frame_for_plot)
            sizePolicy = QSizePolicy(QSizePolicy.Expanding,
                                     QSizePolicy.Expanding)
            sizePolicy.setHorizontalStretch(0)
            sizePolicy.setVerticalStretch(0)
            sizePolicy.setHeightForWidth(
                plotWdg.sizePolicy().hasHeightForWidth())
            plotWdg.setSizePolicy(sizePolicy)
            plotWdg.setMinimumSize(QSize(0, 0))
            plotWdg.setAutoFillBackground(False)
            # Decoration
            plotWdg.setCanvasBackground(Qt.white)
            plotWdg.plotLayout().setAlignCanvasToScales(True)
            zoomer = QwtPlotZoomer(QwtPlot.xBottom, QwtPlot.yLeft,
                                   QwtPicker.DragSelection,
                                   QwtPicker.AlwaysOff, plotWdg.canvas())
            zoomer.setRubberBandPen(QPen(Qt.blue))
            if platform.system() != "Windows":
                # disable picker in Windows due to crashes
                picker = QwtPlotPicker(
                    QwtPlot.xBottom,
                    QwtPlot.yLeft,
                    QwtPicker.NoSelection,
                    QwtPlotPicker.CrossRubberBand,
                    QwtPicker.AlwaysOn,
                    plotWdg.canvas(),
                )
                picker.setTrackerPen(QPen(Qt.green))
            # self.dockwidget.qwtPlot.insertLegend(QwtLegend(), QwtPlot.BottomLegend);
            grid = Qwt.QwtPlotGrid()
            grid.setPen(QPen(QColor("grey"), 0, Qt.DotLine))
            grid.attach(plotWdg)
            return plotWdg

        elif library == "Matplotlib" and has_mpl:
            from matplotlib.figure import Figure

            if int(qgis.PyQt.QtCore.QT_VERSION_STR[0]) == 4:
                from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg
            elif int(qgis.PyQt.QtCore.QT_VERSION_STR[0]) == 5:
                from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg

            fig = Figure(
                (1.0, 1.0),
                linewidth=0.0,
                subplotpars=matplotlib.figure.SubplotParams(left=0,
                                                            bottom=0,
                                                            right=1,
                                                            top=1,
                                                            wspace=0,
                                                            hspace=0),
            )

            font = {"family": "arial", "weight": "normal", "size": 12}
            rc("font", **font)

            rect = fig.patch
            rect.set_facecolor((0.9, 0.9, 0.9))

            self.subplot = fig.add_axes((0.05, 0.15, 0.92, 0.82))
            self.subplot.set_xbound(0, 1000)
            self.subplot.set_ybound(0, 1000)
            self.manageMatplotlibAxe(self.subplot)
            canvas = FigureCanvasQTAgg(fig)
            sizePolicy = QSizePolicy(QSizePolicy.Expanding,
                                     QSizePolicy.Expanding)
            sizePolicy.setHorizontalStretch(0)
            sizePolicy.setVerticalStretch(0)
            canvas.setSizePolicy(sizePolicy)
            return canvas
Example #57
0
        def draw_plots():
            f = Figure(figsize=(7.5, 7.5), dpi=100)

            x0 = float(entry_x0.get())
            y0 = float(entry_y0.get())
            xf = float(entry_xf.get())
            n = int(entry_n.get())

            h = (xf - x0) / (n - 1)

            x = numpy.linspace(x0, xf, n)
            y1 = numpy.zeros([n])
            y2 = numpy.zeros([n])
            y3 = numpy.zeros([n])
            y4 = numpy.zeros([n])
            y1[0] = y0
            y2[0] = y0
            y3[0] = y0
            y4[0] = y0
            for i in range(1, n):
                y1[i] = Computations.euler(x0, y0, h, x[i])
                y2[i] = Computations.euler_imp(x0, y0, h, x[i])
                y3[i] = Computations.runge_kutta(x0, y0, h, x[i])
                y4[i] = Computations.exact(x[i])

            x_limit = (x0 - 1, xf + 1)
            y_limit = (min(y1[n - 1], y2[n - 1], y3[n - 1], y4[n - 1]) - 1,
                       y0 + 1)

            ax1 = f.add_axes([0.1, 0.55, 0.8, 0.35],
                             xlim=(x_limit[0], x_limit[1]),
                             ylim=(y_limit[0], y_limit[1]),
                             title="Solutions of DE y' = "
                             r"$\frac{y^2}{x^2}$ - 2")
            ax1.plot(x, y1, ':', linewidth=2.0)
            ax1.plot(x, y2, ':')
            ax1.plot(x, y3, ':', linewidth=2.0)
            ax1.plot(x, y4, linewidth=1.0)

            f.legend(("Euler Method", "Improved Euler Method",
                      "Runge-Kutta Method", "Exact solution of IVP"))
            er1 = numpy.zeros([n])
            er2 = numpy.zeros([n])
            er3 = numpy.zeros([n])
            step = numpy.zeros([n])
            ma, mi = 0, 0
            for i in range(0, n):
                er1[i] = y4[i] - y1[i]
                er2[i] = y4[i] - y2[i]
                er3[i] = y4[i] - y3[i]
                step[i] = i
                ma = max(ma, er1[i], er2[i], er3[i])
                mi = min(mi, er1[i], er2[i], er3[i])
            ax2 = f.add_axes([0.1, 0.1, 0.8, 0.35],
                             xlim=(0, n),
                             ylim=(mi - 0.5, ma + 0.5),
                             title="Errors")

            ax2.plot(step, er1, linewidth=1.0)
            ax2.plot(step, er2, linewidth=1.0)
            ax2.plot(step, er3, linewidth=1.0)
            canvas = FigureCanvasTkAgg(f, self)

            # canvas.show()
            canvas.draw()
            canvas.get_tk_widget().place(x=250, y=10)

            canvas._tkcanvas.place(x=250, y=10)
Example #58
0
class Renderer:

    def __init__(self, file_name, hide, n, interval=100, hide_widgets=False):
        self.root = tkinter.Tk()
        if not file_name:
            file_name = askopenfilename(initialdir="output", title="Select param", filetypes=[
                ("param files", "*.param")])
        self.file_name = file_name
        X = pickle.load(open(file_name, 'rb'))
        self.plotter = Plotter(X, hide, n)
        self.interval = interval
        self.hide_widgets = hide_widgets

    def render(self):
        print(f'drawing with {self.plotter.n} circles')
        self.fig = Figure(figsize=(13, 13), dpi=100)
        self.ax = self.fig.subplots()
        self.root.wm_title(f"Render - {base_name(args.file_name, True)}")
        canvas = FigureCanvasTkAgg(self.fig, master=self.root)
        canvas.draw()
        canvas.get_tk_widget().pack(side=tkinter.TOP, fill=tkinter.BOTH, expand=1)
        if not self.hide_widgets:
            rax = self.fig.add_axes([0.05, 0.0, 0.1, 0.1])
            rax.axis('off')
            self.check = CheckButtons(rax, ('hide',), (self.plotter.hide,))
            self.check.on_clicked(lambda _: self.plotter.toggle_hide())

            nax = self.fig.add_axes([0.2, 0.07, 0.7, 0.02])
            self.nslider = Slider(nax, 'n', 2, self.plotter.frames,
                                  valinit=self.plotter.n, valstep=1)
            self.nslider.on_changed(self._update_n)
            fpsax = self.fig.add_axes([0.2, 0.03, 0.7, 0.02])
            self.fpsslider = Slider(fpsax, 'fps', 1, 50,
                                    valinit=10, valstep=1)
            self.fpsslider.on_changed(self._update_fps)
        self._init_animation()
        if args.save:
            self.save(args.out)
        else:
            tkinter.mainloop()

    def _init_animation(self):
        self.animation = FuncAnimation(self.fig,
                                       self.plotter.render_frame,
                                       frames=range(self.plotter.frames),
                                       interval=self.interval,
                                       repeat_delay=1000,
                                       init_func=partial(
                                           self.plotter.init_frame, self.ax),
                                       repeat=True)

    def _update_fps(self, fps):
        self.animation.event_source.stop()
        self.interval = int(1000 / fps)
        self._init_animation()

    def _update_n(self, n):
        self.animation.event_source.stop()
        self.plotter = Plotter(self.plotter.X, self.plotter.hide, int(n))
        self._init_animation()

    def save(self, out_fname):
        if not out_fname:
            out_fname = f'output/{base_name(self.file_name)}.mp4'
        print(f'saving to {out_fname}')
        self.animation.save(
            out_fname, writer=FFMpegWriter(fps=10, bitrate=1000))
Example #59
0
def thumbnail(infile, thumbfile, scale=0.1, interpolation='bilinear',
              preview=False):
    """
    make a thumbnail of image in *infile* with output filename
    *thumbfile*.

      *infile* the image file -- must be PNG or PIL readable if you
         have `PIL <http://www.pythonware.com/products/pil/>`_ installed

      *thumbfile*
        the thumbnail filename

      *scale*
        the scale factor for the thumbnail

      *interpolation*
        the interpolation scheme used in the resampling


      *preview*
        if True, the default backend (presumably a user interface
        backend) will be used which will cause a figure to be raised
        if :func:`~matplotlib.pyplot.show` is called.  If it is False,
        a pure image backend will be used depending on the extension,
        'png'->FigureCanvasAgg, 'pdf'->FigureCanvasPdf,
        'svg'->FigureCanvasSVG


    See examples/misc/image_thumbnail.py.

    .. htmlonly::

        :ref:`misc-image_thumbnail`

    Return value is the figure instance containing the thumbnail

    """
    basedir, basename = os.path.split(infile)
    baseout, extout = os.path.splitext(thumbfile)

    im = imread(infile)
    rows, cols, depth = im.shape

    # this doesn't really matter, it will cancel in the end, but we
    # need it for the mpl API
    dpi = 100

    height = float(rows)/dpi*scale
    width = float(cols)/dpi*scale

    extension = extout.lower()

    if preview:
        # let the UI backend do everything
        import matplotlib.pyplot as plt
        fig = plt.figure(figsize=(width, height), dpi=dpi)
    else:
        if extension == '.png':
            from matplotlib.backends.backend_agg \
                import FigureCanvasAgg as FigureCanvas
        elif extension == '.pdf':
            from matplotlib.backends.backend_pdf \
                import FigureCanvasPdf as FigureCanvas
        elif extension == '.svg':
            from matplotlib.backends.backend_svg \
                import FigureCanvasSVG as FigureCanvas
        else:
            raise ValueError("Can only handle "
                             "extensions 'png', 'svg' or 'pdf'")

        from matplotlib.figure import Figure
        fig = Figure(figsize=(width, height), dpi=dpi)
        canvas = FigureCanvas(fig)

    ax = fig.add_axes([0, 0, 1, 1], aspect='auto',
                      frameon=False, xticks=[], yticks=[])

    basename, ext = os.path.splitext(basename)
    ax.imshow(im, aspect='auto', resample=True, interpolation=interpolation)
    fig.savefig(thumbfile, dpi=dpi)
    return fig
    def changePlotWidget(self, library, frame_for_plot):
        if library == "Qwt5" and has_qwt:
            plotWdg = QwtPlot(frame_for_plot)
            sizePolicy = QSizePolicy(QSizePolicy.Expanding,
                                     QSizePolicy.Expanding)
            sizePolicy.setHorizontalStretch(0)
            sizePolicy.setVerticalStretch(0)
            sizePolicy.setHeightForWidth(
                plotWdg.sizePolicy().hasHeightForWidth())
            plotWdg.setSizePolicy(sizePolicy)
            plotWdg.setMinimumSize(QSize(0, 0))
            plotWdg.setAutoFillBackground(False)
            #Decoration
            plotWdg.setCanvasBackground(Qt.white)
            plotWdg.plotLayout().setAlignCanvasToScales(True)
            zoomer = QwtPlotZoomer(QwtPlot.xBottom, QwtPlot.yLeft,
                                   QwtPicker.DragSelection,
                                   QwtPicker.AlwaysOff, plotWdg.canvas())
            zoomer.setRubberBandPen(QPen(Qt.blue))
            if platform.system() != "Windows":
                # disable picker in Windows due to crashes
                picker = QwtPlotPicker(QwtPlot.xBottom, QwtPlot.yLeft,
                                       QwtPicker.NoSelection,
                                       QwtPlotPicker.CrossRubberBand,
                                       QwtPicker.AlwaysOn, plotWdg.canvas())
                picker.setTrackerPen(QPen(Qt.green))
            #self.dockwidget.qwtPlot.insertLegend(QwtLegend(), QwtPlot.BottomLegend);
            grid = Qwt.QwtPlotGrid()
            grid.setPen(QPen(QColor('grey'), 0, Qt.DotLine))
            grid.attach(plotWdg)
            return plotWdg
        elif library == "Matplotlib" and has_mpl:
            from matplotlib.figure import Figure
            from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg

            fig = Figure((1.0, 1.0),
                         linewidth=0.0,
                         subplotpars=matplotlib.figure.SubplotParams(left=0,
                                                                     bottom=0,
                                                                     right=1,
                                                                     top=1,
                                                                     wspace=0,
                                                                     hspace=0))

            font = {'family': 'arial', 'weight': 'normal', 'size': 12}
            matplotlib.rc('font', **font)

            rect = fig.patch
            rect.set_facecolor((0.9, 0.9, 0.9))

            self.subplot = fig.add_axes((0.07, 0.15, 0.92, 0.82))
            self.subplot.set_xbound(0, 1000)
            self.subplot.set_ybound(0, 1000)
            self.manageMatplotlibAxe(self.subplot)
            canvas = FigureCanvasQTAgg(fig)
            sizePolicy = QSizePolicy(QSizePolicy.Expanding,
                                     QSizePolicy.Expanding)
            sizePolicy.setHorizontalStretch(0)
            sizePolicy.setVerticalStretch(0)
            canvas.setSizePolicy(sizePolicy)
            return canvas