Example #1
0
 def __init__(self, parent, red_farred):
     super(GraphPanel, self).__init__(parent, -1)
     self.SetBackgroundColour((218,238,255))
     self.SetWindowStyle(wx.RAISED_BORDER)
     figure = Figure()
     figure.set_facecolor(color='#daeeff')
     sizer = wx.BoxSizer(wx.VERTICAL)
     self.axes = figure.add_subplot(111)
     self.x_data = range(340, 821)
     self.axes.plot(self.x_data, [0] * 481, label='Scan 0')
     self.axes.legend(loc=1)
     self.canvas = FigureCanvasWxAgg(self, -1, figure)
     figure.tight_layout(pad=2.0)
     sizer.Add(self.canvas, 1, wx.EXPAND | wx.ALL)
     sizer.AddSpacer(20)
     add_toolbar(sizer, self.canvas)
     self.SetSizer(sizer)
     self.canvas.draw()
     cid1 = self.canvas.mpl_connect('motion_notify_event', self.on_movement)
     cid2 = self.canvas.mpl_connect('button_press_event', self.on_press)
     cid3 = self.canvas.mpl_connect('scroll_event', self.on_scroll)
     self.integ_lines = []
     self.fractional_lines = []
     self.plot_unit = -1
     self.plot_mode = -1
     self.text = None
     self.x_label = X_LABEL
     self.red_farred = red_farred
Example #2
0
 def end(self):
     num_frames = len(self.results)
     xmax, ymax = self._get_lims()
     prev = None
     for i in range(10000):
         try:
             os.remove('%s%04d.png' % (self.pref, i + 1))
         except:
             break  # No more
     for i in range(num_frames):
         fig = Figure(figsize=(16, 9))
         canvas = FigureCanvasAgg(fig)
         ax = fig.add_subplot(111, xlim=(-xmax, xmax), ylim=(-ymax, ymax))
         prev = _plot_2d(ax, self.results[i], prev)
         ax.text(0, -ymax, 'Generation: %03d' % (i * self.step),
                 ha='left', va='bottom', fontsize=16)
         fig.tight_layout()
         canvas.print_figure('%s%04d.png' % (self.pref, i + 1))
     try:
         os.remove('%s.mp4' % self.pref)
     except FileNotFoundError:
         pass  # OK, does not exist
     os.system('avconv -r 1 -f image2 -i %s%%04d.png %s.mp4 -vcodec libx264'
               % (self.pref, self.pref))
     for i in range(num_frames):
         os.remove('%s%04d.png' % (self.pref, i + 1))
     return '%s.mp4' % self.pref
Example #3
0
class DevPlot(Plot):

  def __init__(self,k1={'intel_snb' : ['intel_snb','intel_snb','intel_snb']},k2={'intel_snb':['LOAD_1D_ALL','INSTRUCTIONS_RETIRED','LOAD_OPS_ALL']},processes=1,**kwargs):
    self.k1 = k1
    self.k2 = k2
    super(DevPlot,self).__init__(processes=processes,**kwargs)

  def plot(self,jobid,job_data=None):
    self.setup(jobid,job_data=job_data)
    cpu_name = self.ts.pmc_type
    type_name=self.k1[cpu_name][0]
    events = self.k2[cpu_name]

    ts=self.ts

    n_events = len(events)
    self.fig = Figure(figsize=(8,n_events*2+3),dpi=110)

    do_rate = True
    scale = 1.0
    if type_name == 'mem': 
      do_rate = False
      scale=2.0**10
    if type_name == 'cpu':
      scale=ts.wayness*100.0

    for i in range(n_events):
      self.ax = self.fig.add_subplot(n_events,1,i+1)
      self.plot_lines(self.ax, [i], xscale=3600., yscale=scale, do_rate = do_rate)
      self.ax.set_ylabel(events[i],size='small')
    self.ax.set_xlabel("Time (hr)")
    self.fig.subplots_adjust(hspace=0.5)
    self.fig.tight_layout()

    self.output('devices')
Example #4
0
 def make_Histogram(self):
     self.read_table()
     functions.process(self.dispData, self.dicData)
     self.make_CorrFigs()
     self.make_TMSFig()
     on = self.dicData['hdf5_on']  # this one contains all the histogram axis
     res = self.dicData['res']  # this contains the calculation results
     fig1 = Figure(facecolor='white', edgecolor='white')
     ax1 = fig1.add_subplot(2, 2, 1)
     ax2 = fig1.add_subplot(2, 2, 2)
     ax3 = fig1.add_subplot(2, 2, 3)
     ax4 = fig1.add_subplot(2, 2, 4)
     ax1.imshow(res.IQmapM_avg[0], interpolation='nearest', origin='low',
                extent=[on.xII[0], on.xII[-1], on.yII[0], on.yII[-1]], aspect='auto')
     ax2.imshow(res.IQmapM_avg[1], interpolation='nearest', origin='low',
                extent=[on.xQQ[0], on.xQQ[-1], on.yQQ[0], on.yQQ[-1]], aspect='auto')
     ax3.imshow(res.IQmapM_avg[2], interpolation='nearest', origin='low',
                extent=[on.xIQ[0], on.xIQ[-1], on.yIQ[0], on.yIQ[-1]], aspect='auto')
     ax4.imshow(res.IQmapM_avg[3], interpolation='nearest', origin='low',
                extent=[on.xQI[0], on.xQI[-1], on.yQI[0], on.yQI[-1]], aspect='auto')
     fig1.tight_layout()
     ax1.set_title('IIc')
     ax2.set_title('QQc')
     ax3.set_title('IQc')
     ax4.set_title('QIc')
     self.update_page_1(fig1)  # send figure to the show_figure terminal
     self.read_table()
Example #5
0
    def plot_stats(self, coll="fireworks", interval="days", num_intervals=5,
                   states=None, style='bar', **kwargs):
        """
        Makes a chart with the summary data

        Args:
            coll (str): collection, either "fireworks", "workflows", or "launches"
            interval (str): one of "minutes", "hours", "days", "months", "years"
            num_intervals (int): number of intervals to go back in time from present moment
            states ([str]): states to include in plot, defaults to all states,
                note this also specifies the order of stacking
            style (str): style of plot to generate, can either be 'bar' or 'fill'

        Returns:
            matplotlib plot module
        """
        results = self.get_stats(coll, interval, num_intervals, **kwargs)
        state_to_color = {"RUNNING": "#F4B90B",
                      "WAITING": "#1F62A2",
                      "FIZZLED": "#DB0051",
                      "READY": "#2E92F2",
                      "COMPLETED": "#24C75A",
                      "RESERVED": "#BB8BC1",
                      "ARCHIVED": "#7F8287",
                      "DEFUSED": "#B7BCC3",
                      "PAUSED": "#FFCFCA"
                      }
        states = states or state_to_color.keys()

        from matplotlib.figure import Figure
        from matplotlib.ticker import MaxNLocator

        fig = Figure()
        ax = fig.add_subplot(111)
        data = {state: [result['states'][state] for result in results]
                for state in states}

        bottom = [0] * len(results)
        for state in states:
            if any(data[state]):
                if style is 'bar':
                    ax.bar(range(len(bottom)), data[state], bottom=bottom,
                            color=state_to_color[state], label=state)
                elif style is 'fill':
                    ax.fill_between(range(len(bottom)),
                                    bottom, [x + y for x, y in
                                             zip(bottom, data[state])],
                                    color=state_to_color[state], label=state)
                bottom = [x + y for x, y in zip(bottom, data[state])]

        ax.yaxis.set_major_locator(MaxNLocator(integer=True))
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))

        ax.set_xlabel("{} ago".format(interval), fontsize=18)
        ax.set_xlim([-0.5, num_intervals-0.5])
        ax.set_ylabel("number of {}".format(coll), fontsize=18)
        ax.tick_params(labelsize=14)
        ax.legend(fontsize=13)
        fig.tight_layout()
        return fig
Example #6
0
def plot_mu_parameters(pdict, outinfo, lumi=False):
    from matplotlib.figure import Figure
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigCanvas
    pars = []
    for x, y in sorted(pdict.items()):
        if x.startswith('mu_'):
            pars.append( (x.split('_',1)[1], y ))
        elif x == 'Lumi' and lumi:
            pars.append((x,y))
    xlab, xpos, ypos, yerr = _get_lab_x_y_err(pars)
    fig = Figure(figsize=(4, 4))
    canvas = FigCanvas(fig)
    ax = fig.add_subplot(1,1,1)
    ax.set_xlim(0, len(xlab))
    ax.set_ylim(0, 2)
    ax.errorbar(
        xpos, ypos, yerr=yerr, **_eb_style)
    ax.axhline(1, **_hline_style)
    ax.set_xticks(xpos)
    ax.set_xticklabels(xlab)
    ax.tick_params(labelsize=_txt_size)
    outdir = outinfo['outdir']
    fig.tight_layout(pad=0.3, h_pad=0.3, w_pad=0.3)
    canvas.print_figure(
        join(outdir, 'mu' + outinfo['ext']))
class PlotOverview(qtgui.QWidget):
    def __init__(self, db):
        self.db = db
        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        super().__init__()

        lay_v = qtgui.QVBoxLayout()
        self.setLayout(lay_v)

        self.year = qtgui.QComboBox()
        self.year.currentIndexChanged.connect(self.plot)

        lay_h = qtgui.QHBoxLayout()
        lay_h.addWidget(self.year)
        lay_h.addStretch(1)
        lay_v.addLayout(lay_h)
        lay_v.addWidget(self.canvas)

        self.update()

    def update(self):
        constraints = self.db.get_constraints()
        current_year = self.year.currentText()
        self.year.clear()
        years = [y for y in range(min(constraints['start_date']).year, datetime.datetime.now().year + 1)]
        self.year.addItems([str(y) for y in years])
        try:
            self.year.setCurrentIndex(years.index(current_year))
        except ValueError:
            self.year.setCurrentIndex(len(years) - 1)

    def plot(self):
        self.fig.clf()
        ax = self.fig.add_subplot(111)

        worked = np.zeros((12, 34)) + np.nan

        year = int(self.year.currentText())
        for month in range(12):
            for day in range(calendar.monthrange(year, month + 1)[1]):
                date = datetime.date(year, month + 1, day + 1)
                if date < datetime.datetime.now().date():
                    t = self.db.get_worktime(date).total_seconds() / 60.0 - self.db.get_desiredtime(date)
                    worked[month, day] = t
                    ax.text(day, month, re.sub('0(?=[.])', '', ('{:.1f}'.format(t / 60))), ha='center', va='center')

        worked[:, 32:] = np.nansum(worked[:, :31], axis=1, keepdims=True)

        for month in range(12):
            ax.text(32.5, month, re.sub('0(?=[.])', '', ('{:.1f}'.format(worked[month, -1] / 60))), ha='center', va='center')

        ax.imshow(worked, vmin=-12*60, vmax=12*60, interpolation='none', cmap='coolwarm')
        ax.set_xticks(np.arange(31))
        ax.set_yticks(np.arange(12))
        ax.set_xticklabels(1 + np.arange(31))
        ax.set_yticklabels(calendar.month_name[1:])

        self.fig.tight_layout()
        self.canvas.draw()
Example #8
0
class MplAxes(object):
    def __init__(self, parent):
        self._parent = parent
        self._parent.resizeEvent = self.resize_graph
        self.create_axes()
        self.redraw_figure()

    def create_axes(self):
        self.figure = Figure(None, dpi=100)
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setParent(self._parent)

        axes_layout = QtGui.QVBoxLayout(self._parent)
        axes_layout.setContentsMargins(0, 0, 0, 0)
        axes_layout.setSpacing(0)
        axes_layout.setMargin(0)
        axes_layout.addWidget(self.canvas)
        self.canvas.setSizePolicy(QtGui.QSizePolicy.Expanding,
                                  QtGui.QSizePolicy.Expanding)
        self.canvas.updateGeometry()
        self.axes = self.figure.add_subplot(111)

    def resize_graph(self, event):
        new_size = event.size()
        self.figure.set_size_inches([new_size.width() / 100.0, new_size.height() / 100.0])
        self.redraw_figure()

    def redraw_figure(self):
        self.figure.tight_layout(None, 0.8, None, None)
        self.canvas.draw()
Example #9
0
def plot_alpha_parameters(pdict, outinfo):
    from matplotlib.figure import Figure
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigCanvas
    pars = []
    parlist, div_idxs = _sort_alpha(pdict)
    # in some fits we don't include systematics, don't draw anything
    if not parlist:
        return
    xlab, xpos, ypos, yerr = _get_lab_x_y_err(parlist)
    fig = Figure(figsize=(8, 4))
    fig.subplots_adjust(bottom=0.2)
    canvas = FigCanvas(fig)
    ax = fig.add_subplot(1,1,1)
    ax.set_xlim(0, len(xlab))
    ax.set_ylim(-2.5, 2.5)
    ax.errorbar(
        xpos, ypos, yerr=yerr, **_eb_style)
    ax.axhline(0, **_hline_style)
    for hline in div_idxs:
        ax.axvline(hline, **_hline_style)
    ax.set_xticks(xpos)
    ax.set_xticklabels(xlab)
    ax.tick_params(labelsize=_txt_size)
    for lab in ax.get_xticklabels():
        lab.set_rotation(60 if len(xlab) < 10 else 90)

    outdir = outinfo['outdir']
    fig.tight_layout(pad=0.3, h_pad=0.3, w_pad=0.3)
    canvas.print_figure(
        join(outdir, 'alpha' + outinfo['ext']), bboxinches='tight')
Example #10
0
    def __init__(self, histogramNumbers, histogramBins, title='', xlabel='', ylabel=''):
        self.histNum =  histogramNumbers
        self.histBins = histogramBins
        self.text_title = title
        self.text_xlabel = xlabel
        self.text_ylabel = ylabel
        
        # init figure
        fig = Figure(figsize=(5, 2.5))
        self.axes = fig.add_subplot(111)
        
        # We want the axes cleared every time plot() is called
        self.axes.hold(False)
        
        # plot data
        self.compute_initial_figure()

        # init canvas (figure -> canvas)
        FigureCanvas.__init__(self, fig)
        #self.setParent(parent)
        
        # setup
        fig.tight_layout()
        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
Example #11
0
def make_1d_plots(in_file_name, out_dir, ext, b_eff=0.1, reject='U'):
    textsize=_text_size
    taggers = {}
    with h5py.File(in_file_name, 'r') as in_file:
        for tag in ['gaia', mv1uc_name, 'jfc', 'jfit']:
            taggers[tag] = get_c_vs_u_const_beff(
                in_file, tag, b_eff=b_eff, reject=reject)

    fig = Figure(figsize=_fig_size)
    canvas = FigureCanvas(fig)
    ax = fig.add_subplot(1,1,1)
    for tname, (vc, vu) in taggers.items():
        label, color = leg_labels_colors.get(tname, (tname, 'k'))
        ax.plot(vc, vu, label=label, color=color, linewidth=_line_width)
    leg = ax.legend(title='$b$-rejection = {}'.format(1/b_eff),
                    prop={'size':textsize})
    leg.get_title().set_fontsize(textsize)

    setup_1d_ctag_legs(ax, textsize, reject=reject)

    fig.tight_layout(pad=0, h_pad=0, w_pad=0)
    if not isdir(out_dir):
        os.mkdir(out_dir)
    file_name = '{}/{rej}Rej-vs-cEff-brej{}{}'.format(
        out_dir, int(1.0/b_eff), ext, rej=reject.lower())
    canvas.print_figure(file_name, bbox_inches='tight')
def bargraph(request):
    p = request.GET    

    try:
        d = [(float(p['d10']), float(p['d11']), float(p['d12']), float(p['d13']), float(p['d14'])),
             (float(p['d20']), float(p['d21']), float(p['d22']), float(p['d23']), float(p['d24'])),
             (float(p['d30']), float(p['d31']), float(p['d32']), float(p['d33']), float(p['d34'])),
             (float(p['d40']), float(p['d41']), float(p['d42']), float(p['d43']), float(p['d44'])),
             (float(p['d50']), float(p['d51']), float(p['d52']), float(p['d53']), float(p['d54'])),
             (float(p['d60']), float(p['d61']), float(p['d62']), float(p['d63']), float(p['d64'])),
             (float(p['d70']), float(p['d71']), float(p['d72']), float(p['d73']), float(p['d74'])),
             (float(p['d80']), float(p['d81']), float(p['d82']), float(p['d83']), float(p['d84']))]
    except:
        return render(request,"bad.html", { 'type':'bargraph' })

    tickM = ["2. Culture for retreatment",
         "3. GeneXpert for HIV positive only", 
         "4. GeneXpert for smear positive only", 
         "5. GeneXpert for all",
         "6. GeneXpert for all, culture confirmed",
         "7. MODS/TLA",
         "8. Same-day smear microscopy",
         "9. Same-day GeneXpert"]

    colors = ["grey","blue","green","yellow","red"]

    ndata = zip(*d)

    loc = np.arange(len(ndata[0]))
    width = 0.15

    fig = Figure(facecolor='white')
    canvas = FigureCanvas(fig)

    ax = fig.add_subplot(111)

    rect = [ax.bar(loc+width*i, ndata[i], width, color=colors[i]) 
            for i in range(len(ndata))]

    ax.set_ylim(-50,100)
    ax.set_xlim(-width*4, len(loc) +(4*width))

    ax.set_xticks(loc + (2.5*width))

    ax.set_xticklabels(tickM, rotation='30', size='small', stretch='condensed',
                       ha='right' )

    ax.legend ((rect[0][0], rect[1][0], rect[2][0], rect[3][0], rect[4][0]),
                ("TBInc", "MDRInc", "TBMort", "Yr1Cost", "Yr5Cost"),loc='best')

    ax.set_title ("Graph Comparison")
    ax.axhline(color='black')

    ax.set_ylabel('percentage change from baseline')

    fig.tight_layout()

    response=HttpResponse(content_type='image/png')
    canvas.print_png(response,facecolor=fig.get_facecolor())
    return response
Example #13
0
def _get_line_canvas(planes, axis, axsize=14, rebin=5, approval=''):
    """return the canvas with everything drawn on it"""
    fig = Figure(figsize=(5.0, 5.0*3/4))
    canvas = FigureCanvas(fig)
    ax = fig.add_subplot(1,1,1)
    for flav, plane in sorted(planes.items()):
        xv, yv = plane.project(axis)
        yv = yv.reshape(-1, rebin).sum(1)
        xv = xv.reshape(-1, rebin)[:,0]
        yv /= yv.sum()
        lab = long_particle_names[flav] + ' jets'
        color = _peter_colors[flav]
        ax.plot(xv, yv, drawstyle='steps-post', label=lab, color=color)
    discriminated = {'U':r'\mathrm{light}', 'B':'b'}[_parts_vs_ax[axis]]
    xname = r'$\log(P_{{c}} / P_{{ {} }})$'.format(discriminated)
    ax.set_xlabel(xname, x=0.98, ha='right', size=axsize)
    ax.set_xlim(_range_vs_ax[axis])
    ax.set_yscale('log')
    ax.set_ylabel('Fraction of jets',
                  y=0.98, ha='right', size=axsize)
    ax.plot([_cut_vs_ax[axis]]*2, [1e-4, 0.05],'-', color='orange', lw=2)
    ax.set_ylim(1e-4, 1.5)
    formatter = FuncFormatter(log_formatting)
    ax.yaxis.set_major_formatter(formatter)
    legprops = dict(size=_text_size)
    legloc = _legpos_vs_ax[axis]
    ax.legend(framealpha=0, loc=legloc, prop=legprops)
    off_x = 0.05 if 'right' in legloc else 0.5
    off_y = 0.93
    ysp = 0.07
    add_atlas(ax, off_x + 0.13, off_y, size=_text_size, approval=approval)
    add_official_garbage(ax, off_x, off_y - ysp, size=_text_size, ysp=0.07)
    fig.tight_layout(pad=0, h_pad=0, w_pad=0)
    return canvas
Example #14
0
def make_1d_overlay(in_file_name, out_dir, ext, subset, b_effs=[0.1, 0.2]):
    textsize = _text_size - 2
    b_eff_styles = _b_eff_styles

    taggers = {x:{} for x in b_effs}
    with h5py.File(in_file_name, 'r') as in_file:
        for b_eff in taggers:
            for tag in (subset or _default_overlay_1d):
                taggers[b_eff][tag] = get_c_vs_u_const_beff(
                    in_file, tag, b_eff=b_eff)

    fig = Figure(figsize=_fig_size)
    canvas = FigureCanvas(fig)
    ax = fig.add_subplot(1,1,1)
    for b_eff, linestyle in zip(b_effs, b_eff_styles):
        for tname, (vc, vu) in taggers[b_eff].items():
            label, color = leg_labels_colors.get(tname, (tname, 'k'))
            lab = '$1 / \epsilon_{{ b }} = $ {rej:.0f}, {tname}'.format(
                rej=1/b_eff, tname=label)
            ax.plot(vc, vu, label=lab, color=color, linewidth=_line_width,
                    linestyle=linestyle)
    ax.set_xlim(0.1, 0.5)
    legprops = {'size':textsize}
    leg = ax.legend(prop=legprops)
    leg.get_title().set_fontsize(textsize)

    setup_1d_ctag_legs(ax, textsize)

    fig.tight_layout(pad=0, h_pad=0, w_pad=0)
    if not isdir(out_dir):
        os.mkdir(out_dir)
    file_name = '{}/ctag-1d-brej-overlay{}'.format(
        out_dir, ext)
    canvas.print_figure(file_name, bbox_inches='tight')
Example #15
0
    def draw_timeline(self):
        fig = Figure(figsize=(7, 3))
        canvas = FigureCanvas(fig)
        ax = fig.add_subplot(111, axisbg="white")

        xseries = range(len(TRADING_MINUTES))
        pct = lambda x, _: "{0:1.1f}%".format(100*x)
        xseries_show = xseries[::30]
        xlabels_show = TRADING_MINUTES[::30]

        ax.clear()
        ax.plot(xseries, self._portfolio_timeline, label="portfolio", linewidth=1.0, color="r")
        ax.plot(xseries, self._benchmark_timeline, label="benchmark", linewidth=1.0, color="b")

        ax.yaxis.set_major_formatter(FuncFormatter(pct))
        for item in ax.get_yticklabels():
            item.set_size(10)

        ax.set_xlim(0, 240)
        ax.set_xticks(xseries_show)
        ax.set_xticklabels(xlabels_show, fontsize=10)

        ax.grid(True)
        ax.legend(loc=2, prop={"size": 9})

        fig.tight_layout()
        fig.savefig(SNAPSHOT_IMG_FILE)
        logger.info("Snapshot image saved at ./static/temp/snapshot.jpg.")
Example #16
0
	def draw(self, isotherms=True, isochores=True, isentrops=True, qIsolines=True, fig = None):
		if (fig is None):
			fig = Figure(figsize=(16.0, 10.0), facecolor='white')
		self.fig = fig
		self.ax = self.fig.add_subplot(1,1,1)
		self.ax.set_xlabel('Enthalpy [kJ/kg]')
		self.ax.set_ylabel('Pressure [bar]')
		self.ax.set_title(self.fluidName, y=1.04)
		self.ax.grid(True, which = 'both')
		
		x_in, y_in = self.fig.get_size_inches()
		dpi = self.fig.get_dpi()
		self.x_pts = x_in * dpi
		self.y_pts = y_in * dpi
		
		self.ax.set_xlim(self.hMin / 1e3, self.hMax / 1e3)
		self.ax.set_ylim(self.pMin / 1e5, self.pMax / 1e5)
		
		if qIsolines:
			self.plotDome()
		if isochores:
			self.plotIsochores()
		if isotherms:
			self.plotIsotherms()
		if isentrops:
			self.plotIsentrops()
		
		self.ax.legend(loc='upper center',  bbox_to_anchor=(0.5, 1.05),  fontsize="small", ncol=4)
		fig.tight_layout()
		return fig
Example #17
0
def plotThreeWay(hist, title, filename=None, x_axis_title=None, minimum=None, maximum=None, bins=101):  # the famous 3 way plot (enhanced)
    if minimum is None:
        minimum = 0
    elif minimum == 'minimum':
        minimum = np.ma.min(hist)
    if maximum == 'median' or maximum is None:
        median = np.ma.median(hist)
        maximum = median * 2  # round_to_multiple(median * 2, math.floor(math.log10(median * 2)))
    elif maximum == 'maximum':
        maximum = np.ma.max(hist)
        maximum = maximum  # round_to_multiple(maximum, math.floor(math.log10(maximum)))
    if maximum < 1 or hist.all() is np.ma.masked:
        maximum = 1

    x_axis_title = '' if x_axis_title is None else x_axis_title
    fig = Figure()
    FigureCanvas(fig)
    fig.patch.set_facecolor('white')
    ax1 = fig.add_subplot(311)
    create_2d_pixel_hist(fig, ax1, hist, title=title, x_axis_title="column", y_axis_title="row", z_min=minimum if minimum else 0, z_max=maximum)
    ax2 = fig.add_subplot(312)
    create_1d_hist(fig, ax2, hist, bins=bins, x_axis_title=x_axis_title, y_axis_title="#", x_min=minimum, x_max=maximum)
    ax3 = fig.add_subplot(313)
    create_pixel_scatter_plot(fig, ax3, hist, x_axis_title="channel=row + column*336", y_axis_title=x_axis_title, y_min=minimum, y_max=maximum)
    fig.tight_layout()
    if not filename:
        fig.show()
    elif isinstance(filename, PdfPages):
        filename.savefig(fig)
    else:
        fig.savefig(filename)
Example #18
0
def confusion_matrix_(y_test,
                      y_pred,
                      target_names,
                      normalize=False,
                      title='Confusion matrix',
                      cmap=plt.cm.Blues):
    cm = confusion_matrix(y_test, y_pred)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    np.set_printoptions(precision=2)
    fig = Figure()
    canvas = FigureCanvas(fig)
    ax = fig.add_subplot(111)
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    fig.colorbar(im)
    tick_marks = np.arange(len(target_names))
    ax.set_xticks(tick_marks)
    ax.set_xticklabels(target_names, rotation=45)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(target_names)
    fig.tight_layout()
    ax.set_title(title)
    ax.set_ylabel('True label')
    ax.set_xlabel('Predicted label')
    return fig
Example #19
0
 def __init__(self, dataForCandle=None,parent=None, width=5, height=4, dpi=100):
     fig = Figure(figsize=(width, height), dpi=dpi)
     axes01 = fig.add_subplot(111)
     #axes02 = fig.add_subplot(212)
     self.candlePlot(axes01,dataForCandle,alpha=1.0) 
     
     FigureCanvas.__init__(self, fig)
     fig.tight_layout()
Example #20
0
    def init_gui(self):
        self.builder = Gtk.Builder()
        self.builder.add_from_file("finally.glade")

        self.builder.connect_signals( 
                                    {
                                      "window_destroy"  : self.window_destroy,
                                      "press_button"    : self.press_button,
                                      "clear_text"      : self.clear_text,
                                      "remove_last_char": self.remove_last_char,
                                      "calculate"       : self.calculate,
                                      "switch_page"     : self.switch_page,
                                      "num_base_change" : self.num_base_change,
                                      "fix_cursor"      : self.fix_cursor,
                                      "enable_textview" : self.enable_textview,
                                      "disable_textview": self.disable_textview,
                                      "prog_calc"       : self.prog_calc,
                                      "back_to_future"  : self.back_to_future,
                                      "plot"            : self.plot
                                    }
                                    )

        self.window = self.builder.get_object("main_window")
        self.window.set_size_request(250,305)
        self.window.set_title("The Calculator")
        self.window.set_icon_from_file("/usr/share/pixmaps/thecalculator-icon.png")
        self.text_buff = self.builder.get_object("class_17").get_buffer()  # Access buffer from TextView
        self.builder.get_object("class_17").grab_focus()
        self.builder.get_object("radiobutton3").set_active(True)
        self.num_base_change(self.builder.get_object("radiobutton3"))

        ############### PLOT FUNCTION ##############

        sw = self.builder.get_object("scrolledwindow1")
        sw2 = self.builder.get_object("scrolledwindow2")

        fig = Figure(figsize=(5,5),dpi=120)
        self.ax = fig.add_subplot(111)

        self.x = arange(-3.1415,3.1415,0.01)
        self.y = sin(self.x)

        self.ax.plot(self.x,self.y,label="sin(x)")
        self.ax.set_xlim(-3.2,3.2)
        self.ax.set_ylim(-1.2,1.2)
        self.ax.grid(True)

        fig.set_size_inches(9.5, 5.5, forward=True)

        fig.tight_layout()

        self.canvas = FigureCanvas(fig)
        sw.add_with_viewport(self.canvas)
        self.canvas.show()

        toolbar = NavigationToolbar(self.canvas, self.window)
        sw2.add_with_viewport(toolbar)
Example #21
0
class FigureCanvas(FigureCanvasQTAgg):
    def __init__(self, parent_axis=None, parent=None, width=5, height=4, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super(FigureCanvas, self).__init__(self.fig)
        self.setParent(parent)

        self.axis = self.fig.add_subplot(111, sharex=parent_axis)
        self.fig.tight_layout(pad=3)
        self.axis.ticklabel_format(useOffset=False)
class MyMplCanvas(FigureCanvas):
    def __init__(self, parent=None):
        self.fig = Figure(dpi=40)
        self.axes = self.fig.add_subplot(111)
        self.axes.hold(False)
        self.fig.set_facecolor('white')
        FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)
        self.axes.get_yaxis().set_visible(False)
        self.fig.tight_layout(pad=1, h_pad=1)
Example #23
0
def main():
    parser = argparse.ArgumentParser(
        description="Calculate 1D power spectrum within the EoR window")
    parser.add_argument("-C", "--clobber", dest="clobber", action="store_true",
                        help="overwrite the output files if already exist")
    parser.add_argument("-s", "--step", dest="step", type=float, default=1.1,
                        help="step ratio (>1; default: 1.1) between 2 " +
                        "consecutive radial k bins, i.e., logarithmic grid. " +
                        "if specified a value <=1, then an equal-width grid " +
                        "of current k bin size will be used.")
    parser.add_argument("-F", "--fov", dest="fov",
                        type=float, required=True,
                        help="instrumental FoV to determine the EoR window; " +
                        "SKA1-Low has FoV ~ 3.12 / (nu/200MHz) [deg], i.e., " +
                        "~5.03 @ 124, ~3.95 @ 158, ~3.18 @ 196")
    parser.add_argument("-e", "--conv-width", dest="conv_width",
                        type=float, default=3.0,
                        help="characteristic convolution width (default: 3.0)")
    parser.add_argument("-p", "--k-perp-min", dest="k_perp_min", type=float,
                        help="minimum k wavenumber perpendicular to LoS; " +
                        "unit: [Mpc^-1]")
    parser.add_argument("-P", "--k-perp-max", dest="k_perp_max", type=float,
                        help="maximum k wavenumber perpendicular to LoS")
    parser.add_argument("-l", "--k-los-min", dest="k_los_min", type=float,
                        help="minimum k wavenumber along LoS")
    parser.add_argument("-L", "--k-los-max", dest="k_los_max", type=float,
                        help="maximum k wavenumber along LoS")
    parser.add_argument("--no-plot", dest="noplot", action="store_true",
                        help="do not plot and save the calculated 1D power " +
                        "power within the EoR window")
    parser.add_argument("-i", "--infile", dest="infile", required=True,
                        help="2D power spectrum FITS file")
    parser.add_argument("-o", "--outfile", dest="outfile", required=True,
                        help="output TXT file to save the PSD data")
    args = parser.parse_args()

    if (not args.clobber) and os.path.exists(args.outfile):
        raise OSError("outfile '%s' already exists" % args.outfile)

    ps2d = PS2D(args.infile, fov=args.fov, e=args.conv_width,
                k_perp_min=args.k_perp_min, k_perp_max=args.k_perp_max,
                k_los_min=args.k_los_min, k_los_max=args.k_los_max)
    ps1d = PS1D(ps2d, step=args.step)
    ps1d.calc_ps1d()
    ps1d.save(args.outfile)

    if not args.noplot:
        fig = Figure(figsize=(8, 8), dpi=150)
        FigureCanvas(fig)
        ax = fig.add_subplot(1, 1, 1)
        ps1d.plot(ax=ax)
        fig.tight_layout()
        plotfile = os.path.splitext(args.outfile)[0] + ".png"
        fig.savefig(plotfile)
        print("Plotted 1D power spectrum within EoR window: %s" % plotfile)
Example #24
0
class CanvasPanel( wx.Panel ):

	def __init__( self, parent ):
		wx.Panel.__init__( self, parent )
		self.figure = Figure()
		self.axes = self.figure.add_subplot(111)
		
		self.canvas = FigureCanvas( self, -1, self.figure )
		self.sizer = wx.BoxSizer( wx.VERTICAL )
		self.sizer.Add( self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW )
		self.SetSizer( self.sizer )
		self.Fit()
		
	def AddLine( self, x0, y0, x1, y1, colour = 'blue', showArrows = True, annotationText = '' ):
		minimum_scale_gap =1
		
		lineWidth = 2
		lines = self.axes.plot( [x0, x1], [y0, y1], color=colour, linewidth = lineWidth )

		if len( annotationText ) > 0:
			self.axes.annotate( annotationText, xy=(x0, y0), xytext=( x0,y0), arrowprops=dict( facecolor='black', shrink=0.75 ) )
		
		if showArrows == True:
			if not ( x0 == x1 and y0 == y1 ):
				dx = 0.5 * ( x1 - x0 )
				dy = 0.5 * ( y1 - y0 )
				
				line_len = sqrt(  dx*dx + dy*dy  );

				width = line_len / 8;
				length = line_len / 5 
				arrow = self.axes.arrow( x0, y0, dx, dy, color=colour, head_width=width, head_length=length )
			
		return lines[0]
		
	def SetTitle( self, title ):
		self.axes.set_title(title)
		self.RefreshView()
		
	def SetAxisTitle( self, xTitle, yTitle ):
		for ax in self.figure.axes:
			ax.set_xlabel( xTitle )
			ax.set_ylabel( yTitle )
		self.figure.tight_layout()
		self.RefreshView()
		
	def ClearLines( self ):
		for ax in self.figure.axes:
			ax.lines = []
			ax.cla()
		self.RefreshView()
			
	def RefreshView( self ):
		self.figure.canvas.draw()
Example #25
0
File: plot.py Project: trmznt/genaf
def pie_chart(df):

    fig = Figure()
    canvas = FigureCanvas(fig)
    ax = fig.add_subplot(111, aspect=1)

    ax.pie(df.iloc[:, 1], labels=df.iloc[:, 0], counterclock=False, startangle=90)
    ax.set_xlabel(df.columns[0])
    fig.tight_layout()

    return save_figure(canvas)
    def __init__(self, master, x_train, y_train, x_test, y_test, evaluator):
        Tk.Frame.__init__(self, master)
        # evaluator.load_data(x_train, y_train, x_test, y_test)
        # evaluator.train()
        x_test_r = evaluator.reduce(x_test)  # 特征降维

        frame_test = Tk.Frame(self)
        frame_test.pack(fill='x', expand=1, padx=15, pady=15)
        figure_test = Figure(figsize=(4, 4), dpi=100)
        subplot_test = figure_test.add_subplot(111)
        subplot_test.set_title('Breast Cancer Testing')
        figure_test.tight_layout()

        h = .02  # step size in the mesh
        x1_min, x1_max = x_test_r[:, 0].min() - 1, x_test_r[:, 0].max() + 1
        x2_min, x2_max = x_test_r[:, 1].min() - 1, x_test_r[:, 1].max() + 1
        xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, h), np.arange(x2_min, x2_max, h))
        yy = evaluator.clf.predict(np.c_[xx1.ravel(), xx2.ravel()])
        yy = yy.reshape(xx1.shape)
        subplot_test.contourf(xx1, xx2, yy, cmap=plt.cm.get_cmap("Paired"), alpha=0.8)
        subplot_test.scatter(x_test_r[:, 0], x_test_r[:, 1], c=y_test, cmap=plt.cm.get_cmap("Paired"))
        self.attach_figure(figure_test, frame_test)

        # 第5.1页 测试性能指标 precision recall f_value
        y_pred = evaluator.pipeline.predict(x_test)
        frame_matrix = Tk.Frame(self)
        frame_matrix.pack(side=Tk.LEFT, fill='x', expand=1, padx=15, pady=15)
        figure_matrix = Figure(figsize=(4, 4), dpi=100)
        subplot_matrix = figure_matrix.add_subplot(111)

        confmat = confusion_matrix(y_true=y_test, y_pred=y_pred)
        subplot_matrix.matshow(confmat, cmap=plt.cm.get_cmap("Blues"), alpha=0.3)
        for i in range(confmat.shape[0]):
            for j in range(confmat.shape[1]):
                subplot_matrix.text(x=j, y=i, s=confmat[i, j], va='center', ha='center')

        subplot_matrix.set_xlabel('predicted label')
        subplot_matrix.set_ylabel('true label')
        self.attach_figure(figure_matrix, frame_matrix)

        frame_result = Tk.Frame(self)
        frame_result.pack(side=Tk.LEFT, fill='x', expand=1, padx=15, pady=15)
        Tk.Label(frame_result, text="Accuracy: ").grid(row=0, column=0, sticky=Tk.W)
        Tk.Label(frame_result, text=str(evaluator.pipeline.score(x_test, y_test))).grid(row=0, column=1,
                                                                                        sticky=Tk.W)
        Tk.Label(frame_result, text="Precision: ").grid(row=1, column=0, sticky=Tk.W)
        Tk.Label(frame_result, text=str(precision_score(y_true=y_test, y_pred=y_pred))).grid(row=1, column=1,
                                                                                             sticky=Tk.W)
        Tk.Label(frame_result, text="Recall: ").grid(row=2, column=0, sticky=Tk.W)
        Tk.Label(frame_result, text=str(recall_score(y_true=y_test, y_pred=y_pred))).grid(row=2, column=1,
                                                                                          sticky=Tk.W)
        Tk.Label(frame_result, text="F-value: ").grid(row=3, column=0, sticky=Tk.W)
        Tk.Label(frame_result, text=str(f1_score(y_true=y_test, y_pred=y_pred))).grid(row=3, column=1, sticky=Tk.W)
