Пример #1
0
    def __spacetime_diagram_o_frame(self):
        # from (x',t') to (x,t)
        def tr(x_prime, t_prime):
            x_prime, t_prime = np.asarray(x_prime), np.asarray(t_prime)
            return self.lorentz_transformations.transform(
                x_prime, t_prime, -self.velocity)

        # form (x,t) to (x',t')
        def inv_tr(x, t):
            x, t = np.asarray(x), np.asarray(t)
            return self.lorentz_transformations.transform(x, t, self.velocity)

        grid_helper = GridHelperCurveLinear((tr, inv_tr))
        ax = Subplot(self.fig, 1, 2, 1, grid_helper=grid_helper) if self.showOPrime \
            else Subplot(self.fig, 1, 1, 1, grid_helper=grid_helper)
        self.fig.add_subplot(ax)

        ax.set_xlabel("x", loc="center")
        ax.set_ylabel("t", loc="center")

        # O' x axis
        ax.axis["x1"] = x1 = ax.new_floating_axis(1, 0)
        x1.label.set_text("x'")

        # O' t axis
        ax.axis["t1"] = t1 = ax.new_floating_axis(0, 0)
        t1.label.set_text("t'")

        self.__add_x_and_y_axis(ax)
        ax.format_coord = self.__format_coord_o_frame

        self.__remove_ticks(ax, x1, t1)

        self.world_lines_plotter.plot(plt, ax, self.velocity)
Пример #2
0
def curvelinear_test1(fig):
    """
    grid for custom transform.
    """
    def tr(x, y):
        x, y = np.asarray(x), np.asarray(y)
        return x, y - x

    def inv_tr(x, y):
        x, y = np.asarray(x), np.asarray(y)
        return x, y + x

    grid_helper = GridHelperCurveLinear((tr, inv_tr))

    ax1 = Subplot(fig, 1, 2, 1, grid_helper=grid_helper)
    # ax1 will have a ticks and gridlines defined by the given
    # transform (+ transData of the Axes). Note that the transform of
    # the Axes itself (i.e., transData) is not affected by the given
    # transform.

    fig.add_subplot(ax1)

    xx, yy = tr([3, 6], [5.0, 10.])
    ax1.plot(xx, yy, linewidth=2.0)

    ax1.set_aspect(1.)
    ax1.set_xlim(0, 10.)
    ax1.set_ylim(0, 10.)

    ax1.axis["t"] = ax1.new_floating_axis(0, 3.)
    ax1.axis["t2"] = ax1.new_floating_axis(1, 7.)
    ax1.grid(True, zorder=0)
Пример #3
0
def plotCorrelation(tauArray,kappaMatrix,kappaLower=None,kappaUpper=None,CI=None,amplify=1):
    
    """Plots Pearson Correlation Coefficient K(t,tau) with rotated
    axis to indicate absolute t, and relative time shift tau, between
    two signals x(t),y(t).
    
    Specified matrix has to be square with values -1 < p < +1
    with corresponding time array giving the absolute time, t
    of the centers of each correlated window."""

    # defining tranformation for relative time shifts
    def R(x, y):
        x, y = asarray(x), asarray(y)
        #return x,y
        return (2*x - y)/2, (y + 2*x)/2

    def Rt(x, y):
        x, y = asarray(x), asarray(y)
        #return x,y
        return x + y, x - y

    # create figure with rotated axes
    fig = figure(figsize=(10, 10),frameon=False)
    grid_locator = angle_helper.LocatorDMS(20)
    grid_helper = GridHelperCurveLinear((R, Rt),
                  grid_locator1=grid_locator,
                  grid_locator2=grid_locator)
    
    ax = Subplot(fig, 1, 1, 1, grid_helper=grid_helper)
    fig.add_subplot(ax);ax.axis('off');
    
    # copying over matrix
    K = array(kappaMatrix)
    
    # zero out correlations if confidence intervals overlap zero
    if all(kappaLower != None) and all(kappaUpper != None) :
        K[ (kappaLower<0) * (0<kappaUpper) ] = 0
        
    # zero out statistically insignificant correlations
    if all(CI != None) :
        K[ abs(kappaMatrix) < CI ] = 0
    
    # display pearson correlation matrix with +ive in red and -ive in blue
    ax.imshow(K,cmap="RdBu_r",interpolation="none",origin="bottom",
              extent = (tauArray[0],tauArray[-1],tauArray[0],tauArray[-1]),vmin=-1.0/amplify,vmax=1.0/amplify)

    # display rotated axes time,t and time delay,tau
    ax.axis["tau"] = tau = ax.new_floating_axis(0,0)
    ax.axis["t"] = t = ax.new_floating_axis(1,0)
    
    # setting axes options
    ax.set_xlim(tauArray[0],tauArray[-1])
    ax.set_ylim(tauArray[0],tauArray[-1])
    ax.grid(which="both")
    ax.set_aspect(1)
    
    return fig
Пример #4
0
    def setup_content(self):
        # Create subplot and add it to the figure:
        self.plot = Subplot(self.figure, 211, axisbg=(1.0, 1.0, 1.0, 0.0))
        self.plot.set_autoscale_on(False)
        self.figure.add_axes(self.plot)

        # Connect events:
        self.canvas.mpl_connect('draw_event', self.fix_after_drawing)
        self.canvas.mpl_connect('resize_event', self.fix_after_drawing)

        self.update()
Пример #5
0
def standard_axis(fig,
                  subplot,
                  sharex=None,
                  sharey=None,
                  hasx=False,
                  hasy=True):

    sys.stderr.write("This method is deprecated. "
                     "Use stfio_plot.StandardAxis instead.\n")
    try:
        iter(subplot)
        if isinstance(subplot, matplotlib.gridspec.GridSpec):
            ax1 = Subplot(fig,
                          subplot,
                          frameon=False,
                          sharex=sharex,
                          sharey=sharey)
        else:
            ax1 = Subplot(fig,
                          subplot[0],
                          subplot[1],
                          subplot[2],
                          frameon=False,
                          sharex=sharex,
                          sharey=sharey)
    except:
        ax1 = Subplot(fig,
                      subplot,
                      frameon=False,
                      sharex=sharex,
                      sharey=sharey)

    fig.add_axes(ax1)
    ax1.axis["right"].set_visible(False)
    ax1.axis["top"].set_visible(False)
    if not hasx:
        ax1.axis["bottom"].set_visible(False)
    if not hasy:
        ax1.axis["left"].set_visible(False)

    return ax1
Пример #6
0
    def setup_content(self, status_callback, marker_callback):
        # Create subplot and add it to the figure:
        self.plot = Subplot(self.figure, 211, facecolor=(1.0, 1.0, 1.0, 0.0))
        self.plot.set_autoscale_on(False)
        self.figure.add_axes(self.plot)

        # Connect events:
        self.canvas.mpl_connect('draw_event', self.fix_after_drawing)
        self.canvas.mpl_connect('resize_event', self.fix_after_drawing)

        self.mtc = MotionTracker(self, status_callback)
        self.cc = ClickCatcher(self, marker_callback)
