Exemplo n.º 1
0
def plotISIDistributions(sessions,groups=None,sessionTypes=None,samplingRate=30000.0,save=False,fname=None,figsize=(10,6)):
    """Plots the isi distributions for all the cells in the given sessions"""
    fig = plt.figure(figsize=figsize)
    fig.subplots_adjust(left=0.05,right=.95)
    ISIs = {}
    ncells = 0
    if sessionTypes != None:
        sessionTypes = dict(zip(sessions,sessionTypes))
    for g in groups:
        for s in sessions:
            dataFile = h5py.File(os.path.expanduser('~/Documents/research/data/spikesorting/hmm/p=1e-20/%sg%.4d.hdf5' % (s,g)),'r')
            try:
                for c in dataFile['unitTimePoints'].keys():
                    isi = np.log(np.diff(dataFile['unitTimePoints'][c][:]/(samplingRate/1000)))
                    cn = 'g%dc%d' % (g,int(c))
                    if cn in ISIs:
                        ISIs[cn]['%s' %(s,)] = isi
                    else:
                        ISIs[cn] = {'%s' %(s,): isi}
            finally:
                dataFile.close()
        
    i = 1
    ncells = len(ISIs.keys())
    for c in ISIs.keys():
        ax = Subplot(fig,1,ncells,i)
        formatAxis(ax)
        fig.add_axes(ax)
        ax.set_title(c)
        for k,v in ISIs[c].items():
            if sessionTypes != None:
                L = sessionTypes[k]
            else:
                L = k
            n,b = np.histogram(v,bins=20,normed=True)
            plt.plot(b[:-1],n,label=L)
        i+=1
        ax.set_xlabel('ISI [ms]')
        xl,xh = int(np.round((b[0]-0.5)*2))/2,int(np.round((b[-1]+0.5)*2))/2
        xl = -0.5
        dx = np.round(10.0*(xh-xl)/4.0)/10
        xt_ = np.arange(xl,xh+1,dx)
        ax.set_xticks(xt_)
        ax.set_xticklabels(map(lambda s: r'$10^{%.1f}$' % (s,),xt_))

    fig.axes[-1].legend()
    if save:
        if fname == None:
            fname = os.path.expanduser('~/Documents/research/figures/isi_comparison.pdf')
        fig.savefig(fname,bbox='tight')
Exemplo n.º 2
0
    from mpl_toolkits.axisartist import Subplot
    from mpl_toolkits.axisartist.grid_helper_curvelinear import \
        GridHelperCurveLinear

    def tr(x, y):  # source (data) to target (rectilinear plot) coordinates
        x, y = numpy.asarray(x), numpy.asarray(y)
        return x + 0.2 * y, y - x

    def inv_tr(x, y):
        x, y = numpy.asarray(x), numpy.asarray(y)
        return x - 0.2 * y, y + x

    grid_helper = GridHelperCurveLinear((tr, inv_tr))

    ax6 = Subplot(fig, nrow, ncol, 6, grid_helper=grid_helper)
    fig.add_subplot(ax6)
    ax6.set_title('non-ortho axes')

    xx, yy = tr([3, 6], [5.0, 10.])
    ax6.plot(xx, yy)

    ax6.set_aspect(1.)
    ax6.set_xlim(0, 10.)
    ax6.set_ylim(0, 10.)

    ax6.axis["t"] = ax6.new_floating_axis(0, 3.)
    ax6.axis["t2"] = ax6.new_floating_axis(1, 7.)
    ax6.grid(True)

    plt.show()
Exemplo n.º 3
0
def plotSpikeCountDistributions(sessions,groups=None,sessionTypes=None,samplingRate=30000.0,save=False,windowSize=40,fname=None,figsize=(10,6)):
    """Plots the isi distributions for all the cells in the given sessions"""
    fig = plt.figure(figsize=figsize)
    fig.subplots_adjust(left=0.05,right=.95)
    spikeCounts = {}
    ncells = 0
    nz = {}
    if sessionTypes != None:
        sessionTypes = dict(zip(sessions,sessionTypes))
    for g in groups:
        for s in sessions:
            dataFile = h5py.File(os.path.expanduser('~/Documents/research/data/spikesorting/hmm/p=1e-20/%sg%.4d.hdf5' % (s,g)),'r')
            try:
                for c in dataFile['unitTimePoints'].keys():
                    sptrain = dataFile['unitTimePoints'][c][:]/(samplingRate/1000)
                    nbins = sptrain.max()/windowSize
                    bins,bs = np.linspace(0,sptrain.max(),nbins,retstep=True)
                    sc,bins = np.histogram(sptrain,bins)
                    cn = 'g%dc%d' % (g,int(c))
                    sl = s
                    if sessionTypes != None:
                        sl = sessionTypes[s]
                    if cn in spikeCounts:
                        spikeCounts[cn]['%s' %(s,)] = {'counts': sc, 'bins':bins}
                        nz[cn]['%s' %( sl,)] = 1.0*sum(sc==0)/len(sc)
                    else:
                        spikeCounts[cn] = {'%s' %(s,):  {'counts':sc,'bins':bins}}
                        nz[cn] = {'%s' %(sl,): 1.0*sum(sc==0)/len(sc)}
            finally:
                dataFile.close()
        
    i = 1
    ncells = len(spikeCounts.keys())
    nsessions = len(sessions)
    colors = ['b','r','g','y','c','m']
    for c in spikeCounts.keys():
        ax = Subplot(fig,1,ncells,i)
        formatAxis(ax)
        fig.add_axes(ax)
        ax.set_title(c)
        j = 0
        for k,v in spikeCounts[c].items():
            #n,b = np.histogram(v['counts'],bins=20,normed=True)
            #plt.plot(b[:-1],n,label=k)
            if sessionTypes != None:
                L = sessionTypes[k]
            else:
                L = k
            n = np.bincount(v['counts'])
            b = np.unique(v['counts'])
            b = np.arange(b[0],len(n))
            ax.bar(b+j*0.2,1.0*n/n.sum(),align='center',width=0.3,fc=colors[j],label=L)
            j+=1
        i+=1
        ax.set_xlabel('Spike counts')
    fig.axes[-1].legend()
    if save:
        if fname == None:
            fname = os.path.expanduser('~/Documents/research/figures/isi_comparison.pdf')
        fig.savefig(fname,bbox='tight')

    return nz