Example #27
0
    def __init__(self, master, item, event):
        tk.Toplevel.__init__(self, master)
        self.wm_title("Waveform: event #{}".format(item))

        fig = Figure(figsize=(4,3), dpi=100)
        canvas = FigureCanvasTkAgg(fig, self)
        canvas.get_tk_widget().pack(fill=tk.BOTH, expand=1)
        toolbar = NavigationToolbar2TkAgg(canvas, self)
        ax = fig.gca()
        event.plot(ax)
        ax.set_title("Event #{}".format(item))
        ax.grid(True)
        fig.tight_layout()
        self.update()
Example #28
0
    def save_qa(helper, band, obsids,
                orders_w_solutions,
                reidentified_lines_map, p, m):
        # filter out the line indices not well fit by the surface


        keys = reidentified_lines_map.keys()
        di_list = [len(reidentified_lines_map[k_][0]) for k_ in keys]

        endi_list = np.add.accumulate(di_list)

        filter_mask = [m[endi-di:endi] for di, endi in zip(di_list, endi_list)]
        #from itertools import compress
        # _ = [list(compress(indices, mm)) for indices, mm \
        #      in zip(line_indices_list, filter_mask)]
        # line_indices_list_filtered = _

        reidentified_lines_ = [reidentified_lines_map[k_] for k_ in keys]
        _ = [(v_[0][mm], v_[1][mm]) for v_, mm \
             in zip(reidentified_lines_, filter_mask)]

        reidentified_lines_map_filtered = dict(zip(orders_w_solutions, _))


        if 1:
            from matplotlib.figure import Figure
            from igrins.libs.ecfit import get_ordered_line_data, check_fit

            xl, yl, zl = get_ordered_line_data(reidentified_lines_map)

            fig1 = Figure(figsize=(12, 7))
            check_fit(fig1, xl, yl, zl, p,
                      orders_w_solutions,
                      reidentified_lines_map)
            fig1.tight_layout()

            fig2 = Figure(figsize=(12, 7))
            check_fit(fig2, xl[m], yl[m], zl[m], p,
                      orders_w_solutions,
                      reidentified_lines_map_filtered)
            fig2.tight_layout()

        from igrins.libs.qa_helper import figlist_to_pngs
        igr_path = helper.igr_path
        sky_basename = helper.get_basename(band, obsids[0])
        sky_figs = igr_path.get_section_filename_base("QA_PATH",
                                                      "oh_fit2d",
                                                      "oh_fit2d_"+sky_basename)
        figlist_to_pngs(sky_figs, [fig1, fig2])
def _plot_counts(counts, out_file):
    from matplotlib.figure import Figure
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigCanvas
    from numpy import arange

    fig = Figure(figsize=(8, 3))
    canvas = FigCanvas(fig)
    ax = fig.add_subplot(1,1,1)

    trash, ex_sys = next(iter(counts.items()))
    ex_regs, ex_nom, *nada, ex_data = ex_sys

    # offsetter for systematics
    sysw = 0.5
    syst_increment = sysw / (len(counts) - 1)
    syst_initial = -sysw / 2
    sys_num = {x:n for n, x in enumerate(sorted(counts.keys()))}
    def get_offset(syst):
        return syst_initial + sys_num[syst] * syst_increment

    x_vals_base = arange(len(ex_regs)) + 0.5
    ax.errorbar(
        x_vals_base, ex_data / ex_nom, yerr=ex_data**0.5/ ex_nom,
        fmt='o', color='k', label='data')

    ax.set_xticks(x_vals_base)
    ax.set_xticklabels([reg_names.get(x,x) for x in ex_regs])
    for sysnm, (regions, nom, down, up, data) in sorted(counts.items()):
        x_vals = x_vals_base + get_offset(sysnm)
        color = _syst_colors[sys_num[sysnm]]
        ax.set_xlim(0, len(regions))
        up_vals = up / nom
        ax.plot(x_vals, up_vals , '^', color=color,
                label=alpha_names.get(sysnm, sysnm))
        dn_vals = down / nom
        ax.plot(x_vals, dn_vals, 'v', color=color)

    ax.tick_params(labelsize=_txt_size)
    leg = ax.legend(
        numpoints=1, ncol=5, borderaxespad=0.0, loc='upper left',
        handletextpad=0, columnspacing=1, framealpha=0.5, fontsize=10)

    ax.axhline(1, linestyle='--', color=(0,0,0,0.5))
    ylims = ax.get_ylim()
    # ax.set_ylim(ylims[0], (ylims[1] - ylims[0]) *0.2 + ylims[1])
    ax.set_ylim(0.8, 1.3)
    ax.set_ylabel('Variation / Nominal')
    fig.tight_layout(pad=0.3, h_pad=0.3, w_pad=0.3)
    canvas.print_figure(out_file, bboxinches='tight')
    def __init__(self, master, x_train, y_train, x_test, y_test, evaluator):
        Tk.Frame.__init__(self, master)
        frame_linear_param = Tk.Frame(self)
        frame_linear_param.pack(fill='x', expand=1, padx=15, pady=15)
        figure_gs = Figure(figsize=(6, 4), dpi=100)
        subplot_linear_param = figure_gs.add_subplot(111)
        figure_gs.tight_layout()
        subplot_linear_param.set_xscale('log')
        subplot_linear_param.set_ylim([0.5, 1.0])
        subplot_linear_param.set_xlabel("C")
        subplot_linear_param.set_ylabel("Accuracy")
        subplot_linear_param.set_title("GridSearchCV on parameter C in SVM with linear-kernel")

        param_range = [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]
        param_grid = [{'clf__C': param_range, 'clf__kernel': ['linear']},
                      {'clf__C': param_range, 'clf__gamma': param_range, 'clf__kernel': ['rbf']}]
        gs = GridSearchCV(estimator=evaluator.pipeline, param_grid=param_grid, scoring='accuracy', cv=10,
                          n_jobs=-1)
        gs = gs.fit(x_train, y_train)
        y_true, y_pred = y_test, gs.predict(x_test)

        lieanr_end = len(param_range)
        subplot_linear_param.plot(map(lambda e: e[0]['clf__C'], gs.grid_scores_[0:lieanr_end]),
                                  map(lambda e: e[1], gs.grid_scores_[0:lieanr_end]),
                                  linewidth=1, label='svm_linear', color="blue", marker='o', markersize=5)
        subplot_linear_param.grid()
        self.attach_figure(figure_gs, frame_linear_param)

        frame_rbf_param = Tk.Frame(self)
        frame_rbf_param.pack(fill='x', expand=1, padx=15, pady=15)
        scrollbar = Tk.Scrollbar(frame_rbf_param)
        scrollbar.pack(side=Tk.RIGHT, fill=Tk.Y)
        text_rbf_param = Tk.Text(frame_rbf_param, width=800, wrap='none', yscrollcommand=scrollbar.set)
        text_rbf_param.insert(Tk.END, "1. Best parameter: " + str(gs.best_params_) + "\n\n")
        text_rbf_param.insert(Tk.END, "2. Best parameter performance on testing data.\n\n ")
        text_rbf_param.insert(Tk.END, classification_report(y_true, y_pred))
        text_rbf_param.insert(Tk.END, "\n\n")
        text_rbf_param.insert(Tk.END, "3. All parameter searched by GridSearchCV.\n\n ")
        log = []
        for params, mean_score, scores in gs.grid_scores_:
            log.append(
                ["%0.3f" % (mean_score),
                 "(+/-%0.03f)" % (scores.std() * 2),
                 params["clf__C"],
                 params.has_key("clf__gamma") and params["clf__gamma"] or "",
                 params["clf__kernel"]])
        text_rbf_param.insert(Tk.END, tabulate(log, headers=["Accuracy", "SD", "C", "gamma", "type"]))
        text_rbf_param.pack()
        scrollbar.config(command=text_rbf_param.yview)
Example #31
0
def visualize_image_attr_multiple(attr: ndarray,
                                  original_image: Union[None, ndarray],
                                  methods: List[str],
                                  signs: List[str],
                                  titles: Union[None, List[str]] = None,
                                  fig_size: Tuple[int, int] = (8, 6),
                                  use_pyplot: bool = True,
                                  **kwargs: Any):
    r"""
    Visualizes attribution using multiple visualization methods displayed
    in a 1 x k grid, where k is the number of desired visualizations.

    Args:

        attr (numpy.array): Numpy array corresponding to attributions to be
                    visualized. Shape must be in the form (H, W, C), with
                    channels as last dimension. Shape must also match that of
                    the original image if provided.
        original_image (numpy.array, optional):  Numpy array corresponding to
                    original image. Shape must be in the form (H, W, C), with
                    channels as the last dimension. Image can be provided either
                    with values in range 0-1 or 0-255. This is a necessary
                    argument for any visualization method which utilizes
                    the original image.
        methods (list of strings): List of strings of length k, defining method
                        for each visualization. Each method must be a valid
                        string argument for method to visualize_image_attr.
        signs (list of strings): List of strings of length k, defining signs for
                        each visualization. Each sign must be a valid
                        string argument for sign to visualize_image_attr.
        titles (list of strings, optional):  List of strings of length k, providing
                    a title string for each plot. If None is provided, no titles
                    are added to subplots.
                    Default: None
        fig_size (tuple, optional): Size of figure created.
                    Default: (8, 6)
        use_pyplot (boolean, optional): If true, uses pyplot to create and show
                    figure and displays the figure after creating. If False,
                    uses Matplotlib object oriented API and simply returns a
                    figure object without showing.
                    Default: True.
        **kwargs (Any, optional): Any additional arguments which will be passed
                    to every individual visualization. Such arguments include
                    `show_colorbar`, `alpha_overlay`, `cmap`, etc.


    Returns:
        2-element tuple of **figure**, **axis**:
        - **figure** (*matplotlib.pyplot.figure*):
                    Figure object on which visualization
                    is created. If plt_fig_axis argument is given, this is the
                    same figure provided.
        - **axis** (*matplotlib.pyplot.axis*):
                    Axis object on which visualization
                    is created. If plt_fig_axis argument is given, this is the
                    same axis provided.

    Examples::

        >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
        >>> # and returns an Nx10 tensor of class probabilities.
        >>> net = ImageClassifier()
        >>> ig = IntegratedGradients(net)
        >>> # Computes integrated gradients for class 3 for a given image .
        >>> attribution, delta = ig.attribute(orig_image, target=3)
        >>> # Displays original image and heat map visualization of
        >>> # computed attributions side by side.
        >>> _ = visualize_mutliple_image_attr(attribution, orig_image,
        >>>                     ["original_image", "heat_map"], ["all", "positive"])
    """
    assert len(methods) == len(
        signs), "Methods and signs array lengths must match."
    if titles is not None:
        assert len(methods) == len(titles), (
            "If titles list is given, length must "
            "match that of methods list.")
    if use_pyplot:
        plt_fig = plt.figure(figsize=fig_size)
    else:
        plt_fig = Figure(figsize=fig_size)
    plt_axis = plt_fig.subplots(1, len(methods))

    # When visualizing one
    if len(methods) == 1:
        plt_axis = [plt_axis]

    for i in range(len(methods)):
        visualize_image_attr(attr,
                             original_image=original_image,
                             method=methods[i],
                             sign=signs[i],
                             plt_fig_axis=(plt_fig, plt_axis[i]),
                             use_pyplot=False,
                             title=titles[i] if titles else None,
                             **kwargs)
    plt_fig.tight_layout()
    if use_pyplot:
        plt.show()
    return plt_fig, plt_axis
Example #32
0
def plot_complex_image_map_line_profile_using_interface_plane(
        image_data,
        amplitude_data,
        phase_data,
        line_profile_amplitude_data,
        line_profile_phase_data,
        interface_plane,
        atom_plane_list=None,
        data_scale=1,
        atom_list=None,
        extra_marker_list=None,
        amplitude_image_lim=None,
        phase_image_lim=None,
        plot_title='',
        add_color_wheel=False,
        color_bar_markers=None,
        vector_to_plot=None,
        rotate_atom_plane_list_90_degrees=False,
        prune_outer_values=False,
        figname="map_data.jpg"):
    """
    Parameters
    ----------
    atom_list : list of Atom_Position instances
    extra_marker_list : two arrays of x and y [[x_values], [y_values]]
    """
    number_of_line_profiles = 2

    figsize = (10, 18 + 2 * number_of_line_profiles)

    fig = Figure(figsize=figsize)
    FigureCanvas(fig)

    gs = GridSpec(100 + 10 * number_of_line_profiles, 95)

    image_ax = fig.add_subplot(gs[0:45, :])
    distance_ax = fig.add_subplot(gs[45:90, :])
    colorbar_ax = fig.add_subplot(gs[90:100, :])

    line_profile_ax_list = []
    for i in range(number_of_line_profiles):
        gs_y_start = 100 + 10 * i
        line_profile_ax = fig.add_subplot(gs[gs_y_start:gs_y_start + 10, :])
        line_profile_ax_list.append(line_profile_ax)

    image_y_lim = (0, image_data.shape[0] * data_scale)
    image_x_lim = (0, image_data.shape[1] * data_scale)

    image_clim = to._get_clim_from_data(image_data, sigma=2)
    image_cax = image_ax.imshow(image_data,
                                origin='lower',
                                extent=[
                                    image_x_lim[0], image_x_lim[1],
                                    image_y_lim[0], image_y_lim[1]
                                ])

    image_cax.set_clim(image_clim[0], image_clim[1])
    image_ax.set_xlim(image_x_lim[0], image_x_lim[1])
    image_ax.set_ylim(image_y_lim[0], image_y_lim[1])
    image_ax.set_title(plot_title)

    if atom_plane_list is not None:
        for atom_plane in atom_plane_list:
            if rotate_atom_plane_list_90_degrees:
                atom_plane_x = np.array(atom_plane.get_x_position_list())
                atom_plane_y = np.array(atom_plane.get_y_position_list())
                start_x = atom_plane_x[0]
                start_y = atom_plane_y[0]
                delta_y = (atom_plane_x[-1] - atom_plane_x[0])
                delta_x = -(atom_plane_y[-1] - atom_plane_y[0])
                atom_plane_x = np.array([start_x, start_x + delta_x])
                atom_plane_y = np.array([start_y, start_y + delta_y])
            else:
                atom_plane_x = np.array(atom_plane.get_x_position_list())
                atom_plane_y = np.array(atom_plane.get_y_position_list())
            image_ax.plot(atom_plane_x * data_scale,
                          atom_plane_y * data_scale,
                          color='red',
                          lw=2)

    atom_plane_x = np.array(interface_plane.get_x_position_list())
    atom_plane_y = np.array(interface_plane.get_y_position_list())
    image_ax.plot(atom_plane_x * data_scale,
                  atom_plane_y * data_scale,
                  color='blue',
                  lw=2)

    _make_subplot_map_from_complex_regular_grid(
        distance_ax,
        amplitude_data,
        phase_data,
        atom_list=atom_list,
        amplitude_image_lim=amplitude_image_lim,
        phase_image_lim=phase_image_lim,
        atom_plane_marker=interface_plane,
        extra_marker_list=extra_marker_list,
        vector_to_plot=vector_to_plot)
    distance_ax.plot(atom_plane_x * data_scale,
                     atom_plane_y * data_scale,
                     color='red',
                     lw=2)

    line_profile_data_list = [
        line_profile_amplitude_data, line_profile_phase_data
    ]

    for line_profile_ax, line_profile_data in zip(line_profile_ax_list,
                                                  line_profile_data_list):
        _make_subplot_line_profile(line_profile_ax,
                                   line_profile_data[:, 0],
                                   line_profile_data[:, 1],
                                   prune_outer_values=prune_outer_values,
                                   scale_x=data_scale)

    amplitude_delta = 0.01 * (amplitude_image_lim[1] - amplitude_image_lim[0])
    phase_delta = 0.01 * (phase_image_lim[1] - phase_image_lim[0])
    colorbar_mgrid = np.mgrid[
        amplitude_image_lim[0]:amplitude_image_lim[1]:amplitude_delta,
        phase_image_lim[0]:phase_image_lim[1]:phase_delta]
    colorbar_rgb = get_rgb_array(colorbar_mgrid[1], colorbar_mgrid[0])
    colorbar_ax.imshow(colorbar_rgb,
                       origin='lower',
                       extent=[
                           phase_image_lim[0], phase_image_lim[1],
                           amplitude_image_lim[0], amplitude_image_lim[1]
                       ])

    colorbar_ax.set_xlabel("Phase", size=6)

    if color_bar_markers is not None:
        for color_bar_marker in color_bar_markers:
            colorbar_ax.axvline(color_bar_marker[0], color='white')
            colorbar_ax.text(color_bar_marker[0],
                             amplitude_image_lim[0] * 0.97,
                             color_bar_marker[1],
                             transform=colorbar_ax.transData,
                             va='top',
                             ha='center',
                             fontsize=8)

    if add_color_wheel:
        ax_magnetic_color_wheel_gs = GridSpecFromSubplotSpec(
            40, 40, subplot_spec=gs[45:90, :])[30:39, 3:12]
        ax_magnetic_color_wheel = fig.add_subplot(ax_magnetic_color_wheel_gs)
        ax_magnetic_color_wheel.set_axis_off()

        make_color_wheel(ax_magnetic_color_wheel)

    fig.tight_layout()
    fig.savefig(figname)
    plt.close(fig)
Example #33
0
def graph(sensor_id):

    # TODO: Add date filtering options

    sensor_config = current_app.config["SENSOR_CONFIG"]
    data_dir = current_app.config["DATA_DIR"]
    logfile = util.get_logfile_path(data_dir, sensor_id)

    if sensor_id not in util.get_sensor_list(sensor_config):
        abort(500, "Sensor ID is not in the sensor list.")
    if not os.path.isfile(logfile):
        abort(500, "No data exists for the sensor.")

    json_data = util.get_json(util.get_logfile_path(data_dir, sensor_id))

    args = request.args

    if "field" in args:
        field = args["field"]
    else:
        return abort(400, "Field not specified.")

    if len(json_data) > 0:
        valid_fields = [
            field for field in json_data[next(iter(json_data))]["data"]
            if field != "id"
        ]
        if field not in valid_fields:
            return abort(500, "Invalid field.")

    filter_dates = False
    if "hours" in request.args:
        filter_dates = True
        hours_back = int(request.args["hours"])

    test_dates = [util.get_local_datetime(date) for date in json_data.keys()]
    test_dates.sort()
    latest_date = test_dates[-1]

    dates = []
    values = []

    for date_str, all_info in json_data.items():
        if not filter_dates or (latest_date - util.get_local_datetime(date_str)
                                ).total_seconds() / 3600 < hours_back:
            if field in all_info["data"]:
                dates.append(util.get_local_datetime(date_str))
                values.append(all_info["data"][field])
            else:
                print("Warning: data line did not contain {}!".format(field))

    # The key to using matplotlib with flask: don't use pyplot!
    # https://matplotlib.org/3.1.1/faq/howto_faq.html

    fig = Figure(figsize=(13, 3))
    ax = fig.subplots()

    ax.plot(dates, values, util.plot_formats[field])
    ax.set_title("{} vs. Time".format(util.data_units[field][0]))
    ax.set_xlabel("Time")
    ax.set_ylabel("{} ({})".format(
        *[val for i, val in enumerate(util.data_units[field]) if i != 1]))
    ax.grid()

    ax.margins(x=0.01, y=0.15)  # Margins are percentages
    fig.tight_layout()

    bg_color = "#ededed"
    fig.patch.set_facecolor(bg_color)
    ax.patch.set_facecolor(bg_color)

    img_io = BytesIO()
    fig.savefig(img_io, format='png', facecolor=fig.get_facecolor())
    img_io.seek(0)

    response = make_response(img_io.getvalue())
    response.headers['Content-Type'] = 'image/png'

    # Disable caching
    response.headers['Last-Modified'] = datetime.datetime.now()
    response.headers[
        'Cache-Control'] = 'no-store, no-cache, must-revalidate, post-check=0, pre-check=0, max-age=0'
    response.headers['Pragma'] = 'no-cache'
    response.headers['Expires'] = '-1'

    return response
class PlottingCanvasView(QtWidgets.QWidget, PlottingCanvasViewInterface):
    def __init__(self, quick_edit, settings, parent=None):
        super().__init__(parent)
        # later we will allow these to be changed in the settings
        self._settings = settings
        self._min_y_range = settings.min_y_range
        self._y_axis_margin = settings.y_axis_margin
        self._x_tick_labels = None
        self._y_tick_labels = None
        self._shaded_regions = {}
        # create the figure
        self.fig = Figure()
        self.fig.canvas = FigureCanvas(self.fig)
        self.fig.canvas.setMinimumHeight(500)
        self.toolBar = PlotToolbar(self.fig.canvas, self)

        # Create a set of Mantid axis for the figure
        self.fig, axes = get_plot_fig(overplot=False,
                                      ax_properties=None,
                                      axes_num=1,
                                      fig=self.fig)
        self._number_of_axes = 1
        self._color_queue = [ColorQueue(DEFAULT_COLOR_CYCLE)]

        # Add a splitter for the plotting canvas and quick edit toolbar
        splitter = QtWidgets.QSplitter(QtCore.Qt.Vertical)
        splitter.addWidget(self.fig.canvas)
        self._quick_edit = quick_edit
        splitter.addWidget(self._quick_edit)
        splitter.setChildrenCollapsible(False)

        layout = QtWidgets.QVBoxLayout()
        layout.addWidget(self.toolBar)
        layout.addWidget(splitter)
        self.setLayout(layout)

        self._plot_information_list = []  # type : List[PlotInformation}

    @property
    def autoscale_state(self):
        return self._quick_edit.autoscale_state

    @property
    def plotted_workspace_information(self):
        return self._plot_information_list

    @property
    def plotted_workspaces_and_indices(self):
        plotted_workspaces = []
        plotted_indices = []
        for plot_info in self._plot_information_list:
            plotted_workspaces.append(plot_info.workspace_name)
            plotted_indices.append(plot_info.index)

        return plotted_workspaces, plotted_indices

    @property
    def num_plotted_workspaces(self):
        return len(self._plot_information_list)

    @property
    def number_of_axes(self):
        return self._number_of_axes

    def set_x_ticks(self, x_ticks=None):
        self._x_tick_labels = x_ticks

    def set_y_ticks(self, y_ticks=None):
        self._y_tick_labels = y_ticks

    def create_new_plot_canvas(self, num_axes):
        """Creates a new blank plotting canvas"""
        self.toolBar.reset_gridline_flags()
        self._plot_information_list = []
        self._number_of_axes = num_axes
        self._color_queue = [
            ColorQueue(DEFAULT_COLOR_CYCLE) for _ in range(num_axes)
        ]
        self.fig.clf()
        self.fig, axes = get_plot_fig(overplot=False,
                                      ax_properties=None,
                                      axes_num=num_axes,
                                      fig=self.fig)
        if self._settings.is_condensed:
            self.fig.subplots_adjust(wspace=0, hspace=0)
        else:
            self.fig.tight_layout()
        self.fig.canvas.draw()

    def clear_all_workspaces_from_plot(self):
        """Clears all workspaces from the plot"""
        for ax in self.fig.axes:
            ax.cla()
            ax.tracked_workspaces.clear()
            ax.set_prop_cycle(None)

        for color_queue in self._color_queue:
            color_queue.reset()

        for shaded_region in self._shaded_regions:
            shaded_region.remove()
        self._shaded_regions = {}

        self._plot_information_list = []

    def _make_plot(self, workspace_plot_info: WorkspacePlotInformation):
        workspace_name = workspace_plot_info.workspace_name
        try:
            workspace = AnalysisDataService.Instance().retrieve(workspace_name)
        except (RuntimeError, KeyError):
            return -1
        self._plot_information_list.append(workspace_plot_info)
        errors = workspace_plot_info.errors
        ws_index = workspace_plot_info.index
        axis_number = workspace_plot_info.axis
        ax = self.fig.axes[axis_number]
        plot_kwargs = self._get_plot_kwargs(workspace_plot_info)
        plot_kwargs['color'] = self._color_queue[axis_number]()
        if workspace_name in self._shaded_regions.keys() and errors:
            errors = False
            self.shade_region(plot_kwargs['color'], workspace_name)

        _do_single_plot(ax,
                        workspace,
                        ws_index,
                        errors=errors,
                        plot_kwargs=plot_kwargs)
        return axis_number

    def add_workspaces_to_plot(
            self, workspace_plot_info_list: List[WorkspacePlotInformation]):
        """Add a list of workspaces to the plot - The workspaces are contained in a list PlotInformation
        The PlotInformation contains the workspace name, workspace index and target axis."""
        nrows, ncols = get_num_row_and_col(self._number_of_axes)
        for workspace_plot_info in workspace_plot_info_list:
            axis_number = self._make_plot(workspace_plot_info)
            if axis_number < 0:
                continue
            self._set_text_tick_labels(axis_number)
            if self._settings.is_condensed:
                self.hide_axis(axis_number, nrows, ncols)
        #remove labels from empty plots
        if self._settings.is_condensed:
            for axis_number in range(int(self._number_of_axes),
                                     int(nrows * ncols)):
                self.hide_axis(axis_number, nrows, ncols)

    def add_shaded_region(self, workspace_name, axis_number, x_values,
                          y1_values, y2_values):
        axis = self.fig.axes[axis_number]
        if workspace_name in self._shaded_regions.keys():
            self._shaded_regions[workspace_name].update(axis=axis,
                                                        x_values=x_values,
                                                        y1_values=y1_values,
                                                        y2_values=y2_values)
        else:
            self._shaded_regions[workspace_name] = ShadedRegionInfo(
                workspace_name=workspace_name,
                axis=axis,
                x_values=x_values,
                y1_values=y1_values,
                y2_values=y2_values)

    def _wrap_labels(self, labels: list) -> list:
        """Wraps a list of labels so that every line is at most self._settings.wrap_width characters long."""
        return [
            "\n".join(wrap(label, self._settings.wrap_width))
            for label in labels
        ]

    def _set_text_tick_labels(self, axis_number):
        ax = self.fig.axes[axis_number]
        # set the axes to not "simplify" the values
        ax.xaxis.set_major_formatter(StrMethodFormatter('{x:g}'))
        ax.xaxis.set_minor_formatter(NullFormatter())
        ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}'))
        ax.yaxis.set_minor_formatter(NullFormatter())
        if self._x_tick_labels:
            ax.set_xticks(range(len(self._x_tick_labels)))
            labels = self._wrap_labels(self._x_tick_labels)
            ax.set_xticklabels(labels,
                               fontsize=self._settings.font_size,
                               rotation=self._settings.rotation,
                               ha="right")
        if self._y_tick_labels:
            ax.set_yticks(range(len(self._y_tick_labels)))
            labels = self._wrap_labels(self._y_tick_labels)
            ax.set_yticklabels(labels, fontsize=self._settings.font_size)

    def hide_axis(self, axis_number, nrows, ncols):
        row, col = convert_index_to_row_and_col(axis_number, nrows, ncols)
        ax = self.fig.axes[axis_number]
        if row != nrows - 1:
            labels = ["" for item in ax.get_xticks().tolist()]
            ax.set_xticklabels(labels)
            ax.xaxis.label.set_visible(False)
        if col != 0 and col != ncols - 1:
            labels = ["" for item in ax.get_yticks().tolist()]
            ax.set_yticklabels(labels)
            ax.yaxis.label.set_visible(False)
        elif col == ncols - 1 and ncols > 1:
            ax.yaxis.set_label_position('right')
            ax.yaxis.tick_right()

    def remove_workspace_info_from_plot(
            self, workspace_plot_info_list: List[WorkspacePlotInformation]):
        # We reverse the workspace info list so that we can maintain a unique color queue
        # See _update_color_queue_on_workspace_removal for more
        workspace_plot_info_list.reverse()
        for workspace_plot_info in workspace_plot_info_list:
            workspace_name = workspace_plot_info.workspace_name
            if not AnalysisDataService.Instance().doesExist(workspace_name):
                continue

            workspace = AnalysisDataService.Instance().retrieve(workspace_name)
            for plotted_information in self._plot_information_list.copy():
                if workspace_plot_info.workspace_name == plotted_information.workspace_name and \
                        workspace_plot_info.axis == plotted_information.axis:
                    self._update_color_queue_on_workspace_removal(
                        workspace_plot_info.axis, workspace_name)
                    axis = self.fig.axes[workspace_plot_info.axis]
                    axis.remove_workspace_artists(workspace)
                    self._plot_information_list.remove(plotted_information)
                    # clear shaded regions from plot
                    if len(
                            axis.collections
                    ) > 0 and plotted_information.workspace_name in self._shaded_regions.keys(
                    ):
                        self._shaded_regions[
                            plotted_information.workspace_name].remove()
        # If we have no plotted lines, reset the color cycle
        if self.num_plotted_workspaces == 0:
            self._reset_color_cycle()

    def remove_workspace_from_plot(self, workspace):
        """Remove all references to a workspaces from the plot """
        for workspace_plot_info in self._plot_information_list.copy():
            workspace_name = workspace_plot_info.workspace_name
            if workspace_name == workspace.name():
                self._update_color_queue_on_workspace_removal(
                    workspace_plot_info.axis, workspace_name)
                axis = self.fig.axes[workspace_plot_info.axis]
                axis.remove_workspace_artists(workspace)
                self._plot_information_list.remove(workspace_plot_info)
                if workspace_name in self._shaded_regions.keys():
                    self._shaded_regions[workspace_name].remove()
                    del self._shaded_regions[workspace_name]

    def _update_color_queue_on_workspace_removal(self, axis_number,
                                                 workspace_name):
        try:
            artist_info = self.fig.axes[axis_number].tracked_workspaces[
                workspace_name]
        except KeyError:
            return
        for ws_artist in artist_info:
            for artist in ws_artist._artists:
                color = get_color_from_artist(artist)
                # When we repeat colors we don't want to add colors to the queue if they are already plotted.
                # We know we are repeating colors if we have more lines than colors, then we check if the color
                # removed is already the color of an existing line. If it is we don't manually re-add the color
                # to the queue. This ensures we only plot lines of the same colour if we have more lines
                # plotted than colours
                lines = self.fig.axes[axis_number].get_lines()
                if len(lines) > NUMBER_OF_COLOURS:
                    current_colors = [line.get_c() for line in lines]
                    if color in current_colors:
                        return
                self._color_queue[axis_number] += color

    # Ads observer functions
    def replace_specified_workspace_in_plot(self, workspace):
        """Replace specified workspace in the plot with a new and presumably updated instance"""
        for workspace_plot_info in self._plot_information_list:
            plotted_workspace_name = workspace_plot_info.workspace_name
            workspace_name = workspace.name()
            if workspace_name == plotted_workspace_name:
                axis = self.fig.axes[workspace_plot_info.axis]
                axis.replace_workspace_artists(workspace)
                if workspace_name in self._shaded_regions.keys(
                ) and workspace_plot_info.errors:
                    # remove old shade first
                    self._shaded_regions[workspace_name].remove()
                    ws_artist = axis.tracked_workspaces[workspace_name][0]
                    color = get_color_from_artist(ws_artist._artists[0])
                    self.shade_region(color, workspace_name)

        self.redraw_figure()

    # not used for tiled plots
    def replot_workspace_with_error_state(self, workspace_name,
                                          with_errors: bool):
        for plot_info in self.plotted_workspace_information:
            # update plot info error state -> important for shading
            plot_info.errors = with_errors
            if plot_info.workspace_name == workspace_name:
                axis = self.fig.axes[plot_info.axis]
                workspace_name = plot_info.workspace_name
                artist_info = axis.tracked_workspaces[workspace_name]
                for ws_artist in artist_info:
                    for artist in ws_artist._artists:
                        color = get_color_from_artist(artist)
                        plot_kwargs = self._get_plot_kwargs(plot_info)
                        plot_kwargs["color"] = color
                        if workspace_name in self._shaded_regions.keys():
                            if with_errors:
                                # replot without errors to get the fit on top
                                axis.replot_artist(artist, False,
                                                   **plot_kwargs)
                                self.shade_region(color, workspace_name)
                            else:
                                self._shaded_regions[workspace_name].remove()
                        else:
                            axis.replot_artist(artist, with_errors,
                                               **plot_kwargs)
        self.redraw_figure()

    def shade_region(self, color, name):
        self._shaded_regions[name].shade_region(color)

    def set_axis_xlimits(self, axis_number, xlims):
        ax = self.fig.axes[axis_number]
        ax.set_xlim(xlims[0], xlims[1])

    def set_axis_ylimits(self, axis_number, ylims):
        ax = self.fig.axes[axis_number]
        ax.set_ylim(ylims[0], ylims[1])

    def set_axes_limits(self, xlim, ylim):
        plt.setp(self.fig.axes, xlim=xlim, ylim=ylim)

    def autoscale_y_axes(self):
        ymin = 1e9
        ymax = -1e9
        for axis in self.fig.axes:
            ymin_i, ymax_i = self._get_y_axis_autoscale_limits(axis)
            if ymin_i < ymin:
                ymin = ymin_i
            if ymax_i > ymax:
                ymax = ymax_i
        plt.setp(self.fig.axes, ylim=[ymin, ymax])

    @property
    def get_xlim_list(self):
        xlim_list = []
        for axis in self.fig.axes:
            min, max = axis.get_xlim()
            xlim_list.append([min, max])
        return xlim_list

    @property
    def get_ylim_list(self):
        ylim_list = []
        for axis in self.fig.axes:
            min, max = axis.get_ylim()
            ylim_list.append([min, max])
        return ylim_list

    def autoscale_selected_y_axis(self, axis_number):
        if axis_number >= len(self.fig.axes):
            return
        axis = self.fig.axes[axis_number]
        bottom, top, = self._get_y_axis_autoscale_limits(axis)
        axis.set_ylim(bottom, top)

    def set_title(self, axis_number, title):
        if axis_number >= self.number_of_axes or self._settings.is_condensed:
            return
        axis = self.fig.axes[axis_number]
        axis.set_title(title)

    def get_axis_limits(self, axis_number):
        xmin, xmax = self.fig.axes[axis_number].get_xlim()
        ymin, ymax = self.fig.axes[axis_number].get_ylim()

        return xmin, xmax, ymin, ymax

    def redraw_figure(self):
        self.fig.canvas.toolbar.update()
        self._redraw_legend()
        if not self._settings.is_condensed:
            self.fig.tight_layout()
        self.fig.canvas.draw()

    def _redraw_legend(self):
        for ax in self.fig.axes:
            if ax.get_legend_handles_labels()[0]:
                legend = ax.legend(prop=dict(size=5))
                legend_set_draggable(legend, True)

    def _get_plot_kwargs(self, workspace_info: WorkspacePlotInformation):
        label = workspace_info.label
        plot_kwargs = {'distribution': True, 'label': label}
        plot_kwargs["marker"] = self._settings.get_marker(
            workspace_info.workspace_name)
        plot_kwargs["linestyle"] = self._settings.get_linestyle(
            workspace_info.workspace_name)
        if isinstance(workspace_info.index, int):
            """
            Attempts at replotting the fit line do not work,
            this is because replot_artist currently only does anything
            if the data has changed (it has not in our case).
            So lets manually set the fit lines to be on top
            (they will always have an index of either 1 or 3 and data
            will always have an index of 0).
            """
            plot_kwargs["zorder"] = workspace_info.index
        return plot_kwargs

    def _get_y_axis_autoscale_limits(self, axis):
        x_min, x_max = sorted(axis.get_xlim())
        y_min, y_max = np.inf, -np.inf
        for line in axis.lines:
            y_min, y_max = get_y_min_max_between_x_range(
                line, x_min, x_max, y_min, y_max)
        if y_min == np.inf:
            y_min = -self._min_y_range
        if y_max == -np.inf:
            y_max = self._min_y_range
        if y_min == y_max:
            y_min -= self._min_y_range
            y_max += self._min_y_range
        y_margin = abs(y_max - y_min) * self._y_axis_margin

        return y_min - y_margin, y_max + y_margin

    def _reset_color_cycle(self):
        for i, ax in enumerate(self.fig.axes):
            ax.cla()
            ax.tracked_workspaces.clear()

    def resizeEvent(self, event):
        if self._settings.is_condensed:
            return
        self.fig.tight_layout()

    def add_uncheck_autoscale_subscriber(self, observer):
        self.toolBar.uncheck_autoscale_notifier.add_subscriber(observer)

    def add_enable_autoscale_subscriber(self, observer):
        self.toolBar.enable_autoscale_notifier.add_subscriber(observer)

    def add_disable_autoscale_subscriber(self, observer):
        self.toolBar.uncheck_autoscale_notifier.add_subscriber(observer)

    def add_range_changed_subscriber(self, observer):
        self.toolBar.range_changed_notifier.add_subscriber(observer)
