class MyGui(QtGui.QMainWindow, Ui_RegionGui): '''gui main class''' def __init__(self, regions_file, parent=None): """initialize the gui, connect signals, add axes objects, etc..""" super(MyGui, self).__init__(parent) self.data = pipeline.TimeSeries() self.baseline = pipeline.TimeSeries() self.axes = {'base':[], 'time':[]} self.setupUi(self) self.select_region_file(regions_file) self.labels = sorted(config["labels"].keys()) # add plot widget self.plots = PlotWidget(self, self.centralwidget) sizePolicy = QtGui.QSizePolicy(QtGui.QSizePolicy.Expanding, QtGui.QSizePolicy.Expanding) sizePolicy.setHeightForWidth(self.plots.sizePolicy().hasHeightForWidth()) self.plots.setSizePolicy(sizePolicy) self.scrollArea.setWidget(self.plots) self.vis = Vis() self.vis.fig = self.plots.canvas.fig # connect signals to slots self.connect(self.selectFolderButton, QtCore.SIGNAL("clicked()"), self.select_folder) self.connect(self.filesListBox, QtCore.SIGNAL("currentIndexChanged(int)"), self.load_file) self.connect(self.nextButton, QtCore.SIGNAL("clicked()"), self.next_button_click) def next_button_click(self): '''load the next file''' box = self.filesListBox box.setCurrentIndex((box.currentIndex() + 1) % (len(box) - 1)) def select_region_file(self, regions_file=None): '''load regions.json, either from commandline or show dialog''' if regions_file: self.regions_file = regions_file if os.path.exists(self.regions_file): self.regions = json.load(open(regions_file)) else: self.regions = {} else: fname = QtGui.QFileDialog.getOpenFileNameAndFilter(caption='select regions.json', filter='*.json') fname = str(fname[0]) if fname and os.path.exists(fname) and fname[-4:] == 'json': self.regions_file = fname self.regions = json.load(open(self.regions_file)) else: l.error('no regions.json selected --> quitting') sys.exit(-1) def select_folder(self, folder=None): """open file select dialog and enter returned path to the line edit""" if folder: fname = folder else: fname = str(QtGui.QFileDialog.getExistingDirectory()) if fname: self.folder = fname filelist = glob.glob(os.path.join(self.folder, '*.json')) filelist = [os.path.splitext(os.path.basename(f))[0] for f in filelist] filelist = [f for f in filelist if not 'base' in f] filelist = [f for f in filelist if not 'regions' in f] self.filesListBox.clear() self.filesListBox.addItems(filelist) self.nextButton.setEnabled(True) self.filesListBox.setEnabled(True) def make_plot_layout(self, n_modes, base_plot_size=50, vertical_plotting_space=0.8): '''initialize gui for current number of modes ''' dpi = self.plots.canvas.fig.get_dpi() fig = self.plots.canvas.fig fig_height = (base_plot_size * n_modes) / (float(dpi) * vertical_plotting_space) fig.set_figheight(fig_height) h = fig.get_figheight() * dpi w = fig.get_figwidth() * dpi self.plots.setMinimumSize(w, h) height_fraction = vertical_plotting_space / n_modes for i in range(n_modes): #create timeaxes ax = fig.add_axes([0.25, height_fraction * i + 0.05, 0.70, height_fraction]) ax.set_xticklabels([]) self.axes['time'].append(ax) #create baseaxes ax = fig.add_axes([0.1, height_fraction * i + 0.05, 0.15, height_fraction]) ax.set_axis_off() ax.set_gid(n_modes - 1 - i) self.axes['base'].append(ax) # bring plots in order as you would expect from subplot self.axes['base'].reverse() self.axes['time'].reverse() def load_file(self): """ load the serialized TimeSeries object that contains the MF results """ fname = os.path.join(self.folder, str(self.filesListBox.currentText())) l.info('loading: %s' % fname) # tupelization magic (set TimeSeries to correct size) self.data.load(fname) self.data.shape = tuple(self.data.shape) self.data.base.shape = tuple(self.data.base.shape) self.data.name = os.path.basename(fname) self.baseline.load(fname + '_baseline') self.baseline.shape = tuple(self.baseline.shape) n_modes = self.data.num_objects self.make_plot_layout(n_modes) # init gui when labels already exist if self.data.name in self.regions: l.debug('already labeled, load this labeling..') self.current_labels = self.regions[self.data.name] else: self.current_labels = [self.labels[0] for i in range(n_modes)] self.draw_spatial_plots() self.draw_temporal_plots() def draw_spatial_plots(self): # TODO: only replot the changed subplots bases = self.data.base.shaped2D() for i in range(self.data.num_objects): colormap = config['labels'][self.current_labels[i]] ax = self.axes['base'][i] ax.hold(False) self.vis.imshow(ax, np.mean(self.baseline.shaped2D(), 0), cmap=plt.cm.bone_r) ax.hold(True) self.vis.overlay_image(ax, bases[i], threshold=0.2, colormap=colormap) ax.set_ylabel(self.current_labels[i], rotation='0') self.plots.canvas.draw() def draw_temporal_plots(self): for i in range(self.data.num_objects): ax = self.axes['time'][i] ax.hold(False) self.vis.plot(ax, self.data.timecourses[:, i]) self.vis.add_labelshade(ax, self.data) ax.set_xticks([]) ax.set_yticks([]) self.vis.add_samplelabel(self.axes['time'][0], self.data, rotation='45', toppos=True) self.plots.canvas.draw()