Пример #7
0
def plotISIDistributions(sessions,groups=None,sessionTypes=None,samplingRate=30000.0,save=False,fname=None,figsize=(10,6)):
    """Plots the isi distributions for all the cells in the given sessions"""
    fig = plt.figure(figsize=figsize)
    fig.subplots_adjust(left=0.05,right=.95)
    ISIs = {}
    ncells = 0
    if sessionTypes != None:
        sessionTypes = dict(zip(sessions,sessionTypes))
    for g in groups:
        for s in sessions:
            dataFile = h5py.File(os.path.expanduser('~/Documents/research/data/spikesorting/hmm/p=1e-20/%sg%.4d.hdf5' % (s,g)),'r')
            try:
                for c in dataFile['unitTimePoints'].keys():
                    isi = np.log(np.diff(dataFile['unitTimePoints'][c][:]/(samplingRate/1000)))
                    cn = 'g%dc%d' % (g,int(c))
                    if cn in ISIs:
                        ISIs[cn]['%s' %(s,)] = isi
                    else:
                        ISIs[cn] = {'%s' %(s,): isi}
            finally:
                dataFile.close()
        
    i = 1
    ncells = len(ISIs.keys())
    for c in ISIs.keys():
        ax = Subplot(fig,1,ncells,i)
        formatAxis(ax)
        fig.add_axes(ax)
        ax.set_title(c)
        for k,v in ISIs[c].items():
            if sessionTypes != None:
                L = sessionTypes[k]
            else:
                L = k
            n,b = np.histogram(v,bins=20,normed=True)
            plt.plot(b[:-1],n,label=L)
        i+=1
        ax.set_xlabel('ISI [ms]')
        xl,xh = int(np.round((b[0]-0.5)*2))/2,int(np.round((b[-1]+0.5)*2))/2
        xl = -0.5
        dx = np.round(10.0*(xh-xl)/4.0)/10
        xt_ = np.arange(xl,xh+1,dx)
        ax.set_xticks(xt_)
        ax.set_xticklabels(map(lambda s: r'$10^{%.1f}$' % (s,),xt_))

    fig.axes[-1].legend()
    if save:
        if fname == None:
            fname = os.path.expanduser('~/Documents/research/figures/isi_comparison.pdf')
        fig.savefig(fname,bbox='tight')
Пример #8
0
def plot_hpxmap_hist(hpxmap,
                     raRange=[-180, 180],
                     decRange=[-90, 90],
                     lonRef=0.0,
                     cbar_kwargs=dict(),
                     hpxmap_kwargs=dict(),
                     hist_kwargs=dict(),
                     fit_gaussian=False,
                     figsize=(10, 4)):

    hist_defaults = dict(peak=True)
    set_defaults(hist_kwargs, hist_defaults)

    if isinstance(hpxmap, basestring):
        hpxmap = healpy.read_map(f)

    fig = plt.figure(10, figsize=figsize)
    fig.clf()
    gridspec = plt.GridSpec(1, 3)

    bmap = FgcmBasemap(lonRef=lonRef,
                       raMin=raRange[0],
                       raMax=raRange[1],
                       decMin=decRange[0],
                       decMax=decRange[1])
    bmap.create_axes(rect=gridspec[0:2])
    im = bmap.draw_hpxmap(hpxmap, **hpxmap_kwargs)
    bmap.draw_inset_colorbar(**cbar_kwargs)
    ax1 = plt.gca()
    ax1.axis['right'].major_ticklabels.set_visible(False)
    ax1.axis['top'].major_ticklabels.set_visible(False)

    ax2 = Subplot(fig, gridspec[2])
    fig.add_subplot(ax2)
    plt.sca(ax2)
    ret = draw_hist(hpxmap, fit_gaussian=fit_gaussian, **hist_kwargs)
    ax2.yaxis.set_major_locator(MaxNLocator(6, prune='both'))
    ax2.xaxis.set_major_locator(MaxNLocator(5))
    ax2.axis['left'].major_ticklabels.set_visible(False)
    ax2.axis['right'].major_ticklabels.set_visible(True)
    ax2.axis['right'].label.set_visible(True)
    ax2.axis['right'].label.set_text(r'Normalized Area (a.u.)')
    ax2.axis['bottom'].label.set_visible(True)

    plt.subplots_adjust(bottom=0.15, top=0.95)

    return fig, [ax1, ax2], ret
Пример #9
0
    def setup_axes1(self, fig, T_ticks, subplotshape=None):
        """
        A simple one.
        """
        deg = -45.
        self.tr = Affine2D().rotate_deg(deg)

        theta_ticks = []  #np.arange(theta_min, theta_max, d_T)

        grid_helper = GridHelperCurveLinear(
            self.tr,
            grid_locator1=FixedLocator(T_ticks),
            grid_locator2=FixedLocator(theta_ticks))

        if subplotshape is None:
            subplotshape = (1, 1, 1)

        ax1 = Subplot(fig, *subplotshape, grid_helper=grid_helper)
        # ax1 will have a ticks and gridlines defined by the given
        # transform (+ transData of the Axes). Note that the transform of
        # the Axes itself (i.e., transData) is not affected by the given
        # transform.

        fig.add_subplot(ax1)

        # SW, SE, NE, NW
        corners = np.array([[-25., -20.], [30., 40.], [-40., 120.],
                            [-105., 60.]])
        corners_t = self._tf(corners[:, 0], corners[:, 1])

        # ax1.set_aspect(1.)
        x_min, x_max = self.x_range
        ax1.set_xlim(x_min, x_max)
        ax1.set_ylim(*self.y_range)
        ax1.set_xlabel('Temperature [C]')

        ax1.set_aspect(1)

        #ax1.axis["t"]=ax1.new_floating_axis(0, 0.)
        #T_axis = ax1.axis['t']
        #theta_axis = ax1.axis["t2"]=ax1.new_floating_axis(1, 0.)

        # plot.draw()
        # plot.show()
        self.ax1 = ax1
Пример #10
0
def plotXcorr(qdata,save=False,fname='hmmSortingUnits.pdf'):

    unitTimePoints = qdata['unitTimePoints']
    samplingRate = qdata.get('samplingRate',30000.0)
    units = unitTimePoints.keys()
    nunits = len(units)
    xsize = max(10,nunits*2)
    fig = plt.figure(figsize=(xsize,(6.0/8)*xsize) )
    fig.subplots_adjust(left=0.03,right=0.97,bottom=0.03,top=0.97)
    i = 1
    if not 'XCorr' in qdata:
        if isinstance(qdata,dict):
            qdata['XCorr'] = {}
        else:
            qdata.create_group('XCorr')
    for k1 in xrange(len(units)-1) :
        if not units[k1] in qdata['XCorr']:
            qdata['XCorr'].create_group(units[k1])
        for k2 in xrange(k1+1,len(units)):
            if not units[k2] in qdata['XCorr'][units[k1]]:
                T1 = unitTimePoints[units[k1]][:]/(samplingRate/1000)
                T2 = unitTimePoints[units[k2]][:]/(samplingRate/1000)
                #compute differences less than 50 ms
                C = pdist_threshold2(T1,T2,50)
                qdata['XCorr'][units[k1]].create_dataset(units[k2],data=C,compression=2,fletcher32=True,shuffle=True)
            else:
                C = qdata['XCorr'][units[k1]][units[k2]][:]
            n,b = np.histogram(C,np.arange(-50,50),normed=True)
            ax = Subplot(fig,nunits-1,nunits,k1*nunits+k2) 
            fig.add_axes(ax)
            formatAxis(ax)
            ax.plot(b[:-1],n,'k')
            ax.fill_betweenx([0,n.max()],-1.0,1.0,color='r',alpha=0.3)
            if not (k1 == len(units)-2 and k2 == len(units)-1):
                ax.set_xticklabels('')
                ax.set_yticklabels('')
    if save:
        fig.savefig(os.path.expanduser('~/Documents/research/figures/SpikeSorting/hmm/%s' %( fname,)),bbox='tight') 
    else:
        plt.draw()