Example #35
0
def plot_image_map_line_profile_using_interface_plane(
        image_data,
        heatmap_data_list,
        line_profile_data_list,
        interface_plane,
        atom_plane_list=None,
        data_scale=1,
        data_scale_z=1,
        atom_list=None,
        extra_marker_list=None,
        clim=None,
        plot_title='',
        vector_to_plot=None,
        rotate_atom_plane_list_90_degrees=False,
        prune_outer_values=False,
        figname="map_data.jpg"):
    """
    Parameters
    ----------
    atom_list : list of Atom_Position instances
    extra_marker_list : two arrays of x and y [[x_values], [y_values]]
    """
    number_of_line_profiles = len(line_profile_data_list)

    figsize = (10, 18 + 2 * number_of_line_profiles)

    fig = Figure(figsize=figsize)
    FigureCanvas(fig)

    gs = GridSpec(95 + 10 * number_of_line_profiles, 95)

    image_ax = fig.add_subplot(gs[0:45, :])
    distance_ax = fig.add_subplot(gs[45:90, :])
    colorbar_ax = fig.add_subplot(gs[90:95, :])
    line_profile_ax_list = []
    for i in range(number_of_line_profiles):
        gs_y_start = 95 + 10 * i
        line_profile_ax = fig.add_subplot(gs[gs_y_start:gs_y_start + 10, :])
        line_profile_ax_list.append(line_profile_ax)

    image_y_lim = (0, image_data.shape[0] * data_scale)
    image_x_lim = (0, image_data.shape[1] * data_scale)

    image_clim = to._get_clim_from_data(image_data, sigma=2)
    image_cax = image_ax.imshow(image_data,
                                origin='lower',
                                extent=[
                                    image_x_lim[0], image_x_lim[1],
                                    image_y_lim[0], image_y_lim[1]
                                ])

    image_cax.set_clim(image_clim[0], image_clim[1])
    image_ax.set_xlim(image_x_lim[0], image_x_lim[1])
    image_ax.set_ylim(image_y_lim[0], image_y_lim[1])
    image_ax.set_title(plot_title)

    if atom_plane_list is not None:
        for atom_plane in atom_plane_list:
            if rotate_atom_plane_list_90_degrees:
                atom_plane_x = np.array(atom_plane.get_x_position_list())
                atom_plane_y = np.array(atom_plane.get_y_position_list())
                start_x = atom_plane_x[0]
                start_y = atom_plane_y[0]
                delta_y = (atom_plane_x[-1] - atom_plane_x[0])
                delta_x = -(atom_plane_y[-1] - atom_plane_y[0])
                atom_plane_x = np.array([start_x, start_x + delta_x])
                atom_plane_y = np.array([start_y, start_y + delta_y])
            else:
                atom_plane_x = np.array(atom_plane.get_x_position_list())
                atom_plane_y = np.array(atom_plane.get_y_position_list())
            image_ax.plot(atom_plane_x * data_scale,
                          atom_plane_y * data_scale,
                          color='red',
                          lw=2)

    atom_plane_x = np.array(interface_plane.get_x_position_list())
    atom_plane_y = np.array(interface_plane.get_y_position_list())
    image_ax.plot(atom_plane_x * data_scale,
                  atom_plane_y * data_scale,
                  color='blue',
                  lw=2)

    _make_subplot_map_from_regular_grid(distance_ax,
                                        heatmap_data_list,
                                        distance_data_scale=data_scale_z,
                                        clim=clim,
                                        atom_list=atom_list,
                                        atom_plane_marker=interface_plane,
                                        extra_marker_list=extra_marker_list,
                                        vector_to_plot=vector_to_plot)
    distance_cax = distance_ax.images[0]
    distance_ax.plot(atom_plane_x * data_scale,
                     atom_plane_y * data_scale,
                     color='red',
                     lw=2)

    for line_profile_ax, line_profile_data in zip(line_profile_ax_list,
                                                  line_profile_data_list):
        _make_subplot_line_profile(line_profile_ax,
                                   line_profile_data[:, 0],
                                   line_profile_data[:, 1],
                                   prune_outer_values=prune_outer_values,
                                   scale_x=data_scale,
                                   scale_z=data_scale_z)

    fig.tight_layout()
    fig.colorbar(distance_cax, cax=colorbar_ax, orientation='horizontal')
    fig.savefig(figname)
Example #36
0
class Timeline(ttk.Frame):
    def __init__(self, master):
        super(Timeline, self).__init__(master, height=100)
        self.window = master  #type:VisAnaWindow
        self.columnconfigure(0, weight=1)
        self.rowconfigure(0, weight=1)

        #self.mininterval=360

        #self.aggregation_amount=self.mininterval
        self.selected_dates = np.ndarray([])
        self.shown_dates = np.ndarray([])
        self.timeline = None

        #df["_color"] = np.asarray(COLORS)[df["_label"]]
        #print("color")

        #draw timeline
    def create_timeline(self):

        base_dates = self.window.ds.df("base").index.values

        start_date = base_dates[0]
        end_date = base_dates[-1]
        ## plot timeline

        self.fig = Figure(figsize=(12, 1), dpi=75)
        ax = self.fig.add_subplot(111)

        #draw all shown dates
        c = self.window.ds.get_data("cluster").centroids
        d = self.window.ds.df("cluster")
        if c is not None:
            # we show clustered dates
            # for i in range(len(c)):
            #     d2=d.loc[d["_label"] == i]
            #     dates=d2.index.values
            #     ax.scatter(dates, [1]*len(dates), c=COLORS[i],
            #                marker='|', s=300)#, fontsize=10)
            #change COLORS to a numpy-array and then map each label to its color
            colors = np.asarray(COLORS)[d["_cluster"]]
            dates = d.index.values
            ax.scatter(dates, [1] * len(dates), c=colors, marker='|',
                       s=300)  #, fontsize=10)
        else:
            for p in self.window.ds.get_significant_nan_columns():
                d = d.loc[d[p].notnull()]

            # we show normal dates
            dates = d.index.values
            ax.scatter(dates, [1] * len(dates), c="blue", marker='|',
                       s=300)  # , fontsize=10)

        #draw all selected data if neccessary
        if self.window.scatter.has_selection():
            selected_dates = self.window.ds.df("ss_selected").index.values
            ax.scatter(selected_dates, [1] * len(selected_dates),
                       c="black",
                       marker='|',
                       s=300)  # , fontsize=10)

        hfmt = mdates.DateFormatter("1. %b '%y")
        # fig.subplots_adjust(left=0.03, right=0.97, top=1)

        ax.xaxis.set_major_formatter(hfmt)
        self.fig.autofmt_xdate()
        #ax.set_xticklabels(ax.xaxis.get_minorticklabels(), rotation=0)

        #ax.set_xlim([datetime(2014,1,1,0,0,0), datetime(2015,1,1,0,0,0)])
        ax.set_xlim([start_date, end_date])

        ## everything after this is turning off stuff that's plotted by default
        ax.yaxis.set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.get_yaxis().set_ticklabels([])
        self.fig.tight_layout(pad=0)

        #print(fig.get_figheight())
        #fig.subplots_adjust(top=1, right=0.99)

        ## add to GUI
        if self.timeline is not None:
            self.timeline.get_tk_widget().destroy()
        self.timeline = FigureCanvasTkAgg(self.fig,
                                          self)  #, resize_callback=print)

        self.timeline.get_tk_widget().grid(column=0,
                                           row=0,
                                           sticky=(tk.N, tk.E, tk.W, tk.S))

    # def draw_timeline(self):
    #     ## create value for each day in data,
    #     ## depending on whether it is selected, shown etc.
    #
    #     shown_dates=self.ds.df("shown_dates").index.values
    #
    #     base_dates=self.df("all_days").index.values
    #
    #     ## extract first and last date
    #     #start_date = base_dates[0]
    #     #end_date = base_dates[-1]
    #     #days_diff = ((end_date-start_date) / np.timedelta64(1, 'D'))
    #
    #     ## prepare data for timeline
    #     days = []
    #     values = []
    #
    #     self.ds.filterforvalues("intervals_with_values",[self.param1,self.param2], "AND","base")
    #     self.ds.filterforvalues("intervals_with_values","intervals_with_values")
    #
    #     #self.df("all_days").
    #     for date in self.df("all_days").index.values:
    #         self.df("all_days")
    #         col1_has_values=((self.param1 == self.ds.get_time_colname()) or (self.df("all_days")[self.param1][date]>0))
    #         col2_has_values=((self.param2 == self.ds.get_time_colname()) or (self.df("all_days")[self.param2][date]>0))
    #         if col1_has_values and col2_has_values:
    #             days.append(date)
    #             #if self.dates[self.startSlider.get()] <= day < self.dates[self.endSlider.get()]:
    #             if date in shown_dates:
    #                 if date in selected_dates:
    #                     values.append("red")
    #                 else:
    #                     values.append("blue")
    #             else:
    #                 values.append("lightskyblue")
    #     #print("d:",days, values)
    #
    #     ## plot timeline
    #     fig = Figure(figsize=(12,1), dpi=75)
    #     ax = fig.add_subplot(111)
    #
    #     ax.scatter(days, [1]*len(days), c=values,
    #                marker='|', s=300)#, fontsize=10)
    #     #fig.xt
    #     hfmt = mdates.DateFormatter("1. %b '%y")
    #     # fig.subplots_adjust(left=0.03, right=0.97, top=1)
    #
    #     ax.xaxis.set_major_formatter(hfmt)
    #     fig.autofmt_xdate()
    #     #ax.set_xticklabels(ax.xaxis.get_minorticklabels(), rotation=0)
    #
    #     #ax.set_xlim([datetime(2014,1,1,0,0,0), datetime(2015,1,1,0,0,0)])
    #     ax.set_xlim([start_date, end_date])
    #
    #
    #     ## everything after this is turning off stuff that's plotted by default
    #     ax.yaxis.set_visible(False)
    #     ax.spines['right'].set_visible(False)
    #     ax.spines['left'].set_visible(False)
    #     ax.spines['top'].set_visible(False)
    #     ax.xaxis.set_ticks_position('bottom')
    #     ax.get_yaxis().set_ticklabels([])
    #     fig.tight_layout(pad=0)
    #
    #     #print(fig.get_figheight())
    #     #fig.subplots_adjust(top=1, right=0.99)
    #
    #     ## add to GUI
    #     self.timeline = FigureCanvasTkAgg(fig, self)
    #     self.timeline.get_tk_widget().grid(column=0, row=2, sticky=(tk.N, tk.E, tk.W, tk.S),columnspan=5)
    #     #print("h:",self.timeline.figure.bbox.height)

    def setAggregation(self, minutes):
        if self.mininterval <= minutes:
            self.aggregation_amount = minutes
        else:
            self.aggregation_amount = self.mininterval
        self.ds.aggregateTime("all_intervals", "COUNT",
                              self.aggregation_amount, "base")

    def selectionChanged(self, isNowSelected=False):
        self.selected_dates = np.ndarray([])

        if isNowSelected:
            #self.ds.groupby2("selected_days", datasource.TIME_ATTR, "COUNT", "selected", bydate=True)
            #self.ds.aggregateTime("selected_days", "COUNT",24*60, "selected")
            #self.ds.aggregateTime("selected_intervals", "COUNT",1, "selected")
            self.selected_dates = self.ds.df("selected_intervals").index.values
        self.draw_timeline()

    def shownDatesChanged(self):
        self.shown_dates = self.ds.df("shown_dates").index.values
Example #37
0
class mclass:
    def __init__(self, window):
        self.window = window
        self.window.resizable(width=False, height=False)

        self.fig = Figure(figsize=(7.2, 7.2))
        self.ax = self.fig.add_subplot(111, projection='polar')

        self.canvas = FigureCanvasTkAgg(self.fig, master=self.window)
        self.canvas.get_tk_widget().pack(side=RIGHT, padx=0, pady=0)
        self.canvas.show()

        self.statusCanvas = Canvas(self.window)
        self.statusCanvas.pack(side=LEFT,
                               padx=5,
                               pady=5,
                               fill=BOTH,
                               expand=True)
        self.label = Label(
            self.statusCanvas,
            text="LAT = 50.3633 \n LON = 30.4961 \n alt = 211.8")
        self.label.pack(side=LEFT, padx=5, pady=5)

        self.plot()

    def plot(self):
        self.window.title("Sky Safety Manager >>>>>> UTC: " +
                          datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S"))

        with open("current_status", 'r') as status_file:
            status = status_file.readline()
        self.statusLable = self.statusCanvas.create_rectangle(
            0, 0, 120, 300, fill='red' if status == "Danger" else "green")

        self.ax.clear()
        self.ax.set_theta_zero_location('S')
        self.ax.set_theta_direction(-1)
        self.ax.set_rmax(90)
        self.ax.set_rticks(np.arange(0, 91, 10))  # less radial ticks
        self.ax.set_yticklabels(self.ax.get_yticks()[::-1])
        self.ax.set_rlabel_position(0)
        self.ax.grid(True)

        r = np.arange(0, 90, 1)
        theta = r

        positions = []
        with open("current_positions.csv") as csv_file:
            read = csv.reader(csv_file, delimiter=',')
            for row in read:
                row[2] = Decimal(row[2])
                row[3] = Decimal(row[3])
                positions.append(row[0:])
        # sat_positions = [["sat1", 30, 0], ["sat2", 60, 90], ["sat3", 30, 180], ["sat4", 50, 270]]
        for (type, PRN, Az, E) in positions:
            if type == "SAT":
                self.ax.annotate(
                    str(PRN),
                    xy=(radians(Az), r2el(E)),  # theta, radius
                    bbox=dict(boxstyle="circle", fc='red', alpha=0.3,
                              pad=1.25),
                    horizontalalignment='center',
                    verticalalignment='center')
            else:
                self.ax.annotate(str(hex(int(PRN)))[2:],
                                 xy=(radians(Az), r2el(E)))
            self.ax.plot(radians(Az), r2el(E), '.')

        self.ax.plot(d2r(theta), r2el(r), linewidth=0)
        self.fig.tight_layout()
        self.canvas.draw()

        window.after(1200, self.plot)
Example #38
0
class subplot(QtWidgets.QWidget):
    signal_quick_edit = QtCore.Signal(object)
    signal_rm_subplot = QtCore.Signal(object)
    signal_rm_line = QtCore.Signal(object)

    def __init__(self, context):
        super(subplot, self).__init__()
        self._context = context
        self.figure = Figure()
        self.figure.set_facecolor("none")
        self.canvas = FigureCanvas(self.figure)
        self._rm_window = None
        self._selector_window = None
        # update quick edit from tool bar
        self.canvas.mpl_connect("draw_event", self.draw_event_callback)

        self._ADSObserver = SubplotADSObserver(self)

        grid = QtWidgets.QGridLayout()
        # add toolbar
        self.toolbar = myToolbar(self.canvas, self)
        self.toolbar.update()
        grid.addWidget(self.toolbar, 0, 0)
        self.toolbar.setRmConnection(self._rm)
        self.toolbar.setRmSubplotConnection(self._rm_subplot)
        # add plot
        self.plot_objects = {}
        grid.addWidget(self.canvas, 1, 0)
        self.setLayout(grid)

    """ this is called when the zoom
    or pan are used. We want to send a
    signal to update the axis ranges """

    def draw_event_callback(self, event):
        self.figure.tight_layout()
        for subplot in self.plot_objects.keys():
            self.emit_subplot_range(subplot)

    def add_annotate(self, subplotName, label):
        if subplotName not in self._context.subplots.keys():
            return
        self._context.add_annotate(subplotName, label)
        self.canvas.draw()

    def add_vline(self, subplot_name, xvalue, name, color):
        if subplot_name not in self._context.subplots.keys():
            return
        self._context.add_vline(subplot_name, xvalue, name, color)
        self.canvas.draw()

    def rm_annotate(self, subplot_name, name):
        if subplot_name not in self._context.subplots.keys():
            return
        self._context.removeLabel(subplot_name, name)
        self.canvas.draw()

    def rm_vline(self, subplot_name, name):
        if subplot_name not in self._context.subplots.keys():
            return
        self._context.removeVLine(subplot_name, name)
        self.canvas.draw()

    # plot a workspace, if a new subplot create it.
    def plot(self, subplot_name, workspace, color=None, spec_num=1):
        new = False
        if subplot_name not in self._context.subplots.keys():
            self.add_subplot(subplot_name, len(list(self.plot_objects.keys())))
            new = True
        self._add_plotted_line(subplot_name,
                               workspace,
                               spec_num=spec_num,
                               color=color)
        if new:
            self.emit_subplot_range(subplot_name)

    def change_errors(self, state, subplot_names):
        for subplotName in subplot_names:
            self._context.subplots[subplotName].change_errors(state)
            self.canvas.draw()

    # adds plotted line to context and updates GUI
    def _add_plotted_line(self, subplot_name, workspace, spec_num, color=None):
        """ Appends plotted lines to the related subplot list. """
        self._context.addLine(subplot_name, workspace, spec_num, color=color)
        self.canvas.draw()

    def add_subplot(self, subplot_name, number):
        self._context.update_gridspec(number + 1)
        gridspec = self._context.gridspec
        self.plot_objects[subplot_name] = self.figure.add_subplot(
            gridspec[number], label=subplot_name, projection='mantid')
        self.plot_objects[subplot_name].set_title(subplot_name)
        self._context.addSubplot(subplot_name, self.plot_objects[subplot_name])
        self._update()

    def _update(self):
        self._context.update_layout(self.figure)
        self.canvas.draw()

    def emit_subplot_range(self, subplot_name):
        self.signal_quick_edit.emit(subplot_name)
        self._context.subplots[subplot_name].redraw_annotations()

    def set_plot_x_range(self, subplot_names, range):
        for subplotName in subplot_names:
            # make a set method in context and set it there
            self.plot_objects[subplotName].set_xlim(range)
            self._context.subplots[subplotName].redraw_annotations()
            self.canvas.draw()

    def set_plot_y_range(self, subplot_names, y_range):
        for subplotName in subplot_names:
            self.plot_objects[subplotName].set_ylim(y_range)
            self._context.subplots[subplotName].redraw_annotations()
            self.canvas.draw()

    def connect_quick_edit_signal(self, slot):
        self.signal_quick_edit.connect(slot)

    def disconnect_quick_edit_signal(self):
        self.signal_quick_edit.disconnect()

    def connect_rm_subplot_signal(self, slot):
        self.signal_rm_subplot.connect(slot)

    def disconnect_rm_subplot_signal(self):
        self.signal_rm_subplot.disconnect()

    def connect_rm_line_signal(self, slot):
        self.signal_rm_line.connect(slot)

    def disconnect_rm_line_signal(self):
        self.signal_rm_line.disconnect()

    def set_y_autoscale(self, subplot_names, state):
        for subplotName in subplot_names:
            self._context.subplots[subplotName].change_auto(state)
            self.canvas.draw()

    def _rm(self):
        names = list(self._context.subplots.keys())
        if len(names) == 1:
            if self._rm_window is not None:
                self._rm_window.show()
            else:
                self._get_rm_window(names[0])
        else:
            if self._rm_window is not None:
                self._rm_window.close()
                self._rm_window = None
            self._close_selector_window()

            self._selector_window = self._create_select_window(names)
            self._selector_window.subplotSelectorSignal.connect(
                self._get_rm_window)
            self._selector_window.closeEventSignal.connect(
                self._close_selector_window)
            self._selector_window.setMinimumSize(300, 100)
            self._selector_window.show()

    def _rm_subplot(self):
        names = list(self._context.subplots.keys())
        # If the selector is hidden then close it
        self._close_selector_window()

        self._selector_window = self._create_select_window(names)
        self._selector_window.subplotSelectorSignal.connect(
            self._remove_subplot)
        self._selector_window.subplotSelectorSignal.connect(
            self._close_selector_window)
        self._selector_window.closeEventSignal.connect(
            self._close_selector_window)
        self._selector_window.setMinimumSize(300, 100)
        self._selector_window.show()

    def _create_select_window(self, names):
        return SelectSubplot(names)

    def _close_selector_window(self):
        if self._selector_window is not None:
            self._selector_window.close()
            self._selector_window = None

    def _create_rm_window(self, subplot_name):
        line_names = list(self._context.subplots[subplot_name].lines.keys())
        vline_names = self._context.subplots[subplot_name].vlines
        return RemovePlotWindow(lines=line_names,
                                vlines=vline_names,
                                subplot=subplot_name,
                                parent=self)

    def _get_rm_window(self, subplot_name):
        # always close selector after making a selection
        self._close_selector_window()
        # create the remove window
        self._rm_window = self._create_rm_window(subplot_name=subplot_name)
        self._rm_window.applyRemoveSignal.connect(self._apply_rm)
        self._rm_window.closeEventSignal.connect(self._close_rm_window)
        self._rm_window.setMinimumSize(200, 200)
        self._rm_window.show()

    def remove_lines(self, subplot_name, line_names):
        # remove the lines from the subplot
        for name in line_names:
            self._context.subplots[subplot_name].removeLine(name)

        self.signal_rm_line.emit([str(name) for name in line_names])

        # if all of the lines have been removed -> delete subplot
        if not self._context.get_lines(subplot_name):
            self._remove_subplot(subplot_name)
        else:
            self.canvas.draw()

    def _apply_rm(self, line_names):
        to_close = []
        for name in line_names:
            if self._rm_window.getState(name):
                to_close.append(name)

        self.remove_lines(self._rm_window.subplot, to_close)
        self._close_rm_window()

    def _close_rm_window(self):
        self._rm_window.close
        self._rm_window = None

    def _remove_subplot(self, subplot_name):
        self.figure.delaxes(self.plot_objects[subplot_name])
        del self.plot_objects[subplot_name]
        self._context.delete(subplot_name)
        self._context.update_gridspec(len(list(self.plot_objects.keys())))
        self._update()
        self.signal_rm_subplot.emit(subplot_name)

    def _rm_ws_from_plots(self, workspace_name):
        keys = deepcopy(list(self._context.subplots.keys()))
        for subplot in keys:
            labels = self._context.get_lines_from_WS(subplot, workspace_name)
            for label in labels:
                self._context.removePlotLine(subplot, label)
                self.canvas.draw()
            if self._context.is_subplot_empty(subplot):
                self._remove_subplot(subplot)

    def _replaced_ws(self, workspace):
        for subplot in self._context.subplots.keys():
            redraw = self._context.subplots[subplot].replace_ws(workspace)
            if redraw:
                self.canvas.draw()
Example #39
0
class ScatterPlotWidget(QtWidgets.QWidget):

    def __init__(self):

        QtWidgets.QWidget.__init__(self)

        self.init_tab()

    def init_tab(self):

        self.setLayout(QtWidgets.QVBoxLayout())
        # self.figure, self.ax = plt.subplots(figsize=(10, 10))
        self.figure = Figure(figsize=(10, 10))

        # Add canvas
        self.canvas = FigureCanvasQTAgg(self.figure)
        def nothing(event, limits=None):
            pass
        self.toolbar = widgets.CustomNavigationToolbar(canvas=self.canvas, widget=self, update_func=nothing)

        canvasbox = QtWidgets.QVBoxLayout()
        canvasbox.addWidget(self.toolbar)
        canvasbox.addWidget(self.canvas)
        canvasbox.addItem(QtWidgets.QSpacerItem(0, 0, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding))

        self.layout().addLayout(canvasbox)

        self.ax = self.figure.add_subplot()
        self.ax.grid()
        self.ax.spines['right'].set_visible(False)
        self.ax.spines['top'].set_visible(False)
        self.ax.tick_params(axis='y', color='0.75')
        self.ax.tick_params(axis='x', color='0.75')

        self.scatters = {}
        self.resultvals = {}

        self.legend = InteractiveLegend(self, self.scatters, loc='upper left')#, title='Klik op punt\nom te verbergen')

        self.canvas.mpl_connect('pick_event', self._onpick)

    def _onpick(self, event):
        self.legend._onpick(event)
        self.update_scatters()

    def remove_scatter(self, param):

        # Remove element if present
        if param in self.scatters.keys():
            self.scatters[param].remove()
            del self.scatters[param]
            self.legend.remove(param)

    def remove_all_scatters(self):
        # Remove present plot elements
        keys = list(self.scatters.keys())
        list(map(self.remove_scatter, keys))
        self.canvas.draw_idle()

    def get_visible(self):
        """Get a list of visible scatters"""
        return [key for key, scatter in self.scatters.items() if scatter.get_visible()]

    def update_axis(self, xlabel, xticks, ylabel, yparams):
        """
        Method to update axis after table has been changed
        """
        self.ax.set_xlabel(xlabel)

        if isinstance(xticks[0], str):
            self.ax.set_xticks(list(range(len(xticks))))
            self.minxspace = 1
            self.ax.set_xticklabels(xticks, rotation=90)
        else:
            self.ax.set_xticks(xticks)
            self.ax.set_xticklabels(xticks)
            self.minxspace = np.diff(xticks).min()

        self.ax.set_ylabel(ylabel)

        # Result parameters
        for i, param in enumerate(yparams):
            self.scatters[param], = self.ax.plot([], [], ms=3, alpha=0.5, ls='', marker='o', color=f'C{i}')
            self.legend.add_item(param, handle=self.ax.scatter([], [], s=20, alpha=0.5, marker='o', color=f'C{i}'), label=param)

        self.legend._update_legend()
        self.figure.tight_layout()


    def update_scatters(self, resultvals=None):

        # Update data if provided
        if resultvals is not None:
            self.resultvals.clear()
            self.resultvals.update(resultvals)

        visible = self.get_visible()
        if len(visible) == 0:
            return None
        else:
            xoffsets = np.linspace(-self.minxspace / 4, self.minxspace / 4, len(visible)+2)[1:-1]

        i = 0
        for param, values in self.resultvals.items():
            if param == 'x':
                continue
            elif param not in visible:
                continue
            self.scatters[param].set_data(self.resultvals['x'] + xoffsets[i], values)
            i += 1

        # Resize
        self.ax.relim()
        self.ax.autoscale_view(True, True, True)
        self.figure.tight_layout()
        self.canvas.draw_idle()
Example #40
0
class MLGraphViewer(QMainWindow):

    COLORS = [
        "#ff0000", "#ff8000", "#bfff00", "#00ff00", "#00ffbf", "#00bfff",
        "#8000ff", "#bf00ff", "#ff00ff", "#cc5151", "#7f3333", "#51cccc",
        "#337f7f", "#8ecc51", "#597f33", "#8e51cc", "#59337f", "#ccad51",
        "#7f6c33", "#51cc70", "#337f46", "#5170cc", "#33467f", "#cc51ad",
        "#7f336c", "#cc7f51", "#7f4f33", "#bccc51", "#757f33", "#60cc51",
        "#3c7f33", "#51cc9e", "#337f62", "#519ecc", "#33627f", "#6051cc",
        "#3c337f", "#bc51cc", "#75337f", "#cc517f", "#7f334f", "#cc6851",
        "#7f4133", "#cc9651", "#7f5e33", "#ccc451", "#7f7a33", "#a5cc51",
        "#677f33", "#77cc51", "#4a7f33", "#51cc59", "#337f37", "#51cc87",
        "#337f54", "#51ccb5", "#337f71", "#51b5cc", "#33717f", "#5187cc",
        "#33547f", "#5159cc", "#33377f", "#7751cc", "#4a337f", "#a551cc",
        "#67337f", "#cc51c4", "#7f337a", "#cc5196", "#7f335e", "#cc5168",
        "#7f3341", "#cc5d51", "#7f3a33", "#cc7451", "#7f4833", "#cc8a51",
        "#7f5633", "#cca151", "#7f6533", "#ccb851", "#7f7333", "#c8cc51",
        "#7d7f33", "#b1cc51", "#6e7f33", "#9acc51", "#607f33", "#83cc51",
        "#527f33", "#6ccc51", "#437f33", "#55cc51", "#357f33", "#51cc64",
        "#337f3e", "#51cc7b", "#337f4d", "#51cc92", "#337f5b", "#51cca9",
        "#337f69", "#51ccc0", "#337f78", "#51c0cc", "#33787f", "#51a9cc",
        "#33697f"
    ]
    COLOR_MAP = dict((c, True) for c in COLORS)

    # -------------------------------------------------------------------- #

    def __init__(self, dirs, parent=None):
        QMainWindow.__init__(self, parent)

        self._dirs = dirs
        self._graphs = {}
        self._handleChanges = True

        self.setWindowTitle('ML Graph Viewer')

        self.create_menu()
        self.create_main_frame()
        self.create_status_bar()

    # -------------------------------------------------------------------- #

    def _scaleXToFit(self):
        active_graphs = [
            graph for graph in self._graphs.values() if graph.isVisible()
        ]
        if len(active_graphs) == 0:
            return

        x_min = min([graph.min_x_val() for graph in active_graphs])
        x_max = max([graph.max_x_val() for graph in active_graphs])
        if ((x_min, x_max) == (0.0, 1.0)) or (x_min >= x_max):
            return

        x_delta = x_max - x_min
        x_min -= 0.02 * x_delta
        x_max += 0.02 * x_delta
        self.host.set_xlim(x_min, x_max)

    # -------------------------------------------------------------------- #

    def _onXlimsChange(self, axes):
        #print "updated xlims: ", axes.get_xlim()
        if axes.get_xlim() == (0.0, 1.0):
            # print "Reset"
            self._resetZoom()

    # -------------------------------------------------------------------- #

    def _onYlimsChange(self, axes):
        #print "updated ylims: ", axes.get_ylim()
        pass

    # -------------------------------------------------------------------- #

    def _getNextColor(self):
        for c in MLGraphViewer.COLORS:
            if MLGraphViewer.COLOR_MAP[c]:
                return c
        raise Exception("No more available colors.")

    # -------------------------------------------------------------------- #

    def _getOrCreateGraph(self, csv_path):
        csv_path = str(csv_path)
        if csv_path in self._graphs:
            graph = self._graphs[csv_path]
        else:
            label = csv_path  #
            desc = Graph.getGraphDesc(csv_path)
            graph = Graph(label, csv_path, desc.graph_type, desc.ymax)
            self._graphs[csv_path] = graph
        return graph

    # -------------------------------------------------------------------- #

    def _plotGraph(self, graph, tree_item):
        ###################
        # Allocate color: #
        ###################
        color = self._getNextColor()
        graph.setColor(color)
        MLGraphViewer.COLOR_MAP[graph.color()] = False

        self.fig, self.host, first_time = graph.plot(self.fig, self.host)
        if not graph.isInVisibleRectX(self.host):
            self._scaleXToFit()
        tree_item.setBackground(0, QBrush(QColor(graph.color())))
        tree_item.setCheckState(0, Qt.Checked)

    # -------------------------------------------------------------------- #

    def _unplotGraph(self, graph, tree_item):
        ###############
        # Free color: #
        ###############
        if graph.color() is not None:
            MLGraphViewer.COLOR_MAP[graph.color()] = True

        graph.clear()
        tree_item.setBackground(0, QBrush())
        tree_item.setCheckState(0, Qt.Unchecked)

    # -------------------------------------------------------------------- #

    def _onTreeItemChanged(self, item):
        if not self._handleChanges:
            return

        csv_path = item.data(0, Qt.UserRole).toString()
        graph = self._getOrCreateGraph(csv_path)

        self._handleChanges = False
        if item.checkState(0) == Qt.Checked:
            self._plotGraph(graph, item)
        else:
            self._unplotGraph(graph, item)
        self._handleChanges = True
        self.canvas.draw()

    # -------------------------------------------------------------------- #

    def _findItem(self, graph):
        csv_path = graph.csvPath()
        csv_name = os.path.basename(csv_path)
        tree_items = self._tree.findItems(csv_name, Qt.MatchRecursive)
        for item in tree_items:
            if item.data(0, Qt.UserRole) == csv_path:
                return item
        return None

    # -------------------------------------------------------------------- #

    def _resetHost(self):
        self.host.axis('off')
        #self.host.axes.get_yaxis().set_visible(False)
        #self.host.axes.get_xaxis().set_visible(True)
        #self.host.spines['top'].set_visible(False)
        #self.host.spines['right'].set_visible(False)
        #self.host.spines['bottom'].set_visible(True)
        #self.host.spines['left'].set_visible(False)
        #self.host.spines["right"].set_position(("axes", 1.0))
        self.host.set_autoscale_on(False)

    # -------------------------------------------------------------------- #

    def _repaintCanvas(self):
        self.canvas.draw()

    # -------------------------------------------------------------------- #

    def _resetZoom(self):
        self.host.set_ylim(0.0, 1.0)
        for graph in self._graphs.values():
            if graph.isVisible():
                graph.plot(self.fig, self.host)
        self._scaleXToFit()
        self._repaintCanvas()

    # -------------------------------------------------------------------- #

    def _refresh(self):
        self._handleChanges = False
        self._tree.clear()
        for dir in self._dirs:
            name = os.path.basename(os.path.normpath(dir))
            top_level_item = QTreeWidgetItem(self._tree, [name])
            top_level_item.setExpanded(True)
            self._loadDir(dir, top_level_item)

        dead_graphs = []
        for graph in self._graphs.values():
            tree_item = self._findItem(graph)
            if tree_item is None:
                dead_graphs.append(graph)
            else:
                if graph.isVisible():
                    tree_item.setBackground(0, QBrush(QColor(graph.color())))
                    tree_item.setCheckState(0, Qt.Checked)
                    while tree_item is not None:
                        tree_item.setExpanded(True)
                        tree_item = tree_item.parent()

        ######################
        # Close dead graphs: #
        ######################
        for graph in dead_graphs:
            csv_path = graph.csvPath()
            print "Warning: Graph %s no longer exists." % csv_path
            if graph.isVisible():
                graph.clear()
            del self._graphs[csv_path]

        ###################
        # Refresh graphs: #
        ###################
        for graph in self._graphs.values():
            if graph.isVisible():
                graph.refresh(self.fig, self.host)
        self._repaintCanvas()

        self._handleChanges = True

    # -------------------------------------------------------------------- #

    def _onRefreshClicked(self):
        self._refresh()

    # -------------------------------------------------------------------- #

    def _loadDir(self, base_path, tree):
        for element in os.listdir(base_path):
            path = os.path.join(base_path, element)
            if (not os.path.isdir(path)) and (not path.endswith(".csv")):
                continue

            if os.path.isdir(path):
                parent_itm = QTreeWidgetItem(tree, [element])
                self._loadDir(path, parent_itm)
                if parent_itm.childCount() == 0:
                    parent = parent_itm.parent()
                    root = parent_itm.treeWidget().invisibleRootItem()
                    (parent or root).removeChild(parent_itm)
                else:
                    parent_itm.setIcon(
                        0,
                        QIcon(
                            "/usr/share/icons/ubuntu-mono-light/places/16/folder-home.svg"
                        ))
                    # parent_itm.setExpanded(True)
            else:
                if Graph.getGraphDesc(path) is None:
                    continue
                # item = GraphFileTreeWidgetItem(tree, element)
                item = QTreeWidgetItem(tree, [element])
                item.setData(0, Qt.UserRole, path)
                item.setCheckState(0, Qt.Unchecked)

    # -------------------------------------------------------------------- #

    def save_plot(self):
        file_choices = "PNG (*.png)|*.png"

        path = unicode(
            QFileDialog.getSaveFileName(self, 'Save file', '', file_choices))
        if path:
            self.canvas.print_figure(path, dpi=self.dpi)
            self.statusBar().showMessage('Saved to %s' % path, 2000)

    # -------------------------------------------------------------------- #

    def on_about(self):
        msg = """ View one or more graphs on a joined canvas.
            
            Many thanks to Eli Bendersky for the PyPlot skeleton. 
        """
        QMessageBox.about(self, "ML Graph Viewer", msg.strip())

    # -------------------------------------------------------------------- #

    def on_pick(self, event):
        # The event received here is of the type
        # matplotlib.backend_bases.PickEvent
        #
        # It carries lots of information, of which we're using
        # only a small amount here.
        #
        box_points = event.artist.get_bbox().get_points()
        msg = "You've clicked on a bar with coords:\n %s" % box_points

        QMessageBox.information(self, "Click!", msg)

    # -------------------------------------------------------------------- #

    def _openCSVFile(self, base_dir):
        selfilter = QString("CSV (*.csv)")
        file_path = QFileDialog.getOpenFileName(
            None, 'Open File', '/', "All files (*.*);;CSV (*.csv)", selfilter)
        return file_path

    # -------------------------------------------------------------------- #

    def create_main_frame(self):
        self.main_frame = QWidget()

        # Create the mpl Figure and FigCanvas objects.
        # 5x4 inches, 100 dots-per-inch
        #
        #self.dpi = 100
        self.fig = Figure()  #(5.0, 4.0), dpi=self.dpi)
        #self.fig.patch.set_visible(False)
        self.host = self.fig.add_subplot(1, 1, 1)
        self._resetHost()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.main_frame)
        self.fig.tight_layout(pad=0)
        self.fig.subplots_adjust(left=0.01, right=0.92, top=0.98, bottom=0.09)

        self.host.callbacks.connect('xlim_changed', self._onXlimsChange)
        self.host.callbacks.connect('ylim_changed', self._onYlimsChange)

        # Bind the 'pick' event for clicking on one of the bars
        #
        self.canvas.mpl_connect('pick_event', self.on_pick)

        # Create the navigation toolbar, tied to the canvas
        #
        self.mpl_toolbar = NavigationToolbar(self.canvas, self.main_frame)

        #         self.grid_cb = QCheckBox("Show &Grid")
        #         self.grid_cb.setChecked(False)
        #         self.connect(self.grid_cb, SIGNAL('stateChanged(int)'), self.on_draw)

        #         slider_label = QLabel('Bar width (%):')
        #         self.slider = QSlider(Qt.Horizontal)
        #         self.slider.setRange(1, 100)
        #         self.slider.setValue(20)
        #         self.slider.setTracking(True)
        #         self.slider.setTickPosition(QSlider.TicksBothSides)
        #         self.connect(self.slider, SIGNAL('valueChanged(int)'), self.on_draw)

        splitter = QSplitter()

        self._tree = QTreeWidget(splitter)
        self._tree.setHeaderLabel("Files")
        self._refresh()
        self._tree.itemChanged.connect(self._onTreeItemChanged)

        self.b_refresh = QPushButton("Refresh")
        self.b_refresh.clicked.connect(self._onRefreshClicked)

        file_select_pane = QWidget(splitter)
        file_select_pane.setLayout(QVBoxLayout())
        file_select_pane.layout().addWidget(self._tree)
        file_select_pane.layout().addWidget(self.b_refresh)

        canvas_pane = QWidget(splitter)
        canvas_pane.setLayout(QVBoxLayout())
        canvas_pane.layout().addWidget(self.canvas)
        canvas_pane.layout().addWidget(self.mpl_toolbar)

        self.setCentralWidget(splitter)

    def create_status_bar(self):
        self.status_text = QLabel("")
        self.statusBar().addWidget(self.status_text, 1)

    def create_menu(self):
        self.file_menu = self.menuBar().addMenu("&File")

        load_file_action = self.create_action("&Save plot",
                                              shortcut="Ctrl+S",
                                              slot=self.save_plot,
                                              tip="Save the plot")
        quit_action = self.create_action("&Quit",
                                         slot=self.close,
                                         shortcut="Ctrl+Q",
                                         tip="Close the application")

        self.add_actions(self.file_menu, (load_file_action, None, quit_action))

        self.help_menu = self.menuBar().addMenu("&Help")
        about_action = self.create_action("&About",
                                          shortcut='F1',
                                          slot=self.on_about,
                                          tip='About the demo')

        self.add_actions(self.help_menu, (about_action, ))

    def add_actions(self, target, actions):
        for action in actions:
            if action is None:
                target.addSeparator()
            else:
                target.addAction(action)

    def create_action(self,
                      text,
                      slot=None,
                      shortcut=None,
                      icon=None,
                      tip=None,
                      checkable=False,
                      signal="triggered()"):
        action = QAction(text, self)
        if icon is not None:
            action.setIcon(QIcon(":/%s.png" % icon))
        if shortcut is not None:
            action.setShortcut(shortcut)
        if tip is not None:
            action.setToolTip(tip)
            action.setStatusTip(tip)
        if slot is not None:
            self.connect(action, SIGNAL(signal), slot)
        if checkable:
            action.setCheckable(True)
        return action
