Ejemplo n.º 1
0
def test_repeated_save_with_alpha():
    # We want an image which has a background color of bluish green, with an
    # alpha of 0.25.

    fig = Figure([1, 0.4])
    canvas = FigureCanvas(fig)
    fig.set_facecolor((0, 1, 0.4))
    fig.patch.set_alpha(0.25)

    # The target color is fig.patch.get_facecolor()

    _, img_fname = tempfile.mkstemp(suffix='.png')
    try:
        fig.savefig(img_fname,
                    facecolor=fig.get_facecolor(),
                    edgecolor='none')

        # Save the figure again to check that the
        # colors don't bleed from the previous renderer.
        fig.savefig(img_fname,
                    facecolor=fig.get_facecolor(),
                    edgecolor='none')

        # Check the first pixel has the desired color & alpha
        # (approx: 0, 1.0, 0.4, 0.25)
        assert_array_almost_equal(tuple(imread(img_fname)[0, 0]),
                                  (0.0, 1.0, 0.4, 0.250),
                                  decimal=3)
    finally:
        os.remove(img_fname)
Ejemplo n.º 2
0
class MplCanvas(FigureCanvas):
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        matplotlib.rcParams['font.size'] = 8
        self.figure = Figure(figsize=(width, height), dpi=dpi)
        self.axes = self.figure.add_subplot(111)

        FigureCanvas.__init__(self, self.figure)
        self.setParent(parent)
        
        self.toolbar = NavigationToolbar(self, parent)
        self.toolbar.setIconSize(QSize(16, 16))

        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        
    def getToolbar(self):
        return self.toolbar

    def clear(self):
        self.figure.clear()
        self.axes = self.figure.add_subplot(111)
        
    def test(self):
        self.axes.plot([1,2,3,4])
        
    def saveAs(self, fname):
        self.figure.savefig(fname)
Ejemplo n.º 3
0
def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None, origin=None):
    """
    Saves a 2D :class:`numpy.array` as an image with one pixel per element.
    The output formats available depend on the backend being used.

    Arguments:
      *fname*:
        A string containing a path to a filename, or a Python file-like object.
        If *format* is *None* and *fname* is a string, the output
        format is deduced from the extension of the filename.
      *arr*:
        A 2D array.
    Keyword arguments:
      *vmin*/*vmax*: [ None | scalar ]
        *vmin* and *vmax* set the color scaling for the image by fixing the
        values that map to the colormap color limits. If either *vmin* or *vmax*
        is None, that limit is determined from the *arr* min/max value.
      *cmap*:
        cmap is a colors.Colormap instance, eg cm.jet.
        If None, default to the rc image.cmap value.
      *format*:
        One of the file extensions supported by the active
        backend.  Most backends support png, pdf, ps, eps and svg.
      *origin*
        [ 'upper' | 'lower' ] Indicates where the [0,0] index of
        the array is in the upper left or lower left corner of
        the axes. Defaults to the rc image.origin value.
    """
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
    from matplotlib.figure import Figure

    fig = Figure(figsize=arr.shape[::-1], dpi=1, frameon=False)
    canvas = FigureCanvas(fig)
    fig.figimage(arr, cmap=cmap, vmin=vmin, vmax=vmax, origin=origin)
    fig.savefig(fname, dpi=1, format=format)
Ejemplo n.º 4
0
def plot_tdc_event(points, filename=None):
    fig = Figure()
    FigureCanvas(fig)
    ax = fig.add_subplot(111, projection='3d')
    xs = points[:, 0]
    ys = points[:, 1]
    zs = points[:, 2]
    cs = points[:, 3]

    p = ax.scatter(xs, ys, zs, c=cs, s=points[:, 3] ** (2) / 5., marker='o')

    ax.set_xlabel('x [250 um]')
    ax.set_ylabel('y [50 um]')
    ax.set_zlabel('t [25 ns]')
    ax.title('Track of one TPC event')
    ax.set_xlim(0, 80)
    ax.set_ylim(0, 336)

    c_bar = fig.colorbar(p)
    c_bar.set_label('charge [TOT]')

    if not filename:
        fig.show()
    elif isinstance(filename, PdfPages):
        filename.savefig(fig)
    elif filename:
        fig.savefig(filename)
    return fig
Ejemplo n.º 5
0
def tBidBax_kd(mcmc_set):
    """ .. todo:: document the basis for this"""
    num_kds = 10000
    # Get indices for the tBid/Bax binding constants
    estimate_params = mcmc = mcmc_set.chains[0].options.estimate_params
    tBid_iBax_kf_index = None
    tBid_iBax_kr_index = None
    for i, p in enumerate(estimate_params):
        if p.name == 'tBid_iBax_kf':
            tBid_iBax_kf_index = i
        elif p.name == 'tBid_iBax_kr':
            tBid_iBax_kr_index = i
    # If we couldn't find the parameters, return None for the result
    if tBid_iBax_kf_index is None or tBid_iBax_kr_index is None:
        return Result(None, None)
    # Sample the kr/kf ratio across the pooled chains
    kd_dist = np.zeros(num_kds)
    for i in range(num_kds):
        position = mcmc_set.get_sample_position()
        kd_dist[i] = ((10 ** position[tBid_iBax_kr_index]) /
                      (10 ** position[tBid_iBax_kf_index]))
    # Calculate the mean and variance
    mean = kd_dist.mean()
    sd = kd_dist.std()

    # Plot the Kd distribution
    plot_filename = '%s_tBidiBax_kd_dist.png' % mcmc_set.name
    fig = Figure()
    ax = fig.gca()
    ax.hist(kd_dist)
    canvas = FigureCanvasAgg(fig)
    fig.set_canvas(canvas)
    fig.savefig(plot_filename)

    return MeanSdResult(mean, sd, plot_filename)
Ejemplo n.º 6
0
    def draw_timeline(self):
        fig = Figure(figsize=(7, 3))
        canvas = FigureCanvas(fig)
        ax = fig.add_subplot(111, axisbg="white")

        xseries = range(len(TRADING_MINUTES))
        pct = lambda x, _: "{0:1.1f}%".format(100*x)
        xseries_show = xseries[::30]
        xlabels_show = TRADING_MINUTES[::30]

        ax.clear()
        ax.plot(xseries, self._portfolio_timeline, label="portfolio", linewidth=1.0, color="r")
        ax.plot(xseries, self._benchmark_timeline, label="benchmark", linewidth=1.0, color="b")

        ax.yaxis.set_major_formatter(FuncFormatter(pct))
        for item in ax.get_yticklabels():
            item.set_size(10)

        ax.set_xlim(0, 240)
        ax.set_xticks(xseries_show)
        ax.set_xticklabels(xlabels_show, fontsize=10)

        ax.grid(True)
        ax.legend(loc=2, prop={"size": 9})

        fig.tight_layout()
        fig.savefig(SNAPSHOT_IMG_FILE)
        logger.info("Snapshot image saved at ./static/temp/snapshot.jpg.")
Ejemplo n.º 7
0
def test_repeated_save_with_alpha():
    # We want an image which has a background color of bluish green, with an
    # alpha of 0.25.

    fig = Figure([1, 0.4])
    fig.set_facecolor((0, 1, 0.4))
    fig.patch.set_alpha(0.25)

    # The target color is fig.patch.get_facecolor()

    buf = io.BytesIO()

    fig.savefig(buf,
                facecolor=fig.get_facecolor(),
                edgecolor='none')

    # Save the figure again to check that the
    # colors don't bleed from the previous renderer.
    buf.seek(0)
    fig.savefig(buf,
                facecolor=fig.get_facecolor(),
                edgecolor='none')

    # Check the first pixel has the desired color & alpha
    # (approx: 0, 1.0, 0.4, 0.25)
    buf.seek(0)
    assert_array_almost_equal(tuple(imread(buf)[0, 0]),
                              (0.0, 1.0, 0.4, 0.250),
                              decimal=3)
Ejemplo n.º 8
0
def residuals_at_max_likelihood(mcmc_set):
    # Get the maximum likelihood parameters
    try:
        (max_likelihood, max_likelihood_position) = mcmc_set.maximum_likelihood()
    except NoPositionsException as npe:
        return Result(None, None)

    # Get the residuals
    residuals = mcmc_set.chains[0].get_residuals(max_likelihood_position)

    # Make the residuals plot
    fig = Figure()
    ax = fig.gca()
    plot_filename = '%s_max_likelihood_residuals.png' % mcmc_set.name
    thumbnail_filename = '%s_max_likelihood_residuals_th.png' % mcmc_set.name
    ax.plot(residuals[0], residuals[1])
    ax.set_title('Residuals at Maximum Likelihood')
    #ax.xlabel('Time')
    #ax.ylabel('Residual')
    canvas = FigureCanvasAgg(fig)
    fig.set_canvas(canvas)
    fig.savefig(plot_filename)
    fig.savefig(thumbnail_filename, dpi=10)

    return ThumbnailResult(thumbnail_filename, plot_filename)
Ejemplo n.º 9
0
def acf_of_ml_residuals(mcmc_set):
    # Get the maximum likelihood parameters
    try:
        (max_likelihood, max_likelihood_position) = mcmc_set.maximum_likelihood()
    except NoPositionsException as npe:
        return Result(None, None)

    # Get the residuals
    residuals = mcmc_set.chains[0].get_residuals(max_likelihood_position)

    # Plot the autocorrelation function
    acf = np.correlate(residuals[1], residuals[1], mode='full')

    plot_filename = '%s_acf_of_ml_residuals.png' % mcmc_set.name
    thumbnail_filename = '%s_acf_of_ml_residuals_th.png' % mcmc_set.name
    fig = Figure()
    ax = fig.gca()
    ax.plot(acf)
    ax.set_title('Autocorrelation of Maximum Likelihood Residuals')

    canvas = FigureCanvasAgg(fig)
    fig.set_canvas(canvas)
    fig.savefig(plot_filename)
    fig.savefig(thumbnail_filename, dpi=10)

    return ThumbnailResult(thumbnail_filename, plot_filename)
Ejemplo n.º 10
0
def lineChart(size, data, output):
    x = data['x']
    y = data['y']

    fig = Figure(figsize=(size[0], size[1]), dpi=size[2])
    FigureCanvas(fig)  # Stores canvas on fig.canvas

    axis = fig.add_subplot(111)
    axis.grid(color='r', linestyle='dotted', linewidth=0.1, alpha=0.5)

    for i in y:
        yy = i['data']
        axis.plot(x, yy, label=i.get('label'), marker='.', color='orange')
        axis.fill_between(x, yy, 0)

    axis.set_title(data.get('title', ''))
    axis.set_xlabel(data['xlabel'])
    axis.set_ylabel(data['ylabel'])

    if data.get('allTicks', True) is True:
        axis.set_xticks(x)

    if 'xtickFnc' in data:
        axis.set_xticklabels([data['xtickFnc'](v) for v in axis.get_xticks()])

    axis.legend()

    fig.savefig(output, format='png', transparent=True)
Ejemplo n.º 11
0
def write_figures(prefix, directory, dose_name, dose_data, data, ec50_coeffs, feature_set, log_transform):
    """Write out figure scripts for each measurement
    
    prefix - prefix for file names
    directory - write files into this directory
    dose_name - name of the dose measurement
    dose_data - doses per image
    data - data per image
    ec50_coeffs - coefficients calculated by calculate_ec50
    feature_set - tuples of object name and feature name in same order as data
    log_transform - true to log-transform the dose data
    """
    from matplotlib.figure import Figure
    from matplotlib.backends.backend_pdf import FigureCanvasPdf

    if log_transform:
        dose_data = np.log(dose_data)
    for i, (object_name, feature_name) in enumerate(feature_set):
        fdata = data[:, i]
        fcoeffs = ec50_coeffs[i, :]
        filename = "%s%s_%s.pdf" % (prefix, object_name, feature_name)
        pathname = os.path.join(directory, filename)
        f = Figure()
        canvas = FigureCanvasPdf(f)
        ax = f.add_subplot(1, 1, 1)
        x = np.linspace(0, np.max(dose_data), num=100)
        y = sigmoid(fcoeffs, x)
        ax.plot(x, y)
        dose_y = sigmoid(fcoeffs, dose_data)
        ax.plot(dose_data, dose_y, "o")
        ax.set_xlabel("Dose")
        ax.set_ylabel("Response")
        ax.set_title("%s_%s" % (object_name, feature_name))
        f.savefig(pathname)
Ejemplo n.º 12
0
def test_moll():
    from matplotlib.figure import Figure
    from matplotlib.backends.backend_agg import FigureCanvasAgg

    # Now make a simple example using the custom projection.

    import numpy as np

    fig = Figure(figsize=(6, 6))

    ax = fig.add_subplot(111, projection="ast.mollweide")

#    ra = np.random.uniform(size=100, low=0, high=360)
#    dec = np.random.uniform(size=100, low=-90, high=90)

    ra = np.random.uniform(size=10000, low=0, high=360)
    dec = np.random.uniform(size=10000, low=-90, high=90)
    ax.mapshow(np.arange(48), shading='flat')

    ra = np.random.uniform(size=1000, low=30, high=60)
    dec = np.random.uniform(size=1000, low=-50, high=50)

    ax.plot(ra, dec, '*')
    ax.axhline(-20)
    ax.axvline(140)

    ax.grid()
    canvas = FigureCanvasAgg(fig)
    fig.savefig('xxx-moll.png')
Ejemplo n.º 13
0
def barChart(size, data, output):
    d = data['x']
    ind = np.arange(len(d))
    ys = data['y']

    width = 0.60
    fig = Figure(figsize=(size[0], size[1]), dpi=size[2])
    FigureCanvas(fig)  # Stores canvas on fig.canvas

    axis = fig.add_subplot(111)
    axis.grid(color='r', linestyle='dotted', linewidth=0.1, alpha=0.5)

    bottom = np.zeros(len(ys[0]['data']))
    for y in ys:
        axis.bar(ind, y['data'], width, bottom=bottom, label=y.get('label'))
        bottom += np.array(y['data'])

    axis.set_title(data.get('title', ''))
    axis.set_xlabel(data['xlabel'])
    axis.set_ylabel(data['ylabel'])

    if data.get('allTicks', True) is True:
        axis.set_xticks(ind)

    if 'xtickFnc' in data:
        axis.set_xticklabels([data['xtickFnc'](v) for v in axis.get_xticks()])

    axis.legend()

    fig.savefig(output, format='png', transparent=True)
Ejemplo n.º 14
0
class MatplotlibWidget(QtGui.QWidget):

    def __init__(self, parent=None):
        super(MatplotlibWidget, self).__init__(parent)

        #create figure
        figwidth = 5.0    # inches
        figheight = 3.5   # inches

        self.figure = Figure(figsize=(figwidth, figheight))
        self.canvas = FigureCanvas(self.figure)
        self.axis = self.figure.add_subplot(111)
        self.axis.clear()
        self.layoutVertical = QtGui.QVBoxLayout(self)
        self.layoutVertical.addWidget(self.canvas)

        savePlotBtn = QtGui.QPushButton("Save")
        self.layoutVertical.addWidget(savePlotBtn)
        savePlotBtn.clicked.connect(self.onSave)

    def add_data(self, values, labels):
        self.axis.clear()
        self.axis.pie(values, labels=labels, explode=None,
                        autopct='%1.1f%%', shadow=True, startangle=90)

    def onSave(self):
        file = QtGui.QFileDialog.getSaveFileName(self, 'Save file', '',)
        fname, _ = file
        self.figure.savefig(fname)
Ejemplo n.º 15
0
    def save_as_plt(self, fname, pixel_array=None, vmin=None, vmax=None,
        cmap=None, format=None, origin=None):
        """ This method saves the image from a numpy array using matplotlib

        :param fname: Location and name of the image file to be saved.
        :param pixel_array: Numpy pixel array, i.e. ``numpy()`` return value
        :param vmin: matplotlib vmin
        :param vmax: matplotlib vmax
        :param cmap: matplotlib color map
        :param format: matplotlib format
        :param origin: matplotlib origin

        This method will return True if successful
        """
        from matplotlib.backends.backend_agg \
        import FigureCanvasAgg as FigureCanvas
        from matplotlib.figure import Figure
        from pylab import cm

        if pixel_array is None:
            pixel_array = self.numpy()

        if cmap is None:
            cmap = cm.bone
        fig = Figure(figsize=pixel_array.shape[::-1], dpi=1, frameon=False)
        canvas = FigureCanvas(fig)
        fig.figimage(pixel_array, cmap=cmap, vmin=vmin,
            vmax=vmax, origin=origin)
        fig.savefig(fname, dpi=1, format=format)
        return True
Ejemplo n.º 16
0
    def make_reference_plot(self, file_name):

        ref_fig = Figure()
        FigureCanvas(ref_fig)
        refFile = self.leakage_data[0]

        ref_plot = ref_fig.add_subplot(111)
        ref_plot.semilogy(refFile[0], np.abs(refFile[1]), 'bo-')
        ref_plot.set_xlabel(r'Voltage [V]')
        ref_plot.set_ylabel(r'Current [A]')
        ref_plot.set_title(self.devName + "- Reference Curve - Leakage")
        ref_plot.grid(True)
        myaxis = ref_plot.axis()
        date_deb = datetime.now()
        ref_plot.axvline(self.spot, color='m', linestyle='-.', linewidth=2)
        ref_plot.text(1.05 * self.spot, 0.4 * myaxis[3],
                      " spot: {0:.4}V ".format(self.spot),
                      color='m', rotation=270)
        ref_plot.text(myaxis[0],
                      myaxis[2] * 2,
                      " " + date_deb.ctime())
        ref_fig.savefig(file_name)

        textFile = ('[<img src="reference.png"'
                    + ' align="center" alt="Reference"> ](../'
                    + self.devName + '_report.html)')
        text_output = md.markdown(textFile, extensions=['extra'])
        text_out = open(file_name[:-4] + ".html", "w")
        text_out.write(text_output)
        text_out.close()
Ejemplo n.º 17
0
def plotBatchResults(db):
	'Hook called from woo.batch.writeResults'

	import re,math,woo.batch,os
	results=woo.batch.dbReadResults(db)
	out='%s.pdf'%re.sub('\.sqlite$','',db)
	from matplotlib.figure import Figure
	from matplotlib.backends.backend_agg import FigureCanvasAgg
	fig=Figure();
	canvas=FigureCanvasAgg(fig)
	ax1=fig.add_subplot(2,1,1)
	ax2=fig.add_subplot(2,1,2)
	ax1.set_xlabel('Time [s]')
	ax1.set_ylabel('Kinetic energy [J]')
	ax1.grid(True)
	ax2.set_xlabel('Time [s]')
	ax2.set_ylabel('Relative energy error')
	ax2.grid(True)
	for res in results:
		series=res['series']
		pre=res['pre']
		if not res['title']: res['title']=res['sceneId']
		ax1.plot(series['t'],series['kinetic'],label=res['title'],alpha=.6)
		ax2.plot(series['t'],series['relErr'],label=res['title'],alpha=.6)
	for ax,loc in (ax1,'lower left'),(ax2,'lower right'):
		l=ax.legend(loc=loc,labelspacing=.2,prop={'size':7})
		l.get_frame().set_alpha(.4)
	fig.savefig(out)
	print 'Batch figure saved to file://%s'%os.path.abspath(out)
def main(argv=None):
    width = 1
    for test in all_tests():
        fig = Figure(figsize=(10, 5))
        FigureCanvasAgg(fig)  # why do i have to do this?
        ax = fig.add_subplot(1, 1, 1)
        x = np.array([float(i)*(len(STORAGE_DRIVERS)+1) for i in range(len(NUM_INSTANCES))])
        bars = []

        for driver, color in zip(STORAGE_DRIVERS, 'mbcg'):
            means = []
            for num_instances in NUM_INSTANCES:
                means.append(average_from_file(os.path.join(
                    'results',
                    '{}.{}.{}'.format(test, num_instances, driver),
                )))

            bars.append(ax.bar(x, means, width, color=color))
            x += width

        ax.set_title(test)
        ax.set_xlabel('num parallel processes' + ' ' * 90)  # lol
        ax.set_ylabel('seconds for completion')
        ax.set_xticklabels(NUM_INSTANCES)
        ax.set_xticks(x - 1.5)

        box = ax.get_position()
        ax.set_position([box.x0, box.y0 + box.height * 0.3, box.width, box.height * 0.7])
        ax.legend((bar[0] for bar in bars), STORAGE_DRIVERS,
                bbox_to_anchor=(1, -0.1),
              fancybox=True, shadow=True, ncol=5)

        fig.savefig(os.path.join('graphs', test + '.png'))