Пример #11
0
def curvelinear_test1(fig):
    """
    Grid for custom transform.
    """
    def tr(x, y):
        x, y = numpy.asarray(x), numpy.asarray(y)
        return x, y - (2 * x)  # return x + (5 * y), (7 * y) + (3 * x)

    def inv_tr(x, y):
        x, y = numpy.asarray(x), numpy.asarray(y)
        return x, y + (2 * x)

    grid_helper = GridHelperCurveLinear((tr, inv_tr))

    ax1 = Subplot(fig, 1, 1, 1, grid_helper=grid_helper)
    # ax1 will have a ticks and gridlines defined by the given
    # transform (+ transData of the Axes). Note that the transform of
    # the Axes itself (i.e., transData) is not affected by the given
    # transform.

    fig.add_subplot(ax1)

    xx, yy = tr([0, 1], [0, 2])
    ax1.plot(xx, yy, linewidth=2.0)

    ax1.set_aspect(1)
    ax1.set_xlim(-3, 3)
    ax1.set_ylim(-3, 3)

    ax1.axis["t"] = ax1.new_floating_axis(
        0, 0
    )  # first argument appears to be slope, second argument appears to be starting point on vertical
    ax1.axis["t2"] = ax1.new_floating_axis(1, 0)
    ax1.axhline(y=0, color='r')
    ax1.axvline(x=0, color='r')
    ax1.grid(True, zorder=0)
Пример #12
0
    def __init__(self, ol=None, parent=None):
        # pylint: disable=unused-argument,super-on-old-class
        super(DGSPlannerGUI, self).__init__(parent)
        #OrientedLattice
        if ValidateOL(ol):
            self.ol = ol
        else:
            self.ol = mantid.geometry.OrientedLattice()
        self.masterDict = dict()  #holds info about instrument and ranges
        self.updatedInstrument = False
        self.updatedOL = False
        self.wg = None  #workspace group
        self.instrumentWidget = InstrumentSetupWidget.InstrumentSetupWidget(
            self)
        self.setLayout(QtGui.QHBoxLayout())
        controlLayout = QtGui.QVBoxLayout()
        controlLayout.addWidget(self.instrumentWidget)
        self.ublayout = QtGui.QHBoxLayout()
        self.classic = ClassicUBInputWidget.ClassicUBInputWidget(self.ol)
        self.ublayout.addWidget(self.classic,
                                alignment=QtCore.Qt.AlignTop,
                                stretch=1)
        self.matrix = MatrixUBInputWidget.MatrixUBInputWidget(self.ol)
        self.ublayout.addWidget(self.matrix,
                                alignment=QtCore.Qt.AlignTop,
                                stretch=1)
        controlLayout.addLayout(self.ublayout)
        self.dimensionWidget = DimensionSelectorWidget.DimensionSelectorWidget(
            self)
        controlLayout.addWidget(self.dimensionWidget)
        plotControlLayout = QtGui.QGridLayout()
        self.plotButton = QtGui.QPushButton("Plot", self)
        self.oplotButton = QtGui.QPushButton("Overplot", self)
        self.helpButton = QtGui.QPushButton("?", self)
        self.colorLabel = QtGui.QLabel('Color by angle', self)
        self.colorButton = QtGui.QCheckBox(self)
        self.colorButton.toggle()
        self.aspectLabel = QtGui.QLabel('Aspect ratio 1:1', self)
        self.aspectButton = QtGui.QCheckBox(self)
        self.saveButton = QtGui.QPushButton("Save Figure", self)
        plotControlLayout.addWidget(self.plotButton, 0, 0)
        plotControlLayout.addWidget(self.oplotButton, 0, 1)
        plotControlLayout.addWidget(self.colorLabel, 0, 2,
                                    QtCore.Qt.AlignRight)
        plotControlLayout.addWidget(self.colorButton, 0, 3)
        plotControlLayout.addWidget(self.aspectLabel, 0, 4,
                                    QtCore.Qt.AlignRight)
        plotControlLayout.addWidget(self.aspectButton, 0, 5)
        plotControlLayout.addWidget(self.helpButton, 0, 6)
        plotControlLayout.addWidget(self.saveButton, 0, 7)
        controlLayout.addLayout(plotControlLayout)
        self.layout().addLayout(controlLayout)

        #figure
        self.figure = Figure()
        self.figure.patch.set_facecolor('white')
        self.canvas = FigureCanvas(self.figure)
        self.grid_helper = GridHelperCurveLinear((self.tr, self.inv_tr))
        self.trajfig = Subplot(self.figure,
                               1,
                               1,
                               1,
                               grid_helper=self.grid_helper)
        self.trajfig.hold(True)
        self.figure.add_subplot(self.trajfig)
        self.layout().addWidget(self.canvas)
        self.needToClear = False
        self.saveDir = ''

        #connections
        self.matrix.UBmodel.changed.connect(self.updateUB)
        self.matrix.UBmodel.changed.connect(self.classic.updateOL)
        self.classic.changed.connect(self.matrix.UBmodel.updateOL)
        self.classic.changed.connect(self.updateUB)
        self.instrumentWidget.changed.connect(self.updateParams)
        self.dimensionWidget.changed.connect(self.updateParams)
        self.plotButton.clicked.connect(self.updateFigure)
        self.oplotButton.clicked.connect(self.updateFigure)
        self.helpButton.clicked.connect(self.help)
        self.saveButton.clicked.connect(self.save)
        #force an update of values
        self.instrumentWidget.updateAll()
        self.dimensionWidget.updateChanges()
        #help
        self.assistantProcess = QtCore.QProcess(self)
        # pylint: disable=protected-access
        self.collectionFile = os.path.join(mantid._bindir,
                                           '../docs/qthelp/MantidProject.qhc')
        version = ".".join(mantid.__version__.split(".")[:2])
        self.qtUrl = 'qthelp://org.sphinx.mantidproject.' + version + '/doc/interfaces/DGSPlanner.html'
        self.externalUrl = 'http://docs.mantidproject.org/nightly/interfaces/DGSPlanner.html'
        #control for cancel button
        self.iterations = 0
        self.progress_canceled = False
Пример #13
0
]
x_mixing_ratios = [
    x_from_Tp((Bolton.mixing_ratio_line(p_all, MRi) + C_to_K), p_all)
    for MRi in mixing_ratios
]
mesh_T, mesh_p = np.meshgrid(
    np.arange(-60.0,
              T_levels.max() - C_to_K + 0.1, 0.1), p_all)
theta_ep_mesh = Bolton.theta_ep_field(mesh_T, mesh_p)

# Plotting Code!

skew_grid_helper = GridHelperCurveLinear((from_thermo, to_thermo))

fig = plt.figure()
ax = Subplot(fig, 1, 1, 1, grid_helper=skew_grid_helper)
fig.add_subplot(ax)

for yi in y_p_levels:
    ax.plot((x_min, x_max), (yi, yi), color=(1.0, 0.8, 0.8))

for x_T in x_T_levels:
    ax.plot(x_T, y_all_p, color=(1.0, 0.5, 0.5))

for x_theta in x_thetas:
    ax.plot(x_theta, y_all_p, color=(1.0, 0.7, 0.7))

for x_mixing_ratio in x_mixing_ratios:
    good = p_all >= 600  # restrict mixing ratio lines to below 600 mb
    ax.plot(x_mixing_ratio[good], y_all_p[good], color=(0.8, 0.8, 0.6))