Example #41
0
class MPLPlot(FCanvas):
    """
    This is the basic matplotlib canvas widget we are using for matplotlib
    plots. This canvas only provides a few convenience tools for automatic
    sizing and creating subfigures, but is otherwise not very different
    from the class ``FCanvas`` that comes with matplotlib (and which we inherit).
    It can be used as any QT widget.
    """
    def __init__(self,
                 parent: Optional[QtWidgets.QWidget] = None,
                 width: float = 4.0,
                 height: float = 3.0,
                 dpi: int = 150,
                 nrows: int = 1,
                 ncols: int = 1):
        """
        Create the canvas.

        :param parent: the parent widget
        :param width: canvas width (inches)
        :param height: canvas height (inches)
        :param dpi: figure dpi
        :param nrows: number of subplot rows
        :param ncols: number of subplot columns
        """

        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super().__init__(self.fig)

        self.axes: List[Axes] = []
        self._tightLayout = False
        self._showInfo = False
        self._infoArtist = None
        self._info = ''
        self._meta_info: Dict[str, str] = {}

        self.clearFig(nrows, ncols)
        self.setParent(parent)

    def autosize(self) -> None:
        """
        Sets some default spacings/margins.
        :return:
        """
        if not self._tightLayout:
            self.fig.subplots_adjust(left=0.125,
                                     bottom=0.125,
                                     top=0.9,
                                     right=0.875,
                                     wspace=0.35,
                                     hspace=0.2)
        else:
            self.fig.tight_layout(rect=[0, 0.03, 1, 0.95])

        self.draw()

    def clearFig(self,
                 nrows: int = 1,
                 ncols: int = 1,
                 naxes: int = 1) -> List[Axes]:
        """
        Clear and reset the canvas.

        :param nrows: number of subplot/axes rows to prepare
        :param ncols: number of subplot/axes column to prepare
        :param naxes: number of axes in total
        :returns: the created axes in the grid
        """
        self.fig.clear()
        setMplDefaults(self)

        self.axes = []
        iax = 1
        if naxes > nrows * ncols:
            raise ValueError(
                f'Number of axes ({naxes}) larger than rows ({nrows}) x '
                f'columns ({ncols}).')

        for i in range(1, naxes + 1):
            kw = {}
            if iax > 1:
                kw['sharex'] = self.axes[0]
                kw['sharey'] = self.axes[0]

            self.axes.append(self.fig.add_subplot(nrows, ncols, i))
            iax += 1

        self.autosize()
        return self.axes

    def resizeEvent(self, event: QtGui.QResizeEvent) -> None:
        """
        Re-implementation of the widget resizeEvent method.
        Makes sure we resize the plots appropriately.
        """
        self.autosize()
        super().resizeEvent(event)

    def setTightLayout(self, tight: bool) -> None:
        """
        Set tight layout mode.
        :param tight: if true, use tight layout for autosizing.
        """
        self._tightLayout = tight
        self.autosize()

    def setShowInfo(self, show: bool) -> None:
        """Whether to show additional info in the plot"""
        self._showInfo = show
        self.updateInfo()

    def updateInfo(self) -> None:
        if self._infoArtist is not None:
            self._infoArtist.remove()
            self._infoArtist = None

        if self._showInfo:
            self._infoArtist = self.fig.text(
                0,
                0,
                self._info,
                fontsize='x-small',
                verticalalignment='bottom',
            )
        self.draw()

    def toClipboard(self) -> None:
        """
        Copy the current canvas to the clipboard.
        """
        buf = io.BytesIO()
        self.fig.savefig(buf,
                         dpi=300,
                         facecolor='w',
                         format='png',
                         transparent=True)
        clipboard = QtWidgets.QApplication.clipboard()
        clipboard.setImage(QtGui.QImage.fromData(buf.getvalue()))
        buf.close()

    def metaToClipboard(self) -> None:
        clipboard = QtWidgets.QApplication.clipboard()
        meta_info_string = "\n".join(f"{k}: {v}"
                                     for k, v in self._meta_info.items())
        clipboard.setText(meta_info_string)

    def setFigureTitle(self, title: str) -> None:
        """Add a title to the figure."""
        self.fig.text(0.5,
                      0.99,
                      title,
                      horizontalalignment='center',
                      verticalalignment='top',
                      fontsize='small')
        self.draw()

    def setFigureInfo(self, info: str) -> None:
        """Display an info string in the figure"""
        self._info = info
        self.updateInfo()

    def setMetaInfo(self, meta_info: Dict[str, str]) -> None:
        self._meta_info = meta_info
Example #42
0
    def plotCrossingRates(self):
        """Plot longitude crossing rates"""
        LOG.debug("Plotting longitude crossing rates")
        fig = Figure()
        ax1 = fig.add_subplot(2, 1, 1)
        for i in range(len(self.gateLons)):
            ax1.plot(2. * self.gateLons[i] -
                     100. * self.lonCrossingEWHist[:, i],
                     self.gateLats[:-1],
                     color='r',
                     lw=2)

            ax1.plot(2. * self.gateLons[i] - 100. * self.synCrossEW[:, i],
                     self.gateLats[:-1],
                     color='k',
                     lw=2)

            x1 = 2. * self.gateLons[i] - 100. * self.synCrossEWUpper[:, i]
            x2 = 2. * self.gateLons[i] - 100. * self.synCrossEWLower[:, i]
            ax1.fill_betweenx(self.gateLats[:-1],
                              x1,
                              x2,
                              color='0.75',
                              alpha=0.7)

        minLonLim = 2. * (self.lon_range.min() - 10.)
        maxLonLim = 2. * (self.lon_range.max() + 10.)
        ax1.set_xlim(minLonLim, maxLonLim)
        ax1.set_xticks(2. * self.gateLons)
        ax1.set_xticklabels(self.gateLons.astype(int))
        ax1.set_xlabel("East-west crossings")
        ax1.set_ylim(self.gateLats.min(), self.gateLats[-2])
        ax1.set_ylabel('Latitude')
        ax1.grid(True)

        ax2 = fig.add_subplot(2, 1, 2)
        for i in range(len(self.gateLons)):
            ax2.plot(2. * self.gateLons[i] +
                     100. * self.lonCrossingWEHist[:, i],
                     self.gateLats[:-1],
                     color='r',
                     lw=2)

            ax2.plot(2. * self.gateLons[i] + 100. * self.synCrossWE[:, i],
                     self.gateLats[:-1],
                     color='k',
                     lw=2)

            x1 = 2. * self.gateLons[i] + 100. * self.synCrossWEUpper[:, i]
            x2 = 2. * self.gateLons[i] + 100. * self.synCrossWELower[:, i]
            ax2.fill_betweenx(self.gateLats[:-1],
                              x1,
                              x2,
                              color='0.75',
                              alpha=0.7)

        ax2.set_xlim(minLonLim, maxLonLim)
        ax2.set_xticks(2. * self.gateLons)
        ax2.set_xticklabels(self.gateLons.astype(int))

        ax2.set_xlabel("West-east crossings")
        ax2.set_ylim(self.gateLats.min(), self.gateLats[-2])
        ax2.set_ylabel('Latitude')
        ax2.grid(True)

        fig.tight_layout()
        canvas = FigureCanvas(fig)
        canvas.print_figure(pjoin(self.plotPath, 'lon_crossing_syn.png'))

        return
Example #43
0
class ReportWidget(QWidget):
    """ A class to display the historical data """
    def __init__(self, parent=None, nrows=3, ncols=2):
        super(ReportWidget, self).__init__(parent)

        self.data = parent.data
        self.ncols = ncols

        self.start_hour_val = datetime.now(
        ).hour if datetime.now().hour > 9 else 9
        self.stop_hour_val = 24 if datetime.now().hour >= 17 else 17
        self.frac_hour_val = 0.5  #0.125

        self.nrows, self.ncols = nrows, ncols
        self.facecolor = (49. / 255, 54. / 255, 59. / 255)
        self.figure = Figure(figsize=(4 * ncols, 3 * nrows),
                             facecolor=self.facecolor)
        self.canvas = FigureCanvas(self.figure)
        self.toolbar = NavigationToolbar(self.canvas, self)

        gs = gridspec.GridSpec(self.nrows,
                               self.ncols,
                               width_ratios=[3, 1],
                               wspace=0.05)
        for r in range(self.nrows):
            for c in range(self.ncols):
                self.figure.add_subplot(gs[r, c])
        self.axes = np.array(self.figure.axes).reshape(self.nrows, self.ncols)
        self.axes[0, 0].get_shared_x_axes().join(self.axes[0, 0],
                                                 self.axes[1, 0], self.axes[2,
                                                                            0])
        for r in range(self.nrows):
            self.axes[r, 0].get_shared_y_axes().join(self.axes[r, 0],
                                                     self.axes[r, 1])

        self.cax = []

        self.reportsHBox = QHBoxLayout()
        start_combo = QComboBox()
        stop_combo = QComboBox()
        frac_combo = QComboBox()
        start_label = QLabel()
        stop_label = QLabel()
        frac_label = QLabel()
        zoom_buttom = QPushButton()

        start_combo.addItems(str(h) for h in range(24))
        stop_combo.addItems(str(h) for h in range(1, 25))
        frac_combo.addItems(str(f) for f in 1. / np.array([2, 4, 8, 16]))
        start_label.setText('Start Hour')
        stop_label.setText('End Hour')
        frac_label.setText('Increment Fraction')
        start_combo.setCurrentText(str(self.start_hour_val))
        stop_combo.setCurrentText(str(self.stop_hour_val))
        frac_combo.setCurrentText(str(self.frac_hour_val))
        zoom_buttom.setText('Quick Zoom')

        start_combo.activated[str].connect(self.update_start)
        stop_combo.activated[str].connect(self.update_stop)
        frac_combo.activated[str].connect(self.update_frac)
        zoom_buttom.clicked.connect(self.quick_zoom)

        self.reportsHBox.addWidget(zoom_buttom)
        self.reportsHBox.addWidget(start_label)
        self.reportsHBox.addWidget(start_combo)
        self.reportsHBox.addWidget(stop_label)
        self.reportsHBox.addWidget(stop_combo)
        self.reportsHBox.addWidget(frac_label)
        self.reportsHBox.addWidget(frac_combo)

        self.layoutVertical = QVBoxLayout(self)
        self.layoutVertical.addLayout(self.reportsHBox)
        self.layoutVertical.addWidget(self.toolbar)
        self.layoutVertical.addWidget(self.canvas)

        self.ims = np.empty((self.nrows, self.ncols), dtype=AxesImage)

        self.num_hours = (self.stop_hour_val -
                          self.start_hour_val) / self.frac_hour_val
        self.day_hours = np.linspace(
            self.start_hour_val, self.stop_hour_val, self.num_hours
        )  #use linspace instead of arange to include the final value
        self.actual_day_hours = np.arange(
            datetime.now().hour + float(datetime.now().minute) / 60, 23, 0.25)
        self.ts_hist_loc = [0, 1]
        self.completed_lines = [None, None, None]
        self.completed_hists = [[None, None], [None, None], [None, None]]
        self.ts_hist = None

        self.initialize_plots()
        self.figure.tight_layout(
        )  # why does this have to be here to not ignore the axis labels?
        self.figure.subplots_adjust(top=0.965,
                                    bottom=0.1,
                                    left=0.12,
                                    right=0.971,
                                    hspace=0,
                                    wspace=0.28)

        parent.horizontalLayout_2.addWidget(self)
        parent.reports_groupBox.setLayout(parent.horizontalLayout_2)

    def quick_zoom(self):
        xmin, xmax = self.data.work_time_hours[-1] + np.array([-2, 2])
        for y in range(3):
            self.axes[y, 0].set_xlim(xmin, xmax)
            # mins = np.array([np.min(self.data.metrics_history[y][-1]), self.data.start_goals[y]])
            # maxs = np.array([np.max(self.data.metrics_history[y][-1]), self.data.goals[y]])
            # ymin, ymax = np.array([min(mins), max(maxs)]) * np.array([0.5,1.5])
            # self.axes[y, 0].set_ylim(ymin,ymax)

    def update_start(self, text):
        self.start_hour_val = int(text)
        self.day_hours = np.linspace(self.start_hour_val, self.stop_hour_val,
                                     self.num_hours)
        self.update_goals(self.data.goals)

    def update_stop(self, text):
        self.stop_hour_val = int(text)
        self.day_hours = np.linspace(self.start_hour_val, self.stop_hour_val,
                                     self.num_hours)
        self.update_goals(self.data.goals)

    def update_frac(self, text):
        self.frac_hour_val = float(text)
        self.num_hours = (self.stop_hour_val -
                          self.start_hour_val) / self.frac_hour_val
        self.day_hours = np.linspace(self.start_hour_val, self.stop_hour_val,
                                     self.num_hours)
        self.update_goals(self.data.goals)

    def initialize_plots(self):
        self.goal_lines = []
        for y in range(self.ncols):
            axes = self.axes[:, y]
            for i, start, goal, ax, ylabel in zip(range(len(self.data.goals)),
                                                  self.data.start_goals,
                                                  self.data.goals, axes,
                                                  self.data.ylabels):

                ax.tick_params(direction='in', color='w', labelcolor='w')
                for spine in ax.spines.values():
                    spine.set_edgecolor('w')
                ax.xaxis.label.set_color('w')
                ax.yaxis.label.set_color('w')
                ax.set_facecolor(self.facecolor)
                if i == 2 and y == 0:
                    ax.set_xlabel('Clock time (hours)')
                    ax.set_ylim(-self.data.daily.todo_goal,
                                self.data.daily.todo_goal)
                if i < 2:
                    plt.setp(ax.get_xticklabels(), visible=False)

                if y == 0:
                    goal_steps = np.linspace(start, goal, len(
                        self.day_hours))  # factor of 3600 will be from here
                    self.goal_lines.append(
                        ax.plot(self.day_hours,
                                goal_steps,
                                linestyle='--',
                                color='w',
                                linewidth=2))
                    ax.set_ylabel(ylabel)
                if i == 2 and y == 1:
                    ax.set_xlabel('Amount')
                if y == 1:
                    plt.setp(ax.get_yticklabels(), visible=False)

                # ax.legend()

    def update_goals(self, goals):
        for ig, line, start, goal, ax in zip(range(len(self.data.goals)),
                                             self.goal_lines,
                                             self.data.start_goals, goals,
                                             self.axes[:, 0]):
            line.pop(0).remove()
            self.data.goals[ig] = goal
            goal_steps = np.linspace(start, goal, len(self.day_hours))
            ax.set_ylim(0, goal)
            self.goal_lines[ig] = ax.plot(self.day_hours,
                                          goal_steps,
                                          linestyle='--',
                                          color='w',
                                          linewidth=2)
            if ig == 2:
                ax.set_ylim(-self.data.daily.todo_goal,
                            self.data.daily.todo_goal)
            m = goal / (self.stop_hour_val - self.start_hour_val)
            current_goal = m * (self.data.work_time_hours -
                                self.start_hour_val)
            self.data.goal_hours[ig] = current_goal

    def update_lineplots(self, diff_sec):
        """
        During each TimerWidget.tick a (or several) data point(s) are added to the plots

        Parameters
        ----------
        diff_sec : int
                   the number of seconds since the code last updated (because of mac background mode)
        """
        current_goals = []
        for ig, line, ax, goal, metric in zip(range(len(self.data.goals)),
                                              self.completed_lines,
                                              self.axes[:, 0], self.data.goals,
                                              self.data.metrics_history):
            if line is not None:
                line.remove()
            ax.collections.clear()
            self.completed_lines[ig], = ax.plot(self.data.work_time_hours,
                                                metric,
                                                color=(64. / 255, 173. / 255,
                                                       233. / 255),
                                                linewidth=2)

            m = goal / (self.stop_hour_val - self.start_hour_val)
            new_inds = (diff_sec + 1) if diff_sec > 0 else 2
            current_goal = m * (self.data.work_time_hours[-new_inds:] -
                                self.start_hour_val)
            if np.any(current_goal > goal):
                current_goal = np.ones_like(current_goal) * goal
            current_goals.append(current_goal)
            # self.data.goal_hours[ig] = np.append(self.data.goal_hours[ig, :-1], current_goal)

            goal_hours = np.append(self.data.goal_hours[ig, :-1], current_goal)
            # ax.plot(self.data.work_time_hours, self.data.goal_hours[ig], color='orange', marker='o')

            ax.fill_between(self.data.work_time_hours,
                            metric,
                            goal_hours,
                            where=goal_hours >= metric,
                            facecolor='orangered',
                            interpolate=True)

            ax.fill_between(self.data.work_time_hours,
                            metric,
                            goal_hours,
                            where=goal_hours <= metric,
                            facecolor='lime',
                            interpolate=True)

            self.canvas.draw()

        self.data.goal_hours = np.append(self.data.goal_hours[:, :-1],
                                         current_goals,
                                         axis=1)

    def update_time_hist(self):
        for ig, hist, ax, metric, bins in zip(range(len(self.data.goals)),
                                              self.completed_hists,
                                              self.axes[:, 1],
                                              self.data.metrics_history,
                                              self.data.metric_bins):

            if not None in hist:
                [b.remove() for b in hist]
            ax.collections.clear()

            #todo make not slow by only operating on the current bin

            # now = datetime.now()
            # current_bin = now.hour+now.minute/60+now.second/3600 == bins
            # print(current_bin)
            # where_thrive = self.data.goal_hours[ig] <= metric
            # metric_heights, _ = np.histogram(metric, bins=bins)
            # survive_heights, thrive_heights = np.array(
            #     [(sum(b == False), sum(b == True)) for b in np.split(where_thrive, np.cumsum(metric_heights))]).T
            # self.completed_hists[ig][0] = ax.barh(bins, thrive_heights, height=2, color='springgreen')
            # self.completed_hists[ig][1] = ax.barh(bins, survive_heights, left=thrive_heights, height=2, color='orangered')

            start = time.time()
            where_thrive = self.data.goal_hours[ig] <= metric
            thrive = metric[where_thrive]
            survive = metric[~where_thrive]
            end = time.time()
            # print(ig, end - start)

            start = time.time()
            colors = ['lime', 'orangered']
            _, _, self.completed_hists[ig] = ax.hist([thrive, survive],
                                                     bins=bins,
                                                     color=colors,
                                                     orientation='horizontal',
                                                     stacked=True)
            end = time.time()
            # print(ig, end - start)

            # start = time.time()
            # _, _, self.completed_hists[ig] = ax.hist(metric, bins=bins,
            #                                          color=(64./255,173./255,233./255), orientation='horizontal')
            # end = time.time()
            # print(ig, end - start)

            self.canvas.draw()
Example #44
0
class CalibrationViewer(QtGui.QMainWindow):
    trialsChanged = QtCore.pyqtSignal()

    def __init__(self):
        self.filters = list()
        super(CalibrationViewer, self).__init__()
        self.setWindowTitle('Olfa Calibration')
        self.statusBar()
        self.trial_selected_list = []
        self.trialsChanged.connect(self._trial_selection_changed)

        mainwidget = QtGui.QWidget(self)
        self.setCentralWidget(mainwidget)
        layout = QtGui.QGridLayout(mainwidget)
        mainwidget.setLayout(layout)

        menu = self.menuBar()
        filemenu = menu.addMenu("&File")
        toolsmenu = menu.addMenu("&Tools")

        openAction = QtGui.QAction("&Open recording...", self)
        openAction.triggered.connect(self._openAction_triggered)
        openAction.setStatusTip(
            "Open a HDF5 data file with calibration recording.")
        openAction.setShortcut("Ctrl+O")
        filemenu.addAction(openAction)
        saveFigsAction = QtGui.QAction('&Save figures...', self)
        saveFigsAction.triggered.connect(self._saveFiguresAction_triggered)
        saveFigsAction.setShortcut('Ctrl+S')
        openAction.setStatusTip("Saves current figures.")
        filemenu.addAction(saveFigsAction)
        exitAction = QtGui.QAction("&Quit", self)
        exitAction.setShortcut("Ctrl+Q")
        exitAction.setStatusTip("Quit program.")
        exitAction.triggered.connect(QtGui.qApp.quit)
        filemenu.addAction(exitAction)
        removeTrialAction = QtGui.QAction("&Remove trials", self)
        removeTrialAction.setStatusTip(
            'Permanently removes selected trials (bad trials) from trial list.'
        )
        removeTrialAction.triggered.connect(self._remove_trials)
        removeTrialAction.setShortcut('Ctrl+R')
        toolsmenu.addAction(removeTrialAction)

        trial_group_list_box = QtGui.QGroupBox()
        trial_group_list_box.setTitle('Trial Groups')
        self.trial_group_list = TrialGroupListWidget()
        trial_group_layout = QtGui.QVBoxLayout()
        trial_group_list_box.setLayout(trial_group_layout)
        trial_group_layout.addWidget(self.trial_group_list)
        layout.addWidget(trial_group_list_box, 0, 0)
        self.trial_group_list.itemSelectionChanged.connect(
            self._trial_group_selection_changed)

        trial_select_list_box = QtGui.QGroupBox()
        trial_select_list_box.setMouseTracking(True)
        trial_select_list_layout = QtGui.QVBoxLayout()
        trial_select_list_box.setLayout(trial_select_list_layout)
        trial_select_list_box.setTitle('Trials')
        self.trial_select_list = TrialListWidget()
        self.trial_select_list.setMouseTracking(True)
        trial_select_list_layout.addWidget(self.trial_select_list)
        self.trial_select_list.setSelectionMode(
            QtGui.QAbstractItemView.ExtendedSelection)
        self.trial_select_list.itemSelectionChanged.connect(
            self._trial_selection_changed)
        layout.addWidget(trial_select_list_box, 0, 1)
        self.trial_select_list.createGroupSig.connect(
            self.trial_group_list.create_group)

        filters_box = QtGui.QGroupBox("Trial filters.")
        filters_box_layout = QtGui.QVBoxLayout(filters_box)
        filters_scroll_area = QtGui.QScrollArea()
        filters_buttons = QtGui.QHBoxLayout()
        filters_all = QtGui.QPushButton('Select all', self)
        filters_all.clicked.connect(self._select_all_filters)
        filters_none = QtGui.QPushButton('Select none', self)
        filters_none.clicked.connect(self._select_none_filters)
        filters_buttons.addWidget(filters_all)
        filters_buttons.addWidget(filters_none)
        filters_box_layout.addLayout(filters_buttons)
        filters_box_layout.addWidget(filters_scroll_area)
        filters_wid = QtGui.QWidget()
        filters_scroll_area.setWidget(filters_wid)
        filters_scroll_area.setWidgetResizable(True)
        filters_scroll_area.setFixedWidth(300)
        self.filters_layout = QtGui.QVBoxLayout()
        filters_wid.setLayout(self.filters_layout)
        layout.addWidget(filters_box, 0, 2)

        plots_box = QtGui.QGroupBox()
        plots_box.setTitle('Plots')
        plots_layout = QtGui.QHBoxLayout()

        self.figure = Figure((9, 5))
        self.figure.patch.set_facecolor('None')
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setParent(plots_box)
        self.canvas.setSizePolicy(QtGui.QSizePolicy.Expanding,
                                  QtGui.QSizePolicy.Expanding)
        plots_layout.addWidget(self.canvas)
        plots_box.setLayout(plots_layout)
        layout.addWidget(plots_box, 0, 3)
        self.ax_pid = self.figure.add_subplot(2, 1, 1)
        self.ax_pid.set_title('PID traces')
        self.ax_pid.set_ylabel('')
        self.ax_pid.set_xlabel('t (ms)')
        # self.ax_pid.set_yscale('log')
        self.ax_mean_plots = self.figure.add_subplot(2, 1, 2)
        self.ax_mean_plots.set_title('Mean value')
        self.ax_mean_plots.set_ylabel('value')
        self.ax_mean_plots.set_xlabel('Concentration')
        self.ax_mean_plots.autoscale(enable=True, axis=u'both', tight=False)
        self.figure.tight_layout()

    @QtCore.pyqtSlot()
    def _list_context_menu_trig(self):
        pass

    @QtCore.pyqtSlot()
    def _filters_changed(self):
        mask = np.ones_like(self.trial_mask)
        for v in self.filters:
            mask *= v.trial_mask
        self.trial_mask = mask
        self.trial_select_list.itemSelectionChanged.disconnect(
            self._trial_selection_changed)
        for i in xrange(len(self.trial_mask)):
            hide = not self.trial_mask[i]
            it = self.trial_select_list.item(i)
            it.setHidden(hide)
            select = it.isSelected() * self.trial_mask[i]
            it.setSelected(select)
        self.trial_select_list.itemSelectionChanged.connect(
            self._trial_selection_changed)
        self.trial_select_list.itemSelectionChanged.emit(
        )  # emit that something changed so that we redraw.

    @QtCore.pyqtSlot()
    def _openAction_triggered(self):
        filedialog = QtGui.QFileDialog(self)
        if os.path.exists('D:\\experiment\\raw_data'):
            startpath = 'D:\\experiment\\raw_data\\mouse_o_cal_cw\\sess_001'
        else:
            startpath = ''
        fn = filedialog.getOpenFileName(self, "Select a data file.", startpath,
                                        "HDF5 (*.h5)")
        if fn:
            data = CalibrationFile(str(fn))
            self.data = data
            self.trial_actions = []
            trial_num_list = []
            self.trial_select_list.clear()
            trials = self.data.trials
            for i, t in enumerate(trials):
                tstr = "Trial {0}".format(i)
                it = QtGui.QListWidgetItem(tstr, self.trial_select_list)
                # trial = trials[i]
                # odor = trial['odor']
                # vialconc = trial['vialconc']
                # odorconc = trial['odorconc']
                # dilution = 1000 - trial['dilution'][0]
                # trst = 'Odor: {0}, vialconc: {1}, odorconc: {2}, dilution: {3}'.format(odor, vialconc, odorconc,
                #                                                                        dilution)
                # it.setStatusTip(trst)
                trial_num_list.append(i)
            self.trial_select_list.trial_num_list = np.array(trial_num_list)
            self.trial_mask = np.ones(len(
                self.trial_select_list.trial_num_list),
                                      dtype=bool)
            self.build_filters(trials)
        else:
            print('No file selected.')
        return

    def build_filters(self, trials):
        while self.filters_layout.itemAt(0):
            self.filters_layout.takeAt(0)
        if self.filters:
            for f in self.filters:
                f.deleteLater()
        self.filters = list()
        colnames = trials.dtype.names
        if 'odorconc' not in colnames:
            self.error = QtGui.QErrorMessage()
            self.error.showMessage(
                'Data file must have "odorconc" field to allow plotting.')
        start_strings = ('odorconc', 'olfas', 'dilutors')
        filter_fields = []
        for ss in start_strings:
            for fieldname in colnames:
                if fieldname.startswith(ss):
                    filter_fields.append(fieldname)

        for field in filter_fields:
            filter = FiltersListWidget(field)
            filter.populate_list(self.data.trials)
            filter.setVisible(False)
            self.filters.append(filter)
            box = QtGui.QWidget()
            box.setSizePolicy(0, 0)
            # box.setTitle(filter.fieldname)
            show_button = QtGui.QPushButton(filter.fieldname)
            show_button.setStyleSheet('text-align:left; border:0px')
            show_button.clicked.connect(filter.toggle_visible)
            _filt_layout = QtGui.QVBoxLayout(box)
            _filt_layout.addWidget(show_button)
            _filt_layout.addWidget(filter)
            _filt_layout.setSpacing(0)
            self.filters_layout.addWidget(box)
            filter.filterChanged.connect(self._filters_changed)
        for v in self.filters:
            assert isinstance(v, FiltersListWidget)

        # self.filters_layout.addWidget(QtGui.QSpacerItem())
        self.filters_layout.addStretch()
        self.filters_layout.setSpacing(0)

        return

    @QtCore.pyqtSlot()
    def _saveFiguresAction_triggered(self):
        # TODO: add figure saving functionality with filedialog.getSaveFileName.
        self.saveDialog = QtGui.QFileDialog()
        saveloc = self.saveDialog.getSaveFileName(
            self, 'Save figure', '', 'PDF (*.pdf);;JPEG (*.jpg);;TIFF (*.tif)')
        saveloc = str(saveloc)
        self.figure.savefig(saveloc)

    @QtCore.pyqtSlot()
    def _remove_trials(self):
        selected_idxes = self.trial_select_list.selectedIndexes()
        remove_idxes = []
        for id in selected_idxes:
            idx = id.row()
            remove_idxes.append(idx)
        while self.trial_select_list.selectedIndexes():
            selected_idxes = self.trial_select_list.selectedIndexes()
            idx = selected_idxes[0].row()
            self.trial_select_list.takeItem(idx)
        new_trials_array = np.zeros(
            len(self.trial_select_list.trial_num_list) - len(remove_idxes),
            dtype=np.int)
        ii = 0
        remove_trialnums = []
        new_trials_mask = np.zeros_like(new_trials_array, dtype=bool)
        for i in xrange(len(self.trial_select_list.trial_num_list)):
            if i not in remove_idxes:
                new_trials_mask[ii] = self.trial_mask[i]
                new_trials_array[ii] = self.trial_select_list.trial_num_list[i]
                ii += 1
            else:
                remove_trialnums.append(
                    self.trial_select_list.trial_num_list[i])
        self.trial_mask = new_trials_mask
        self.trial_select_list.trial_num_list = new_trials_array

        for f in self.filters:
            f.remove_trials(remove_idxes)
        self.trial_group_list._remove_trials(remove_trialnums)

    @QtCore.pyqtSlot()
    def _trial_selection_changed(self):
        selected_idxes = self.trial_select_list.selectedIndexes()
        selected_trial_nums = []
        for id in selected_idxes:
            idx = id.row()
            trialnum = self.trial_select_list.trial_num_list[idx]
            selected_trial_nums.append(trialnum)
        self.update_plots(selected_trial_nums)
        self.trial_group_list.blockSignals(True)
        for i, g in zip(xrange(self.trial_group_list.count()),
                        self.trial_group_list.trial_groups):
            it = self.trial_group_list.item(i)
            all_in = True
            group_trials = g['trial_nums']
            for t in group_trials:
                if t not in selected_trial_nums:
                    all_in = False
            if not all_in:
                it.setSelected(False)
            elif all_in:
                it.setSelected(True)
        self.trial_group_list.blockSignals(False)
        return

    @QtCore.pyqtSlot()
    def _trial_group_selection_changed(self):
        selected_idxes = self.trial_group_list.selectedIndexes()
        self._select_all_filters()
        selected_trial_nums = []
        for id in selected_idxes:
            idx = id.row()
            trialnums = self.trial_group_list.trial_groups[idx]['trial_nums']
            selected_trial_nums.extend(trialnums)
        self.trial_select_list.blockSignals(True)
        for i in range(self.trial_select_list.count()):
            item = self.trial_select_list.item(i)
            self.trial_select_list.setItemSelected(item, False)
        for i in selected_trial_nums:
            idx = np.where(self.trial_select_list.trial_num_list == i)[0][0]
            it = self.trial_select_list.item(idx)
            if not it.isSelected():
                it.setSelected(True)
        self.trial_select_list.blockSignals(False)
        self._trial_selection_changed()

    def update_plots(self, trials):
        padding = (
            2000, 2000
        )  #TODO: make this changable - this is the number of ms before/afterr trial to extract for stream.
        trial_streams = []
        trial_colors = []
        while self.ax_pid.lines:
            self.ax_pid.lines.pop(0)
        while self.ax_mean_plots.lines:
            self.ax_mean_plots.lines.pop(0)
        groups_by_trial = []
        all_groups = set()
        ntrials = len(trials)
        vals = np.empty(ntrials)
        concs = np.empty_like(vals)
        if trials:
            a = max([1. / len(trials), .25])
            for i, tn in enumerate(trials):
                color = self.trial_group_list.get_trial_color(tn)
                groups = self.trial_group_list.get_trial_groups(tn)
                trial_colors.append(color)
                groups_by_trial.append(groups)
                all_groups.update(groups)
                trial = self.data.return_trial(tn, padding=padding)
                stream = remove_stream_trend(trial.streams['sniff'],
                                             (0, padding[0]))
                stream -= stream[0:padding[0]].min()
                # TODO: remove baseline (N2) trial average from this.
                trial_streams.append(stream)
                self.ax_pid.plot(stream, color=color, alpha=a)
                conc = trial.trials['odorconc']
                baseline = np.mean(stream[:2000])
                val = np.mean(stream[3000:4000]) - baseline
                vals[i] = val
                concs[i] = conc
                self.ax_mean_plots.plot(conc, val, '.', color=color)
        minlen = 500000000
        for i in trial_streams:
            minlen = min(len(i), minlen)
        streams_array = np.empty((ntrials, minlen))
        for i in xrange(ntrials):
            streams_array[i, :] = trial_streams[i][:minlen]
        for g in all_groups:
            mask = np.empty(ntrials, dtype=bool)
            for i in xrange(ntrials):
                groups = groups_by_trial[i]
                mask[i] = g in groups
            c = concs[mask]
            groupstreams = streams_array[mask]
            if len(np.unique(c)) < 2:
                self.ax_pid.plot(groupstreams.mean(axis=0),
                                 color='k',
                                 linewidth=2)
            else:
                v = vals[mask]
                a, b, _, _, _ = stats.linregress(c, v)
                color = self.trial_group_list.get_group_color(g)
                minn, maxx = self.ax_mean_plots.get_xlim()
                x = np.array([minn, maxx])
                self.ax_mean_plots.plot(x, a * x + b, color=color)
        self.ax_pid.relim()
        self.ax_mean_plots.set_yscale('log')
        self.ax_mean_plots.set_xscale('log')
        self.ax_mean_plots.relim()

        self.canvas.draw()

    @QtCore.pyqtSlot()
    def _select_none_filters(self):
        for filter in self.filters:
            filter.filterChanged.disconnect(self._filters_changed)
            filter.clearSelection()
            filter.filterChanged.connect(self._filters_changed)
        self._filters_changed()

    @QtCore.pyqtSlot()
    def _select_all_filters(self):
        for filter in self.filters:
            filter.filterChanged.disconnect(self._filters_changed)
            filter.selectAll()
            filter.filterChanged.connect(self._filters_changed)
        self._filters_changed()