Ejemplo n.º 19
0
class Canvas(FigureCanvas):
    def __init__(self, parent=None, width=5, height=4, dpi=100, nD = 2):
        # plt.xkcd()
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        FigureCanvas.__init__(self, self.fig)
        self.dim = nD
        if self.dim is 2:
            self.axes = self.fig.add_subplot(1, 1, 1)
            pass
        else:
            self.axes = self.fig.add_subplot(1, 1, 1, projection='3d')
            pass
        self.axes.hold(False)
        self.compute_initial_figure()
        self.setParent(parent)
        FigureCanvas.setSizePolicy(self,
                                   QtGui.QSizePolicy.Expanding,
                                   QtGui.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

    def compute_initial_figure(self):
        pass

    def export_pdf(self):
        fname='export.pdf'
        self.fig.savefig(fname)

    def export_jpg(self):
        fname='export.jpg'
        self.fig.savefig(fname)
Ejemplo n.º 20
0
def generate_scatterplot(datafname, imgfname):
    """Creates a 2D scatter plot of the specified Gocator XYZ
    data and saves it in the specified image filename."""
    matplotlib.rcParams['axes.formatter.limits'] = -4, 4
    matplotlib.rcParams['font.size'] = 14
    matplotlib.rcParams['axes.titlesize'] = 12
    matplotlib.rcParams['axes.labelsize'] = 12
    matplotlib.rcParams['xtick.labelsize'] = 11
    matplotlib.rcParams['ytick.labelsize'] = 11
    figure = Figure()
    canvas = FigureCanvas(figure)
    axes = figure.gca()
    x,y,z = np.genfromtxt(datafname, delimiter=",", unpack=True)
    xi = x[z!=-32.768]
    yi = y[z!=-32.768]
    zi = z[z!=-32.768]
    scatter_plt = axes.scatter(xi, yi, c=zi, cmap=cm.get_cmap("Set1"), marker=',')
    axes.grid(True)
    axes.axis([np.min(xi), np.max(xi), np.min(yi), np.max(yi)])
    axes.set_title(os.path.basename(datafname))
    colorbar = figure.colorbar(scatter_plt)
    colorbar.set_label("Range [mm]")
    axes.set_xlabel("Horizontal Position [mm]")
    axes.set_ylabel("Scan Position [mm]")
    figure.savefig(imgfname)
Ejemplo n.º 21
0
def plot_parameter_curve(mcmc_set, p_index, p_name):
    # Make sure we've already run the fits for this mcmc set!
    if mcmc_set.name not in two_exp_fits_dict.keys():
        raise Exception('%s not found in two_exp_fits_dict!' % mcmc_set.name)
    # Make sure we've already run the fits for the data!
    if 'data' not in two_exp_fits_dict.keys():
        fit_data(mcmc_set)

    # Get the parameter array
    p_arr = two_exp_fits_dict[mcmc_set.name][p_index]
    p_arr_data = two_exp_fits_dict['data'][p_index]
    data = mcmc_set.chains[0].data
    plot_filename = '%s_%s_curve.png' % (mcmc_set.name, p_name)
    thumbnail_filename = '%s_%s_curve_th.png' % (mcmc_set.name, p_name)

    # Plot of parameter
    fig = Figure()
    ax = fig.gca()
    ax.plot(data.columns, p_arr, 'b')
    ax.plot(data.columns, p_arr_data, marker='o', linestyle='', color='r')
    ax.set_ylabel('%s value' % p_name)
    ax.set_xlabel('[Bax] (nM)')
    ax.set_title('%s for %s' % (mcmc_set.name, p_name))
    canvas = FigureCanvasAgg(fig)
    fig.set_canvas(canvas)
    fig.savefig(plot_filename)
    fig.savefig(thumbnail_filename, dpi=10)
    return ThumbnailResult(thumbnail_filename, plot_filename)
Ejemplo n.º 22
0
class MplCanvas(FigureCanvas):

    def __init__(self, figsize=(8, 6), dpi=80):
        self.fig = Figure(figsize=figsize, dpi=dpi)
        self.fig.subplots_adjust(wspace=0, hspace=0, left=0, right=1.0, top=1.0, bottom=0)

        self.ax = self.fig.add_subplot(111, frameon=False)
        self.ax.patch.set_visible(False)
        self.ax.set_axis_off()
        FigureCanvas.__init__(self, self.fig)
        FigureCanvas.setSizePolicy(self, QtGui.QSizePolicy.Expanding, QtGui.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        #self.setAlpha(0.0)

    def saveFig(self, str_io_buf):
        self.setAlpha(0.0)
        self.fig.savefig(str_io_buf, transparent=True, frameon=False, format='png')
        # 140818 Save file to StringIO Buffer not disk file

    def getFig(self):
        return self.fig

    def setAlpha(self, value):
        self.fig.patch.set_alpha(value)
        self.ax.patch.set_alpha(value)
        self.ax.set_axis_off()

    def set_face_color(self, color):
        self.fig.set_facecolor(color)  # "#000000"
Ejemplo n.º 23
0
def plotThreeWay(hist, title, filename=None, x_axis_title=None, minimum=None, maximum=None, bins=101):  # the famous 3 way plot (enhanced)
    if minimum is None:
        minimum = 0
    elif minimum == 'minimum':
        minimum = np.ma.min(hist)
    if maximum == 'median' or maximum is None:
        median = np.ma.median(hist)
        maximum = median * 2  # round_to_multiple(median * 2, math.floor(math.log10(median * 2)))
    elif maximum == 'maximum':
        maximum = np.ma.max(hist)
        maximum = maximum  # round_to_multiple(maximum, math.floor(math.log10(maximum)))
    if maximum < 1 or hist.all() is np.ma.masked:
        maximum = 1

    x_axis_title = '' if x_axis_title is None else x_axis_title
    fig = Figure()
    FigureCanvas(fig)
    fig.patch.set_facecolor('white')
    ax1 = fig.add_subplot(311)
    create_2d_pixel_hist(fig, ax1, hist, title=title, x_axis_title="column", y_axis_title="row", z_min=minimum if minimum else 0, z_max=maximum)
    ax2 = fig.add_subplot(312)
    create_1d_hist(fig, ax2, hist, bins=bins, x_axis_title=x_axis_title, y_axis_title="#", x_min=minimum, x_max=maximum)
    ax3 = fig.add_subplot(313)
    create_pixel_scatter_plot(fig, ax3, hist, x_axis_title="channel=row + column*336", y_axis_title=x_axis_title, y_min=minimum, y_max=maximum)
    fig.tight_layout()
    if not filename:
        fig.show()
    elif isinstance(filename, PdfPages):
        filename.savefig(fig)
    else:
        fig.savefig(filename)
Ejemplo n.º 24
0
def plot_cluster_tot_size(hist, median=False, z_max=None, filename=None):
    H = hist[0:50, 0:20]
    if z_max is None:
        z_max = np.ma.max(H)
    if z_max < 1 or H.all() is np.ma.masked:
        z_max = 1
    fig = Figure()
    FigureCanvas(fig)
    ax = fig.add_subplot(111)
    extent = [-0.5, 20.5, 49.5, -0.5]
    bounds = np.linspace(start=0, stop=z_max, num=255, endpoint=True)
    cmap = cm.get_cmap('jet')
    cmap.set_bad('w')
    norm = colors.BoundaryNorm(bounds, cmap.N)
    im = ax.imshow(H, aspect="auto", interpolation='nearest', cmap=cmap, norm=norm, extent=extent)  # for monitoring
    ax.set_title('Cluster size and cluster ToT (' + str(np.sum(H) / 2) + ' entries)')
    ax.set_xlabel('cluster size')
    ax.set_ylabel('cluster ToT')

    ax.invert_yaxis()
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)
    cb = fig.colorbar(im, cax=cax, ticks=np.linspace(start=0, stop=z_max, num=9, endpoint=True))
    cb.set_label("#")
    fig.patch.set_facecolor('white')
    if not filename:
        fig.show()
    elif isinstance(filename, PdfPages):
        filename.savefig(fig)
    else:
        fig.savefig(filename)
Ejemplo n.º 25
0
def plot_1d_hist(hist, yerr=None, title=None, x_axis_title=None, y_axis_title=None, x_ticks=None, color='r', plot_range=None, log_y=False, filename=None, figure_name=None):
    logging.info('Plot 1d histogram%s', (': ' + title) if title is not None else '')
    fig = Figure()
    FigureCanvas(fig)
    ax = fig.add_subplot(111)
    if plot_range is None:
        plot_range = range(0, len(hist))
    if not plot_range:
        plot_range = [0]
    if yerr is not None:
        ax.bar(left=plot_range, height=hist[plot_range], color=color, align='center', yerr=yerr)
    else:
        ax.bar(left=plot_range, height=hist[plot_range], color=color, align='center')
    ax.set_xlim((min(plot_range) - 0.5, max(plot_range) + 0.5))
    ax.set_title(title)
    if x_axis_title is not None:
        ax.set_xlabel(x_axis_title)
    if y_axis_title is not None:
        ax.set_ylabel(y_axis_title)
    if x_ticks is not None:
        ax.set_xticks(range(0, len(hist[:])) if plot_range is None else plot_range)
        ax.set_xticklabels(x_ticks)
        ax.tick_params(which='both', labelsize=8)
    if np.allclose(hist, 0.0):
        ax.set_ylim((0, 1))
    else:
        if log_y:
            ax.set_yscale('log')
    ax.grid(True)
    if not filename:
        fig.show()
    elif isinstance(filename, PdfPages):
        filename.savefig(fig)
    else:
        fig.savefig(filename)
Ejemplo n.º 26
0
def plot_correlation(hist, title="Hit correlation", xlabel=None, ylabel=None, filename=None):
    logging.info("Plotting correlations")
    fig = Figure()
    FigureCanvas(fig)
    ax = fig.add_subplot(1, 1, 1)
    cmap = cm.get_cmap('jet')
    extent = [hist[2][0] - 0.5, hist[2][-1] + 0.5, hist[1][-1] + 0.5, hist[1][0] - 0.5]
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    im = ax.imshow(hist[0], extent=extent, cmap=cmap, interpolation='nearest')
    ax.invert_yaxis()
    # add colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    z_max = np.max(hist[0])
    bounds = np.linspace(start=0, stop=z_max, num=255, endpoint=True)
    norm = colors.BoundaryNorm(bounds, cmap.N)
    fig.colorbar(im, boundaries=bounds, cmap=cmap, norm=norm, ticks=np.linspace(start=0, stop=z_max, num=9, endpoint=True), cax=cax)
    if not filename:
        fig.show()
    elif isinstance(filename, PdfPages):
        filename.savefig(fig)
    else:
        fig.savefig(filename)
Ejemplo n.º 27
0
def plot_scatter_time(x, y, yerr=None, title=None, legend=None, plot_range=None, plot_range_y=None, x_label=None, y_label=None, marker_style='-o', log_x=False, log_y=False, filename=None):
    logging.info("Plot time scatter plot %s", (': ' + title) if title is not None else '')
    fig = Figure()
    FigureCanvas(fig)
    ax = fig.add_subplot(111)
    ax.format_xdata = mdates.DateFormatter('%Y-%m-%d')
    times = []
    for time in x:
        times.append(datetime.fromtimestamp(time))
    if yerr is not None:
        ax.errorbar(times, y, yerr=[yerr, yerr], fmt=marker_style)
    else:
        ax.plot(times, y, marker_style)
    ax.set_title(title)
    if x_label is not None:
        ax.set_xlabel(x_label)
    if y_label is not None:
        ax.set_ylabel(y_label)
    if log_x:
        ax.xscale('log')
    if log_y:
        ax.yscale('log')
    if plot_range:
        ax.set_xlim((min(plot_range), max(plot_range)))
    if plot_range_y:
        ax.set_ylim((min(plot_range_y), max(plot_range_y)))
    if legend:
        ax.legend(legend, 0)
    ax.grid(True)
    if not filename:
        fig.show()
    elif isinstance(filename, PdfPages):
        filename.savefig(fig)
    else:
        fig.savefig(filename)
Ejemplo n.º 28
0
def plot_correlations(filenames, limit=None):
    DataFrame = pd.DataFrame()
    index = 0
    for fileName in filenames:
        with pd.get_store(fileName, 'r') as store:
            tempDataFrame = pd.DataFrame({'Event': store.Hits.Event[:15000], 'Row' + str(index): store.Hits.Row[:15000]})
            tempDataFrame = tempDataFrame.set_index('Event')
            DataFrame = tempDataFrame.join(DataFrame)
            DataFrame = DataFrame.dropna()
            index += 1
            del tempDataFrame
    DataFrame["index"] = DataFrame.index
    DataFrame.drop_duplicates(take_last=True, inplace=True)
    del DataFrame["index"]
    correlationNames = ('Row')
    index = 0
    for corName in correlationNames:
        for colName in itertools.permutations(DataFrame.filter(regex=corName), 2):
            if(corName == 'Col'):
                heatmap, xedges, yedges = np.histogram2d(DataFrame[colName[0]], DataFrame[colName[1]], bins=(80, 80), range=[[1, 80], [1, 80]])
            else:
                heatmap, xedges, yedges = np.histogram2d(DataFrame[colName[0]], DataFrame[colName[1]], bins=(336, 336), range=[[1, 336], [1, 336]])
            extent = [yedges[0] - 0.5, yedges[-1] + 0.5, xedges[-1] + 0.5, xedges[0] - 0.5]
            cmap = cm.get_cmap('hot', 40)
            fig = Figure()
            FigureCanvas(fig)
            ax = fig.add_subplot(111)
            ax.imshow(heatmap, extent=extent, cmap=cmap, interpolation='nearest')
            ax.invert_yaxis()
            ax.set_xlabel(colName[0])
            ax.set_ylabel(colName[1])
            ax.set_title('Correlation plot(' + corName + ')')
            fig.savefig(colName[0] + '_' + colName[1] + '.pdf')
            index += 1
Ejemplo n.º 29
0
def plot_scatter(x, y, x_err=None, y_err=None, title=None, legend=None, plot_range=None, plot_range_y=None, x_label=None, y_label=None, marker_style='-o', log_x=False, log_y=False, filename=None):
    logging.info('Plot scatter plot %s', (': ' + title) if title is not None else '')
    fig = Figure()
    FigureCanvas(fig)
    ax = fig.add_subplot(111)
    if x_err is not None:
        x_err = [x_err, x_err]
    if y_err is not None:
        y_err = [y_err, y_err]
    if x_err is not None or y_err is not None:
        ax.errorbar(x, y, xerr=x_err, yerr=y_err, fmt=marker_style)
    else:
        ax.plot(x, y, marker_style, markersize=1)
    ax.set_title(title)
    if x_label is not None:
        ax.set_xlabel(x_label)
    if y_label is not None:
        ax.set_ylabel(y_label)
    if log_x:
        ax.set_xscale('log')
    if log_y:
        ax.set_yscale('log')
    if plot_range:
        ax.set_xlim((min(plot_range), max(plot_range)))
    if plot_range_y:
        ax.set_ylim((min(plot_range_y), max(plot_range_y)))
    if legend:
        ax.legend(legend, 0)
    ax.grid(True)
    if not filename:
        fig.show()
    elif isinstance(filename, PdfPages):
        filename.savefig(fig)
    else:
        fig.savefig(filename)
Ejemplo n.º 30
0
def make_xafs_plot(x, y, title, xlabel='Energy (eV)', ylabel='mu', x0=None,
                   ref_mu=None, ref_name=None):

    fig  = Figure(figsize=(8.5, 5.0), dpi=300)
    canvas = FigureCanvas(fig)
    axes = fig.add_axes([0.16, 0.16, 0.75, 0.75], axisbg='#FFFFFF')

    axes.set_xlabel(xlabel, fontproperties=mpl_lfont)
    axes.set_ylabel(ylabel, fontproperties=mpl_lfont)
    axes.plot(x, y, linewidth=3.5)
    if x0 is not None:
        axes.axvline(0, ymin=min(y), ymax=max(y),
                     linewidth=2, color='#CCBBDD', zorder=-10)

        if ref_mu is not None:
            axes.plot(x, ref_mu, linewidth=4, zorder=-5, color='#AA4444')
            title = "%s (%s)" % (title, ref_name)
        # ymax = (0.7*max(y) + 0.3 * min(y))
        # axes.text(-25.0, ymax, "%.1f eV" % x0, fontproperties=mpl_lfont)

    xrange = max(x)-min(x)
    yrange = max(y)-min(y)

    axes.set_xlim((min(x)-xrange*0.05, max(x)+xrange*0.05), emit=True)
    axes.set_ylim((min(y)-yrange*0.05, max(y)+yrange*0.05), emit=True)
    axes.set_title(title, fontproperties=mpl_lfont)

    figdata = io.BytesIO()
    fig.savefig(figdata, format='png', facecolor='#FDFDFA')
    figdata.seek(0)
    return base64.b64encode(figdata.getvalue())
Ejemplo n.º 31
0
class CalibrationViewer(QtGui.QMainWindow):
    trialsChanged = QtCore.pyqtSignal()

    def __init__(self):
        self.filters = list()
        super(CalibrationViewer, self).__init__()
        self.setWindowTitle('Olfa Calibration')
        self.statusBar()
        self.trial_selected_list = []
        self.trialsChanged.connect(self._trial_selection_changed)

        mainwidget = QtGui.QWidget(self)
        self.setCentralWidget(mainwidget)
        layout = QtGui.QGridLayout(mainwidget)
        mainwidget.setLayout(layout)

        menu = self.menuBar()
        filemenu = menu.addMenu("&File")
        toolsmenu = menu.addMenu("&Tools")

        openAction = QtGui.QAction("&Open recording...", self)
        openAction.triggered.connect(self._openAction_triggered)
        openAction.setStatusTip(
            "Open a HDF5 data file with calibration recording.")
        openAction.setShortcut("Ctrl+O")
        filemenu.addAction(openAction)
        saveFigsAction = QtGui.QAction('&Save figures...', self)
        saveFigsAction.triggered.connect(self._saveFiguresAction_triggered)
        saveFigsAction.setShortcut('Ctrl+S')
        openAction.setStatusTip("Saves current figures.")
        filemenu.addAction(saveFigsAction)
        exitAction = QtGui.QAction("&Quit", self)
        exitAction.setShortcut("Ctrl+Q")
        exitAction.setStatusTip("Quit program.")
        exitAction.triggered.connect(QtGui.qApp.quit)
        filemenu.addAction(exitAction)
        removeTrialAction = QtGui.QAction("&Remove trials", self)
        removeTrialAction.setStatusTip(
            'Permanently removes selected trials (bad trials) from trial list.'
        )
        removeTrialAction.triggered.connect(self._remove_trials)
        removeTrialAction.setShortcut('Ctrl+R')
        toolsmenu.addAction(removeTrialAction)

        trial_group_list_box = QtGui.QGroupBox()
        trial_group_list_box.setTitle('Trial Groups')
        self.trial_group_list = TrialGroupListWidget()
        trial_group_layout = QtGui.QVBoxLayout()
        trial_group_list_box.setLayout(trial_group_layout)
        trial_group_layout.addWidget(self.trial_group_list)
        layout.addWidget(trial_group_list_box, 0, 0)
        self.trial_group_list.itemSelectionChanged.connect(
            self._trial_group_selection_changed)

        trial_select_list_box = QtGui.QGroupBox()
        trial_select_list_box.setMouseTracking(True)
        trial_select_list_layout = QtGui.QVBoxLayout()
        trial_select_list_box.setLayout(trial_select_list_layout)
        trial_select_list_box.setTitle('Trials')
        self.trial_select_list = TrialListWidget()
        self.trial_select_list.setMouseTracking(True)
        trial_select_list_layout.addWidget(self.trial_select_list)
        self.trial_select_list.setSelectionMode(
            QtGui.QAbstractItemView.ExtendedSelection)
        self.trial_select_list.itemSelectionChanged.connect(
            self._trial_selection_changed)
        layout.addWidget(trial_select_list_box, 0, 1)
        self.trial_select_list.createGroupSig.connect(
            self.trial_group_list.create_group)

        filters_box = QtGui.QGroupBox("Trial filters.")
        filters_box_layout = QtGui.QVBoxLayout(filters_box)
        filters_scroll_area = QtGui.QScrollArea()
        filters_buttons = QtGui.QHBoxLayout()
        filters_all = QtGui.QPushButton('Select all', self)
        filters_all.clicked.connect(self._select_all_filters)
        filters_none = QtGui.QPushButton('Select none', self)
        filters_none.clicked.connect(self._select_none_filters)
        filters_buttons.addWidget(filters_all)
        filters_buttons.addWidget(filters_none)
        filters_box_layout.addLayout(filters_buttons)
        filters_box_layout.addWidget(filters_scroll_area)
        filters_wid = QtGui.QWidget()
        filters_scroll_area.setWidget(filters_wid)
        filters_scroll_area.setWidgetResizable(True)
        filters_scroll_area.setFixedWidth(300)
        self.filters_layout = QtGui.QVBoxLayout()
        filters_wid.setLayout(self.filters_layout)
        layout.addWidget(filters_box, 0, 2)

        plots_box = QtGui.QGroupBox()
        plots_box.setTitle('Plots')
        plots_layout = QtGui.QHBoxLayout()

        self.figure = Figure((9, 5))
        self.figure.patch.set_facecolor('None')
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setParent(plots_box)
        self.canvas.setSizePolicy(QtGui.QSizePolicy.Expanding,
                                  QtGui.QSizePolicy.Expanding)
        plots_layout.addWidget(self.canvas)
        plots_box.setLayout(plots_layout)
        layout.addWidget(plots_box, 0, 3)
        self.ax_pid = self.figure.add_subplot(2, 1, 1)
        self.ax_pid.set_title('PID traces')
        self.ax_pid.set_ylabel('')
        self.ax_pid.set_xlabel('t (ms)')
        # self.ax_pid.set_yscale('log')
        self.ax_mean_plots = self.figure.add_subplot(2, 1, 2)
        self.ax_mean_plots.set_title('Mean value')
        self.ax_mean_plots.set_ylabel('value')
        self.ax_mean_plots.set_xlabel('Concentration')
        self.ax_mean_plots.autoscale(enable=True, axis=u'both', tight=False)
        self.figure.tight_layout()

    @QtCore.pyqtSlot()
    def _list_context_menu_trig(self):
        pass

    @QtCore.pyqtSlot()
    def _filters_changed(self):
        mask = np.ones_like(self.trial_mask)
        for v in self.filters:
            mask *= v.trial_mask
        self.trial_mask = mask
        self.trial_select_list.itemSelectionChanged.disconnect(
            self._trial_selection_changed)
        for i in xrange(len(self.trial_mask)):
            hide = not self.trial_mask[i]
            it = self.trial_select_list.item(i)
            it.setHidden(hide)
            select = it.isSelected() * self.trial_mask[i]
            it.setSelected(select)
        self.trial_select_list.itemSelectionChanged.connect(
            self._trial_selection_changed)
        self.trial_select_list.itemSelectionChanged.emit(
        )  # emit that something changed so that we redraw.

    @QtCore.pyqtSlot()
    def _openAction_triggered(self):
        filedialog = QtGui.QFileDialog(self)
        if os.path.exists('D:\\experiment\\raw_data'):
            startpath = 'D:\\experiment\\raw_data\\mouse_o_cal_cw\\sess_001'
        else:
            startpath = ''
        fn = filedialog.getOpenFileName(self, "Select a data file.", startpath,
                                        "HDF5 (*.h5)")
        if fn:
            data = CalibrationFile(str(fn))
            self.data = data
            self.trial_actions = []
            trial_num_list = []
            self.trial_select_list.clear()
            trials = self.data.trials
            for i, t in enumerate(trials):
                tstr = "Trial {0}".format(i)
                it = QtGui.QListWidgetItem(tstr, self.trial_select_list)
                # trial = trials[i]
                # odor = trial['odor']
                # vialconc = trial['vialconc']
                # odorconc = trial['odorconc']
                # dilution = 1000 - trial['dilution'][0]
                # trst = 'Odor: {0}, vialconc: {1}, odorconc: {2}, dilution: {3}'.format(odor, vialconc, odorconc,
                #                                                                        dilution)
                # it.setStatusTip(trst)
                trial_num_list.append(i)
            self.trial_select_list.trial_num_list = np.array(trial_num_list)
            self.trial_mask = np.ones(len(
                self.trial_select_list.trial_num_list),
                                      dtype=bool)
            self.build_filters(trials)
        else:
            print('No file selected.')
        return

    def build_filters(self, trials):
        while self.filters_layout.itemAt(0):
            self.filters_layout.takeAt(0)
        if self.filters:
            for f in self.filters:
                f.deleteLater()
        self.filters = list()
        colnames = trials.dtype.names
        if 'odorconc' not in colnames:
            self.error = QtGui.QErrorMessage()
            self.error.showMessage(
                'Data file must have "odorconc" field to allow plotting.')
        start_strings = ('odorconc', 'olfas', 'dilutors')
        filter_fields = []
        for ss in start_strings:
            for fieldname in colnames:
                if fieldname.startswith(ss):
                    filter_fields.append(fieldname)

        for field in filter_fields:
            filter = FiltersListWidget(field)
            filter.populate_list(self.data.trials)
            filter.setVisible(False)
            self.filters.append(filter)
            box = QtGui.QWidget()
            box.setSizePolicy(0, 0)
            # box.setTitle(filter.fieldname)
            show_button = QtGui.QPushButton(filter.fieldname)
            show_button.setStyleSheet('text-align:left; border:0px')
            show_button.clicked.connect(filter.toggle_visible)
            _filt_layout = QtGui.QVBoxLayout(box)
            _filt_layout.addWidget(show_button)
            _filt_layout.addWidget(filter)
            _filt_layout.setSpacing(0)
            self.filters_layout.addWidget(box)
            filter.filterChanged.connect(self._filters_changed)
        for v in self.filters:
            assert isinstance(v, FiltersListWidget)

        # self.filters_layout.addWidget(QtGui.QSpacerItem())
        self.filters_layout.addStretch()
        self.filters_layout.setSpacing(0)

        return

    @QtCore.pyqtSlot()
    def _saveFiguresAction_triggered(self):
        # TODO: add figure saving functionality with filedialog.getSaveFileName.
        self.saveDialog = QtGui.QFileDialog()
        saveloc = self.saveDialog.getSaveFileName(
            self, 'Save figure', '', 'PDF (*.pdf);;JPEG (*.jpg);;TIFF (*.tif)')
        saveloc = str(saveloc)
        self.figure.savefig(saveloc)

    @QtCore.pyqtSlot()
    def _remove_trials(self):
        selected_idxes = self.trial_select_list.selectedIndexes()
        remove_idxes = []
        for id in selected_idxes:
            idx = id.row()
            remove_idxes.append(idx)
        while self.trial_select_list.selectedIndexes():
            selected_idxes = self.trial_select_list.selectedIndexes()
            idx = selected_idxes[0].row()
            self.trial_select_list.takeItem(idx)
        new_trials_array = np.zeros(
            len(self.trial_select_list.trial_num_list) - len(remove_idxes),
            dtype=np.int)
        ii = 0
        remove_trialnums = []
        new_trials_mask = np.zeros_like(new_trials_array, dtype=bool)
        for i in xrange(len(self.trial_select_list.trial_num_list)):
            if i not in remove_idxes:
                new_trials_mask[ii] = self.trial_mask[i]
                new_trials_array[ii] = self.trial_select_list.trial_num_list[i]
                ii += 1
            else:
                remove_trialnums.append(
                    self.trial_select_list.trial_num_list[i])
        self.trial_mask = new_trials_mask
        self.trial_select_list.trial_num_list = new_trials_array

        for f in self.filters:
            f.remove_trials(remove_idxes)
        self.trial_group_list._remove_trials(remove_trialnums)

    @QtCore.pyqtSlot()
    def _trial_selection_changed(self):
        selected_idxes = self.trial_select_list.selectedIndexes()
        selected_trial_nums = []
        for id in selected_idxes:
            idx = id.row()
            trialnum = self.trial_select_list.trial_num_list[idx]
            selected_trial_nums.append(trialnum)
        self.update_plots(selected_trial_nums)
        self.trial_group_list.blockSignals(True)
        for i, g in zip(xrange(self.trial_group_list.count()),
                        self.trial_group_list.trial_groups):
            it = self.trial_group_list.item(i)
            all_in = True
            group_trials = g['trial_nums']
            for t in group_trials:
                if t not in selected_trial_nums:
                    all_in = False
            if not all_in:
                it.setSelected(False)
            elif all_in:
                it.setSelected(True)
        self.trial_group_list.blockSignals(False)
        return

    @QtCore.pyqtSlot()
    def _trial_group_selection_changed(self):
        selected_idxes = self.trial_group_list.selectedIndexes()
        self._select_all_filters()
        selected_trial_nums = []
        for id in selected_idxes:
            idx = id.row()
            trialnums = self.trial_group_list.trial_groups[idx]['trial_nums']
            selected_trial_nums.extend(trialnums)
        self.trial_select_list.blockSignals(True)
        for i in range(self.trial_select_list.count()):
            item = self.trial_select_list.item(i)
            self.trial_select_list.setItemSelected(item, False)
        for i in selected_trial_nums:
            idx = np.where(self.trial_select_list.trial_num_list == i)[0][0]
            it = self.trial_select_list.item(idx)
            if not it.isSelected():
                it.setSelected(True)
        self.trial_select_list.blockSignals(False)
        self._trial_selection_changed()

    def update_plots(self, trials):
        padding = (
            2000, 2000
        )  #TODO: make this changable - this is the number of ms before/afterr trial to extract for stream.
        trial_streams = []
        trial_colors = []
        while self.ax_pid.lines:
            self.ax_pid.lines.pop(0)
        while self.ax_mean_plots.lines:
            self.ax_mean_plots.lines.pop(0)
        groups_by_trial = []
        all_groups = set()
        ntrials = len(trials)
        vals = np.empty(ntrials)
        concs = np.empty_like(vals)
        if trials:
            a = max([1. / len(trials), .25])
            for i, tn in enumerate(trials):
                color = self.trial_group_list.get_trial_color(tn)
                groups = self.trial_group_list.get_trial_groups(tn)
                trial_colors.append(color)
                groups_by_trial.append(groups)
                all_groups.update(groups)
                trial = self.data.return_trial(tn, padding=padding)
                stream = remove_stream_trend(trial.streams['sniff'],
                                             (0, padding[0]))
                stream -= stream[0:padding[0]].min()
                # TODO: remove baseline (N2) trial average from this.
                trial_streams.append(stream)
                self.ax_pid.plot(stream, color=color, alpha=a)
                conc = trial.trials['odorconc']
                baseline = np.mean(stream[:2000])
                val = np.mean(stream[3000:4000]) - baseline
                vals[i] = val
                concs[i] = conc
                self.ax_mean_plots.plot(conc, val, '.', color=color)
        minlen = 500000000
        for i in trial_streams:
            minlen = min(len(i), minlen)
        streams_array = np.empty((ntrials, minlen))
        for i in xrange(ntrials):
            streams_array[i, :] = trial_streams[i][:minlen]
        for g in all_groups:
            mask = np.empty(ntrials, dtype=bool)
            for i in xrange(ntrials):
                groups = groups_by_trial[i]
                mask[i] = g in groups
            c = concs[mask]
            groupstreams = streams_array[mask]
            if len(np.unique(c)) < 2:
                self.ax_pid.plot(groupstreams.mean(axis=0),
                                 color='k',
                                 linewidth=2)
            else:
                v = vals[mask]
                a, b, _, _, _ = stats.linregress(c, v)
                color = self.trial_group_list.get_group_color(g)
                minn, maxx = self.ax_mean_plots.get_xlim()
                x = np.array([minn, maxx])
                self.ax_mean_plots.plot(x, a * x + b, color=color)
        self.ax_pid.relim()
        self.ax_mean_plots.set_yscale('log')
        self.ax_mean_plots.set_xscale('log')
        self.ax_mean_plots.relim()

        self.canvas.draw()

    @QtCore.pyqtSlot()
    def _select_none_filters(self):
        for filter in self.filters:
            filter.filterChanged.disconnect(self._filters_changed)
            filter.clearSelection()
            filter.filterChanged.connect(self._filters_changed)
        self._filters_changed()

    @QtCore.pyqtSlot()
    def _select_all_filters(self):
        for filter in self.filters:
            filter.filterChanged.disconnect(self._filters_changed)
            filter.selectAll()
            filter.filterChanged.connect(self._filters_changed)
        self._filters_changed()
Ejemplo n.º 32
0
class Plotter(QMainWindow):
    legends=[]
    base_names=[]
    database={}
    def __init__(self):
        super().__init__()
        self.size_policy=QSizePolicy.Expanding
        self.font=QFont()
        self.font.setPointSize(12)
        self.setWindowTitle('Calibrated Spectrum Viewer')
        self.menu()
        self.geometry()
        self.showMaximized()
        self.show()
        self.figure.tight_layout()
    def menu(self):
        self.menuFile=self.menuBar().addMenu('&File')
        self.load_new=QAction('&Load New Spectrum')
        self.load_new.triggered.connect(self.new_graph)
        self.load_new.setShortcut('Ctrl+O')
        self.load_new.setToolTip('Load a new calibrated spectrum')
        
        self.save_figure=QAction('&Save Spectrum Image')
        self.save_figure.triggered.connect(self.save_fig)
        self.save_figure.setShortcut('Ctrl+S')
        self.save_figure.setEnabled(False)
        
        self.clear=QAction('&Clear Graph')
        self.clear.triggered.connect(self.clear_graph)
        self.clear.setEnabled(False)
    
        self.menuFile.addActions([self.load_new,self.save_figure,self.clear])
        
    def geometry(self):
        self.prop_label=QLabel('Detection Probability: 99%')
        self.prop_label.setSizePolicy(self.size_policy, self.size_policy)
        self.prop_label.setFont(self.font)
        
        self.prob=QSlider(Qt.Horizontal)
        self.prob.setSizePolicy(self.size_policy,self.size_policy)
        self.prob.setFont(self.font)
        self.prob.setMinimum(0)
        self.prob.setMaximum(100)
        self.prob.setSingleStep(1)
        self.prob.setValue(99)
        self.prob.setTickInterval(10)
        self.prob.setTickPosition(QSlider.TicksBelow)
        self.prob.valueChanged.connect(self.label_update)
        
        self.plot_window=QWidget()
        layout=QVBoxLayout()
        self.figure=Figure()
        self._canvas=FigureCanvas(self.figure)
        self.toolbar=NavigationToolbar(self._canvas,self)
        layout.addWidget(self.toolbar)
        layout.addWidget(self._canvas)
        self.plot_window.setLayout(layout)
        
        topper=QDockWidget('Adjustments')
        
        horwidget=QWidget()
        hlayout=QHBoxLayout()
        hlayout.addWidget(self.prop_label)
        hlayout.addWidget(self.prob)
        horwidget.setLayout(hlayout)
        topper.setWidget(horwidget)
        self.addDockWidget(Qt.TopDockWidgetArea, topper)
        
        self.setCentralWidget(self.plot_window)
        self.static_ax = self._canvas.figure.subplots()
        self.static_ax.set_xlim(0,1800)
        self.static_ax.set_xlabel('Time [s]')
        self.static_ax.set_ylabel('Probability of Detection [%]')
        
    
    def new_graph(self):
        self.adder=Load_New()
        self.adder.add.clicked.connect(self.new_)
        
    def new_(self):
        self.adder.process()
        
        time,probability=self.adder.time,self.adder.probability
        self.base_names.append(self.adder.name)
        self.database[self.adder.name]=[time,probability]
        self.replot()
        
        
    def replot(self):
        self.static_ax.clear()
        self.static_ax.set_xlim(0,1800)
        self.static_ax.set_xlabel('Time [s]')
        self.static_ax.set_ylabel('Probability of Detection [%]')
        # self.static_ax.set_ylim(0,100)
        markers=['c+','rx','k*','m1','y.','g8','b2','mh','co','k+']
        j=0
        eva=int(self.prob.value())
        for i in self.base_names:
            time,probability=self.database[i]
            try:
                detect_time=np.interp(int(eva),probability,time)

                self.legends.append('{}: {:.2f}s'.format(i,detect_time))
                self.static_ax.plot(time,probability,markers[j],
                label='{}: Time to {}% detection probability: {:.2f}s'.format(
                    i,eva,detect_time))
                j+=1
            except:
                pass
        self.static_ax.axhline(eva,color='r',linestyle='--',linewidth=0.5)
        self.static_ax.legend(prop={'size':18})
        self._canvas.draw()
        self.clear.setEnabled(True)
        self.save_figure.setEnabled(True) 
        
    def save_fig(self):
        options='Portable Network Graphics (*.png);;'
        options_='Joint Photographic Experts Group(*.jpg)'
        options=options+options_
        file_name,ok=QFileDialog.getSaveFileName(self,'Image Save',""
                                              ,options)
        
        if file_name and ok:
            self.figure.savefig(file_name,dpi=600,figsize=(10,10))
    
    def clear_graph(self):
        self.static_ax.clear()
        self.static_ax.set_xlim(0,1800)
        self.static_ax.set_xlabel('Time [s]')
        self.static_ax.set_ylabel('Probability of Detection [%]')
        self.static_ax.set_ylim(0,100)
        self._canvas.draw()
        self.clear.setEnabled(False)
        self.save_figure.setEnabled(False)
        
    def label_update(self):
        val=self.prob.value()
        self.prop_label.setText('Detection Probability: {}%'.format(val))
        self.replot()
Ejemplo n.º 33
0
class MplCanvas(FigureCanvas):
    """MplCanvas class."""

    def __init__(self, parent=None, width=5, height=4, dpi=100):
        """
        Initialize of MatPlotLib canvas for plotter.

        :param parent: the QWidget parent
        :param width: the initial width of canvas
        :param height: the initial height of canvas
        :param dpi: the dpi of the canvas
        """
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = self.fig.add_subplot(111)

        FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self,
                                   QtWidgets.QSizePolicy.Expanding,
                                   QtWidgets.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

    def update_plot(
            self,
            data,
            x_axis,
            y_axis,
            color,
            marker,
            style,
            x_scale,
            y_scale,
            linewidth,
            font_title,
            font_axes):
        """
        Update the data and axis range of the canvas.

        :param data: a dictionary that contains the data points
        :param x_axis: the ticks on X axis
        :param y_axis: the ticks on Y axis
        :param color: color name
        :param marker : data marker
        :param style: line style
        :param x_scale: x-axis scale
        :param y_scale: y-axis scale
        :param linewidth: plot line width
        :param font_title: title font size
        :param font_axes: axes labels font size
        :return: None
        """
        self.axes.cla()
        self.axes.grid(True, linestyle='-.', which='both')
        if x_axis in data.keys() and y_axis in data.keys():
            x_unit = ""
            y_unit = ""
            title = ""
            self.axes.plot(
                data[x_axis],
                data[y_axis],
                color=color,
                marker=marker,
                linestyle=style,
                linewidth=linewidth)
            if y_axis in gopem.helper.UnitTable.keys():
                title += gopem.helper.UnitTable[y_axis][0] + "~"
                if gopem.helper.UnitTable[y_axis][1] is not None:
                    y_unit = "({0})".format(gopem.helper.UnitTable[y_axis][1])
            if x_axis in gopem.helper.UnitTable.keys():
                if title:
                    title += gopem.helper.UnitTable[x_axis][0]
                    self.axes.set_title(title, fontsize=font_title)
                if gopem.helper.UnitTable[x_axis][1] is not None:
                    x_unit = "({0})".format(gopem.helper.UnitTable[x_axis][1])
            self.axes.set_xlabel(x_axis + x_unit, fontsize=font_axes)
            self.axes.set_ylabel(y_axis + y_unit, fontsize=font_axes)
            self.axes.set_yscale(y_scale)
            self.axes.set_xscale(x_scale)
            self.axes.tick_params(labelsize=font_axes)

        self.draw()

    def save_fig(self, filename, transparent):
        """
        Save figure.

        :param filename: file name
        :param transparent: transparent flag
        :return: None
        """
        self.fig.savefig(filename, transparent=transparent)
Ejemplo n.º 34
0
class PlotCanvas(FigureCanvas):
    def __init__(self,
                 linenumber,
                 ownsizes=None,
                 parent=None,
                 width=5,
                 height=4,
                 dpi=100):

        self.fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = self.fig.add_subplot(111)
        super().__init__(self.fig)
        self.setParent(parent)
        FigureCanvas.setSizePolicy(self, Wid.QSizePolicy.Expanding,
                                   Wid.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        self.ldict = {}
        lcolors = ()
        lmarkers = ()
        lcolors = ('purple', 'red', 'green', 'orange', 'yellow')
        lmarkers = ('o', 'o', '*', 'o', 'o')
        if ownsizes is None:
            lsizes = (2, 2, 2, 2, 2)
        else:
            lsizes = ownsizes

        for i in range(0, linenumber):
            key = 'line' + str(i)
            self.ldict[key] = lines.Line2D([], [],
                                           marker=lmarkers[i],
                                           color=lcolors[i],
                                           linestyle='',
                                           markersize=lsizes[i])

        self.draw()
        self.axes.grid()
        self.axes.set_ylabel('y')
        self.axes.set_xlabel('x')

    def scalePlot(self, scale='linear'):
        self.axes.set_yscale(scale)

    def limitPlot(self, xvalues, yvalues):
        # Replot with new scales.
        self.axes.set_xlim(xvalues)
        self.axes.set_ylim(yvalues)
        self.fig.canvas.draw()
        self.flush_events()

    def getData(self):
        for key in self.ldict:
            print(self.ldict[key].get_data())

    def updatePlot(self, *args):
        for data, key in zip(args, self.ldict):
            self.ldict[key].set_data(data)
            self.axes.add_line(self.ldict[key])
        self.fig.canvas.draw()
        self.flush_events()

    def updatexyLabels(self, xlabel, ylabel):
        self.axes.set_xlabel(xlabel)
        self.axes.set_ylabel(ylabel)

    def savePlot(self, name):
        self.fig.savefig(name)
Ejemplo n.º 35
0
                 label=True,
                 linewidths=2)
mg.pressure.plot(ax_old,
                 pm_original,
                 scale=0.01,
                 resolution=0.25,
                 levels=numpy.arange(870, 1050, 7),
                 colors='blue',
                 label=True,
                 linewidths=2)
mg.pressure.plot(ax_new,
                 pm_modified,
                 scale=0.01,
                 resolution=0.25,
                 levels=numpy.arange(870, 1050, 7),
                 colors='blue',
                 label=True,
                 linewidths=2)

# Mark the data used
mg.utils.plot_label(ax_new,
                    '%04d-%02d-%02d:%02d' %
                    (args.year, args.month, args.day, args.hour),
                    facecolor=fig.get_facecolor(),
                    x_fraction=0.98,
                    horizontalalignment='right')

# Render the figure as a png
fig.savefig("comparison_%04d-%02d-%02d:%02d.png" %
            (args.year, args.month, args.day, args.hour))
Ejemplo n.º 36
0
class BackendMatplotlib(BackendBase.BackendBase):
    """Base class for Matplotlib backend without a FigureCanvas.

    For interactive on screen plot, see :class:`BackendMatplotlibQt`.

    See :class:`BackendBase.BackendBase` for public API documentation.
    """

    def __init__(self, plot, parent=None):
        super(BackendMatplotlib, self).__init__(plot, parent)

        # matplotlib is handling keep aspect ratio at draw time
        # When keep aspect ratio is on, and one changes the limits and
        # ask them *before* next draw has been performed he will get the
        # limits without applying keep aspect ratio.
        # This attribute is used to ensure consistent values returned
        # when getting the limits at the expense of a replot
        self._dirtyLimits = True
        self._axesDisplayed = True
        self._matplotlibVersion = _parse_version(matplotlib.__version__)

        self.fig = Figure()
        self.fig.set_facecolor("w")

        self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
        self.ax2 = self.ax.twinx()
        self.ax2.set_label("right")

        # disable the use of offsets
        try:
            self.ax.get_yaxis().get_major_formatter().set_useOffset(False)
            self.ax.get_xaxis().get_major_formatter().set_useOffset(False)
            self.ax2.get_yaxis().get_major_formatter().set_useOffset(False)
            self.ax2.get_xaxis().get_major_formatter().set_useOffset(False)
        except:
            _logger.warning('Cannot disabled axes offsets in %s '
                            % matplotlib.__version__)

        # critical for picking!!!!
        self.ax2.set_zorder(0)
        self.ax2.set_autoscaley_on(True)
        self.ax.set_zorder(1)
        # this works but the figure color is left
        if self._matplotlibVersion < _parse_version('2'):
            self.ax.set_axis_bgcolor('none')
        else:
            self.ax.set_facecolor('none')
        self.fig.sca(self.ax)

        self._overlays = set()
        self._background = None

        self._colormaps = {}

        self._graphCursor = tuple()

        self._enableAxis('right', False)
        self._isXAxisTimeSeries = False

    # Add methods

    def addCurve(self, x, y, legend,
                 color, symbol, linewidth, linestyle,
                 yaxis,
                 xerror, yerror, z, selectable,
                 fill, alpha, symbolsize):
        for parameter in (x, y, legend, color, symbol, linewidth, linestyle,
                          yaxis, z, selectable, fill, alpha, symbolsize):
            assert parameter is not None
        assert yaxis in ('left', 'right')

        if (len(color) == 4 and
                type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
            color = numpy.array(color, dtype=numpy.float) / 255.

        if yaxis == "right":
            axes = self.ax2
            self._enableAxis("right", True)
        else:
            axes = self.ax

        picker = 3 if selectable else None

        artists = []  # All the artists composing the curve

        # First add errorbars if any so they are behind the curve
        if xerror is not None or yerror is not None:
            if hasattr(color, 'dtype') and len(color) == len(x):
                errorbarColor = 'k'
            else:
                errorbarColor = color

            # On Debian 7 at least, Nx1 array yerr does not seems supported
            if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2 and
                    yerror.shape[1] == 1 and len(x) != 1):
                yerror = numpy.ravel(yerror)

            errorbars = axes.errorbar(x, y, label=legend,
                                      xerr=xerror, yerr=yerror,
                                      linestyle=' ', color=errorbarColor)
            artists += list(errorbars.get_children())

        if hasattr(color, 'dtype') and len(color) == len(x):
            # scatter plot
            if color.dtype not in [numpy.float32, numpy.float]:
                actualColor = color / 255.
            else:
                actualColor = color

            if linestyle not in ["", " ", None]:
                # scatter plot with an actual line ...
                # we need to assign a color ...
                curveList = axes.plot(x, y, label=legend,
                                      linestyle=linestyle,
                                      color=actualColor[0],
                                      linewidth=linewidth,
                                      picker=picker,
                                      marker=None)
                artists += list(curveList)

            scatter = axes.scatter(x, y,
                                   label=legend,
                                   color=actualColor,
                                   marker=symbol,
                                   picker=picker,
                                   s=symbolsize**2)
            artists.append(scatter)

            if fill:
                artists.append(axes.fill_between(
                    x, FLOAT32_MINPOS, y, facecolor=actualColor[0], linestyle=''))

        else:  # Curve
            curveList = axes.plot(x, y,
                                  label=legend,
                                  linestyle=linestyle,
                                  color=color,
                                  linewidth=linewidth,
                                  marker=symbol,
                                  picker=picker,
                                  markersize=symbolsize)
            artists += list(curveList)

            if fill:
                artists.append(
                    axes.fill_between(x, FLOAT32_MINPOS, y, facecolor=color))

        for artist in artists:
            artist.set_zorder(z)
            if alpha < 1:
                artist.set_alpha(alpha)

        return Container(artists)

    def addImage(self, data, legend,
                 origin, scale, z,
                 selectable, draggable,
                 colormap, alpha):
        # Non-uniform image
        # http://wiki.scipy.org/Cookbook/Histograms
        # Non-linear axes
        # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
        for parameter in (data, legend, origin, scale, z,
                          selectable, draggable):
            assert parameter is not None

        origin = float(origin[0]), float(origin[1])
        scale = float(scale[0]), float(scale[1])
        height, width = data.shape[0:2]

        picker = (selectable or draggable)

        # All image are shown as RGBA image
        image = Image(self.ax,
                      label="__IMAGE__" + legend,
                      interpolation='nearest',
                      picker=picker,
                      zorder=z,
                      origin='lower')

        if alpha < 1:
            image.set_alpha(alpha)

        # Set image extent
        xmin = origin[0]
        xmax = xmin + scale[0] * width
        if scale[0] < 0.:
            xmin, xmax = xmax, xmin

        ymin = origin[1]
        ymax = ymin + scale[1] * height
        if scale[1] < 0.:
            ymin, ymax = ymax, ymin

        image.set_extent((xmin, xmax, ymin, ymax))

        # Set image data
        if scale[0] < 0. or scale[1] < 0.:
            # For negative scale, step by -1
            xstep = 1 if scale[0] >= 0. else -1
            ystep = 1 if scale[1] >= 0. else -1
            data = data[::ystep, ::xstep]

        if data.ndim == 2:  # Data image, convert to RGBA image
            data = colormap.applyToData(data)

        image.set_data(data)
        self.ax.add_artist(image)
        return image

    def addItem(self, x, y, legend, shape, color, fill, overlay, z,
                linestyle, linewidth, linebgcolor):
        if (linebgcolor is not None and
                shape not in ('rectangle', 'polygon', 'polylines')):
            _logger.warning(
                'linebgcolor not implemented for %s with matplotlib backend',
                shape)
        xView = numpy.array(x, copy=False)
        yView = numpy.array(y, copy=False)

        linestyle = normalize_linestyle(linestyle)

        if shape == "line":
            item = self.ax.plot(x, y, label=legend, color=color,
                                linestyle=linestyle, linewidth=linewidth,
                                marker=None)[0]

        elif shape == "hline":
            if hasattr(y, "__len__"):
                y = y[-1]
            item = self.ax.axhline(y, label=legend, color=color,
                                   linestyle=linestyle, linewidth=linewidth)

        elif shape == "vline":
            if hasattr(x, "__len__"):
                x = x[-1]
            item = self.ax.axvline(x, label=legend, color=color,
                                   linestyle=linestyle, linewidth=linewidth)

        elif shape == 'rectangle':
            xMin = numpy.nanmin(xView)
            xMax = numpy.nanmax(xView)
            yMin = numpy.nanmin(yView)
            yMax = numpy.nanmax(yView)
            w = xMax - xMin
            h = yMax - yMin
            item = Rectangle(xy=(xMin, yMin),
                             width=w,
                             height=h,
                             fill=False,
                             color=color,
                             linestyle=linestyle,
                             linewidth=linewidth)
            if fill:
                item.set_hatch('.')

            if linestyle != "solid" and linebgcolor is not None:
                item = _DoubleColoredLinePatch(item)
                item.linebgcolor = linebgcolor

            self.ax.add_patch(item)

        elif shape in ('polygon', 'polylines'):
            points = numpy.array((xView, yView)).T
            if shape == 'polygon':
                closed = True
            else:  # shape == 'polylines'
                closed = numpy.all(numpy.equal(points[0], points[-1]))
            item = Polygon(points,
                           closed=closed,
                           fill=False,
                           label=legend,
                           color=color,
                           linestyle=linestyle,
                           linewidth=linewidth)
            if fill and shape == 'polygon':
                item.set_hatch('/')

            if linestyle != "solid" and linebgcolor is not None:
                item = _DoubleColoredLinePatch(item)
                item.linebgcolor = linebgcolor

            self.ax.add_patch(item)

        else:
            raise NotImplementedError("Unsupported item shape %s" % shape)

        item.set_zorder(z)

        if overlay:
            item.set_animated(True)
            self._overlays.add(item)

        return item

    def addMarker(self, x, y, legend, text, color,
                  selectable, draggable,
                  symbol, linestyle, linewidth, constraint):
        legend = "__MARKER__" + legend

        textArtist = None

        xmin, xmax = self.getGraphXLimits()
        ymin, ymax = self.getGraphYLimits(axis='left')

        if x is not None and y is not None:
            line = self.ax.plot(x, y, label=legend,
                                linestyle=" ",
                                color=color,
                                marker=symbol,
                                markersize=10.)[-1]

            if text is not None:
                if symbol is None:
                    valign = 'baseline'
                else:
                    valign = 'top'
                    text = "  " + text

                textArtist = self.ax.text(x, y, text,
                                          color=color,
                                          horizontalalignment='left',
                                          verticalalignment=valign)

        elif x is not None:
            line = self.ax.axvline(x,
                                   label=legend,
                                   color=color,
                                   linewidth=linewidth,
                                   linestyle=linestyle)
            if text is not None:
                # Y position will be updated in updateMarkerText call
                textArtist = self.ax.text(x, 1., " " + text,
                                          color=color,
                                          horizontalalignment='left',
                                          verticalalignment='top')

        elif y is not None:
            line = self.ax.axhline(y,
                                   label=legend,
                                   color=color,
                                   linewidth=linewidth,
                                   linestyle=linestyle)

            if text is not None:
                # X position will be updated in updateMarkerText call
                textArtist = self.ax.text(1., y, " " + text,
                                          color=color,
                                          horizontalalignment='right',
                                          verticalalignment='top')

        else:
            raise RuntimeError('A marker must at least have one coordinate')

        if selectable or draggable:
            line.set_picker(5)

        # All markers are overlays
        line.set_animated(True)
        if textArtist is not None:
            textArtist.set_animated(True)

        artists = [line] if textArtist is None else [line, textArtist]
        container = _MarkerContainer(artists, x, y)
        container.updateMarkerText(xmin, xmax, ymin, ymax)
        self._overlays.add(container)

        return container

    def _updateMarkers(self):
        xmin, xmax = self.ax.get_xbound()
        ymin, ymax = self.ax.get_ybound()
        for item in self._overlays:
            if isinstance(item, _MarkerContainer):
                item.updateMarkerText(xmin, xmax, ymin, ymax)

    # Remove methods

    def remove(self, item):
        # Warning: It also needs to remove extra stuff if added as for markers
        self._overlays.discard(item)
        try:
            item.remove()
        except ValueError:
            pass  # Already removed e.g., in set[X|Y]AxisLogarithmic

    # Interaction methods

    def setGraphCursor(self, flag, color, linewidth, linestyle):
        if flag:
            lineh = self.ax.axhline(
                self.ax.get_ybound()[0], visible=False, color=color,
                linewidth=linewidth, linestyle=linestyle)
            lineh.set_animated(True)

            linev = self.ax.axvline(
                self.ax.get_xbound()[0], visible=False, color=color,
                linewidth=linewidth, linestyle=linestyle)
            linev.set_animated(True)

            self._graphCursor = lineh, linev
        else:
            if self._graphCursor is not None:
                lineh, linev = self._graphCursor
                lineh.remove()
                linev.remove()
                self._graphCursor = tuple()

    # Active curve

    def setCurveColor(self, curve, color):
        # Store Line2D and PathCollection
        for artist in curve.get_children():
            if isinstance(artist, (Line2D, LineCollection)):
                artist.set_color(color)
            elif isinstance(artist, PathCollection):
                artist.set_facecolors(color)
                artist.set_edgecolors(color)
            else:
                _logger.warning(
                    'setActiveCurve ignoring artist %s', str(artist))

    # Misc.

    def getWidgetHandle(self):
        return self.fig.canvas

    def _enableAxis(self, axis, flag=True):
        """Show/hide Y axis

        :param str axis: Axis name: 'left' or 'right'
        :param bool flag: Default, True
        """
        assert axis in ('right', 'left')
        axes = self.ax2 if axis == 'right' else self.ax
        axes.get_yaxis().set_visible(flag)

    def replot(self):
        """Do not perform rendering.

        Override in subclass to actually draw something.
        """
        # TODO images, markers? scatter plot? move in remove?
        # Right Y axis only support curve for now
        # Hide right Y axis if no line is present
        self._dirtyLimits = False
        if not self.ax2.lines:
            self._enableAxis('right', False)

    def saveGraph(self, fileName, fileFormat, dpi):
        # fileName can be also a StringIO or file instance
        if dpi is not None:
            self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
        else:
            self.fig.savefig(fileName, format=fileFormat)
        self._plot._setDirtyPlot()

    # Graph labels

    def setGraphTitle(self, title):
        self.ax.set_title(title)

    def setGraphXLabel(self, label):
        self.ax.set_xlabel(label)

    def setGraphYLabel(self, label, axis):
        axes = self.ax if axis == 'left' else self.ax2
        axes.set_ylabel(label)

    # Graph limits

    def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
        # Let matplotlib taking care of keep aspect ratio if any
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))

        if y2min is not None and y2max is not None:
            if not self.isYAxisInverted():
                self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
            else:
                self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))

        if not self.isYAxisInverted():
            self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
        else:
            self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))

        self._updateMarkers()

    def getGraphXLimits(self):
        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits
        return self.ax.get_xbound()

    def setGraphXLimits(self, xmin, xmax):
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
        self._updateMarkers()

    def getGraphYLimits(self, axis):
        assert axis in ('left', 'right')
        ax = self.ax2 if axis == 'right' else self.ax

        if not ax.get_visible():
            return None

        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits

        return ax.get_ybound()

    def setGraphYLimits(self, ymin, ymax, axis):
        ax = self.ax2 if axis == 'right' else self.ax
        if ymax < ymin:
            ymin, ymax = ymax, ymin
        self._dirtyLimits = True

        if self.isKeepDataAspectRatio():
            # matplotlib keeps limits of shared axis when keeping aspect ratio
            # So x limits are kept when changing y limits....
            # Change x limits first by taking into account aspect ratio
            # and then change y limits.. so matplotlib does not need
            # to make change (to y) to keep aspect ratio
            xmin, xmax = ax.get_xbound()
            curYMin, curYMax = ax.get_ybound()

            newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
            xcenter = 0.5 * (xmin + xmax)
            ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)

        if not self.isYAxisInverted():
            ax.set_ylim(ymin, ymax)
        else:
            ax.set_ylim(ymax, ymin)

        self._updateMarkers()

    # Graph axes

    def setXAxisTimeZone(self, tz):
        super(BackendMatplotlib, self).setXAxisTimeZone(tz)

        # Make new formatter and locator with the time zone.
        self.setXAxisTimeSeries(self.isXAxisTimeSeries())

    def isXAxisTimeSeries(self):
        return self._isXAxisTimeSeries

    def setXAxisTimeSeries(self, isTimeSeries):
        self._isXAxisTimeSeries = isTimeSeries
        if self._isXAxisTimeSeries:
            # We can't use a matplotlib.dates.DateFormatter because it expects
            # the data to be in datetimes. Silx works internally with
            # timestamps (floats).
            locator = NiceDateLocator(tz=self.getXAxisTimeZone())
            self.ax.xaxis.set_major_locator(locator)
            self.ax.xaxis.set_major_formatter(
                NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone()))
        else:
            try:
                scalarFormatter = ScalarFormatter(useOffset=False)
            except:
                _logger.warning('Cannot disabled axes offsets in %s ' %
                                matplotlib.__version__)
                scalarFormatter = ScalarFormatter()
            self.ax.xaxis.set_major_formatter(scalarFormatter)

    def setXAxisLogarithmic(self, flag):
        # Workaround for matplotlib 2.1.0 when one tries to set an axis
        # to log scale with both limits <= 0
        # In this case a draw with positive limits is needed first
        if flag and self._matplotlibVersion >= _parse_version('2.1.0'):
            xlim = self.ax.get_xlim()
            if xlim[0] <= 0 and xlim[1] <= 0:
                self.ax.set_xlim(1, 10)
                self.draw()

        self.ax2.set_xscale('log' if flag else 'linear')
        self.ax.set_xscale('log' if flag else 'linear')

    def setYAxisLogarithmic(self, flag):
        # Workaround for matplotlib 2.0 issue with negative bounds
        # before switching to log scale
        if flag and self._matplotlibVersion >= _parse_version('2.0.0'):
            redraw = False
            for axis, dataRangeIndex in ((self.ax, 1), (self.ax2, 2)):
                ylim = axis.get_ylim()
                if ylim[0] <= 0 or ylim[1] <= 0:
                    dataRange = self._plot.getDataRange()[dataRangeIndex]
                    if dataRange is None:
                        dataRange = 1, 100  # Fallback
                    axis.set_ylim(*dataRange)
                    redraw = True
            if redraw:
                self.draw()

        self.ax2.set_yscale('log' if flag else 'linear')
        self.ax.set_yscale('log' if flag else 'linear')

    def setYAxisInverted(self, flag):
        if self.ax.yaxis_inverted() != bool(flag):
            self.ax.invert_yaxis()

    def isYAxisInverted(self):
        return self.ax.yaxis_inverted()

    def isKeepDataAspectRatio(self):
        return self.ax.get_aspect() in (1.0, 'equal')

    def setKeepDataAspectRatio(self, flag):
        self.ax.set_aspect(1.0 if flag else 'auto')
        self.ax2.set_aspect(1.0 if flag else 'auto')

    def setGraphGrid(self, which):
        self.ax.grid(False, which='both')  # Disable all grid first
        if which is not None:
            self.ax.grid(True, which=which)

    # Data <-> Pixel coordinates conversion

    def _mplQtYAxisCoordConversion(self, y):
        """Qt origin (top) to/from matplotlib origin (bottom) conversion.

        :rtype: float
        """
        height = self.fig.get_window_extent().height
        return height - y

    def dataToPixel(self, x, y, axis):
        ax = self.ax2 if axis == "right" else self.ax

        pixels = ax.transData.transform_point((x, y))
        xPixel, yPixel = pixels.T

        # Convert from matplotlib origin (bottom) to Qt origin (top)
        yPixel = self._mplQtYAxisCoordConversion(yPixel)

        return xPixel, yPixel

    def pixelToData(self, x, y, axis, check):
        ax = self.ax2 if axis == "right" else self.ax

        # Convert from Qt origin (top) to matplotlib origin (bottom)
        y = self._mplQtYAxisCoordConversion(y)

        inv = ax.transData.inverted()
        x, y = inv.transform_point((x, y))

        if check:
            xmin, xmax = self.getGraphXLimits()
            ymin, ymax = self.getGraphYLimits(axis=axis)

            if x > xmax or x < xmin or y > ymax or y < ymin:
                return None  # (x, y) is out of plot area

        return x, y

    def getPlotBoundsInPixels(self):
        bbox = self.ax.get_window_extent()
        # Warning this is not returning int...
        return (bbox.xmin,
                self._mplQtYAxisCoordConversion(bbox.ymax),
                bbox.width,
                bbox.height)

    def setAxesDisplayed(self, displayed):
        """Display or not the axes.

        :param bool displayed: If `True` axes are displayed. If `False` axes
            are not anymore visible and the margin used for them is removed.
        """
        BackendBase.BackendBase.setAxesDisplayed(self, displayed)
        if displayed:
            # show axes and viewbox rect
            self.ax.set_axis_on()
            self.ax2.set_axis_on()
            # set the default margins
            self.ax.set_position([.15, .15, .75, .75])
            self.ax2.set_position([.15, .15, .75, .75])
        else:
            # hide axes and viewbox rect
            self.ax.set_axis_off()
            self.ax2.set_axis_off()
            # remove external margins
            self.ax.set_position([0, 0, 1, 1])
            self.ax2.set_position([0, 0, 1, 1])
        self._synchronizeBackgroundColors()
        self._synchronizeForegroundColors()
        self._plot._setDirtyPlot()

    def _synchronizeBackgroundColors(self):
        backgroundColor = self._plot.getBackgroundColor().getRgbF()

        dataBackgroundColor = self._plot.getDataBackgroundColor()
        if dataBackgroundColor.isValid():
            dataBackgroundColor = dataBackgroundColor.getRgbF()
        else:
            dataBackgroundColor = backgroundColor

        if self.ax.axison:
            self.fig.patch.set_facecolor(backgroundColor)
            if self._matplotlibVersion < _parse_version('2'):
                self.ax.set_axis_bgcolor(dataBackgroundColor)
            else:
                self.ax.set_facecolor(dataBackgroundColor)
        else:
            self.fig.patch.set_facecolor(dataBackgroundColor)

    def _synchronizeForegroundColors(self):
        foregroundColor = self._plot.getForegroundColor().getRgbF()

        gridColor = self._plot.getGridColor()
        if gridColor.isValid():
            gridColor = gridColor.getRgbF()
        else:
            gridColor = foregroundColor

        if self.ax.axison:
            self.ax.spines['bottom'].set_color(foregroundColor)
            self.ax.spines['top'].set_color(foregroundColor)
            self.ax.spines['right'].set_color(foregroundColor)
            self.ax.spines['left'].set_color(foregroundColor)
            self.ax.tick_params(axis='x', colors=foregroundColor)
            self.ax.tick_params(axis='y', colors=foregroundColor)
            self.ax.yaxis.label.set_color(foregroundColor)
            self.ax.xaxis.label.set_color(foregroundColor)
            self.ax.title.set_color(foregroundColor)

            for line in self.ax.get_xgridlines():
                line.set_color(gridColor)

            for line in self.ax.get_ygridlines():
                line.set_color(gridColor)