Пример #14
0
    def __init__(self, parent=None, window_flags=None, ol=None):
        # pylint: disable=unused-argument,super-on-old-class
        super(DGSPlannerGUI, self).__init__(parent)
        if window_flags:
            self.setWindowFlags(window_flags)
        # OrientedLattice
        if ValidateOL(ol):
            self.ol = ol
        else:
            self.ol = mantid.geometry.OrientedLattice()
        self.masterDict = dict()  # holds info about instrument and ranges
        self.updatedInstrument = False
        self.instrumentWAND = False
        self.updatedOL = False
        self.wg = None  # workspace group
        self.instrumentWidget = InstrumentSetupWidget.InstrumentSetupWidget(
            self)
        self.setLayout(QtWidgets.QHBoxLayout())
        controlLayout = QtWidgets.QVBoxLayout()
        geometryBox = QtWidgets.QGroupBox("Instrument Geometry")
        plotBox = QtWidgets.QGroupBox("Plot Axes")
        geometryBoxLayout = QtWidgets.QVBoxLayout()
        geometryBoxLayout.addWidget(self.instrumentWidget)
        geometryBox.setLayout(geometryBoxLayout)
        controlLayout.addWidget(geometryBox)
        self.ublayout = QtWidgets.QHBoxLayout()
        self.classic = ClassicUBInputWidget.ClassicUBInputWidget(self.ol)
        self.ublayout.addWidget(self.classic,
                                alignment=QtCore.Qt.AlignTop,
                                stretch=1)
        self.matrix = MatrixUBInputWidget.MatrixUBInputWidget(self.ol)
        self.ublayout.addWidget(self.matrix,
                                alignment=QtCore.Qt.AlignTop,
                                stretch=1)
        sampleBox = QtWidgets.QGroupBox("Sample")
        sampleBox.setLayout(self.ublayout)
        controlLayout.addWidget(sampleBox)
        self.dimensionWidget = DimensionSelectorWidget.DimensionSelectorWidget(
            self)
        plotBoxLayout = QtWidgets.QVBoxLayout()
        plotBoxLayout.addWidget(self.dimensionWidget)
        plotControlLayout = QtWidgets.QGridLayout()
        self.plotButton = QtWidgets.QPushButton("Plot", self)
        self.oplotButton = QtWidgets.QPushButton("Overplot", self)
        self.helpButton = QtWidgets.QPushButton("?", self)
        self.colorLabel = QtWidgets.QLabel('Color by angle', self)
        self.colorButton = QtWidgets.QCheckBox(self)
        self.colorButton.toggle()
        self.aspectLabel = QtWidgets.QLabel('Aspect ratio 1:1', self)
        self.aspectButton = QtWidgets.QCheckBox(self)
        self.saveButton = QtWidgets.QPushButton("Save Figure", self)
        plotControlLayout.addWidget(self.plotButton, 0, 0)
        plotControlLayout.addWidget(self.oplotButton, 0, 1)
        plotControlLayout.addWidget(self.colorLabel, 0, 2,
                                    QtCore.Qt.AlignRight)
        plotControlLayout.addWidget(self.colorButton, 0, 3)
        plotControlLayout.addWidget(self.aspectLabel, 0, 4,
                                    QtCore.Qt.AlignRight)
        plotControlLayout.addWidget(self.aspectButton, 0, 5)
        plotControlLayout.addWidget(self.helpButton, 0, 6)
        plotControlLayout.addWidget(self.saveButton, 0, 7)
        plotBoxLayout.addLayout(plotControlLayout)
        plotBox = QtWidgets.QGroupBox("Plot Axes")
        plotBox.setLayout(plotBoxLayout)
        controlLayout.addWidget(plotBox)
        self.layout().addLayout(controlLayout)

        # figure
        self.figure = Figure()
        self.figure.patch.set_facecolor('white')
        self.canvas = FigureCanvas(self.figure)
        self.grid_helper = GridHelperCurveLinear((self.tr, self.inv_tr))
        self.trajfig = Subplot(self.figure,
                               1,
                               1,
                               1,
                               grid_helper=self.grid_helper)
        if matplotlib.compare_versions('2.1.0', matplotlib.__version__):
            self.trajfig.hold(
                True)  # hold is deprecated since 2.1.0, true by default
        self.figure.add_subplot(self.trajfig)
        self.toolbar = MantidNavigationToolbar(self.canvas, self)
        figureLayout = QtWidgets.QVBoxLayout()
        figureLayout.addWidget(self.toolbar, 0)
        figureLayout.addWidget(self.canvas, 1)
        self.layout().addLayout(figureLayout)
        self.needToClear = False
        self.saveDir = ''

        # connections
        self.matrix.UBmodel.changed.connect(self.updateUB)
        self.matrix.UBmodel.changed.connect(self.classic.updateOL)
        self.classic.changed.connect(self.matrix.UBmodel.updateOL)
        self.classic.changed.connect(self.updateUB)
        self.instrumentWidget.changed.connect(self.updateParams)
        self.instrumentWidget.getInstrumentComboBox().activated[str].connect(
            self.instrumentUpdateEvent)
        self.instrumentWidget.getEditEi().textChanged.connect(
            self.eiWavelengthUpdateEvent)
        self.dimensionWidget.changed.connect(self.updateParams)
        self.plotButton.clicked.connect(self.updateFigure)
        self.oplotButton.clicked.connect(self.updateFigure)
        self.helpButton.clicked.connect(self.help)
        self.saveButton.clicked.connect(self.save)
        # force an update of values
        self.instrumentWidget.updateAll()
        self.dimensionWidget.updateChanges()
        # help
        self.assistant_process = QtCore.QProcess(self)
        # pylint: disable=protected-access
        self.mantidplot_name = 'DGS Planner'
        # control for cancel button
        self.iterations = 0
        self.progress_canceled = False

        # register startup
        mantid.UsageService.registerFeatureUsage(
            mantid.kernel.FeatureType.Interface, "DGSPlanner", False)
Пример #15
0
    # http://matplotlib.org/examples/axes_grid/demo_curvelinear_grid.html
    from mpl_toolkits.axisartist import Subplot
    from mpl_toolkits.axisartist.grid_helper_curvelinear import \
        GridHelperCurveLinear

    def tr(x, y):  # source (data) to target (rectilinear plot) coordinates
        x, y = numpy.asarray(x), numpy.asarray(y)
        return x + 0.2 * y, y - x

    def inv_tr(x, y):
        x, y = numpy.asarray(x), numpy.asarray(y)
        return x - 0.2 * y, y + x

    grid_helper = GridHelperCurveLinear((tr, inv_tr))

    ax6 = Subplot(fig, nrow, ncol, 6, grid_helper=grid_helper)
    fig.add_subplot(ax6)
    ax6.set_title('non-ortho axes')

    xx, yy = tr([3, 6], [5.0, 10.])
    ax6.plot(xx, yy)

    ax6.set_aspect(1.)
    ax6.set_xlim(0, 10.)
    ax6.set_ylim(0, 10.)

    ax6.axis["t"] = ax6.new_floating_axis(0, 3.)
    ax6.axis["t2"] = ax6.new_floating_axis(1, 7.)
    ax6.grid(True)

    plt.show()
Пример #16
0
    def subplots(self,
                 nrows=1,
                 ncols=1,
                 n=None,
                 sharex=False,
                 sharey=False,
                 sharex_hspace=0.15,
                 sharey_wspace=0.15,
                 axisartist=True,
                 reshape=True):
        if n is not None and n > 1:
            ncols = np.ceil(np.sqrt(n))
            nrows = np.ceil(n / ncols)

        # Create empty object array to hold all axes. It's easiest to make it 1-d
        # so we can just append subplots upon creation, and then reshape
        nplots = int(nrows * ncols)
        axarr = np.empty(nplots, dtype=object)

        sharedx_axes = None
        sharedy_axes = None

        if sharex:
            self.subplots_adjust(hspace=sharex_hspace)
        if sharey:
            self.subplots_adjust(wspace=sharey_wspace)

        for i in range(nplots):
            if axisartist:
                subplot = Subplot(self,
                                  nrows,
                                  ncols,
                                  i + 1,
                                  sharex=sharedx_axes,
                                  sharey=sharedy_axes)
                axarr[i] = self.add_subplot(subplot)
            else:
                axarr[i] = self.add_subplot(nrows,
                                            ncols,
                                            i + 1,
                                            sharex=sharedx_axes,
                                            sharey=sharedy_axes)
            if ncols == 1 and sharex:
                if not sharedx_axes:
                    sharedx_axes = axarr[i]
                if i != nplots - 1:
                    axarr[i].axis["top", "bottom"].toggle(ticklabels=False)
                    # for label in axarr[i].get_xticklabels():
                    #     label.set_visible(False)
            if nrows == 1 and sharey:
                if not sharedy_axes:
                    sharedy_axes = axarr[i]
                if i != 0:
                    axarr[i].axis["left", "right"].toggle(ticklabels=False)
                    # for label in axarr[i].get_yticklabels():
                    #     label.set_visible(False)

        if reshape:
            if nplots == 1:
                return axarr[0]
            elif nrows == 1 or ncols == 1:
                return axarr

            axarr = axarr.reshape(nrows, ncols)

        return axarr