Example #45
0
class ConanReportGraphsPresenter(Presenter[Gtk.Stack]):
    spinner: TemplateChild[Gtk.Spinner] = TemplateChild('spinner')
    figure_container = TemplateChild('figure_container')  # type: TemplateChild[Gtk.Container]

    _analyses = ()

    @inject
    def __init__(self, graphs_service: ConanReportGraphsService) -> None:
        self.graphs_service = graphs_service

    def after_view_init(self) -> None:
        self.figure = Figure(tight_layout=False)

        self.figure_canvas = FigureCanvas(self.figure)
        self.figure_canvas.props.hexpand = True
        self.figure_canvas.props.vexpand = True
        self.figure_canvas.props.visible = True
        self.figure_container.add(self.figure_canvas)

        self.figure_canvas_mapped = False

        self.figure_canvas.connect('map', self.canvas_map)
        self.figure_canvas.connect('unmap', self.canvas_unmap)
        self.figure_canvas.connect('size-allocate', self.canvas_size_allocate)

        left_ca_ax, right_ca_ax = self.figure.subplots(2, 1, sharex='col')

        left_ca_ax.set_ylabel('Left CA [°]')
        left_ca_ax.tick_params(axis='x', direction='inout')

        right_ca_ax.xaxis.set_ticks_position('both')
        right_ca_ax.set_ylabel('Right CA [°]')
        right_ca_ax.tick_params(axis='x', direction='inout')

        left_ca_ax.tick_params(axis='y', left=False, labelleft=False, right=True, labelright=True)
        right_ca_ax.tick_params(axis='y', left=False, labelleft=False, right=True, labelright=True)

        # Format the labels to scale to the right units.
        left_ca_ax.get_yaxis().set_major_formatter(
            FuncFormatter(lambda x, _: '{:.4g}'.format(math.degrees(x)))
        )
        right_ca_ax.get_yaxis().set_major_formatter(
            FuncFormatter(lambda x, _: '{:.4g}'.format(math.degrees(x)))
        )

        left_ca_ax.grid(axis='x', linestyle='--', color="#dddddd")
        left_ca_ax.grid(axis='y', linestyle='-', color="#dddddd")

        right_ca_ax.grid(axis='x', linestyle='--', color="#dddddd")
        right_ca_ax.grid(axis='y', linestyle='-', color="#dddddd")

        self._left_ca_ax = left_ca_ax
        self._left_ca_line = left_ca_ax.plot([], marker='o', color='#0080ff')[0]
        self._right_ca_line = right_ca_ax.plot([], marker='o', color='#ff8000')[0]

        self.graphs_service.connect('notify::left-angle', self.data_changed)
        self.graphs_service.connect('notify::right-angle', self.data_changed)

        self.data_changed()

    def canvas_map(self, *_) -> None:
        self.figure_canvas_mapped = True
        self.figure_canvas.draw_idle()

    def canvas_unmap(self, *_) -> None:
        self.figure_canvas_mapped = False

    def canvas_size_allocate(self, *_) -> None:
        self.figure.tight_layout(pad=2.0, h_pad=0)
        self.figure.subplots_adjust(hspace=0)

    @install
    @GObject.Property
    def analyses(self) -> Sequence[ConanAnalysisJob]:
        return self._analyses

    @analyses.setter
    def analyses(self, analyses: Iterable[ConanAnalysisJob]) -> None:
        self._analyses = tuple(analyses)
        self.graphs_service.analyses = analyses

    def data_changed(self, *_) -> None:
        left_angle_data = self.graphs_service.left_angle
        right_angle_data = self.graphs_service.right_angle

        if left_angle_data.shape[1] <= 1 or right_angle_data.shape[1] <=1:
            self.show_waiting_placeholder()
            return

        self.hide_waiting_placeholder()

        self._left_ca_line.set_data(left_angle_data)
        self._right_ca_line.set_data(right_angle_data)

        self.update_xlim()

        self._left_ca_line.axes.relim()
        self._left_ca_line.axes.margins(y=0.1)
        self._right_ca_line.axes.relim()
        self._right_ca_line.axes.margins(y=0.1)

        self.figure_canvas.draw()

    def update_xlim(self) -> None:
        all_xdata = (
            *self._left_ca_line.get_xdata(),
            *self._right_ca_line.get_xdata(),
        )

        if len(all_xdata) <= 1:
            return

        xmin = min(all_xdata)
        xmax = max(all_xdata)

        if xmin == xmax:
            return

        self._left_ca_ax.set_xlim(xmin, xmax)

    def show_waiting_placeholder(self) -> None:
        self.host.set_visible_child(self.spinner)
        self.spinner.start()

    def hide_waiting_placeholder(self) -> None:
        self.host.set_visible_child(self.figure_container)
        self.spinner.stop()
Example #46
0
class MapCanvas(FigureCanvasQTAgg):
    """Plotting maps."""
    def __init__(self, parent=None, width=5, height=4, dpi=100):

        logging.info("Initialize MapCanvas")
        self.trajectories = defaultdict(list)

        self.fig = Figure(figsize=(width, height), dpi=dpi)
        self.main = parent
        self.lock = Lock()

        FigureCanvasQTAgg.__init__(self, self.fig)
        self.setParent(parent)
        FigureCanvasQTAgg.setSizePolicy(self, QSizePolicy.Expanding,
                                        QSizePolicy.Expanding)
        FigureCanvasQTAgg.updateGeometry(self)
        self.create_map()

    def wheelEvent(self, event):
        if sys.platform == "darwin":  # rather use pinch
            return
        self.zoom(event.angleDelta().y() > 0, 0.8)

    def zoom(self, zoom_in, factor):
        min_x, max_x, min_y, max_y = self.ax.axis()
        if not zoom_in:
            factor = 1.0 / factor

        center_x = 0.5 * (max_x + min_x)
        delta_x = 0.5 * (max_x - min_x)
        center_y = 0.5 * (max_y + min_y)
        delta_y = 0.5 * (max_y - min_y)

        self.ax.axis((
            center_x - factor * delta_x,
            center_x + factor * delta_x,
            center_y - factor * delta_y,
            center_y + factor * delta_y,
        ))

        self.fig.tight_layout()
        self.lims = self.ax.axis()
        fmt = ", ".join("{:.5e}".format(t) for t in self.lims)
        logging.info("Zooming to {}".format(fmt))
        self.draw()

    def create_map(self,
                   projection: Union[str, Projection] = "EuroPP()") -> None:
        if isinstance(projection, str):
            if not projection.endswith(")"):
                projection = projection + "()"
            projection = eval(projection)

        self.projection = projection
        self.trajectories.clear()

        with plt.style.context("traffic"):

            self.fig.clear()
            self.ax = self.fig.add_subplot(111, projection=self.projection)
            projection_name = projection.__class__.__name__.split(".")[-1]

            self.ax.add_feature(
                countries(scale="10m" if projection_name not in
                          ["Mercator", "Orthographic"] else "110m"))
            if projection_name in ["Lambert93", "GaussKruger", "Amersfoort"]:
                # Hardcoded projection list?! :o/
                self.ax.add_feature(rivers())

            self.fig.set_tight_layout(True)
            self.ax.background_patch.set_visible(False)
            self.ax.outline_patch.set_visible(False)
            self.ax.format_coord = lambda x, y: ""
            self.ax.set_global()

        self.draw()

    def plot_callsigns(self, traffic: Traffic, callsigns: List[str]) -> None:
        if traffic is None:
            return

        for key, value in self.trajectories.items():
            for elt in value:
                elt.remove()
        self.trajectories.clear()
        self.ax.set_prop_cycle(None)

        for c in callsigns:
            f = traffic[c]
            if f is not None:
                try:
                    self.trajectories[c] += f.plot(self.ax)
                    f_at = f.at()
                    if (f_at is not None and hasattr(f_at, "latitude")
                            and f_at.latitude == f_at.latitude):
                        self.trajectories[c] += f_at.plot(self.ax,
                                                          s=8,
                                                          text_kw=dict(s=c))
                except AttributeError:
                    # 'DataFrame' object has no attribute 'longitude'
                    pass
                except TypeError:
                    # NoneType object is not iterable
                    pass

        if len(callsigns) == 0:
            self.default_plot(traffic)

        self.draw()

    def default_plot(self, traffic: Traffic) -> None:
        if traffic is None:
            return
        # clear all trajectory pieces
        for key, value in self.trajectories.items():
            for elt in value:
                elt.remove()
        self.trajectories.clear()

        lon_min, lon_max, lat_min, lat_max = self.ax.get_extent(PlateCarree())
        cur_ats = list(f.at() for f in traffic)
        cur_flights = list(
            at for at in cur_ats if at is not None
            if hasattr(at, "latitude") and at.latitude is not None and lat_min
            <= at.latitude <= lat_max and lon_min <= at.longitude <= lon_max)

        def params(at):
            if len(cur_flights) < 10:
                return dict(s=8, text_kw=dict(s=at.callsign))
            else:
                return dict(s=8, text_kw=dict(s=""))

        for at in cur_flights:
            if at is not None:
                self.trajectories[at.callsign] += at.plot(
                    self.ax, **params(at))

        self.draw()

    def draw(self):
        with self.lock:
            if self.fig is None:
                return
            super().draw()
Example #47
0
class StartPage(Frame):
    def __init__(self, parent, controller):

        Frame.__init__(self, parent)
        query = "SELECT fecha,categoria,eur FROM facts_table WHERE tipo='Gasto'"
        query2 = "SELECT fecha,categoria,eur FROM facts_table WHERE tipo='Gasto' AND categoria='alquiler'"
        self.controller = controller
        self.db = base_paths.db_file
        self.figura = Figure(figsize=(10, 5), dpi=100)
        self.sub1 = self.figura.add_subplot(1, 2, 1)
        self.sub2 = self.figura.add_subplot(1, 2, 2)
        canvas = FigureCanvasTkAgg(self.figura, self)
        canvas.show()
        canvas.get_tk_widget().pack(expand=True)
        self.datos = self.get_data2(query)
        self.get_pie(self.datos)
        self.datos2 = self.get_data2(query)
        self.get_cat_time(self.datos2)
        self.figura.tight_layout()
#Conect to SQLite and return  query  in DataFrame

    def get_data2(self, query):

        cnx = sqlite3.connect(self.db)
        data = pd.read_sql_query(query, cnx)
        data['fecha'] = pd.to_datetime(data['fecha'], format="%d/%m/%Y")
        return data
# Create a pie chart

    def get_pie(self, data):
        def limpiar_eur(x):
            if type(x) == float:
                return x
            else:
                return float(x.replace(',', '.'))

        data['fecha'] = pd.to_datetime(data['fecha'], format="%d/%m/%Y")
        data['mes'] = data['fecha'].apply(lambda x: x.month)
        data['anio'] = data['fecha'].apply(lambda x: x.year)
        data['eur'] = data['eur'].apply(lambda x: limpiar_eur(x))
        tablita = data[['categoria', 'eur']].pivot_table('eur',
                                                         index='categoria',
                                                         aggfunc='sum')
        self.sub1.pie(tablita.sort_values(by='eur', ascending=False).head(5),
                      labels=tablita.sort_values(
                          by='eur', ascending=False).head(5).index,
                      autopct='%1.1f%%',
                      shadow=True,
                      startangle=90)
        labels = tablita.sort_values(by='eur', ascending=False).head(5).index
        self.sub1.legend(labels, loc=3)


# plot selected category by month

    def get_cat_time(self, data):
        def limpiar_eur(x):
            if type(x) == float:
                return x
            else:
                return float(x.replace(',', '.'))

        data['fecha'] = pd.to_datetime(data['fecha'], format="%d/%m/%Y")
        data['mes'] = data['fecha'].apply(lambda x: x.month)
        data['anio'] = data['fecha'].apply(lambda x: x.year)
        data['eur'] = data['eur'].apply(lambda x: limpiar_eur(x))
        tablita = data.pivot_table('eur', index=['anio', 'mes'], aggfunc='sum')
        tablita.index = [
            pd.datetime(anio, mes, 1).date().strftime('%y/%m')
            for (anio, mes) in tablita.index
        ]
        tablita = pd.DataFrame(tablita.to_records())
        self.sub2.plot(tablita['index'].sort_values(), tablita['eur'])
        self.sub2.set_xticklabels(tablita['index'].sort_values(), rotation=60)
        self.sub2.tick_params()
class BrightnessContrastEditor(QObject):

    edited = Signal(float, float)

    reset = Signal()

    def __init__(self, parent=None):
        super().__init__(parent)

        self._data_range = (0, 1)
        self._ui_min, self._ui_max = self._data_range
        self._data = None
        self.histogram = None
        self.histogram_artist = None
        self.line_artist = None

        self.default_auto_threshold = 5000
        self.current_auto_threshold = self.default_auto_threshold

        loader = UiLoader()
        self.ui = loader.load_file('brightness_contrast_editor.ui', parent)

        self.setup_plot()

        self.ui.minimum.setMaximum(NUM_INCREMENTS)
        self.ui.maximum.setMaximum(NUM_INCREMENTS)
        self.ui.brightness.setMaximum(NUM_INCREMENTS)
        self.ui.contrast.setMaximum(NUM_INCREMENTS)

        self.setup_connections()

    def setup_connections(self):
        self.ui.minimum.valueChanged.connect(self.minimum_edited)
        self.ui.maximum.valueChanged.connect(self.maximum_edited)
        self.ui.brightness.valueChanged.connect(self.brightness_edited)
        self.ui.contrast.valueChanged.connect(self.contrast_edited)

        self.ui.set_data_range.pressed.connect(self.select_data_range)
        self.ui.reset.pressed.connect(self.reset_pressed)
        self.ui.auto_button.pressed.connect(self.auto_pressed)

    @property
    def data_range(self):
        return self._data_range

    @data_range.setter
    def data_range(self, v):
        self._data_range = v
        self.clip_ui_range()
        self.ensure_min_max_space('max')
        self.update_gui()

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, v):
        self._data = v
        self.reset_data_range()

    @property
    def data_list(self):
        if self.data is None:
            return []
        elif isinstance(self.data, (tuple, list)):
            return list(self.data)
        elif isinstance(self.data, dict):
            return list(self.data.values())
        else:
            return [self.data]

    @property
    def data_bounds(self):
        if self.data is None:
            return (0, 1)

        data = self.data_list
        mins = [x.min() for x in data]
        maxes = [x.max() for x in data]
        return (min(mins), max(maxes))

    def reset_data_range(self):
        self.data_range = self.data_bounds

    def update_gui(self):
        self.update_brightness()
        self.update_contrast()
        self.update_histogram()
        self.update_range_labels()
        self.update_line()

    @property
    def data_min(self):
        return self.data_range[0]

    @property
    def data_max(self):
        return self.data_range[1]

    @property
    def data_mean(self):
        return np.mean(self.data_range)

    @property
    def data_width(self):
        return self.data_range[1] - self.data_range[0]

    @property
    def ui_min(self):
        return self._ui_min

    @ui_min.setter
    def ui_min(self, v):
        self._ui_min = v
        slider_v = np.interp(v, self.data_range, (0, NUM_INCREMENTS))
        self.ui.minimum.setValue(slider_v)
        self.update_range_labels()
        self.update_line()
        self.modified()

    @property
    def ui_max(self):
        return self._ui_max

    @ui_max.setter
    def ui_max(self, v):
        self._ui_max = v
        slider_v = np.interp(v, self.data_range, (0, NUM_INCREMENTS))
        self.ui.maximum.setValue(slider_v)
        self.update_range_labels()
        self.update_line()
        self.modified()

    def clip_ui_range(self):
        # Clip the ui min and max to be in the data range
        if self.ui_min < self.data_min:
            self.ui_min = self.data_min

        if self.ui_max > self.data_max:
            self.ui_max = self.data_max

    @property
    def ui_mean(self):
        return np.mean((self.ui_min, self.ui_max))

    @ui_mean.setter
    def ui_mean(self, v):
        offset = v - self.ui_mean
        self.ui_range = (self.ui_min + offset, self.ui_max + offset)

    @property
    def ui_width(self):
        return self.ui_max - self.ui_min

    @ui_width.setter
    def ui_width(self, v):
        offset = (v - self.ui_width) / 2
        self.ui_range = (self.ui_min - offset, self.ui_max + offset)

    @property
    def ui_range(self):
        return (self.ui_min, self.ui_max)

    @ui_range.setter
    def ui_range(self, v):
        with block_signals(self, self.ui.minimum, self.ui.maximum):
            self.ui_min = v[0]
            self.ui_max = v[1]

        self.modified()

    @property
    def ui_brightness(self):
        return self.ui.brightness.value() / NUM_INCREMENTS * 100

    @ui_brightness.setter
    def ui_brightness(self, v):
        self.ui.brightness.setValue(v / 100 * NUM_INCREMENTS)

    @property
    def ui_contrast(self):
        return self.ui.contrast.value() / NUM_INCREMENTS * 100

    @ui_contrast.setter
    def ui_contrast(self, v):
        self.ui.contrast.setValue(v / 100 * NUM_INCREMENTS)

    @property
    def contrast(self):
        angle = np.arctan((self.ui_width - self.data_width) / self.data_width)
        return 100 - np.interp(angle, (-np.pi / 4, np.pi / 4), (0, 100))

    @contrast.setter
    def contrast(self, v):
        angle = np.interp(100 - v, (0, 100), (-np.pi / 4, np.pi / 4))
        self.ui_width = np.tan(angle) * self.data_width + self.data_width

    @property
    def brightness(self):
        return 100 - np.interp(self.ui_mean, self.data_range, (0, 100))

    @brightness.setter
    def brightness(self, v):
        self.ui_mean = np.interp(100 - v, (0, 100), self.data_range)

    def ensure_min_max_space(self, one_to_change):
        # Keep the maximum at least one increment ahead of the minimum
        if self.ui.maximum.value() > self.ui.minimum.value():
            return

        if one_to_change == 'max':
            w = self.ui.maximum
            v = self.ui.minimum.value() + 1
            a = '_ui_max'
        else:
            w = self.ui.minimum
            v = self.ui.maximum.value() - 1
            a = '_ui_min'

        with block_signals(w):
            w.setValue(v)

        interpolated = np.interp(v, (0, NUM_INCREMENTS), self.data_range)
        setattr(self, a, interpolated)

    def minimum_edited(self):
        v = self.ui.minimum.value()
        self._ui_min = np.interp(v, (0, NUM_INCREMENTS), self.data_range)
        self.clip_ui_range()
        self.ensure_min_max_space('max')

        self.update_brightness()
        self.update_contrast()
        self.update_range_labels()
        self.update_line()
        self.modified()

    def maximum_edited(self):
        v = self.ui.maximum.value()
        self._ui_max = np.interp(v, (0, NUM_INCREMENTS), self.data_range)
        self.clip_ui_range()
        self.ensure_min_max_space('min')

        self.update_brightness()
        self.update_contrast()
        self.update_range_labels()
        self.update_line()
        self.modified()

    def update_brightness(self):
        with block_signals(self, self.ui.brightness):
            self.ui_brightness = self.brightness

    def update_contrast(self):
        with block_signals(self, self.ui.contrast):
            self.ui_contrast = self.contrast

    def brightness_edited(self, v):
        self.brightness = self.ui_brightness
        self.update_contrast()

    def contrast_edited(self, v):
        self.contrast = self.ui_contrast
        self.update_brightness()

    def modified(self):
        self.edited.emit(self.ui_min, self.ui_max)

    def setup_plot(self):
        self.figure = Figure()
        self.canvas = FigureCanvas(self.figure)
        self.axis = self.figure.add_subplot(111)

        # Turn off ticks
        self.axis.axis('off')

        self.figure.tight_layout()

        self.ui.plot_layout.addWidget(self.canvas)

    def clear_plot(self):
        self.axis.clear()
        self.histogram_artist = None
        self.line_artist = None

    def update_histogram(self):
        # Clear the plot so everything will be re-drawn from scratch
        self.clear_plot()

        data = self.data_list
        if not data:
            return

        histograms = []
        for datum in data:
            kwargs = {
                'a': datum,
                'bins': HISTOGRAM_NUM_BINS,
                'range': self.data_range,
            }
            hist, bins = np.histogram(**kwargs)
            histograms.append(hist)

        self.histogram = sum(histograms)
        kwargs = {
            'x': self.histogram,
            'bins': HISTOGRAM_NUM_BINS,
            'color': 'black',
        }
        self.histogram_artist = self.axis.hist(**kwargs)[2]

        self.canvas.draw()

    def update_range_labels(self):
        labels = (self.ui.min_label, self.ui.max_label)
        texts = [f'{x:.2f}' for x in self.ui_range]
        for label, text in zip(labels, texts):
            label.setText(text)

    def create_line(self):
        xs = (self.ui_min, self.ui_max)
        ys = self.axis.get_ylim()
        kwargs = {
            'scalex': False,
            'scaley': False,
            'color': 'black',
        }
        self.line_artist, = self.axis.plot(xs, ys, **kwargs)

    def update_line(self):
        if self.line_artist is None:
            self.create_line()

        xs = (self.ui_min, self.ui_max)
        ys = self.axis.get_ylim()

        xlim = self.axis.get_xlim()

        # Rescale the xs to be in the plot scaling
        interp = interp1d(self.data_range, xlim, fill_value='extrapolate')

        self.line_artist.set_data(interp(xs), ys)
        self.canvas.draw_idle()

    @property
    def max_num_pixels(self):
        return max(np.prod(x.shape) for x in self.data_list)

    def select_data_range(self):
        dialog = QDialog(self.ui)
        layout = QVBoxLayout()
        dialog.setLayout(layout)

        range_widget = RangeWidget(dialog)
        range_widget.bounds = self.data_bounds
        range_widget.min = self.data_range[0]
        range_widget.max = self.data_range[1]
        layout.addWidget(range_widget.ui)

        buttons = QDialogButtonBox.Ok | QDialogButtonBox.Cancel
        button_box = QDialogButtonBox(buttons, dialog)
        button_box.accepted.connect(dialog.accept)
        button_box.rejected.connect(dialog.reject)
        layout.addWidget(button_box)

        if not dialog.exec_():
            # User canceled
            return

        data_range = range_widget.range
        if data_range[0] >= data_range[1]:
            message = 'Min cannot be greater than or equal to the max'
            QMessageBox.critical(self.ui, 'Validation Error', message)
            return

        if self.data_range == data_range:
            # Nothing changed...
            return

        self.data_range = data_range
        self.modified()

    def reset_pressed(self):
        self.reset_data_range()
        self.reset_auto_threshold()
        self.reset.emit()

    def reset_auto_threshold(self):
        self.current_auto_threshold = self.default_auto_threshold

    def auto_pressed(self):
        data_range = self.data_range
        hist = self.histogram

        if hist is None:
            return

        # FIXME: should we do something other than max_num_pixels?
        pixel_count = self.max_num_pixels
        num_bins = len(hist)
        hist_start = data_range[0]
        bin_size = self.data_width / num_bins
        auto_threshold = self.current_auto_threshold

        # Perform the operation as ImageJ does it
        if auto_threshold < 10:
            auto_threshold = self.default_auto_threshold
        else:
            auto_threshold /= 2

        self.current_auto_threshold = auto_threshold

        limit = pixel_count / 10
        threshold = pixel_count / auto_threshold
        for i, count in enumerate(hist):
            if threshold < count <= limit:
                break

        h_min = i

        for i, count in reversed_enumerate(hist):
            if threshold < count <= limit:
                break

        h_max = i

        if h_max < h_min:
            # Reset the range
            self.reset_auto_threshold()
            self.ui_range = self.data_range
        else:
            vmin = hist_start + h_min * bin_size
            vmax = hist_start + h_max * bin_size
            if vmin == vmax:
                vmin, vmax = data_range

            self.ui_range = vmin, vmax

        self.update_brightness()
        self.update_contrast()
Example #49
0
    canvas.draw()


#Make a matplotlib plot in a tkinter window

root = tk.Tk()
root.title('Matplotlib in window')

#make fake data
x = np.arange(0, 4, 0.01)
y = x**2

#make plot
fig = Figure()
#can also do tight layout option here, if needed
fig.tight_layout()

#how to create plot and add to it
MyPlot = fig.add_subplot()
MyPlot.plot(x, y, label='this is data')
MyPlot.set_xlabel('x axis')
MyPlot.set_title('title')
MyPlot.legend(loc='best')
MyPlot.tick_params(direction='in')

canvas = FigureCanvasTkAgg(fig, master=root)
canvas.draw()
canvas.get_tk_widget().grid(row=0, column=0, columnspan=2)

#making a button to check
b = tk.Button(root, text='quit here', command=root.destroy)
class ParticleTrackerWindow(QParticleTrackerWindow, Ui_ParticleTrackerWindow):
    def __init__(self, dataContainer, colormapsCollection, dataPlotter):
        super(ParticleTrackerWindow, self).__init__()
        self.setupUi(self)
        self.particleTracker = ParticleTracker(dataContainer)
        self.colormapsCollection = colormapsCollection
        self.dataPlotter = dataPlotter
        self.selectorSubplot = None
        self.evolSubplotList = list()
        self.instantSubplotList = list()
        self.evolSubplotRows = 1
        self.evolSubplotColumns = 1
        self.instantSubplotRows = 1
        self.instantSubplotColumns = 1
        self.increaseEvolRowsColumnsCounter = 0
        self.increaseInstantRowsColumnsCounter = 0
        self.instantTimeSteps = np.zeros(1)
        self.CreateCanvasAndFigures()
        self.FillInitialUI()
        self.RegisterUIEvents()

    def CreateCanvasAndFigures(self):
        # Graphic selector
        self.selectorFigure = Figure()
        self.selectorFigure.patch.set_facecolor("white")
        self.selectorCanvas = FigureCanvas(self.selectorFigure)
        self.selectorPlot_layout.addWidget(self.selectorCanvas)
        self.selectorCanvas.draw()
        # Evolution plots
        self.mainFigure = Figure()
        self.mainFigure.patch.set_facecolor("white")
        self.mainCanvas = FigureCanvas(self.mainFigure)
        self.mainPlot_layout.addWidget(self.mainCanvas)
        self.mainCanvas.draw()
        self.toolbar = NavigationToolbar(self.mainCanvas,
                                         self.mainPlot_widget,
                                         coordinates=True)
        self.mainPlot_layout.addWidget(self.toolbar)
        # Instant plots
        self.instantPlotsFigure = Figure()
        self.instantPlotsFigure.patch.set_facecolor("white")
        self.instantPlotsCanvas = FigureCanvas(self.instantPlotsFigure)
        self.instantPlots_layout.addWidget(self.instantPlotsCanvas)
        self.instantPlotsCanvas.draw()
        self.instantPlotsToolbar = NavigationToolbar(self.instantPlotsCanvas,
                                                     self.instantPlots_widget,
                                                     coordinates=True)
        self.instantPlots_layout.addWidget(self.instantPlotsToolbar)

    def CreateSelectorSubplotObject(self):
        #if self.selectorSubplot == None or self.selectorSubplot.GetPlottedSpeciesName() != self.speciesSelector_comboBox.currentText():
        speciesName = str(self.speciesSelector_comboBox.currentText())
        dataSets = {}
        xAxis = str(self.xAxis_comboBox.currentText())
        yAxis = str(self.yAxis_comboBox.currentText())
        dataSets["x"] = RawDataSetToPlot(
            self.particleTracker.GetSpeciesDataSet(speciesName, xAxis))
        dataSets["y"] = RawDataSetToPlot(
            self.particleTracker.GetSpeciesDataSet(speciesName, yAxis))
        dataSets["weight"] = RawDataSetToPlot(
            self.particleTracker.GetSpeciesDataSet(speciesName, "Charge"))
        self.selectorSubplot = RawDataSubplot(1, self.colormapsCollection,
                                              dataSets)
        self.selectorSubplot.SetPlotType("Scatter")
        self.selectorSubplot.SetPlotProperty("General", "DisplayColorbar",
                                             False)
        self.selectorSubplot.SetAxisProperty("x", "LabelFontSize", 10)
        self.selectorSubplot.SetAxisProperty("y", "LabelFontSize", 10)
        self.selectorSubplot.SetTitleProperty("FontSize", 0)

    def RegisterUIEvents(self):
        self.selectorTimeStep_Slider.valueChanged.connect(
            self.SelectorTimeStepSlider_ValueChanged)
        self.selectorTimeStep_Slider.sliderReleased.connect(
            self.SelectorTimeStepSlider_Released)
        self.instantTimeStep_Slider.valueChanged.connect(
            self.InstantTimeStepSlider_ValueChanged)
        self.instantTimeStep_Slider.sliderReleased.connect(
            self.InstantTimeStepSlider_Released)
        self.speciesSelector_comboBox.currentIndexChanged.connect(
            self.SpeciesSelectorComboBox_IndexChanged)
        self.xAxis_comboBox.currentIndexChanged.connect(
            self.AxisComboBox_IndexChanged)
        self.yAxis_comboBox.currentIndexChanged.connect(
            self.AxisComboBox_IndexChanged)
        self.rectangleSelection_Button.clicked.connect(
            self.RectangleSelectionButton_Clicked)
        self.findParticles_button.clicked.connect(
            self.FindParticlesButton_Clicked)
        self.trackParticles_Button.clicked.connect(
            self.TrackParticlesButton_Clicked)
        self.plotType_radioButton_1.toggled.connect(
            self.PlotTypeRadioButton_Toggled)
        self.plotType_radioButton_2.toggled.connect(
            self.PlotTypeRadioButton_Toggled)
        self.instPlotType_radioButton.toggled.connect(
            self.InstantPlotTypeRadioButton_Toggled)
        self.instPlotType_radioButton_2.toggled.connect(
            self.InstantPlotTypeRadioButton_Toggled)
        self.addToPlot_Button.clicked.connect(self.AddToPlotButton_Clicked)
        self.addToInstantPlot_Button.clicked.connect(
            self.AddToInstantPlotButton_Clicked)
        self.plot_pushButton.clicked.connect(self.PlotPushButton_Clicked)
        self.plotInstant_pushButton.clicked.connect(
            self.PlotInstantPushButton_Clicked)
        self.manualSelection_radioButton.toggled.connect(
            self.SelectMode_Changed)
        self.selectAll_radioButton.toggled.connect(self.SelectMode_Changed)
        self.selectFraction_radioButton.toggled.connect(
            self.SelectMode_Changed)
        self.browseExportPath_pushButton.clicked.connect(
            self.BrowseExportPathButton_Clicked)
        self.selectAllExport_checkBox.toggled.connect(
            self.SelectAllExportCheckBox_StatusChanged)
        self.exportData_pushButton.clicked.connect(
            self.ExportDataButton_Clicked)
        self.nextStep_Button.clicked.connect(self.NextButton_Clicked)
        self.prevStep_Button.clicked.connect(self.PrevButton_Clicked)

    def FillInitialUI(self):
        comboBoxItems = self.particleTracker.GetSpeciesNames()
        if len(comboBoxItems) > 0:
            comboBoxItems.insert(0, "Select Species")
        else:
            comboBoxItems.insert(0, "No species available")
        self.speciesSelector_comboBox.addItems(comboBoxItems)

    def FillEvolutionPlotsUI(self):
        self.x_comboBox.clear()
        self.y_comboBox.clear()
        self.z_comboBox.clear()
        self.x_comboBox.addItems(
            self.particleTracker.
            GetAvailableWholeSimulationQuantitiesInParticles())
        self.y_comboBox.addItems(
            self.particleTracker.
            GetAvailableWholeSimulationQuantitiesInParticles())
        self.z_comboBox.addItems(
            self.particleTracker.
            GetAvailableWholeSimulationQuantitiesInParticles())
        self.trackedParticles_Label.setText(
            "Tracking " +
            str(self.particleTracker.GetTotalNumberOfTrackedParticles()) +
            " particle(s)")

    def FillInstantPlotsUI(self):
        self.instX_comboBox.clear()
        self.instY_comboBox.clear()
        self.instZ_comboBox.clear()
        self.instX_comboBox.addItems(
            self.particleTracker.GetNamesOfInstantRawDataSets())
        self.instY_comboBox.addItems(
            self.particleTracker.GetNamesOfInstantRawDataSets())
        self.instZ_comboBox.addItems(
            self.particleTracker.GetNamesOfInstantRawDataSets())
        self.trackedParticles_Label_3.setText(
            "Tracking " +
            str(self.particleTracker.GetTotalNumberOfTrackedParticles()) +
            " particle(s)")

    def FillExportDataUI(self):
        self.CreateTrackedParticlesTable()
        self.exportPath_lineEdit.setText(
            self.particleTracker.GetDataLocation())

    """
    UI Events
    """

    def SelectorTimeStepSlider_ValueChanged(self):
        self.selectorTimeStep_lineEdit.setText(
            str(self.selectorTimeStep_Slider.value()))

    def SelectorTimeStepSlider_Released(self):
        if self.selectorTimeStep_Slider.value() not in self.timeSteps:
            val = self.selectorTimeStep_Slider.value()
            closestHigher = self.timeSteps[np.where(
                self.timeSteps > val)[0][0]]
            closestLower = self.timeSteps[np.where(
                self.timeSteps < val)[0][-1]]
            if abs(val - closestHigher) < abs(val - closestLower):
                self.selectorTimeStep_Slider.setValue(closestHigher)
            else:
                self.selectorTimeStep_Slider.setValue(closestLower)
        self.MakeSelectorPlot()
        self._CreateFineSelectionTable()

    def InstantTimeStepSlider_ValueChanged(self):
        self.instantTimeStep_lineEdit.setText(
            str(self.instantTimeStep_Slider.value()))

    def InstantTimeStepSlider_Released(self):
        if self.instantTimeStep_Slider.value() not in self.instantTimeSteps:
            val = self.instantTimeStep_Slider.value()
            closestHigher = self.timeSteps[np.where(
                self.instantTimeSteps > val)[0][0]]
            closestLower = self.timeSteps[np.where(
                self.instantTimeSteps < val)[0][-1]]
            if abs(val - closestHigher) < abs(val - closestLower):
                self.instantTimeStep_Slider.setValue(closestHigher)
            else:
                self.instantTimeStep_Slider.setValue(closestLower)
        self.MakeInstantPlots()

    def NextButton_Clicked(self):
        currentTimeStep = self.instantTimeStep_Slider.value()
        currentIndex = np.where(self.instantTimeSteps == currentTimeStep)[0][0]
        if currentIndex < len(self.instantTimeSteps) - 1:
            self.instantTimeStep_Slider.setValue(
                self.instantTimeSteps[currentIndex + 1])
        self.MakeInstantPlots()

    def PrevButton_Clicked(self):
        currentTimeStep = self.instantTimeStep_Slider.value()
        currentIndex = np.where(self.instantTimeSteps == currentTimeStep)[0][0]
        if currentIndex > 0:
            self.instantTimeStep_Slider.setValue(
                self.instantTimeSteps[currentIndex - 1])
        self.MakeInstantPlots()

    def SpeciesSelectorComboBox_IndexChanged(self):
        self._SetGraphicSelectorComboBoxes()
        self.MakeSelectorPlot()
        self._CreateFineSelectionTable()

    def _SetGraphicSelectorComboBoxes(self):
        self._updatingUI = True
        speciesName = str(self.speciesSelector_comboBox.currentText())
        axisList = self.particleTracker.GetSpeciesRawDataSetNames(speciesName)
        self.xAxis_comboBox.clear()
        self.yAxis_comboBox.clear()
        self.xAxis_comboBox.addItems(axisList)
        self.yAxis_comboBox.addItems(axisList)
        if "z" in axisList:
            self.xAxis_comboBox.setCurrentIndex(axisList.index("z"))
        if "y" in axisList:
            self.yAxis_comboBox.setCurrentIndex(axisList.index("y"))
        self._updatingUI = False

    def AxisComboBox_IndexChanged(self):
        if not self._updatingUI:
            self.MakeSelectorPlot()

    def RectangleSelectionButton_Clicked(self):
        self.toggle_selector.set_active(True)

    def FindParticlesButton_Clicked(self):
        speciesName = str(self.speciesSelector_comboBox.currentText())
        timeStep = self.selectorTimeStep_Slider.value()
        filters = self.GetSelectedFilters()
        self.FindParticles(timeStep, speciesName, filters)

    def TrackParticlesButton_Clicked(self):
        self.particleTracker.SetParticlesToTrack(self.GetSelectedParticles())
        self.particleTracker.FillEvolutionOfAllDataSetsInParticles()
        self.particleTracker.MakeInstantaneousRawDataSets()
        self.FillEvolutionPlotsUI()
        self.FillInstantPlotsUI()
        self.FillExportDataUI()

    def PlotTypeRadioButton_Toggled(self):
        self.z_comboBox.setEnabled(self.plotType_radioButton_2.isChecked())

    def InstantPlotTypeRadioButton_Toggled(self):
        self.instZ_comboBox.setEnabled(
            self.instPlotType_radioButton_2.isChecked())

    def AddToPlotButton_Clicked(self):
        xDataSetName = str(self.x_comboBox.currentText())
        yDataSetName = str(self.y_comboBox.currentText())
        zDataSetName = None
        if self.plotType_radioButton_2.isChecked():
            zDataSetName = str(self.z_comboBox.currentText())
        plotPosition = len(self.evolSubplotList) + 1
        subplot = RawDataEvolutionSubplot(
            plotPosition, self.colormapsCollection,
            self.particleTracker.GetTrackedParticlesDataToPlot(
                xDataSetName, yDataSetName, zDataSetName),
            self.particleTracker.GetTrackedSpeciesName())
        self.evolSubplotList.append(subplot)
        self.SetAutoEvolColumnsAndRows()
        wid = SubplotItem(subplot, self.MakeEvolPlots, self)
        wid2 = QtWidgets.QListWidgetItem()
        wid2.setSizeHint(QtCore.QSize(100, 40))
        self.subplots_listWidget.addItem(wid2)
        self.subplots_listWidget.setItemWidget(wid2, wid)

    def AddToInstantPlotButton_Clicked(self):
        xDataSetName = str(self.instX_comboBox.currentText())
        yDataSetName = str(self.instY_comboBox.currentText())
        if self.instPlotType_radioButton_2.isChecked():
            zDataSetName = str(self.instZ_comboBox.currentText())
        plotPosition = len(self.instantSubplotList) + 1
        dataSets = {}
        xDataSet = self.particleTracker.GetInstantRawDataSet(xDataSetName)
        dataSets["x"] = RawDataSetToPlot(xDataSet)
        yDataSet = self.particleTracker.GetInstantRawDataSet(yDataSetName)
        dataSets["y"] = RawDataSetToPlot(yDataSet)
        pxDataSet = self.particleTracker.GetInstantRawDataSet("Px")
        dataSets["Px"] = RawDataSetToPlot(pxDataSet)
        pyDataSet = self.particleTracker.GetInstantRawDataSet("Py")
        dataSets["Py"] = RawDataSetToPlot(pyDataSet)
        if self.instPlotType_radioButton_2.isChecked():
            zDataSet = self.particleTracker.GetInstantRawDataSet(zDataSetName)
            dataSets["z"] = RawDataSetToPlot(zDataSet)
            pzDataSet = self.particleTracker.GetInstantRawDataSet("Pz")
            dataSets["Pz"] = RawDataSetToPlot(pzDataSet)
        weightDataSet = self.particleTracker.GetInstantRawDataSet("Charge")
        dataSets["weight"] = RawDataSetToPlot(weightDataSet)
        subplot = RawDataSubplot(plotPosition, self.colormapsCollection,
                                 dataSets)
        self.instantSubplotList.append(subplot)
        self.SetAutoInstantColumnsAndRows()
        wid = SubplotItem(subplot, self.MakeInstantPlots, self)
        wid2 = QtWidgets.QListWidgetItem()
        wid2.setSizeHint(QtCore.QSize(100, 40))
        self.instantSubplots_listWidget.addItem(wid2)
        self.instantSubplots_listWidget.setItemWidget(wid2, wid)
        self.SetInstantTimeSteps()

    def PlotPushButton_Clicked(self):
        self.MakeEvolPlots()

    def PlotInstantPushButton_Clicked(self):
        self.MakeInstantPlots()

    def SelectMode_Changed(self):
        if self.selectAll_radioButton.isChecked():
            for row in np.arange(0, self.particleList_tableWidget.rowCount()):
                item = self.particleList_tableWidget.item(row, 0)
                item.setCheckState(QtCore.Qt.Checked)
        elif self.manualSelection_radioButton.isChecked():
            for row in np.arange(0, self.particleList_tableWidget.rowCount()):
                item = self.particleList_tableWidget.item(row, 0)
                item.setCheckState(QtCore.Qt.Unchecked)
        elif self.selectFraction_radioButton.isChecked():
            step = self.selectFraction_spinBox.value()
            for row in np.arange(0, self.particleList_tableWidget.rowCount()):
                item = self.particleList_tableWidget.item(row, 0)
                if row % step == 0:
                    item.setCheckState(QtCore.Qt.Checked)
                else:
                    item.setCheckState(QtCore.Qt.Unchecked)

    def SelectAllExportCheckBox_StatusChanged(self):
        if self.selectAllExport_checkBox.checkState():
            for row in np.arange(
                    0, self.trackedParticlesList_tableWidget.rowCount()):
                item = self.trackedParticlesList_tableWidget.item(row, 0)
                item.setCheckState(QtCore.Qt.Checked)
        else:
            for row in np.arange(
                    0, self.trackedParticlesList_tableWidget.rowCount()):
                item = self.trackedParticlesList_tableWidget.item(row, 0)
                item.setCheckState(QtCore.Qt.Unchecked)

    def BrowseExportPathButton_Clicked(self):
        self.OpenFolderDialog()

    def ExportDataButton_Clicked(self):
        particleIndices = self.GetIndicesOfParticlesToExport()
        self.particleTracker.ExportParticleData(
            particleIndices, str(self.exportPath_lineEdit.text()))

    """
    Rectangle Selector
    """

    def line_select_callback(self, eclick, erelease):
        'eclick and erelease are the press and release events'
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
        filter = {}
        xAxisVariable = str(self.xAxis_comboBox.currentText())
        yAxisVariable = str(self.yAxis_comboBox.currentText())
        filter[xAxisVariable] = (min(x1, x2), max(x1, x2))
        filter[yAxisVariable] = (min(y1, y2), max(y1, y2))
        self.FindParticles(self.selectorTimeStep_Slider.value(),
                           str(self.speciesSelector_comboBox.currentText()),
                           filter)

    """
    Other functions
    """

    def OpenFolderDialog(self):
        folderPath = str(
            QtWidgets.QFileDialog.getExistingDirectory(
                self, "Export data to:", str(self.exportPath_lineEdit.text())))
        if folderPath != "":
            self.exportPath_lineEdit.setText(folderPath)

    def FindParticles(self, timeStep, speciesName, filter):
        self.particleList = self.particleTracker.FindParticles(
            timeStep, speciesName, filter)
        self.CreateParticleTable()

    def CreateParticleTable(self):
        n = len(self.particleList)
        if n > 0:
            variableNames = self.particleList[0].GetNamesOfTimeStepQuantities()
            allParticlesData = list()
            tableData = {}
            for particle in self.particleList:
                allParticlesData.append(
                    particle.GetCurrentTimeStepQuantities())
            for variableName in variableNames:
                varValues = np.zeros(n)
                for i in np.arange(0, n):
                    varValues[i] = allParticlesData[i][variableName]
                tableData[variableName] = varValues
            self.particleList_tableWidget.setColumnCount(
                len(variableNames) + 1)
            self.particleList_tableWidget.setRowCount(n)
            tableHeaders = variableNames
            tableHeaders.insert(0, " ")
            for i in np.arange(0, n):
                newItem = QtWidgets.QTableWidgetItem()
                newItem.setCheckState(QtCore.Qt.Unchecked)
                self.particleList_tableWidget.setItem(i, 0, newItem)
            for n, key in enumerate(tableHeaders[1:]):
                for m, item in enumerate(tableData[key]):
                    newItem = QtWidgets.QTableWidgetItem(str(item))
                    self.particleList_tableWidget.setItem(m, n + 1, newItem)
            self.particleList_tableWidget.resizeColumnsToContents()
            self.particleList_tableWidget.setHorizontalHeaderLabels(
                tableHeaders)

    def GetSelectedParticles(self):
        selectedParticles = list()
        for row in np.arange(0, self.particleList_tableWidget.rowCount()):
            item = self.particleList_tableWidget.item(row, 0)
            if item.checkState():
                selectedParticles.append(self.particleList[row])
        return selectedParticles

    def GetSelectedFilters(self):
        filters = {}
        for row in np.arange(0, self.advancedSelector_tableWidget.rowCount()):
            item = self.advancedSelector_tableWidget.item(row, 0)
            if item.checkState():
                name = self.advancedSelector_tableWidget.item(row, 1).text()
                minVal = float(
                    self.advancedSelector_tableWidget.item(row, 2).text())
                maxVal = float(
                    self.advancedSelector_tableWidget.item(row, 3).text())
                filters[name] = (minVal, maxVal)
        return filters

    def CreateTrackedParticlesTable(self):
        trackedParticles = self.particleTracker.GetTrackedParticles()
        n = len(trackedParticles)
        variableNames = trackedParticles[0].GetNamesOfTimeStepQuantities()
        allParticlesData = list()
        tableData = {}
        for particle in trackedParticles:
            allParticlesData.append(particle.GetCurrentTimeStepQuantities())
        for variableName in variableNames:
            varValues = np.zeros(n)
            for i in np.arange(0, n):
                varValues[i] = allParticlesData[i][variableName]
            tableData[variableName] = varValues
        self.trackedParticlesList_tableWidget.setColumnCount(
            len(variableNames) + 1)
        self.trackedParticlesList_tableWidget.setRowCount(n)
        tableHeaders = variableNames
        tableHeaders.insert(0, " ")
        for i in np.arange(0, n):
            newItem = QtWidgets.QTableWidgetItem()
            newItem.setCheckState(QtCore.Qt.Unchecked)
            self.trackedParticlesList_tableWidget.setItem(i, 0, newItem)
        for n, key in enumerate(tableHeaders[1:]):
            for m, item in enumerate(tableData[key]):
                newItem = QtWidgets.QTableWidgetItem(str(item))
                self.trackedParticlesList_tableWidget.setItem(
                    m, n + 1, newItem)
        self.trackedParticlesList_tableWidget.setHorizontalHeaderLabels(
            tableHeaders)
        self.trackedParticlesList_tableWidget.resizeColumnsToContents()

    def _CreateFineSelectionTable(self):
        tableData = self._GetMaximumsAndMinimumsOfQuantities()
        tableHeaders = ["Use", "Quantity", "Minimum", "Maximum"]
        n = len(tableData)
        m = len(tableHeaders)
        self.advancedSelector_tableWidget.setRowCount(n)
        self.advancedSelector_tableWidget.setColumnCount(m)
        for i in np.arange(0, n):
            newItem = QtWidgets.QTableWidgetItem()
            newItem.setCheckState(QtCore.Qt.Unchecked)
            self.advancedSelector_tableWidget.setItem(i, 0, newItem)
        for i in np.arange(n):
            for j in np.arange(m - 1):
                newItem = QtWidgets.QTableWidgetItem(str(tableData[i][j]))
                self.advancedSelector_tableWidget.setItem(i, j + 1, newItem)
        self.advancedSelector_tableWidget.setHorizontalHeaderLabels(
            tableHeaders)
        self.advancedSelector_tableWidget.resizeColumnsToContents()

    def _GetMaximumsAndMinimumsOfQuantities(self):
        speciesName = str(self.speciesSelector_comboBox.currentText())
        timeStep = self.selectorTimeStep_Slider.value()
        quantityNamesList = self.particleTracker.GetSpeciesRawDataSetNames(
            speciesName)
        maxMinList = []
        for quantity in quantityNamesList:
            dataSet = self.particleTracker.GetSpeciesDataSet(
                speciesName, quantity)
            data = dataSet.GetDataInOriginalUnits(timeStep)
            dataMin = min(data)
            dataMax = max(data)
            maxMinList.append([quantity, dataMin, dataMax])
        return maxMinList

    def GetIndicesOfParticlesToExport(self):
        selectedParticlesIndices = list()
        for row in np.arange(0,
                             self.trackedParticlesList_tableWidget.rowCount()):
            item = self.trackedParticlesList_tableWidget.item(row, 0)
            if item.checkState():
                selectedParticlesIndices.append(row)
        return selectedParticlesIndices

    def MakeSelectorPlot(self):
        if self.speciesSelector_comboBox.currentText() not in [
                "Select Species", "No species available"
        ]:
            self.CreateSelectorSubplotObject()
            sbpList = list()
            sbpList.append(
                self.selectorSubplot
            )  # we need to create a list of only one subplot because the DataPlotter only accepts lists.
            self.dataPlotter.MakePlot(self.selectorFigure, sbpList, 1, 1,
                                      self.selectorTimeStep_Slider.value())
            ax = self.selectorFigure.axes[0]
            self.toggle_selector = RectangleSelector(
                ax,
                self.line_select_callback,
                drawtype='box',
                useblit=True,
                button=[1, 3],  # don't use middle button
                minspanx=5,
                minspany=5,
                spancoords='pixels',
                interactive=True)
            self.toggle_selector.set_active(False)
            self.selectorFigure.tight_layout()
            self.selectorCanvas.draw()
            self.SetTimeSteps()

    def SetTimeSteps(self):
        self.timeSteps = self.selectorSubplot.GetTimeSteps()
        minTime = min(self.timeSteps)
        maxTime = max(self.timeSteps)
        self.selectorTimeStep_Slider.setMinimum(minTime)
        self.selectorTimeStep_Slider.setMaximum(maxTime)

    def SetAutoEvolColumnsAndRows(self):
        if self.evolSubplotRows * self.evolSubplotColumns < len(
                self.evolSubplotList):
            if self.increaseEvolRowsColumnsCounter % 2 == 0:
                self.evolSubplotRows += 1
            else:
                self.evolSubplotColumns += 1
            self.increaseEvolRowsColumnsCounter += 1

    def SetAutoInstantColumnsAndRows(self):
        if self.instantSubplotRows * self.instantSubplotColumns < len(
                self.instantSubplotList):
            if self.increaseInstantRowsColumnsCounter % 2 == 0:
                self.instantSubplotRows += 1
            else:
                self.instantSubplotColumns += 1
            self.increaseInstantRowsColumnsCounter += 1

    def RemoveSubplot(self, item):
        index = self.evolSubplotList.index(item.subplot)
        self.evolSubplotList.remove(item.subplot)
        self.subplots_listWidget.takeItem(index)
        for subplot in self.evolSubplotList:
            if subplot.GetPosition() > index + 1:
                subplot.SetPosition(subplot.GetPosition() - 1)
        if len(self.evolSubplotList) > 0:
            if self.increaseEvolRowsColumnsCounter % 2 == 0:
                if len(self.evolSubplotList) <= self.evolSubplotRows * (
                        self.evolSubplotColumns - 1):
                    self.evolSubplotColumns -= 1
                    self.increaseEvolRowsColumnsCounter -= 1
            else:
                if len(self.evolSubplotList) <= (self.evolSubplotRows -
                                                 1) * self.evolSubplotColumns:
                    self.evolSubplotRows -= 1
                    self.increaseEvolRowsColumnsCounter -= 1

    def MakeEvolPlots(self):
        self.dataPlotter.MakePlot(self.mainFigure, self.evolSubplotList,
                                  self.evolSubplotRows,
                                  self.evolSubplotColumns)
        self.mainCanvas.draw()

    def MakeInstantPlots(self):
        timeStep = self.instantTimeStep_Slider.value()
        self.dataPlotter.MakePlot(self.instantPlotsFigure,
                                  self.instantSubplotList,
                                  self.instantSubplotRows,
                                  self.instantSubplotColumns, timeStep)
        self.instantPlotsCanvas.draw()

    def SetInstantTimeSteps(self):
        i = 0
        for subplot in self.instantSubplotList:
            if i == 0:
                self.instantTimeSteps = subplot.GetTimeSteps()
            else:
                self.instantTimeSteps = np.intersect1d(self.instantTimeSteps,
                                                       subplot.GetTimeSteps())
            i += 1
        minTime = min(self.instantTimeSteps)
        maxTime = max(self.instantTimeSteps)
        self.instantTimeStep_Slider.setMinimum(minTime)
        self.instantTimeStep_Slider.setMaximum(maxTime)