Ejemplo n.º 37
0
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure

fig = Figure()
canvas = FigureCanvas(fig)

import numpy as np

x = np.random.randn(10000)

ax = fig.add_subplot(111)
ax.hist(x, 100)

ax.set_title('Normal distribution with u=0,sigma=1')
fig.savefig('.\png\matplotlib_histogram_1.png')
Ejemplo n.º 38
0
def plot_zone_vector_and_atom_distance_map(image_data,
                                           distance_data,
                                           atom_planes=None,
                                           distance_data_scale=1,
                                           atom_list=None,
                                           extra_marker_list=None,
                                           clim=None,
                                           atom_plane_marker=None,
                                           plot_title='',
                                           vector_to_plot=None,
                                           figsize=(10, 20),
                                           figname="map_data.jpg"):
    """
    Parameters
    ----------
    atom_list : list of Atom_Position instances
    extra_marker_list : two arrays of x and y [[x_values], [y_values]]
    """
    if image_data is None:
        raise ValueError("Image data is None, no data to plot")

    fig = Figure(figsize=figsize)
    FigureCanvas(fig)

    gs = GridSpec(95, 95)

    image_ax = fig.add_subplot(gs[0:45, :])
    distance_ax = fig.add_subplot(gs[45:90, :])
    colorbar_ax = fig.add_subplot(gs[90:, :])

    image_clim = to._get_clim_from_data(image_data, sigma=2)
    image_cax = image_ax.imshow(image_data)
    image_cax.set_clim(image_clim[0], image_clim[1])
    if atom_planes:
        for atom_plane_index, atom_plane in enumerate(atom_planes):
            x_pos = atom_plane.get_x_position_list()
            y_pos = atom_plane.get_y_position_list()
            image_ax.plot(x_pos, y_pos, lw=3, color='blue')
            image_ax.text(atom_plane.start_atom.pixel_x,
                          atom_plane.start_atom.pixel_y,
                          str(atom_plane_index),
                          color='red')
    image_ax.set_ylim(0, image_data.shape[0])
    image_ax.set_xlim(0, image_data.shape[1])
    image_ax.set_title(plot_title)

    if atom_plane_marker:
        atom_plane_x = atom_plane_marker.get_x_position_list()
        atom_plane_y = atom_plane_marker.get_y_position_list()
        image_ax.plot(atom_plane_x, atom_plane_y, color='red', lw=2)

    _make_subplot_map_from_regular_grid(
        distance_ax,
        distance_data,
        distance_data_scale=distance_data_scale,
        clim=clim,
        atom_list=atom_list,
        atom_plane_marker=atom_plane_marker,
        extra_marker_list=extra_marker_list,
        vector_to_plot=vector_to_plot)
    distance_cax = distance_ax.images[0]

    fig.tight_layout()
    fig.colorbar(distance_cax, cax=colorbar_ax, orientation='horizontal')
    fig.savefig(figname)
    plt.close(fig)