Пример #17
0
def plot_univariate_inequality(xdata,
                               inclusive=False,
                               raycolor='blue',
                               bgcolor='white',
                               xlim=None,
                               figsize=(5, 0.5),
                               figure=None,
                               title=None,
                               **kwargs):
    """Plot a chart of a univariate inequality ray with matplotlib
    
    Args:
        xdata (list, tuple): endpoints of the ray: ``[start, ..., end]``
        
    Kwargs:
        inclusive (bool): if True, draw a filled circle; if False, draw a circle without fill
        raycolor (str): matplotlib color of the line and markers
        bgcolor (str): matplotlib color for the background and circle without fill
        xlim (None, True, or list/tuple): if True, calculate defaults for xmin and xmax of the x-axis
        figsize (tuple): matplotlib figsize
        figure (None or matplotlib.figure.Figure): Figure to add a plot to
        title (None or str): Title to add the the plot
        kwargs: all other kwargs will be passed to `plt.figure()`
        
    Returns:
        matplotlib.figure.Figure: inequality figure
    """

    _xdata = sorted(xdata)
    xstart = _xdata[0]
    xend = _xdata[-1]
    lefttoright = xstart <= xend
    lefttoright2 = xdata[0] <= xdata[-1]

    fig1 = figure or plt.figure(figsize=figsize, facecolor=bgcolor, **kwargs)
    ax1 = Subplot(fig1, 111)
    fig1.add_subplot(ax1)

    style = {}
    style['linewidth'] = 4
    style['linestyle'] = 'solid'
    style['color'] = raycolor
    style['marker'] = 'o'
    style['markersize'] = 12
    style['markevery'] = [0] if lefttoright else [-1]
    style['markerfacecolor'] = raycolor if inclusive else bgcolor
    style['markeredgecolor'] = raycolor
    style['markeredgewidth'] = 4

    arrowstyle = {}
    arrowstyle['marker'] = '>' if lefttoright2 else '<'
    arrowstyle['markersize'] = 12
    arrowstyle['markerfacecolor'] = raycolor
    arrowstyle['markeredgecolor'] = raycolor
    arrowy = xdata[-1] if lefttoright else xdata[0]

    ax1.set(ylim=(-0.5, 1))

    if xlim is True:
        xlim = (float(xstart - 1), float(xend + 0.5))
    if xlim:
        ax1.set(xlim=xlim)

    ax1.axis["left"].set_visible(False)
    ax1.axis["right"].set_visible(False)
    ax1.axis["top"].set_visible(False)

    #ax1.axis["bottom"].set_axisline_style(axisline_style="->", size=2)
    ax1.plot(1, -0.5, ">k", transform=ax1.get_yaxis_transform(), clip_on=False)
    ax1.plot(0, -0.5, "<k", transform=ax1.get_yaxis_transform(), clip_on=False)

    _x = [xval for xval in xdata]
    _y = [0 for xval in xdata]
    plt.plot(_x, _y, **style)
    if title:
        ax1.set_title(title)
    plt.plot(arrowy, 0, **arrowstyle)
    return fig1
Пример #18
0
def plot2d(data, variables, filename):

    # default vales
    logscale_x = False
    logscale_y = False
    logable = True

    grid_bool = False

    axistop_bool = False
    axisright_bool = False
    axisbottom_bool = True
    axisleft_bool = True

    legend_outside = True

    colors = {}
    for i in range(0, len(variables)):
        colors[variables[i]] = ['w', 'b', 'g', 'r', 'c', 'm', 'y', 'k'][i]
    # intern parameters
    plot_range = 1. / 8  # means of two data series differ by this factor,
    # the scalng is defined ill

    # prepare data --->

    # check for appropriate format of input
    if (isinstance(variables, list)):
        for i in variables:
            if (not isinstance(i, str)):
                print('second argument must be a list of str')
                return None
    else:
        print('second argument must be a list')

    if (not isinstance(filename, str)):
        print('third argument must be of type str')
        return None
    if (not isinstance(data, list)):
        print('first argument must be of type list')
        return None
    for j in data[0]:
        if (not isinstance(j, float) and not isinstance(j, int)):
            print('data must be passed as a list of int or float')
            return None
    for i in data[1:]:
        for j in i:
            if (not isinstance(j, float) and not isinstance(j, int)):
                print('data must be passed as list of int or float')
                return None
            elif j <= 0 and logable:
                print('logscale impossible')
                logable = False  # to avoid using logscale

    #get number of rows and colums of data
    ncols = len(data)
    nrows = len(data[0])

    #check wether all rows have the same size and create df
    df = {}  # initialize empty dictionary similar to dataframe
    for i in range(0, ncols):
        if not len(data[i]) == nrows:
            print('missing values in data set')
            return None
        else:
            df[variables[i]] = data[i]

    # <--- prepare data

    # ---> analyse data

    # get lower and upper bounds of each variable
    bounds = {}

    for i in variables:
        [bounds['lower', i], bounds['upper', i]] = [min(df[i]), max(df[i])]

    # check whether magnitudes are approximately equal and use logscale if not
    means = [0] * (ncols - 1)
    for i in range(1, ncols):  # compute means of the columns
        means[i - 1] = np.mean(df[variables[i]])
    if min(means) / max(
            means
    ) < plot_range and logable:  # if variables have different magnitudes set
        print('ill scaling - use logscale')  # logscale
        logscale_y = True

    for i in range(
            1, ncols
    ):  # if one variable changes magnitude drastically set logscale
        if ((max(df[variables[i]]) - min(df[variables[i]])) != 0):
            temp = means[i - 1] / (max(df[variables[i]]) -
                                   min(df[variables[i]]))
            if temp < plot_range and logable and not logscale_y:
                print('very ill scaling - use logscale')
                logscale_y = True

    # <--- analyse data

    # plot --->
    fig = plt.figure(1)
    ax = Subplot(fig, 111)
    fig.add_subplot(ax)

    for i in variables[1:]:
        ax.scatter(df[variables[0]],
                   df[i],
                   label='$' + i + '$',
                   color=colors[i])
    # font

    # labels and legend
    ax.legend(loc='best')
    plt.xlabel(variables[0], **hfont)
    if legend_outside:
        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    # logscale
    if logscale_y:
        plt.yscale('log')
    if logscale_x:
        plt.xscale('log')
    # grid
    plt.grid(grid_bool)

    # visible axis
    ax.axis["right"].set_visible(axisright_bool)
    ax.axis["top"].set_visible(axistop_bool)
    ax.axis["bottom"].set_visible(axisbottom_bool)
    ax.axis["left"].set_visible(axisleft_bool)

    # <--- plot

    # save plot --->
    pdfdirectory = filename + '.pdf'
    pgfdirectory = filename + '.pgf'
    fig.savefig(pdfdirectory)
    plt.savefig(pgfdirectory)
    # <--- save plot

    print('run successful')
    return None