Example #51
0
def plot_zone_vector_and_atom_distance_map(image_data,
                                           distance_data,
                                           atom_planes=None,
                                           distance_data_scale=1,
                                           atom_list=None,
                                           extra_marker_list=None,
                                           clim=None,
                                           atom_plane_marker=None,
                                           plot_title='',
                                           vector_to_plot=None,
                                           figsize=(10, 20),
                                           figname="map_data.jpg"):
    """
    Parameters
    ----------
    atom_list : list of Atom_Position instances
    extra_marker_list : two arrays of x and y [[x_values], [y_values]]
    """
    if image_data is None:
        raise ValueError("Image data is None, no data to plot")

    fig = Figure(figsize=figsize)
    FigureCanvas(fig)

    gs = GridSpec(95, 95)

    image_ax = fig.add_subplot(gs[0:45, :])
    distance_ax = fig.add_subplot(gs[45:90, :])
    colorbar_ax = fig.add_subplot(gs[90:, :])

    image_clim = to._get_clim_from_data(image_data, sigma=2)
    image_cax = image_ax.imshow(image_data)
    image_cax.set_clim(image_clim[0], image_clim[1])
    if atom_planes:
        for atom_plane_index, atom_plane in enumerate(atom_planes):
            x_pos = atom_plane.get_x_position_list()
            y_pos = atom_plane.get_y_position_list()
            image_ax.plot(x_pos, y_pos, lw=3, color='blue')
            image_ax.text(atom_plane.start_atom.pixel_x,
                          atom_plane.start_atom.pixel_y,
                          str(atom_plane_index),
                          color='red')
    image_ax.set_ylim(0, image_data.shape[0])
    image_ax.set_xlim(0, image_data.shape[1])
    image_ax.set_title(plot_title)

    if atom_plane_marker:
        atom_plane_x = atom_plane_marker.get_x_position_list()
        atom_plane_y = atom_plane_marker.get_y_position_list()
        image_ax.plot(atom_plane_x, atom_plane_y, color='red', lw=2)

    _make_subplot_map_from_regular_grid(
        distance_ax,
        distance_data,
        distance_data_scale=distance_data_scale,
        clim=clim,
        atom_list=atom_list,
        atom_plane_marker=atom_plane_marker,
        extra_marker_list=extra_marker_list,
        vector_to_plot=vector_to_plot)
    distance_cax = distance_ax.images[0]

    fig.tight_layout()
    fig.colorbar(distance_cax, cax=colorbar_ax, orientation='horizontal')
    fig.savefig(figname)
    plt.close(fig)
Example #52
0
    def create_widgets(self):
        """Create Widgets for the main window"""

        # parent frame
        overall = Frame(self.parent, relief=RIDGE, background='gray40')
        overall.pack(side=LEFT, fill=BOTH)

        leftFrame = Frame(overall, relief=FLAT, background='gray40')
        leftFrame.pack(side=TOP, fill=Y)

        self.centerFrame = Frame(self.parent,
                                 relief=RIDGE,
                                 background='gray40')
        self.centerFrame.pack(side=LEFT, fill=Y)

        rightFrame = Frame(self.parent, relief=FLAT, background='gray30')
        rightFrame.pack(side=TOP, fill=Y)

        # Mode selection Frame
        self.mode = LabelFrame(leftFrame,
                               text="Mode",
                               font=('arial', 11, 'bold'))
        self.mode.configure(background='gray40')
        self.mode.pack(anchor=W)
        self.mode_var = IntVar()

        self.manual_mode = Radiobutton(self.mode,
                                       text='Manual',
                                       value=1,
                                       variable=self.mode_var,
                                       indicatoron=0,
                                       fg='red',
                                       background='gray10',
                                       font=('verdana', 11, 'bold'),
                                       width=10,
                                       command=self.check_current_mode,
                                       selectcolor='lawngreen').pack(side=LEFT,
                                                                     padx=18,
                                                                     pady=5)

        self.intelligent_mode = Radiobutton(self.mode,
                                            text='Intelligent',
                                            value=2,
                                            variable=self.mode_var,
                                            indicatoron=0,
                                            fg='red',
                                            background='gray10',
                                            font=('verdana', 11, 'bold'),
                                            width=10,
                                            command=self.check_current_mode,
                                            selectcolor='lawngreen').pack(
                                                side=LEFT, padx=18, pady=5)

        self.remote_mode = Radiobutton(self.mode,
                                       text='Remote',
                                       value=3,
                                       variable=self.mode_var,
                                       indicatoron=0,
                                       fg='red',
                                       background='gray10',
                                       font=('verdana', 11, 'bold'),
                                       width=10,
                                       command=self.check_current_mode,
                                       selectcolor='lawngreen').pack(side=LEFT,
                                                                     padx=18,
                                                                     pady=5)

        # ==============================Pick coordinates frame================================================
        self.pick_group = LabelFrame(leftFrame,
                                     text="Pick Position",
                                     font=('arial', 11, 'bold'))
        self.pick_group.configure(background='gray40')
        self.pick_group.pack(anchor=W)

        Label(
            self.pick_group,
            font=('verdana', 11),
            text="Enter coordinates of the desired position of gripper center",
            background='gray40').pack()

        # picking position coordinates entry frame
        pick_entry_frame = Frame(self.pick_group,
                                 relief=FLAT,
                                 background='gray40')
        pick_entry_frame.configure(background='gray40')

        x_coords = Frame(pick_entry_frame, relief=FLAT, background='gray40')
        x_coords.pack()
        x_label = Label(x_coords,
                        text='x:                      ',
                        font=('verdana', 11),
                        background='gray40')
        self.x = numericValidator.NumericEntry(x_coords,
                                               font=('verdana', 11, 'bold'),
                                               background='gray70',
                                               width=10)
        x_label.pack(side=LEFT)
        self.x.pack(side=LEFT)

        y_coords = Frame(pick_entry_frame, relief=FLAT, background='gray40')
        y_coords.pack()
        y_label = Label(y_coords,
                        text='y:                      ',
                        font=('verdana', 11),
                        background='gray40')
        self.y = numericValidator.NumericEntry(y_coords,
                                               font=('verdana', 11, 'bold'),
                                               background='gray70',
                                               width=10)
        y_label.pack(side=LEFT)
        self.y.pack(side=LEFT)

        z_coords = Frame(pick_entry_frame, relief=FLAT, background='gray40')
        z_coords.pack()
        z_label = Label(z_coords,
                        text='z:                      ',
                        font=('verdana', 11),
                        background='gray40')
        self.z = numericValidator.NumericEntry(z_coords,
                                               font=('verdana', 11, 'bold'),
                                               background='gray70',
                                               width=10)
        z_label.pack(side=LEFT)
        self.z.pack(side=LEFT)

        # Inverse Kinematics computing button
        self.pick_ik = Button(pick_entry_frame,
                              text="Compute IK",
                              borderwidth=1,
                              font=('verdana', 11, 'bold'),
                              relief=SOLID,
                              background='gray10',
                              fg='orange')

        self.pick_ik.pack(pady=5)

        pick_entry_frame.pack()

        # pick position joint angles
        Label(self.pick_group,
              font=('verdana', 11),
              text="Joint Angles (deg)",
              background='gray40').pack(pady=5)
        joint_angles_frame = Frame(self.pick_group,
                                   relief=FLAT,
                                   background='gray40')

        shoulder_az = Frame(joint_angles_frame,
                            relief=FLAT,
                            background='gray40')
        shoulder_az.pack()
        Label(shoulder_az,
              text='Shoulder Azimuth:',
              font=('verdana', 11),
              background='gray40').pack(side=LEFT)
        self.shoulder_azimuth = numericValidator.NumericEntry(
            shoulder_az,
            width=10,
            font=('verdana', 11, 'bold'),
            background='gray70',
        )
        self.shoulder_azimuth.pack(side=LEFT)

        shoulder_pivot = Frame(joint_angles_frame,
                               relief=FLAT,
                               background='gray40')
        shoulder_pivot.pack()
        Label(shoulder_pivot,
              text='Shoulder Pivot:     ',
              font=('verdana', 11),
              background='gray40').pack(side=LEFT)
        self.shoulder_pivot = numericValidator.NumericEntry(
            shoulder_pivot,
            width=10,
            font=('verdana', 11, 'bold'),
            background='gray70',
        )
        self.shoulder_pivot.pack(side=LEFT)

        elbow_p = Frame(joint_angles_frame, relief=FLAT, background='gray40')
        elbow_p.pack()
        Label(elbow_p,
              text='Elbow Pivot:         ',
              font=('verdana', 11),
              background='gray40').pack(side=LEFT)
        self.elbow_pivot = numericValidator.NumericEntry(elbow_p,
                                                         width=10,
                                                         font=('verdana', 11,
                                                               'bold'),
                                                         background='gray70')
        self.elbow_pivot.pack(side=LEFT)

        wrist_p = Frame(joint_angles_frame, relief=FLAT, background='gray40')
        wrist_p.pack()
        Label(wrist_p,
              text='Wrist Pitch:          ',
              font=('verdana', 11),
              background='gray40').pack(side=LEFT)
        self.wrist_pivot = numericValidator.NumericEntry(wrist_p,
                                                         width=10,
                                                         font=('verdana', 11,
                                                               'bold'),
                                                         background='gray70')
        self.wrist_pivot.pack(side=LEFT)

        joint_angles_frame.pack(side=TOP)

        # pick_x.component('entry').focus_set()
        # ===================================END================================================================

        # ===================================placing coordinates frame======================================
        self.place_group = LabelFrame(leftFrame,
                                      text='Place Position',
                                      font=('arial', 11, 'bold'),
                                      background='gray40')
        self.place_group.pack(anchor=W)

        Label(
            self.place_group,
            font=('verdana', 11),
            text="Enter coordinates of the desired position of gripper center",
            background='gray40').pack()

        # picking position coordinates entry frame
        place_entry_frame = Frame(self.place_group, relief=FLAT)

        place_x_coords = Frame(place_entry_frame,
                               relief=FLAT,
                               background='gray40')
        place_x_coords.pack()
        place_x_label = Label(place_x_coords,
                              text='x:                      ',
                              font=('verdana', 11),
                              background='gray40')
        self.place_x = numericValidator.NumericEntry(place_x_coords,
                                                     font=('verdana', 11,
                                                           'bold'),
                                                     background='gray70',
                                                     width=10)
        place_x_label.pack(side=LEFT)
        self.place_x.pack(side=LEFT)

        place_y_coords = Frame(place_entry_frame,
                               relief=FLAT,
                               background='gray40')
        place_y_coords.pack()
        place_y_label = Label(place_y_coords,
                              text='y:                      ',
                              font=('verdana', 11),
                              background='gray40')
        self.place_y = numericValidator.NumericEntry(place_y_coords,
                                                     font=('verdana', 11,
                                                           'bold'),
                                                     background='gray70',
                                                     width=10)
        place_y_label.pack(side=LEFT)
        self.place_y.pack(side=LEFT)

        place_z_coords = Frame(place_entry_frame,
                               relief=FLAT,
                               background='gray40')
        place_z_coords.pack()
        place_z_label = Label(place_z_coords,
                              text='z:                      ',
                              font=('verdana', 11),
                              background='gray40')
        self.place_z = numericValidator.NumericEntry(place_z_coords,
                                                     font=('verdana', 11,
                                                           'bold'),
                                                     background='gray70',
                                                     width=10)
        place_z_label.pack(side=LEFT)
        self.place_z.pack(side=LEFT)

        # Inverse Kinematics computing button
        self.place_ik = Button(self.place_group,
                               text="Compute IK",
                               borderwidth=1,
                               font=('verdana', 11, 'bold'),
                               relief=SOLID,
                               background='gray10',
                               fg='orange')

        place_entry_frame.pack()

        self.place_ik.pack(pady=5)

        # pick position joint angles
        Label(self.place_group,
              font=('verdana', 11),
              text="Joint Angles (deg)",
              background='gray40').pack(pady=5)
        place_joint_angles = Frame(self.place_group,
                                   relief=FLAT,
                                   background='gray40')

        shoulder_az = Frame(place_joint_angles,
                            relief=FLAT,
                            background='gray40')
        shoulder_az.pack()
        Label(shoulder_az,
              text='Shoulder Azimuth:',
              font=('verdana', 11),
              background='gray40').pack(side=LEFT)
        self.shoulder_azimuth = numericValidator.NumericEntry(
            shoulder_az,
            width=10,
            font=('verdana', 11, 'bold'),
            background='gray70',
        )
        self.shoulder_azimuth.pack(side=LEFT)

        shoulder_pivot = Frame(place_joint_angles,
                               relief=FLAT,
                               background='gray40')
        shoulder_pivot.pack()
        Label(shoulder_pivot,
              text='Shoulder Pivot:     ',
              font=('verdana', 11),
              background='gray40').pack(side=LEFT)
        self.shoulder_pivot = numericValidator.NumericEntry(
            shoulder_pivot,
            width=10,
            font=('verdana', 11, 'bold'),
            background='gray70',
        )
        self.shoulder_pivot.pack(side=LEFT)

        elbow_p = Frame(place_joint_angles, relief=FLAT, background='gray40')
        elbow_p.pack()
        Label(elbow_p,
              text='Elbow Pivot:         ',
              font=('verdana', 11),
              background='gray40').pack(side=LEFT)
        self.elbow_pivot = numericValidator.NumericEntry(
            elbow_p,
            width=10,
            font=('verdana', 11, 'bold'),
            background='gray70',
        )
        self.elbow_pivot.pack(side=LEFT)

        wrist_p = Frame(place_joint_angles, relief=FLAT, background='gray40')
        wrist_p.pack()
        Label(wrist_p,
              text='Wrist Pitch:          ',
              font=('verdana', 11),
              background='gray40').pack(side=LEFT)
        self.wrist_pivot = numericValidator.NumericEntry(
            wrist_p,
            width=10,
            font=('verdana', 11, 'bold'),
            background='gray70',
        )
        self.wrist_pivot.pack(side=LEFT)

        place_joint_angles.pack(side=TOP)
        # ===================================END====================================================================

        # ===================================Operation Buttons======================================================
        operations = Frame(leftFrame, relief=FLAT, background='gray40')
        operations.pack(side=LEFT, anchor=W, fill=Y)

        #===============================OPERATION BUTTONS================================================

        # Load Button
        self.load_parameters = Button(operations,
                                      text='  UPLOAD  ',
                                      font=('lucida', 11, 'bold'),
                                      fg='red',
                                      width=11,
                                      background='gray60',
                                      command=self.load_to_console)
        self.load_parameters.pack(padx=4, pady=10, side=LEFT)

        # run button
        self.btnRun = Button(operations,
                             text='RUN ',
                             font=('lucida', 11, 'bold'),
                             fg='red',
                             width=10,
                             background='gray60',
                             command=self.run)
        self.btnRun.pack(padx=4, pady=10, side=LEFT)

        self.btnStop = Button(operations,
                              text='STOP',
                              font=('lucida', 11, 'bold'),
                              fg='red',
                              width=10,
                              background='gray60',
                              command=self.stop)
        self.btnStop.pack(padx=4, pady=10, side=LEFT)

        self.btnExit = Button(operations,
                              text='EXIT',
                              font=('lucida', 11, 'bold'),
                              fg='red',
                              width=10,
                              background='gray60',
                              command=self.exit)
        self.btnExit.pack(padx=4, pady=10, side=LEFT)
        # =================================END========================================================================

        # ==================================No of objects to be picked================================================

        parametersFrame = LabelFrame(self.centerFrame,
                                     background='gray40',
                                     text='Operation Parameters',
                                     font=('arial', 11, 'bold'))
        parametersFrame.pack(side=TOP)

        objects = Frame(parametersFrame, background='gray40')
        objectsLabel = Label(objects,
                             text='No. of Objects:   ',
                             font=('verdana', 11),
                             background='gray40')
        self.objectsEntry = numericValidator.NumericEntry(objects,
                                                          background='gray70',
                                                          font=('verdana', 11,
                                                                'bold'),
                                                          width=20)
        objectsLabel.pack(side=LEFT)
        self.objectsEntry.pack(side=LEFT, pady=4)

        objects.pack(side=TOP)

        servoFrame = Frame(parametersFrame, background='gray40')
        servoLabel = Label(servoFrame,
                           text="Servo Speed:      ",
                           font=('verdana', 11),
                           background='gray40')
        # servoEntry = Entry(servoFrame, background='gray70', width=10)
        self.servoEntry = Pmw.Counter(
            servoFrame,
            entry_width=30,
            entryfield_value='12.5',
            datatype={
                'counter': 'real',
                'separator': '.'
            },
            entryfield_validate={
                'validator': 'real',
                'min': 0.0,
                'max': 15.0,
                'separator': '.'
            },
            increment=.2,
        )

        servoLabel.pack(side=LEFT)
        self.servoEntry.pack(side=LEFT, fill=Y, pady=4)
        servoFrame.pack(side=LEFT)

        # ==================================END======================================================================

        # status bar
        sbar = Frame(overall, relief=SUNKEN, background='gray40')
        self.statusbar = Label(sbar,
                               text='Status:',
                               bd=1,
                               anchor=W,
                               font=('consolas', 13, 'bold'))
        self.statusbar.configure(background='gray40', fg='lawngreen')
        self.cstat = Label(sbar,
                           text='...',
                           fg='lawngreen',
                           background='gray40',
                           font=('consolas', 13, 'bold'))

        self.statusbar.pack(anchor=W, side=LEFT)
        self.cstat.pack(anchor=W, side=LEFT)
        sbar.pack(fill=X)

        jaw_label = Label(self.centerFrame,
                          text='Set Jaw width(mm)',
                          font=('verdana', 10, 'bold'),
                          background='gray40')
        jaw_label.pack(anchor=CENTER)

        # CREATE THE JAW WIDTH SLIDER
        self.slider_var = DoubleVar()
        self.scaler = Scale(self.centerFrame,
                            variable=self.slider_var,
                            orient=HORIZONTAL,
                            from_=0,
                            to=90,
                            tickinterval=0.2,
                            sliderlength=4,
                            relief=FLAT,
                            length=355,
                            fg='steelblue',
                            bg='gray40',
                            activebackground='brown',
                            font=('consolas', 11, 'bold'))
        self.scaler.pack(anchor=CENTER)

        # ==================System info==============================================
        # Show the robot arm image
        simulator = LabelFrame(self.centerFrame,
                               text='System Information',
                               font=('arial', 11, 'bold'),
                               height=700)
        simulator.pack(side=TOP, fill=Y, anchor=W)
        simulator.configure(background='gray40')

        # environment_parameters = LabelFrame(simulator, text='<<<Loaded Parameters:>', font=('consolas', 11))
        # environment_parameters.configure(background='gray40', fg='black')
        # environment_parameters.pack(side=TOP)

        tFrame = Frame(simulator, relief=FLAT)
        Label(tFrame,
              text='Modified: ',
              fg='black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.ctime = Label(tFrame,
                           text='_',
                           fg='black',
                           background='gray40',
                           font=('consolas', 11))
        self.ctime.pack(anchor=W)
        tFrame.pack(anchor=W)

        filler = Label(simulator,
                       text='============================================',
                       background='gray40',
                       fg='black')
        filler.pack()

        username = Frame(simulator, relief=FLAT)
        Label(username,
              text='Current User :'******'black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.uname = Label(username,
                           text='e5430',
                           background='gray40',
                           fg='black',
                           font=('consolas', 11))
        self.uname.pack(anchor=W)
        username.pack(anchor=W)

        operatinSys = Frame(simulator, relief=FLAT)
        Label(operatinSys,
              text='OS :',
              fg='black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.operatingsys = Label(operatinSys,
                                  text='_',
                                  background='gray40',
                                  fg='black',
                                  font=('consolas', 11))
        self.operatingsys.pack(anchor=W)
        operatinSys.pack(anchor=W)

        controller = Frame(simulator, relief=FLAT)
        Label(controller,
              text='Microcontroller :',
              fg='black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.control = Label(controller,
                             text='Arduino UNO ATMEGA328P',
                             background='gray40',
                             fg='black',
                             font=('consolas', 11))
        self.control.pack(anchor=W)
        controller.pack(anchor=W)

        configuration = Frame(simulator, relief=FLAT)
        Label(configuration,
              text='DoF :',
              fg='black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.config = Label(configuration,
                            text='6',
                            background='gray40',
                            fg='black',
                            font=('consolas', 11))
        self.config.pack(anchor=W)
        configuration.pack(anchor=W)

        run = Frame(simulator, relief=FLAT)
        Label(run,
              text='Run mode:',
              fg='black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.run_mode = Label(run,
                              text='_',
                              background='gray40',
                              fg='black',
                              font=('consolas', 11))
        self.run_mode.pack(anchor=W)
        run.pack(anchor=W)

        coords = Frame(simulator, relief=FLAT)
        Label(coords,
              text='Pick coordinates : ',
              fg='black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.pick_coords = Label(coords,
                                 text='_',
                                 background='gray40',
                                 fg='black',
                                 font=('consolas', 11))
        self.pick_coords.pack(anchor=W)
        coords.pack(anchor=W)

        pcoords = Frame(simulator, relief=FLAT)
        Label(pcoords,
              text='Place Coordinates : ',
              fg='black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.place_coords = Label(pcoords,
                                  text='_',
                                  background='gray40',
                                  fg='black',
                                  font=('consolas', 11))
        self.place_coords.pack(anchor=W)
        pcoords.pack(anchor=W)

        speed = Frame(simulator, relief=FLAT)
        Label(speed,
              text='Servo speed : ',
              fg='black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.servo_speed = Label(speed,
                                 text='_',
                                 background='gray40',
                                 fg='black',
                                 font=('consolas', 11))
        self.servo_speed.pack(anchor=W)
        speed.pack(anchor=W)

        objs = Frame(simulator, relief=FLAT)
        Label(objs,
              text='Objects : ',
              fg='black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.object = Label(objs,
                            text='_',
                            background='gray40',
                            fg='black',
                            font=('consolas', 11))
        self.object.pack(anchor=W)
        objs.pack(anchor=W)

        grip = Frame(simulator, relief=FLAT)
        Label(grip,
              text='Gripper width : ',
              fg='black',
              background='gray40',
              font=('consolas', 11)).pack(side=LEFT)
        self.gripper_width = Label(grip,
                                   text='_',
                                   background='gray40',
                                   fg='black',
                                   font=('consolas', 11))
        self.gripper_width.pack(anchor=W)
        grip.pack(anchor=W)

        # =====================END========================================================

        # Serial port
        self.centerLower = Frame(self.centerFrame, background='gray40')
        self.centerLower.pack(side=TOP)

        serialFrame = Frame(self.centerLower, background='gray40')
        serialFrame.pack(fill=X, side=TOP)

        serialLabel = Label(serialFrame,
                            text='Serial Port:              ',
                            background='gray40',
                            font=('verdana', 11))
        serialLabel.pack(side=LEFT, anchor=W, fill=X)
        self.ports = ['COM13', 'COM17']
        self.serialport = ttk.Combobox(serialFrame,
                                       font=('courier', 11),
                                       values=['COM13', 'COM17'])

        self.serialport.pack(side=LEFT)
        self.serialport.current(0)

        self.connect_status = Frame(self.centerLower, background='gray40')
        self.connect_status.pack(side=BOTTOM)
        self.connectLabel = Label(self.connect_status,
                                  text='',
                                  background='gray40',
                                  font=('verdana', 10, 'bold'))
        self.connectLabel.pack(anchor=CENTER)

        # =====================ANALOG GAUGES================================================

        # subdividing the right frame
        rightupper = LabelFrame(rightFrame,
                                bg='gray40',
                                text='Monitors',
                                font=('verdana', 10, 'bold'))
        rightupper.pack()

        self.inputVoltage = tools.Gauge(rightupper,
                                        width=180,
                                        height=150,
                                        min_value=0.0,
                                        max_value=5,
                                        label='Volts-In',
                                        unit='V',
                                        bg='gray40',
                                        yellow=100,
                                        red=100,
                                        red_low=80)
        self.inputVoltage.set_value(4.5)
        self.inputVoltage.grid(row=0, column=0)

        self.motorOne = tools.Gauge(rightupper,
                                    width=180,
                                    height=150,
                                    min_value=0.0,
                                    max_value=5,
                                    label='M1',
                                    unit='mW',
                                    bg='gray40',
                                    red_low=70,
                                    yellow=100)
        self.motorOne.set_value(4)
        self.motorOne.grid(row=0, column=1)

        self.motorTwo = tools.Gauge(rightupper,
                                    width=180,
                                    height=150,
                                    min_value=0.0,
                                    max_value=5,
                                    label='M2',
                                    unit='mW',
                                    bg='gray40',
                                    red_low=70,
                                    yellow=100)
        self.motorTwo.set_value(4.5)
        self.motorTwo.grid(row=0, column=2)

        self.motorThree = tools.Gauge(rightupper,
                                      width=180,
                                      height=150,
                                      min_value=0.0,
                                      max_value=5,
                                      label='M3',
                                      unit='mW',
                                      bg='gray40',
                                      red_low=70,
                                      yellow=100)
        self.motorThree.set_value(4.5)
        self.motorThree.grid(row=1, column=0)

        self.motorFour = tools.Gauge(rightupper,
                                     width=180,
                                     height=150,
                                     min_value=0.0,
                                     max_value=5,
                                     label='M4',
                                     unit='mW',
                                     bg='gray40',
                                     red_low=70,
                                     yellow=100)
        self.motorFour.set_value(3)
        self.motorFour.grid(row=1, column=1)
        #
        self.motorFive = tools.Gauge(rightupper,
                                     width=180,
                                     height=150,
                                     min_value=0.0,
                                     max_value=5,
                                     label='M5',
                                     unit='mw',
                                     bg='gray40',
                                     red_low=70,
                                     yellow=100)
        self.motorFive.set_value(4.5)
        self.motorFive.grid(row=1, column=2)
        #
        self.gripperMotor = tools.Gauge(rightupper,
                                        width=180,
                                        height=150,
                                        min_value=0.0,
                                        max_value=5,
                                        label='Gmotor',
                                        unit='mW',
                                        bg='gray40',
                                        red_low=70,
                                        yellow=100)
        self.gripperMotor.set_value(4.5)
        self.gripperMotor.grid(row=2, column=0)

        self.gripperForce = tools.Gauge(rightupper,
                                        width=180,
                                        height=150,
                                        min_value=0.0,
                                        max_value=5,
                                        label='GrForce',
                                        unit='mN',
                                        bg='gray40',
                                        red_low=70,
                                        yellow=100)
        self.gripperForce.set_value(4.5)
        self.gripperForce.grid(row=2, column=1)

        self.proximitySensor = tools.Gauge(rightupper,
                                           width=180,
                                           height=150,
                                           min_value=0.0,
                                           max_value=5,
                                           label='Contact',
                                           unit='mm',
                                           bg='gray40',
                                           yellow=100,
                                           red=100)
        self.proximitySensor.set_value(4.5)
        self.proximitySensor.grid(row=2, column=2)
        # ===============================END================================================

        # ============================GRIPPER FORCE GRAPH====================================
        rightLower = LabelFrame(rightFrame,
                                bg='gray40',
                                text='Gripper Force',
                                font=('verdana', 10, 'bold'))
        rightLower.pack(fill=X)

        figure = Figure(figsize=(4.5, 2.5), dpi=100)
        figure.tight_layout(h_pad=5)

        subplot = figure.add_subplot(111)
        # subplot.set_xlabel('Time(s)')
        subplot.set_ylabel('Force(N)')
        subplot.plot([1, 2, 3, 4, 5, 6, 7, 8], [5, 6, 7, 5, 6, 1, 2, 3],
                     color='orange')

        graphcanvas = FigureCanvasTkAgg(figure, rightLower)
        graphcanvas.draw()
        graphcanvas.get_tk_widget().pack(fill=X)