Ejemplo n.º 39
0
def plot_complex_image_map_line_profile_using_interface_plane(
        image_data,
        amplitude_data,
        phase_data,
        line_profile_amplitude_data,
        line_profile_phase_data,
        interface_plane,
        atom_plane_list=None,
        data_scale=1,
        atom_list=None,
        extra_marker_list=None,
        amplitude_image_lim=None,
        phase_image_lim=None,
        plot_title='',
        add_color_wheel=False,
        color_bar_markers=None,
        vector_to_plot=None,
        rotate_atom_plane_list_90_degrees=False,
        prune_outer_values=False,
        figname="map_data.jpg"):
    """
    Parameters
    ----------
    atom_list : list of Atom_Position instances
    extra_marker_list : two arrays of x and y [[x_values], [y_values]]
    """
    number_of_line_profiles = 2

    figsize = (10, 18 + 2 * number_of_line_profiles)

    fig = Figure(figsize=figsize)
    FigureCanvas(fig)

    gs = GridSpec(100 + 10 * number_of_line_profiles, 95)

    image_ax = fig.add_subplot(gs[0:45, :])
    distance_ax = fig.add_subplot(gs[45:90, :])
    colorbar_ax = fig.add_subplot(gs[90:100, :])

    line_profile_ax_list = []
    for i in range(number_of_line_profiles):
        gs_y_start = 100 + 10 * i
        line_profile_ax = fig.add_subplot(gs[gs_y_start:gs_y_start + 10, :])
        line_profile_ax_list.append(line_profile_ax)

    image_y_lim = (0, image_data.shape[0] * data_scale)
    image_x_lim = (0, image_data.shape[1] * data_scale)

    image_clim = to._get_clim_from_data(image_data, sigma=2)
    image_cax = image_ax.imshow(image_data,
                                origin='lower',
                                extent=[
                                    image_x_lim[0], image_x_lim[1],
                                    image_y_lim[0], image_y_lim[1]
                                ])

    image_cax.set_clim(image_clim[0], image_clim[1])
    image_ax.set_xlim(image_x_lim[0], image_x_lim[1])
    image_ax.set_ylim(image_y_lim[0], image_y_lim[1])
    image_ax.set_title(plot_title)

    if atom_plane_list is not None:
        for atom_plane in atom_plane_list:
            if rotate_atom_plane_list_90_degrees:
                atom_plane_x = np.array(atom_plane.get_x_position_list())
                atom_plane_y = np.array(atom_plane.get_y_position_list())
                start_x = atom_plane_x[0]
                start_y = atom_plane_y[0]
                delta_y = (atom_plane_x[-1] - atom_plane_x[0])
                delta_x = -(atom_plane_y[-1] - atom_plane_y[0])
                atom_plane_x = np.array([start_x, start_x + delta_x])
                atom_plane_y = np.array([start_y, start_y + delta_y])
            else:
                atom_plane_x = np.array(atom_plane.get_x_position_list())
                atom_plane_y = np.array(atom_plane.get_y_position_list())
            image_ax.plot(atom_plane_x * data_scale,
                          atom_plane_y * data_scale,
                          color='red',
                          lw=2)

    atom_plane_x = np.array(interface_plane.get_x_position_list())
    atom_plane_y = np.array(interface_plane.get_y_position_list())
    image_ax.plot(atom_plane_x * data_scale,
                  atom_plane_y * data_scale,
                  color='blue',
                  lw=2)

    _make_subplot_map_from_complex_regular_grid(
        distance_ax,
        amplitude_data,
        phase_data,
        atom_list=atom_list,
        amplitude_image_lim=amplitude_image_lim,
        phase_image_lim=phase_image_lim,
        atom_plane_marker=interface_plane,
        extra_marker_list=extra_marker_list,
        vector_to_plot=vector_to_plot)
    distance_ax.plot(atom_plane_x * data_scale,
                     atom_plane_y * data_scale,
                     color='red',
                     lw=2)

    line_profile_data_list = [
        line_profile_amplitude_data, line_profile_phase_data
    ]

    for line_profile_ax, line_profile_data in zip(line_profile_ax_list,
                                                  line_profile_data_list):
        _make_subplot_line_profile(line_profile_ax,
                                   line_profile_data[:, 0],
                                   line_profile_data[:, 1],
                                   prune_outer_values=prune_outer_values,
                                   scale_x=data_scale)

    amplitude_delta = 0.01 * (amplitude_image_lim[1] - amplitude_image_lim[0])
    phase_delta = 0.01 * (phase_image_lim[1] - phase_image_lim[0])
    colorbar_mgrid = np.mgrid[
        amplitude_image_lim[0]:amplitude_image_lim[1]:amplitude_delta,
        phase_image_lim[0]:phase_image_lim[1]:phase_delta]
    colorbar_rgb = get_rgb_array(colorbar_mgrid[1], colorbar_mgrid[0])
    colorbar_ax.imshow(colorbar_rgb,
                       origin='lower',
                       extent=[
                           phase_image_lim[0], phase_image_lim[1],
                           amplitude_image_lim[0], amplitude_image_lim[1]
                       ])

    colorbar_ax.set_xlabel("Phase", size=6)

    if color_bar_markers is not None:
        for color_bar_marker in color_bar_markers:
            colorbar_ax.axvline(color_bar_marker[0], color='white')
            colorbar_ax.text(color_bar_marker[0],
                             amplitude_image_lim[0] * 0.97,
                             color_bar_marker[1],
                             transform=colorbar_ax.transData,
                             va='top',
                             ha='center',
                             fontsize=8)

    if add_color_wheel:
        ax_magnetic_color_wheel_gs = GridSpecFromSubplotSpec(
            40, 40, subplot_spec=gs[45:90, :])[30:39, 3:12]
        ax_magnetic_color_wheel = fig.add_subplot(ax_magnetic_color_wheel_gs)
        ax_magnetic_color_wheel.set_axis_off()

        make_color_wheel(ax_magnetic_color_wheel)

    fig.tight_layout()
    fig.savefig(figname)
    plt.close(fig)