Exemplo n.º 4
0
def plotSpikes(qdata,save=False,fname='hmmSorting.pdf',tuning=False,figsize=(10,6)):

    allSpikes = qdata['allSpikes'] 
    unitSpikes = qdata['unitSpikes']
    spikeIdx = qdata['spikeIdx']
    spikeIdx = qdata['unitTimePoints']
    units = qdata['unitTimePoints']
    spikeForms = qdata['spikeForms']
    channels = qdata['channels']
    uniqueIdx = qdata['uniqueIdx']
    samplingRate = qdata.get('samplingRate',30000.0)
    """
    mustClose = False
    if isinstance(dataFile,str):
        dataFile = h5py.File(dataFile,'r')
        mustClose = True
    data = dataFile['data'][:]
    """
    keys = np.array(units.keys())
    x = np.arange(32)[None,:] + 42*np.arange(spikeForms.shape[1])[:,None]
    xt = np.linspace(0,31,spikeForms.shape[-1])[None,:] + 42*np.arange(spikeForms.shape[1])[:,None]
    xch = 10 + 42*np.arange(len(channels))
    for c in units.keys():
        ymin,ymax = (5000,-5000)
        fig = plt.figure(figsize=figsize)
        fig.subplots_adjust(hspace=0.3)
        print "Unit: %s " %(str(c),)
        print "\t Plotting waveforms..."
        sys.stdout.flush()
        #allspikes = data[units[c][:,None]+np.arange(-10,22)[None,:],:]
        #allspikes = allSpikes[spikeIdx[c]]
        allspikes = qdata['unitSpikes'][c]
        otherunits = keys[keys!=c]
        #nonOverlapIdx = np.prod(np.array([~np.lib.arraysetops.in1d(spikeIdx[c],spikeIdx[c1]) for c1 in otherunits]),axis=0).astype(np.bool)
        #nonOverlapIdx = np.prod(np.array([pdist_threshold(spikeIdx[c],spikeIdx[c1],3) for c1 in otherunits]),axis=0).astype(np.bool)
        #nonOverlapIdx = uniqueIdx[c]
        nonOverlapIdx = qdata['nonOverlapIdx'][c]
        overlapIdx = np.lib.arraysetops.setdiff1d(np.arange(qdata['unitTimePoints'][c].shape[0]),nonOverlapIdx)
        #allspikes = allSpikes[np.lib.arraysetops.union1d(nonOverlapIdx,overlapIdx)]
        ax = Subplot(fig,2,3,1)
        fig.add_axes(ax)
        formatAxis(ax)
        #plt.plot(x.T,sp,'b')
        m = allspikes[:].mean(0)
        s = allspikes[:].std(0)
        plt.plot(x.T,m,'k',lw=1.5)
        #find the minimum point for this template
        ich = spikeForms[int(c)].min(1).argmin()
        ix = spikeForms[int(c)][ich,:].argmin()
        #plt.plot(x.T,spikeForms[int(c)][:,ix-10:ix+22].T,'r')
        plt.plot(x.T,np.roll(spikeForms[int(c)],10-ix,axis=1)[:,:32].T,'r')
        for i in xrange(x.shape[0]):
            plt.fill_between(x[i],m[:,i]-s[:,i],m[:,i]+s[:,i],color='b',alpha=0.5)
        yl = ax.get_ylim()
        ymin = min(ymin,yl[0])
        ymax = max(ymax,yl[1])
        ax.set_title('All spikes (%d)' % (allspikes.shape[0],))

        ax = Subplot(fig,2,3,2)
        fig.add_axes(ax)
        formatAxis(ax)
        if len(nonOverlapIdx)>0:
            m =  allspikes[:][nonOverlapIdx,:,:].mean(0)
            s =  allspikes[:][nonOverlapIdx,:,:].std(0)
            plt.plot(x.T,m,'k',lw=1.5)
            for i in xrange(x.shape[0]):
                plt.fill_between(x[i],m[:,i]-s[:,i],m[:,i]+s[:,i],color='b',alpha=0.5)
        #plt.plot(x.T,spikeForms[int(c)][:,ix-10:ix+22].T,'r')
        plt.plot(x.T,np.roll(spikeForms[int(c)],10-ix,axis=1)[:,:32].T,'r')
        yl = ax.get_ylim()
        ymin = min(ymin,yl[0])
        ymax = max(ymax,yl[1])
        #for sp in allspikes[nonOverlapIdx,:,:]:
        #    plt.plot(x.T,sp,'r')

        ax.set_title('Non-overlap spikes (%d)' %(nonOverlapIdx.shape[0],))
        ax = Subplot(fig,2,3,3)
        fig.add_axes(ax)
        formatAxis(ax)
        if len(overlapIdx)>0:
            m =  allspikes[:][overlapIdx,:,:].mean(0)
            s =  allspikes[:][overlapIdx,:,:].std(0)
            plt.plot(x.T,m,'k',lw=1.5)
            for i in xrange(x.shape[0]):
                plt.fill_between(x[i],m[:,i]-s[:,i],m[:,i]+s[:,i],color='b',alpha=0.5)
        #plt.plot(x.T,spikeForms[int(c)][:,ix-10:ix+22].T,'r')
        plt.plot(x.T,np.roll(spikeForms[int(c)],10-ix,axis=1)[:,:32].T,'r')
        yl = ax.get_ylim()
        ymin = min(ymin,yl[0])
        ymax = max(ymax,yl[1])
        #for sp in allspikes[~nonOverlapIdx,:,:]:
        #    plt.plot(x.T,sp,'g')
        ax.set_title('Overlap spikes (%d)' % ((overlapIdx).shape[0],))
        for a in fig.axes:
            a.set_ylim((ymin,ymax))
            a.set_xticks(xch)
            a.set_xticklabels(map(str,channels))
            a.set_xlabel('Channels')
        for a in fig.axes[1:]:
            a.set_yticklabels([])
        fig.axes[0].set_ylabel('Amplitude')
        """
        isi distribution
        """
        print "\t ISI distribution..."
        sys.stdout.flush()
        timepoints = qdata['unitTimePoints'][c][:]/(samplingRate/1000)
        if len(timepoints)<2:
            print "Too few spikes. Aborting..."
            continue 
        isi = np.log(np.diff(timepoints))
        n,b = np.histogram(isi,100)
        ax = Subplot(fig,2,3,4)
        fig.add_axes(ax)
        formatAxis(ax)
        ax.plot(b[:-1],n,'k')
        yl = ax.get_ylim()
        ax.vlines(0.0,0,yl[1],'r',lw=1.5)
        ax.set_xlabel('ISI [ms]')
        #get xticklabels
        xl,xh = int(np.round((b[0]-0.5)*2))/2,int(np.round((b[-1]+0.5)*2))/2
        xl = -0.5
        dx = np.round(10.0*(xh-xl)/5.0)/10
        xt_ = np.arange(xl,xh+1,dx)
        ax.set_xticks(xt_)
        ax.set_xticklabels(map(lambda s: r'$10^{%.1f}$' % (s,),xt_))

        """
        auto-correlogram
        """
        print "\t auto-correllogram..."
        sys.stdout.flush()
        if not 'autoCorr' in qdata:
            if isinstance(qdata,dict):
                qdata['autoCorr'] = {}
            else:
                qdata.create_group('autoCorr')
        if not c in qdata['autoCorr']:
            C = pdist_threshold2(timepoints,timepoints,50)
            qdata['autoCorr'][c] = C
            if not isinstance(qdata,dict):
                qdata.flush()
        else:
            C = qdata['autoCorr'][c][:]
        n,b = np.histogram(C[C!=0],np.arange(-50,50))
        ax = Subplot(fig,2,3,5)
        fig.add_axes(ax)
        formatAxis(ax)
        ax.plot(b[:-1],n,'k')
        ax.fill_betweenx([0,n.max()],-1.0,1.0,color='r',alpha=0.3)
        ax.set_xlabel('Lag [ms]')
        if tuning:
            print "\tPlotting tuning..."
            sys.stdout.flush()
            #attempt to get tuning for the current session, based on PWD
            stimCounts,isiCounts,angles,spikedata = gt.getTuning(sptrain=timepoints)        
            #reshape to number of orientations X number of reps, collapsing
            #across everything else
            #angles = np.append(angles,[angles[0]])
            C = stimCounts['0'].transpose((1,0,2,3))
            C = C.reshape(C.shape[0],C.size/C.shape[0])
            ax = plt.subplot(2,3,6,polar=True) 
            ax.errorbar(angles*np.pi/180,C.mean(1),C.std(1))

        if save:
            if not os.path.isabs(fname):
                fn = os.path.expanduser('~/Documents/research/figures/SpikeSorting/hmm/%s' % (fname.replace('.pdf','Unit%s.pdf' %(str(c),)),))
            else:
                fn = fname.replace('.pdf','Unit%s.pdf' %(str(c),))

            fig.savefig(fn,bbox='tight')

    if not save:
        plt.draw()
    """
Exemplo n.º 5
0
class SliceViewerDataView(QWidget):
    """The view for the data portion of the sliceviewer"""

    def __init__(self, presenter: IDataViewSubscriber, dims_info, can_normalise, parent=None, conf=None):
        super().__init__(parent)

        self.presenter = presenter

        self.image = None
        self.line_plots_active = False
        self.can_normalise = can_normalise
        self.nonortho_transform = None
        self.conf = conf

        self._line_plots = None
        self._image_info_tracker = None
        self._region_selection_on = False
        self._orig_lims = None

        # Dimension widget
        self.dimensions_layout = QGridLayout()
        self.dimensions = DimensionWidget(dims_info, parent=self)
        self.dimensions.dimensionsChanged.connect(self.presenter.dimensions_changed)
        self.dimensions.valueChanged.connect(self.presenter.slicepoint_changed)
        self.dimensions_layout.addWidget(self.dimensions, 1, 0, 1, 1)

        self.colorbar_layout = QVBoxLayout()
        self.colorbar_layout.setContentsMargins(0, 0, 0, 0)
        self.colorbar_layout.setSpacing(0)

        self.image_info_widget = ImageInfoWidget(self)
        self.image_info_widget.setToolTip("Information about the selected pixel")
        self.track_cursor = QCheckBox("Track Cursor", self)
        self.track_cursor.setToolTip(
            "Update the image readout table when the cursor is over the plot. "
            "If unticked the table will update only when the plot is clicked")
        self.dimensions_layout.setHorizontalSpacing(10)
        self.dimensions_layout.addWidget(self.track_cursor, 0, 1, Qt.AlignRight)
        self.dimensions_layout.addWidget(self.image_info_widget, 1, 1)
        self.track_cursor.setChecked(True)
        self.track_cursor.stateChanged.connect(self.on_track_cursor_state_change)

        # normalization options
        if can_normalise:
            self.norm_label = QLabel("Normalization")
            self.colorbar_layout.addWidget(self.norm_label)
            self.norm_opts = QComboBox()
            self.norm_opts.addItems(["None", "By bin width"])
            self.norm_opts.setToolTip("Normalization options")
            self.colorbar_layout.addWidget(self.norm_opts)

        # MPL figure + colorbar
        self.fig = Figure()
        self.ax = None
        self.image = None
        self._grid_on = False
        self.fig.set_facecolor(self.palette().window().color().getRgbF())
        self.canvas = SliceViewerCanvas(self.fig)
        self.canvas.mpl_connect('button_release_event', self.mouse_release)
        self.canvas.mpl_connect('button_press_event', self.presenter.canvas_clicked)

        self.colorbar_label = QLabel("Colormap")
        self.colorbar_layout.addWidget(self.colorbar_label)
        norm_scale = self.get_default_scale_norm()
        self.colorbar = ColorbarWidget(self, norm_scale)
        self.colorbar.cmap.setToolTip("Colormap options")
        self.colorbar.crev.setToolTip("Reverse colormap")
        self.colorbar.norm.setToolTip("Colormap normalisation options")
        self.colorbar.powerscale.setToolTip("Power colormap scale")
        self.colorbar.cmax.setToolTip("Colormap maximum limit")
        self.colorbar.cmin.setToolTip("Colormap minimum limit")
        self.colorbar.autoscale.setToolTip("Automatically changes colormap limits when zooming on the plot")
        self.colorbar_layout.addWidget(self.colorbar)
        self.colorbar.colorbarChanged.connect(self.update_data_clim)
        self.colorbar.scaleNormChanged.connect(self.scale_norm_changed)
        # make width larger to fit image readout table
        self.colorbar.setMaximumWidth(200)

        # MPL toolbar
        self.toolbar_layout = QHBoxLayout()
        self.mpl_toolbar = SliceViewerNavigationToolbar(self.canvas, self, False)
        self.mpl_toolbar.gridClicked.connect(self.toggle_grid)
        self.mpl_toolbar.linePlotsClicked.connect(self.on_line_plots_toggle)
        self.mpl_toolbar.regionSelectionClicked.connect(self.on_region_selection_toggle)
        self.mpl_toolbar.homeClicked.connect(self.on_home_clicked)
        self.mpl_toolbar.nonOrthogonalClicked.connect(self.on_non_orthogonal_axes_toggle)
        self.mpl_toolbar.zoomPanClicked.connect(self.presenter.zoom_pan_clicked)
        self.mpl_toolbar.zoomPanFinished.connect(self.on_data_limits_changed)
        self.toolbar_layout.addWidget(self.mpl_toolbar)

        # Status bar
        self.status_bar = QStatusBar(parent=self)
        self.status_bar.setStyleSheet('QStatusBar::item {border: None;}')  # Hide spacers between button and label
        self.status_bar_label = QLabel()
        self.help_button = QToolButton()
        self.help_button.setText("?")
        self.status_bar.addWidget(self.help_button)
        self.status_bar.addWidget(self.status_bar_label)

        # layout
        layout = QGridLayout(self)
        layout.setSpacing(1)
        layout.addLayout(self.dimensions_layout, 0, 0, 1, 2)
        layout.addLayout(self.toolbar_layout, 1, 0, 1, 1)
        layout.addLayout(self.colorbar_layout, 1, 1, 3, 1)
        layout.addWidget(self.canvas, 2, 0, 1, 1)
        layout.addWidget(self.status_bar, 3, 0, 1, 1)
        layout.setRowStretch(2, 1)

    @property
    def grid_on(self):
        return self._grid_on

    @property
    def line_plotter(self):
        return self._line_plots

    @property
    def nonorthogonal_mode(self):
        return self.nonortho_transform is not None

    def create_axes_orthogonal(self, redraw_on_zoom=False):
        """Create a standard set of orthogonal axes
        :param redraw_on_zoom: If True then when scroll zooming the canvas is redrawn immediately
        """
        self.clear_figure()
        self.nonortho_transform = None
        self.ax = self.fig.add_subplot(111, projection='mantid')
        self.enable_zoom_on_mouse_scroll(redraw_on_zoom)
        if self.grid_on:
            self.ax.grid(self.grid_on)
        if self.line_plots_active:
            self.add_line_plots()

        self.plot_MDH = self.plot_MDH_orthogonal

        self.canvas.draw_idle()

    def create_axes_nonorthogonal(self, transform):
        self.clear_figure()
        self.set_nonorthogonal_transform(transform)
        self.ax = CurveLinearSubPlot(self.fig,
                                     1,
                                     1,
                                     1,
                                     grid_helper=GridHelperCurveLinear(
                                         (transform.tr, transform.inv_tr)))
        # don't redraw on zoom as the data is rebinned and has to be redrawn again anyway
        self.enable_zoom_on_mouse_scroll(redraw=False)
        self.set_grid_on()
        self.fig.add_subplot(self.ax)
        self.plot_MDH = self.plot_MDH_nonorthogonal

        self.canvas.draw_idle()

    def enable_zoom_on_mouse_scroll(self, redraw):
        """Enable zoom on scroll the mouse wheel for the created axes
        :param redraw: Pass through to redraw option in enable_zoom_on_scroll
        """
        self.canvas.enable_zoom_on_scroll(self.ax,
                                          redraw=redraw,
                                          toolbar=self.mpl_toolbar,
                                          callback=self.on_data_limits_changed)

    def add_line_plots(self, toolcls, exporter):
        """Assuming line plots are currently disabled, enable them on the current figure
        The image axes must have been created first.
        :param toolcls: Use this class to handle creating the plots
        :param exporter: Object defining methods to export cuts/roi
        """
        if self.line_plots_active:
            return

        self.line_plots_active = True
        self._line_plots = toolcls(LinePlots(self.ax, self.colorbar), exporter)
        self.status_bar_label.setText(self._line_plots.status_message())
        self.canvas.setFocus()
        self.mpl_toolbar.set_action_checked(ToolItemText.LINEPLOTS, True, trigger=False)

    def switch_line_plots_tool(self, toolcls, exporter):
        """Assuming line plots are currently enabled then switch the tool used to
        generate the plot curves.
        :param toolcls: Use this class to handle creating the plots
        """
        if not self.line_plots_active:
            return

        # Keep the same set of line plots axes but swap the selection tool
        plotter = self._line_plots.plotter
        plotter.delete_line_plot_lines()
        self._line_plots.disconnect()
        self._line_plots = toolcls(plotter, exporter)
        self.status_bar_label.setText(self._line_plots.status_message())
        self.canvas.setFocus()
        self.canvas.draw_idle()

    def remove_line_plots(self):
        """Assuming line plots are currently enabled, remove them from the current figure
        """
        if not self.line_plots_active:
            return

        self._line_plots.plotter.close()
        self.status_bar_label.clear()
        self._line_plots = None
        self.line_plots_active = False

    def plot_MDH_orthogonal(self, ws, **kwargs):
        """
        clears the plot and creates a new one using a MDHistoWorkspace
        """
        self.clear_image()
        self.image = self.ax.imshow(ws,
                                    origin='lower',
                                    aspect='auto',
                                    transpose=self.dimensions.transpose,
                                    norm=self.colorbar.get_norm(),
                                    **kwargs)
        # ensure the axes data limits are updated to match the
        # image. For example if the axes were zoomed and the
        # swap dimensions was clicked we need to restore the
        # appropriate extents to see the image in the correct place
        extent = self.image.get_extent()
        self.ax.set_xlim(extent[0], extent[1])
        self.ax.set_ylim(extent[2], extent[3])
        # Set the original data limits which get passed to the ImageInfoWidget so that
        # the mouse projection to data space is correct for MDH workspaces when zoomed/changing slices
        self._orig_lims = self.get_axes_limits()

        self.on_track_cursor_state_change(self.track_cursor_checked())

        self.draw_plot()

    def plot_MDH_nonorthogonal(self, ws, **kwargs):
        self.clear_image()
        self.image = pcolormesh_nonorthogonal(self.ax,
                                              ws,
                                              self.nonortho_transform.tr,
                                              transpose=self.dimensions.transpose,
                                              norm=self.colorbar.get_norm(),
                                              **kwargs)
        self.on_track_cursor_state_change(self.track_cursor_checked())

        # swapping dimensions in nonorthogonal mode currently resets back to the
        # full data limits as the whole axes has been recreated so we don't have
        # access to the original limits
        # pcolormesh clears any grid that was previously visible
        if self.grid_on:
            self.ax.grid(self.grid_on)
        self.draw_plot()

    def plot_matrix(self, ws, **kwargs):
        """
        clears the plot and creates a new one using a MatrixWorkspace keeping
        the axes limits that have already been set
        """
        # ensure view is correct if zoomed in while swapping dimensions
        # compute required extent and just have resampling imshow deal with it
        old_extent = None
        if self.image is not None:
            old_extent = self.image.get_extent()
            if self.image.transpose != self.dimensions.transpose:
                e1, e2, e3, e4 = old_extent
                old_extent = e3, e4, e1, e2

        self.clear_image()
        self.image = self.ax.imshow(ws,
                                    origin='lower',
                                    aspect='auto',
                                    interpolation='none',
                                    transpose=self.dimensions.transpose,
                                    norm=self.colorbar.get_norm(),
                                    extent=old_extent,
                                    **kwargs)
        self.on_track_cursor_state_change(self.track_cursor_checked())

        self.draw_plot()

    def clear_image(self):
        """Removes any image from the axes"""
        if self.image is not None:
            if self.line_plots_active:
                self._line_plots.plotter.delete_line_plot_lines()
            self.image_info_widget.cursorAt(DBLMAX, DBLMAX, DBLMAX)
            if hasattr(self.ax, "remove_artists_if"):
                self.ax.remove_artists_if(lambda art: art == self.image)
            else:
                self.image.remove()
            self.image = None

    def clear_figure(self):
        """Removes everything from the figure"""
        if self.line_plots_active:
            self._line_plots.plotter.close()
            self.line_plots_active = False
        self.image = None
        self.canvas.disable_zoom_on_scroll()
        self.fig.clf()
        self.ax = None

    def draw_plot(self):
        self.ax.set_title('')
        self.canvas.draw()
        if self.image:
            self.colorbar.set_mappable(self.image)
            self.colorbar.update_clim()
        self.mpl_toolbar.update()  # clear nav stack
        if self.line_plots_active:
            self._line_plots.plotter.delete_line_plot_lines()
            self._line_plots.plotter.update_line_plot_labels()

    def export_region(self, limits, cut):
        """
        React to a region selection that should be exported
        :param limits: 2-tuple of ((left, right), (bottom, top))
        :param cut: A str denoting which cuts to export.
        """
        self.presenter.export_region(limits, cut)

    def update_plot_data(self, data):
        """
        This just updates the plot data without creating a new plot. The extents
        can change if the data has been rebinned.
        """
        if self.nonortho_transform:
            self.image.set_array(data.T.ravel())
        else:
            self.image.set_data(data.T)
        self.colorbar.update_clim()

    def track_cursor_checked(self):
        return self.track_cursor.isChecked() if self.track_cursor else False

    def on_track_cursor_state_change(self, state):
        """
        Called to notify the current state of the track cursor box
        """
        if self._image_info_tracker is not None:
            self._image_info_tracker.disconnect()
        if self._line_plots is not None and not self._region_selection_on:
            self._line_plots.disconnect()

        self._image_info_tracker = ImageInfoTracker(image=self.image,
                                                    transform=self.nonortho_transform,
                                                    do_transform=self.nonorthogonal_mode,
                                                    widget=self.image_info_widget,
                                                    cursor_transform=self._orig_lims)

        if state:
            self._image_info_tracker.connect()
            if self._line_plots and not self._region_selection_on:
                self._line_plots.connect()
        else:
            self._image_info_tracker.disconnect()
            if self._line_plots and not self._region_selection_on:
                self._line_plots.disconnect()

    def on_home_clicked(self):
        """Reset the view to encompass all of the data"""
        self.presenter.show_all_data_clicked()

    def on_line_plots_toggle(self, state):
        """Switch state of the line plots"""
        self.presenter.line_plots(state)

    def on_region_selection_toggle(self, state):
        """Switch state of the region selection"""
        self.presenter.region_selection(state)
        self._region_selection_on = state
        # If state is off and track cursor is on, make sure line plots are re-connected to move cursor
        if not state and self.track_cursor_checked():
            if self._line_plots:
                self._line_plots.connect()

    def on_non_orthogonal_axes_toggle(self, state):
        """
        Switch state of the non-orthognal axes on/off
        """
        self.presenter.nonorthogonal_axes(state)

    def on_data_limits_changed(self):
        """
        React to when the data limits have changed
        """
        self.presenter.data_limits_changed()

    def deactivate_and_disable_tool(self, tool_text):
        """Deactivate a tool as if the control had been pressed and disable the functionality"""
        self.deactivate_tool(tool_text)
        self.disable_tool_button(tool_text)

    def activate_tool(self, tool_text):
        """Activate a given tool as if the control had been pressed"""
        self.mpl_toolbar.set_action_checked(tool_text, True)

    def deactivate_tool(self, tool_text):
        """Deactivate a given tool as if the tool button had been pressed"""
        self.mpl_toolbar.set_action_checked(tool_text, False)

    def enable_tool_button(self, tool_text):
        """Set a given tool button enabled so it can be interacted with"""
        self.mpl_toolbar.set_action_enabled(tool_text, True)

    def disable_tool_button(self, tool_text):
        """Set a given tool button disabled so it cannot be interacted with"""
        self.mpl_toolbar.set_action_enabled(tool_text, False)

    def get_axes_limits(self):
        """
        Return the limits on the image axes transformed into the nonorthogonal frame if appropriate
        """
        if self.image is None:
            return None
        else:
            xlim, ylim = self.ax.get_xlim(), self.ax.get_ylim()
            if self.nonorthogonal_mode:
                inv_tr = self.nonortho_transform.inv_tr
                # viewing axis y not aligned with plot axis
                xmin_p, ymax_p = inv_tr(xlim[0], ylim[1])
                xmax_p, ymin_p = inv_tr(xlim[1], ylim[0])
                xlim, ylim = (xmin_p, xmax_p), (ymin_p, ymax_p)
            return xlim, ylim

    def get_full_extent(self):
        """
        Return the full extent of image - only applicable for plots of matrix workspaces
        """
        if self.image and isinstance(self.image, samplingimage.SamplingImage):
            return self.image.get_full_extent()
        else:
            return None

    def set_axes_limits(self, xlim, ylim):
        """
        Set the view limits on the image axes to the given extents. Assume the
        limits are in the orthogonal frame.

        :param xlim: 2-tuple of (xmin, xmax)
        :param ylim: 2-tuple of (ymin, ymax)
        """
        self.ax.set_xlim(xlim)
        self.ax.set_ylim(ylim)

    def set_grid_on(self):
        """
        If not visible sets the grid visibility
        """
        if not self._grid_on:
            self._grid_on = True
            self.mpl_toolbar.set_action_checked(ToolItemText.GRID, state=self._grid_on)

    def set_nonorthogonal_transform(self, transform):
        """
        Set the transform for nonorthogonal axes mode
        :param transform: An object with a tr method to transform from nonorthognal
                          coordinates to display coordinates
        """
        self.nonortho_transform = transform

    def show_temporary_status_message(self, msg, timeout_ms):
        """
        Show a message in the status bar that disappears after a set period
        :param msg: A str message to display
        :param timeout_ms: Timeout in milliseconds to display the message for
        """
        self.status_bar.showMessage(msg, timeout_ms)

    def toggle_grid(self, state):
        """
        Toggle the visibility of the grid on the axes
        """
        self._grid_on = state
        self.ax.grid(self._grid_on)
        self.canvas.draw_idle()

    def mouse_release(self, event):
        if event.inaxes != self.ax:
            return
        self.canvas.setFocus()
        if event.button == 1:
            self._image_info_tracker.on_cursor_at(event.xdata, event.ydata)
            if self.line_plots_active and not self._region_selection_on:
                self._line_plots.on_cursor_at(event.xdata, event.ydata)
        if event.button == 3:
            self.on_home_clicked()

    def deactivate_zoom_pan(self):
        self.deactivate_tool(ToolItemText.PAN)
        self.deactivate_tool(ToolItemText.ZOOM)

    def update_data_clim(self):
        self.image.set_clim(self.colorbar.colorbar.mappable.get_clim())
        if self.line_plots_active:
            self._line_plots.plotter.update_line_plot_limits()
        self.canvas.draw_idle()

    def set_normalization(self, ws, **kwargs):
        normalize_by_bin_width, _ = get_normalize_by_bin_width(ws, self.ax, **kwargs)
        is_normalized = normalize_by_bin_width or ws.isDistribution()
        self.presenter.normalization = is_normalized
        if is_normalized:
            self.norm_opts.setCurrentIndex(1)
        else:
            self.norm_opts.setCurrentIndex(0)

    def get_default_scale_norm(self):
        scale = 'Linear'
        if self.conf is None:
            return scale

        if self.conf.has(SCALENORM):
            scale = self.conf.get(SCALENORM)

        if scale == 'Power' and self.conf.has(POWERSCALE):
            exponent = self.conf.get(POWERSCALE)
            scale = (scale, exponent)

        scale = "SymmetricLog10" if scale == 'Log' else scale
        return scale

    def scale_norm_changed(self):
        if self.conf is None:
            return

        scale = self.colorbar.norm.currentText()
        self.conf.set(SCALENORM, scale)

        if scale == 'Power':
            exponent = self.colorbar.powerscale_value
            self.conf.set(POWERSCALE, exponent)
Exemplo n.º 6
0
class SliceViewerDataView(QWidget):
    """The view for the data portion of the sliceviewer"""
    def __init__(self, presenter, dims_info, can_normalise, parent=None):
        super().__init__(parent)

        self.presenter = presenter

        self.image = None
        self.line_plots = False
        self.can_normalise = can_normalise
        self.nonortho_tr = None

        # Dimension widget
        self.dimensions_layout = QHBoxLayout()
        self.dimensions = DimensionWidget(dims_info, parent=self)
        self.dimensions.dimensionsChanged.connect(
            self.presenter.dimensions_changed)
        self.dimensions.valueChanged.connect(self.presenter.slicepoint_changed)
        self.dimensions_layout.addWidget(self.dimensions)

        self.colorbar_layout = QVBoxLayout()

        # normalization options
        if can_normalise:
            self.norm_layout = QHBoxLayout()
            self.norm_label = QLabel("Normalization =")
            self.norm_layout.addWidget(self.norm_label)
            self.norm_opts = QComboBox()
            self.norm_opts.addItems(["None", "By bin width"])
            self.norm_opts.setToolTip("Normalization options")
            self.norm_layout.addWidget(self.norm_opts)
            self.colorbar_layout.addLayout(self.norm_layout)

        # MPL figure + colorbar
        self.mpl_layout = QHBoxLayout()
        self.fig = Figure()
        self.ax = None
        self._grid_on = False
        self.fig.set_facecolor(self.palette().window().color().getRgbF())
        self.canvas = FigureCanvas(self.fig)
        self.canvas.mpl_connect('motion_notify_event', self.mouse_move)
        self.create_axes_orthogonal()
        self.mpl_layout.addWidget(self.canvas)
        self.colorbar = ColorbarWidget(self)
        self.colorbar_layout.addWidget(self.colorbar)
        self.colorbar.colorbarChanged.connect(self.update_data_clim)
        self.colorbar.colorbarChanged.connect(self.update_line_plot_limits)
        self.mpl_layout.addLayout(self.colorbar_layout)

        # MPL toolbar
        self.mpl_toolbar = SliceViewerNavigationToolbar(self.canvas, self)
        self.mpl_toolbar.gridClicked.connect(self.toggle_grid)
        self.mpl_toolbar.linePlotsClicked.connect(self.line_plots_toggle)
        self.mpl_toolbar.plotOptionsChanged.connect(
            self.colorbar.mappable_changed)
        self.mpl_toolbar.nonOrthogonalClicked.connect(
            self.non_orthogonal_axes_toggle)

        # layout
        self.layout = QGridLayout(self)
        self.layout.addLayout(self.dimensions_layout, 0, 0)
        self.layout.addWidget(self.mpl_toolbar, 1, 0)
        self.layout.addLayout(self.mpl_layout, 2, 0)

    @property
    def grid_on(self):
        return self._grid_on

    @property
    def nonorthogonal_mode(self):
        return self.nonortho_tr is not None

    def create_axes_orthogonal(self):
        self.clear_figure()
        self.nonortho_tr = None
        self.ax = self.fig.add_subplot(111, projection='mantid')
        if self.grid_on:
            self.ax.grid()
        if self.line_plots:
            self.add_line_plots()
        self.plot_MDH = self.plot_MDH_orthogonal

        self.canvas.draw_idle()

    def create_axes_nonorthogonal(self, transform):
        self.clear_figure()
        self.set_nonorthogonal_transform(transform)
        self.ax = CurveLinearSubPlot(self.fig,
                                     1,
                                     1,
                                     1,
                                     grid_helper=GridHelperCurveLinear(
                                         (self.nonortho_tr, transform.inv_tr)))
        self.set_grid_on()
        self.fig.add_subplot(self.ax)
        self.plot_MDH = self.plot_MDH_nonorthogonal

        self.canvas.draw_idle()

    def add_line_plots(self):
        """Assuming line plots are currently disabled, enable them on the current figure
        The image axes must have been created first.
        """
        if self.line_plots:
            return
        image_axes = self.ax
        if image_axes is None:
            return

        # Create a new GridSpec and reposition the existing image Axes
        gs = gridspec.GridSpec(2,
                               2,
                               width_ratios=[1, 4],
                               height_ratios=[4, 1],
                               wspace=0.0,
                               hspace=0.0)
        image_axes.set_position(gs[1].get_position(self.fig))
        image_axes.xaxis.set_visible(False)
        image_axes.yaxis.set_visible(False)
        self.axx = self.fig.add_subplot(gs[3], sharex=image_axes)
        self.axx.yaxis.tick_right()
        self.axy = self.fig.add_subplot(gs[0], sharey=image_axes)
        self.axy.xaxis.tick_top()

        self.mpl_toolbar.update()  # sync list of axes in navstack
        self.canvas.draw_idle()

    def remove_line_plots(self):
        """Assuming line plots are currently enabled, remove them from the current figure
        """
        if not self.line_plots:
            return
        image_axes = self.ax
        if image_axes is None:
            return

        self.clear_line_plots()
        all_axes = self.fig.axes
        # The order is defined by the order of the add_subplot calls so we always want to remove
        # the last two Axes. Do it backwards to cope with the container size change
        all_axes[2].remove()
        all_axes[1].remove()

        gs = gridspec.GridSpec(1, 1)
        image_axes.set_position(gs[0].get_position(self.fig))
        image_axes.xaxis.set_visible(True)
        image_axes.yaxis.set_visible(True)
        self.axx, self.axy = None, None

        self.mpl_toolbar.update()  # sync list of axes in navstack
        self.canvas.draw_idle()

    def plot_MDH_orthogonal(self, ws, **kwargs):
        """
        clears the plot and creates a new one using a MDHistoWorkspace
        """
        self.clear_image()
        self.image = self.ax.imshow(ws,
                                    origin='lower',
                                    aspect='auto',
                                    transpose=self.dimensions.transpose,
                                    norm=self.colorbar.get_norm(),
                                    **kwargs)
        self.draw_plot()

    def plot_MDH_nonorthogonal(self, ws, **kwargs):
        self.clear_image()
        self.image = pcolormesh_nonorthogonal(
            self.ax,
            ws,
            self.nonortho_tr,
            transpose=self.dimensions.transpose,
            norm=self.colorbar.get_norm(),
            **kwargs)
        # pcolormesh clears any grid that was previously visible
        if self.grid_on:
            self.ax.grid()
        self.draw_plot()

    def plot_matrix(self, ws, **kwargs):
        """
        clears the plot and creates a new one using a MatrixWorkspace
        """
        self.clear_image()
        self.image = imshow_sampling(self.ax,
                                     ws,
                                     origin='lower',
                                     aspect='auto',
                                     interpolation='none',
                                     transpose=self.dimensions.transpose,
                                     norm=self.colorbar.get_norm(),
                                     **kwargs)
        self.image._resample_image()
        self.draw_plot()

    def clear_image(self):
        """Removes any image from the axes"""
        if self.image is not None:
            self.image.remove()
            self.image = None

    def clear_figure(self):
        """Removes everything from the figure"""
        self.image = None
        self._grid_on = False
        self.fig.clf()

    def draw_plot(self):
        self.ax.set_title('')
        self.colorbar.set_mappable(self.image)
        self.colorbar.update_clim()
        self.mpl_toolbar.update()  # clear nav stack
        self.clear_line_plots()
        self.canvas.draw_idle()

    def update_plot_data(self, data):
        """
        This just updates the plot data without creating a new plot
        """
        if self.nonortho_tr:
            self.image.set_array(data.T.ravel())
        else:
            self.image.set_data(data.T)
        self.colorbar.update_clim()

    def line_plots_toggle(self, state):
        self.presenter.line_plots(state)
        self.line_plots = state

    def non_orthogonal_axes_toggle(self, state):
        """
        Switch state of the non-orthognal axes on/off
        """
        self.presenter.nonorthogonal_axes(state)

    def enable_lineplots_button(self):
        """
        Enables line plots functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.LINEPLOTS, True)

    def disable_lineplots_button(self):
        """
        Disabled line plots functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.LINEPLOTS, False)

    def enable_peaks_button(self):
        """
        Enables line plots functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.OVERLAYPEAKS, True)

    def disable_peaks_button(self):
        """
        Disables line plots functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.OVERLAYPEAKS, False)

    def disable_nonorthogonal_axes_button(self):
        """
        Disables non-orthorognal axes functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.NONORTHOGONAL_AXES,
                                            False)

    def clear_line_plots(self):
        try:  # clear old plots
            del self.xfig
            del self.yfig
        except AttributeError:
            pass

    def update_data_clim(self):
        self.image.set_clim(self.colorbar.colorbar.mappable.get_clim())
        self.canvas.draw_idle()

    def update_line_plot_limits(self):
        if self.line_plots:
            self.axx.set_ylim(self.colorbar.cmin_value,
                              self.colorbar.cmax_value)
            self.axy.set_xlim(self.colorbar.cmin_value,
                              self.colorbar.cmax_value)

    def set_grid_on(self):
        """
        If not visible sets the grid visibility
        """
        if not self._grid_on:
            self.toggle_grid()

    def set_nonorthogonal_transform(self, transform):
        """
        Set the transform for nonorthogonal axes mode
        :param transform: An object with a tr method to transform from nonorthognal
                          coordinates to display coordinates
        """
        self.nonortho_tr = transform.tr

    def toggle_grid(self):
        """
        Toggle the visibility of the grid on the axes
        """
        self.ax.grid()
        self._grid_on = not self._grid_on
        self.canvas.draw_idle()

    def mouse_move(self, event):
        if self.line_plots and event.inaxes == self.ax:
            self.update_line_plots(event.xdata, event.ydata)

    def plot_x_line(self, x, y):
        try:
            self.xfig[0].set_data(x, y)
        except (AttributeError, IndexError):
            self.axx.clear()
            self.xfig = self.axx.plot(x, y)
            self.axx.set_xlabel(self.ax.get_xlabel())
            self.update_line_plot_limits()
        self.canvas.draw_idle()

    def plot_y_line(self, x, y):
        try:
            self.yfig[0].set_data(y, x)
        except (AttributeError, IndexError):
            self.axy.clear()
            self.yfig = self.axy.plot(y, x)
            self.axy.set_ylabel(self.ax.get_ylabel())
            self.update_line_plot_limits()
        self.canvas.draw_idle()

    def update_line_plots(self, x, y):
        xmin, xmax, ymin, ymax = self.image.get_extent()
        arr = self.image.get_array()
        data_extent = Bbox([[ymin, xmin], [ymax, xmax]])
        array_extent = Bbox([[0, 0], arr.shape[:2]])
        trans = BboxTransform(boxin=data_extent, boxout=array_extent)
        point = trans.transform_point([y, x])
        if any(np.isnan(point)):
            return
        i, j = point.astype(int)
        if 0 <= i < arr.shape[0]:
            self.plot_x_line(np.linspace(xmin, xmax, arr.shape[1]), arr[i, :])
        if 0 <= j < arr.shape[1]:
            self.plot_y_line(np.linspace(ymin, ymax, arr.shape[0]), arr[:, j])

    def set_normalization(self, ws, **kwargs):
        normalize_by_bin_width, _ = get_normalize_by_bin_width(
            ws, self.ax, **kwargs)
        is_normalized = normalize_by_bin_width or ws.isDistribution()
        if is_normalized:
            self.presenter.normalization = mantid.api.MDNormalization.VolumeNormalization
            self.norm_opts.setCurrentIndex(1)
        else:
            self.presenter.normalization = mantid.api.MDNormalization.NoNormalization
            self.norm_opts.setCurrentIndex(0)