Example #53
0
class PlotView(QtWidgets.QWidget):
    subplotRemovedSignal = QtCore.Signal(object)
    plotCloseSignal = QtCore.Signal()

    def __init__(self):
        super(PlotView, self).__init__()
        self.plots = OrderedDict({})
        self.errors_list = set()
        self.plot_storage = {}  # stores lines and info to create lines
        self.current_grid = None
        self.gridspecs = {
            1: gridspec.GridSpec(1, 1),
            2: gridspec.GridSpec(1, 2),
            3: gridspec.GridSpec(3, 1),
            4: gridspec.GridSpec(2, 2)
        }
        self.figure = Figure()
        self.figure.set_facecolor("none")
        self.canvas = FigureCanvas(self.figure)

        self.plot_selector = QtWidgets.QComboBox()
        self._update_plot_selector()
        self.plot_selector.currentIndexChanged[str].connect(self._set_bounds)

        button_layout = QtWidgets.QHBoxLayout()
        self.x_axis_changer = AxisChangerPresenter(AxisChangerView("X"))
        self.x_axis_changer.on_upper_bound_changed(self._update_x_axis_upper)
        self.x_axis_changer.on_lower_bound_changed(self._update_x_axis_lower)

        self.y_axis_changer = AxisChangerPresenter(AxisChangerView("Y"))
        self.y_axis_changer.on_upper_bound_changed(self._update_y_axis_upper)
        self.y_axis_changer.on_lower_bound_changed(self._update_y_axis_lower)

        self.errors = QtWidgets.QCheckBox("Errors")
        self.errors.stateChanged.connect(self._errors_changed)

        button_layout.addWidget(self.plot_selector)
        button_layout.addWidget(self.x_axis_changer.view)
        button_layout.addWidget(self.y_axis_changer.view)
        button_layout.addWidget(self.errors)

        grid = QtWidgets.QGridLayout()

        self.toolbar = myToolbar(self.canvas, self)
        self.toolbar.update()

        grid.addWidget(self.toolbar, 0, 0)
        grid.addWidget(self.canvas, 1, 0)
        grid.addLayout(button_layout, 2, 0)
        self.setLayout(grid)

    def setAddConnection(self, slot):
        self.toolbar.setAddConnection(slot)

    def setRmConnection(self, slot):
        self.toolbar.setRmConnection(slot)

    def _redo_layout(func):
        """
        Simple decorator (@_redo_layout) to call tight_layout() on plots
         and to redraw the canvas.
        (https://www.python.org/dev/peps/pep-0318/)
        """
        def wraps(self, *args, **kwargs):
            output = func(self, *args, **kwargs)
            if len(self.plots):
                self.figure.tight_layout()
            self.canvas.draw()
            return output

        return wraps

    def _silent_checkbox_check(self, state):
        """ Checks a checkbox without emitting a checked event. """
        self.errors.blockSignals(True)
        self.errors.setChecked(state)
        self.errors.blockSignals(False)

    def _set_plot_bounds(self, name, plot):
        """
        Sets AxisChanger bounds to the given plot bounds and updates
            the plot-specific error checkbox.
        """
        self.x_axis_changer.set_bounds(plot.get_xlim())
        self.y_axis_changer.set_bounds(plot.get_ylim())
        self._silent_checkbox_check(name in self.errors_list)

    def _set_bounds(self, new_plot):
        """
        Sets AxisChanger bounds if a new plot is added, or removes the AxisChanger
            fields if a plot is removed.
        """
        new_plot = str(new_plot)
        if new_plot and new_plot != "All":
            plot = self.get_subplot(new_plot)
            self._set_plot_bounds(new_plot, plot)
        elif not new_plot:
            self.x_axis_changer.clear_bounds()
            self.y_axis_changer.clear_bounds()

    def _get_current_plot_name(self):
        """ Returns the 'current' plot name based on the dropdown selector. """
        return str(self.plot_selector.currentText())

    def _get_current_plots(self):
        """
        Returns a list of the current plot, or all plots if 'All' is selected.
        """
        name = self._get_current_plot_name()
        return self.plots.values() if name == "All" else [
            self.get_subplot(name)
        ]

    @_redo_layout
    def _update_x_axis(self, bound):
        """ Updates the plot's x limits with the specified bound. """
        try:
            for plot in self._get_current_plots():
                plot.set_xlim(**bound)
        except KeyError:
            return

    def _update_x_axis_lower(self, bound):
        """ Updates the lower x axis limit. """
        self._update_x_axis({"left": bound})

    def _update_x_axis_upper(self, bound):
        """ Updates the upper x axis limit. """
        self._update_x_axis({"right": bound})

    @_redo_layout
    def _update_y_axis(self, bound):
        """ Updates the plot's y limits with the specified bound. """
        try:
            for plot in self._get_current_plots():
                plot.set_ylim(**bound)
        except KeyError:
            return

    def _update_y_axis_lower(self, bound):
        """ Updates the lower y axis limit. """
        self._update_y_axis({"bottom": bound})

    def _update_y_axis_upper(self, bound):
        """ Updates the upper y axis limit. """
        self._update_y_axis({"top": bound})

    def _modify_errors_list(self, name, state):
        """
        Adds/Removes a plot name to the errors set depending on the 'state' bool.
        """
        if state:
            self.errors_list.add(name)
        else:
            try:
                self.errors_list.remove(name)
            except KeyError:
                return

    def _change_plot_errors(self, name, plot, state):
        """
        Removes the previous plot and redraws with/without errors depending on the state.
        """
        self._modify_errors_list(name, state)
        # get a copy of all the workspaces
        workspaces = copy(self.plot_storage[name].ws)
        # get the limits before replotting, so they appear unchanged.
        x, y = plot.get_xlim(), plot.get_ylim()
        # clear out the old container
        self.plot_storage[name].delete()
        for workspace in workspaces:
            self.plot(name, workspace)
        plot.set_xlim(x)
        plot.set_ylim(y)
        self._set_bounds(name)  # set AxisChanger bounds again.

    @_redo_layout
    def _errors_changed(self, state):
        """ Replots subplots with errors depending on the current selection. """
        current_name = self._get_current_plot_name()
        if current_name == "All":
            for name, plot in iteritems(self.plots):
                self._change_plot_errors(name, plot, state)
        else:
            self._change_plot_errors(current_name,
                                     self.get_subplot(current_name), state)

    def _set_positions(self, positions):
        """ Moves all subplots based on a gridspec change. """
        for plot, pos in zip(self.plots.values(), positions):
            grid_pos = self.current_grid[pos[0], pos[1]]
            plot.set_position(grid_pos.get_position(
                self.figure))  # sets plot position, magic?
            # required because tight_layout() is used.
            plot.set_subplotspec(grid_pos)

    @_redo_layout
    def _update_gridspec(self, new_plots, last=None):
        """ Updates the gridspec; adds a 'last' subplot if one is supplied. """
        if new_plots:
            self.current_grid = self.gridspecs[new_plots]
            positions = putils.get_layout(new_plots)
            self._set_positions(positions)
            if last is not None:
                # label is necessary to fix
                # https://github.com/matplotlib/matplotlib/issues/4786
                pos = self.current_grid[positions[-1][0], positions[-1][1]]
                self.plots[last] = self.figure.add_subplot(pos, label=last)
                self.plots[last].set_subplotspec(pos)
        self._update_plot_selector()

    def _update_plot_selector(self):
        """ Updates plot selector (dropdown). """
        self.plot_selector.clear()
        self.plot_selector.addItem("All")
        self.plot_selector.addItems(list(self.plots.keys()))

    @_redo_layout
    def plot(self, name, workspace):
        """ Plots a workspace to a subplot (with errors, if necessary). """
        if name in self.errors_list:
            self.plot_workspace_errors(name, workspace)
        else:
            self.plot_workspace(name, workspace)
        self._set_bounds(name)

    def _add_plotted_line(self, name, label, lines, workspace):
        """ Appends plotted lines to the related subplot list. """
        self.plot_storage[name].addLine(label, lines, workspace)

    def plot_workspace_errors(self, name, workspace):
        """ Plots a workspace with errors, and appends caps/bars to the subplot list. """
        subplot = self.get_subplot(name)
        line, cap_lines, bar_lines = plots.plotfunctions.errorbar(subplot,
                                                                  workspace,
                                                                  specNum=1)
        # make a tmp plot to get auto generated legend name
        tmp, = plots.plotfunctions.plot(subplot, workspace, specNum=1)
        label = tmp.get_label()
        # remove the tmp line
        tmp.remove()
        del tmp
        # collect results
        all_lines = [line]
        all_lines.extend(cap_lines)
        all_lines.extend(bar_lines)
        self._add_plotted_line(name, label, all_lines, workspace)

    def plot_workspace(self, name, workspace):
        """ Plots a workspace normally. """
        subplot = self.get_subplot(name)
        line, = plots.plotfunctions.plot(subplot, workspace, specNum=1)
        self._add_plotted_line(name, line.get_label(), [line], workspace)

    def get_subplot(self, name):
        """ Returns the subplot corresponding to a given name """
        return self.plots[name]

    def get_subplots(self):
        """ Returns all subplots. """
        return self.plots

    def add_subplot(self, name):
        """ will raise KeyError if: plots exceed 4 """
        self._update_gridspec(len(self.plots) + 1, last=name)
        self.plot_storage[name] = subPlot(name)
        return self.get_subplot(name)

    def remove_subplot(self, name):
        """ will raise KeyError if: 'name' isn't a plot; there are no plots """
        self.figure.delaxes(self.get_subplot(name))
        del self.plots[name]
        del self.plot_storage[name]
        self._update_gridspec(len(self.plots))
        self.subplotRemovedSignal.emit(name)

    def removeLine(self, subplot, label):
        self.plot_storage[subplot].removeLine(label)

    @_redo_layout
    def add_moveable_vline(self, plot_name, x_value, y_minx, y_max, **kwargs):
        pass

    @_redo_layout
    def add_moveable_hline(self, plot_name, y_value, x_min, x_max, **kwargs):
        pass

    def closeEvent(self, event):
        self.plotCloseSignal.emit()

    def plotCloseConnection(self, slot):
        self.plotCloseSignal.connect(slot)

    @property
    def subplot_names(self):
        return self.plot_storage.keys()

    def line_labels(self, subplot):
        return self.plot_storage[subplot].lines.keys()
Example #54
0
def grism_sky_column_average_GP(asn_file='GDN12-G102_asn.fits', mask_grow=8):
    """
    Remove column-averaged residuals from grism exposures, smooth with Gaussian Processes
    """
    import scipy.ndimage as nd
    import astropy.io.fits as pyfits
    from sklearn.gaussian_process import GaussianProcess
    
    asn = threedhst.utils.ASNFile(asn_file)
            
    for k in range(len(asn.exposures)):
        #### 1D column averages
        flt = pyfits.open('%s_flt.fits' %(asn.exposures[k]), mode='update')
        segfile = '%s_flt.seg.fits' %(asn.exposures[k])
        seg = pyfits.open(segfile)[0].data
        seg_mask = nd.maximum_filter((seg > 0), size=mask_grow) == 0
        dq_ok = (flt[3].data & (4+32+16+512+2048+4096)) == 0
        
        mask = seg_mask & dq_ok & (flt[2].data > 0)
        
        threedhst.showMessage('Remove column average (GP): %s' %(asn.exposures[k]))
        
        #### Iterative clips on percentile
        #mask &= (flt[1].data < np.percentile(flt[1].data[mask], 98)) & (flt[2].data > 0) & (flt[1].data > np.percentile(flt[1].data[mask], 2))
        #mask &= (flt[1].data < np.percentile(flt[1].data[mask], 84)) & (flt[2].data > 0) & (flt[1].data > np.percentile(flt[1].data[mask], 16))
                    
        xmsk = np.arange(1014)

        masked = flt[1].data*1
        masked[~mask] = np.nan
        yres = np.zeros(1014)
        yrms = yres*0.
        for i in range(1014):
            # ymsk = mask[:,i]
            # yres[i] = np.median(flt[1].data[ymsk,i])
            ymsk = masked[:,i]
            #ymsk = masked[:,np.maximum(i-10,0):i+10]
            #yres[i] = np.median(ymsk[np.isfinite(ymsk)])
            ok = np.isfinite(ymsk)
            ymsk[(ymsk > np.percentile(ymsk[ok], 84)) | (ymsk < np.percentile(ymsk[ok], 16))] = np.nan
            msk = np.isfinite(ymsk)
            yres[i] = np.mean(ymsk[msk])
            yrms[i] = np.std(ymsk[msk])/np.sqrt(msk.sum())
            
        #
        yok = np.isfinite(yres)
        if 'GSKY00' in list(flt[0].header.keys()):
            bg_sky = flt[0].header['GSKY00']
        else:
            bg_sky = 1
            
        gp = GaussianProcess(regr='constant', corr='squared_exponential', theta0=8,
                             thetaL=7, thetaU=12,
                             nugget=(yrms/bg_sky)[yok][::1]**2,
                             random_start=10, verbose=True, normalize=True) #, optimizer='Welch')
        #
        gp.fit(np.atleast_2d(xmsk[yok][::1]).T, yres[yok][::1]+bg_sky)
        y_pred, MSE = gp.predict(np.atleast_2d(xmsk).T, eval_MSE=True)
        gp_sigma = np.sqrt(MSE)
        
        resid = threedhst.utils.medfilt(yres, 41)

        flt[1].data -= y_pred-bg_sky
        flt.flush()
        
        #flt.writeto(flt.filename(), clobber=True)
        
        #plt.plot(yres_sm)
        
        ### Make figure
        from matplotlib.figure import Figure
        from matplotlib.backends.backend_agg import FigureCanvasAgg
        
        fig = Figure(figsize=[6,4], dpi=100)

        fig.subplots_adjust(wspace=0.25,hspace=0.02,left=0.15,
                            bottom=0.08,right=0.97,top=0.92)

        ax = fig.add_subplot(111)
        ax.set_title(flt.filename())

        ax.plot(yres, color='black', alpha=0.3)
        ax.plot(y_pred-bg_sky, color='red', linewidth=2, alpha=0.7)
        ax.fill_between(xmsk, y_pred-bg_sky + gp_sigma, y_pred-bg_sky - gp_sigma, color='red', alpha=0.3)
        
        ax.set_xlim(0,1014)
        ax.set_xlabel('x pix'); ax.set_ylabel('BG residual (e/s)')
        fig.tight_layout(pad=0.2)
        
        canvas = FigureCanvasAgg(fig)
        canvas.print_figure(flt.filename().split('.fits')[0] + '.column.png', dpi=100, transparent=False)
Example #55
0
class Plotting:
    def __init__(self, fitWindow):
        self.fitWindow = fitWindow
        #plot display flags
        self.flag_ca = True
        self.flag_ra = False
        self.flag_cl = False
        self.flag_rl = False
        self.flag_cn = False
        self.flag_ecc = False
        self.flag_lf = False
        self.flag_zp = False
        self.flag_xp = False
        self.flag_ap = False
        self.flag_fp = False
        self.flag_st = False
        self.flag_zd = False
        self.x_var = 'Time'  #x axis default parameter
        self.legendPos = "upper right"
        # self.fig1_close = True
        self.show_title = True
        self.showLegend2 = True
        self.fontSize = 12
        #fitting
        # self.flag_fit = False
        # self.fit_x = 'Vertical Position (μm)'
        # self.fit_y = 'Vertical Force'
        # self.startFit = 0
        # self.endFit = 100
        self.fit_pos = '0.5,0.5'
        self.fit_show = False
        self.slope = ''
        self.slope_unit = ''

        #initialize figure with random data
        self.fig1 = Figure(figsize=(11, 5), dpi=100)
        ax = self.fig1.add_subplot(111)
        xdata = np.linspace(0, 4, 50)
        ydata = np.sin(xdata)
        ax.plot(xdata, ydata, 'r-', linewidth=1, markersize=1)

        self.plotWidget = PlotWidget(fig=self.fig1,
                                     cursor1_init=2,
                                     cursor2_init=6)

    def plotData(self, unit):  #prepare plot

        xDict = {
            'Vertical Position (μm)': self.dist_vert1,
            'Lateral Position (μm)': self.dist_lat1,
            'Deformation (μm)': self.deform_vert,
            'Time (s)': self.time1
        }
        xAxisData = xDict.get(self.x_var)

        markerlist = ["o", "v", "^", "s", "P", "*", "D", "<", "X", ">"]
        linelist = [":", "-.", "--", "-", ":", "-.", "--", "-", ":", "-."]

        plt.rcParams.update({'font.size': self.fontSize})

        # self.fig1 = plt.figure(num="Force/Area vs Time", figsize = [11, 5])
        # self.fig1 = Figure(figsize=(11, 5), dpi=100)

        # self.fig1.canvas.mpl_connect('close_event', self.handle_close)

        print("fig1")

        #store cursor position values before clearing plot
        if self.plotWidget.wid.cursor1 == None:
            c1_init = None
        else:
            c1_init = self.plotWidget.wid.cursor1.get_xdata()[0]

        if self.plotWidget.wid.cursor2 == None:
            c2_init = None
        else:
            c2_init = self.plotWidget.wid.cursor2.get_xdata()[0]

        self.fig1.clear()
        ax1 = self.fig1.add_subplot(1, 1, 1)
        lns = []

        # ax1.set_title('Speed = ' + str(self.speed_um) + ' μm/s')
        ax1.set_xlabel(self.x_var)
        ax1.set_ylabel('Vertical Force (μN)', color='r')
        p1, = ax1.plot(xAxisData[self.plot_slice],
                       self.force_vert1_shifted[self.plot_slice],
                       'ro',
                       alpha=0.5,
                       linewidth=1,
                       markersize=1,
                       label="Vertical Force")
        lns.append(p1)

        # self.plotWidget.mpl_connect('close_event', self.handle_close)

        if self.ptsnumber != 0:
            ##            ptsperstep = int(self.ptsnumber/self.step_num)
            i = 0
            lns_reg = []  #region legend handle
            lab_reg = []  #region legend label
            speed_inview = []  #speed list in plot range
            for a in self.steps:  #shade step regions
                if i < ((self.plot_slice.start + 1) / self.ptsperstep) - 1:
                    i += 1
                    continue

                if self.ptsperstep * (i + 1) - 1 > self.plot_slice.stop:
                    endpoint = self.plot_slice.stop - 1
                    exit_flag = True
                else:
                    endpoint = self.ptsperstep * (i + 1) - 1
                    exit_flag = False

                if self.ptsperstep * i < self.plot_slice.start:
                    startpoint = self.plot_slice.start
                else:
                    startpoint = self.ptsperstep * i

                x_start = min(xAxisData[startpoint:endpoint])
                x_end = max(xAxisData[startpoint:endpoint])
                if a == 'Front':
                    v1 = ax1.axvspan(x_start,
                                     x_end,
                                     alpha=0.9,
                                     color='aliceblue',
                                     label=a)
                    lns_reg.append(v1)
                    lab_reg.append(a)
                    speed_inview.append(self.speed_um[i])
                    if exit_flag == True:
                        break
                elif a == 'Back':
                    v2 = ax1.axvspan(x_start,
                                     x_end,
                                     alpha=0.9,
                                     color='whitesmoke',
                                     label=a)
                    lns_reg.append(v2)
                    lab_reg.append(a)
                    speed_inview.append(self.speed_um[i])
                    if exit_flag == True:
                        break
                elif a == 'Up':
                    v3 = ax1.axvspan(x_start,
                                     x_end,
                                     alpha=0.9,
                                     color='honeydew',
                                     label=a)
                    lns_reg.append(v3)
                    lab_reg.append(a)
                    speed_inview.append(self.speed_um[i])
                    if exit_flag == True:
                        break
                elif a == 'Down':
                    v4 = ax1.axvspan(x_start,
                                     x_end,
                                     alpha=0.9,
                                     color='linen',
                                     label=a)
                    lns_reg.append(v4)
                    lab_reg.append(a)
                    speed_inview.append(self.speed_um[i])
                    if exit_flag == True:
                        break
                elif a == 'Pause':
                    v5 = ax1.axvspan(x_start,
                                     x_end,
                                     alpha=0.9,
                                     color='lightyellow',
                                     label=a)
                    lns_reg.append(v5)
                    lab_reg.append(a)
                    speed_inview.append(self.speed_um[i])
                    if exit_flag == True:
                        break
                i += 1

        if self.show_title == True:
            ax1.set_title('Speed = ' +
                          str(speed_inview).replace('[', '').replace(']', '') +
                          ' μm/s')
        if self.showLegend2 == True:
            dict_reg = dict(zip(lab_reg,
                                lns_reg))  #legend dictionary (remove dup)
            self.fig1.legend(dict_reg.values(),
                             dict_reg.keys(),
                             loc='lower right',
                             ncol=len(lns_reg))

        if self.flag_ap == True:  #show adhesion calc
            #fill adhesion energy region
            ax1.fill_between(xAxisData[self.energy_slice],
                             self.forceDict["zero1"][0],
                             self.force_vert1_shifted[self.energy_slice],
                             color='black')
            i = 0
            for k in self.rangeDict.keys():
                if len(self.rangeDict.keys()) > 1 and k == "Default":
                    continue
                ax1.axhline(y=self.forceDict["zero1"][i],
                            color='y',
                            alpha=1,
                            linestyle=linelist[i],
                            linewidth=1)
                ax1.axhline(y=self.forceDict["force_min1"][i],
                            color='y',
                            alpha=1,
                            linestyle=linelist[i],
                            linewidth=1)
                ax1.axhline(y=self.forceDict["force_max1"][i],
                            color='y',
                            alpha=1,
                            linestyle=linelist[i],
                            linewidth=1)
                ax1.axvline(x=xAxisData[self.time1.index(
                    self.indDict["time1_max"][i])],
                            color='y',
                            alpha=1,
                            linestyle=linelist[i],
                            linewidth=1)
                i += 1

        if self.flag_ca == True or self.flag_ra == True:
            ax2 = ax1.twinx()  #secondary axis
            ##                cmap = plt.cm.get_cmap("Reds")  # type: matplotlib.colors.ListedColormap
            num = len(self.rangeDict.keys())
            ##                colors = plt.cm.Reds(np.linspace(0.3,1,num))
            colors = plt.cm.Greens([0.7, 0.5, 0.9, 0.3, 1])
            ax2.set_prop_cycle(color=colors)
            ax2.set_ylabel('Area ($' + unit + '^2$)', color='g')
            if self.flag_ca == True:
                i = 0
                for k in self.rangeDict.keys():
                    if len(self.rangeDict.keys()) > 1 and k == "Default":
                        continue
                    p2, = ax2.plot(self.time2[self.plot_slice2],
                                   self.dataDict[k][0][self.plot_slice2],
                                   '-' + markerlist[i],
                                   alpha=0.5,
                                   linewidth=1,
                                   markersize=2,
                                   label="Contact Area: " + k)
                    # p2.set_animated(True) #BLIT THIS CHECK!!!
                    lns.append(p2)
                    if self.flag_ap == True:  #adhesion calc
                        ax2.plot(self.indDict["time1_max"][i],
                                 self.areaDict["area2_pulloff"][i],
                                 'y' + markerlist[i],
                                 alpha=0.8)
                    if self.flag_fp == True:  #friction calc
                        ax2.plot(self.indDict["time1_lat_avg"][i],
                                 self.areaDict["area_friction"][i],
                                 'g' + markerlist[i],
                                 alpha=0.8)
                    i += 1
            if self.flag_ra == True:  #consider first key since auto roi is same for all keys
                colors = plt.cm.Blues([0.7, 0.5, 0.9, 0.3, 1])
                ax2.set_prop_cycle(color=colors)
                j = 0
                for k in self.rangeDict.keys():
                    if len(self.rangeDict.keys()) > 1 and k == "Default":
                        continue
                    p3, = ax2.plot(self.time2[self.plot_slice2],
                                   self.dataDict[k][3][self.plot_slice2],
                                   '-' + markerlist[j],
                                   alpha=0.5,
                                   linewidth=1,
                                   markersize=2,
                                   label="ROI Area: " + k)
                    lns.append(p3)
                    j += 1

        if self.flag_lf == True:
            ax3 = ax1.twinx()  #lateral force
            ax3.set_ylabel('Lateral Force (μN)', color='c')
            ax3.spines['left'].set_position(
                ('outward', int(6 * self.fontSize)))
            ax3.spines["left"].set_visible(True)
            ax3.yaxis.set_label_position('left')
            ax3.yaxis.set_ticks_position('left')
            if self.invert_latf == True:
                ax3.invert_yaxis()
            if self.flag_lf == True:
                p4, = ax3.plot(xAxisData[self.plot_slice],
                               self.force_lat1_shifted[self.plot_slice],
                               'co',
                               alpha=0.5,
                               linewidth=1,
                               markersize=1,
                               label="Lateral Force")

##            if self.flag_lf_filter == True:
##                p4, = ax3.plot(self.time1[self.plot_slice], self.force_lat1_filtered_shifted[self.plot_slice], '-c',
##                     alpha=0.5, linewidth=1, label="Lateral Force")

            if self.flag_fp == True:  #show friction calc
                i = 0
                for k in self.rangeDict.keys():
                    if len(self.rangeDict.keys()) > 1 and k == "Default":
                        continue
                    ax3.axhline(y=self.forceDict["force_lat_max"][i],
                                color='g',
                                alpha=1,
                                linestyle=linelist[i],
                                linewidth=1)
                    ax3.axhline(y=self.forceDict["force_lat_min"][i],
                                color='g',
                                alpha=1,
                                linestyle=linelist[i],
                                linewidth=1)
                    ax1.axhline(y=self.forceDict["force_max2"][i],
                                color='g',
                                alpha=1,
                                linestyle=linelist[i],
                                linewidth=1)
                    ax3.axvline(x=xAxisData[self.time1.index(
                        self.indDict["time1_lat_avg"][i])],
                                color='g',
                                alpha=1,
                                linestyle=linelist[i],
                                linewidth=1)
                    # ax2.plot(self.indDict["time1_lat_avg"][i],
                    #          self.areaDict["area_friction"][i],
                    #          'g' + markerlist[i], alpha=0.8)
                    i += 1
                ax3.axhline(y=self.forceDict["zero2"],
                            color='g',
                            alpha=0.5,
                            linestyle=linelist[0],
                            linewidth=1)
            lns.append(p4)
        else:
            ax3 = None

        if self.flag_zp == True or self.flag_xp == True or self.flag_zd:  #piezo position/deformation
            ax4 = ax1.twinx()  #piezo waveform
            ax4.set_ylabel('Displacement (μm)', color='violet')
            if self.flag_ca == True or self.flag_ra == True:  #shift axis if area plotted
                ax4.spines['right'].set_position(
                    ('outward', int(7 * self.fontSize)))
##                ax4.invert_yaxis()
            if self.flag_zp == True:
                p5, = ax4.plot(xAxisData[self.plot_slice],
                               self.dist_vert1[self.plot_slice],
                               '-',
                               markersize=1,
                               color='violet',
                               alpha=0.5,
                               label="Vertical Piezo")
                lns.append(p5)
            if self.flag_xp == True:
                p6, = ax4.plot(xAxisData[self.plot_slice],
                               self.dist_lat1[self.plot_slice],
                               '-.',
                               markersize=1,
                               color='violet',
                               alpha=0.5,
                               label="Lateral Piezo")
                lns.append(p6)
            if self.flag_zd == True:  #actual deformation plot
                p12, = ax4.plot(xAxisData[self.plot_slice],
                                self.deform_vert[self.plot_slice],
                                '-o',
                                markersize=1,
                                color='violet',
                                alpha=0.5,
                                label="Deformation")
                if self.flag_ap == True:
                    ax1.axvline(x=xAxisData[self.deform_tol],
                                color='violet',
                                alpha=1,
                                linestyle=":",
                                linewidth=1)
                lns.append(p12)

        if self.flag_cl == True or self.flag_rl == True:
            ax5 = ax1.twinx()
            num = len(self.rangeDict.keys())
            colors = plt.cm.copper(np.linspace(0.2, 0.7, num))
            ax5.set_prop_cycle(color=colors)
            ax5.set_ylabel('Length ($' + unit + '$)', color='brown')
            if self.flag_ca == True or self.flag_ra == True:
                ax5.spines['right'].set_position(
                    ('outward', int(7 * self.fontSize)))
            if self.flag_cl == True:  #contact length
                i = 0
                for k in self.rangeDict.keys():
                    if len(self.rangeDict.keys()) > 1 and k == "Default":
                        continue
                    p7, = ax5.plot(self.time2[self.plot_slice2],
                                   self.dataDict[k][1][self.plot_slice2],
                                   '-' + markerlist[i],
                                   alpha=0.5,
                                   linewidth=1,
                                   markersize=2,
                                   label="Contact Length: " + k)
                    lns.append(p7)
                    i += 1
            if self.flag_rl == True:  #roi length
                ##                ax5 = ax1.twinx()
                num = len(self.rangeDict.keys())
                colors = plt.cm.Wistia(np.linspace(0.2, 0.7, num))
                ax5.set_prop_cycle(color=colors)
                ##                ax5.spines['right'].set_position(('outward', 70))
                j = 0
                for k in self.rangeDict.keys():
                    if len(self.rangeDict.keys()) > 1 and k == "Default":
                        continue
##                    ax5.set_ylabel('Length ($' + unit + '$)', color = 'brown')
                    p8, = ax5.plot(self.time2[self.plot_slice2],
                                   self.dataDict[k][4][self.plot_slice2],
                                   '-' + markerlist[j],
                                   alpha=0.5,
                                   linewidth=1,
                                   markersize=2,
                                   label="ROI Length: " + k)
                    lns.append(p8)
                    j += 1
        if self.flag_cn == True:  #contact number
            ax5 = ax1.twinx()
            num = len(self.rangeDict.keys())
            colors = plt.cm.copper(np.linspace(0.2, 0.7, num))
            ax5.set_prop_cycle(color=colors)
            ax5.spines['right'].set_position(
                ('outward', int(7 * self.fontSize)))
            i = 0
            for k in self.rangeDict.keys():
                if len(self.rangeDict.keys()) > 1 and k == "Default":
                    continue
                ax5.set_ylabel('Number', color='brown')
                p9, = ax5.plot(self.time2[self.plot_slice2],
                               self.dataDict[k][2][self.plot_slice2],
                               '-' + markerlist[i],
                               alpha=0.5,
                               linewidth=1,
                               markersize=2,
                               label="Contact Number: " + k)
                lns.append(p9)
                i += 1
        if self.flag_ecc == True:  #contact eccentricity
            ax5 = ax1.twinx()
            num = len(self.rangeDict.keys())
            colors = plt.cm.copper(np.linspace(0.2, 0.7, num))
            ax5.set_prop_cycle(color=colors)
            ax5.spines['right'].set_position(
                ('outward', int(7 * self.fontSize)))
            i = 0
            for k in self.rangeDict.keys():
                if len(self.rangeDict.keys()) > 1 and k == "Default":
                    continue
                ax5.set_ylabel('Eccentricity' + unit + '$)', color='brown')
                p10, = ax5.plot(self.time2[self.plot_slice2],
                                self.dataDict[k][5][self.plot_slice2],
                                '-' + markerlist[i],
                                alpha=0.5,
                                linewidth=1,
                                markersize=2,
                                label="Median Eccentricity: " + k)
                lns.append(p10)
                i += 1

        if self.flag_st == True or self.flag_lf_filter == True:  #stress CHECK!
            ax6 = ax1.twinx()
            ax6.set_ylabel('Stress (μN/$' + unit + '^2$)', color='c')
            ax6.spines['left'].set_position(
                ('outward', int(6 * self.fontSize)))
            ax6.spines["left"].set_visible(True)
            ax6.yaxis.set_label_position('left')
            ax6.yaxis.set_ticks_position('left')
            if self.flag_st == True:
                p11, = ax6.plot(xAxisData[self.plot_slice],
                                self.stress[self.plot_slice],
                                'co',
                                alpha=0.5,
                                linewidth=1,
                                markersize=1,
                                label="Stress")
            if self.flag_lf_filter == True:
                p11, = ax6.plot(xAxisData[self.plot_slice],
                                self.stress_filtered[self.plot_slice],
                                '-c',
                                alpha=0.5,
                                linewidth=1,
                                markersize=1,
                                label="Stress")

            lns.append(p11)

##            lns = [p1, p3, p2, p4, p5]
##        else:
##            lns = [p1, p2]

        ax1.legend(handles=lns, loc=self.legendPos)

        if self.fitWindow.enableFitting.isChecked() == True:
            axDict = {'Vertical Force (μN)': ax1, 'Lateral Force (μN)': ax3}
            # yDict = {'Vertical Force (μN)':self.force_vert1_shifted,
            #          'Lateral Force (μN)':self.force_lat1_shifted}
            # fit_slice = slice(int(self.startFit * self.ptsnumber/100),
            #                   int(self.endFit * self.ptsnumber/100))
            self.slope_unit = self.fitWindow.yFit.currentText().split('(')[1].split(')')[0] + '/' +\
                              self.fitWindow.xFit.currentText().split('(')[1].split(')')[0]
            text_pos = self.fit_pos.split(",")

            # self.slope = fitting.polyfitData(xDict.get(self.fit_x)[fit_slice], yDict.get(self.fit_y)[fit_slice],
            #                          axDict.get(self.fit_y), xAxisData[fit_slice], unit = self.slope_unit,
            #                          eq_pos = text_pos, fit_order = 1, fit_show = self.fit_show)
            ax_fit = axDict.get(self.fitWindow.yFit.currentText())
            ax_fit.plot(xAxisData[self.fitWindow.fit_slice],
                        self.fitWindow.fit_ydata,
                        color='black',
                        linewidth=2,
                        linestyle='dashed')

            ##        print(eq_pos)
            if self.fit_show == True and \
                self.fitWindow.fittingFunctionType.currentText() == 'Linear':
                self.slope = self.fitWindow.fitParams['m']
                slope_label = "Slope: " + "%.4f"%(self.slope) + \
                    ' (' + self.slope_unit + ')'
                ax_top = self.fig1.get_axes()[-1]
                ax_top.text(float(text_pos[0]),
                            float(text_pos[1]),
                            slope_label,
                            ha='right',
                            transform=ax_top.transAxes,
                            color='black',
                            bbox=dict(facecolor='white',
                                      edgecolor='black',
                                      alpha=0.5),
                            picker=5)
        else:
            self.slope = ''
            self.slope_unit = ''

        self.fig1.tight_layout()
        self.fig1.canvas.draw()

        self.plotWidget.wid.axes = self.fig1.get_axes()[-1]

        self.plotWidget.wid.add_cursors(cursor1_init=c1_init,
                                        cursor2_init=c2_init)

        print("plot finish")
        # self.plotWidget.wid.draw_idle()

        # self.fig1.tight_layout()
        # self.fig1.canvas.draw()

#     def showPlot(self): #show plot
# ##        self.fig1.show()
#         try:
#             # plt.pause(0.05)
#             # self.axes.relim()
#             # self.axes.autoscale()
#             self.fig1.tight_layout()
#             self.fig1.canvas.draw()
#             # self.plotWidget
#             # self.plotWidget.show()
#         except Exception as e:
#             print(e)

##        plt.show(block=False)
##        plt.draw()

# def handle_close(self, evt): #figure closed event
#     self.fig1_close = True
#     print("close")

    def convertPlot(self):  #convert plot to numpy
        self.fig1.canvas.draw()
        data = np.fromstring(self.fig1.canvas.tostring_rgb(),
                             dtype=np.uint8,
                             sep='')
        data = data.reshape(self.fig1.canvas.get_width_height()[::-1] + (3, ))
        return data

    def savePlot(self, filepath):  #save force plots
        print("save plot")
        self.fig1.savefig(filepath, orientation='landscape', transparent=True)
        #save figure object as pickle file
        with open(filepath[:-4] + '.pickle', 'wb') as f:
            pickle.dump(self.fig1, f, pickle.HIGHEST_PROTOCOL)