Ejemplo n.º 40
0
# Label location of a station in ax_full coordinates
def pos_right(idx):
    result = {}
    result['x'] = 0.668
    result['y'] = 0.05 + (0.94 / len(stations)) * (idx + 0.5)
    return result


for i in range(len(stations)):
    p_left = pos_left(i)
    if p_left['x'] < 0.335 or p_left['x'] > (0.335 + 0.323): continue
    if p_left['y'] < 0.005 or p_left['y'] > (0.005 + 0.94): continue
    p_right = pos_right(i)
    ax_full.add_patch(
        Circle((p_right['x'], p_right['y']),
               radius=0.001,
               facecolor=(1, 0, 0, 1),
               edgecolor=(0, 0, 0, 1),
               alpha=1,
               zorder=1))
    ax_full.add_line(
        matplotlib.lines.Line2D(xdata=(p_left['x'], p_right['x']),
                                ydata=(p_left['y'], p_right['y']),
                                linestyle='solid',
                                linewidth=0.2,
                                color=(1, 0, 0, 1.0),
                                zorder=1))

# Output as png
fig.savefig('Leave_one_out_%04d%02d%02d%02d.png' % (year, month, day, hour))
Ejemplo n.º 41
0
class AbstractMultiGroupPlotPlugin(FigureCanvas):
	'''
	Abstract base class specifying interface of a multiple group plot plugin.
	'''
	def __init__(self, preferences, parent=None):	
		self.preferences = preferences
		
		self.fig = Figure(facecolor='white', dpi=96)
		
		FigureCanvas.__init__(self, self.fig)
		
		self.setParent(parent)
		FigureCanvas.setSizePolicy(self,
														 QtGui.QSizePolicy.Fixed,
														 QtGui.QSizePolicy.Fixed)
		FigureCanvas.updateGeometry(self)
		
		self.cid = None
		
		self.type = '<none>'
		self.name = '<none>'
		self.bSupportsHighlight = False
		self.bPlotFeaturesIndividually = True
		self.bRunPostHocTest = False
		
	def mouseEventCallback(self, callback):
		if self.cid != None:
			FigureCanvas.mpl_disconnect(self, self.cid)
			
		self.cid = FigureCanvas.mpl_connect(self, 'button_press_event', callback)
			
	def plot(self, profile, statsResults):
		pass
	
	def configure(self, profile, statsResults):
		pass
	
	def savePlot(self, filename, dpi=300):
		format = filename[filename.rfind('.')+1:len(filename)]
		if format in ['png', 'pdf', 'ps', 'eps','svg']:
			self.fig.savefig(filename,format=format,dpi=dpi,facecolor='white',edgecolor='white')
		else:
			pass
	
	def clear(self):
		self.fig.clear()
		
	def mirrorProperties(self, plotToCopy):
		self.type = plotToCopy.type
		self.name = plotToCopy.name
		self.bSupportsHighlight = plotToCopy.bSupportsHighlight
	
	def labelExtents(self, xLabels, xFontSize, xRotation, yLabels, yFontSize, yRotation):
		self.fig.clear()
		
		tempAxes = self.fig.add_axes([0,0,1.0,1.0])	
		
		tempAxes.set_xticks(np.arange(len(xLabels)))	
		tempAxes.set_yticks(np.arange(len(yLabels)))	
		
		xText = tempAxes.set_xticklabels(xLabels, size=xFontSize, rotation=xRotation)
		yText = tempAxes.set_yticklabels(yLabels, size=yFontSize, rotation=yRotation)
		
		bboxes = []
		for label in xText:
			bbox = label.get_window_extent(self.get_renderer())
			bboxi = bbox.inverse_transformed(self.fig.transFigure)
			bboxes.append(bboxi)
		xLabelBounds = mtransforms.Bbox.union(bboxes)
		
		bboxes = []
		for label in yText:
			bbox = label.get_window_extent(self.get_renderer())
			bboxi = bbox.inverse_transformed(self.fig.transFigure)
			bboxes.append(bboxi)		
		yLabelBounds = mtransforms.Bbox.union(bboxes)			
		
		self.fig.clear()
		
		return xLabelBounds, yLabelBounds
		
	def xLabelExtents(self, labels, fontSize, rotation=0):
		self.fig.clear()
		
		tempAxes = self.fig.add_axes([0,0,1.0,1.0])	
		tempAxes.set_xticks(np.arange(len(labels)))	
		xLabels = tempAxes.set_xticklabels(labels, size=fontSize, rotation=rotation)
		
		bboxes = []
		for label in xLabels:
			bbox = label.get_window_extent(self.get_renderer())
			bboxi = bbox.inverse_transformed(self.fig.transFigure)
			bboxes.append(bboxi)		
		xLabelBounds = mtransforms.Bbox.union(bboxes)		
		
		self.fig.clear()
		
		return xLabelBounds
	
	def yLabelExtents(self, labels, fontSize, rotation=0):
		self.fig.clear()

		tempAxes = self.fig.add_axes([0,0,1.0,1.0])	
		tempAxes.set_yticks(np.arange(len(labels)))	
		yLabels = tempAxes.set_yticklabels(labels, size=fontSize, rotation=rotation)
		
		bboxes = []
		for label in yLabels:
			bbox = label.get_window_extent(self.get_renderer())
			bboxi = bbox.inverse_transformed(self.fig.transFigure)
			bboxes.append(bboxi)		
		yLabelBounds = mtransforms.Bbox.union(bboxes)		
		
		self.fig.clear()
		
		return yLabelBounds

	def emptyAxis(self, title = ''):
		self.fig.clear()
		self.fig.set_size_inches(6,4)	
		emptyAxis = self.fig.add_axes([0.1,0.1,0.8,0.8]) 
		
		emptyAxis.set_ylabel('No active features or degenerate plot', fontsize=8)
		emptyAxis.set_xlabel('No active features or degenerate plot', fontsize=8)
		emptyAxis.set_yticks([])
		emptyAxis.set_xticks([])
		emptyAxis.set_title(title)
		
		for loc, spine in emptyAxis.spines.iteritems():
			if loc in ['right','top']:
					spine.set_color('none') 
		
		self.updateGeometry()
		self.draw()
		
	def formatLabels(self, labels): 
		formattedLabels = []     
		for label in labels: 
			value = float(label.get_text())    
			if value < 0.01:
				valueStr = '%.2e' % value
				if 'e-00' in valueStr:
					valueStr = valueStr.replace('e-00', 'e-')
				elif 'e-0' in valueStr:
					valueStr = valueStr.replace('e-0', 'e-')
			else:
				valueStr = '%.3f' % value
				
			formattedLabels.append(valueStr)
				
		return formattedLabels
Ejemplo n.º 42
0
def plot_image_map_line_profile_using_interface_plane(
        image_data,
        heatmap_data_list,
        line_profile_data_list,
        interface_plane,
        atom_plane_list=None,
        data_scale=1,
        data_scale_z=1,
        atom_list=None,
        extra_marker_list=None,
        clim=None,
        plot_title='',
        vector_to_plot=None,
        rotate_atom_plane_list_90_degrees=False,
        prune_outer_values=False,
        figname="map_data.jpg"):
    """
    Parameters
    ----------
    atom_list : list of Atom_Position instances
    extra_marker_list : two arrays of x and y [[x_values], [y_values]]
    """
    number_of_line_profiles = len(line_profile_data_list)

    figsize = (10, 18 + 2 * number_of_line_profiles)

    fig = Figure(figsize=figsize)
    FigureCanvas(fig)

    gs = GridSpec(95 + 10 * number_of_line_profiles, 95)

    image_ax = fig.add_subplot(gs[0:45, :])
    distance_ax = fig.add_subplot(gs[45:90, :])
    colorbar_ax = fig.add_subplot(gs[90:95, :])
    line_profile_ax_list = []
    for i in range(number_of_line_profiles):
        gs_y_start = 95 + 10 * i
        line_profile_ax = fig.add_subplot(gs[gs_y_start:gs_y_start + 10, :])
        line_profile_ax_list.append(line_profile_ax)

    image_y_lim = (0, image_data.shape[0] * data_scale)
    image_x_lim = (0, image_data.shape[1] * data_scale)

    image_clim = to._get_clim_from_data(image_data, sigma=2)
    image_cax = image_ax.imshow(image_data,
                                origin='lower',
                                extent=[
                                    image_x_lim[0], image_x_lim[1],
                                    image_y_lim[0], image_y_lim[1]
                                ])

    image_cax.set_clim(image_clim[0], image_clim[1])
    image_ax.set_xlim(image_x_lim[0], image_x_lim[1])
    image_ax.set_ylim(image_y_lim[0], image_y_lim[1])
    image_ax.set_title(plot_title)

    if atom_plane_list is not None:
        for atom_plane in atom_plane_list:
            if rotate_atom_plane_list_90_degrees:
                atom_plane_x = np.array(atom_plane.get_x_position_list())
                atom_plane_y = np.array(atom_plane.get_y_position_list())
                start_x = atom_plane_x[0]
                start_y = atom_plane_y[0]
                delta_y = (atom_plane_x[-1] - atom_plane_x[0])
                delta_x = -(atom_plane_y[-1] - atom_plane_y[0])
                atom_plane_x = np.array([start_x, start_x + delta_x])
                atom_plane_y = np.array([start_y, start_y + delta_y])
            else:
                atom_plane_x = np.array(atom_plane.get_x_position_list())
                atom_plane_y = np.array(atom_plane.get_y_position_list())
            image_ax.plot(atom_plane_x * data_scale,
                          atom_plane_y * data_scale,
                          color='red',
                          lw=2)

    atom_plane_x = np.array(interface_plane.get_x_position_list())
    atom_plane_y = np.array(interface_plane.get_y_position_list())
    image_ax.plot(atom_plane_x * data_scale,
                  atom_plane_y * data_scale,
                  color='blue',
                  lw=2)

    _make_subplot_map_from_regular_grid(distance_ax,
                                        heatmap_data_list,
                                        distance_data_scale=data_scale_z,
                                        clim=clim,
                                        atom_list=atom_list,
                                        atom_plane_marker=interface_plane,
                                        extra_marker_list=extra_marker_list,
                                        vector_to_plot=vector_to_plot)
    distance_cax = distance_ax.images[0]
    distance_ax.plot(atom_plane_x * data_scale,
                     atom_plane_y * data_scale,
                     color='red',
                     lw=2)

    for line_profile_ax, line_profile_data in zip(line_profile_ax_list,
                                                  line_profile_data_list):
        _make_subplot_line_profile(line_profile_ax,
                                   line_profile_data[:, 0],
                                   line_profile_data[:, 1],
                                   prune_outer_values=prune_outer_values,
                                   scale_x=data_scale,
                                   scale_z=data_scale_z)

    fig.tight_layout()
    fig.colorbar(distance_cax, cax=colorbar_ax, orientation='horizontal')
    fig.savefig(figname)
Ejemplo n.º 43
0
def graph(sensor_id):

    # TODO: Add date filtering options

    sensor_config = current_app.config["SENSOR_CONFIG"]
    data_dir = current_app.config["DATA_DIR"]
    logfile = util.get_logfile_path(data_dir, sensor_id)

    if sensor_id not in util.get_sensor_list(sensor_config):
        abort(500, "Sensor ID is not in the sensor list.")
    if not os.path.isfile(logfile):
        abort(500, "No data exists for the sensor.")

    json_data = util.get_json(util.get_logfile_path(data_dir, sensor_id))

    args = request.args

    if "field" in args:
        field = args["field"]
    else:
        return abort(400, "Field not specified.")

    if len(json_data) > 0:
        valid_fields = [
            field for field in json_data[next(iter(json_data))]["data"]
            if field != "id"
        ]
        if field not in valid_fields:
            return abort(500, "Invalid field.")

    filter_dates = False
    if "hours" in request.args:
        filter_dates = True
        hours_back = int(request.args["hours"])

    test_dates = [util.get_local_datetime(date) for date in json_data.keys()]
    test_dates.sort()
    latest_date = test_dates[-1]

    dates = []
    values = []

    for date_str, all_info in json_data.items():
        if not filter_dates or (latest_date - util.get_local_datetime(date_str)
                                ).total_seconds() / 3600 < hours_back:
            if field in all_info["data"]:
                dates.append(util.get_local_datetime(date_str))
                values.append(all_info["data"][field])
            else:
                print("Warning: data line did not contain {}!".format(field))

    # The key to using matplotlib with flask: don't use pyplot!
    # https://matplotlib.org/3.1.1/faq/howto_faq.html

    fig = Figure(figsize=(13, 3))
    ax = fig.subplots()

    ax.plot(dates, values, util.plot_formats[field])
    ax.set_title("{} vs. Time".format(util.data_units[field][0]))
    ax.set_xlabel("Time")
    ax.set_ylabel("{} ({})".format(
        *[val for i, val in enumerate(util.data_units[field]) if i != 1]))
    ax.grid()

    ax.margins(x=0.01, y=0.15)  # Margins are percentages
    fig.tight_layout()

    bg_color = "#ededed"
    fig.patch.set_facecolor(bg_color)
    ax.patch.set_facecolor(bg_color)

    img_io = BytesIO()
    fig.savefig(img_io, format='png', facecolor=fig.get_facecolor())
    img_io.seek(0)

    response = make_response(img_io.getvalue())
    response.headers['Content-Type'] = 'image/png'

    # Disable caching
    response.headers['Last-Modified'] = datetime.datetime.now()
    response.headers[
        'Cache-Control'] = 'no-store, no-cache, must-revalidate, post-check=0, pre-check=0, max-age=0'
    response.headers['Pragma'] = 'no-cache'
    response.headers['Expires'] = '-1'

    return response
Ejemplo n.º 44
0
class Plotting:
    def __init__(self, fitWindow):
        self.fitWindow = fitWindow
        #plot display flags
        self.flag_ca = True
        self.flag_ra = False
        self.flag_cl = False
        self.flag_rl = False
        self.flag_cn = False
        self.flag_ecc = False
        self.flag_lf = False
        self.flag_zp = False
        self.flag_xp = False
        self.flag_ap = False
        self.flag_fp = False
        self.flag_st = False
        self.flag_zd = False
        self.x_var = 'Time'  #x axis default parameter
        self.legendPos = "upper right"
        # self.fig1_close = True
        self.show_title = True
        self.showLegend2 = True
        self.fontSize = 12
        #fitting
        # self.flag_fit = False
        # self.fit_x = 'Vertical Position (μm)'
        # self.fit_y = 'Vertical Force'
        # self.startFit = 0
        # self.endFit = 100
        self.fit_pos = '0.5,0.5'
        self.fit_show = False
        self.slope = ''
        self.slope_unit = ''

        #initialize figure with random data
        self.fig1 = Figure(figsize=(11, 5), dpi=100)
        ax = self.fig1.add_subplot(111)
        xdata = np.linspace(0, 4, 50)
        ydata = np.sin(xdata)
        ax.plot(xdata, ydata, 'r-', linewidth=1, markersize=1)

        self.plotWidget = PlotWidget(fig=self.fig1,
                                     cursor1_init=2,
                                     cursor2_init=6)

    def plotData(self, unit):  #prepare plot

        xDict = {
            'Vertical Position (μm)': self.dist_vert1,
            'Lateral Position (μm)': self.dist_lat1,
            'Deformation (μm)': self.deform_vert,
            'Time (s)': self.time1
        }
        xAxisData = xDict.get(self.x_var)

        markerlist = ["o", "v", "^", "s", "P", "*", "D", "<", "X", ">"]
        linelist = [":", "-.", "--", "-", ":", "-.", "--", "-", ":", "-."]

        plt.rcParams.update({'font.size': self.fontSize})

        # self.fig1 = plt.figure(num="Force/Area vs Time", figsize = [11, 5])
        # self.fig1 = Figure(figsize=(11, 5), dpi=100)

        # self.fig1.canvas.mpl_connect('close_event', self.handle_close)

        print("fig1")

        #store cursor position values before clearing plot
        if self.plotWidget.wid.cursor1 == None:
            c1_init = None
        else:
            c1_init = self.plotWidget.wid.cursor1.get_xdata()[0]

        if self.plotWidget.wid.cursor2 == None:
            c2_init = None
        else:
            c2_init = self.plotWidget.wid.cursor2.get_xdata()[0]

        self.fig1.clear()
        ax1 = self.fig1.add_subplot(1, 1, 1)
        lns = []

        # ax1.set_title('Speed = ' + str(self.speed_um) + ' μm/s')
        ax1.set_xlabel(self.x_var)
        ax1.set_ylabel('Vertical Force (μN)', color='r')
        p1, = ax1.plot(xAxisData[self.plot_slice],
                       self.force_vert1_shifted[self.plot_slice],
                       'ro',
                       alpha=0.5,
                       linewidth=1,
                       markersize=1,
                       label="Vertical Force")
        lns.append(p1)

        # self.plotWidget.mpl_connect('close_event', self.handle_close)

        if self.ptsnumber != 0:
            ##            ptsperstep = int(self.ptsnumber/self.step_num)
            i = 0
            lns_reg = []  #region legend handle
            lab_reg = []  #region legend label
            speed_inview = []  #speed list in plot range
            for a in self.steps:  #shade step regions
                if i < ((self.plot_slice.start + 1) / self.ptsperstep) - 1:
                    i += 1
                    continue

                if self.ptsperstep * (i + 1) - 1 > self.plot_slice.stop:
                    endpoint = self.plot_slice.stop - 1
                    exit_flag = True
                else:
                    endpoint = self.ptsperstep * (i + 1) - 1
                    exit_flag = False

                if self.ptsperstep * i < self.plot_slice.start:
                    startpoint = self.plot_slice.start
                else:
                    startpoint = self.ptsperstep * i

                x_start = min(xAxisData[startpoint:endpoint])
                x_end = max(xAxisData[startpoint:endpoint])
                if a == 'Front':
                    v1 = ax1.axvspan(x_start,
                                     x_end,
                                     alpha=0.9,
                                     color='aliceblue',
                                     label=a)
                    lns_reg.append(v1)
                    lab_reg.append(a)
                    speed_inview.append(self.speed_um[i])
                    if exit_flag == True:
                        break
                elif a == 'Back':
                    v2 = ax1.axvspan(x_start,
                                     x_end,
                                     alpha=0.9,
                                     color='whitesmoke',
                                     label=a)
                    lns_reg.append(v2)
                    lab_reg.append(a)
                    speed_inview.append(self.speed_um[i])
                    if exit_flag == True:
                        break
                elif a == 'Up':
                    v3 = ax1.axvspan(x_start,
                                     x_end,
                                     alpha=0.9,
                                     color='honeydew',
                                     label=a)
                    lns_reg.append(v3)
                    lab_reg.append(a)
                    speed_inview.append(self.speed_um[i])
                    if exit_flag == True:
                        break
                elif a == 'Down':
                    v4 = ax1.axvspan(x_start,
                                     x_end,
                                     alpha=0.9,
                                     color='linen',
                                     label=a)
                    lns_reg.append(v4)
                    lab_reg.append(a)
                    speed_inview.append(self.speed_um[i])
                    if exit_flag == True:
                        break
                elif a == 'Pause':
                    v5 = ax1.axvspan(x_start,
                                     x_end,
                                     alpha=0.9,
                                     color='lightyellow',
                                     label=a)
                    lns_reg.append(v5)
                    lab_reg.append(a)
                    speed_inview.append(self.speed_um[i])
                    if exit_flag == True:
                        break
                i += 1

        if self.show_title == True:
            ax1.set_title('Speed = ' +
                          str(speed_inview).replace('[', '').replace(']', '') +
                          ' μm/s')
        if self.showLegend2 == True:
            dict_reg = dict(zip(lab_reg,
                                lns_reg))  #legend dictionary (remove dup)
            self.fig1.legend(dict_reg.values(),
                             dict_reg.keys(),
                             loc='lower right',
                             ncol=len(lns_reg))

        if self.flag_ap == True:  #show adhesion calc
            #fill adhesion energy region
            ax1.fill_between(xAxisData[self.energy_slice],
                             self.forceDict["zero1"][0],
                             self.force_vert1_shifted[self.energy_slice],
                             color='black')
            i = 0
            for k in self.rangeDict.keys():
                if len(self.rangeDict.keys()) > 1 and k == "Default":
                    continue
                ax1.axhline(y=self.forceDict["zero1"][i],
                            color='y',
                            alpha=1,
                            linestyle=linelist[i],
                            linewidth=1)
                ax1.axhline(y=self.forceDict["force_min1"][i],
                            color='y',
                            alpha=1,
                            linestyle=linelist[i],
                            linewidth=1)
                ax1.axhline(y=self.forceDict["force_max1"][i],
                            color='y',
                            alpha=1,
                            linestyle=linelist[i],
                            linewidth=1)
                ax1.axvline(x=xAxisData[self.time1.index(
                    self.indDict["time1_max"][i])],
                            color='y',
                            alpha=1,
                            linestyle=linelist[i],
                            linewidth=1)
                i += 1

        if self.flag_ca == True or self.flag_ra == True:
            ax2 = ax1.twinx()  #secondary axis
            ##                cmap = plt.cm.get_cmap("Reds")  # type: matplotlib.colors.ListedColormap
            num = len(self.rangeDict.keys())
            ##                colors = plt.cm.Reds(np.linspace(0.3,1,num))
            colors = plt.cm.Greens([0.7, 0.5, 0.9, 0.3, 1])
            ax2.set_prop_cycle(color=colors)
            ax2.set_ylabel('Area ($' + unit + '^2$)', color='g')
            if self.flag_ca == True:
                i = 0
                for k in self.rangeDict.keys():
                    if len(self.rangeDict.keys()) > 1 and k == "Default":
                        continue
                    p2, = ax2.plot(self.time2[self.plot_slice2],
                                   self.dataDict[k][0][self.plot_slice2],
                                   '-' + markerlist[i],
                                   alpha=0.5,
                                   linewidth=1,
                                   markersize=2,
                                   label="Contact Area: " + k)
                    # p2.set_animated(True) #BLIT THIS CHECK!!!
                    lns.append(p2)
                    if self.flag_ap == True:  #adhesion calc
                        ax2.plot(self.indDict["time1_max"][i],
                                 self.areaDict["area2_pulloff"][i],
                                 'y' + markerlist[i],
                                 alpha=0.8)
                    if self.flag_fp == True:  #friction calc
                        ax2.plot(self.indDict["time1_lat_avg"][i],
                                 self.areaDict["area_friction"][i],
                                 'g' + markerlist[i],
                                 alpha=0.8)
                    i += 1
            if self.flag_ra == True:  #consider first key since auto roi is same for all keys
                colors = plt.cm.Blues([0.7, 0.5, 0.9, 0.3, 1])
                ax2.set_prop_cycle(color=colors)
                j = 0
                for k in self.rangeDict.keys():
                    if len(self.rangeDict.keys()) > 1 and k == "Default":
                        continue
                    p3, = ax2.plot(self.time2[self.plot_slice2],
                                   self.dataDict[k][3][self.plot_slice2],
                                   '-' + markerlist[j],
                                   alpha=0.5,
                                   linewidth=1,
                                   markersize=2,
                                   label="ROI Area: " + k)
                    lns.append(p3)
                    j += 1

        if self.flag_lf == True:
            ax3 = ax1.twinx()  #lateral force
            ax3.set_ylabel('Lateral Force (μN)', color='c')
            ax3.spines['left'].set_position(
                ('outward', int(6 * self.fontSize)))
            ax3.spines["left"].set_visible(True)
            ax3.yaxis.set_label_position('left')
            ax3.yaxis.set_ticks_position('left')
            if self.invert_latf == True:
                ax3.invert_yaxis()
            if self.flag_lf == True:
                p4, = ax3.plot(xAxisData[self.plot_slice],
                               self.force_lat1_shifted[self.plot_slice],
                               'co',
                               alpha=0.5,
                               linewidth=1,
                               markersize=1,
                               label="Lateral Force")