Пример #19
0
                     s=4,
                     marker='s',
                     vmin=vmin,
                     vmax=vmax,
                     latlon=True,
                     rasterized=True)
        smap.draw_inset_colorbar(label=r'mag arcsec$^{-2}$')

        ax1 = plt.gca()
        #ax1.annotate(f"{band}-band",(0.05,0.9),xycoords='axes fraction',
        #             fontsize=14)
        ax1.axis['right'].major_ticklabels.set_visible(False)
        #ax1.axis['top'].major_ticklabels.set_visible(False)

        # Plot histogram
        ax2 = Subplot(fig, gridspec[2])
        fig.add_subplot(ax2)
        plt.sca(ax2)
        #fig.add_subplot(gridspec[2])

        bins = np.linspace(x['sb'].min(), x['sb'].max(), 35)
        ret = draw_hist(hpxmap,
                        label=band,
                        color=v,
                        peak=False,
                        bins=bins,
                        density=True)
        ax2.yaxis.set_major_locator(MaxNLocator(6, prune='both'))
        ax2.axis['left'].major_ticklabels.set_visible(True)
        ax2.axis['right'].major_ticklabels.set_visible(False)
        ax2.axis['left'].label.set_visible(True)
Пример #20
0
def plotwitherrors(time, timeunit, means, meansunit, std):
    # ---> basic settings
    title = input('Title: ')
    y_name = input('Observable name: ')
    if title == '' or y_name == '':
        title = 'test'
        y_name = 'test'

    axistop_bool = False
    axisright_bool = False
    axisbottom_bool = True
    axisleft_bool = True
    # <--- basic settings

    # ---> analyse data
    # <--- analyse data

    # plot --->
    fig = plt.figure()
    ax = Subplot(fig, 111)
    fig.add_subplot(ax)

    ax.errorbar(time,
                means,
                yerr=std,
                marker='o',
                markersize=2,
                c='black',
                linestyle='none',
                ecolor='black',
                elinewidth=1,
                capthick=1,
                capsize=3)

    ax.plot(time,
            means,
            color='black',
            alpha=0.5,
            linewidth=1.0,
            linestyle='--')

    #legend and labels
    ax.legend(loc='best')
    plt.xlabel('time in ' + timeunit, **hfont)
    plt.ylabel(y_name + ' in ' + meansunit, **hfont)

    # visible axis
    ax.axis["right"].set_visible(axisright_bool)
    ax.axis["top"].set_visible(axistop_bool)
    ax.axis["bottom"].set_visible(axisbottom_bool)
    ax.axis["left"].set_visible(axisleft_bool)

    # <--- plot

    # save plot --->
    pdfdirectory = '../../figures/' + title + '.pdf'
    pgfdirectory = '../../figures/' + title + '.pgf'
    fig.savefig(pdfdirectory)
    plt.savefig(pgfdirectory)
    # <--- save plot

    print('test run successful')
    return None
Пример #21
0
def plotSpikeCountDistributions(sessions,groups=None,sessionTypes=None,samplingRate=30000.0,save=False,windowSize=40,fname=None,figsize=(10,6)):
    """Plots the isi distributions for all the cells in the given sessions"""
    fig = plt.figure(figsize=figsize)
    fig.subplots_adjust(left=0.05,right=.95)
    spikeCounts = {}
    ncells = 0
    nz = {}
    if sessionTypes != None:
        sessionTypes = dict(zip(sessions,sessionTypes))
    for g in groups:
        for s in sessions:
            dataFile = h5py.File(os.path.expanduser('~/Documents/research/data/spikesorting/hmm/p=1e-20/%sg%.4d.hdf5' % (s,g)),'r')
            try:
                for c in dataFile['unitTimePoints'].keys():
                    sptrain = dataFile['unitTimePoints'][c][:]/(samplingRate/1000)
                    nbins = sptrain.max()/windowSize
                    bins,bs = np.linspace(0,sptrain.max(),nbins,retstep=True)
                    sc,bins = np.histogram(sptrain,bins)
                    cn = 'g%dc%d' % (g,int(c))
                    sl = s
                    if sessionTypes != None:
                        sl = sessionTypes[s]
                    if cn in spikeCounts:
                        spikeCounts[cn]['%s' %(s,)] = {'counts': sc, 'bins':bins}
                        nz[cn]['%s' %( sl,)] = 1.0*sum(sc==0)/len(sc)
                    else:
                        spikeCounts[cn] = {'%s' %(s,):  {'counts':sc,'bins':bins}}
                        nz[cn] = {'%s' %(sl,): 1.0*sum(sc==0)/len(sc)}
            finally:
                dataFile.close()
        
    i = 1
    ncells = len(spikeCounts.keys())
    nsessions = len(sessions)
    colors = ['b','r','g','y','c','m']
    for c in spikeCounts.keys():
        ax = Subplot(fig,1,ncells,i)
        formatAxis(ax)
        fig.add_axes(ax)
        ax.set_title(c)
        j = 0
        for k,v in spikeCounts[c].items():
            #n,b = np.histogram(v['counts'],bins=20,normed=True)
            #plt.plot(b[:-1],n,label=k)
            if sessionTypes != None:
                L = sessionTypes[k]
            else:
                L = k
            n = np.bincount(v['counts'])
            b = np.unique(v['counts'])
            b = np.arange(b[0],len(n))
            ax.bar(b+j*0.2,1.0*n/n.sum(),align='center',width=0.3,fc=colors[j],label=L)
            j+=1
        i+=1
        ax.set_xlabel('Spike counts')
    fig.axes[-1].legend()
    if save:
        if fname == None:
            fname = os.path.expanduser('~/Documents/research/figures/isi_comparison.pdf')
        fig.savefig(fname,bbox='tight')

    return nz