Exemplo n.º 7
0
class SliceViewerDataView(QWidget):
    """The view for the data portion of the sliceviewer"""
    def __init__(self, presenter, dims_info, can_normalise, parent=None):
        super().__init__(parent)

        self.presenter = presenter

        self.image = None
        self.line_plots = False
        self.can_normalise = can_normalise
        self.nonortho_tr = None
        self.ws_type = dims_info[0]['type']

        # Dimension widget
        self.dimensions_layout = QGridLayout()
        self.dimensions = DimensionWidget(dims_info, parent=self)
        self.dimensions.dimensionsChanged.connect(
            self.presenter.dimensions_changed)
        self.dimensions.valueChanged.connect(self.presenter.slicepoint_changed)
        self.dimensions_layout.addWidget(self.dimensions, 1, 0, 1, 1)

        self.colorbar_layout = QVBoxLayout()
        self.colorbar_layout.setContentsMargins(0, 0, 0, 0)
        self.colorbar_layout.setSpacing(0)

        self.image_info_widget = ImageInfoWidget(self.ws_type, self)
        self.track_cursor = QCheckBox("Track Cursor", self)
        self.track_cursor.setToolTip(
            "Update the image readout table when the cursor is over the plot. "
            "If unticked the table will update only when the plot is clicked")
        if self.ws_type == 'MDE':
            self.colorbar_layout.addWidget(self.image_info_widget,
                                           alignment=Qt.AlignCenter)
            self.colorbar_layout.addWidget(self.track_cursor)
        else:
            self.dimensions_layout.setHorizontalSpacing(10)
            self.dimensions_layout.addWidget(self.track_cursor, 0, 1,
                                             Qt.AlignRight)
            self.dimensions_layout.addWidget(self.image_info_widget, 1, 1)
        self.track_cursor.setChecked(True)

        # normalization options
        if can_normalise:
            self.norm_label = QLabel("Normalization")
            self.colorbar_layout.addWidget(self.norm_label)
            self.norm_opts = QComboBox()
            self.norm_opts.addItems(["None", "By bin width"])
            self.norm_opts.setToolTip("Normalization options")
            self.colorbar_layout.addWidget(self.norm_opts)

        # MPL figure + colorbar
        self.fig = Figure()
        self.ax = None
        self.axx, self.axy = None, None
        self.image = None
        self._grid_on = False
        self.fig.set_facecolor(self.palette().window().color().getRgbF())
        self.canvas = SliceViewerCanvas(self.fig)
        self.canvas.mpl_connect('motion_notify_event', self.mouse_move)
        self.canvas.mpl_connect('axes_leave_event', self.mouse_outside_image)
        self.canvas.mpl_connect('button_press_event', self.mouse_click)
        self.canvas.mpl_connect('button_release_event', self.mouse_release)

        self.colorbar_label = QLabel("Colormap")
        self.colorbar_layout.addWidget(self.colorbar_label)
        self.colorbar = ColorbarWidget(self)
        self.colorbar_layout.addWidget(self.colorbar)
        self.colorbar.colorbarChanged.connect(self.update_data_clim)
        self.colorbar.colorbarChanged.connect(self.update_line_plot_limits)
        # make width larger to fit image readout table
        if self.ws_type == 'MDE':
            self.colorbar.setMaximumWidth(155)

        # MPL toolbar
        self.toolbar_layout = QHBoxLayout()
        self.mpl_toolbar = SliceViewerNavigationToolbar(
            self.canvas, self, False)
        self.mpl_toolbar.gridClicked.connect(self.toggle_grid)
        self.mpl_toolbar.linePlotsClicked.connect(self.on_line_plots_toggle)
        self.mpl_toolbar.homeClicked.connect(self.on_home_clicked)
        self.mpl_toolbar.plotOptionsChanged.connect(
            self.colorbar.mappable_changed)
        self.mpl_toolbar.nonOrthogonalClicked.connect(
            self.on_non_orthogonal_axes_toggle)
        self.mpl_toolbar.zoomPanFinished.connect(self.on_data_limits_changed)
        self.toolbar_layout.addWidget(self.mpl_toolbar)

        # layout
        layout = QGridLayout(self)
        layout.setSpacing(1)
        layout.addLayout(self.dimensions_layout, 0, 0, 1, 2)
        layout.addLayout(self.toolbar_layout, 1, 0, 1, 2)
        layout.addWidget(self.canvas, 2, 0, 1, 1)
        layout.addLayout(self.colorbar_layout, 1, 1, 2, 1)
        layout.setRowStretch(2, 1)

    @property
    def grid_on(self):
        return self._grid_on

    @property
    def nonorthogonal_mode(self):
        return self.nonortho_tr is not None

    def create_axes_orthogonal(self, redraw_on_zoom=False):
        """Create a standard set of orthogonal axes
        :param redraw_on_zoom: If True then when scroll zooming the canvas is redrawn immediately
        """
        self.clear_figure()
        self.nonortho_tr = None
        self.ax = self.fig.add_subplot(111, projection='mantid')
        self.enable_zoom_on_mouse_scroll(redraw_on_zoom)
        if self.grid_on:
            self.ax.grid(self.grid_on)
        if self.line_plots:
            self.add_line_plots()

        self.plot_MDH = self.plot_MDH_orthogonal

        self.canvas.draw_idle()

    def create_axes_nonorthogonal(self, transform):
        self.clear_figure()
        self.set_nonorthogonal_transform(transform)
        self.ax = CurveLinearSubPlot(self.fig,
                                     1,
                                     1,
                                     1,
                                     grid_helper=GridHelperCurveLinear(
                                         (self.nonortho_tr, transform.inv_tr)))
        # don't redraw on zoom as the data is rebinned and has to be redrawn again anyway
        self.enable_zoom_on_mouse_scroll(redraw=False)
        self.set_grid_on()
        self.fig.add_subplot(self.ax)
        self.plot_MDH = self.plot_MDH_nonorthogonal

        self.canvas.draw_idle()

    def enable_zoom_on_mouse_scroll(self, redraw):
        """Enable zoom on scroll the mouse wheel for the created axes
        :param redraw: Pass through to redraw option in enable_zoom_on_scroll
        """
        self.canvas.enable_zoom_on_scroll(self.ax,
                                          redraw=redraw,
                                          toolbar=self.mpl_toolbar,
                                          callback=self.on_data_limits_changed)

    def add_line_plots(self):
        """Assuming line plots are currently disabled, enable them on the current figure
        The image axes must have been created first.
        """
        if self.line_plots:
            return

        self.line_plots = True
        image_axes = self.ax
        if image_axes is None:
            return

        # Create a new GridSpec and reposition the existing image Axes
        gs = gridspec.GridSpec(2,
                               2,
                               width_ratios=[1, 4],
                               height_ratios=[4, 1],
                               wspace=0.0,
                               hspace=0.0)
        image_axes.set_position(gs[1].get_position(self.fig))
        set_artist_property(image_axes.get_xticklabels(), visible=False)
        set_artist_property(image_axes.get_yticklabels(), visible=False)
        self.axx = self.fig.add_subplot(gs[3], sharex=image_axes)
        self.axx.yaxis.tick_right()
        self.axy = self.fig.add_subplot(gs[0], sharey=image_axes)
        self.axy.xaxis.tick_top()
        self.update_line_plot_labels()
        self.mpl_toolbar.update()  # sync list of axes in navstack
        self.canvas.draw_idle()

    def remove_line_plots(self):
        """Assuming line plots are currently enabled, remove them from the current figure
        """
        if not self.line_plots:
            return

        self.line_plots = False
        image_axes = self.ax
        if image_axes is None:
            return

        self.delete_line_plot_lines()
        all_axes = self.fig.axes
        # The order is defined by the order of the add_subplot calls so we always want to remove
        # the last two Axes. Do it backwards to cope with the container size change
        all_axes[2].remove()
        all_axes[1].remove()

        gs = gridspec.GridSpec(1, 1)
        image_axes.set_position(gs[0].get_position(self.fig))
        image_axes.xaxis.tick_bottom()
        image_axes.yaxis.tick_left()
        self.axx, self.axy = None, None

        self.mpl_toolbar.update()  # sync list of axes in navstack
        self.canvas.draw_idle()

    def plot_MDH_orthogonal(self, ws, **kwargs):
        """
        clears the plot and creates a new one using a MDHistoWorkspace
        """
        self.clear_image()
        self.image = self.ax.imshow(ws,
                                    origin='lower',
                                    aspect='auto',
                                    transpose=self.dimensions.transpose,
                                    norm=self.colorbar.get_norm(),
                                    **kwargs)
        # ensure the axes data limits are updated to match the
        # image. For example if the axes were zoomed and the
        # swap dimensions was clicked we need to restore the
        # appropriate extents to see the image in the correct place
        extent = self.image.get_extent()
        self.ax.set_xlim(extent[0], extent[1])
        self.ax.set_ylim(extent[2], extent[3])

        self.draw_plot()

    def plot_MDH_nonorthogonal(self, ws, **kwargs):
        self.clear_image()
        self.image = pcolormesh_nonorthogonal(
            self.ax,
            ws,
            self.nonortho_tr,
            transpose=self.dimensions.transpose,
            norm=self.colorbar.get_norm(),
            **kwargs)
        # swapping dimensions in nonorthogonal mode currently resets back to the
        # full data limits as the whole axes has been recreated so we don't have
        # access to the original limits
        # pcolormesh clears any grid that was previously visible
        if self.grid_on:
            self.ax.grid(self.grid_on)
        self.draw_plot()

    def plot_matrix(self, ws, **kwargs):
        """
        clears the plot and creates a new one using a MatrixWorkspace keeping
        the axes limits that have already been set
        """
        # ensure view is correct if zoomed in while swapping dimensions
        # compute required extent and just have resampling imshow deal with it
        old_extent = None
        if self.image is not None:
            old_extent = self.image.get_extent()
            if self.image.transpose != self.dimensions.transpose:
                e1, e2, e3, e4 = old_extent
                old_extent = e3, e4, e1, e2

        self.clear_image()
        self.image = imshow_sampling(self.ax,
                                     ws,
                                     origin='lower',
                                     aspect='auto',
                                     interpolation='none',
                                     transpose=self.dimensions.transpose,
                                     norm=self.colorbar.get_norm(),
                                     extent=old_extent,
                                     **kwargs)

        self.draw_plot()

    def clear_image(self):
        """Removes any image from the axes"""
        if self.image is not None:
            if self.line_plots:
                self.delete_line_plot_lines()
            self.image_info_widget.updateTable(DBLMAX, DBLMAX, DBLMAX)
            self.image.remove()
            self.image = None

    def clear_figure(self):
        """Removes everything from the figure"""
        if self.line_plots:
            self.delete_line_plot_lines()
            self.axx, self.axy = None, None
        self.image = None
        self.canvas.disable_zoom_on_scroll()
        self.fig.clf()
        self.ax = None

    def draw_plot(self):
        self.ax.set_title('')
        self.colorbar.set_mappable(self.image)
        self.colorbar.update_clim()
        self.mpl_toolbar.update()  # clear nav stack
        self.delete_line_plot_lines()
        self.update_line_plot_labels()
        self.canvas.draw_idle()

    def select_zoom(self):
        """Select the zoom control on the toolbar"""
        self.mpl_toolbar.zoom()

    def update_plot_data(self, data):
        """
        This just updates the plot data without creating a new plot. The extents
        can change if the data has been rebinned
        """
        if self.nonortho_tr:
            self.image.set_array(data.T.ravel())
        else:
            self.image.set_data(data.T)
        self.colorbar.update_clim()

    def on_home_clicked(self):
        """Reset the view to encompass all of the data"""
        self.presenter.show_all_data_requested()

    def on_line_plots_toggle(self, state):
        self.presenter.line_plots(state)

    def on_non_orthogonal_axes_toggle(self, state):
        """
        Switch state of the non-orthognal axes on/off
        """
        self.presenter.nonorthogonal_axes(state)

    def on_data_limits_changed(self):
        """
        React to when the data limits have changed
        """
        self.presenter.data_limits_changed()

    def enable_lineplots_button(self):
        """
        Enables line plots functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.LINEPLOTS, True)

    def disable_lineplots_button(self):
        """
        Disabled line plots functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.LINEPLOTS, False)

    def enable_peaks_button(self):
        """
        Enables line plots functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.OVERLAYPEAKS, True)

    def disable_peaks_button(self):
        """
        Disables line plots functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.OVERLAYPEAKS, False)

    def enable_nonorthogonal_axes_button(self):
        """
        Enables access to non-orthogonal axes functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.NONORTHOGONAL_AXES,
                                            True)

    def disable_nonorthogonal_axes_button(self):
        """
        Disables non-orthorognal axes functionality
        """
        self.mpl_toolbar.set_action_enabled(ToolItemText.NONORTHOGONAL_AXES,
                                            state=False)

    def delete_line_plot_lines(self):
        try:  # clear old plots
            try:
                self.xfig.remove()
                self.yfig.remove()
            except ValueError:
                pass
            del self.xfig
            del self.yfig
        except AttributeError:
            pass

    def get_axes_limits(self):
        """
        Return the limits of the image axes or None if no image yet exists
        """
        if self.image is None:
            return None
        else:
            return self.ax.get_xlim(), self.ax.get_ylim()

    def set_axes_limits(self, xlim, ylim):
        """
        Set the view limits on the image axes to the given extents
        :param xlim: 2-tuple of (xmin, xmax)
        :param ylim: 2-tuple of (ymin, ymax)
        """
        self.ax.set_xlim(xlim)
        self.ax.set_ylim(ylim)

    def set_grid_on(self):
        """
        If not visible sets the grid visibility
        """
        if not self._grid_on:
            self._grid_on = True
            self.mpl_toolbar.set_action_checked(ToolItemText.GRID,
                                                state=self._grid_on)

    def set_nonorthogonal_transform(self, transform):
        """
        Set the transform for nonorthogonal axes mode
        :param transform: An object with a tr method to transform from nonorthognal
                          coordinates to display coordinates
        """
        self.nonortho_tr = transform.tr

    def toggle_grid(self, state):
        """
        Toggle the visibility of the grid on the axes
        """
        self._grid_on = state
        self.ax.grid(self._grid_on)
        self.canvas.draw_idle()

    def mouse_move(self, event):
        if event.inaxes == self.ax:
            signal = self.update_image_data(event.xdata, event.ydata,
                                            self.line_plots)
            if self.track_cursor.checkState() == Qt.Checked:
                self.update_image_table_widget(event.xdata, event.ydata,
                                               signal)

    def mouse_outside_image(self, _):
        """
        Indicates that the mouse have moved outside of an axes.
        We clear the line plots so that it is not confusing what they mean.
        """
        if self.line_plots:
            self.delete_line_plot_lines()
            self.canvas.draw_idle()

    def mouse_click(self, event):
        if self.track_cursor.checkState() == Qt.Unchecked \
                and event.inaxes == self.ax and event.button == 1:
            signal = self.update_image_data(event.xdata, event.ydata)
            self.update_image_table_widget(event.xdata, event.ydata, signal)

    def mouse_release(self, event):
        if event.button == 3 and event.inaxes == self.ax:
            self.on_home_clicked()

    def update_image_table_widget(self, xdata, ydata, signal):
        if signal is not None:
            if self.dimensions.transpose and self.ws_type == "MATRIX":
                self.image_info_widget.updateTable(ydata, xdata, signal)
            else:
                self.image_info_widget.updateTable(xdata, ydata, signal)

    def plot_x_line(self, x, y):
        try:
            self.xfig.set_data(x, y)
        except (AttributeError, IndexError):
            self.axx.clear()
            self.xfig = self.axx.plot(x, y, scalex=False)[0]
            self.update_line_plot_labels()
            self.update_line_plot_limits()
        self.canvas.draw_idle()

    def plot_y_line(self, x, y):
        try:
            self.yfig.set_data(y, x)
        except (AttributeError, IndexError):
            self.axy.clear()
            self.yfig = self.axy.plot(y, x, scaley=False)[0]
            self.update_line_plot_labels()
            self.update_line_plot_limits()
        self.canvas.draw_idle()

    def update_data_clim(self):
        self.image.set_clim(self.colorbar.colorbar.mappable.get_clim())
        self.canvas.draw_idle()

    def update_line_plot_limits(self):
        try:  # set line plot intensity axes to match colorbar limits
            self.axx.set_ylim(self.colorbar.cmin_value,
                              self.colorbar.cmax_value)
            self.axy.set_xlim(self.colorbar.cmin_value,
                              self.colorbar.cmax_value)
        except AttributeError:
            pass

    def update_line_plot_labels(self):
        try:  # ensure plot labels are in sync with main axes
            self.axx.set_xlabel(self.ax.get_xlabel())
            self.axy.set_ylabel(self.ax.get_ylabel())
        except AttributeError:
            pass

    def update_image_data(self, x, y, update_line_plot=False):
        xmin, xmax, ymin, ymax = self.image.get_extent()
        arr = self.image.get_array()
        data_extent = Bbox([[ymin, xmin], [ymax, xmax]])
        array_extent = Bbox([[0, 0], arr.shape[:2]])
        trans = BboxTransform(boxin=data_extent, boxout=array_extent)
        point = trans.transform_point([y, x])
        if any(np.isnan(point)):
            return
        i, j = point.astype(int)

        if update_line_plot:
            if 0 <= i < arr.shape[0]:
                self.plot_x_line(np.linspace(xmin, xmax, arr.shape[1]),
                                 arr[i, :])
            if 0 <= j < arr.shape[1]:
                self.plot_y_line(np.linspace(ymin, ymax, arr.shape[0]), arr[:,
                                                                            j])

        # Clip the coordinates at array bounds
        if not (0 <= i < arr.shape[0]) or not (0 <= j < arr.shape[1]):
            return None
        else:
            return arr[i, j]

    def set_normalization(self, ws, **kwargs):
        normalize_by_bin_width, _ = get_normalize_by_bin_width(
            ws, self.ax, **kwargs)
        is_normalized = normalize_by_bin_width or ws.isDistribution()
        if is_normalized:
            self.presenter.normalization = mantid.api.MDNormalization.VolumeNormalization
            self.norm_opts.setCurrentIndex(1)
        else:
            self.presenter.normalization = mantid.api.MDNormalization.NoNormalization
            self.norm_opts.setCurrentIndex(0)
Exemplo n.º 8
0
def plot_univariate_inequality(xdata,
                               inclusive=False,
                               raycolor='blue',
                               bgcolor='white',
                               xlim=None,
                               figsize=(5, 0.5),
                               figure=None,
                               title=None,
                               **kwargs):
    """Plot a chart of a univariate inequality ray with matplotlib
    
    Args:
        xdata (list, tuple): endpoints of the ray: ``[start, ..., end]``
        
    Kwargs:
        inclusive (bool): if True, draw a filled circle; if False, draw a circle without fill
        raycolor (str): matplotlib color of the line and markers
        bgcolor (str): matplotlib color for the background and circle without fill
        xlim (None, True, or list/tuple): if True, calculate defaults for xmin and xmax of the x-axis
        figsize (tuple): matplotlib figsize
        figure (None or matplotlib.figure.Figure): Figure to add a plot to
        title (None or str): Title to add the the plot
        kwargs: all other kwargs will be passed to `plt.figure()`
        
    Returns:
        matplotlib.figure.Figure: inequality figure
    """

    _xdata = sorted(xdata)
    xstart = _xdata[0]
    xend = _xdata[-1]
    lefttoright = xstart <= xend
    lefttoright2 = xdata[0] <= xdata[-1]

    fig1 = figure or plt.figure(figsize=figsize, facecolor=bgcolor, **kwargs)
    ax1 = Subplot(fig1, 111)
    fig1.add_subplot(ax1)

    style = {}
    style['linewidth'] = 4
    style['linestyle'] = 'solid'
    style['color'] = raycolor
    style['marker'] = 'o'
    style['markersize'] = 12
    style['markevery'] = [0] if lefttoright else [-1]
    style['markerfacecolor'] = raycolor if inclusive else bgcolor
    style['markeredgecolor'] = raycolor
    style['markeredgewidth'] = 4

    arrowstyle = {}
    arrowstyle['marker'] = '>' if lefttoright2 else '<'
    arrowstyle['markersize'] = 12
    arrowstyle['markerfacecolor'] = raycolor
    arrowstyle['markeredgecolor'] = raycolor
    arrowy = xdata[-1] if lefttoright else xdata[0]

    ax1.set(ylim=(-0.5, 1))

    if xlim is True:
        xlim = (float(xstart - 1), float(xend + 0.5))
    if xlim:
        ax1.set(xlim=xlim)

    ax1.axis["left"].set_visible(False)
    ax1.axis["right"].set_visible(False)
    ax1.axis["top"].set_visible(False)

    #ax1.axis["bottom"].set_axisline_style(axisline_style="->", size=2)
    ax1.plot(1, -0.5, ">k", transform=ax1.get_yaxis_transform(), clip_on=False)
    ax1.plot(0, -0.5, "<k", transform=ax1.get_yaxis_transform(), clip_on=False)

    _x = [xval for xval in xdata]
    _y = [0 for xval in xdata]
    plt.plot(_x, _y, **style)
    if title:
        ax1.set_title(title)
    plt.plot(arrowy, 0, **arrowstyle)
    return fig1