##            if self.flag_lf_filter == True:
##                p4, = ax3.plot(self.time1[self.plot_slice], self.force_lat1_filtered_shifted[self.plot_slice], '-c',
##                     alpha=0.5, linewidth=1, label="Lateral Force")

            if self.flag_fp == True:  #show friction calc
                i = 0
                for k in self.rangeDict.keys():
                    if len(self.rangeDict.keys()) > 1 and k == "Default":
                        continue
                    ax3.axhline(y=self.forceDict["force_lat_max"][i],
                                color='g',
                                alpha=1,
                                linestyle=linelist[i],
                                linewidth=1)
                    ax3.axhline(y=self.forceDict["force_lat_min"][i],
                                color='g',
                                alpha=1,
                                linestyle=linelist[i],
                                linewidth=1)
                    ax1.axhline(y=self.forceDict["force_max2"][i],
                                color='g',
                                alpha=1,
                                linestyle=linelist[i],
                                linewidth=1)
                    ax3.axvline(x=xAxisData[self.time1.index(
                        self.indDict["time1_lat_avg"][i])],
                                color='g',
                                alpha=1,
                                linestyle=linelist[i],
                                linewidth=1)
                    # ax2.plot(self.indDict["time1_lat_avg"][i],
                    #          self.areaDict["area_friction"][i],
                    #          'g' + markerlist[i], alpha=0.8)
                    i += 1
                ax3.axhline(y=self.forceDict["zero2"],
                            color='g',
                            alpha=0.5,
                            linestyle=linelist[0],
                            linewidth=1)
            lns.append(p4)
        else:
            ax3 = None

        if self.flag_zp == True or self.flag_xp == True or self.flag_zd:  #piezo position/deformation
            ax4 = ax1.twinx()  #piezo waveform
            ax4.set_ylabel('Displacement (μm)', color='violet')
            if self.flag_ca == True or self.flag_ra == True:  #shift axis if area plotted
                ax4.spines['right'].set_position(
                    ('outward', int(7 * self.fontSize)))
##                ax4.invert_yaxis()
            if self.flag_zp == True:
                p5, = ax4.plot(xAxisData[self.plot_slice],
                               self.dist_vert1[self.plot_slice],
                               '-',
                               markersize=1,
                               color='violet',
                               alpha=0.5,
                               label="Vertical Piezo")
                lns.append(p5)
            if self.flag_xp == True:
                p6, = ax4.plot(xAxisData[self.plot_slice],
                               self.dist_lat1[self.plot_slice],
                               '-.',
                               markersize=1,
                               color='violet',
                               alpha=0.5,
                               label="Lateral Piezo")
                lns.append(p6)
            if self.flag_zd == True:  #actual deformation plot
                p12, = ax4.plot(xAxisData[self.plot_slice],
                                self.deform_vert[self.plot_slice],
                                '-o',
                                markersize=1,
                                color='violet',
                                alpha=0.5,
                                label="Deformation")
                if self.flag_ap == True:
                    ax1.axvline(x=xAxisData[self.deform_tol],
                                color='violet',
                                alpha=1,
                                linestyle=":",
                                linewidth=1)
                lns.append(p12)

        if self.flag_cl == True or self.flag_rl == True:
            ax5 = ax1.twinx()
            num = len(self.rangeDict.keys())
            colors = plt.cm.copper(np.linspace(0.2, 0.7, num))
            ax5.set_prop_cycle(color=colors)
            ax5.set_ylabel('Length ($' + unit + '$)', color='brown')
            if self.flag_ca == True or self.flag_ra == True:
                ax5.spines['right'].set_position(
                    ('outward', int(7 * self.fontSize)))
            if self.flag_cl == True:  #contact length
                i = 0
                for k in self.rangeDict.keys():
                    if len(self.rangeDict.keys()) > 1 and k == "Default":
                        continue
                    p7, = ax5.plot(self.time2[self.plot_slice2],
                                   self.dataDict[k][1][self.plot_slice2],
                                   '-' + markerlist[i],
                                   alpha=0.5,
                                   linewidth=1,
                                   markersize=2,
                                   label="Contact Length: " + k)
                    lns.append(p7)
                    i += 1
            if self.flag_rl == True:  #roi length
                ##                ax5 = ax1.twinx()
                num = len(self.rangeDict.keys())
                colors = plt.cm.Wistia(np.linspace(0.2, 0.7, num))
                ax5.set_prop_cycle(color=colors)
                ##                ax5.spines['right'].set_position(('outward', 70))
                j = 0
                for k in self.rangeDict.keys():
                    if len(self.rangeDict.keys()) > 1 and k == "Default":
                        continue
##                    ax5.set_ylabel('Length ($' + unit + '$)', color = 'brown')
                    p8, = ax5.plot(self.time2[self.plot_slice2],
                                   self.dataDict[k][4][self.plot_slice2],
                                   '-' + markerlist[j],
                                   alpha=0.5,
                                   linewidth=1,
                                   markersize=2,
                                   label="ROI Length: " + k)
                    lns.append(p8)
                    j += 1
        if self.flag_cn == True:  #contact number
            ax5 = ax1.twinx()
            num = len(self.rangeDict.keys())
            colors = plt.cm.copper(np.linspace(0.2, 0.7, num))
            ax5.set_prop_cycle(color=colors)
            ax5.spines['right'].set_position(
                ('outward', int(7 * self.fontSize)))
            i = 0
            for k in self.rangeDict.keys():
                if len(self.rangeDict.keys()) > 1 and k == "Default":
                    continue
                ax5.set_ylabel('Number', color='brown')
                p9, = ax5.plot(self.time2[self.plot_slice2],
                               self.dataDict[k][2][self.plot_slice2],
                               '-' + markerlist[i],
                               alpha=0.5,
                               linewidth=1,
                               markersize=2,
                               label="Contact Number: " + k)
                lns.append(p9)
                i += 1
        if self.flag_ecc == True:  #contact eccentricity
            ax5 = ax1.twinx()
            num = len(self.rangeDict.keys())
            colors = plt.cm.copper(np.linspace(0.2, 0.7, num))
            ax5.set_prop_cycle(color=colors)
            ax5.spines['right'].set_position(
                ('outward', int(7 * self.fontSize)))
            i = 0
            for k in self.rangeDict.keys():
                if len(self.rangeDict.keys()) > 1 and k == "Default":
                    continue
                ax5.set_ylabel('Eccentricity' + unit + '$)', color='brown')
                p10, = ax5.plot(self.time2[self.plot_slice2],
                                self.dataDict[k][5][self.plot_slice2],
                                '-' + markerlist[i],
                                alpha=0.5,
                                linewidth=1,
                                markersize=2,
                                label="Median Eccentricity: " + k)
                lns.append(p10)
                i += 1

        if self.flag_st == True or self.flag_lf_filter == True:  #stress CHECK!
            ax6 = ax1.twinx()
            ax6.set_ylabel('Stress (μN/$' + unit + '^2$)', color='c')
            ax6.spines['left'].set_position(
                ('outward', int(6 * self.fontSize)))
            ax6.spines["left"].set_visible(True)
            ax6.yaxis.set_label_position('left')
            ax6.yaxis.set_ticks_position('left')
            if self.flag_st == True:
                p11, = ax6.plot(xAxisData[self.plot_slice],
                                self.stress[self.plot_slice],
                                'co',
                                alpha=0.5,
                                linewidth=1,
                                markersize=1,
                                label="Stress")
            if self.flag_lf_filter == True:
                p11, = ax6.plot(xAxisData[self.plot_slice],
                                self.stress_filtered[self.plot_slice],
                                '-c',
                                alpha=0.5,
                                linewidth=1,
                                markersize=1,
                                label="Stress")

            lns.append(p11)

##            lns = [p1, p3, p2, p4, p5]
##        else:
##            lns = [p1, p2]

        ax1.legend(handles=lns, loc=self.legendPos)

        if self.fitWindow.enableFitting.isChecked() == True:
            axDict = {'Vertical Force (μN)': ax1, 'Lateral Force (μN)': ax3}
            # yDict = {'Vertical Force (μN)':self.force_vert1_shifted,
            #          'Lateral Force (μN)':self.force_lat1_shifted}
            # fit_slice = slice(int(self.startFit * self.ptsnumber/100),
            #                   int(self.endFit * self.ptsnumber/100))
            self.slope_unit = self.fitWindow.yFit.currentText().split('(')[1].split(')')[0] + '/' +\
                              self.fitWindow.xFit.currentText().split('(')[1].split(')')[0]
            text_pos = self.fit_pos.split(",")

            # self.slope = fitting.polyfitData(xDict.get(self.fit_x)[fit_slice], yDict.get(self.fit_y)[fit_slice],
            #                          axDict.get(self.fit_y), xAxisData[fit_slice], unit = self.slope_unit,
            #                          eq_pos = text_pos, fit_order = 1, fit_show = self.fit_show)
            ax_fit = axDict.get(self.fitWindow.yFit.currentText())
            ax_fit.plot(xAxisData[self.fitWindow.fit_slice],
                        self.fitWindow.fit_ydata,
                        color='black',
                        linewidth=2,
                        linestyle='dashed')

            ##        print(eq_pos)
            if self.fit_show == True and \
                self.fitWindow.fittingFunctionType.currentText() == 'Linear':
                self.slope = self.fitWindow.fitParams['m']
                slope_label = "Slope: " + "%.4f"%(self.slope) + \
                    ' (' + self.slope_unit + ')'
                ax_top = self.fig1.get_axes()[-1]
                ax_top.text(float(text_pos[0]),
                            float(text_pos[1]),
                            slope_label,
                            ha='right',
                            transform=ax_top.transAxes,
                            color='black',
                            bbox=dict(facecolor='white',
                                      edgecolor='black',
                                      alpha=0.5),
                            picker=5)
        else:
            self.slope = ''
            self.slope_unit = ''

        self.fig1.tight_layout()
        self.fig1.canvas.draw()

        self.plotWidget.wid.axes = self.fig1.get_axes()[-1]

        self.plotWidget.wid.add_cursors(cursor1_init=c1_init,
                                        cursor2_init=c2_init)

        print("plot finish")
        # self.plotWidget.wid.draw_idle()

        # self.fig1.tight_layout()
        # self.fig1.canvas.draw()

#     def showPlot(self): #show plot
# ##        self.fig1.show()
#         try:
#             # plt.pause(0.05)
#             # self.axes.relim()
#             # self.axes.autoscale()
#             self.fig1.tight_layout()
#             self.fig1.canvas.draw()
#             # self.plotWidget
#             # self.plotWidget.show()
#         except Exception as e:
#             print(e)

##        plt.show(block=False)
##        plt.draw()

# def handle_close(self, evt): #figure closed event
#     self.fig1_close = True
#     print("close")

    def convertPlot(self):  #convert plot to numpy
        self.fig1.canvas.draw()
        data = np.fromstring(self.fig1.canvas.tostring_rgb(),
                             dtype=np.uint8,
                             sep='')
        data = data.reshape(self.fig1.canvas.get_width_height()[::-1] + (3, ))
        return data

    def savePlot(self, filepath):  #save force plots
        print("save plot")
        self.fig1.savefig(filepath, orientation='landscape', transparent=True)
        #save figure object as pickle file
        with open(filepath[:-4] + '.pickle', 'wb') as f:
            pickle.dump(self.fig1, f, pickle.HIGHEST_PROTOCOL)


# class PlotWindow(QWidget):
#     def __init__(self, fig, *args, **kwargs):
# ##        super(QWidget, self).__init__(*args, **kwargs)
#         super().__init__()
#         self.setGeometry(100, 100, 1000, 500)
#         self.setWindowTitle("Plot")
#         self.fig = fig
#         self.home()

#     def home(self):
#         self.plotWidget = Plot2Widget(self.fig,cursor1_init=2,cursor2_init=6)
#         plotToolbar = NavigationToolbar(self.plotWidget, self)

#         plotGroupBox = QGroupBox()
#         plotlayout=QGridLayout()
#         plotGroupBox.setLayout(plotlayout)
#         plotlayout.addWidget(plotToolbar, 0, 0, 1, 1)
#         plotlayout.addWidget(self.plotWidget, 1, 0, 1, 1)

#         layout=QGridLayout()
#         layout.addWidget(plotGroupBox, 0, 0, 1, 1)

#         self.setLayout(layout)
# self.show()
# startFitLabel = QLabel("Start (%):")
Ejemplo n.º 45
0
def make_chart_response(country, deaths_start, avg_before_deaths, df_to_show):
    city = df_to_show['City'].iloc[0]
    df_quar = pd.read_csv(data_dir + 'all_countries_response.csv',
                          parse_dates=['Quarantine'])
    quarantine = df_quar[df_quar['Country'] == country]['Quarantine'].iloc[0]

    week = mdates.WeekdayLocator(interval=2)  # every year
    months = mdates.MonthLocator()  # every month
    month_fmt = mdates.DateFormatter('%b-%d')

    y_lim = df_to_show['TotalDeaths'].max() * 1.2
    y2_lim = df_to_show['no2'].max() * 1.8

    # Generate the figure **without using pyplot**.
    fig = Figure(figsize=(10, 5))

    ax = fig.subplots()

    ax.set_title('Assessing quarantine implementation - ' + country,
                 fontsize=16,
                 loc='left')

    if not pd.isnull(quarantine):
        ax.axvline(x=quarantine,
                   color='k',
                   linestyle='--',
                   lw=3,
                   label='Official quarantine')

    ax.scatter(df_to_show['Date'],
               df_to_show['TotalDeaths'],
               color='black',
               alpha=0.7,
               label='Confirmed deaths')

    ax.xaxis.set_major_locator(week)
    ax.xaxis.set_major_formatter(month_fmt)

    ax.set_yscale('log')
    ax.yaxis.set_major_formatter(
        ticker.FuncFormatter(lambda y, _: '{:g}'.format(y)))
    ax.set_ylim(1, y_lim)
    ax.set(ylabel='Confirmed deaths')

    ax2 = ax.twinx()

    sns.lineplot(x="Date",
                 y='no2',
                 alpha=0.7,
                 lw=6,
                 label='Daily $\mathrm{{NO}}_2$ pollution *',
                 ax=ax2,
                 data=df_to_show)
    sns.lineplot(x="Date",
                 y=avg_before_deaths,
                 alpha=0.7,
                 lw=6,
                 label='Average pollution **',
                 ax=ax2,
                 data=df_to_show)

    ax2.grid(False)
    ax2.xaxis.set_major_locator(week)
    ax2.xaxis.set_major_formatter(month_fmt)
    ax2.set_ylim(1, y2_lim)
    ax2.set(ylabel='$\mathrm{{NO}}_2$ pollution')

    # ask matplotlib for the plotted objects and their labels
    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines + lines2, labels + labels2, loc='upper left')

    annotation = """* Median of $\mathrm{{NO}}_2$ measurements in the most affected city ({city}), 5 days rolling average over time series\n** Average daily $\mathrm{{NO}}_2$ measurements from the begining of 2020 until the first day after {deaths_start} deaths""".format(
        city=city, deaths_start=deaths_start)
    ax.annotate(annotation, (0, 0), (0, -30),
                xycoords='axes fraction',
                textcoords='offset points',
                va='top')

    logo = plt.imread('./static/img/new_logo_site.png')
    ax.figure.figimage(logo, 100, 110, alpha=.35, zorder=1)

    fig.tight_layout()
    # Save it to a temporary buffer.
    buf = BytesIO()
    fig.savefig(buf, format="png")
    buf.seek(0)
    return buf
Ejemplo n.º 46
0
ax = fig.add_axes([0.04, 0.07, 0.95, 0.92])
ax.set_xlim(xmin=1926, xmax=1937),
ax.set_ylim(ymin=-5, ymax=5)
ax.set_xticks([1927, 1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935])
ax.set_yticks([-4, -3, -2, -1, 0, 1, 2, 3, 4])
ax.set_ylabel('TMP2m anomaly')

# Mark the zero line
ax.add_line(
    Line2D((1926, 1937), (0, 0),
           linewidth=1.0,
           color=(0, 0, 0, 0.5),
           alpha=1.0,
           zorder=100))

# Mark the flood period
ax.add_patch(
    Rectangle((1931.5, -100), 0.3, 200, fill=True, facecolor=(0, 0, 0, 0.1)))

for member in range(80):
    ax.add_line(
        Line2D(dts,
               tSeries[member],
               linewidth=0.5,
               color=(0, 0, 1, 1),
               alpha=0.1,
               zorder=200))

fig.savefig('TMP2m_ts.png')
Ejemplo n.º 47
0
           frameon=False,
           subplotpars=None,
           tight_layout=None)
canvas=FigureCanvas(fig)

# Hidden layer
plot_hidden(autoencoder.get_weights()[1])

# Global projection
projection=ccrs.RotatedPole(pole_longitude=180.0, pole_latitude=90.0)
extent=[-180,180,-90,90]

for layer in [0,2]:
    w_l=autoencoder.get_weights()[layer]
    vmin=numpy.mean(w_l)-numpy.std(w_l)*3
    vmax=numpy.mean(w_l)+numpy.std(w_l)*3
    count=0
    for channel in order:
        w_in=ic.copy()
        if layer==0:
            w_in.data=w_l[:,channel].reshape(ic.data.shape)
        else:
            w_in.data=w_l[channel,:].reshape(ic.data.shape)
        w_in.data *= numpy.sign(autoencoder.get_weights()[1][channel])
        plot_weights(w_in,layer=layer,channel=count,nchannels=36,
                     vmin=vmin,vmax=vmax)
        count += 1
        
# Render the figure as a png
fig.savefig("weights.png")
Ejemplo n.º 48
0
ax.set_xlim(xmin=1930,xmax=1932),
ax.set_ylim(ymin=0,ymax=170)
ax.set_xticks([1927,1928,1929,1930,1931,1932,1933,1934,1935])
#ax.set_yticks([-4,-3,-2,-1,0,1,2,3,4])
ax.set_ylabel('PRATE actual')

# Mark the flood period
ax.add_patch(
    Rectangle(
        (1931.5, -1), 0.3, 200, fill=True, facecolor=(0,0,0,0.1)
    )
)

# Mark the climatology
ax.add_line(Line2D(dts,clim*11,
                   linewidth=1.5,
                   color=(1,0,0,1),
                   alpha=0.5,
                   zorder=100))

for member in range(80):
    ax.add_line(Line2D(dts,tSeries[member],
                       linewidth=0.5,
                       color=(0,0,1,1),
                       alpha=0.1,
                       zorder=200))