Пример #22
0
    def __init__(
        self,
        figure=None,
        isotherm_locator=None,
        dry_adiabat_locator=None,
        anchor=None,
    ):
        """
        Initialise the tephigram transformation and plot axes.

        Kwargs:

        * figure:
            An existing :class:`matplotlib.figure.Figure` instance for the
            tephigram plot. If a figure is not provided, a new figure will
            be created by default.
        * isotherm_locator:
            A :class:`tephi.Locator` instance or a numeric step size
            for the isotherm lines.
        * dry_adiabat_locator:
            A :class:`tephi.Locator` instance or a numeric step size
            for the dry adiabat lines.
        * anchor:
            A sequence of two pressure, temperature pairs specifying the extent
            of the tephigram plot in terms of the bottom left hand corner and
            the top right hand corner. Pressure data points must be in units of
            mb or hPa, and temperature data points must be in units of degC.

        For example:

        .. plot::
            :include-source:

            import matplotlib.pyplot as plt
            from numpy import column_stack
            import os.path
            import tephi
            from tephi import Tephigram

            dew_point = os.path.join(tephi.DATA_DIR, 'dews.txt')
            dry_bulb = os.path.join(tephi.DATA_DIR, 'temps.txt')
            dew_data, temp_data = tephi.loadtxt(dew_point, dry_bulb)
            dews = column_stack((dew_data.pressure, dew_data.temperature))
            temps = column_stack((temp_data.pressure, temp_data.temperature))
            tpg = Tephigram()
            tpg.plot(dews, label='Dew-point', color='blue', linewidth=2)
            tpg.plot(temps, label='Dry-bulb', color='red', linewidth=2)
            plt.show()

        """
        if not figure:
            # Create a default figure.
            self.figure = plt.figure(0, figsize=(9, 9))
        else:
            self.figure = figure

        # Configure the locators.
        if isotherm_locator and not isinstance(isotherm_locator, Locator):
            if not isinstance(isotherm_locator, numbers.Number):
                raise ValueError("Invalid isotherm locator")
            locator_isotherm = Locator(isotherm_locator)
        else:
            locator_isotherm = isotherm_locator

        if dry_adiabat_locator and not isinstance(dry_adiabat_locator,
                                                  Locator):
            if not isinstance(dry_adiabat_locator, numbers.Number):
                raise ValueError("Invalid dry adiabat locator")
            locator_theta = Locator(dry_adiabat_locator)
        else:
            locator_theta = dry_adiabat_locator

        # Define the tephigram coordinate-system transformation.
        self.tephi_transform = transforms.TephiTransform()
        ghelper = GridHelperCurveLinear(
            self.tephi_transform,
            tick_formatter1=_FormatterIsotherm(),
            grid_locator1=locator_isotherm,
            tick_formatter2=_FormatterTheta(),
            grid_locator2=locator_theta,
        )
        self.axes = Subplot(self.figure, 1, 1, 1, grid_helper=ghelper)
        self.transform = self.tephi_transform + self.axes.transData
        self.axes.axis["isotherm"] = self.axes.new_floating_axis(1, 0)
        self.axes.axis["theta"] = self.axes.new_floating_axis(0, 0)
        self.axes.axis["left"].get_helper().nth_coord_ticks = 0
        self.axes.axis["left"].toggle(all=True)
        self.axes.axis["bottom"].get_helper().nth_coord_ticks = 1
        self.axes.axis["bottom"].toggle(all=True)
        self.axes.axis["top"].get_helper().nth_coord_ticks = 0
        self.axes.axis["top"].toggle(all=False)
        self.axes.axis["right"].get_helper().nth_coord_ticks = 1
        self.axes.axis["right"].toggle(all=True)
        self.axes.gridlines.set_linestyle("solid")

        self.figure.add_subplot(self.axes)

        # Configure default axes.
        axis = self.axes.axis["left"]
        axis.major_ticklabels.set_fontsize(10)
        axis.major_ticklabels.set_va("baseline")
        axis.major_ticklabels.set_rotation(135)
        axis = self.axes.axis["right"]
        axis.major_ticklabels.set_fontsize(10)
        axis.major_ticklabels.set_va("baseline")
        axis.major_ticklabels.set_rotation(-135)
        self.axes.axis["top"].major_ticklabels.set_fontsize(10)
        axis = self.axes.axis["bottom"]
        axis.major_ticklabels.set_fontsize(10)
        axis.major_ticklabels.set_ha("left")
        axis.major_ticklabels.set_va("top")
        axis.major_ticklabels.set_rotation(-45)

        # Isotherms: lines of constant temperature (degC).
        axis = self.axes.axis["isotherm"]
        axis.set_axis_direction("right")
        axis.set_axislabel_direction("-")
        axis.major_ticklabels.set_rotation(90)
        axis.major_ticklabels.set_fontsize(10)
        axis.major_ticklabels.set_va("bottom")
        axis.major_ticklabels.set_color("grey")
        axis.major_ticklabels.set_visible(False)  # turned-off

        # Dry adiabats: lines of constant potential temperature (degC).
        axis = self.axes.axis["theta"]
        axis.set_axis_direction("right")
        axis.set_axislabel_direction("+")
        axis.major_ticklabels.set_fontsize(10)
        axis.major_ticklabels.set_va("bottom")
        axis.major_ticklabels.set_color("grey")
        axis.major_ticklabels.set_visible(False)  # turned-off
        axis.line.set_linewidth(3)
        axis.line.set_linestyle("--")

        # Lock down the aspect ratio.
        self.axes.set_aspect(1.0)
        self.axes.grid(True)

        # Initialise the text formatter for the navigation status bar.
        self.axes.format_coord = self._status_bar

        # Factor in the tephigram transform.
        ISOBAR_TEXT["transform"] = self.transform
        WET_ADIABAT_TEXT["transform"] = self.transform
        MIXING_RATIO_TEXT["transform"] = self.transform

        # Create plot collections for the tephigram isopleths.
        func = partial(
            isopleths.isobar,
            MIN_THETA,
            MAX_THETA,
            self.axes,
            self.transform,
            ISOBAR_LINE,
        )
        self._isobars = _PlotCollection(
            self.axes,
            ISOBAR_SPEC,
            MAX_PRESSURE,
            func,
            ISOBAR_TEXT,
            fixed=ISOBAR_FIXED,
            minimum=MIN_PRESSURE,
        )

        func = partial(
            isopleths.wet_adiabat,
            MAX_PRESSURE,
            MIN_TEMPERATURE,
            self.axes,
            self.transform,
            WET_ADIABAT_LINE,
        )
        self._wet_adiabats = _PlotCollection(
            self.axes,
            WET_ADIABAT_SPEC,
            MAX_WET_ADIABAT,
            func,
            WET_ADIABAT_TEXT,
            fixed=WET_ADIABAT_FIXED,
            minimum=MIN_WET_ADIABAT,
            xfocus=True,
        )

        func = partial(
            isopleths.mixing_ratio,
            MIN_PRESSURE,
            MAX_PRESSURE,
            self.axes,
            self.transform,
            MIXING_RATIO_LINE,
        )
        self._mixing_ratios = _PlotCollection(
            self.axes,
            MIXING_RATIO_SPEC,
            MIXING_RATIOS,
            func,
            MIXING_RATIO_TEXT,
            fixed=MIXING_RATIO_FIXED,
        )

        # Initialise for the tephigram plot event handler.
        plt.connect("motion_notify_event", _handler)
        self.axes.tephigram = True
        self.axes.tephigram_original_delta_xlim = DEFAULT_WIDTH
        self.original_delta_xlim = DEFAULT_WIDTH
        self.axes.tephigram_transform = self.tephi_transform
        self.axes.tephigram_inverse = self.tephi_transform.inverted()
        self.axes.tephigram_isopleths = [
            self._isobars,
            self._wet_adiabats,
            self._mixing_ratios,
        ]

        # The tephigram profiles.
        self._profiles = []
        self.axes.tephigram_profiles = self._profiles

        # Center the plot around the anchor extent.
        self._anchor = anchor
        if self._anchor is not None:
            self._anchor = np.asarray(anchor)
            if (self._anchor.ndim != 2 or self._anchor.shape[-1] != 2
                    or len(self._anchor) != 2):
                msg = ("Invalid anchor, expecting [(bottom-left-pressure, "
                       "bottom-left-temperature), (top-right-pressure, "
                       "top-right-temperature)]")
                raise ValueError(msg)
            (
                (bottom_pressure, bottom_temp),
                (top_pressure, top_temp),
            ) = self._anchor

            if (bottom_pressure - top_pressure) < 0:
                raise ValueError("Invalid anchor pressure range")
            if (bottom_temp - top_temp) < 0:
                raise ValueError("Invalid anchor temperature range")

            self._anchor = isopleths.Profile(anchor, self.axes)
            self._anchor.plot(visible=False)
            xlim, ylim = self._calculate_extents()
            self.axes.set_xlim(xlim)
            self.axes.set_ylim(ylim)
