Example #1
0
    def __init__(self, regions_file, parent=None):
        """initialize the gui, connect signals, add axes objects, etc.."""
        super(MyGui, self).__init__(parent)
        self.data = pipeline.TimeSeries()
        self.baseline = pipeline.TimeSeries()
        self.axes = {'base':[], 'time':[]}

        self.setupUi(self)
        self.select_region_file(regions_file)
        self.labels = sorted(config["labels"].keys())

        # add plot widget
        self.plots = PlotWidget(self, self.centralwidget)
        sizePolicy = QtGui.QSizePolicy(QtGui.QSizePolicy.Expanding, QtGui.QSizePolicy.Expanding)
        sizePolicy.setHeightForWidth(self.plots.sizePolicy().hasHeightForWidth())
        self.plots.setSizePolicy(sizePolicy)
        self.scrollArea.setWidget(self.plots)
        self.vis = Vis()
        self.vis.fig = self.plots.canvas.fig

        # connect signals to slots
        self.connect(self.selectFolderButton, QtCore.SIGNAL("clicked()"), self.select_folder)
        self.connect(self.filesListBox, QtCore.SIGNAL("currentIndexChanged(int)"), self.load_file)
        self.connect(self.nextButton, QtCore.SIGNAL("clicked()"), self.next_button_click)
Example #2
0
class MyGui(QtGui.QMainWindow, Ui_RegionGui):
    '''gui main class'''

    def __init__(self, regions_file, parent=None):
        """initialize the gui, connect signals, add axes objects, etc.."""
        super(MyGui, self).__init__(parent)
        self.data = pipeline.TimeSeries()
        self.baseline = pipeline.TimeSeries()
        self.axes = {'base':[], 'time':[]}

        self.setupUi(self)
        self.select_region_file(regions_file)
        self.labels = sorted(config["labels"].keys())

        # add plot widget
        self.plots = PlotWidget(self, self.centralwidget)
        sizePolicy = QtGui.QSizePolicy(QtGui.QSizePolicy.Expanding, QtGui.QSizePolicy.Expanding)
        sizePolicy.setHeightForWidth(self.plots.sizePolicy().hasHeightForWidth())
        self.plots.setSizePolicy(sizePolicy)
        self.scrollArea.setWidget(self.plots)
        self.vis = Vis()
        self.vis.fig = self.plots.canvas.fig

        # connect signals to slots
        self.connect(self.selectFolderButton, QtCore.SIGNAL("clicked()"), self.select_folder)
        self.connect(self.filesListBox, QtCore.SIGNAL("currentIndexChanged(int)"), self.load_file)
        self.connect(self.nextButton, QtCore.SIGNAL("clicked()"), self.next_button_click)

    def next_button_click(self):
        '''load the next file'''
        box = self.filesListBox
        box.setCurrentIndex((box.currentIndex() + 1) % (len(box) - 1))

    def select_region_file(self, regions_file=None):
        '''load regions.json, either from commandline or show dialog'''
        if regions_file:
            self.regions_file = regions_file
            if os.path.exists(self.regions_file):
                self.regions = json.load(open(regions_file))
            else:
                self.regions = {}
        else:
            fname = QtGui.QFileDialog.getOpenFileNameAndFilter(caption='select regions.json',
                                                               filter='*.json')
            fname = str(fname[0])
            if fname and os.path.exists(fname) and fname[-4:] == 'json':
                self.regions_file = fname
                self.regions = json.load(open(self.regions_file))
            else:
                l.error('no regions.json selected --> quitting')
                sys.exit(-1)

    def select_folder(self, folder=None):
        """open file select dialog and enter returned path to the line edit"""
        if folder:
            fname = folder
        else:
            fname = str(QtGui.QFileDialog.getExistingDirectory())
        if fname:
            self.folder = fname
            filelist = glob.glob(os.path.join(self.folder, '*.json'))
            filelist = [os.path.splitext(os.path.basename(f))[0] for f in filelist]
            filelist = [f for f in filelist if not 'base' in f]
            filelist = [f for f in filelist if not 'regions' in f]
            self.filesListBox.clear()
            self.filesListBox.addItems(filelist)
            self.nextButton.setEnabled(True)
            self.filesListBox.setEnabled(True)

    def make_plot_layout(self, n_modes, base_plot_size=50, vertical_plotting_space=0.8):
        '''initialize gui for current number of modes
        '''
        dpi = self.plots.canvas.fig.get_dpi()
        fig = self.plots.canvas.fig
        fig_height = (base_plot_size * n_modes) / (float(dpi) * vertical_plotting_space)
        fig.set_figheight(fig_height)
        h = fig.get_figheight() * dpi
        w = fig.get_figwidth() * dpi
        self.plots.setMinimumSize(w, h)

        height_fraction = vertical_plotting_space / n_modes
        for i in range(n_modes):
            #create timeaxes
            ax = fig.add_axes([0.25, height_fraction * i + 0.05, 0.70, height_fraction])
            ax.set_xticklabels([])
            self.axes['time'].append(ax)
            #create baseaxes
            ax = fig.add_axes([0.1, height_fraction * i + 0.05, 0.15, height_fraction])
            ax.set_axis_off()
            ax.set_gid(n_modes - 1 - i)
            self.axes['base'].append(ax)
        # bring plots in order as you would expect from subplot
        self.axes['base'].reverse()
        self.axes['time'].reverse()

    def load_file(self):
        """ load the serialized TimeSeries object that contains the MF results """
        fname = os.path.join(self.folder, str(self.filesListBox.currentText()))
        l.info('loading: %s' % fname)

        # tupelization magic (set TimeSeries to correct size)
        self.data.load(fname)
        self.data.shape = tuple(self.data.shape)
        self.data.base.shape = tuple(self.data.base.shape)
        self.data.name = os.path.basename(fname)
        self.baseline.load(fname + '_baseline')
        self.baseline.shape = tuple(self.baseline.shape)

        n_modes = self.data.num_objects
        self.make_plot_layout(n_modes)

        # init gui when labels already exist
        if self.data.name in self.regions:
            l.debug('already labeled, load this labeling..')
            self.current_labels = self.regions[self.data.name]
        else:
            self.current_labels = [self.labels[0] for i in range(n_modes)]
        self.draw_spatial_plots()
        self.draw_temporal_plots()

    def draw_spatial_plots(self):
        # TODO: only replot the changed subplots
        bases = self.data.base.shaped2D()
        for i in range(self.data.num_objects):
            colormap = config['labels'][self.current_labels[i]]
            ax = self.axes['base'][i]
            ax.hold(False)
            self.vis.imshow(ax, np.mean(self.baseline.shaped2D(), 0), cmap=plt.cm.bone_r)
            ax.hold(True)
            self.vis.overlay_image(ax, bases[i], threshold=0.2, colormap=colormap)
            ax.set_ylabel(self.current_labels[i], rotation='0')
        self.plots.canvas.draw()

    def draw_temporal_plots(self):
        for i in range(self.data.num_objects):
            ax = self.axes['time'][i]
            ax.hold(False)
            self.vis.plot(ax, self.data.timecourses[:, i])
            self.vis.add_labelshade(ax, self.data)
            ax.set_xticks([])
            ax.set_yticks([])
        self.vis.add_samplelabel(self.axes['time'][0], self.data, rotation='45', toppos=True)
        self.plots.canvas.draw()