fig.savefig('PRATE_ts_actuals.png')


    def train(self, sess, config):
        """ Training the GAN """
        print('initializing...opt')
        d_opt = self.d_opt
        g_opt = self.g_opt

        try:
            init = tf.global_variables_initializer()
            sess.run(init)
        except AttributeError:
            init = tf.intializer_all_varialble()
            sess.run(init)

        print('initializing...var')
        # g_summaries = [self.d_fake_summary,
        #                 self.d_fake_loss_summary,
        #                 self.g_loss_summary,
        #                 self.g_l2_loss_summary,
        #                 self.g_loss_adv_summary,
        #                 self.generated_wav_summary]
        # d_summaries = [self.d_loss_summary, self.d_real_summary, self.d_real_loss_summary, self.high_wav_summary]

        # if hasattr(self, 'alpha_summ'):
        #     g_summaries += self.alpha_summ
        # self.g_sum = tf.summary.merge(g_summaries)
        # self.d_sum = tf.summary.merge(d_summaries)

        if not os.path.exists(os.path.join(config.save_path, 'train')):
            os.makedirs(os.path.join(config.save_path, 'train'))

        self.writer = tf.summary.FileWriter(
            os.path.join(config.save_path, 'train'), self.sess.graph)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        sample_low, sample_high, sample_z = self.sess.run(
            [self.gt_low[0], self.gt_high[0], self.zz_A[0]],
            feed_dict={
                self.is_valid: False,
                self.is_train: True,
                self.is_mismatch: False
            })
        v_sample_low, v_sample_high, v_sample_z = self.sess.run(
            [self.gt_low[0], self.gt_high[0], self.zz_A[0]],
            feed_dict={
                self.is_valid: True,
                self.is_train: False,
                self.is_mismatch: False
            })

        print('sample low shape: ', sample_low.shape)
        print('sample high shape: ', sample_high.shape)
        print('sample z shape: ', sample_z.shape)

        save_path = config.save_path
        counter = 0
        # count of num of samples
        num_examples = 0
        for record in tf.python_io.tf_record_iterator(self.tfrecords):
            num_examples += 1
        print("total num of patches in tfrecords", self.tfrecords, ":  ",
              num_examples)

        # last samples
        # batch num
        num_batches = num_examples / self.batch_size
        print('batches per epoch: ', num_batches)

        if self.load(self.save_path):
            print('load success')
        else:
            print('load failed')
        batch_idx = 0
        current_epoch = 0
        batch_timings = []
        g_losses = []
        d_A_losses = []
        d_B_losses = []
        g_adv_losses = []
        g_l1_losses_BAB = []
        g_l1_losses_AB = []
        g_l1_losses_ABA = []
        g_l1_losses_BA = []

        try:
            while not coord.should_stop():
                start = timeit.default_timer()
                if counter % config.save_freq == 0:

                    for d_iter in range(self.disc_updates):
                        _d_opt, d_A_loss, d_B_loss = self.sess.run(
                            [d_opt, self.d_A_losses[0], self.d_B_losses[0]],
                            feed_dict={
                                self.is_valid: False,
                                self.is_train: True,
                                self.is_mismatch: True
                            })
                        _d_opt, d_A_loss, d_B_loss = self.sess.run(
                            [d_opt, self.d_A_losses[0], self.d_B_losses[0]],
                            feed_dict={
                                self.is_valid: False,
                                self.is_train: True,
                                self.is_mismatch: False
                            })

                        #_d_sum, d_fake_loss, d_real_loss = self.sess.run(
                        #   [self.d_sum, self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False})

                        if self.d_clip_weights:
                            self.sess.run(self.d_clip,
                                          feed_dict={
                                              self.is_valid: False,
                                              self.is_train: True
                                          })

                    #_g_opt, _g_sum, g_adv_loss, g_l2_loss = self.sess.run([g_opt, self.g_sum, self.g_adv_losses[0], self.g_l2_losses[0]], feed_dict={self.is_valid:False})
                    _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run(
                        [
                            g_opt, self.g_adv_losses[0], self.g_losses_AB[0],
                            self.g_losses_BA[0], self.g_l1_losses_ABA[0],
                            self.g_l1_losses_BAB[0]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True,
                            self.is_mismatch: True
                        })
                    _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run(
                        [
                            g_opt, self.g_adv_losses[0], self.g_losses_AB[0],
                            self.g_losses_BA[0], self.g_l1_losses_ABA[0],
                            self.g_l1_losses_BAB[0]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True,
                            self.is_mismatch: False
                        })
                    # _phase_opt, phase_loss = self.sess.run([phase_opt, self.phase_losses[0]], feed_dict={self.is_valid:False,self.is_train: True})

                else:
                    for d_iter in range(self.disc_updates):
                        _d_opt, d_A_loss, d_B_loss = self.sess.run(
                            [d_opt, self.d_A_losses[0], self.d_B_losses[0]],
                            feed_dict={
                                self.is_valid: False,
                                self.is_train: True,
                                self.is_mismatch: True
                            })
                        _d_opt, d_A_loss, d_B_loss = self.sess.run(
                            [d_opt, self.d_A_losses[0], self.d_B_losses[0]],
                            feed_dict={
                                self.is_valid: False,
                                self.is_train: True,
                                self.is_mismatch: False
                            })
                        #d_fake_loss, d_real_loss = self.sess.run(
                        #    [self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False})
                        if self.d_clip_weights:
                            self.sess.run(self.d_clip,
                                          feed_dict={
                                              self.is_valid: False,
                                              self.is_train: True
                                          })
                    _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run(
                        [
                            g_opt, self.g_adv_losses[0], self.g_losses_AB[0],
                            self.g_losses_BA[0], self.g_l1_losses_ABA[0],
                            self.g_l1_losses_BAB[0]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True,
                            self.is_mismatch: True
                        })
                    _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run(
                        [
                            g_opt, self.g_adv_losses[0], self.g_losses_AB[0],
                            self.g_losses_BA[0], self.g_l1_losses_ABA[0],
                            self.g_l1_losses_BAB[0]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True,
                            self.is_mismatch: False
                        })
                    # _phase_opt, phase_loss = self.sess.run([phase_opt, self.phase_losses[0]], feed_dict={self.is_valid:False,self.is_train: True})

                end = timeit.default_timer()
                batch_timings.append(end - start)
                d_A_losses.append(d_A_loss)
                d_B_losses.append(d_B_loss)
                g_adv_losses.append(g_adv_loss)
                g_l1_losses_BAB.append(g_BAB_loss)  # clean - reverb - clean
                g_l1_losses_AB.append(g_AB_loss)  # reverb - clean
                g_l1_losses_ABA.append(g_ABA_loss)  # reverb - clean  - reverb
                g_l1_losses_BA.append(g_BA_loss)  # clean - reverb

                print(
                    '{}/{} (epoch {}), d_A_loss = {:.5f}, '
                    'd_B_loss = {:.5f}, '  #d_nfk_loss = {:.5f}, '
                    'g_adv_loss = {:.5f}, g_AB_loss = {:.5f}, g_BAB_loss = {:.5f}, '
                    'g_BA_loss = {:.5f}, g_ABA_loss = {:.5f}, '
                    ' time/batch = {:.5f}, '
                    'mtime/batch = {:.5f}'.format(
                        counter, config.epoch * num_batches, current_epoch,
                        d_A_loss, d_B_loss, g_adv_loss, g_AB_loss, g_BAB_loss,
                        g_BA_loss, g_ABA_loss, end - start,
                        np.mean(batch_timings)))
                batch_idx += 1
                counter += 1

                if (counter) % 2000 == 0 and (counter) > 0:
                    self.save(config.save_path, counter)

                if (counter % config.save_freq == 0) or (counter == 1):
                    # self.writer.add_summary(_g_sum, counter)
                    # self.writer.add_summary(_d_sum, counter)
                    #feed_dict = {self.gt_high[0]:v_sample_high, self.gt_low[0]:v_sample_low, self.zz[0]:v_sample_z, self.is_valid:True}

                    s_A, s_B, s_reverb, s_gt, r_phase, f_phase = self.sess.run(
                        [
                            self.GG_A[0][0, :, :, :], self.GG_B[0][0, :, :, :],
                            self.gt_low[0][0, :, :, :],
                            self.gt_high[0][0, :, :, :],
                            self.ori_phase_[0][0, :, :, :],
                            self.rev_phase_[0][0, :, :, :]
                        ],
                        feed_dict={
                            self.is_valid: True,
                            self.is_train: False,
                            self.is_mismatch: False
                        })

                    if not os.path.exists(save_path + '/wav'):
                        os.makedirs(save_path + '/wav')
                    if not os.path.exists(save_path + '/txt'):
                        os.makedirs(save_path + '/txt')
                    if not os.path.exists(save_path + '/spec'):
                        os.makedirs(save_path + '/spec')

                    print(str(counter) + 'th finished')

                    x_AB = s_A
                    x_BA = s_B
                    x_reverb = s_reverb
                    x_gt = s_gt

                    Sre = self.get_spectrum(x_reverb).reshape(512, 128)
                    Sgt = self.get_spectrum(x_gt).reshape(512, 128)
                    SAB = self.get_spectrum(x_AB).reshape(512, 128)
                    SBA = self.get_spectrum(x_BA).reshape(512, 128)
                    S = np.concatenate((Sre, Sgt, SAB, SBA), axis=1)
                    fig = Figure(figsize=S.shape[::-1], dpi=1, frameon=False)
                    canvas = FigureCanvas(fig)
                    fig.figimage(S, cmap='jet')
                    fig.savefig(save_path + '/spec/' + 'valid_batch_index' +
                                str(counter) + '-th_pr.png')

                    x_pr = librosa.istft(self.inv_magphase(s_A, f_phase))
                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) +
                        '_AB(dereverb).wav', x_pr, 16000)
                    x_pr = librosa.istft(self.inv_magphase(s_B, r_phase))
                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) + '_BA(reverb).wav',
                        x_pr, 16000)
                    x_lr = librosa.istft(self.inv_magphase(s_reverb, f_phase))
                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) + '_reverb.wav',
                        x_lr, 16000)
                    x_hr = librosa.istft(self.inv_magphase(s_gt, r_phase))
                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) + '_orig.wav', x_hr,
                        16000)

                    s_AB, s_BA, s_reverb, s_gt = self.sess.run(
                        [
                            self.GG_A[0][0, :, :, :], self.GG_B[0][0, :, :, :],
                            self.gt_low[0][0, :, :, :],
                            self.gt_high[0][0, :, :, :]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True,
                            self.is_mismatch: False
                        })

                    x_AB = s_AB
                    x_BA = s_BA
                    x_reverb = s_reverb
                    x_gt = s_gt

                    Sre = self.get_spectrum(x_reverb).reshape(512, 128)
                    Sgt = self.get_spectrum(x_gt).reshape(512, 128)
                    SAB = self.get_spectrum(x_AB).reshape(512, 128)
                    SBA = self.get_spectrum(x_BA).reshape(512, 128)

                    S = np.concatenate((Sre, Sgt, SAB, SBA), axis=1)
                    fig = Figure(figsize=S.shape[::-1], dpi=1, frameon=False)
                    canvas = FigureCanvas(fig)
                    fig.figimage(S, cmap='jet')
                    fig.savefig(save_path + '/spec/' + 'train_batch_index' +
                                str(counter) + '-th_pr.png')

                    #np.savetxt(os.path.join(save_path, '/txt/d_real_losses.txt'), d_real_losses)
                    #np.savetxt(os.path.join(save_path, '/txt/d_fake_losses.txt'), d_fake_losses)
                    #np.savetxt(os.path.join(save_path, '/txt/g_adv_losses.txt'), g_adv_losses)
                    #np.savetxt(os.path.join(save_path, '/txt/g_l2_losses.txt'), g_l2_losses)

                if batch_idx >= num_batches:
                    current_epoch += 1
                    #reset batch idx
                    batch_idx = 0

                if current_epoch >= config.epoch:
                    print(str(self.epoch), ': epoch limit')
                    print('saving last model at iteration', str(counter))
                    self.save(config.save_path, counter)
                    # self.writer.add_summary(_g_sum, counter)
                    # self.writer.add_summary(_d_sum, counter)
                    break

        except tf.errors.InternalError:
            print('InternalError')
            pass

        except tf.errors.OutOfRangeError:
            print('done training')
            pass
        finally:
            coord.request_stop()
        coord.join(threads)
Ejemplo n.º 50
0
def make_chart_comparison(df_places_to_show,
                          level='countries',
                          scale='log',
                          y='total',
                          mode='static'):
    week = mdates.WeekdayLocator(interval=2)  # every year
    months = mdates.MonthLocator()  # every month
    month_fmt = mdates.DateFormatter('%b-%d')

    var_y_suffix = '' if y == 'total' else 'Per100k'
    label_y_scale = ' (log)' if scale == 'log' else ''
    label_y_y = '' if y == 'total' else ' per 100k'

    # get last date from dataframe
    date = df_places_to_show['Date'].max(
    )  # datetime.today().strftime('%Y-%m-%d')

    gap = int(df_places_to_show['gap'].min())

    y_lim = df_places_to_show['Total' + var_y_suffix].max()  #* 1.2

    # Generate the figure **without using pyplot**.
    fig = Figure(figsize=(8, 5))

    ax = fig.subplots()

    places_to_show = df_places_to_show['Name'].unique()[:2]

    place_name = 'Country' if level == 'countries' else 'City'
    df_places_to_show = df_places_to_show.rename(columns={'Name': place_name})

    ax.set_title('{} Comparison - COVID-19 Cases vs. Deaths - {}'.format(
        place_name, date),
                 fontsize=14)

    sns.scatterplot(x="DayAdj",
                    y='Total' + var_y_suffix,
                    hue=place_name,
                    lw=6,
                    alpha=0.8,
                    data=df_places_to_show,
                    ax=ax)

    ax.xaxis.set_major_locator(months)
    ax.xaxis.set_major_formatter(month_fmt)

    ax.legend(loc='upper left', title="Confirmed cases", frameon=True)

    ax.set(
        ylabel='Total confirmed cases{}{}'.format(label_y_y, label_y_scale),
        xlabel="Date for {} ({}'s data shifted {} days to align death curves)".
        format(places_to_show[0], places_to_show[1], gap))

    ax.set_ylim(0.5, y_lim) if scale == 'log' else ax.set_ylim(-5, y_lim)

    ax2 = ax.twinx()

    if scale == 'log':
        ax.set_yscale('log')
        ax2.set_yscale('log')

    ax.yaxis.set_major_formatter(
        ticker.FuncFormatter(lambda y, _: '{:g}'.format(y)))
    ax2.yaxis.set_major_formatter(
        ticker.FuncFormatter(lambda y, _: '{:g}'.format(y)))

    ax2.grid(False)

    sns.lineplot(x="DayAdj",
                 y='TotalDeaths' + var_y_suffix,
                 hue=place_name,
                 alpha=0.7,
                 lw=6,
                 ax=ax2,
                 data=df_places_to_show)

    ax2.legend(loc='lower right', title="Deaths", frameon=True)

    ax2.set(ylabel='Total deaths{}{}'.format(label_y_y, label_y_scale))

    ax2.set_ylim(0.5, y_lim) if scale == 'log' else ax2.set_ylim(-5, y_lim)

    logo = plt.imread('./static/img/new_logo_site.png')
    ax.figure.figimage(logo, 95, 70, alpha=.35, zorder=1)

    fig.tight_layout()

    # display(fig)

    # Save it to a temporary buffer.
    buf = BytesIO()
    fig.savefig(buf, format="png")
    buf.seek(0)
    return buf
Ejemplo n.º 51
0
def visualize_manipulation_training(flow, epoch, save_dir=None):
    """
    Visualize progress of manipulation training.

    :param nip: the neural imaging pipeline
    :param fan: the forensic analysis network
    :param dcn: the compression model (e.g., deep compression network)
    :param conf: confusion matrix (see 'confusion()')
    :param epoch: epoch counter to be appended to the output filename
    :param save_dir: path to the directory where figures should be generated (figure handle returned otherwise)
    :param classes: labels for the classes to be used for plotting the confusion matrix
    :return: None (if output to file requested) or figure handle
    """

    # Basic figure setup
    images_x = 3
    images_y = 3 if isinstance(flow.codec, DCN) else 2
    fig = Figure(figsize=(18, 10 / images_x * images_y))
    conf = np.array(flow.fan.performance['confusion'])

    # Draw the plots
    ax = fig.add_subplot(images_y, images_x, 1)
    ax.plot(flow.nip.performance['loss']['training'], '.', alpha=0.25)
    ax.plot(helpers.stats.ma_conv(flow.nip.performance['loss']['training'], 0))
    ax.set_ylabel('{} NIP loss'.format(flow.nip.class_name))
    ax.set_title('Loss')

    ax = fig.add_subplot(images_y, images_x, 2)
    ax.plot(flow.nip.performance['psnr']['validation'], '.', alpha=0.25)
    ax.plot(
        helpers.stats.ma_conv(flow.nip.performance['psnr']['validation'], 0))
    ax.set_ylabel('{} NIP psnr'.format(flow.nip.class_name))
    ax.set_title('PSNR')
    ax.set_ylim([30, 50])

    ax = fig.add_subplot(images_y, images_x, 3)
    ax.plot(flow.nip.performance['ssim']['validation'], '.', alpha=0.25)
    ax.plot(
        helpers.stats.ma_conv(flow.nip.performance['ssim']['validation'], 0))
    ax.set_ylabel('{} NIP ssim'.format(flow.nip.class_name))
    ax.set_title('SSIM')
    ax.set_ylim([0.8, 1])

    ax = fig.add_subplot(images_y, images_x, 4)
    ax.plot(flow.fan.performance['loss']['training'], '.', alpha=0.25)
    ax.plot(helpers.stats.ma_conv(flow.fan.performance['loss']['training'], 0))
    ax.set_ylabel('FAN loss')

    ax = fig.add_subplot(images_y, images_x, 5)
    ax.plot(flow.fan.performance['accuracy']['validation'], '.', alpha=0.25)
    ax.plot(
        helpers.stats.ma_conv(flow.fan.performance['accuracy']['validation'],
                              0))
    ax.set_ylabel('FAN accuracy')
    ax.set_ylim([0, 1])

    # The confusion matrix
    ax = fig.add_subplot(images_y, images_x, 6)
    ax.imshow(conf, vmin=0, vmax=1)

    ax.set_xticks(range(flow.n_classes))
    ax.set_xticklabels(flow._forensics_classes, rotation='vertical')
    ax.set_yticks(range(flow.n_classes))
    ax.set_yticklabels(flow._forensics_classes)

    for r in range(flow.n_classes):
        ax.text(r,
                r,
                '{:.2f}'.format(conf[r, r]),
                horizontalalignment='center',
                color='b' if conf[r, r] > 0.5 else 'w')

    ax.set_xlabel('PREDICTED class')
    ax.set_ylabel('TRUE class')
    ax.set_title('Accuracy: {:.2f}'.format(np.mean(np.diag(conf))))

    # If the compression model is a trainable DCN, include it's validation metrics
    if images_y == 3:
        ax = fig.add_subplot(images_y, images_x, 7)
        ax.plot(flow.codec.performance['loss']['validation'], '.', alpha=0.25)
        ax.plot(
            helpers.stats.ma_conv(flow.codec.performance['loss']['validation'],
                                  0))
        ax.set_ylabel('DCN loss')

        ax = fig.add_subplot(images_y, images_x, 8)
        ax.plot(flow.codec.performance['ssim']['validation'], '.', alpha=0.25)
        ax.plot(
            helpers.stats.ma_conv(flow.codec.performance['ssim']['validation'],
                                  0))
        ax.set_ylabel('DCN ssim')
        ax.set_ylim([0.8, 1])

        ax = fig.add_subplot(images_y, images_x, 9)
        ax.plot(flow.codec.performance['entropy']['validation'],
                '.',
                alpha=0.25)
        ax.plot(
            helpers.stats.ma_conv(
                flow.codec.performance['entropy']['validation'], 0))
        ax.set_ylabel('DCN entropy')

    if save_dir is not None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        fig.savefig('{}/manip_validation_{:05d}.jpg'.format(save_dir, epoch),
                    bbox_inches='tight',
                    dpi=100)
        del fig

    else:
        return fig
Ejemplo n.º 52
0
def vis_points(run=None,
               d=None,
               sdata_file=None,
               y_target=0,
               seed=None,
               blocksize=None,
               highlight_block=None):

    if d is None:
        d = exp_dir(run)

    if sdata_file is not None:
        with open(sdata_file, 'rb') as f:
            sdata = pickle.load(f)

    for fname in [
            "true.xxx",
    ] + sorted(os.listdir(d)):
        if fname == "true.xxx":
            X = sdata.SX
        elif not fname.startswith("step") or not fname.endswith("_X.npy"):
            continue
        else:
            X = np.load(os.path.join(d, fname))

        try:
            ix_fname = fname.replace("_X", "_IX")
            IX = np.load(os.path.join(d, ix_fname))
        except:
            IX = None

        fig = Figure(dpi=144, figsize=(14, 14))
        fig.patch.set_facecolor('white')
        ax = fig.add_subplot(111)

        cmap = "jet"
        sargs = {}
        if y_target == -1:
            # plot "wrongness"
            c = np.sqrt(np.sum((X - sdata.SX)**2, axis=1))
            cmap = "hot"
        elif y_target == -2 or y_target == -3:
            # plot blocks
            c = np.zeros((X.shape[0]))

            if y_target == -2:
                np.random.seed(seed)
                sdata.cluster_rpc(blocksize)
            else:
                centers = grid_centers(blocksize)
                sdata.set_centers(centers)

            cmap = "prism"
            if highlight_block is not None:
                block_colors = np.ones((len(sdata.block_idxs), )) * 0.4
                block_colors[highlight_block] = 0.0
            else:
                block_colors = np.linspace(0.0, 1.0, len(sdata.block_idxs))
            block_idxs = sdata.reblock(X)
            for i, idxs in enumerate(block_idxs):
                c[idxs] = block_colors[i]

            #c = np.sqrt(np.sum((X - sdata.SX)**2, axis=1))
        elif sdata_file is None:
            c = None
        else:
            c = sdata.SY[:, y_target:y_target + 1].flatten()
            sargs['vmin'] = -3.0
            sargs['vmax'] = 3.0

        npts = len(X)
        xmax = np.sqrt(npts)
        X *= xmax

        if IX is not None:
            IX *= xmax
            ax.scatter(IX[:, 0],
                       IX[:, 1],
                       alpha=1.0,
                       c="black",
                       s=25,
                       marker='o',
                       linewidths=0.0,
                       **sargs)

        ax.scatter(X[:, 0],
                   X[:, 1],
                   alpha=1.0,
                   c=c,
                   cmap=cmap,
                   s=70,
                   marker='.',
                   linewidths=0.0,
                   **sargs)
        ax.set_xlim((0, xmax))
        ax.set_ylim((0, xmax))

        ax.set_yticks([20, 40, 60, 80, 100])

        ax.tick_params(axis='x', labelsize=30)
        ax.tick_params(axis='y', labelsize=30)

        canvas = FigureCanvasAgg(fig)

        out_name = os.path.join(d, fname[:-4] + ".png")
        fig.savefig(out_name, bbox_inches="tight")
        print "wrote", out_name

    print "generating movie...:"
    cmd = "avconv -f image2 -r 5 -i step_%05d_X.png -qscale 28 gprf.mp4".split(
        " ")
    import subprocess
    p = subprocess.Popen(cmd, cwd=d)
    p.wait()
    print "done"
Ejemplo n.º 53
0
# add a label to the x axis
#ax.set_xlabel('Year')
# add a label to the y axis
ax.set_ylabel('Temperature in Celsius')
# add a title
ax.set_title(
    'Yearly Average Temperatures measured in Santacruz, Bombay for the years 1973-2019'
)

[s.set_visible(False) for s in ax.spines.values()]

DefaultSize = fig.get_size_inches()

fig.set_size_inches((DefaultSize[0] * 1.5, DefaultSize[1] * 1.5))

fig.savefig('weatherplot.png')

#Plotting daily average temperatures over different decades

santacruz_df80s = santacruz_df[(santacruz_df['YEAR'] >= '1980')
                               & (santacruz_df['YEAR'] < '1990')]
print(santacruz_df80s)

santacruz_df80s = santacruz_df80s.drop(
    santacruz_df80s[(santacruz_df80s['MONTH'] == '02')
                    & (santacruz_df80s['DAY'] == '29')].index)

daily_avg_80s = santacruz_df80s.groupby(
    ['MONTH', 'DAY'])['TAVG'].agg(AVERAGE='mean').reset_index()