# class PlotWindow(QWidget):
#     def __init__(self, fig, *args, **kwargs):
# ##        super(QWidget, self).__init__(*args, **kwargs)
#         super().__init__()
#         self.setGeometry(100, 100, 1000, 500)
#         self.setWindowTitle("Plot")
#         self.fig = fig
#         self.home()

#     def home(self):
#         self.plotWidget = Plot2Widget(self.fig,cursor1_init=2,cursor2_init=6)
#         plotToolbar = NavigationToolbar(self.plotWidget, self)

#         plotGroupBox = QGroupBox()
#         plotlayout=QGridLayout()
#         plotGroupBox.setLayout(plotlayout)
#         plotlayout.addWidget(plotToolbar, 0, 0, 1, 1)
#         plotlayout.addWidget(self.plotWidget, 1, 0, 1, 1)

#         layout=QGridLayout()
#         layout.addWidget(plotGroupBox, 0, 0, 1, 1)

#         self.setLayout(layout)
# self.show()
# startFitLabel = QLabel("Start (%):")
Example #56
0
class ResultsWindow(QtWidgets.QMainWindow):
    def __init__(self, patient):
        super(ResultsWindow, self).__init__()

        self.Patient = patient
        self.Results = patient.Results
        self.Vessels = patient.Topology.Vessels

        self.main_widget = QtWidgets.QWidget(self)
        self.fs_watcher = QtCore.QFileSystemWatcher([self.Results.File])
        self.fs_watcher.fileChanged.connect(self.file_changed)

        self.fig = Figure()
        self.ax1 = self.fig.add_subplot(141)
        self.ax2 = self.fig.add_subplot(142)
        self.ax2_c = self.ax2.twinx()
        self.ax3 = self.fig.add_subplot(143)
        self.ax4 = self.fig.add_subplot(144)
        self.canvas = FigureCanvasQTAgg(self.fig)

        self.ax1.set_aspect('auto')
        self.ax2.set_aspect('auto')
        self.ax3.set_aspect('auto')

        self.canvas.setSizePolicy(QtWidgets.QSizePolicy.Expanding,
                                  QtWidgets.QSizePolicy.Expanding)
        self.canvas.updateGeometry()

        # Vessel selector
        self.VesselSelector = QtWidgets.QComboBox()
        self.VesselSelector.setStyleSheet("QComboBox { combobox-popup: 0; }")
        self.VesselSelector.setMaxVisibleItems(10)
        self.VesselSelector.addItems([vessel.Name for vessel in self.Vessels])

        # Node selector
        self.NodeSelector = QtWidgets.QSlider(QtCore.Qt.Horizontal)
        self.NodeSelector.setTickPosition(QtWidgets.QSlider.TicksBelow)

        # Save figure button
        self.SaveButton = QtWidgets.QPushButton("Save")
        # Exit Button
        self.ExitButton = QtWidgets.QPushButton("Exit")

        # Linking stuff to actions
        self.VesselSelector.currentIndexChanged.connect(self.updateVessel)
        self.NodeSelector.valueChanged.connect(self.updatefig)
        self.SaveButton.clicked.connect(self.savefig)
        self.ExitButton.clicked.connect(self.exit)

        # Add widgets to the layout
        self.layout = QtWidgets.QGridLayout(self.main_widget)
        self.layout.addWidget(QtWidgets.QLabel("Select Vessel"))
        self.layout.addWidget(self.VesselSelector)
        self.layout.addWidget(self.NodeSelector)
        self.layout.addWidget(self.canvas, 0, 0)

        self.groupBox = QtWidgets.QGroupBox()
        self.Buttons = QtWidgets.QGridLayout()

        # self.Buttons.addWidget(self.NodeSelector)
        self.Buttons.addWidget(self.SaveButton)
        self.Buttons.addWidget(self.ExitButton)
        self.groupBox.setLayout(self.Buttons)

        self.layout.addWidget(self.groupBox, 0, 1)
        self.setCentralWidget(self.main_widget)
        self.resize(2000, 800)
        self.updateVessel()
        self.updatefig()
        self.show()

    def updateVessel(self):
        # print("Changed Vessel")
        CurrentVessel = self.VesselSelector.currentText()
        StartNode, EndNode = [[vessel.Nodes[0], vessel.Nodes[-1]]
                              for vessel in self.Vessels
                              if vessel.Name == CurrentVessel][0]
        self.NodeSelector.setMinimum(StartNode.Number)
        self.NodeSelector.setMaximum(EndNode.Number)

    @QtCore.pyqtSlot(str)
    def file_changed(self):
        print('File Changed.')
        time.sleep(.100)  # wait for writing to finish
        self.Patient.Results.LoadResults(self.Results.File)
        self.Patient.CorrectforDirectionVectorOutlets()
        self.Patient.Results.CalculateVelocity()
        self.Patient.Results.CalculateMeanResultsPerNode()
        self.updatefig()

    def updatefig(self):
        self.ax1.clear()
        self.ax2.clear()
        self.ax2_c.clear()
        self.ax3.clear()
        self.ax4.clear()

        CurrentVessel = self.VesselSelector.currentText()
        CurrentNode = self.NodeSelector.value()
        # print('Vessel: ' + CurrentVessel + ', Node: ' + str(CurrentNode))

        self.ax1.set_title("Flow rate over time", fontsize=30)
        self.ax2.set_title("Pressure over time", fontsize=30)
        self.ax3.set_title("Radius over time", fontsize=30)
        self.ax4.set_title("Velocity over time", fontsize=30)

        self.ax1.set_xlabel("Time (s)", fontsize=30)
        self.ax2.set_xlabel("Time (s)", fontsize=30)
        self.ax3.set_xlabel("Time (s)", fontsize=30)
        self.ax4.set_xlabel("Time (s)", fontsize=30)

        self.ax1.set_ylabel("Flow rate (mL/s)", fontsize=30)
        self.ax2.set_ylabel("Pressure (Pa)", fontsize=30)
        self.ax2_c.set_ylabel("Pressure (mmHg)", fontsize=30)
        self.ax3.set_ylabel("Radius (mm)", fontsize=30)
        self.ax4.set_ylabel("Velocity (m/s)", fontsize=30)

        self.ax2.callbacks.connect("ylim_changed", self.PressureTommhg)

        self.ax1.xaxis.set_tick_params(labelsize=25)
        self.ax2.xaxis.set_tick_params(labelsize=25)
        self.ax3.xaxis.set_tick_params(labelsize=25)
        self.ax4.xaxis.set_tick_params(labelsize=25)

        self.ax1.yaxis.set_tick_params(labelsize=25)
        self.ax2.yaxis.set_tick_params(labelsize=25)
        self.ax2_c.yaxis.set_tick_params(labelsize=25)
        self.ax3.yaxis.set_tick_params(labelsize=25)
        self.ax4.yaxis.set_tick_params(labelsize=25)

        # print(CurrentNode)
        for hb in range(0, len(self.Results.VolumeFlowRate)):
            self.ax1.plot(self.Results.Time[0],
                          self.Results.VolumeFlowRate[hb][CurrentNode])
            self.ax2.plot(self.Results.Time[0],
                          self.Results.Pressure[hb][CurrentNode])
            self.ax3.plot(self.Results.Time[0],
                          self.Results.Radius[hb][CurrentNode])
            self.ax4.plot(self.Results.Time[0],
                          self.Results.Velocity[hb][CurrentNode])

        self.ax1.text(
            0.0,
            0.0,
            str(self.Results.MeanVolumeFlowRatePerNode[-1][CurrentNode]),
            fontsize=25,
            transform=self.ax1.transAxes)
        self.ax2.text(0.0,
                      0.0,
                      str(self.Results.MeanPressurePerNode[-1][CurrentNode]),
                      fontsize=25,
                      transform=self.ax2.transAxes)
        self.ax3.text(0.0,
                      0.0,
                      str(self.Results.MeanRadiusPerNode[-1][CurrentNode]),
                      fontsize=25,
                      transform=self.ax3.transAxes)
        self.ax4.text(0.0,
                      0.0,
                      str(self.Results.MeanVelocityPerNode[-1][CurrentNode]),
                      fontsize=25,
                      transform=self.ax4.transAxes)

        self.ax1.grid()
        self.ax2.grid()
        self.ax3.grid()
        self.ax4.grid()
        self.fig.canvas.draw_idle()
        self.fig.tight_layout()

    def savefig(self, figuredpi=72):
        # print("Save Button Pressed")
        node = self.NodeSelector.value()
        path = str(
            QtWidgets.QFileDialog.getExistingDirectory(
                directory="Select Directory"))

        self.fig.savefig(path + '/Figure Node:' + str(node) + '.png',
                         dpi=figuredpi)

        with open(path + '/Node: ' + str(node) + '.csv', 'w') as f:
            f.write(
                "time(s),Flowrate(mL/s),Pressure(Pa),Radius(m/s),Velocity(m/s)\n"
            )
            for i in range(len(self.Results.Time[-1])):
                f.write("%s," % self.Results.Time[-1][i])
                f.write("%s," % self.Results.VolumeFlowRate[-1][node][i])
                f.write("%s," % self.Results.Pressure[-1][node][i])
                f.write("%s," % self.Results.Radius[-1][node][i])
                f.write("%s\n" % self.Results.Velocity[-1][node][i])

    def PressureTommhg(self, ax2):
        y1, y2 = ax2.get_ylim()
        self.ax2_c.set_ylim(y1 * 0.007500617, y2 * 0.007500617)

    def exit(self):
        # print("Exit Button Pressed")
        self.close()
Example #57
0
def make_chart_response(country, deaths_start, avg_before_deaths, df_to_show):
    city = df_to_show['City'].iloc[0]
    df_quar = pd.read_csv(data_dir + 'all_countries_response.csv',
                          parse_dates=['Quarantine'])
    quarantine = df_quar[df_quar['Country'] == country]['Quarantine'].iloc[0]

    week = mdates.WeekdayLocator(interval=2)  # every year
    months = mdates.MonthLocator()  # every month
    month_fmt = mdates.DateFormatter('%b-%d')

    y_lim = df_to_show['TotalDeaths'].max() * 1.2
    y2_lim = df_to_show['no2'].max() * 1.8

    # Generate the figure **without using pyplot**.
    fig = Figure(figsize=(10, 5))

    ax = fig.subplots()

    ax.set_title('Assessing quarantine implementation - ' + country,
                 fontsize=16,
                 loc='left')

    if not pd.isnull(quarantine):
        ax.axvline(x=quarantine,
                   color='k',
                   linestyle='--',
                   lw=3,
                   label='Official quarantine')

    ax.scatter(df_to_show['Date'],
               df_to_show['TotalDeaths'],
               color='black',
               alpha=0.7,
               label='Confirmed deaths')

    ax.xaxis.set_major_locator(week)
    ax.xaxis.set_major_formatter(month_fmt)

    ax.set_yscale('log')
    ax.yaxis.set_major_formatter(
        ticker.FuncFormatter(lambda y, _: '{:g}'.format(y)))
    ax.set_ylim(1, y_lim)
    ax.set(ylabel='Confirmed deaths')

    ax2 = ax.twinx()

    sns.lineplot(x="Date",
                 y='no2',
                 alpha=0.7,
                 lw=6,
                 label='Daily $\mathrm{{NO}}_2$ pollution *',
                 ax=ax2,
                 data=df_to_show)
    sns.lineplot(x="Date",
                 y=avg_before_deaths,
                 alpha=0.7,
                 lw=6,
                 label='Average pollution **',
                 ax=ax2,
                 data=df_to_show)

    ax2.grid(False)
    ax2.xaxis.set_major_locator(week)
    ax2.xaxis.set_major_formatter(month_fmt)
    ax2.set_ylim(1, y2_lim)
    ax2.set(ylabel='$\mathrm{{NO}}_2$ pollution')

    # ask matplotlib for the plotted objects and their labels
    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines + lines2, labels + labels2, loc='upper left')

    annotation = """* Median of $\mathrm{{NO}}_2$ measurements in the most affected city ({city}), 5 days rolling average over time series\n** Average daily $\mathrm{{NO}}_2$ measurements from the begining of 2020 until the first day after {deaths_start} deaths""".format(
        city=city, deaths_start=deaths_start)
    ax.annotate(annotation, (0, 0), (0, -30),
                xycoords='axes fraction',
                textcoords='offset points',
                va='top')

    logo = plt.imread('./static/img/new_logo_site.png')
    ax.figure.figimage(logo, 100, 110, alpha=.35, zorder=1)

    fig.tight_layout()
    # Save it to a temporary buffer.
    buf = BytesIO()
    fig.savefig(buf, format="png")
    buf.seek(0)
    return buf
Example #58
0
class MplWidget(QWidget):
    """
    Construct a subwidget with Matplotlib canvas and NavigationToolbar
    """
    def __init__(self, parent):
        super(MplWidget, self).__init__(parent)
        # Create the mpl figure and subplot (white bg, 100 dots-per-inch).
        # Construct the canvas with the figure
        #
        self.plt_lim = []  # x,y plot limits
        self.fig = Figure()
        #        self.mpl = self.fig.add_subplot(111) # self.fig.add_axes([.1,.1,.9,.9])#
        #        self.mpl21 = self.fig.add_subplot(211)

        self.pltCanv = FigureCanvas(self.fig)
        self.pltCanv.setSizePolicy(QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)

        # Needed for mouse modifiers (x,y, <CTRL>, ...):
        #    Key press events in general are not processed unless you
        #    "activate the focus of Qt onto your mpl canvas"
        # http://stackoverflow.com/questions/22043549/matplotlib-and-qt-mouse-press-event-key-is-always-none
        self.pltCanv.setFocusPolicy(QtCore.Qt.ClickFocus)
        self.pltCanv.setFocus()

        self.pltCanv.updateGeometry()

        # Create a custom navigation toolbar, tied to the canvas and
        # initialize toolbar settings
        #
        #self.mplToolbar = NavigationToolbar(self.pltCanv, self) # original
        self.mplToolbar = MyMplToolbar(self.pltCanv, self)
        self.mplToolbar.grid = True
        self.mplToolbar.lock_zoom = False
        self.mplToolbar.enable_update(state=True)
        self.mplToolbar.sigEnabled.connect(self.clear_disabled_figure)

        #=============================================
        # Widget layout with QHBox / QVBox
        #=============================================

        #        self.hbox = QHBoxLayout()
        #
        #        for w in [self.mpl_toolbar, self.butDraw, self.cboxGrid]:
        #            self.hbox.addWidget(w)
        #            self.hbox.setAlignment(w, QtCore.Qt.AlignVCenter)
        #        self.hbox.setSizeConstraint(QLayout.SetFixedSize)

        self.layVMainMpl = QVBoxLayout()
        #        self.layVMainMpl.addLayout(self.hbox)
        self.layVMainMpl.addWidget(self.mplToolbar)
        self.layVMainMpl.addWidget(self.pltCanv)

        self.setLayout(self.layVMainMpl)

#------------------------------------------------------------------------------

    def save_limits(self):
        """
        Save x- and y-limits of all axes in self.limits when zoom is unlocked
        """
        if not self.mplToolbar.lock_zoom:
            for ax in self.fig.axes:
                self.limits = ax.axis()  # save old limits

#------------------------------------------------------------------------------

    def redraw(self):
        """
        Redraw the figure with new properties (grid, linewidth)
        """
        # only execute when at least one axis exists -> tight_layout crashes otherwise
        if self.fig.axes:
            for ax in self.fig.axes:
                ax.grid(self.mplToolbar.grid
                        )  # collect axes objects and toggle grid
                #        plt.artist.setp(self.pltPlt, linewidth = self.sldLw.value()/5.)
                if self.mplToolbar.lock_zoom:
                    ax.axis(self.limits)  # restore old limits
                else:
                    self.limits = ax.axis()  # save old limits

            self.fig.tight_layout(pad=0.2)
#        self.pltCanv.updateGeometry()
#        self.pltCanv.adjustSize() #  resize the parent widget to fit its content
        self.pltCanv.draw()  # now (re-)draw the figure

#------------------------------------------------------------------------------

    def clear_disabled_figure(self):
        """
        Clear the figure when it is disabled in the mplToolbar
        """
        if not self.mplToolbar.enabled:
            self.fig.clf()
            self.pltCanv.draw()
        else:
            self.redraw()

#------------------------------------------------------------------------------

    def plt_full_view(self):
        """
        Zoom to full extent of data if axes is set to "navigationable"
        by the navigation toolbar
        """
        #Add current view limits to view history to enable "back to previous view"
        self.mplToolbar.push_current()
        for ax in self.fig.axes:
            if ax.get_navigate():
                ax.autoscale()
        self.redraw()

#------------------------------------------------------------------------------

    def get_full_extent(self, ax, pad=0.0):
        """
        Get the full extent of an axes, including axes labels, tick labels, and
        titles.
        """
        #http://stackoverflow.com/questions/14712665/matplotlib-subplot-background-axes-face-labels-colour-or-figure-axes-coor
        # For text objects, we need to draw the figure first, otherwise the extents
        # are undefined.
        self.pltCanv.draw()
        items = ax.get_xticklabels() + ax.get_yticklabels()
        items += [ax, ax.title, ax.xaxis.label, ax.yaxis.label]
        #        items += [ax, ax.title]
        bbox = Bbox.union([item.get_window_extent() for item in items])
        return bbox.expanded(1.0 + pad, 1.0 + pad)
Example #59
0
nodes, bcast, tar, chmod, bcasttot, launch = numpy.loadtxt('./collect.txt',
                                                           unpack=True)

f = Figure(figsize=(4, 4))
ax = f.add_subplot(111)

ax.plot(nodes * 32, launch, 's ', color='r', mec='none', label='import scipy')
ax.plot(nodes * 32,
        bcasttot,
        ls='none',
        marker=(8, 2, 0),
        color='m',
        label='bcast')
ax.plot(nodes * 32, bcast, 'x ', color='g', mew=1, label='bcast/MPI_Bcast')
ax.plot(nodes * 32, tar, '+ ', color='b', mew=1, label='bcast/tar xzvf')
ax.plot(nodes * 32, launch + bcasttot, 'D ', color='k', label='total')

ax.set_xlabel('Number of Ranks')
ax.set_ylabel('Wall time [sec]')
ax.set_ylim(3e-1, 2e3)
ax.set_xscale('log')
ax.set_yscale('log')
#ax.grid()
ax.legend(loc='upper left', frameon=False, ncol=1, fontsize='small')
canvas = FigureCanvasAgg(f)
f.tight_layout()
f.savefig('cray-xt-startup-time.png', dpi=72)
f.savefig('cray-xt-startup-time-hires.png', dpi=200)
f.savefig('cray-xt-startup-time-hires.pdf', dpi=200)
Example #60
0
class GraphWindow(tk.Toplevel):
    '''
    General plan:
        * one frame of _what to plot_ (points, regression line, design curves, etc)
        * One frame of _how to plot_ (symbols, lines, limits, grid)
        * misc (e.g. make plot button)
    This class does not do ANY plotting directly: it is just an interface to 
    choose parameters for the plotting performed in EdnaCalc
    '''

    ##########################################################################
    ###############     USER INTERFACE INITIALISATION
    ##########################################################################
    def __init__(self, parent):
        self.parent = parent
        self.root = tk.Toplevel()
        self.root.title("PyEdna")
        self.init_values()
        self.init_frames()
        self.init_what()
        self.init_how()
        self.init_misc()
        self.init_graph()
        self.busy_starting = False
        return None

    def init_values(self):
        '''Initialise the values that the various buttoms will rely on'''
        self.busy_starting = True  # Used to prevent callbacks until everything is initialised

        self.plot_points = tk.BooleanVar()
        self.plot_points.set(True)
        self.plot_regression = tk.BooleanVar()
        self.plot_regression.set(True)
        self.plot_conf_pt = tk.BooleanVar()
        self.plot_conf_reg = tk.BooleanVar()
        self.plot_dc_bs540 = tk.BooleanVar()
        self.plot_dc_ec3 = tk.BooleanVar()

        self.grid_major = tk.BooleanVar()
        self.grid_minor = tk.BooleanVar()

        self.symbol = tk.StringVar()
        self.symbol.set("o")

        self.line = tk.StringVar()
        self.line.set("-")

        self.axis_limit = tk.StringVar()
        self.axis_limit.set(STEEL)
        self.axis_limit.trace("w", self.btn_axis_limits_changed)

        self.axis_y_style = tk.StringVar()
        self.axis_y_style.set(LOG)

        # Keeping track of graph size in cm
        self.x_size = tk.DoubleVar()
        self.y_size = tk.DoubleVar()
        self.x_size.trace("w", self.set_graph_size)
        self.y_size.trace("w", self.set_graph_size)

        # Axis limits
        self.n_min = tk.IntVar()
        self.n_max = tk.IntVar()
        self.s_min = tk.IntVar()
        self.s_max = tk.IntVar()
        self.limit_vars = (self.n_min, self.n_max, self.s_min, self.s_max)
        for i, var in enumerate(self.limit_vars):
            var.set(STEEL_lim[i])

        self.font_size = tk.IntVar()
        self.font_size.set(12)

        return None

    def init_frames(self):
        '''Initialise the frames that hold the various UI elements
        self.root has a Left and Right frame
        Left is divded into the various controls (what, how, limits, misc)
        Right holds the actual graph'''
        self.frame_left = tk.Frame(self.root)
        self.frame_right = tk.Frame(self.root)
        self.frame_left.grid(row=0, column=0, sticky="nsew")
        self.frame_right.grid(row=0, column=1, sticky="nsew")
        self.root.grid_columnconfigure(1, weight=1)
        self.root.grid_rowconfigure(0, weight=1)

        self.frame_what = tk.Frame(self.frame_left, height=250)
        self.frame_how = tk.Frame(self.frame_left, height=250)
        self.frame_misc = tk.Frame(self.frame_left, height=50)
        self.frame_what.grid(row=1,
                             column=0,
                             sticky="nsew",
                             pady=PAD,
                             padx=PAD)
        self.frame_what.grid_columnconfigure(
            0, weight=1)  # allow to grow width-wise
        self.frame_how.grid(row=2, column=0, sticky="nsew", pady=PAD, padx=PAD)
        self.frame_how.grid_columnconfigure(0, weight=1, uniform="how")
        self.frame_how.grid_columnconfigure(1, weight=1, uniform="how")
        self.frame_how.grid_columnconfigure(2, weight=1, uniform="how")
        self.frame_misc.grid(row=3,
                             column=0,
                             sticky="nsew",
                             pady=PAD,
                             padx=PAD)
        self.frame_misc.grid_columnconfigure(0, weight=1)
        return None

    def init_what(self):
        '''Initialise the actual UI elements in the "what" frame
        In this sense, "what" refers to user decisions on what data will be plotted on the graph:
            * Data points entered into the program
            * Regression analysis
            * Confidence interval of one or both
            
            * Design curves given by the various standards'''
        self.what_title = tk.Label(self.frame_what,
                                   text="Plotting Features",
                                   font=FONT)
        self.what_subtitle = tk.Label(self.frame_what,
                                      text="Design Curves",
                                      font=FONT)

        self.bt_plot_points = tk.Checkbutton(self.frame_what,
                                             text="Data points",
                                             variable=self.plot_points)
        self.bt_plot_regres = tk.Checkbutton(self.frame_what,
                                             text="Regression line",
                                             variable=self.plot_regression)
        self.bt_plot_points_conf = tk.Checkbutton(
            self.frame_what,
            text="95% conf. for reg. line",
            variable=self.plot_conf_reg)
        self.bt_plot_regres_conf = tk.Checkbutton(
            self.frame_what,
            text="95% conf. for given value of S",
            variable=self.plot_conf_pt)
        self.bt_dc_bs540 = tk.Checkbutton(
            self.frame_what,
            text="95% Surv, 97.5% Conf (BS540, NS3472)",
            variable=self.plot_dc_bs540)
        self.bt_dc_ec3 = tk.Checkbutton(self.frame_what,
                                        text="95% Surv, 75% Conf (EC3)",
                                        variable=self.plot_dc_ec3)

        self.what_title.grid(row=0, sticky="nsew")

        self.bt_plot_points.grid(row=1, sticky="nsw")
        self.bt_plot_regres.grid(row=2, sticky="nsw")
        self.bt_plot_points_conf.grid(row=3, sticky="nsw")
        self.bt_plot_regres_conf.grid(row=4, sticky="nsw")

        self.what_subtitle.grid(row=5, sticky="nsew")

        self.bt_dc_bs540.grid(row=6, sticky="nsw")
        self.bt_dc_ec3.grid(row=7, sticky="nsw")
        return None

    def init_how(self):
        '''Initialise the actual UI elements relating to "how"
        In this sense, "how" refers to how the plotted data will be presented:
            * Axis limits
            * Fitted line limits
            * Symbols and line style
            * Grid
        '''
        how_title = tk.Label(self.frame_how, text="Plotting style", font=FONT)
        how_title.grid(row=0, column=0, columnspan=4, sticky="nsew")

        # Data point symbols - here the variable is directly the Matplotlib marker command
        how_subtitle_symbol = tk.Label(self.frame_how,
                                       text="Symbol style",
                                       font=FONT)
        bt_symbols = (
            tk.Radiobutton(self.frame_how,
                           text="Circle",
                           variable=self.symbol,
                           value="o"),
            tk.Radiobutton(
                self.frame_how,
                text="Square",
                variable=self.symbol,
                value="s",
            ),
            tk.Radiobutton(
                self.frame_how,
                text="Triangle",
                variable=self.symbol,
                value="^",
            ),
            tk.Radiobutton(
                self.frame_how,
                text="Cross",
                variable=self.symbol,
                value="x",
            ),
            tk.Radiobutton(
                self.frame_how,
                text="Star",
                variable=self.symbol,
                value="*",
            ),
            tk.Radiobutton(
                self.frame_how,
                text="Diamond",
                variable=self.symbol,
                value="d",
            ),
        )
        how_subtitle_symbol.grid(row=1, column=0, sticky="nsew")
        for i, sym in enumerate(bt_symbols):
            sym.grid(row=i + 2, column=0, sticky="nsw")

        # Line styles - here the variable is directly the Matplotlib marker command
        how_subtitle_line = tk.Label(self.frame_how,
                                     text="Line style",
                                     font=FONT)
        bt_lines = (tk.Radiobutton(self.frame_how,
                                   text="Solid",
                                   variable=self.line,
                                   value="-"),
                    tk.Radiobutton(
                        self.frame_how,
                        text="Dashed",
                        variable=self.line,
                        value="--",
                    ),
                    tk.Radiobutton(
                        self.frame_how,
                        text="Dotted",
                        variable=self.line,
                        value=":",
                    ),
                    tk.Radiobutton(
                        self.frame_how,
                        text="Dash-dot",
                        variable=self.line,
                        value="-.",
                    ))
        how_subtitle_line.grid(row=1, column=1, sticky="nsew")
        for i, lin in enumerate(bt_lines):
            lin.grid(row=i + 2, column=1, sticky="nsw")

        # Axis limits
        how_subtitle_axis_limits = tk.Label(self.frame_how,
                                            text="Axis Limits",
                                            font=FONT)
        bt_axis_limits = (
            tk.Radiobutton(self.frame_how,
                           text="Steel",
                           variable=self.axis_limit,
                           value=STEEL),
            tk.Radiobutton(self.frame_how,
                           text="Aluminium",
                           variable=self.axis_limit,
                           value=AL),
            tk.Radiobutton(self.frame_how,
                           text="N-limits",
                           variable=self.axis_limit,
                           value=NN),
            tk.Radiobutton(self.frame_how,
                           text="S-limits",
                           variable=self.axis_limit,
                           value=SS),
            tk.Radiobutton(self.frame_how,
                           text="Auto limits",
                           variable=self.axis_limit,
                           value=AUTO),
            tk.Radiobutton(self.frame_how,
                           text="Manual limits",
                           variable=self.axis_limit,
                           value=MANUAL),
        )
        how_subtitle_axis_limits.grid(row=1, column=2, sticky="nsew")
        for i, axl in enumerate(bt_axis_limits):
            axl.grid(row=i + 2, column=2, sticky="nsw")

        # Grid - be careful of distinction between _tkinter grid_ and _grid to be plotted in graph_
        nr = 10  # New row
        col = 0
        how_subtitle_grid = tk.Label(self.frame_how, text="Grid", font=FONT)
        how_subtitle_grid.grid(row=nr, column=col, sticky="nsew")
        bt_grid_major = tk.Checkbutton(self.frame_how,
                                       text="Major",
                                       variable=self.grid_major)
        bt_grid_major.grid(row=nr + 1, column=col, sticky="nsw")
        bt_grid_minor = tk.Checkbutton(self.frame_how,
                                       text="Minor",
                                       variable=self.grid_minor)
        bt_grid_minor.grid(row=nr + 2, column=col, sticky="nsw")

        # Axis Style
        col = 1
        how_subtitle_axis = tk.Label(self.frame_how, text="Y axis", font=FONT)
        how_subtitle_axis.grid(row=nr, column=col, sticky="nsew")
        bt_axis_style = (
            tk.Radiobutton(self.frame_how,
                           text="Log",
                           variable=self.axis_y_style,
                           value=LOG),
            tk.Radiobutton(self.frame_how,
                           text="Linear",
                           variable=self.axis_y_style,
                           value=LINEAR),
        )
        for i, bt in enumerate(bt_axis_style):
            bt.grid(row=nr + 1 + i, column=col, sticky="nsw")

        return None

    def init_misc(self):
        '''Initialise figure style handling'''
        # Everything in here is referred to by linekd variables, therefore nothing needs to be added to self
        figstyle_title = tk.Label(self.frame_misc,
                                  text="Figure Style",
                                  font=FONT)
        figstyle_title.grid(row=0, column=0, columnspan=4, sticky="nsew")

        lim_xmin_label = tk.Label(self.frame_misc, text="N (min)")
        lim_xmax_label = tk.Label(self.frame_misc, text="N (max)")
        lim_ymin_label = tk.Label(self.frame_misc, text="S (min)")
        lim_ymax_label = tk.Label(self.frame_misc, text="S (max)")

        lim_xmin = tk.Entry(self.frame_misc,
                            textvariable=self.n_min,
                            state="disabled")
        lim_xmax = tk.Entry(self.frame_misc,
                            textvariable=self.n_max,
                            state="disabled")
        lim_ymin = tk.Entry(self.frame_misc,
                            textvariable=self.s_min,
                            state="disabled")
        lim_ymax = tk.Entry(self.frame_misc,
                            textvariable=self.s_max,
                            state="disabled")
        self.limit_entries = (lim_xmin, lim_xmax, lim_ymin, lim_ymax)

        lim_xmin_label.grid(row=4, column=0, sticky="nsw")
        lim_xmax_label.grid(row=5, column=0, sticky="nsw")
        lim_ymin_label.grid(row=4, column=2, sticky="nsw")
        lim_ymax_label.grid(row=5, column=2, sticky="nsw")

        lim_xmin.grid(row=4, column=1, sticky="nsew")
        lim_xmax.grid(row=5, column=1, sticky="nsew")
        lim_ymin.grid(row=4, column=3, sticky="nsew")
        lim_ymax.grid(row=5, column=3, sticky="nsew")

        font_size_label = tk.Label(self.frame_misc, text="Font (pt)")
        font_size = tk.Entry(self.frame_misc,
                             textvariable=self.font_size,
                             state="normal")

        font_size_label.grid(row=6, column=0, sticky="nsw")
        font_size.grid(row=6, column=1, sticky="nsew")

        self.btn_plot = tk.Button(self.frame_left,
                                  text="Plot SN curve",
                                  font=FONT,
                                  command=self.plot_curve,
                                  state="normal")
        self.btn_plot.grid(row=4, column=0, sticky="nsew", pady=PAD, padx=PAD)

        return None

    def init_graph(self):
        '''based on https://matplotlib.org/3.1.0/gallery/user_interfaces/embedding_in_tk_sgskip.html'''
        self.fig = Figure()
        self.ax = self.fig.add_subplot(111)
        self.graph = FigureCanvasTkAgg(self.fig, master=self.frame_right)
        self.graph.draw()
        self.graph.get_tk_widget().pack(side="top", fill="both", expand=1)
        self.graph_toolbar = NavigationToolbar2Tk(self.graph, self.frame_right)
        self.graph_toolbar.update()
        self.graph.get_tk_widget().pack(side="top", fill="both", expand=1)

        self.frame_right.bind("<Configure>", self.graph_resized)

        self.ax.set_xlabel("N [cycles]")
        self.ax.set_ylabel("S [MPa]")
        self.ax.set_title('Fatigue Lifecycle')

        return None

    def btn_axis_limits_changed(self, *args, **kwargs):
        '''Possible values are: "steel", "aluminium", "n", "s", "auto", "manual".
        Limits are grouped in a list like so: (n_min, n_max, s_min, s_max)
        Where auto ranges are required, we have to get the range of data that will be plotted'''
        # TODO: SHould this happen over in EdnaCalc? The thing is that it is also UI related, including the numbers that should appear
        if self.parent:
            data = self.parent.calc.data[
                self.parent.
                selected_data]  # TODO This will currently ignore merging, probably want a method in ednacalc to provide the relevant data
        else:  #DEBUGGING PURPOSES ONLY
            data = np.arange(10).reshape(5, 2)
        actual_data_range_limit = np.array(
            (np.min(data[:, 1]), np.max(data[:, 1]), np.min(data[:, 0]),
             np.max(data[:, 0])))
        actual_data_range_limit *= np.array(
            (0.8, 1.2, 0.8, 1.2))  # give a 20% buffer around
        limit_type = self.axis_limit.get()
        if limit_type == STEEL:
            lims = STEEL_lim
            states = ["disabled"] * 4
        elif limit_type == AL:
            lims = AL_lim
            states = ["disabled"] * 4
        elif limit_type == NN:
            states = ("normal", "normal", "disabled", "disabled")
            lims = (self.n_min.get(), self.n_max.get(),
                    actual_data_range_limit[2], actual_data_range_limit[3])
        elif limit_type == SS:
            states = ("disabled", "disabled", "normal", "normal")
            lims = (actual_data_range_limit[0], actual_data_range_limit[1],
                    self.s_min.get(), self.s_max.get())
        elif limit_type == AUTO:
            states = ["disabled"] * 4
            lims = actual_data_range_limit
        elif limit_type == MANUAL:
            states = ["normal"] * 4
            lims = (self.n_min.get(), self.n_max.get(), self.s_min.get(),
                    self.s_max.get())

        for i in range(4):
            self.limit_entries[i].config(state=states[i])
            self.limit_vars[i].set(lims[i])

        return None

    def graph_resized(self, event, **kwargs):
        '''Callback when graph size changes'''
        (x, y) = self.fig.get_size_inches()
        self.x_size.set(x * INCH2CM)
        self.y_size.set(y * INCH2CM)
        return None

    def set_graph_size(self, *args, **kwargs):
        '''callback when graph size setting changed'''
        # TODO: THis bit doesn't actually work!

        if self.busy_starting:
            #print("Callback (pass)")
            pass
        else:
            #print("Callback (change)")
            x = self.x_size.get() / INCH2CM
            y = self.y_size.get() / INCH2CM
            self.fig.set_size_inches(x, y)
        return None

    ##########################################################################
    ###############     FUNCTIONAL CODE HERE
    ##########################################################################

    def plot_curve(self):
        '''Triggered by "Plot SN curve" button
        Pass the resulting values on to EdnaCalc as a dictionary. The idea is 
        that the actual plotting is handled in EdnaCalc, so it can be run
        standalone if desired'''
        kwargs = {
            "marker": self.symbol.get(),
            "line_style": self.line.get(),
            "axis_limits": self.get_axis_limits(),
            "axis_style":
            self.axis_y_style.get() == LOG,  # Send a boolean, i.e. isLog
            "grid_major": self.grid_major.get(),
            "grid_minor": self.grid_minor.get(),
            "plot_points": self.plot_points.get(),
            "plot_regression": self.plot_regression.get(),
            "plot_points_conf": self.plot_conf_pt.get(),
            "plot_regression_conf": self.plot_conf_reg.get(),
            "plot_dc_bs540": self.plot_dc_bs540.get(),
            "plot_dc_ec3": self.plot_dc_ec3.get(),
            "font": self.font_size.get(),
            "fig": self.fig,
            "ax": self.ax,
        }
        if self.parent is not None:
            new_limits = self.parent.calc.plot_results(
                self.parent.selected_data, **kwargs)
            # Based on the actual limits used in the graph (in the case of autoset values),
            # update the axis limit text boxes
            for i in range(4):
                # Round to a nice number
                self.limit_vars[i].set(int(new_limits[i]))
            self.refresh_graph()
        else:
            # This should only be used for debugging purposes, where the graph_plotter
            # is initialised without a parent window
            print("Plot! \n" + str(kwargs))

        return None

    def get_axis_limits(self):
        '''This is distinct from the button function above in that it sends None values
        to Matplotlib to properly handle autosetting axis sizes, while None type can't
        be shown through the user interface
        '''
        limit_type = self.axis_limit.get()
        if limit_type == AUTO:
            # Let matplotlib decide both axes
            limits = (None, None)
        elif limit_type == NN:
            # The user is specifying the N (X axis) limits, but let the Y axis float
            limits = ((self.n_min.get(), self.n_max.get()), None)
        elif limit_type == SS:
            # The user is specifying the S (Y axis) limits, but let the X axis float
            limits = (None, (self.s_min.get(), self.s_max.get()))
        else:
            limits = ((self.n_min.get(), self.n_max.get()), (self.s_min.get(),
                                                             self.s_max.get()))
        return limits

    def refresh_graph(self, *args, **kwargs):
        self.fig.tight_layout()
        self.graph.draw()
        return None