Пример #23
0
def plotSpikes(qdata,save=False,fname='hmmSorting.pdf',tuning=False,figsize=(10,6)):

    allSpikes = qdata['allSpikes'] 
    unitSpikes = qdata['unitSpikes']
    spikeIdx = qdata['spikeIdx']
    spikeIdx = qdata['unitTimePoints']
    units = qdata['unitTimePoints']
    spikeForms = qdata['spikeForms']
    channels = qdata['channels']
    uniqueIdx = qdata['uniqueIdx']
    samplingRate = qdata.get('samplingRate',30000.0)
    """
    mustClose = False
    if isinstance(dataFile,str):
        dataFile = h5py.File(dataFile,'r')
        mustClose = True
    data = dataFile['data'][:]
    """
    keys = np.array(units.keys())
    x = np.arange(32)[None,:] + 42*np.arange(spikeForms.shape[1])[:,None]
    xt = np.linspace(0,31,spikeForms.shape[-1])[None,:] + 42*np.arange(spikeForms.shape[1])[:,None]
    xch = 10 + 42*np.arange(len(channels))
    for c in units.keys():
        ymin,ymax = (5000,-5000)
        fig = plt.figure(figsize=figsize)
        fig.subplots_adjust(hspace=0.3)
        print "Unit: %s " %(str(c),)
        print "\t Plotting waveforms..."
        sys.stdout.flush()
        #allspikes = data[units[c][:,None]+np.arange(-10,22)[None,:],:]
        #allspikes = allSpikes[spikeIdx[c]]
        allspikes = qdata['unitSpikes'][c]
        otherunits = keys[keys!=c]
        #nonOverlapIdx = np.prod(np.array([~np.lib.arraysetops.in1d(spikeIdx[c],spikeIdx[c1]) for c1 in otherunits]),axis=0).astype(np.bool)
        #nonOverlapIdx = np.prod(np.array([pdist_threshold(spikeIdx[c],spikeIdx[c1],3) for c1 in otherunits]),axis=0).astype(np.bool)
        #nonOverlapIdx = uniqueIdx[c]
        nonOverlapIdx = qdata['nonOverlapIdx'][c]
        overlapIdx = np.lib.arraysetops.setdiff1d(np.arange(qdata['unitTimePoints'][c].shape[0]),nonOverlapIdx)
        #allspikes = allSpikes[np.lib.arraysetops.union1d(nonOverlapIdx,overlapIdx)]
        ax = Subplot(fig,2,3,1)
        fig.add_axes(ax)
        formatAxis(ax)
        #plt.plot(x.T,sp,'b')
        m = allspikes[:].mean(0)
        s = allspikes[:].std(0)
        plt.plot(x.T,m,'k',lw=1.5)
        #find the minimum point for this template
        ich = spikeForms[int(c)].min(1).argmin()
        ix = spikeForms[int(c)][ich,:].argmin()
        #plt.plot(x.T,spikeForms[int(c)][:,ix-10:ix+22].T,'r')
        plt.plot(x.T,np.roll(spikeForms[int(c)],10-ix,axis=1)[:,:32].T,'r')
        for i in xrange(x.shape[0]):
            plt.fill_between(x[i],m[:,i]-s[:,i],m[:,i]+s[:,i],color='b',alpha=0.5)
        yl = ax.get_ylim()
        ymin = min(ymin,yl[0])
        ymax = max(ymax,yl[1])
        ax.set_title('All spikes (%d)' % (allspikes.shape[0],))

        ax = Subplot(fig,2,3,2)
        fig.add_axes(ax)
        formatAxis(ax)
        if len(nonOverlapIdx)>0:
            m =  allspikes[:][nonOverlapIdx,:,:].mean(0)
            s =  allspikes[:][nonOverlapIdx,:,:].std(0)
            plt.plot(x.T,m,'k',lw=1.5)
            for i in xrange(x.shape[0]):
                plt.fill_between(x[i],m[:,i]-s[:,i],m[:,i]+s[:,i],color='b',alpha=0.5)
        #plt.plot(x.T,spikeForms[int(c)][:,ix-10:ix+22].T,'r')
        plt.plot(x.T,np.roll(spikeForms[int(c)],10-ix,axis=1)[:,:32].T,'r')
        yl = ax.get_ylim()
        ymin = min(ymin,yl[0])
        ymax = max(ymax,yl[1])
        #for sp in allspikes[nonOverlapIdx,:,:]:
        #    plt.plot(x.T,sp,'r')

        ax.set_title('Non-overlap spikes (%d)' %(nonOverlapIdx.shape[0],))
        ax = Subplot(fig,2,3,3)
        fig.add_axes(ax)
        formatAxis(ax)
        if len(overlapIdx)>0:
            m =  allspikes[:][overlapIdx,:,:].mean(0)
            s =  allspikes[:][overlapIdx,:,:].std(0)
            plt.plot(x.T,m,'k',lw=1.5)
            for i in xrange(x.shape[0]):
                plt.fill_between(x[i],m[:,i]-s[:,i],m[:,i]+s[:,i],color='b',alpha=0.5)
        #plt.plot(x.T,spikeForms[int(c)][:,ix-10:ix+22].T,'r')
        plt.plot(x.T,np.roll(spikeForms[int(c)],10-ix,axis=1)[:,:32].T,'r')
        yl = ax.get_ylim()
        ymin = min(ymin,yl[0])
        ymax = max(ymax,yl[1])
        #for sp in allspikes[~nonOverlapIdx,:,:]:
        #    plt.plot(x.T,sp,'g')
        ax.set_title('Overlap spikes (%d)' % ((overlapIdx).shape[0],))
        for a in fig.axes:
            a.set_ylim((ymin,ymax))
            a.set_xticks(xch)
            a.set_xticklabels(map(str,channels))
            a.set_xlabel('Channels')
        for a in fig.axes[1:]:
            a.set_yticklabels([])
        fig.axes[0].set_ylabel('Amplitude')
        """
        isi distribution
        """
        print "\t ISI distribution..."
        sys.stdout.flush()
        timepoints = qdata['unitTimePoints'][c][:]/(samplingRate/1000)
        if len(timepoints)<2:
            print "Too few spikes. Aborting..."
            continue 
        isi = np.log(np.diff(timepoints))
        n,b = np.histogram(isi,100)
        ax = Subplot(fig,2,3,4)
        fig.add_axes(ax)
        formatAxis(ax)
        ax.plot(b[:-1],n,'k')
        yl = ax.get_ylim()
        ax.vlines(0.0,0,yl[1],'r',lw=1.5)
        ax.set_xlabel('ISI [ms]')
        #get xticklabels
        xl,xh = int(np.round((b[0]-0.5)*2))/2,int(np.round((b[-1]+0.5)*2))/2
        xl = -0.5
        dx = np.round(10.0*(xh-xl)/5.0)/10
        xt_ = np.arange(xl,xh+1,dx)
        ax.set_xticks(xt_)
        ax.set_xticklabels(map(lambda s: r'$10^{%.1f}$' % (s,),xt_))

        """
        auto-correlogram
        """
        print "\t auto-correllogram..."
        sys.stdout.flush()
        if not 'autoCorr' in qdata:
            if isinstance(qdata,dict):
                qdata['autoCorr'] = {}
            else:
                qdata.create_group('autoCorr')
        if not c in qdata['autoCorr']:
            C = pdist_threshold2(timepoints,timepoints,50)
            qdata['autoCorr'][c] = C
            if not isinstance(qdata,dict):
                qdata.flush()
        else:
            C = qdata['autoCorr'][c][:]
        n,b = np.histogram(C[C!=0],np.arange(-50,50))
        ax = Subplot(fig,2,3,5)
        fig.add_axes(ax)
        formatAxis(ax)
        ax.plot(b[:-1],n,'k')
        ax.fill_betweenx([0,n.max()],-1.0,1.0,color='r',alpha=0.3)
        ax.set_xlabel('Lag [ms]')
        if tuning:
            print "\tPlotting tuning..."
            sys.stdout.flush()
            #attempt to get tuning for the current session, based on PWD
            stimCounts,isiCounts,angles,spikedata = gt.getTuning(sptrain=timepoints)        
            #reshape to number of orientations X number of reps, collapsing
            #across everything else
            #angles = np.append(angles,[angles[0]])
            C = stimCounts['0'].transpose((1,0,2,3))
            C = C.reshape(C.shape[0],C.size/C.shape[0])
            ax = plt.subplot(2,3,6,polar=True) 
            ax.errorbar(angles*np.pi/180,C.mean(1),C.std(1))

        if save:
            if not os.path.isabs(fname):
                fn = os.path.expanduser('~/Documents/research/figures/SpikeSorting/hmm/%s' % (fname.replace('.pdf','Unit%s.pdf' %(str(c),)),))
            else:
                fn = fname.replace('.pdf','Unit%s.pdf' %(str(c),))

            fig.savefig(fn,bbox='tight')

    if not save:
        plt.draw()
    """