print(daily_avg_80s)
Ejemplo n.º 54
0
def matplotlib(pltid):
    """Generate a random image using Matplotlib and display it"""
    # in the future create a private function __import__ to import third-party
    # libraries, so that it can respond gracefully.  See for example the
    # Examples section at https://docs.python.org/2/library/imp.html
    user = root.authorized()
    import StringIO
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
    from matplotlib.figure import Figure
    app = request.query.app
    cid = request.query.cid

    fig = Figure()
    fig.set_tight_layout(True)
    ax = fig.add_subplot(111)

    # get info about plot from db
    p = Plot()
    result = db(plots.id == pltid).select().first()
    plot_title = result['title']
    plottype = result['ptype']
    options = result['options']

    # parse plot options to extract and set x- and y-axis labels
    m = re.search("xaxis:\s*{(.*)}", options)
    if m:
        n = re.search("axisLabel:\s*\"(\w*)\"", m.group(1))
        if n: ax.set_xlabel(n.group(1))

    m = re.search("yaxis:\s*{(.*)}", options)
    if m:
        n = re.search("axisLabel:\s*\"(\w*)\"", m.group(1))
        if n: ax.set_ylabel(n.group(1))

    # get info about data source
    # fix in the future to handle multiple data sources
    result = db(datasource.pltid == pltid).select()
    for r in result:
        plotfn = r['filename']
        cols = r['cols']
        line_range = r['line_range']
        (col1str, col2str) = cols.split(":")
        col1 = int(col1str)
        col2 = int(col2str)
        if line_range is not None:
            # to prevent breaking current spc apps, still support
            # expressions like 1:1000, but in the future this should
            # be changed to a range 1-1000.  Therefore, using : is deprecated
            # and will be removed in the future.
            (line1str, line2str) = re.split("[-:]", line_range)

    plotfn = re.sub(r"<cid>", cid, plotfn)
    sim_dir = os.path.join(user_dir, user, app, cid)
    plotpath = os.path.join(sim_dir, plotfn)
    xx = p.get_column_of_data(plotpath, col1)
    yy = p.get_column_of_data(plotpath, col2)
    # convert elements from strings to floats
    xx = [float(i) for i in xx]
    yy = [float(i) for i in yy]

    # plot
    if plottype == 'mpl-line':
        ax.plot(xx, yy)
    elif plottype == 'mpl-bar':
        ax.bar(xx, yy)
    else:
        return "ERROR: plottype not supported"
    canvas = FigureCanvas(fig)
    png_output = StringIO.StringIO()
    canvas.print_png(png_output)

    # save file
    tmp_dir = "static/tmp"
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir)
    fn = plot_title + '.png'
    fig.set_size_inches(7, 4)
    img_path = os.path.join(sim_dir, fn)
    fig.savefig(img_path)

    # get list of all plots for this app
    query = (apps.id == plots.appid) & (apps.name == app)
    list_of_plots = db(query).select()
    stats = compute_stats(plotpath)

    params = {
        'image': fn,
        'app': app,
        'cid': cid,
        'pltid': pltid,
        'plotpath': plotpath,
        'img_path': img_path,
        'plot_title': plot_title,
        'rows': list_of_plots,
        'stats': stats
    }
    return template('plots/matplotlib', params)
Ejemplo n.º 55
0
def thumbnail(infile,
              thumbfile,
              scale=0.1,
              interpolation='bilinear',
              preview=False):
    """
    make a thumbnail of image in *infile* with output filename
    *thumbfile*.

      *infile* the image file -- must be PNG or PIL readable if you
         have `PIL <http://www.pythonware.com/products/pil/>`_ installed

      *thumbfile*
        the thumbnail filename

      *scale*
        the scale factor for the thumbnail

      *interpolation*
        the interpolation scheme used in the resampling


      *preview*
        if True, the default backend (presumably a user interface
        backend) will be used which will cause a figure to be raised
        if :func:`~matplotlib.pyplot.show` is called.  If it is False,
        a pure image backend will be used depending on the extension,
        'png'->FigureCanvasAgg, 'pdf'->FigureCanvasPdf,
        'svg'->FigureCanvasSVG


    See examples/misc/image_thumbnail.py.

    .. htmlonly::

        :ref:`misc-image_thumbnail`

    Return value is the figure instance containing the thumbnail

    """
    basedir, basename = os.path.split(infile)
    baseout, extout = os.path.splitext(thumbfile)

    im = imread(infile)
    rows, cols, depth = im.shape

    # this doesn't really matter, it will cancel in the end, but we
    # need it for the mpl API
    dpi = 100

    height = float(rows) / dpi * scale
    width = float(cols) / dpi * scale

    extension = extout.lower()

    if preview:
        # let the UI backend do everything
        import matplotlib.pyplot as plt
        fig = plt.figure(figsize=(width, height), dpi=dpi)
    else:
        if extension == '.png':
            from matplotlib.backends.backend_agg \
                import FigureCanvasAgg as FigureCanvas
        elif extension == '.pdf':
            from matplotlib.backends.backend_pdf \
                import FigureCanvasPdf as FigureCanvas
        elif extension == '.svg':
            from matplotlib.backends.backend_svg \
                import FigureCanvasSVG as FigureCanvas
        else:
            raise ValueError("Can only handle "
                             "extensions 'png', 'svg' or 'pdf'")

        from matplotlib.figure import Figure
        fig = Figure(figsize=(width, height), dpi=dpi)
        canvas = FigureCanvas(fig)

    ax = fig.add_axes([0, 0, 1, 1],
                      aspect='auto',
                      frameon=False,
                      xticks=[],
                      yticks=[])

    basename, ext = os.path.splitext(basename)
    ax.imshow(im, aspect='auto', resample=True, interpolation=interpolation)
    fig.savefig(thumbfile, dpi=dpi)
    return fig
Ejemplo n.º 56
0
def validate_nip(model,
                 data,
                 save_dir=False,
                 epoch=0,
                 show_ref=False,
                 loss_type='L2'):
    """
    Develops image patches using the given NIP and returns standard image quality measures.
    If requested, resulting patches are visualized as thumbnails and saved to a directory.

    :param model: the NIP model
    :param data: the dataset (instance of Dataset)
    :param data: the dataset (instance of Dataset)
    :param save_dir: path to the directory where figures should be generated
    :param epoch: epoch counter to be appended to the output filename
    :param show_ref: whether to show only the developed image or also the GT target
    :param loss_type: L1 or L2
    :return: tuple of lists with per-image measurements of (ssims, psnrs, losss)
    """

    ssims = []
    psnrs = []
    losss = []

    # If requested, plot a figure with output/target pairs
    if save_dir is not None:
        images_x = np.minimum(data.count_validation, 10 if not show_ref else 5)
        images_y = np.ceil(data.count_validation / images_x)
        fig = Figure(figsize=(20, 20 / images_x * images_y *
                              (1 if not show_ref else 0.5)))

    developed_out = np.zeros_like(data['validation']['y'], dtype=np.float32)

    for b in range(data.count_validation):
        example_x, example_y = data.next_validation_batch(b, 1)
        developed = model.process(example_x).numpy().clip(0, 1)
        developed_out[b, :, :, :] = developed
        developed = developed[:, :, :, :].squeeze()
        reference = example_y.squeeze()

        # Compute stats
        ssim = metrics.ssim(reference, developed).mean()
        psnr = metrics.psnr(reference, developed).mean()

        if loss_type == 'L2':
            loss = np.mean(np.power(reference - developed, 2.0))
        elif loss_type == 'L1':
            loss = np.mean(np.abs(reference - developed))
        else:
            raise ValueError('Invalid loss! Use either L1 or L2.')

        ssims.append(ssim)
        psnrs.append(psnr)
        losss.append(loss)

        # Add images to the plot
        if save_dir is not None:
            ax = fig.add_subplot(images_y, images_x, b + 1)
            plots.image(np.concatenate(
                (reference, developed), axis=1) if show_ref else developed,
                        '{:.1f} dB / {:.2f}'.format(psnr, ssim),
                        axes=ax)

    if save_dir is not None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        fig.savefig('{}/nip_validation_{:05d}.jpg'.format(save_dir, epoch),
                    bbox_inches='tight',
                    dpi=100,
                    quality=90)
        del fig

    return ssims, psnrs, losss
Ejemplo n.º 57
0
				
- scripting layer: user layer
"""

# example 1
# generate histo with random number using artist layer
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas  # import figure canvas from back end layer (agg -> anty grain geometry that is high performance lib)
from matplotlib.figure import Figure  # import figure artist

fig = Figure()
canvas = FigureCanvas(fig)  # attach the Figure artist to Figure canvas

import numpy as np  # to generate random number

x = np.random.randn(10000)
ax = fig.add_subplot(
    111
)  # it create an axis artist obj (the axes artist is added automatically to the figure axes container)
ax.hist(x, 100)  # generate the histogram of 10000 points
ax.set_title('Histo Example')
fig.savefig('matplotlib_histogram.png')

# example 2
# generate histo with random number using scripting layer
import matplotlib.pyplot as plt
import numpy as np

x = np.random.randn(10000)
plt.hist(x, 100)
plt.title('Histo example using script layer')
plt.show()
Ejemplo n.º 58
0
def specphot_acs_and_wfc3(
        id=69,
        grism_root='ibhm45030',
        MAIN_OUTPUT_FILE='cosmos-1.v4.6',
        OUTPUT_DIRECTORY='/Users/gbrammer/research/drg/PHOTZ/EAZY/NEWFIRM/v4.6/OUTPUT_KATE/',
        CACHE_FILE='Same',
        Verbose=False,
        SPC=None,
        cat=None,
        grismCat=None,
        zout=None,
        fout=None,
        OUT_PATH='/tmp/',
        OUT_FILE_FORMAT=True,
        OUT_FILE='junk.png',
        GET_SPEC_ONLY=False,
        GET_WFC3=False,
        WFC3_DIR='/3DHST/Spectra/Release/v2.0/GOODS-S'):
    """
    specphot_acs_and_wfc3(id)
    
    Get photometry/SED fit as well as WFC3 spectrum when available and overplot G141 spectrum.
    This is different from unicorn.analysis.specphot() which does not get the WFC3 spectrum.
    """

    import threedhst.eazyPy as eazy
    import threedhst.catIO as catIO
    import pyfits

    #### Get G141 spectrum
    if Verbose:
        print 'Read SPC'

    if SPC is None:
        SPC = threedhst.plotting.SPCFile(grism_root + '_2_opt.SPC.fits',
                                         axe_drizzle_dir='DRIZZLE_G141')

    spec = SPC.getSpec(id)
    if spec is False:
        return False

    xmin = 3000
    xmax = 2.4e4

    lam = spec.field('LAMBDA')
    flux = spec.field('FLUX')
    ffix = flux - spec.field('CONTAM')
    ferr = spec.field('FERROR')  #*0.06/0.128254

    if Verbose:
        print 'Read grism catalog'

    #### Read the grism catalog and get coords of desired object
    if grismCat is None:
        grismCat = threedhst.sex.mySexCat('DATA/' + grism_root + '_drz.cat')

    #### Source size
    R = np.sqrt(np.cast[float](grismCat.A_IMAGE) *
                np.cast[float](grismCat.B_IMAGE))
    grism_idx = np.where(grismCat.id == id)[0][0]

    Rmatch = R[grism_idx] * 1.

    ra0 = grismCat.ra[grismCat.id == id][0]
    de0 = grismCat.dec[grismCat.id == id][0]

    #### Read EAZY outputs and get info for desired object
    if cat is None:
        cat = catIO.ReadASCIICat(OUTPUT_DIRECTORY + '../' + MAIN_OUTPUT_FILE +
                                 '.cat')

    dr = np.sqrt((cat.ra - ra0)**2 * np.cos(de0 / 360. * 2 * np.pi)**2 +
                 (cat.dec - de0)**2) * 3600.

    photom_idx = np.where(dr == np.min(dr))[0][0]

    drMatch = dr[photom_idx] * 1.

    if drMatch > 2:
        return False

    if Verbose:
        print 'Read zout'
    if zout is None:
        zout = catIO.ReadASCIICat(OUTPUT_DIRECTORY + '/' + MAIN_OUTPUT_FILE +
                                  '.zout')

    if fout is None:
        fout = catIO.ReadASCIICat(OUTPUT_DIRECTORY +
                                  '/../cosmos-1.m05.v4.6.fout')

    if Verbose:
        print 'Read binaries'

    lambdaz, temp_sed, lci, obs_sed, fobs, efobs = \
        eazy.getEazySED(photom_idx, MAIN_OUTPUT_FILE=MAIN_OUTPUT_FILE, \
                          OUTPUT_DIRECTORY=OUTPUT_DIRECTORY, \
                          CACHE_FILE = CACHE_FILE)

    try:
        lambdaz, temp_sed_sm = unicorn.analysis.convolveWithThumb(
            id, lambdaz, temp_sed, SPC)
    except:
        temp_sed_sm = temp_sed * 1.

    wfc3_exist = False

    if GET_WFC3:
        wfc3_file_path = WFC3_DIR + "/*/1D/FITS/*%05d.1D.fits" % (
            cat.id[photom_idx])
        wfc3_file = glob.glob(wfc3_file_path)
        if wfc3_file != []:
            wfc3_spec = pyfits.open(wfc3_file[0])
            wfc3_exist = True
        else:
            print 'No WFC3 spectrum.'

    if Verbose:
        print 'Normalize spectrum'

    q = np.where((lam > 0.55e4) & (lam < 1.0e4) & (flux > 0))[0]

    if len(q) == 0:
        return False

    yint = np.interp(lam[q], lambdaz, temp_sed_sm)

    anorm = np.sum(yint * ffix[q]) / np.sum(ffix[q]**2)
    if np.isnan(anorm):
        anorm = 1.
    total_err = np.sqrt((ferr)**2 + (1.0 * spec.field('CONTAM'))**2) * anorm

    if GET_SPEC_ONLY:
        if drMatch > 1:
            return False
        else:
            return lam, ffix * anorm, total_err, lci, fobs, efobs, photom_idx

    if Verbose:
        print 'Start plot'

    #### Make the plot
    threedhst.plotting.defaultPlotParameters()

    xs = 5.8
    ys = xs / 4.8 * 3.2
    if USE_PLOT_GUI:
        fig = plt.figure(figsize=[xs, ys], dpi=100)
    else:
        fig = Figure(figsize=[xs, ys], dpi=100)

    fig.subplots_adjust(wspace=0.2,
                        hspace=0.2,
                        left=0.13 * 4.8 / xs,
                        bottom=0.15 * 4.8 / xs,
                        right=1. - 0.02 * 4.8 / xs,
                        top=1 - 0.10 * 4.8 / xs)

    ax = fig.add_subplot(111)

    ymax = np.max((ffix[q]) * anorm)

    if Verbose:
        print 'Make the plot'

    ax.plot(lambdaz, temp_sed_sm, color='red')
    ax.plot(lam[q], ffix[q] * anorm, color='blue', alpha=0.2, linewidth=1)

    #### Show own extraction
    sp1d = threedhst.spec1d.extract1D(id,
                                      root=grism_root,
                                      path='./HTML',
                                      show=False,
                                      out2d=False)  #, GRISM_NAME='G800L')
    lam = sp1d['lam']
    flux = sp1d['flux']
    ffix = sp1d['flux'] - sp1d['contam']
    ferr = sp1d['error']
    anorm = np.sum(yint * ffix[q]) / np.sum(ffix[q]**2)
    ax.plot(lam[q], ffix[q] * anorm, color='blue', alpha=0.6, linewidth=1)

    #### Show photometry + eazy template
    ax.errorbar(lci,
                fobs,
                yerr=efobs,
                color='orange',
                marker='o',
                markersize=10,
                linestyle='None',
                alpha=0.4)
    ax.plot(lambdaz, temp_sed_sm, color='red', alpha=0.4)

    if wfc3_exist:
        q_wfc3 = np.where((wfc3_spec[1].data.wave > 1.08e4)
                          & (wfc3_spec[1].data.wave < 1.68e4)
                          & (wfc3_spec[1].data.flux > 0))[0]
        yint_wfc3 = np.interp(wfc3_spec[1].data.wave[q_wfc3], lambdaz,
                              temp_sed_sm)
        spec_wfc3 = (wfc3_spec[1].data.flux -
                     wfc3_spec[1].data.contam) / wfc3_spec[1].data.sensitivity
        anorm_wfc3 = np.sum(yint_wfc3 * spec_wfc3[q_wfc3]) / np.sum(
            spec_wfc3[q_wfc3]**2)
        if np.isnan(anorm_wfc3): anorm_wfc3 = 1.
        print 'Scaling factors: ', anorm, anorm_wfc3
        ax.plot(wfc3_spec[1].data.wave[q_wfc3],
                spec_wfc3[q_wfc3] * anorm_wfc3,
                color='blue',
                alpha=0.6,
                linewidth=1)

    ax.set_ylabel(r'$f_{\lambda}$')

    if plt.rcParams['text.usetex']:
        ax.set_xlabel(r'$\lambda$ [\AA]')
        ax.set_title(
            '%s: \#%d, z=%4.1f' % (SPC.filename.split('_2_opt')[0].replace(
                '_', '\_'), id, zout.z_peak[photom_idx]))
    else:
        ax.set_xlabel(r'$\lambda$ [$\AA$]')
        ax.set_title(
            '%s: #%d, z=%4.1f' % (SPC.filename.split('_2_opt')[0].replace(
                '_', '\_'), id, zout.z_peak[photom_idx]))

    #kmag = 25-2.5*np.log10(cat.ktot[photom_idx])
    kmag = cat.kmag[photom_idx]

    ##### Labels
    label = 'ID=' + r'%s   K=%4.1f  $\log M$=%4.1f' % (np.int(
        cat.id[photom_idx]), kmag, fout.field('lmass')[photom_idx])

    ax.text(5e3,
            1.08 * ymax,
            label,
            horizontalalignment='left',
            verticalalignment='bottom')

    label = 'R=%4.1f"' % (drMatch)
    if drMatch > 1.1:
        label_color = 'red'
    else:
        label_color = 'black'
    ax.text(2.2e4,
            1.08 * ymax,
            label,
            horizontalalignment='right',
            color=label_color,
            verticalalignment='bottom')

    ax.set_xlim(xmin, xmax)
    ax.set_ylim(-0.1 * ymax, 1.2 * ymax)

    if Verbose:
        print 'Save the plot'

    if OUT_FILE_FORMAT:
        out_file = '%s_%05d_SED.png' % (grism_root, id)
    else:
        out_file = OUT_FILE

    if USE_PLOT_GUI:
        fig.savefig(OUT_PATH + '/' + out_file, dpi=100, transparent=False)
        plt.close()
    else:
        canvas = FigureCanvasAgg(fig)
        canvas.print_figure(OUT_PATH + '/' + out_file,
                            dpi=100,
                            transparent=False)

    print unicorn.noNewLine + OUT_PATH + '/' + out_file

    if Verbose:
        print 'Close the plot window'
Ejemplo n.º 59
0
class MPLPlot(FCanvas):
    """
    This is the basic matplotlib canvas widget we are using for matplotlib
    plots. This canvas only provides a few convenience tools for automatic
    sizing and creating subfigures, but is otherwise not very different
    from the class ``FCanvas`` that comes with matplotlib (and which we inherit).
    It can be used as any QT widget.
    """
    def __init__(self,
                 parent: Optional[QtWidgets.QWidget] = None,
                 width: float = 4.0,
                 height: float = 3.0,
                 dpi: int = 150,
                 nrows: int = 1,
                 ncols: int = 1):
        """
        Create the canvas.

        :param parent: the parent widget
        :param width: canvas width (inches)
        :param height: canvas height (inches)
        :param dpi: figure dpi
        :param nrows: number of subplot rows
        :param ncols: number of subplot columns
        """

        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super().__init__(self.fig)

        self.axes: List[Axes] = []
        self._tightLayout = False
        self._showInfo = False
        self._infoArtist = None
        self._info = ''

        self.clearFig(nrows, ncols)
        self.setParent(parent)

    def autosize(self) -> None:
        """
        Sets some default spacings/margins.
        :return:
        """
        if not self._tightLayout:
            self.fig.subplots_adjust(left=0.125,
                                     bottom=0.125,
                                     top=0.9,
                                     right=0.875,
                                     wspace=0.35,
                                     hspace=0.2)
        else:
            self.fig.tight_layout(rect=[0, 0.03, 1, 0.95])

        self.draw()

    def clearFig(self,
                 nrows: int = 1,
                 ncols: int = 1,
                 naxes: int = 1) -> List[Axes]:
        """
        Clear and reset the canvas.

        :param nrows: number of subplot/axes rows to prepare
        :param ncols: number of subplot/axes column to prepare
        :param naxes: number of axes in total
        :returns: the created axes in the grid
        """
        self.fig.clear()
        setMplDefaults()

        self.axes = []
        iax = 1
        if naxes > nrows * ncols:
            raise ValueError(
                f'Number of axes ({naxes}) larger than rows ({nrows}) x '
                f'columns ({ncols}).')

        for i in range(1, naxes + 1):
            kw = {}
            if iax > 1:
                kw['sharex'] = self.axes[0]
                kw['sharey'] = self.axes[0]

            self.axes.append(self.fig.add_subplot(nrows, ncols, i))
            iax += 1

        self.autosize()
        return self.axes

    def resizeEvent(self, event: QtGui.QResizeEvent) -> None:
        """
        Re-implementation of the widget resizeEvent method.
        Makes sure we resize the plots appropriately.
        """
        self.autosize()
        super().resizeEvent(event)

    def setTightLayout(self, tight: bool) -> None:
        """
        Set tight layout mode.
        :param tight: if true, use tight layout for autosizing.
        """
        self._tightLayout = tight
        self.autosize()

    def setShowInfo(self, show: bool) -> None:
        """Whether to show additional info in the plot"""
        self._showInfo = show
        self.updateInfo()

    def updateInfo(self) -> None:
        if self._infoArtist is not None:
            self._infoArtist.remove()
            self._infoArtist = None

        if self._showInfo:
            self._infoArtist = self.fig.text(
                0,
                0,
                self._info,
                fontsize='x-small',
                verticalalignment='bottom',
            )
        self.draw()

    def toClipboard(self) -> None:
        """
        Copy the current canvas to the clipboard.
        """
        buf = io.BytesIO()
        self.fig.savefig(buf,
                         dpi=300,
                         facecolor='w',
                         format='png',
                         transparent=True)
        QtWidgets.QApplication.clipboard().setImage(
            QtGui.QImage.fromData(buf.getvalue()))
        buf.close()

    def setFigureTitle(self, title: str) -> None:
        """Add a title to the figure."""
        self.fig.text(0.5,
                      0.99,
                      title,
                      horizontalalignment='center',
                      verticalalignment='top',
                      fontsize='small')
        self.draw()

    def setFigureInfo(self, info: str) -> None:
        """Display an info string in the figure"""
        self._info = info
        self.updateInfo()
Ejemplo n.º 60
0
# Re-map to highlight small differences
s = t2m.data.shape
t2m.data = qcut(t2m.data.flatten(), 20, labels=False).reshape(s)
# Plot as a colour map
lats = t2m.coord('latitude').points
lons = t2m.coord('longitude').points
t2m_img = ax.pcolorfast(lons,
                        lats,
                        t2m.data,
                        cmap='coolwarm',
                        vmin=0,
                        vmax=20,
                        alpha=0.5)

# Also pressure
prmsl = twcr.load('prmsl', dte, version='4.5.1')
prmsl = prmsl.extract(iris.Constraint(member=1))
mg.pressure.plot(ax, prmsl, scale=0.01, resolution=0.25, linewidths=1)

# Also precip
prate = twcr.load('prate', dte, version='4.5.1')
prate = prate.extract(iris.Constraint(member=1))
mg.precipitation.plot(ax, prate, resolution=0.25, vmin=-0.01, vmax=0.04)

# Add a label showing the date
label = "Surface weather test plot"
mg.utils.plot_label(ax, label, facecolor=fig.get_facecolor())

# Render the figure as a png
fig.savefig('%s/tst2_v3.png' % args.opdir)