class StackViewer(object): """ Parameters ---------- viewer : object expected to have update_image method and fig attribute images : array-like must support integer indexing and return a 2D array """ def __init__(self, viewer, images): self.viewer = viewer self.images = images length = len(self.images) fig = self.viewer._fig slider_ax = fig.add_axes([0.1, 0.01, 0.8, 0.02]) self.slider = Slider(slider_ax, 'Frame', 0, length - 1, 0, valfmt='%d/{}'.format(length - 1)) self.slider.on_changed(self.update) self.update(0) # Trigger the initialization of viewer. def update(self, val): if not isinstance(val, int): self.slider.set_val(int(round(val))) # sends up through 'update' again self.viewer.update_image(self.images[int(val)])
def generate_plots(): global rts_filename, img3d, contour_matrix, im, structure_wanted, list_of_ct_filenames, current_slice_slider,contour_plot # get all the relevant data (choose random structure set, get ct and contour data from dicoms) structure_wanted, rts_filename, list_of_ct_filenames = get_random_structure_set() img3d, contour_matrix, x_origin, y_origin, pixel_size = prepare_dicom_data_for_individual_case(rts_filename, list_of_ct_filenames, structure_wanted) #smooth the human contours if desired if smooth_human_contours == True and human_rts_filename in rts_filename: smooth_contour_matrix() # can remove ax.clear() for randomly generated colors of contours ax.clear() # deal with the Slider # try statement only runs if it's NOT the first image plotted (ie slider doesn't exist yet, needs to be initialized) # subsequent images just update the slider rather than re-creating it try: current_slice_slider.set_val(int(img3d.shape[0]/2)) current_slice_slider.valmax = img3d.shape[0]-1 current_slice_slider.ax.set_xlim(current_slice_slider.valmin, current_slice_slider.valmax) except NameError: current_slice_slider = Slider(slice_slider_axes, 'CT Slice', 0, img3d.shape[0]-1, valstep=1, valfmt='%0.0f') current_slice_slider.set_val(int(img3d.shape[0]/2)) current_slice_slider.on_changed(on_slider_change) #plot the data, and connect im = ax.imshow(img3d[int(current_slice_slider.val), :, :], extent=[x_origin, x_origin + img3d.shape[1] * pixel_size, y_origin, y_origin + img3d.shape[2] * pixel_size], cmap='Greys_r', vmin=hu_min, vmax=hu_max, animated = True, interpolation = 'nearest', origin = 'upper') contour_plot = ax.plot(contour_matrix[int(current_slice_slider.val)][0], contour_matrix[int(current_slice_slider.val)][1] )
def main(filenames, equal_axes): all_data = [] tmin = tmax = 0 xmin, ymin, xmax, ymax = 1e100, 1e100, -1e100, -1e100 for i, filename in enumerate(filenames): name = str(i + 1) if ':' in filename: filename, name = filename.split(':') print('Reading %s from file %s' % (name, filename)) _description, _value, _dim, timesteps, data = read_iso_surfaces( filename) all_data.append((name, numpy.array(timesteps), data)) tmax = max(tmax, timesteps[-1]) for contours in data: for contour in contours: xmin = min(xmin, numpy.min(contour[0])) ymin = min(ymin, numpy.min(contour[1])) xmax = max(xmax, numpy.max(contour[0])) ymax = max(ymax, numpy.max(contour[1])) fig, ax = pyplot.subplots() pyplot.subplots_adjust(bottom=0.25) axcolor = '#a1b8dd' slider_ax = pyplot.axes([0.1, 0.1, 0.8, 0.03], facecolor=axcolor) slider = Slider(slider_ax, 'Time', tmin, tmax, valinit=tmin) xdiff = xmax - xmin ydiff = ymax - ymin xmin, xmax = xmin - 0.1 * xdiff, xmax + 0.1 * xdiff ymin, ymax = ymin - 0.1 * ydiff, ymax + 0.1 * ydiff ax.set_xlim(xmin, xmax) ax.set_ylim(ymin, ymax) def update(val): xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() ax.clear() t = slider.val for name, timesteps, data in all_data: i = numpy.argmin(abs(timesteps - t)) dt = timesteps[1] - timesteps[0] if abs(timesteps[i] - t) > 1.5 * dt: continue plotit(ax, data[i], name) ax.set_xlim(xmin, xmax) ax.set_ylim(ymin, ymax) if equal_axes: ax.set_aspect('equal') if len(filenames) > 1: ax.legend(loc='lower right') fig.canvas.draw_idle() slider.on_changed(update) slider.set_val(tmin) pyplot.show()
class ImgView(object): def __init__(self, img, fig_id=None, imin=None, imax=None): self.fig = plt.figure(fig_id) self.ax = self.fig.add_axes([0.1, 0.25, 0.8, 0.7]) self.fig.subplots_adjust(bottom=0.25) axcolor = 'lightgoldenrodyellow' self.axlo = self.fig.add_axes([0.15, 0.1, 0.65, 0.03], axisbg=axcolor) self.axhi = self.fig.add_axes([0.15, 0.15, 0.65, 0.03], axisbg=axcolor) self.axrefr = self.fig.add_axes([0.15, 0.05, 0.10, 0.03], axisbg=axcolor) imsort = np.sort(img.flatten()) n = len(imsort) if imin is None: imin = imsort[int(n * 0.005)] if imax is None: imax = imsort[int(n * 0.995)] self.slo = Slider(self.axlo, 'Scale min', imin, imax, valinit=imsort[int(n * 0.02)]) self.shi = Slider(self.axhi, 'Scale max', imin, imax, valinit=imsort[int(n * 0.98)]) self.brefr = Button(self.axrefr, 'Refresh') self.slo.on_changed(self.update_slider) self.shi.on_changed(self.update_slider) self.brefr.on_clicked(self.refresh) self.set_img(img) def update_slider(self, val): if self.shi.val <= self.slo.val: self.shi.set_val(self.slo.val + 1) self.imgplt.set_clim(self.slo.val, self.shi.val) self.fig.canvas.draw() def set_img(self, img=None): if img is not None: self.img = img self.ax.set_xlim(0, self.img.shape[1] - 1) self.ax.set_ylim(0, self.img.shape[0] - 1) self.imgplt = self.ax.imshow(self.img, vmin=self.slo.val, vmax=self.shi.val, cmap='hot', origin='lower', interpolation='nearest') self.imgplt.set_cmap('hot') divider = make_axes_locatable(self.ax) cax = divider.append_axes("right", size="5%", pad=0.05) self.fig.colorbar(self.imgplt, cax=cax) self.fig.canvas.draw() def refresh(self, event=None): self.set_img()
class GUI: def __init__(self, coinToss): self._coinToss = coinToss maxTosses = coinToss.tosses axcolor = 'lightgoldenrodyellow' self.figure = pyplot.figure() self.mainAxis = pyplot.axes([0.05, 0.2, 0.9, 0.75]) # Slider for number of tosses tossAxis = pyplot.axes([0.1, 0.05, 0.8, 0.05]) self.tossSlider = Slider(tossAxis, 'Tosses', 0., 1. * maxTosses, valinit=0, valfmt=u'%d') self.tossSlider.on_changed(lambda x: self.draw(x)) # Reset button resetAxis = pyplot.axes([0.8, 0.85, 0.1, 0.05]) self.resetButton = Button(resetAxis, 'Re-toss', color=axcolor, hovercolor='0.975') self.resetButton.on_clicked(self.retoss) # Key press events self.figure.canvas.mpl_connect('key_press_event', lambda x: self.press(x)) self._coinToss.doTosses() def retoss(self, event): self._coinToss.doTosses() self.draw(self.tossSlider.val) def press(self, event): if event.key == u'left' and self.tossSlider.val > self.tossSlider.valmin: self.tossSlider.set_val(self.tossSlider.val - 1) if event.key == u'right' and self.tossSlider.val < self.tossSlider.valmax: self.tossSlider.set_val(self.tossSlider.val + 1) if event.key == u'r': self.retoss(event) def draw(self, x=0): c = self._coinToss m = max(enumerate(c.posterior[int(x), :]), key=operator.itemgetter(1)) self.mainAxis.clear() self.mainAxis.plot(c.conditional, c.posterior[int(x), :], lw=2, color='red') self.mainAxis.vlines(c.conditional[m[0]], 0, c.posterior.max()) self.figure.canvas.draw()
class Index(object): def __init__(self, ax_slider, ax_prev, ax_next): self.ind = 0 self.num = len(wavelengths) self.bnext = Button(ax_next, 'Next') self.bnext.on_clicked(self.next) self.bprev = Button(ax_prev, 'Previous') self.bprev.on_clicked(self.prev) self.slider = Slider(ax_slider, "Energy Resolution: {:.2f} nm".format( wavelengths[0]), 0, self.num, valinit=0, valfmt='%d') self.slider.valtext.set_visible(False) self.slider.label.set_horizontalalignment('center') self.slider.on_changed(self.update) position = ax_slider.get_position() self.slider.label.set_position((0.5, -0.5)) self.slider.valtext.set_position((0.5, -0.5)) def next(self, event): i = (self.ind + 1) % (self.num + 1) self.slider.set_val(i) def prev(self, event): i = (self.ind - 1) % (self.num + 1) self.slider.set_val(i) def update(self, i): self.ind = int(i) image.set_data(R[self.ind]) if self.ind != len(wavelengths): self.slider.label.set_text( "Energy Resolution: {:.2f} nm".format( wavelengths[self.ind])) else: self.slider.label.set_text("Calibrated Pixels") if self.ind != len(wavelengths): number = 11 cbar.set_clim(vmin=0, vmax=maximum) cbar_ticks = np.linspace(0., maximum, num=number, endpoint=True) else: number = 2 cbar.set_clim(vmin=0, vmax=1) cbar_ticks = np.linspace(0., 1, num=number) cbar.set_ticks(cbar_ticks) cbar.draw_all() plt.draw()
def adv_window(): figs, axx = plt.subplots(num = 'Advanced settings') figs.canvas.mpl_connect('key_press_event', adv_exit) axx.axis('off') #bx1_as = plt.axes([0.05, 0.3, 0.15, 0.11]) #bx1_as.set_axis_off() #button_as1 = CheckButtons(bx1_as, ['lambda1'], [1]) axx.axis('off') ax_as1 = plt.axes([0.15, 0.02, 0.5, 0.05]) slider_as1 = Slider(ax_as1, 'lambda1', 0.1, 4, dragging = True, valstep = 0.1) ax_as2 = plt.axes([0.15, 0.10, 0.5, 0.05]) slider_as2 = Slider(ax_as2, 'lambda2', 0.1, 4, dragging = True, valstep = 0.1) ax_as3 = plt.axes([0.15, 0.18, 0.5, 0.05]) slider_as3 = Slider(ax_as3, 'smoothing', 0, 4, dragging = True, valstep = 1) ax_as4 = plt.axes([0.15, 0.26, 0.5, 0.05]) slider_as4 = Slider(ax_as4, 'iterations', 1, 1000, dragging = True, valstep = 1) ax_as5 = plt.axes([0.15, 0.34, 0.5, 0.05]) slider_as5 = Slider(ax_as5, 'radius', 0.5, 5, dragging = True, valstep = 0.1) ax_b1 = plt.axes([0.85, 0.15, 0.07, 0.08]) ax_b2 = plt.axes([0.85, 0.05, 0.07, 0.08]) but_as1 = Button(ax_b1, 'exit', color = 'beige', hovercolor = 'beige') but_as2 = Button(ax_b2, 'start', color = 'beige', hovercolor = 'beige') #ax_textbox = plt.axes([0, 0.4, 0.5, 0.4]) #axx.axis('off') textstr = "Press ENTER in terminal to start segmentation. \n Shouldn't be necessairy to change settings below, but can be tuned if \n resulting segmentation is not ideal. \n Especially if small nodule: try setting lambda1 <= lambda2 \n Or if very nonhomogeneous nodule: try setting lambda1 > lambda2." props = dict(boxstyle='round', facecolor='wheat') axx.text(-0.18, 0.25, textstr, transform=ax.transAxes, fontsize=12, verticalalignment='top', bbox=props) slider_as1.set_val(lambda1) slider_as2.set_val(lambda2) slider_as3.set_val(smoothing) slider_as4.set_val(iterations) slider_as5.set_val(rad) but_as1.on_clicked(adv_exit) but_as2.on_clicked(adv_start) figs.canvas.draw_idle() return figs, axx, slider_as1, slider_as2, slider_as3, slider_as4, slider_as5
def plot_hfo_samples(hfo_detection_run: HfoDetectionRun): periods = hfo_detection_run.detector.last_run.analytics.periods fig_height = 6 fig_width = 10 rows = 4 columns = 1 fig = plt.figure(figsize=(fig_width, fig_height)) plt.rc('font', family='sans-serif') spec = gridspec.GridSpec(rows, columns, figure=fig, hspace=0.7) bandwidth_axes = fig.add_subplot(spec[0, 0]) spike_train_axes = fig.add_subplot(spec[1, 0]) raster_axes = fig.add_subplot(spec[2, 0]) slider_axes = fig.add_subplot(8, 1, 8) period_windows = list(zip(periods.start, periods.stop)) if len(period_windows) == 0: return slider = Slider(slider_axes, 'Period Index\n(Interactive)', 1, len(period_windows) + 1, valinit=1, valstep=1.0) initial_start, initial_stop = period_windows[0] _plot_hfo_sample(hfo_detection_run, np.float64(initial_start), np.float64(initial_stop), bandwidth_axes, spike_train_axes, raster_axes) def plot_time(one_based_index): start, stop = period_windows[int(np.round(one_based_index - 1))] _plot_hfo_sample(hfo_detection_run, np.float64(start), np.float64(stop), bandwidth_axes, spike_train_axes, raster_axes) fig.canvas.draw_idle() slider.on_changed(plot_time) if should_show_plot(hfo_detection_run.configuration): plt.show() if should_save_plot(hfo_detection_run.configuration): for one_based_index in range(1, len(period_windows) + 1): slider.set_val(one_based_index) save_or_show_channel_plot(f'hfo_sample_period_{one_based_index}', hfo_detection_run)
class HistoryPlotter: def __init__(self,saved_filename): self.plotter = RealTimePlotter() self.plotter.fig.suptitle('History data', fontsize='14', fontweight='bold') self.plotter.fig.subplots_adjust(bottom=0.23,hspace=0.5) self.index = 0 self.saved_filename = saved_filename self.data = [] self.prev_axis = plt.axes([0.395, 0.03, 0.1, 0.06]) self.next_axis = plt.axes([0.505, 0.03, 0.1, 0.06]) self.btn_next = Button(self.next_axis, 'Next') self.btn_next.on_clicked(self.next) self.btn_prev = Button(self.prev_axis, 'Previous') self.btn_prev.on_clicked(self.prev) self.slider_axis = plt.axes([0.25, 0.11, 0.5, 0.03]) self.read_data() self.slider = Slider(self.slider_axis, 'chunk',0,len(self.data)-1, valinit=0,valfmt='%10.0f') self.slider.on_changed(self.update) def read_data(self): with bz2.BZ2File(self.saved_filename,'r') as f: self.data = pickle.load(f) print('Number of chunks',len(self.data) ) def loop(self): try: self.plotter.show() self.plotter.plot(self.data[self.index]['buffer'],self.data[self.index]['stats']) plt.show(block=True) except KeyboardInterrupt: print("Stopping...") def update(self,val): self.index=int(val) self.plotter.plot(self.data[self.index]['buffer'],self.data[self.index]['stats']) def next(self,event): if(self.index < len(self.data)-1): self.index +=1 self.plotter.plot(self.data[self.index]['buffer'],self.data[self.index]['stats']) self.slider.set_val(self.index) def prev(self,event): if(self.index >0): self.index -=1 self.plotter.plot(self.data[self.index]['buffer'],self.data[self.index]['stats']) self.slider.set_val(self.index)
class ImgView(object): def __init__(self, img, fig_id=None, imin=None, imax=None): self.fig = plt.figure(fig_id) self.ax = self.fig.add_axes([0.1, 0.25, 0.8, 0.7]) self.fig.subplots_adjust(bottom=0.25) axcolor = 'lightgoldenrodyellow' self.axlo = self.fig.add_axes([0.15, 0.1, 0.65, 0.03], axisbg=axcolor) self.axhi = self.fig.add_axes([0.15, 0.15, 0.65, 0.03], axisbg=axcolor) self.axrefr = self.fig.add_axes([0.15, 0.05, 0.10, 0.03], axisbg=axcolor) imsort = np.sort(img.flatten()) n = len(imsort) if imin is None: imin = imsort[int(n * 0.005)] if imax is None: imax = imsort[int(n * 0.995)] self.slo = Slider(self.axlo, 'Scale min', imin, imax, valinit=imsort[int(n * 0.02)]) self.shi = Slider(self.axhi, 'Scale max', imin, imax, valinit=imsort[int(n * 0.98)]) self.brefr = Button(self.axrefr, 'Refresh') self.slo.on_changed(self.update_slider) self.shi.on_changed(self.update_slider) self.brefr.on_clicked(self.refresh) self.set_img(img) def update_slider(self, val): if self.shi.val <= self.slo.val: self.shi.set_val(self.slo.val + 1) self.imgplt.set_clim(self.slo.val, self.shi.val) self.fig.canvas.draw() def set_img(self, img=None): if img is not None: self.img = img self.ax.set_xlim(0, self.img.shape[1]-1) self.ax.set_ylim(0, self.img.shape[0]-1) self.imgplt = self.ax.imshow(self.img, vmin=self.slo.val, vmax=self.shi.val, cmap='hot', origin='lower', interpolation='nearest') self.imgplt.set_cmap('hot') divider = make_axes_locatable(self.ax) cax = divider.append_axes("right", size="5%", pad=0.05) self.fig.colorbar(self.imgplt, cax=cax) self.fig.canvas.draw() def refresh(self, event=None): self.set_img()
class Index(object): def __init__(self, ax_slider, ax_prev, ax_next): self.ind = 0 self.num = len(wavelengths) self.bnext = Button(ax_next, 'Next') self.bnext.on_clicked(self.next) self.bprev = Button(ax_prev, 'Previous') self.bprev.on_clicked(self.prev) self.slider = Slider(ax_slider, "Energy Resolution: {:.2f} nm".format(wavelengths[0]), 0, self.num, valinit=0, valfmt='%d') self.slider.valtext.set_visible(False) self.slider.label.set_horizontalalignment('center') self.slider.on_changed(self.update) position = ax_slider.get_position() self.slider.label.set_position((0.5, -0.5)) self.slider.valtext.set_position((0.5, -0.5)) def next(self, event): i = (self.ind + 1) % (self.num + 1) self.slider.set_val(i) def prev(self, event): i = (self.ind - 1) % (self.num + 1) self.slider.set_val(i) def update(self, i): self.ind = int(i) image.set_data(R[self.ind]) if self.ind != len(wavelengths): self.slider.label.set_text("Energy Resolution: {:.2f} nm" .format(wavelengths[self.ind])) else: self.slider.label.set_text("Calibrated Pixels") if self.ind != len(wavelengths): number = 11 cbar.set_clim(vmin=0, vmax=maximum) cbar_ticks = np.linspace(0., maximum, num=number, endpoint=True) else: number = 2 cbar.set_clim(vmin=0, vmax=1) cbar_ticks = np.linspace(0., 1, num=number) cbar.set_ticks(cbar_ticks) cbar.draw_all() plt.draw()
def set_val(self, val): """ Set the value and update the color. Notes ----- valmin/valmax are set on the parent to 0 and len(depths). """ val = int(val) # valmax is not allowed, since it is out of the array. # valmin is allowed since 0 index is in depth array. if val < self.valmin or val >= self.valmax: # invalid, so ignore return # activate color is first since we still have access to self.val self.updatePageDepthColor(val) Slider.set_val(self, val)
class GUI: def __init__(self, coinToss): self._coinToss = coinToss maxTosses = coinToss.tosses axcolor = 'lightgoldenrodyellow' self.figure = pyplot.figure() self.mainAxis = pyplot.axes([0.05, 0.2, 0.9, 0.75]) # Slider for number of tosses tossAxis = pyplot.axes([0.1, 0.05, 0.8, 0.05]) self.tossSlider = Slider(tossAxis, 'Tosses', 0., 1.*maxTosses, valinit=0, valfmt=u'%d') self.tossSlider.on_changed(lambda x: self.draw(x)) # Reset button resetAxis = pyplot.axes([0.8, 0.85, 0.1, 0.05]) self.resetButton = Button(resetAxis, 'Re-toss', color=axcolor, hovercolor='0.975') self.resetButton.on_clicked(self.retoss) # Key press events self.figure.canvas.mpl_connect('key_press_event', lambda x: self.press(x)) self._coinToss.doTosses() def retoss(self, event): self._coinToss.doTosses() self.draw(self.tossSlider.val) def press(self, event): if event.key == u'left' and self.tossSlider.val > self.tossSlider.valmin: self.tossSlider.set_val(self.tossSlider.val - 1) if event.key == u'right' and self.tossSlider.val < self.tossSlider.valmax: self.tossSlider.set_val(self.tossSlider.val + 1) if event.key == u'r': self.retoss(event) def draw(self, x = 0): c = self._coinToss m = max(enumerate(c.posterior[int(x),:]), key=operator.itemgetter(1)) self.mainAxis.clear() self.mainAxis.plot(c.conditional,c.posterior[int(x),:], lw=2, color='red') self.mainAxis.vlines(c.conditional[m[0]],0,c.posterior.max()) self.figure.canvas.draw()
class IndexTracker(object): def __init__(self, ax, X, step=41): self.ax = ax ax.figure.subplots_adjust(left=0.25, bottom=0.25) ax.set_title('use scroll wheel to navigate images') self.step = step self.X = X self.slices, rows, cols = X.shape self.ind = 0 self.im = ax.imshow(self.X[self.ind, :, :], vmin=np.min(X), vmax=np.max(X)) ax = fig.add_axes([0.25, 0.1, 0.65, 0.03]) self.slider = Slider(ax, 'Axis %i index' % self.slices, 0, self.slices, valinit=self.ind, valfmt='%i') self.slider.on_changed(self.update_slider) self.update() def press(self, event): if event.key == 'right': self.ind = (self.ind + self.step) % self.slices elif event.key == 'left': self.ind = (self.ind - self.step) % self.slices self.slider.set_val(self.ind) self.update() def update_slider(self, event): ind = int(self.slider.val) self.ind = ind self.update() def update(self): self.im.set_data(self.X[self.ind, :, :].T) ax.set_ylabel('slice %s' % self.ind) self.im.axes.figure.canvas.draw()
def plot_data(self): import matplotlib.pyplot as plt from matplotlib.widgets import Slider fig = plt.figure() ax0 = plt.axes([0.085, 0.2, 0.9, 0.75]) ax1 = plt.axes([0.21, 0.02, 0.7, 0.03]) self.draw_vals(ax0) B = Slider(ax1, label='Beta value (1e-8)', valmin=4, valmax=50, valstep=0.1) Slider.set_val(B, 10) Slider.on_changed(B, self.update_beta) plt.show()
def animate_plot_to_pictures(min_speed, max_speed, pictures): axcolor = 'lightgoldenrodyellow' axspeed = plt.axes([0.125, 0.05, 0.65, 0.03], facecolor=axcolor) sspeed = Slider(axspeed, 'Speed', 0, max_speed, valinit=0, valstep=1) for i in range(pictures): speed = 1.0 * i / (pictures-1) * max_speed + min_speed accels, rads = strafevis.strafe_stats.get_stats(720, strafevis.strafe_stats.StatType.ACCEL, speed=speed) display_axes = plt.subplot(1, 1, 1, polar=True) norm = mpl.colors.Normalize(0.0, 2 * np.pi) cmap = AngleMap(accels) cb = mpl.colorbar.ColorbarBase(display_axes, cmap=cmap, norm=norm, orientation='horizontal') # aesthetics - get rid of border and axis labels cb.outline.set_visible(False) display_axes.set_axis_off() sspeed.set_val(speed) plt.savefig('pic_%04d.png' % i)
class Visualizer(object): def __init__(self, snapshot, save_on_close=None): self.snapshot = snapshot self.save_on_close = save_on_close self.T = self.snapshot.human.values()[0].T self.choice = None @property def t(self): return self._t @t.setter def t(self, value): self._t = value for scene in self.scenes: scene.t = self.t def select(self, key): self.choice = key plt.close(self.fig) def close(self, event): if self.save_on_close is not None: self.fig.savefig(self.save_on_close) def run(self): self.fig, self.ax = plt.subplots(1, len(self.snapshot.keys()), sharex=True, sharey=True, figsize=(13, 7)) self.fig.canvas.mpl_connect('key_press_event', self.key_press) self.fig.canvas.mpl_connect('close_event', self.close) self.scenes = [Scene(ax, self.snapshot.view(key)) for ax, key in zip(self.ax, self.snapshot.keys())] self.fig.subplots_adjust(bottom=0.15, top=0.85) box = self.fig.add_axes([0.15, 0.05, 0.7, 0.05]) self.slider = Slider(box, 'Time', 0., self.T, valinit=0.) self.t = 0. def update_t(t): self.t = t self.slider.on_changed(update_t) def click(key): def f(event): self.select(key) return f self.buttons = [] for ax, key in zip(self.ax, self.snapshot.keys()): box = ax.figbox box = self.fig.add_axes([box.x0, box.y1+0.05, box.width, 0.05]) self.buttons.append(Button(box, 'Prefer {}'.format(key))) self.buttons[-1].on_clicked(click(key)) plt.show() def key_press(self, event): if event.key=='escape': plt.close(self.fig) elif event.key=='r': self.slider.set_val(0.) elif event.key=='up': self.slider.set_val(min(max(self.t+0.2, 0), self.T)) elif event.key=='down': self.slider.set_val(min(max(self.t-0.2, 0), self.T)) elif event.key.lower() in [s.lower() for s in self.snapshot.keys()]: for key in self.snapshot.keys(): if event.key.lower()==key.lower(): self.select(key)
class slicer(object): def __init__(self, data, axis, init_slice): self.axis = axis self.axis_label = {0:'X', 1:'Y', 2:'Z'}[axis] self.data = np.swapaxes(data, self.axis, 0) self.slice_index = init_slice self.slice = self.data[self.slice_index] self.max_index = self.data.shape[0] self.fig = plt.figure() self.ax = self.fig.add_subplot(111) self.ax.imshow(self.slice.T, interpolation='none', origin='lower') self.along_axis = plt.axes([0.2, 0.1, 0.65, 0.03]) self.slab = Slider(self.along_axis, '%s_Slab'%self.axis_label, 0, self.max_index , valinit=self.slice_index, valfmt='%i') self.slab.on_changed(self.update_figure) self.fig.canvas.mpl_connect('key_press_event',self.update_slice_index) plt.show() def draw(self): im = your_function(self.values) pylab.show() self.ax.imshow(im) def update_slice_index(self, event): if event.key=='+': self.slice_index += 1 elif event.key == '-': self.slice_index -= 1 if self.slice_index < 0: self.slice_index = self.max_index self.slab.set_val(self.slice_index) def update_figure(self, event = None): self.slice_index = int(self.slab.val%self.max_index) self.slice = self.data[self.slice_index] self.ax.imshow(self.slice.T, interpolation='none', origin='lower') self.fig.canvas.draw()
class InteractivePlot(object): def __init__(self, record_wins, sliding_wins, example_labels, similarities, is_test=None, plot_rc=None): self.score_plot = ScorePlot(record_wins, sliding_wins, example_labels, similarities, is_test=is_test, plot_rc=plot_rc) self.fig, self.main_ax = self.score_plot.fig, self.score_plot.main_ax plt.subplots_adjust(bottom=0.2) self.score_plot.draw() max_time = self.score_plot.records.absolute_end ax_s = plt.axes([0.15, 0.1, 0.75, 0.03]) ax_w = plt.axes([0.15, 0.05, 0.75, 0.03]) self.slider_start = Slider(ax_s, 'Start', 0., max_time, valinit=0.) self.slider_start.on_changed(self.update) self.slider_width = Slider(ax_w, 'Width', 0., 30., valinit=15.) self.slider_width.on_changed(self.update) self.fig.canvas.mpl_connect('key_press_event', self.on_key) def update(self, val): self.score_plot.current.absolute_start = self.slider_start.val width = self.slider_width.val self.score_plot.current.absolute_end = ( self.score_plot.current.absolute_start + width) self.score_plot.draw() def on_key(self, event): if event.key == 'right': direction = 1 elif event.key == 'left': direction = -1 else: return self.slider_start.set_val(self.slider_start.val + direction * .5 * self.slider_width.val) self.update(None)
def CreateDisplay(self): rax = plt.axes([0.025, 0.8, 0.15, 0.15]) radioSelectOperation = RadioButtons(rax, ("Search", "Insert"), active=0) radioSelectOperation.on_clicked(self.OnOperationTypeSelect) radioSelectOperation.set_active(0) axAS = plt.axes([0.25, 0.20, 0.65, 0.03]) axC = plt.axes([0.25, 0.15, 0.65, 0.03]) axGS = plt.axes([0.25, 0.10, 0.65, 0.03]) sAtomicSize = Slider(axAS, 'AtomicSize', 0, len(self.AtomicSize) - 1, valinit=0, valfmt="%1.2f") sAtomicSize.on_changed(partial(self.setAS_slider, sAtomicSize)) sAtomicSize.set_val(0.0) sCapacity = Slider(axC, 'Capacity', 0, len(self.Capacity) - 1, valinit=0, valfmt="%i") sCapacity.on_changed(partial(self.setC_slider, sCapacity)) sCapacity.set_val(0.0) sGridSize = Slider(axGS, 'GridSize', 0, len(self.GridSize) - 1, valinit=0, valfmt="%i") sGridSize.on_changed(partial(self.setGS_slider, sGridSize)) sGridSize.set_val(0.0) plt.show()
class AtlasEditor(plot_support.ImageSyncMixin): """Graphical interface to view an atlas in multiple orthogonal dimensions and edit atlas labels. :attr:`plot_eds` are dictionaries of keys specified by one of :const:`magmap.config.PLANE` plane orientations to Plot Editors. Attributes: image5d: Numpy image array in t,z,y,x,[c] format. labels_img: Numpy image array in z,y,x format. channel: Channel of the image to display. offset: Index of plane at which to start viewing in x,y,z (user) order. fn_close_listener: Handle figure close events. borders_img: Numpy image array in z,y,x,[c] format to show label borders, such as that generated during label smoothing. Defaults to None. If this image has a different number of labels than that of ``labels_img``, a new colormap will be generated. fn_show_label_3d: Function to call to show a label in a 3D viewer. Defaults to None. title (str): Window title; defaults to None. fn_refresh_atlas_eds (func): Callback for refreshing other Atlas Editors to synchronize them; defaults to None. Typically takes one argument, this ``AtlasEditor`` object to refreshing it. Defaults to None. alpha_slider: Matplotlib alpha slider control. alpha_reset_btn: Maplotlib button for resetting alpha transparency. alpha_last: Float specifying the previous alpha value. interp_planes: Current :class:`InterpolatePlanes` object. interp_btn: Matplotlib button to initiate plane interpolation. save_btn: Matplotlib button to save the atlas. fn_status_bar (func): Function to call during status bar updates in :class:`pixel_display.PixelDisplay`; defaults to None. fn_update_coords (func): Handler for coordinate updates, which takes coordinates in z-plane orientation; defaults to None. """ _EDIT_BTN_LBLS = ("Edit", "Editing") def __init__(self, image5d, labels_img, channel, offset, fn_close_listener, borders_img=None, fn_show_label_3d=None, title=None, fn_refresh_atlas_eds=None, fig=None, fn_status_bar=None): """Plot ROI as sequence of z-planes containing only the ROI itself.""" super().__init__() self.image5d = image5d self.labels_img = labels_img self.channel = channel self.offset = offset self.fn_close_listener = fn_close_listener self.borders_img = borders_img self.fn_show_label_3d = fn_show_label_3d self.title = title self.fn_refresh_atlas_eds = fn_refresh_atlas_eds self.fig = fig self.fn_status_bar = fn_status_bar self.alpha_slider = None self.alpha_reset_btn = None self.alpha_last = None self.interp_planes = None self.interp_btn = None self.save_btn = None self.edit_btn = None self.color_picker_box = None self.fn_update_coords = None self._labels_img_sitk = None # for saving labels image def show_atlas(self): """Set up the atlas display with multiple orthogonal views.""" # set up the figure if self.fig is None: fig = figure.Figure(self.title) self.fig = fig else: fig = self.fig fig.clear() gs = gridspec.GridSpec(2, 1, wspace=0.1, hspace=0.1, height_ratios=(20, 1), figure=fig, left=0.06, right=0.94, bottom=0.02, top=0.98) gs_viewers = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0, 0]) # set up a colormap for the borders image if present cmap_borders = colormaps.get_borders_colormap(self.borders_img, self.labels_img, config.cmap_labels) coord = list(self.offset[::-1]) # editor controls, split into a slider sub-spec to allow greater # spacing for labels on either side and a separate sub-spec for # buttons and other fields gs_controls = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[1, 0], width_ratios=(1, 1), wspace=0.15) self.alpha_slider = Slider( fig.add_subplot(gs_controls[0, 0]), "Opacity", 0.0, 1.0, valinit=plot_editor.PlotEditor.ALPHA_DEFAULT) gs_controls_btns = gridspec.GridSpecFromSubplotSpec( 1, 5, subplot_spec=gs_controls[0, 1], wspace=0.1) self.alpha_reset_btn = Button(fig.add_subplot(gs_controls_btns[0, 0]), "Reset") self.interp_btn = Button(fig.add_subplot(gs_controls_btns[0, 1]), "Fill Label") self.interp_planes = InterpolatePlanes(self.interp_btn) self.interp_planes.update_btn() self.save_btn = Button(fig.add_subplot(gs_controls_btns[0, 2]), "Save") self.edit_btn = Button(fig.add_subplot(gs_controls_btns[0, 3]), "Edit") self.color_picker_box = TextBox( fig.add_subplot(gs_controls_btns[0, 4]), None) # adjust button colors based on theme and enabled status; note # that colors do not appear to refresh until fig mouseover for btn in (self.alpha_reset_btn, self.edit_btn): enable_btn(btn) enable_btn(self.save_btn, False) enable_btn(self.color_picker_box, color=config.widget_color + 0.1) def setup_plot_ed(axis, gs_spec): # set up a PlotEditor for the given axis # subplot grid, with larger height preference for plot for # each increased row to make sliders of approx equal size and # align top borders of top images rows_cols = gs_spec.get_rows_columns() extra_rows = rows_cols[3] - rows_cols[2] gs_plot = gridspec.GridSpecFromSubplotSpec( 2, 1, subplot_spec=gs_spec, height_ratios=(1, 10 + 14 * extra_rows), hspace=0.1 / (extra_rows * 1.4 + 1)) # transform arrays to the given orthogonal direction ax = fig.add_subplot(gs_plot[1, 0]) plot_support.hide_axes(ax) plane = config.PLANE[axis] arrs_3d, aspect, origin, scaling = \ plot_support.setup_images_for_plane( plane, (self.image5d[0], self.labels_img, self.borders_img)) img3d_tr, labels_img_tr, borders_img_tr = arrs_3d # slider through image planes ax_scroll = fig.add_subplot(gs_plot[0, 0]) plane_slider = Slider(ax_scroll, plot_support.get_plane_axis(plane), 0, len(img3d_tr) - 1, valfmt="%d", valinit=0, valstep=1) # plot editor max_size = max_sizes[axis] if max_sizes else None plot_ed = plot_editor.PlotEditor( ax, img3d_tr, labels_img_tr, config.cmap_labels, plane, aspect, origin, self.update_coords, self.refresh_images, scaling, plane_slider, img3d_borders=borders_img_tr, cmap_borders=cmap_borders, fn_show_label_3d=self.fn_show_label_3d, interp_planes=self.interp_planes, fn_update_intensity=self.update_color_picker, max_size=max_size, fn_status_bar=self.fn_status_bar) return plot_ed # setup plot editors for all 3 orthogonal directions max_sizes = plot_support.get_downsample_max_sizes() for i, gs_viewer in enumerate( (gs_viewers[:2, 0], gs_viewers[0, 1], gs_viewers[1, 1])): self.plot_eds[config.PLANE[i]] = setup_plot_ed(i, gs_viewer) self.set_show_crosslines(True) # attach listeners fig.canvas.mpl_connect("scroll_event", self.scroll_overview) fig.canvas.mpl_connect("key_press_event", self.on_key_press) fig.canvas.mpl_connect("close_event", self._close) fig.canvas.mpl_connect("axes_leave_event", self.axes_exit) self.alpha_slider.on_changed(self.alpha_update) self.alpha_reset_btn.on_clicked(self.alpha_reset) self.interp_btn.on_clicked(self.interpolate) self.save_btn.on_clicked(self.save_atlas) self.edit_btn.on_clicked(self.toggle_edit_mode) self.color_picker_box.on_text_change(self.color_picker_changed) # initialize and show planes in all plot editors if self._max_intens_proj is not None: self.update_max_intens_proj(self._max_intens_proj) self.update_coords(coord, config.PLANE[0]) plt.ion() # avoid the need for draw calls def _close(self, evt): """Handle figure close events by calling :attr:`fn_close_listener` with this object. Args: evt (:obj:`matplotlib.backend_bases.CloseEvent`): Close event. """ self.fn_close_listener(evt, self) def on_key_press(self, event): """Respond to key press events. """ if event.key == "a": # toggle between current and 0 opacity if self.alpha_slider.val == 0: # return to saved alpha if available and reset if self.alpha_last is not None: self.alpha_slider.set_val(self.alpha_last) self.alpha_last = None else: # make translucent, saving alpha if not already saved # during a halve-opacity event if self.alpha_last is None: self.alpha_last = self.alpha_slider.val self.alpha_slider.set_val(0) elif event.key == "A": # halve opacity, only saving alpha on first halving to allow # further halving or manual movements while still returning to # originally saved alpha if self.alpha_last is None: self.alpha_last = self.alpha_slider.val self.alpha_slider.set_val(self.alpha_slider.val / 2) elif event.key == "up" or event.key == "down": # up/down arrow for scrolling planes self.scroll_overview(event) elif event.key == "w": # shortcut to toggle editing mode self.toggle_edit_mode(event) elif event.key == "ctrl+s" or event.key == "cmd+s": # support default save shortcuts on multiple platforms; # ctrl-s will bring up save dialog from fig, but cmd/win-S # will bypass self.save_fig(self.get_save_path()) def update_coords(self, coord, plane_src=config.PLANE[0]): """Update all plot editors with given coordinates. Args: coord: Coordinate at which to center images, in z,y,x order. plane_src: One of :const:`magmap.config.PLANE` to specify the orientation from which the coordinates were given; defaults to the first element of :const:`magmap.config.PLANE`. """ coord_rev = libmag.transpose_1d_rev(list(coord), plane_src) for i, plane in enumerate(config.PLANE): coord_transposed = libmag.transpose_1d(list(coord_rev), plane) if i == 0: self.offset = coord_transposed[::-1] if self.fn_update_coords: # update offset based on xy plane, without centering # planes are centered on the offset as-is self.fn_update_coords(coord_transposed, False) self.plot_eds[plane].update_coord(coord_transposed) def view_subimg(self, offset, shape): """Zoom all Plot Editors to the given sub-image. Args: offset: Sub-image coordinates in ``z,y,x`` order. shape: Sub-image shape in ``z,y,x`` order. """ for i, plane in enumerate(config.PLANE): offset_tr = libmag.transpose_1d(list(offset), plane) shape_tr = libmag.transpose_1d(list(shape), plane) self.plot_eds[plane].view_subimg(offset_tr[1:], shape_tr[1:]) self.fig.canvas.draw_idle() def refresh_images(self, plot_ed=None, update_atlas_eds=False): """Refresh images in a plot editor, such as after editing one editor and updating the displayed image in the other editors. Args: plot_ed (:obj:`magmap.plot_editor.PlotEditor`): Editor that does not need updating, typically the editor that originally changed. Defaults to None. update_atlas_eds (bool): True to update other ``AtlasEditor``s; defaults to False. """ for key in self.plot_eds: ed = self.plot_eds[key] if ed != plot_ed: ed.refresh_img3d_labels() if ed.edited: # display save button as enabled if any editor has been edited enable_btn(self.save_btn) if update_atlas_eds and self.fn_refresh_atlas_eds is not None: # callback to synchronize other Atlas Editors self.fn_refresh_atlas_eds(self) def scroll_overview(self, event): """Scroll images and crosshairs in all plot editors Args: event: Scroll event. """ for key in self.plot_eds: self.plot_eds[key].scroll_overview(event) def alpha_update(self, event): """Update the alpha transparency in all plot editors. Args: event: Slider event. """ for key in self.plot_eds: self.plot_eds[key].alpha_updater(event) def alpha_reset(self, event): """Reset the alpha transparency in all plot editors. Args: event: Button event, currently ignored. """ self.alpha_slider.reset() def axes_exit(self, event): """Trigger axes exit for all plot editors. Args: event: Axes exit event. """ for key in self.plot_eds: self.plot_eds[key].on_axes_exit(event) def interpolate(self, event): """Interpolate planes using :attr:`interp_planes`. Args: event: Button event, currently ignored. """ try: self.interp_planes.interpolate(self.labels_img) # flag Plot Editors as edited so labels can be saved for ed in self.plot_eds.values(): ed.edited = True self.refresh_images(None, True) except ValueError as e: print(e) def save_atlas(self, event): """Save atlas labels using the registered image suffix given by :attr:`config.reg_suffixes[config.RegSuffixes.ANNOTATION]`. Args: event: Button event, currently not used. """ # only save if at least one editor has been edited if not any([ed.edited for ed in self.plot_eds.values()]): return # save to the labels reg suffix; use sitk Image if loaded and store # any Image loaded during saving reg_name = config.reg_suffixes[config.RegSuffixes.ANNOTATION] if self._labels_img_sitk is None: self._labels_img_sitk = config.labels_img_sitk self._labels_img_sitk = sitk_io.write_registered_image( self.labels_img, config.filename, reg_name, self._labels_img_sitk, overwrite=True) # reset edited flag in all editors and show save button as disabled for ed in self.plot_eds.values(): ed.edited = False enable_btn(self.save_btn, False) print("Saved labels image at {}".format(datetime.datetime.now())) def get_save_path(self): """Get figure save path based on filename, ROI, and overview plane shown. Returns: str: Figure save path. """ ext = config.savefig if config.savefig else config.DEFAULT_SAVEFIG return "{}.{}".format( naming.get_roi_path(os.path.basename(self.title), self.offset), ext) def toggle_edit_mode(self, event): """Toggle editing mode, determining the current state from the first :class:`magmap.plot_editor.PlotEditor` and switching to the opposite value for all plot editors. Args: event: Button event, currently not used. """ edit_mode = False for i, ed in enumerate(self.plot_eds.values()): if i == 0: # change edit mode based on current mode in first plot editor edit_mode = not ed.edit_mode toggle_btn(self.edit_btn, edit_mode, text=self._EDIT_BTN_LBLS) ed.edit_mode = edit_mode if not edit_mode: # reset the color picker text box when turning off editing self.color_picker_box.set_val("") def update_color_picker(self, val): """Update the color picker :class:`TextBox` with the given value. Args: val (str): Color value. If None, only :meth:`color_picker_changed` will be triggered. """ if val is None: # updated picked color directly self.color_picker_changed(val) else: # update text box, which triggers color_picker_changed self.color_picker_box.set_val(val) def color_picker_changed(self, text): """Respond to color picker :class:`TextBox` changes by updating the specified intensity value in all plot editors. Args: text (str): String of text box value. Converted to an int if non-empty. """ intensity = text if text: if not libmag.is_number(intensity): return intensity = int(intensity) print("updating specified color to", intensity) for i, ed in enumerate(self.plot_eds.values()): ed.intensity_spec = intensity
class VideoViewer(object): """ A matplotlib-based video viewer. Parameters ---------- video : list-like, iterator A list of a tuple of 2D arrays or a generator of a tuple of 2D arrays. If an iterator is provided, you must set 'count' as well. count: int Length of the video. When this is set it displays only first 'count' frames of the video. id : int, optional For multi-frame data specifies camera index. norm_func : callable Normalization function that takes a single argument (array) and returns a single element (array). Can be used to apply custom normalization function to the image before it is shown. title : str, optional Plot title. kw : options, optional Extra arguments passed directly to imshow function Examples -------- >>> from cddm.viewer import VideoViewer >>> video = (np.random.randn(256,256) for i in range(256)) >>> vg = VideoViewer(video, 256, title = "iterator example") #must set nframes, because video has no __len__ #>>> vg.show() >>> video = [np.random.randn(256,256) for i in range(256)] >>> vl = VideoViewer(video, title = "list example") #>>> vl.show() """ def __init__(self, video, count=None, id=0, norm_func=lambda x: x.real, title="", **kw): if count is None: try: count = len(video) except TypeError: raise Exception("You must specify count!") self._norm = norm_func self.id = id self.index = 0 self.video = video self.fig, self.ax = plt.subplots() self.ax.set_title(title) plt.subplots_adjust(bottom=0.25) frame = next(iter(video)) #take first frame frame = self._prepare_image(frame) self.img = self.ax.imshow(frame, **kw) self.fig.colorbar(self.img, ax=self.ax) self.axframe = plt.axes([0.1, 0.1, 0.7, 0.03]) self.sframe = Slider(self.axframe, '', 0, count - 1, valinit=0, valstep=1, valfmt='%i') self.axnext = plt.axes([0.7, 0.02, 0.1, 0.05]) self.bnext = Button(self.axnext, '>') self.axnext2 = plt.axes([0.6, 0.02, 0.1, 0.05]) self.bnext2 = Button(self.axnext2, '>>>') self.axprev = plt.axes([0.1, 0.02, 0.1, 0.05]) self.bprev = Button(self.axprev, '<') self.axprev2 = plt.axes([0.2, 0.02, 0.1, 0.05]) self.bprev2 = Button(self.axprev2, '<<<') self.axstop = plt.axes([0.4, 0.02, 0.1, 0.05]) self.bstop = Button(self.axstop, 'Stop') self.axplay = plt.axes([0.3, 0.02, 0.1, 0.05]) self.bplay = Button(self.axplay, 'Play') self.axfast = plt.axes([0.5, 0.02, 0.1, 0.05]) self.bfast = Button(self.axfast, 'FF') self.playing = False self.step_fast = count / 100 self.step = 1 self.pause_duration = 0.001 def _play(): while self.playing: plt.pause(self.pause_duration) next_frame = self.sframe.val + self.step if next_frame >= count: self.playing = False else: self.sframe.set_val(next_frame) def stop(event): self.playing = False @skip_runtime_error def play(event): self.playing = True self.step = 1 _play() @skip_runtime_error def play_fast(event): self.playing = True self.step = self.step_fast _play() def next_frame(event): self.sframe.set_val(min(self.sframe.val + 1, count - 1)) def next_fast(event): self.sframe.set_val(min(self.sframe.val + self.step, count - 1)) def prev_frame(event): self.sframe.set_val(max(self.sframe.val - 1, 0)) def prev_fast(event): self.sframe.set_val(max(self.sframe.val - self.step, 0)) @skip_runtime_error def update(val): i = int(self.sframe.val) try: frame = self.video[i] #assume list-like object self.index = i except TypeError: #assume generator frame = None if i > self.index: for frame in self.video: self.index += 1 if self.index >= i: break if frame is not None: frame = self._prepare_image(frame) self.img.set_data(frame) self.fig.canvas.draw_idle() self.sframe.on_changed(update) self.bnext.on_clicked(next_frame) self.bstop.on_clicked(stop) self.bplay.on_clicked(play) self.bfast.on_clicked(play_fast) self.bnext2.on_clicked(next_fast) self.bprev2.on_clicked(prev_fast) self.bprev.on_clicked(prev_frame) def _prepare_image(self, im): if isinstance(im, tuple) or isinstance(im, list): return self._norm(im[self.id]) else: return self._norm(im) def show(self): """Shows video.""" plt.show()
class Control: def __init__(self, figure, position, label, initial_value, dim): """ Control group for slice """ # Slider self.slider = Slider(figure.add_axes(position, xticks=[], yticks=[], facecolor='#222222'), label, 0, dim - 1, valinit=initial_value, valstep=1, valfmt='%1.0f', color='#444444') # Set Button Positions position[1] = position[1] - 0.04 position[2] = position[2] / 2 position[3] = 0.04 self.buttondown = Button(figure.add_axes(position), '-', color='#222222', hovercolor='#333333') # Buttons position[0] = position[0] + position[2] self.buttonup = Button(figure.add_axes(position), '+', color='#222222', hovercolor='#333333') # save value for display self.value = initial_value # save dim for slider limit self.lim = dim def get_value(self): """ Returns the current value of the control """ return int(self.value) # decrement slider def decrement(self, event): new_val = self.value - 1 if new_val >= 0: self.slider.set_val(new_val) # increment slider def increment(self, event): new_val = self.value + 1 if new_val < self.lim: self.slider.set_val(new_val) # update control def update(self, value, callback_func): """ Updates the current value of the control then runs the callback_func with the current value """ self.value = int(value) # set new value callback_func(int(value)) # execute callback
class ytViewer(object): def __init__(self, filename, fold=19277, nmax=100,NORM=True,dtype=int8,DEB=0,shear_val=0.): self.UPDATE = True self.color = True self.NORM = NORM self.NMAX = nmax self.fold = fold self.increment = 5 self.index = 0 self.shear_val = shear_val self.remove_len1 = 0 self.remove_len2 = 1 self.fig = figure(figsize=(16,7)) self.data = fromfile(filename,dtype=dtype) self.max_index = int(len(self.data)/self.fold) self.data = self.data[:self.max_index*self.fold] self.folded_data_orig2 = self.data.reshape(self.max_index,self.fold) self.folded_data_orig = array(self.folded_data_orig2) self.folded_data_orig3 = array(self.folded_data_orig) self.folded_data = self.folded_data_orig3[:self.NMAX] self.Y0 = 0 self.ax = axes([0.1,0.4,0.8,0.47]) if not self.NORM: self.im = self.ax.imshow(self.folded_data, interpolation='nearest', aspect='auto', origin='lower', vmin=0, vmax=255) else: self.im = self.ax.imshow(self.folded_data, interpolation='nearest', aspect='auto', origin='lower', vmin=self.data.min(), vmax=self.data.max()) self.cursor = Cursor(self.ax, useblit=True, color='red', linewidth=2) self.axh = axes([0.1,0.05,0.8,0.2]) self.hline, = self.axh.plot(self.folded_data[self.Y0,:]) self.axh.set_xlim(0,len(self.folded_data[0,:])) self.axh.set_ylim(self.folded_data.min(),self.folded_data.max()) # create 'remove_len1' slider self.remove_len1_sliderax = axes([0.1,0.925,0.8,0.02]) self.remove_len1_slider = Slider(self.remove_len1_sliderax,'beg',0.,self.fold*(3.5/4),self.remove_len1,'%d') self.remove_len1_slider.on_changed(self.update_tab) # create 'remove_len2' slider self.remove_len2_sliderax = axes([0.1,0.905,0.8,0.02]) self.remove_len2_slider = Slider(self.remove_len2_sliderax,'end',0.,self.fold*(3.5/4),self.remove_len2,'%d') self.remove_len2_slider.on_changed(self.update_tab) # create 'shear' slider self.shear_sliderax = axes([0.1,0.88,0.8,0.02]) self.shear_slider = Slider(self.shear_sliderax,'shear',-0.5,0.5,self.shear_val,'%1.3f') self.shear_slider.on_changed(self.update_shear) # create 'index' slider self.index_sliderax = axes([0.1,0.975,0.8,0.02]) self.index_slider = Slider(self.index_sliderax,'index',0,self.max_index-self.increment,0,'%d') self.index_slider.on_changed(self.update_param) # create 'nmax' slider self.nmax_sliderax = axes([0.1,0.955,0.8,0.02]) self.nmax_slider = Slider(self.nmax_sliderax,'nmax',0,self.max_index,self.NMAX,'%d') self.nmax_slider.on_changed(self.update_tab) cid = self.fig.canvas.mpl_connect('motion_notify_event', self.mousemove) cid2 = self.fig.canvas.mpl_connect('key_press_event', self.keypress) self.axe_toggledisplay = self.fig.add_axes([0.43,0.27,0.14,0.1]) self.plot_circle(0,0,2,fc='#00FF7F') mpl.pyplot.axis('off') if self.shear_val!=0.: self.shear() gobject.idle_add(self.update_plot) show() def update_shear(self,value): self.shear_val = round(self.shear_slider.val,3) self.shear() self.update_tab(value) def update_param(self,value): self.index = int(round(self.index_slider.val,0)) self.update_tab(value) def shear(self): if self.shear_val == 0: pass dd = array(self.folded_data_orig2) for i in range(0,self.folded_data_orig2.shape[0]): dd[i,:] = roll(self.folded_data_orig2[i,:], int(i*self.shear_val)) self.folded_data_orig = dd def update_tab(self,val): self.remove_len1 = int(self.remove_len1_slider.val) self.remove_len2 = int(self.remove_len2_slider.val) self.NMAX = int(round(self.nmax_slider.val,0)) self.folded_data_orig3 = array(self.folded_data_orig[:,self.remove_len1:-self.remove_len2]) self.folded_data = self.folded_data_orig3[self.index:(self.index+self.NMAX)] self.Y0 = 0 self.ax.clear() self.im = self.ax.imshow(self.folded_data, interpolation='nearest', aspect='auto', origin='lower', vmin=self.data.min(), vmax=self.data.max()) self.axh = axes([0.1,0.05,0.8,0.2]) self.axh.clear() self.hline, = self.axh.plot(self.folded_data[self.Y0,:]) plt.ylim(self.folded_data.min(),self.folded_data.max()) plt.xlim(0,len(self.folded_data[self.Y0,:])) draw() def update_plot(self): while self.UPDATE: self.folded_data = self.folded_data_orig3[self.index:(self.index+self.NMAX)] ### Update picture ### self.im.set_data(self.folded_data) self.hline.set_ydata(self.folded_data[self.Y0,:]) self.index = self.index + self.increment self.index_slider.set_val(self.index) draw() return True return False def update_cut(self): self.hline.set_ydata(self.folded_data[self.Y0,:]) draw() def keypress(self, event): if event.key == 'q': # eXit del event sys.exit() elif event.key=='n': del event self.NORM = not(self.NORM) if not self.NORM: self.ax.clear() self.im = self.ax.imshow(self.folded_data, interpolation='nearest', aspect='auto', origin='lower', vmin=0, vmax=255) self.axh.set_ylim(0, 255) else: self.ax.clear() self.im = self.ax.imshow(self.folded_data, interpolation='nearest', aspect='auto', origin='lower', vmin=self.folded_data.min(), vmax=self.folded_data.max()) self.axh.set_ylim(self.folded_data.min(), self.folded_data.max()) elif event.key == ' ': # play/pause self.toggle_update() else: print 'Key '+str(event.key)+' not known' def mousemove(self, event): # called on each mouse motion to get mouse position if event.inaxes!=self.ax: return self.X0 = int(round(event.xdata,0)) self.Y0 = int(round(event.ydata,0)) self.update_cut() def toggle_update(self): self.UPDATE = not(self.UPDATE) if self.UPDATE: gobject.idle_add(self.update_plot) self.color = not(self.color) if not(self.color): self.patch.remove() self.axe_toggledisplay = self.fig.add_axes([0.43,0.27,0.14,0.1]) self.axe_toggledisplay.clear() self.plot_circle(0,0,2,fc='#FF4500') mpl.pyplot.axis('off') draw() else: self.patch.remove() self.axe_toggledisplay = self.fig.add_axes([0.43,0.27,0.14,0.1]) self.axe_toggledisplay.clear() self.plot_circle(0,0,2,fc='#00FF7F') mpl.pyplot.axis('off') draw() #gobject.idle_add(self.update_plot) def plot_circle(self,x,y,r,fc='r'): """Plot a circle of radius r at position x,y""" cir = mpl.patches.Circle((x,y), radius=r, fc=fc) self.patch = mpl.pyplot.gca().add_patch(cir)
class Notes(object): """ An interactive matplotlib window for ROI drawing, definition of the main axis of the eye, and selection of a reference frames where the eye is open. """ def __init__(self, stack, frame_lut, key): self.wintitle = "ixtract - " + key['m'] # stack of frames sampled from the files to be analyzed: self.stack = stack self.nframes = self.stack.shape[0] self.frame_lut = frame_lut # look-up table for frame indices self.ref_stack = [] # selected frames self.ref_frame_inds = [] # indices of selected frames self.mode = 'roi' # init in roi drawing mode self.points = np.zeros((2,2)) # holds mouse click coords self.roi = None # init ROI and axis variables to None-type self.axis = None self.line = None # patch variables to display ROI and axis on figure self.rect = None # set up frame display axis self.fig, _ = plt.subplots() self.fig.canvas.set_window_title(self.wintitle) grid = gs.GridSpec(9, 12) self.ax_frame = plt.subplot(grid[:8,:]) self.ax_frame.axis('off') self.ax_frame.imshow(self.stack[0,:,:], cmap='gray') # set up slider axis self.ax_slider = plt.subplot(grid[8,0:11]) self.slider = Slider(self.ax_slider, 'Frame', 0, self.nframes-1, valinit=0, valfmt='%d') # connect callback functions # REMOVE AXIS BINDINGS (AXIS CHECK IS IN CALLBACK FUNCTIONS) self.cidpress = self.ax_frame.figure.canvas.mpl_connect( 'button_press_event', self.on_click) self.cidrelease = self.ax_frame.figure.canvas.mpl_connect( 'button_release_event', self.on_release) self.slider.on_changed(self.update_frame) # disable default matplotlib key bindings manager, canvas = self.fig.canvas.manager, self.fig.canvas canvas.mpl_disconnect(manager.key_press_handler_id) # connect to custom key bindings self.cidkey = self.fig.canvas.mpl_connect('key_press_event', self.on_key) # display user prompts print("Select an ROI and define the main axis of the eye") print(" - use 'd' to switch between drawing modes") print(" - use 's' to select a few reference frames in which the eye is open") print(" - click on the slider or use the arrow keys to scroll between video frames") print(" - use 'c' to continue, or 'esc' to quit") def on_click(self, event): """stores coordintes of click""" if event.inaxes != self.ax_frame: return self.points[0,0], self.points[0,1] = event.xdata, event.ydata def on_release(self, event): """assigns coordinates of click & release to ROI or axis properties and draws the appropriate shape patch onto the figure""" if event.inaxes != self.ax_frame: return # assign click data to temporary variables self.points[1,0], self.points[1,1] = event.xdata, event.ydata # assign click data as object property and draw patch if self.mode == 'roi': # parse click data self.roi = self.points.copy() x0 = min(self.points[:,0]) y0 = min(self.points[:,1]) dx = np.absolute(self.points[1,0] - self.points[0,0]) # width dy = np.absolute(self.points[1,1] - self.points[0,1]) # height # remove old patch if self.rect is not None: self.rect.remove() # add new patch self.rect = Rectangle((x0,y0), dx, dy, lw=2, ec=[0,1,0.5], fill=False) self.ax_frame.add_patch(self.rect) elif self.mode == 'axis': # parse click data self.axis = self.points.copy() x0 = self.points[0,0] y0 = self.points[0,1] dx = self.points[1,0] - self.points[0,0] dy = self.points[1,1] - self.points[0,1] # remove old patch if self.line is not None: self.line.remove() # add new patch # arrow patch used in lieu of a line patch (arrow head width set to zero) self.line = Arrow(x0, y0, dx, dy, width=0, lw=2, color=[1,0,1]) self.ax_frame.add_patch(self.line) # update figure self.fig.canvas.draw() def on_key(self, event): """specifies functions of pressed keys in matplotlib figure""" # change drawing mode if event.key == 'd': if self.mode == 'roi': self.mode = 'axis' elif self.mode == 'axis': self.mode = 'roi' print(" Drawing: " + self.mode) # select current frame as open-eye reference elif event.key == 's': newframe = np.int(np.round(self.slider.val)) print(" Frame %d selected as reference" % newframe) self.ref_stack.append(self.stack[newframe,:,:]) self.ref_frame_inds.append(newframe) # change frame elif event.key == 'right': # move forward one frame if np.round(self.slider.val) < self.nframes-1: self.slider.set_val(self.slider.val+1) elif event.key == 'left': # move back one frame if np.round(self.slider.val) > 0: self.slider.set_val(self.slider.val-1) # confirm or exit elif event.key == 'c': # confirm and continue # check that all annotations have been made if self.roi is None: print(" Please define an ROI") elif self.axis is None: print(" Please define the eye's aixs") elif len(self.ref_stack) == 0: print(" Please select at least one open-eye reference frame") else: print("ROI and axis confirmed") self.wrap_up() plt.close('all') elif event.key == 'escape': # quit eye tracking print("Eye tracking aborted") plt.close('all') sys.exit() def update_frame(self, val): """Updates frame based on slider value""" newframe = np.int(np.round(self.slider.val)) frame = self.stack[newframe,:,:] self.ax_frame.imshow(frame, cmap='gray') def wrap_up(self): """Performs cropping and reference frame computations before closing""" # crop stacks based on selected ROI print(" Cropping frames...") self.stack = img.crop(self.stack, self.roi) self.ref_stack = img.crop(np.array(self.ref_stack), self.roi) # take median if multiple frames are given as reference if len(self.ref_stack.shape) == 2: self.ref_frame = self.ref_stack elif len(self.ref_stack.shape) == 3: self.ref_frame = np.median(self.ref_stack, axis=0) # compute mean pixel-wise differences from the reference frame print(" Comparing frames to reference...") self.diffs = img.mean_diffs(self.stack, self.ref_frame) # store indices of frames used as reference self.ref_frames = self.frame_lut[self.ref_frame_inds]
def pager(L,shape,vmake,vpaint,offset=0,save=None,savedefaults=dict(dest='pager-sav',format='svg'),bstyle={},**_ka): r""" :param L: a list of arbitrary objects other than :const:`None` :param shape: a pair (number of rows, number of columns) or a single number if they are equat :param vmake: a function to initialise the display (see below) :param vpaint: a function to instantiate the display for a given page (see below) :param savedefaults: used as default keyword arguments of method :meth:`savefig` when saving pages :type savedefaults: :class:`dict` This function first create a :class:`Cell` instance with all the remaining arguments, then splits it into a grid of sub-cells according to *shape*, then displays *L* page per page on the grid. Each page displays a slice of *L* of length equal to the product of the components of *shape* (or less, for the final page). The toolbar is enriched with page navigation buttons. A save button also allows to save the whole collection of pages in a given directory (beware: may be long). Function *vmake* takes as input a :class:`Cell` instance and instantiates it as needed. It can store information (e.g. about the specific role of each artist created in the cell), if needed, by simply setting attributes in the cell. This is called once for each cell at the begining of the display. Function *vpaint* takes as input a cell and an element of *L* or None, and displays that element in the cell (or resets the cell to indicate a missing value), possibly using the artists created by *vmake* and stored in the cell. This is called once at each page display and for each cell. Unfortunately, matplotlib toolbars are not standardised: the depend on the backend and may not support adding button. """ #------------------------------------------------------------------------------ from numpy import ceil, rint, clip from matplotlib.text import Text from matplotlib.widgets import Slider from matplotlib.pyplot import close from pathlib import Path from shutil import rmtree def gen(L): yield from L while True: yield None def genc(cell): Nr,Nc = cell.shape yield from (cell[row,col] for row in range(Nr) for col in range(Nc)) def paintp(cell,p,draw=True): for c,x in zip(genc(cell),gen(L[p*cellpp:])): vpaint(c,x) if draw: cell.figure.canvas.draw() def toggle_ctrl(): ctrl.ax.set_visible(not ctrl.ax.get_visible()) cell.figure.canvas.draw() def save_all(): ka = _ka.copy() ka.update(fig=None,figsize=((cell.figure.get_figwidth(),cell.figure.get_figheight()))) #import multiprocessing #multiprocessing.get_context('spawn').Process(target=pager,args=(L,shape,vmake,vpaint),kwargs=dict(save={},savedefaults=savedefaults,**ka)).start() pager(L,shape,vmake,vpaint,save={},savedefaults=savedefaults,**ka) cell = Cell.create(**_ka) Nr,Nc = (shape,shape) if isinstance(shape,int) else shape cell.make_grid(Nr,Nc) for c in genc(cell): vmake(c) cellpp = Nr*Nc npage = int(ceil(len(L)/cellpp)) if save is None: actions = [ ('<<',(lambda:ctrl.set_val(clip(ctrl.val-1,1,npage)))), ('>>',(lambda:ctrl.set_val(clip(ctrl.val+1,1,npage)))), ('toggle-ctrl',toggle_ctrl), ('save-all',save_all), ] try: menu = cell.figure.canvas.toolbar; menu.addAction except: menu = Menu(cell.figure,**bstyle) for a,f in actions: menu.addAction(a,f) ctrl = Slider(cell.figure.add_axes((0.1,0.,.8,.03),visible=False,zorder=1),'page',.5,npage+.5,valinit=0,valfmt='%.0f/{}'.format(npage),closedmin=False,closedmax=False) ctrl.on_changed(lambda p:paintp(cell,int(rint(p))-1)) ctrl.set_val(1+offset/cellpp) else: s = savedefaults.copy() s.update(save) pth = Path(s.pop('dest')) try: assert pth.is_dir() for f in list(pth.iterdir()): rmtree(str(f)) except Exception as e: logger.warn('Error on save directory %s: %s',path,e) raise try: for p in range(npage): paintp(cell,p,False) cell.figure.savefig(str((pth/'p{:02d}'.format(p)).with_suffix('.'+s['format'])),**s) except Exception as e: logger.warn('Error saving page %s: %s',p,e) close(cell.figure.number)
class InteractiveView: def __init__(self, img, peaks): import matplotlib.pyplot as plt from matplotlib.widgets import Button, Slider self.fig = plt.figure(figsize=(15, 10)) self.ax = self.fig.add_subplot(111) plt.subplots_adjust(bottom=0.15) self.i = 1 self.peaks = peaks self.img = img if self.img.ndim > 2: flatted_dim = [reduce(operator.mul, self.img.shape[:-2])] self.img.shape = flatted_dim + list(self.img.shape[-2:]) self.text = plt.figtext(0.06, 0.05, '', transform=self.fig.transFigure) w = 0.1 h = 0.050 y_pos = 0.04 self.axprev = plt.axes([0.7, y_pos, w, h]) self.bprev = Button(self.axprev, 'Previous') self.bprev.on_clicked(self.prev) self.axnext = plt.axes([0.8, y_pos, w, h]) self.bnext = Button(self.axnext, 'Next') self.bnext.on_clicked(self.next) self.axslide = plt.axes([0.4, 0.04, 0.25, 0.03]) self.slider = Slider(self.axslide, 'Frames', 1, int(len(self.img)), valinit=1) self.slider.on_changed(self.slide) self.artists = [] self.im = None self.draw(0) def draw(self, i): from matplotlib.patches import Circle import matplotlib.pyplot as plt self.i = i if i > len(self.img) or i <= 0: self.i = 1 i = 1 if self.im: self.im.remove() for art in self.artists: art.remove() self.artists = [] try: current_peaks = self.peaks.ix[i-1] for j, data in current_peaks.iterrows(): x = data['x'] y = data['y'] w = data['w'] outline = Circle((y, x), w, alpha=0.4, color='red') pt = self.ax.add_patch(outline) self.artists.append(pt) pt = self.ax.scatter(y, x, marker='+') self.artists.append(pt) n_peaks = current_peaks.shape[0] except: n_peaks = 0 self.text.set_text("Frame %i/%i | Peaks number %i" % (i, len(self.img), n_peaks)) self.im = self.ax.imshow(self.img[i-1], interpolation='none', cmap='gray', shape=(2, 2)) plt.draw() def next(self, event=None): self.slider.set_val(self.i + 1) def prev(self, event=None): self.slider.set_val(self.i - 1) def slide(self, event): self.draw(int(event)) def show(self): self.fig.show()
class SelectFromCollection(object): """Interactive RLS classifier interface for image segmentation Parameters ---------- fig : matplotlib.figure.Figure The Figure object on which the interface is drawn. mmc : rlscore.learner.interactive_rls_classifier.InteractiveRlsClassifier Interactive RLS classifier object img : numpy.array Array consisting of image data collection : numpy.array, shape = [n_pixels, 2] array consisting of the (x,y) coordinates of all usable pixels in the image windowsize : int Determines the size of a window around grid points (2 * windowsize + 1) """ def __init__(self, fig, mmc, img, collection, windowsize = 0): #Initialize the main axis ax = fig.add_axes([0.1,0.1,0.8,0.8]) ax.set_yticklabels([]) ax.yaxis.set_tick_params(size = 0) ax.set_xticklabels([]) ax.xaxis.set_tick_params(size = 0) self.imdata = ax.imshow(img) #Initialize LassoSelector on the main axis self.lasso = LassoSelector(ax, onselect = self.onselect) self.lasso.connect_event('key_press_event', self.onkeypressed) self.lasso.line.set_visible(False) self.mmc = mmc self.img = img self.img_orig = img.copy() self.collection = collection self.selectedset = set([]) self.lockedset = set([]) self.windowsize = windowsize #Initialize the fraction slider self.slider_axis = fig.add_axes([0.2, 0.06, 0.6, 0.02]) self.in_selection_slider = Slider(self.slider_axis, 'Fraction slider', 0., 1, valinit = len(np.nonzero(self.mmc.classvec_ws)[0]) / len(mmc.working_set)) def sliderupdate(val): val = int(val * len(mmc.working_set)) nonzeroc = len(np.nonzero(self.mmc.classvec_ws)[0]) if val > nonzeroc: claims = val - nonzeroc newclazz = 1 elif val < nonzeroc: claims = nonzeroc - val newclazz = 0 else: return print('Claimed', claims, 'points for class', newclazz) self.claims = claims mmc.claim_n_points(claims, newclazz) self.redrawall() self.in_selection_slider.on_changed(sliderupdate) #Initialize the display for the RLS objective funtion self.objfun_display_axis = fig.add_axes([0.1, 0.96, 0.8, 0.02]) self.objfun_display_axis.imshow(mmc.compute_steepness_vector()[np.newaxis, :], cmap = plt.get_cmap("Oranges")) self.objfun_display_axis.set_aspect('auto') self.objfun_display_axis.set_yticklabels([]) self.objfun_display_axis.yaxis.set_tick_params(size = 0) def onselect(self, verts): #Select a new working set self.path = Path(verts) self.selectedset = set(np.nonzero(self.path.contains_points(self.collection))[0]) print('Selected ' + str(len(self.selectedset)) + ' points') newws = list(self.selectedset - self.lockedset) self.mmc.new_working_set(newws) self.redrawall() def onkeypressed(self, event): print('You pressed', event.key) if event.key == '1': print('Assigned all selected points to class 1') newclazz = 1 mmc.claim_all_points_in_working_set(newclazz) if event.key == '0': print('Assigned all selected points to class 0') newclazz = 0 mmc.claim_all_points_in_working_set(newclazz) if event.key == 'a': print('Selected all points') newws = list(set(range(len(self.collection))) - self.lockedset) self.mmc.new_working_set(newws) self.lasso.line.set_visible(False) if event.key == 'c': changecount = mmc.cyclic_descent_in_working_set() print('Performed ', changecount, 'cyclic descent steps') if event.key == 'l': print('Locked the class labels of selected points') self.lockedset = self.lockedset | self.selectedset newws = list(self.selectedset - self.lockedset) self.mmc.new_working_set(newws) if event.key == 'u': print('Unlocked the selected points') self.lockedset = self.lockedset - self.selectedset newws = list(self.selectedset - self.lockedset) self.mmc.new_working_set(newws) if event.key == 'p': print('Compute predictions and AUC on data') preds = self.mmc.predict(Xmat) print(auc(mmc.Y[:, 0], preds[:, 0])) self.redrawall() def redrawall(self): #Color all class one labeled pixels red oneclazz = np.nonzero(self.mmc.classvec)[0] col_row = self.collection[oneclazz] rowcs, colcs = col_row[:, 1], col_row[:, 0] red = np.array([255, 0, 0]) for i in range(-self.windowsize, self.windowsize + 1): for j in range(-self.windowsize, self.windowsize + 1): self.img[rowcs+i, colcs+j, :] = red #Return the original color of the class zero labeled pixels zeroclazz = np.nonzero(self.mmc.classvec - 1)[0] col_row = self.collection[zeroclazz] rowcs, colcs = col_row[:, 1], col_row[:, 0] for i in range(-self.windowsize, self.windowsize + 1): for j in range(-self.windowsize, self.windowsize + 1): self.img[rowcs+i, colcs+j, :] = self.img_orig[rowcs+i, colcs+j, :] self.imdata.set_data(self.img) #Update the slider position according to labeling of the current working set sliderval = 0 if len(mmc.working_set) > 0: sliderval = len(np.nonzero(self.mmc.classvec_ws)[0]) / len(mmc.working_set) self.in_selection_slider.set_val(sliderval) #Update the RLS objective function display self.objfun_display_axis.imshow(mmc.compute_steepness_vector()[np.newaxis, :], cmap=plt.get_cmap("Oranges")) self.objfun_display_axis.set_aspect('auto') #Final stuff self.lasso.canvas.draw_idle() plt.draw() print_instructions()
class FindCenters(pb.PointBrowser): """ Semi-automatic fitting of bright points in TEM image dragging ... (opt) if True, dragging is allowed for sliders """ def __init__(self, image, dragging=True, **kwargs): # init PointBrowser super(FindCenters,self).__init__(image,[[None,None]],**kwargs); self.axis.set_title('FitHexagonCenters: %s' % self.imginfo['desc']); self.fig.subplots_adjust(bottom=0.2); # space for sliders # add slider for neighborhood size self.nbhd_size = 5; # neighborhood size axNbhd = self.fig.add_axes([0.2, 0.05, 0.1, 0.04]); self.sNbhd = Slider(axNbhd,'neighbors ',2,20,valinit=self.nbhd_size,\ valfmt=' (%d)',dragging=dragging); self.sNbhd.on_changed(self.ChangeNeighborhood); # add slider for number of points self.num_points = 1e6; # number of local maxima to find axNum = self.fig.add_axes([0.45, 0.05, 0.3, 0.04]); self.sNum = Slider(axNum, 'points ',0,100,valinit=self.num_points,\ valfmt=' (%d)',dragging=dragging); self.sNum.on_changed(self.ChangeMaxPoints); # add buttons axRefine = self.fig.add_axes([0.85, 0.7, 0.1, 0.04]); self.bRefine = Button(axRefine,'Refine'); self.bRefine.on_clicked(self.RefineCenters); # initial calculation of local maximas self.ChangeNeighborhood(self.nbhd_size); def ChangeNeighborhood(self,val): #print "ChangeNeighborhood" self.nbhd_size = int(val); # run initial peak fit maxima,diff = self.find_local_maxima(self.image,self.nbhd_size); # update max-number of points in points slider self.sNum.valmax = Nmax = np.sum(maxima); # number of local max self.sNum.ax.set_xlim((self.sNum.valmin, Nmax)); # rescale slider self.sNum.set_val(min(self.num_points,Nmax)); # update value (calls ChangeMaxPoints) def ChangeMaxPoints(self,val): #print "ChangeMaxPoints()"; self.num_points = int(val); self.points = self.refine_local_maxima(self.num_points); self._update_points(); def RefineCenters(self,event): " refine positions by fitting 2D Gaussian in neighborhood of local max " from scipy.optimize import leastsq from sys import stdout #print "Refine()"; NN = self.nbhd_size; Nx,Ny = self.image.shape; dx,dy = np.mgrid[-NN:NN+1,-NN:NN+1]; # refine each point separately self.points = self.points.astype(float); # allow subpixel precision for ip in range(len(self.points)): P = self.points[ip]; x,y = np.round(P); # get neighborhood (skip border) xmin,xmax = dx[[0,-1],0]+x; # first and last element in dx ymin,ymax = dy[0,[0,-1]]+y; # " dy if xmin<0 or ymin<0 or xmax>=Nx or ymax>=Ny: continue nbhd = self.image[xmin:xmax+1,ymin:ymax+1]; assert nbhd.shape == (2*NN+1,2*NN+1) # calculate center of mass def gauss(x0,y0,A,B,fwhm): return A*np.exp( - ((dx-x0)**2+(dy-y0)**2) / fwhm**2) + B; p0 = (0.,0.,self.image[tuple(P)],0.,NN/2); # initial guess residuals = lambda param: (nbhd - gauss(*param)).flat; # residuals p,ierr = leastsq(lambda p: (nbhd - gauss(*p)).flat, p0);# least-squares fit self.points[ip] = (x+p[0],y+p[1]); # correct position of point # DEBUG: plot fits for each point if self.verbosity > 0: print "Refining Points... %d %%\r" % (100*ip/len(self.points-1)), if self.verbosity > 3: print "IN: ",p0 print "OUT: ",p if self.verbosity > 10: plt.figure(); ix = nbhd.shape[0]/2; plt.plot(dy[ix],nbhd[ix], 'k',label='image'); plt.plot(dy[ix],gauss(*p0)[ix],'g',label='first guess'); plt.plot(dy[ix],gauss(*p)[ix], 'r',label='final fit'); plt.plot(dx[:,ix],nbhd[:,ix], 'k--'); plt.plot(dx[:,ix],gauss(*p0)[:,ix],'g--'); plt.plot(dx[:,ix],gauss(*p)[:,ix], 'r--'); plt.legend(); plt.show(); if self.verbosity > 0: print "Refining Points. Finished."; stdout.flush(); self._update_points(); def find_local_maxima(self, data, neighborhood_size): """ find local maxima within neighborhood idea from http://stackoverflow.com/questions/9111711 (get-coordinates-of-local-maxima-in-2d-array-above-certain-value) """ # find local maxima in image (width specified by neighborhood_size) data_max = filters.maximum_filter(data,neighborhood_size); maxima = (data == data_max); assert np.sum(maxima) > 0; # we should always find local maxima # remove connected pixels (plateaus) labeled, num_objects = ndimage.label(maxima) slices = ndimage.find_objects(labeled) maxima *= 0; for dx,dy in slices: maxima[(dx.start+dx.stop-1)/2, (dy.start+dy.stop-1)/2] = 1 # calculate difference between local maxima and lowest # pixel in neighborhood (will be used in select_local_maxima) data_min = filters.minimum_filter(data,neighborhood_size); diff = data_max - data_min; self._maxima = maxima; self._diff = diff; return maxima,diff def refine_local_maxima(self,N): " select highest N local maxima using thresholding " maxima = self._maxima; diff = self._diff; # select highest local maxima using thresholding if np.sum(maxima) > N: # calc treshold from sorted list of differences for local maxima thresh = np.sort(diff[maxima].flat)[-N]; # keep only maxima with diff>thresh maxima = np.logical_and(maxima, diff>thresh); # TODO: refine fit by local 2D Gauss-Fit # return list of x,y positions of local maxima return np.asarray(np.where(maxima)).T;
class BasicDendrogramViewer(object): def __init__(self, dendrogram): if dendrogram.data.ndim not in [2, 3]: raise ValueError("Only 2- and 3-dimensional arrays are supported at this time") self.array = dendrogram.data self.dendrogram = dendrogram self.plotter = DendrogramPlotter(dendrogram) self.plotter.sort(reverse=True) # Get the lines as individual elements, and the mapping from line to structure self.lines = self.plotter.get_lines() # Define the currently selected subtree self.selected = None self.selected_lines = None self.selected_contour = None # Initiate plot import matplotlib.pyplot as plt self.fig = plt.figure(figsize=(14, 8)) self.ax1 = self.fig.add_axes([0.1, 0.1, 0.4, 0.7]) from matplotlib.widgets import Slider self._clim = (np.min(self.array[~np.isnan(self.array) & ~np.isinf(self.array)]), np.max(self.array[~np.isnan(self.array) & ~np.isinf(self.array)])) if self.array.ndim == 2: self.slice = None self.image = self.ax1.imshow(self.array, origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray) else: self.slice = int(round(self.array.shape[0] / 2.)) self.image = self.ax1.imshow(self.array[self.slice, :, :], origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray) self.slice_slider_ax = self.fig.add_axes([0.1, 0.95, 0.4, 0.03]) self.slice_slider_ax.set_xticklabels("") self.slice_slider_ax.set_yticklabels("") self.slice_slider = Slider(self.slice_slider_ax, "3-d slice", 0, self.array.shape[0], valinit=self.slice, valfmt="%i") self.slice_slider.on_changed(self.update_slice) self.slice_slider.drawon = False self.vmin_slider_ax = self.fig.add_axes([0.1, 0.90, 0.4, 0.03]) self.vmin_slider_ax.set_xticklabels("") self.vmin_slider_ax.set_yticklabels("") self.vmin_slider = Slider(self.vmin_slider_ax, "vmin", self._clim[0], self._clim[1], valinit=self._clim[0]) self.vmin_slider.on_changed(self.update_vmin) self.vmin_slider.drawon = False self.vmax_slider_ax = self.fig.add_axes([0.1, 0.85, 0.4, 0.03]) self.vmax_slider_ax.set_xticklabels("") self.vmax_slider_ax.set_yticklabels("") self.vmax_slider = Slider(self.vmax_slider_ax, "vmax", self._clim[0], self._clim[1], valinit=self._clim[1]) self.vmax_slider.on_changed(self.update_vmax) self.vmax_slider.drawon = False self.ax2 = self.fig.add_axes([0.6, 0.3, 0.35, 0.4]) self.ax2.add_collection(self.lines) self.selected_label = self.fig.text(0.6, 0.75, "No structure selected", fontsize=18) x = [p.vertices[:, 0] for p in self.lines.get_paths()] y = [p.vertices[:, 1] for p in self.lines.get_paths()] xmin = np.min(x) xmax = np.max(x) ymin = np.min(y) ymax = np.max(y) self.lines.set_picker(2.) dx = xmax - xmin self.ax2.set_xlim(xmin - dx * 0.1, xmax + dx * 0.1) self.ax2.set_ylim(ymin * 0.5, ymax * 2.0) self.ax2.set_yscale('log') self.fig.canvas.mpl_connect('pick_event', self.line_picker) self.fig.canvas.mpl_connect('button_press_event', self.select_from_map) plt.show() def update_slice(self, pos=None): if self.array.ndim == 2: self.image.set_array(self.array) else: self.slice = int(round(pos)) self.image.set_array(self.array[self.slice,:,:]) self.remove_contour() self.update_contour() self.fig.canvas.draw() def update_vmin(self, vmin): if vmin > self._clim[1]: self._clim = (self._clim[1], self._clim[1]) else: self._clim = (vmin, self._clim[1]) self.image.set_clim(*self._clim) self.fig.canvas.draw() def update_vmax(self, vmax): if vmax < self._clim[0]: self._clim = (self._clim[0], self._clim[0]) else: self._clim = (self._clim[0], vmax) self.image.set_clim(*self._clim) self.fig.canvas.draw() def select_from_map(self, event): # Only do this if no tools are currently selected if event.canvas.toolbar.mode != '': return if event.inaxes is self.ax1: # Find pixel co-ordinates of click ix = int(round(event.xdata)) iy = int(round(event.ydata)) if self.array.ndim == 2: indices = (iy, ix) else: indices = (self.slice, iy, ix) # Select the structure structure = self.dendrogram.node_at(indices) self.select(structure) # Re-draw event.canvas.draw() def line_picker(self, event): # Only do this if no tools are currently selected if event.canvas.toolbar.mode != '': return # event.ind gives the indices of the paths that have been selected # Find levels of selected paths peaks = [event.artist.structures[i].get_peak(subtree=True)[1] for i in event.ind] # Find position of minimum level (may be duplicates, let Numpy decide) ind = event.ind[np.argmax(peaks)] # Extract structure structure = event.artist.structures[ind] # If 3-d, select the slice if self.array.ndim == 3: peak_index = structure.get_peak(subtree=True) self.slice_slider.set_val(peak_index[0][0]) # Select the structure self.select(structure) # Re-draw event.canvas.draw() def select(self, structure): # Remove previously selected collection if self.selected_lines is not None: self.ax2.collections.remove(self.selected_lines) self.selected_lines = None self.remove_contour() if structure is None: self.selected_label.set_text("No structure selected") self.fig.canvas.draw() return self.selected = structure self.selected_label.set_text("Selected structure: {0}".format(structure.idx)) # Get collection for this substructure self.selected_lines = self.plotter.get_lines(structure=structure) self.selected_lines.set_color('red') self.selected_lines.set_linewidth(2) self.selected_lines.set_alpha(0.5) # Add to axes self.ax2.add_collection(self.selected_lines) self.update_contour() def remove_contour(self): if self.selected_contour is not None: for collection in self.selected_contour.collections: self.ax1.collections.remove(collection) self.selected_contour = None def update_contour(self): if self.selected is not None: mask = self.selected.get_mask(self.array.shape, subtree=True) if self.array.ndim == 3: mask = mask[self.slice, :, :] self.selected_contour = self.ax1.contour(mask, colors='red', linewidths=2, levels=[0.5], alpha=0.5)
class viscm_editor(object): def __init__(self, min_Jp=15, max_Jp=95, xp=None, yp=None): from .bezierbuilder import BezierModel, BezierBuilder axes = _viscm_editor_axes() ax_btn_wireframe = plt.axes([0.7, 0.15, 0.1, 0.025]) self.btn_wireframe = Button(ax_btn_wireframe, 'Show 3D gamut') self.btn_wireframe.on_clicked(self.plot_3d_gamut) ax_btn_wireframe = plt.axes([0.81, 0.15, 0.1, 0.025]) self.btn_save = Button(ax_btn_wireframe, 'Save colormap') self.btn_save.on_clicked(self.save_colormap) ax_btn_props = plt.axes([0.81, 0.1, 0.1, 0.025]) self.btn_props = Button(ax_btn_props, 'Properties') self.btn_props.on_clicked(self.show_viscm) self.prop_windows = [] axcolor = 'None' ax_jp_min = plt.axes([0.1, 0.1, 0.5, 0.03], axisbg=axcolor) ax_jp_min.imshow(np.linspace(0, 100, 101).reshape(1, -1), cmap='gray') ax_jp_min.set_xlim(0, 100) ax_jp_max = plt.axes([0.1, 0.15, 0.5, 0.03], axisbg=axcolor) ax_jp_max.imshow(np.linspace(0, 100, 101).reshape(1, -1), cmap='gray') self.jp_min_slider = Slider(ax_jp_min, r"$J'_\mathrm{min}$", 0, 100, valinit=min_Jp) self.jp_max_slider = Slider(ax_jp_max, r"$J'_\mathrm{max}$", 0, 100, valinit=max_Jp) self.jp_min_slider.on_changed(self._jp_update) self.jp_max_slider.on_changed(self._jp_update) # This is my favorite set of control points so far (just from playing # around with things): # min_Jp = 15 # max_Jp = 95 # xp = # [-4, 27.041103603603631, 84.311067635550557, 12.567076579094476, -9.6] # yp = # [-34, -41.447876447876524, 36.28563443264386, 25.357741755170423, 41] # -- njs, 2015-04-05 if xp is None: xp = [-4, 38.289146128951984, 52.1923711457504, 39.050944362271053, 18.60872492130315, -9.6] if yp is None: yp = [-34, -34.34528254916614, -21.594701710471412, 31.701084689194829, 29.510846891948262, 41] self.bezier_model = BezierModel(xp, yp) self.cmap_model = BezierCMapModel(self.bezier_model, self.jp_min_slider.val, self.jp_max_slider.val) self.highlight_point_model = HighlightPointModel(self.cmap_model, 0.5) self.bezier_builder = BezierBuilder(axes['bezier'], self.bezier_model) self.bezier_gamut_viewer = GamutViewer2D(axes['bezier'], self.highlight_point_model) tmp = HighlightPoint2DView(axes['bezier'], self.highlight_point_model) self.bezier_highlight_point_view = tmp #draw_pure_hue_angles(axes['bezier']) axes['bezier'].set_xlim(-100, 100) axes['bezier'].set_ylim(-100, 100) self.cmap_view = CMapView(axes['cm'], self.cmap_model) self.cmap_highlighter = HighlightPointBuilder( axes['cm'], self.highlight_point_model) print("Click sliders at bottom to change min/max lightness") print("Click on colorbar to adjust gamut view") print("Click-drag to move control points, ") print(" shift-click to add, control-click to delete") def plot_3d_gamut(self, event): fig, ax = plt.subplots(subplot_kw=dict(projection='3d')) self.wireframe_view = WireframeView(ax, self.cmap_model, self.highlight_point_model) plt.show() def save_colormap(self, event): import textwrap template = textwrap.dedent(''' from matplotlib.colors import LinearSegmentedColormap from numpy import nan, inf # Used to reconstruct the colormap in pycam02ucs.cm.viscm parameters = {{'xp': {xp}, 'yp': {yp}, 'min_Jp': {min_Jp}, 'max_Jp': {max_Jp}}} cm_data = {array_list} test_cm = LinearSegmentedColormap.from_list(__file__, cm_data) if __name__ == "__main__": import matplotlib.pyplot as plt import numpy as np try: from pycam02ucs.cm.viscm import viscm viscm(test_cm) except ImportError: print("pycam02ucs not found, falling back on simple display") plt.imshow(np.linspace(0, 100, 256)[None, :], aspect='auto', cmap=test_cm) plt.show() ''') rgb, _ = self.cmap_model.get_sRGB(num=256) with open('/tmp/new_cm.py', 'w') as f: array_list = np.array_repr(rgb, max_line_width=78) array_list = array_list.replace('array(', '')[:-1] xp, yp = self.cmap_model.bezier_model.get_control_points() data = dict(array_list=array_list, xp=xp, yp=yp, min_Jp=self.cmap_model.min_Jp, max_Jp=self.cmap_model.max_Jp) f.write(template.format(**data)) print("*" * 50) print("Saved colormap to /tmp/new_cm.py") print("*" * 50) def show_viscm(self, event): cm = LinearSegmentedColormap.from_list( 'test_cm', self.cmap_model.get_sRGB(num=256)[0]) self.prop_windows.append(viscm(cm, name='test_cm')) plt.show() def _jp_update(self, val): jp_min = self.jp_min_slider.val jp_max = self.jp_max_slider.val smallest, largest = min(jp_min, jp_max), max(jp_min, jp_max) if (jp_min > smallest) or (jp_max < largest): self.jp_min_slider.set_val(smallest) self.jp_max_slider.set_val(largest) self.cmap_model.set_Jp_minmax(smallest, largest)
class GUI(animation.TimedAnimation): """ interface for viewing a movie and its associated roi, traces, other things implementation is only through matplotlib, and is non-blocking backend affects performance somewhat dramatically. have achieved decent performance with qt5agg and tkagg """ def __init__(self, mov, roi, traces, images={}, cmap=pl.cm.viridis, **kwargs): """ Parameters: mov : 3d np array, 0'th axis is time/frames roi : 3d np array, one roi per item in 0'th axis, each of which is a True/False mask indicating roi (True=inside roi) traces : 2d np array, 0'th axis is time, 1st axis is sources images: dictionary of still images Attributes: roi_kept : boolean array of length of supplied roi, indicating whether or not roi should be kept based on user input """ self.mov = mov self.roi_idxs = np.array([np.argwhere(r.flat).squeeze() for r in roi]) self.roi_centers = np.array([np.mean(np.argwhere(r),axis=0) for r in roi]) self.roi_orig = roi.copy() self.roi = pretty_roi(roi) self.roi_kept = np.ones(len(self.roi_idxs)).astype(bool) self.traces = traces self.images = images # figure setup self.cmap = cmap self.fig = pl.figure() NR,NC = 128,32 gs = gridspec.GridSpec(nrows=NR, ncols=NC) gs.update(wspace=0.1, hspace=0.1, left=.04, right=.96, top=.98, bottom=.02) # movie axes self.ax_contrast0 = self.fig.add_subplot(gs[0:5,0:NC//3]) self.ax_contrast1 = self.fig.add_subplot(gs[5:10,0:NC//3]) self.ax_mov = self.fig.add_subplot(gs[10:55,0:NC//3]) self.ax_mov.axis('off') self.ax_img = self.fig.add_subplot(gs[55:100,0:NC//3]) self.ax_img.axis('off') self.axs_imbuts = [self.fig.add_subplot(gs[110:128,idx*2:idx*2+2]) for idx,i in enumerate(self.images)] # trace axes self.ax_trcs = self.fig.add_subplot(gs[0:64,NC//2:]) self.ax_trc = self.fig.add_subplot(gs[65:85,NC//2:]) self.ax_trc.set_xlim([0, len(self.traces)]) self.ax_nav = self.fig.add_subplot(gs[85:90,NC//2:]) self.ax_nav.set_xlim([0, len(self.traces)]) self.ax_nav.axis('off') self.ax_rm = self.fig.add_subplot(gs[95:110,NC//2:]) # interactivity self.c0,self.c1= 0,100 self.sl_contrast0 = Slider(self.ax_contrast0, 'Low', 0., 255.0, valinit=self.c0, valfmt='%d') self.sl_contrast1 = Slider(self.ax_contrast1, 'Hi', 0., 255.0, valinit=self.c1, valfmt='%d') self.sl_contrast0.on_changed(self.evt_contrast) self.sl_contrast1.on_changed(self.evt_contrast) self.img_buttons = [Button(ax,k) for k,ax in zip(list(self.images.keys()),self.axs_imbuts)] self.but_rm = Button(self.ax_rm, 'Remove All ROIs Currently in FOV') # display initial things self.movdata = self.ax_mov.imshow(self.mov[0]) self.movdata.set_animated(True) self.roidata = self.ax_mov.imshow(self.roi, cmap=self.cmap, alpha=0.5, vmin=np.nanmin(self.roi), vmax=np.nanmax(self.roi)) self.trdata, = self.ax_trc.plot(np.zeros(len(self.traces))) self.navdata, = self.ax_nav.plot([-2,-2],[-1,np.max(self.traces)], 'r-') if len(self.images): lab,im = list(self.images.items())[0] self.imgdata = self.ax_img.imshow(im) self.ax_img.set_ylabel(lab) self.plot_current_traces() # callbacks for ib,lab in zip(self.img_buttons,list(self.images.keys())): ib.on_clicked(lambda evt, lab=lab: self.evt_imbut(evt,lab)) self.but_rm.on_clicked(self.remove_roi) self.fig.canvas.mpl_connect('button_press_event', self.evt_click) self.ax_mov.callbacks.connect('xlim_changed', self.evt_zoom) self.ax_mov.callbacks.connect('ylim_changed', self.evt_zoom) # runtime self._idx = -1 self.t0 = time.clock() self.always_draw = [self.movdata, self.roidata, self.navdata] self.blit_clear_axes = [self.ax_mov, self.ax_nav] # parent init animation.TimedAnimation.__init__(self, self.fig, interval=40, blit=True, **kwargs) @property def frame_seq(self): #print (time.clock()-self.t0) self._idx += 1 if self._idx == len(self.mov): self._idx = 0 self.navdata.set_xdata([self._idx, self._idx]) yield self.mov[self._idx] @frame_seq.setter def frame_seq(self, val): pass def new_frame_seq(self): return self.mov def _init_draw(self): self._draw_frame(self.mov[0]) self._drawn_artists = self.always_draw def _draw_frame(self, d): self.t0 = time.clock() self.movdata.set_data(d) # blit self._drawn_artists = self.always_draw for da in self._drawn_artists: da.set_animated(True) def _blit_clear(self, artists, bg_cache): for ax in self.blit_clear_axes: if ax in bg_cache: self.fig.canvas.restore_region(bg_cache[ax]) def evt_contrast(self, val): self.c0 = self.sl_contrast0.val self.c1 = self.sl_contrast1.val if self.c0 > self.c1: self.c0 = self.c1-1 self.sl_contrast0.set_val(self.c0) if self.c1 < self.c0: self.c1 = self.c0+1 self.sl_contrast1.set_val(self.c1) self.movdata.set_clim(vmin=self.c0, vmax=self.c1) self.imgdata.set_clim(vmin=self.c0, vmax=self.c1) def evt_imbut(self, evt, lab): self.imgdata.set_data(self.images[lab]) self.ax_img.set_title(lab) def evt_click(self, evt): if not evt.inaxes: return elif evt.inaxes == self.ax_mov: # select roi x,y = int(np.round(evt.xdata)), int(np.round(evt.ydata)) idx = np.ravel_multi_index((y,x), self.roi.shape) inside = np.argwhere([idx in ri for ri in self.roi_idxs]) if len(inside)==0: return i = inside[0] self.set_current_trace(i) elif evt.inaxes in [self.ax_nav]: x = int(np.round(evt.xdata)) self._idx = x def evt_zoom(self, *args): self.plot_current_traces() def set_current_trace(self, idx): col = self.cmap(np.linspace(0,1,np.sum(self.roi_kept)))[np.squeeze(idx)] t = self.traces[:,idx] self.trdata.set_ydata(t) self.trdata.set_color(col) self.ax_trc.set_ylim([t.min(), t.max()]) self.ax_trc.set_title('ROI {}'.format(idx)) self.ax_trc.figure.canvas.draw() def get_current_roi(self): croi = np.array([isin(rc,self.ax_mov) for rc in self.roi_centers]) croi[self.roi_kept==False] = False return croi def remove_roi(self, evt): self.current_roi = self.get_current_roi() self.roi_kept[self.current_roi] = False # update if np.sum(self.roi_kept): proi = pretty_roi(self.roi_orig[self.roi_kept]) self.roidata.set_data(proi) self.roidata.set_clim(vmin=np.nanmin(proi), vmax=np.nanmax(proi)) else: self.roidata.remove() def plot_current_traces(self): self.current_roi = self.get_current_roi() if np.sum(self.current_roi)==0: return for line in self.ax_trcs.get_lines(): line.remove() cols = self.cmap(np.linspace(0,1,len(self.roi_idxs)))[self.current_roi] lastmax = 0 for t,c in zip(self.traces.T[self.current_roi],cols): self.ax_trcs.plot((t-t.min())+lastmax, color=c) lastmax = np.max(t) self.ax_trcs.set_ylim([0,lastmax])
def view_patches_bar(Yr, A, C, b, f, d1, d2, YrA=None, img=None): """view spatial and temporal components interactively Parameters: ----------- Yr: np.ndarray movie in format pixels (d) x frames (T) A: sparse matrix matrix of spatial components (d x K) C: np.ndarray matrix of temporal components (K x T) b: np.ndarray spatial background (vector of length d) f: np.ndarray temporal background (vector of length T) d1,d2: np.ndarray frame dimensions YrA: np.ndarray ROI filtered residual as it is given from update_temporal_components If not given, then it is computed (K x T) img: np.ndarray background image for contour plotting. Default is the image of all spatial components (d1 x d2) """ pl.ion() if 'csc_matrix' not in str(type(A)): A = csc_matrix(A) if 'array' not in str(type(b)): b = b.toarray() nr, T = C.shape nb = f.shape[0] nA2 = np.sqrt(np.array(A.power(2).sum(axis=0))).squeeze() if YrA is None: Y_r = spdiags(old_div(1, nA2), 0, nr, nr) * (A.T.dot(Yr) - (A.T.dot(b)).dot(f) - (A.T.dot(A)).dot(C)) + C else: Y_r = YrA + C if img is None: img = np.reshape(np.array(A.mean(axis=1)), (d1, d2), order='F') fig = pl.figure(figsize=(10, 10)) axcomp = pl.axes([0.05, 0.05, 0.9, 0.03]) ax1 = pl.axes([0.05, 0.55, 0.4, 0.4]) ax3 = pl.axes([0.55, 0.55, 0.4, 0.4]) ax2 = pl.axes([0.05, 0.1, 0.9, 0.4]) s_comp = Slider(axcomp, 'Component', 0, nr + nb - 1, valinit=0) vmax = np.percentile(img, 95) def update(val): i = np.int(np.round(s_comp.val)) print(('Component:' + str(i))) if i < nr: ax1.cla() imgtmp = np.reshape(A[:, i].toarray(), (d1, d2), order='F') ax1.imshow(imgtmp, interpolation='None', cmap=pl.cm.gray, vmax=np.max(imgtmp)*0.5) ax1.set_title('Spatial component ' + str(i + 1)) ax1.axis('off') ax2.cla() ax2.plot(np.arange(T), Y_r[i], 'c', linewidth=3) ax2.plot(np.arange(T), C[i], 'r', linewidth=2) ax2.set_title('Temporal component ' + str(i + 1)) ax2.legend(labels=['Filtered raw data', 'Inferred trace']) ax3.cla() ax3.imshow(img, interpolation='None', cmap=pl.cm.gray, vmax=vmax) imgtmp2 = imgtmp.copy() imgtmp2[imgtmp2 == 0] = np.nan ax3.imshow(imgtmp2, interpolation='None', alpha=0.5, cmap=pl.cm.hot) ax3.axis('off') else: ax1.cla() bkgrnd = np.reshape(b[:, i - nr], (d1, d2), order='F') ax1.imshow(bkgrnd, interpolation='None') ax1.set_title('Spatial background ' + str(i + 1 - nr)) ax1.axis('off') ax2.cla() ax2.plot(np.arange(T), np.squeeze(np.array(f[i - nr, :]))) ax2.set_title('Temporal background ' + str(i + 1 - nr)) def arrow_key_image_control(event): if event.key == 'left': new_val = np.round(s_comp.val - 1) if new_val < 0: new_val = 0 s_comp.set_val(new_val) elif event.key == 'right': new_val = np.round(s_comp.val + 1) if new_val > nr + nb: new_val = nr + nb s_comp.set_val(new_val) else: pass s_comp.on_changed(update) s_comp.set_val(0) fig.canvas.mpl_connect('key_release_event', arrow_key_image_control) pl.show()
RmaxV = Slider(Rmax, 'Rmax', 1, 254, valinit=rgbinit[0]) RminV = Slider(Rmin, 'Rmin', 1, 254, valinit=rgbinit[1]) GmaxV = Slider(Gmax, 'Gmax', 1, 254, valinit=rgbinit[2]) GminV = Slider(Gmin, 'Gmin', 1, 254, valinit=rgbinit[3]) BmaxV = Slider(Bmax, 'Bmax', 1, 254, valinit=rgbinit[4]) BminV = Slider(Bmin, 'Bmin', 1, 254, valinit=rgbinit[5]) RmaxV.on_changed(sliceupdateRmax) RminV.on_changed(sliceupdateRmin) GmaxV.on_changed(sliceupdateGmax) GminV.on_changed(sliceupdateGmin) BmaxV.on_changed(sliceupdateBmax) BminV.on_changed(sliceupdateBmin) ff=1 else: RmaxV.set_val(rgbinit[0]) RminV.set_val(rgbinit[1]) GmaxV.set_val(rgbinit[2]) GminV.set_val(rgbinit[3]) BmaxV.set_val(rgbinit[4]) BminV.set_val(rgbinit[5]) file = open("rgb.txt", "w") file.write(str(rgbinit)) file.close() print rgbinit plt.pause(0.001) plt.show(block=False) #print samp.val,sfreq.val
def view_patches_bar(Yr, A, C, b, f, d1, d2, YrA=None, secs=1, img=None): """view spatial and temporal components interactively Parameters ----------- Yr: np.ndarray movie in format pixels (d) x frames (T) A: sparse matrix matrix of spatial components (d x K) C: np.ndarray matrix of temporal components (K x T) b: np.ndarray spatial background (vector of length d) f: np.ndarray temporal background (vector of length T) d1,d2: np.ndarray frame dimensions YrA: np.ndarray ROI filtered residual as it is given from update_temporal_components If not given, then it is computed (K x T) img: np.ndarray background image for contour plotting. Default is the image of all spatial components (d1 x d2) """ plt.ion() nr, T = C.shape A2 = A.copy() A2.data **= 2 nA2 = np.sqrt(np.array(A2.sum(axis=0))).squeeze() #A = A*spdiags(1/nA2,0,nr,nr) #C = spdiags(nA2,0,nr,nr)*C b = np.squeeze(b) f = np.squeeze(f) if YrA is None: Y_r = np.array(A.T * np.matrix(Yr) - (A.T * np.matrix(b[:, np.newaxis])) * np.matrix( f[np.newaxis]) - (A.T.dot(A)) * np.matrix(C) + C) else: Y_r = YrA + C A = A * spdiags(1 / nA2, 0, nr, nr) A = A.todense() imgs = np.reshape(np.array(A), (d1, d2, nr), order='F') if img is None: img = np.mean(imgs[:, :, :-1], axis=-1) bkgrnd = np.reshape(b, (d1, d2), order='F') fig = plt.figure(figsize=(10, 10)) axcomp = plt.axes([0.05, 0.05, 0.9, 0.03]) ax1 = plt.axes([0.05, 0.55, 0.4, 0.4]) # ax1.axis('off') ax3 = plt.axes([0.55, 0.55, 0.4, 0.4]) # ax1.axis('off') ax2 = plt.axes([0.05, 0.1, 0.9, 0.4]) # axcolor = 'lightgoldenrodyellow' # axcomp = plt.axes([0.25, 0.1, 0.65, 0.03], axisbg=axcolor) s_comp = Slider(axcomp, 'Component', 0, nr, valinit=0) vmax = np.percentile(img, 98) def update(val): i = np.int(np.round(s_comp.val)) print 'Component:' + str(i) if i < nr: ax1.cla() imgtmp = imgs[:, :, i] ax1.imshow(imgtmp, interpolation='None', cmap=plt.cm.gray) ax1.set_title('Spatial component ' + str(i + 1)) ax1.axis('off') ax2.cla() ax2.plot(np.arange(T), np.squeeze(np.array(Y_r[i, :])), 'c', linewidth=3) ax2.plot(np.arange(T), np.squeeze(np.array(C[i, :])), 'r', linewidth=2) ax2.set_title('Temporal component ' + str(i + 1)) ax2.legend(labels=['Filtered raw data', 'Inferred trace']) ax3.cla() ax3.imshow(img, interpolation='None', cmap=plt.cm.gray, vmax=vmax) imgtmp2 = imgtmp.copy() imgtmp2[imgtmp2 == 0] = np.nan ax3.imshow(imgtmp2, interpolation='None', alpha=0.5, cmap=plt.cm.hot) else: ax1.cla() ax1.imshow(bkgrnd, interpolation='None') ax1.set_title('Spatial background background') ax2.cla() ax2.plot(np.arange(T), np.squeeze(np.array(f))) ax2.set_title('Temporal background') def arrow_key_image_control(event): if event.key == 'left': new_val = np.round(s_comp.val - 1) if new_val < 0: new_val = 0 s_comp.set_val(new_val) elif event.key == 'right': new_val = np.round(s_comp.val + 1) if new_val > nr: new_val = nr s_comp.set_val(new_val) else: pass s_comp.on_changed(update) s_comp.set_val(0) id2 = fig.canvas.mpl_connect('key_release_event', arrow_key_image_control) plt.show()
class GridCircle: """ Provided a 2D grid, this object: > plots the grid with a circle which centre and radius are adjustable, > plots values of the grid along the circle alongside the grid. """ def __init__(self, grid, extent=(-1, 1, -1, 1), circle_centre=(0, 0), min=None, max=None, points_theta=100, linear_interpolation=False, show_slider=True): """ Parameters ---------- grid : 2D array-like Grid to plot values from. extent : scalars (left, right, bottom, top) Values of space variables at corners. (default: (-1, 1, -1, 1)) circle_centre : scalars (x, y) Location of the centre of the circle to draw on top of grid. min : float Minimum value for the colormap. (default: None) NOTE: None will be considered as the minimum being the minimum value of grid. max : float Maximum value for the colormap. (default: None) NOTE: None will be considered as the maximum being the maximum value of grid. points_theta : int Number of points to consider in the interval [0, 2\\pi] when computing values along circle. linear_interpolation : bool Get value by linear interpolation of neighbouring grid boxes. (default: False) show_slider : bool Display circle radius slider. """ self.circle_centre = np.array(circle_centre) self.radius = 0 # radius of the circle self.points_theta = points_theta self.theta = np.linspace(0, 2 * np.pi, points_theta) self.linear_interpolation = linear_interpolation self.show_slider = show_slider self.fig, (self.ax_grid, self.ax_plot) = plt.subplots( 1, 2 ) # matplotlib.figure.Figure object and matplotlib.axes.Axes objects for grid and value plot # COLORMAP self.min = np.min(grid) if min == None else min self.max = np.max(grid) if max == None else max self.norm = colors.Normalize(vmin=self.min, vmax=self.max) # normalises data self.scalarmap = cmx.ScalarMappable( norm=self.norm, cmap=cmap) # scalar map for grid values # PLOT self.ax_plot.set_ylim( [self.min, self.max]) # setting y-axis limit of plot as grid extrema self.ax_plot.set_xlim([0, 2 * np.pi]) # angle on the cirlce self.line, = self.ax_plot.plot( np.linspace(0, 2 * np.pi, self.points_theta), [0] * self.points_theta) # plot of values along circle # SLIDER self.extent = np.array(extent) # grid extent if self.show_slider: self.slider_ax = make_axes_locatable(self.ax_plot).append_axes( 'bottom', size='5%', pad=0.5) # slider axes self.slider = Slider(self.slider_ax, 'radius', 0, np.min(np.abs(self.extent)), valinit=self.radius) # slider self.slider.on_changed( self.update_slider ) # call self.update_slider() on slider update # GRID #grid = np.array(grid) self.grid_plot = self.ax_grid.imshow(grid, cmap=cmap, norm=self.norm, extent=self.extent) # grid plot self.colormap_ax = make_axes_locatable(self.ax_grid).append_axes( 'right', size='5%', pad=0.05) # color map axes self.colormap = mpl.colorbar.ColorbarBase( self.colormap_ax, cmap=cmap, norm=self.norm, orientation='vertical') # color map self.circle = plt.Circle(self.circle_centre, self.radius, color='black', fill=False) # circle on grid self.ax_grid.add_artist(self.circle) self.ax_grid.figure.canvas.mpl_connect( 'button_press_event', self.update_grid) # call self.update_grid() on button press event self.update_grid_plot(grid) # plots grid and updates circle and plot def get_fig_ax_cmap(self): """ Returns ------- fig : matplotlib.pyplot.figure object Figure. (ax_grid, ax_plot) : matplotlib.axes.Axes tuple Grid and plot axes. colormap : matplotlib.colorbar.ColorbarBase object Color map. """ return self.fig, (self.ax_grid, self.ax_plot), self.colormap def update_grid_plot(self, grid, extent=None): """ Plots grid. Parameters ---------- grid : 2D array-like Grid to plot values from. extent : scalars (left, right, bottom, top) Values of space variables at corners. (default: None) NOTE: None will be considered as extent to be self.extent. """ if extent != None: self.extent = np.array(extent) self.grid = Grid(grid, extent=self.extent) self.grid_plot.set_data(self.grid.grid) # plots grid self.grid_plot.set_extent(self.extent) # set extent self.draw() # updates circle and plot def update_grid(self, event): """ Executes on click on figure. Updates radius of cirlce on figure. """ if event.inaxes != self.circle.axes: return # if Axes instance mouse is over is different than circle's figure Axes self.radius = np.sqrt( np.sum( (np.array([event.xdata, event.ydata]) - self.circle_centre)**2) ) # radius set to distance between centre of circle and clicked point self.slider.set_val(self.radius) # updates slider value self.draw() # updates figure def update_slider(self, event): """ Executes on slider change. Updates radius of circle on figure. """ self.radius = self.slider.val # radius set to slider value self.draw() # updates figure def draw(self): """ Updates figure. """ self.line.set_ydata( list( map( lambda angle: self.grid.get_value_polar( self.radius, angle, centre=self.circle_centre, linear_interpolation=self.linear_interpolation), self.theta))) # values of the grid along the circle self.circle.set_radius(self.radius) # adjusting circle radius self.ax_grid.figure.canvas.draw() # updating grid self.ax_plot.figure.canvas.draw() # updating plot
class BasicDendrogramViewer(object): def __init__(self, dendrogram): if dendrogram.data.ndim not in [2, 3]: raise ValueError( "Only 2- and 3-dimensional arrays are supported at this time") self.hub = SelectionHub() self._connect_to_hub() self.array = dendrogram.data self.dendrogram = dendrogram self.plotter = DendrogramPlotter(dendrogram) self.plotter.sort(reverse=True) # Get the lines as individual elements, and the mapping from line to structure self.lines = self.plotter.get_lines(edgecolor='k') # Define the currently selected subtree self.selected_lines = {} self.selected_contour = {} # The keys in these dictionaries are event button IDs. # Initiate plot import matplotlib.pyplot as plt self.fig = plt.figure(figsize=(14, 8)) ax_image_limits = [0.1, 0.1, 0.4, 0.7] try: from wcsaxes import WCSAxes __wcaxes_imported = True except ImportError: __wcaxes_imported = False if self.dendrogram.wcs is not None: warnings.warn("`WCSAxes` package required for wcs coordinate display.") if self.dendrogram.wcs is not None and __wcaxes_imported: if self.array.ndim == 2: slices = ('x', 'y') else: slices = ('x', 'y', 1) ax_image = WCSAxes(self.fig, ax_image_limits, wcs=self.dendrogram.wcs, slices=slices) self.ax_image = self.fig.add_axes(ax_image) else: self.ax_image = self.fig.add_axes(ax_image_limits) from matplotlib.widgets import Slider self._clim = (np.min(self.array[~np.isnan(self.array) & ~np.isinf(self.array)]), np.max(self.array[~np.isnan(self.array) & ~np.isinf(self.array)])) if self.array.ndim == 2: self.slice = None self.image = self.ax_image.imshow(self.array, origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray) self.slice_slider = None else: if self.array.shape[0] > 1: self.slice = int(round(self.array.shape[0] / 2.)) self.slice_slider_ax = self.fig.add_axes([0.1, 0.95, 0.4, 0.03]) self.slice_slider_ax.set_xticklabels("") self.slice_slider_ax.set_yticklabels("") self.slice_slider = Slider(self.slice_slider_ax, "3-d slice", 0, self.array.shape[0], valinit=self.slice, valfmt="%i") self.slice_slider.on_changed(self.update_slice) self.slice_slider.drawon = False else: self.slice = 0 self.slice_slider = None self.image = self.ax_image.imshow(self.array[self.slice, :,:], origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray) self.vmin_slider_ax = self.fig.add_axes([0.1, 0.90, 0.4, 0.03]) self.vmin_slider_ax.set_xticklabels("") self.vmin_slider_ax.set_yticklabels("") self.vmin_slider = Slider(self.vmin_slider_ax, "vmin", self._clim[0], self._clim[1], valinit=self._clim[0]) self.vmin_slider.on_changed(self.update_vmin) self.vmin_slider.drawon = False self.vmax_slider_ax = self.fig.add_axes([0.1, 0.85, 0.4, 0.03]) self.vmax_slider_ax.set_xticklabels("") self.vmax_slider_ax.set_yticklabels("") self.vmax_slider = Slider(self.vmax_slider_ax, "vmax", self._clim[0], self._clim[1], valinit=self._clim[1]) self.vmax_slider.on_changed(self.update_vmax) self.vmax_slider.drawon = False self.ax_dendrogram = self.fig.add_axes([0.6, 0.3, 0.35, 0.4]) self.ax_dendrogram.add_collection(self.lines) self.selected_label = {} # map selection IDs -> text objects self.selected_label[1] = self.fig.text(0.6, 0.85, "No structure selected", fontsize=18, color=self.hub.colors[1]) self.selected_label[2] = self.fig.text(0.6, 0.8, "No structure selected", fontsize=18, color=self.hub.colors[2]) self.selected_label[3] = self.fig.text(0.6, 0.75, "No structure selected", fontsize=18, color=self.hub.colors[3]) x = [p.vertices[:, 0] for p in self.lines.get_paths()] y = [p.vertices[:, 1] for p in self.lines.get_paths()] xmin = np.min(x) xmax = np.max(x) ymin = np.min(y) ymax = np.max(y) self.lines.set_picker(2.) self.lines.set_zorder(0) dx = xmax - xmin self.ax_dendrogram.set_xlim(xmin - dx * 0.1, xmax + dx * 0.1) self.ax_dendrogram.set_ylim(ymin * 0.5, ymax * 2.0) self.ax_dendrogram.set_yscale('log') self.fig.canvas.mpl_connect('pick_event', self.line_picker) self.fig.canvas.mpl_connect('button_press_event', self.select_from_map) def show(self): import matplotlib.pyplot as plt plt.show() def update_slice(self, pos=None): if self.array.ndim == 2: self.image.set_array(self.array) else: self.slice = int(round(pos)) self.image.set_array(self.array[self.slice, :, :]) self.update_contours() self.fig.canvas.draw() def _connect_to_hub(self): self.hub.add_callback(self._on_selection_change) def _on_selection_change(self, selection_id): self._update_lines(selection_id) self.update_contours() self.fig.canvas.draw() def update_vmin(self, vmin): if vmin > self._clim[1]: self._clim = (self._clim[1], self._clim[1]) else: self._clim = (vmin, self._clim[1]) self.image.set_clim(*self._clim) self.fig.canvas.draw() def update_vmax(self, vmax): if vmax < self._clim[0]: self._clim = (self._clim[0], self._clim[0]) else: self._clim = (self._clim[0], vmax) self.image.set_clim(*self._clim) self.fig.canvas.draw() def select_from_map(self, event): # Only do this if no tools are currently selected if event.canvas.toolbar.mode != '': return if event.button not in self.selected_label: return if event.inaxes is self.ax_image: input_key = event.button # Find pixel co-ordinates of click ix = int(round(event.xdata)) iy = int(round(event.ydata)) if self.array.ndim == 2: indices = (iy, ix) else: indices = (self.slice, iy, ix) # Select the structure structure = self.dendrogram.structure_at(indices) self.hub.select(input_key, structure) # Re-draw event.canvas.draw() def line_picker(self, event): # Only do this if no tools are currently selected if event.canvas.toolbar.mode != '': return if event.mouseevent.button not in self.selected_label: return input_key = event.mouseevent.button # event.ind gives the indices of the paths that have been selected # Find levels of selected paths peaks = [event.artist.structures[i].get_peak(subtree=True)[1] for i in event.ind] # Find position of minimum level (may be duplicates, let Numpy decide) ind = event.ind[np.argmax(peaks)] # Extract structure structure = event.artist.structures[ind] # If 3-d, select the slice if self.slice_slider is not None: peak_index = structure.get_peak(subtree=True) self.slice_slider.set_val(peak_index[0][0]) # Select the structure self.hub.select(input_key, structure) # Re-draw event.canvas.draw() def _update_lines(self, selection_id): structures = self.hub.selections[selection_id] select_subtree = self.hub.select_subtree[selection_id] structure = structures[0] # Remove previously selected collection if selection_id in self.selected_lines: self.ax_dendrogram.collections.remove(self.selected_lines[selection_id]) del self.selected_lines[selection_id] if structure is None: self.selected_label[selection_id].set_text("No structure selected") self.remove_contour(selection_id) self.fig.canvas.draw() return self.remove_all_contours() if len(structures) <= 1: label_text = "Selected structure: {0}".format(structure.idx) elif len(structures) <=3: label_text = "Selected structures: {0}".format(', '.join([str(structure.idx) for structure in structures])) else: label_text = "Selected structures: {0}...".format(', '.join([str(structure.idx) for structure in structures[:3]])) self.selected_label[selection_id].set_text(label_text) # Get collection for this substructure self.selected_lines[selection_id] = self.plotter.get_lines( structures=structures, subtree=select_subtree) self.selected_lines[selection_id].set_color(self.hub.colors[selection_id]) self.selected_lines[selection_id].set_linewidth(2) self.selected_lines[selection_id].set_zorder(structure.height) # Add to axes self.ax_dendrogram.add_collection(self.selected_lines[selection_id]) def remove_contour(self, selection_id): if selection_id in self.selected_contour: for collection in self.selected_contour[selection_id].collections: self.ax_image.collections.remove(collection) del self.selected_contour[selection_id] def remove_all_contours(self): """ Remove all selected contours. """ for key in self.selected_contour.keys(): self.remove_contour(key) def update_contours(self): self.remove_all_contours() for selection_id in self.hub.selections.keys(): structures = self.hub.selections[selection_id] select_subtree = self.hub.select_subtree[selection_id] struct = structures[0] if struct is None: continue if select_subtree: mask = struct.get_mask(subtree=True) else: mask = reduce(np.add, [structure.get_mask(subtree=True) for structure in structures]) if self.array.ndim == 3: mask = mask[self.slice, :, :] self.selected_contour[selection_id] = self.ax_image.contour( mask, colors=self.hub.colors[selection_id], linewidths=2, levels=[0.5], alpha=0.75, zorder=struct.height)
class Visualizer: def __init__(self, field, fieldname, halospec=None): """Initializes a visualization instance, that is a windows with a field field is a 3D numpy array fieldname is a string with the name of the field halospec is a 2x2 array with the definition of the halo size After this call the window is shown """ self.field = field self.fieldname = fieldname # Register halo information if halospec is None: halospec = [[3, 3], [3, 3]] self.istart = halospec[0][0] self.iend = field.shape[0] - halospec[0][1] self.jstart = halospec[1][0] self.jend = field.shape[1] - halospec[1][1] self.plotHalo = True self.plotLogLog = False self.curklevel = 0 self.figure = plt.figure() # Slider slideraxes = plt.axes([0.15, 0.02, 0.5, 0.03], axisbg="lightgoldenrodyellow") self.slider = Slider(slideraxes, "K level", 0, field.shape[2] - 1, valinit=0) self.slider.valfmt = "%2d" self.slider.set_val(0) self.slider.on_changed(self.updateSlider) # CheckButton self.cbaxes = plt.axes([0.8, -0.04, 0.12, 0.15]) self.cbaxes.set_axis_off() self.cb = CheckButtons(self.cbaxes, ("Halo", "Logscale"), (self.plotHalo, self.plotLogLog)) self.cb.on_clicked(self.updateButton) # Initial plot self.fieldaxes = self.figure.add_axes([0.1, 0.15, 0.9, 0.75]) self.collection = plt.pcolor(self._getField(), axes=self.fieldaxes) self.colorbar = plt.colorbar() self.fieldaxes.set_xlim(right=self._getField().shape[1]) self.fieldaxes.set_ylim(top=self._getField().shape[0]) plt.xlabel("i") plt.ylabel("j") self.title = plt.title("%s - Level 0" % (fieldname,)) plt.show(block=False) def updateSlider(self, val): if val == self.curklevel: return self.curklevel = round(val) self.title.set_text("%s - Level %d" % (self.fieldname, self.curklevel)) # Draw new field level field = self._getField() size = field.shape[0] * field.shape[1] array = field.reshape(size) self.collection.set_array(array) self.colorbar.set_clim(vmin=field.min(), vmax=field.max()) self.collection.set_clim(vmin=field.min(), vmax=field.max()) self.colorbar.update_normal(self.collection) self.figure.canvas.draw_idle() def updateButton(self, label): if label == "Halo": self.plotHalo = not self.plotHalo if label == "Logscale": self.plotLogLog = not self.plotLogLog self.updatePlot() def updatePlot(self): # Redraw field self.collection.remove() field = self._getField() if self.plotLogLog: minvalue = field.min() norm = SymLogNorm(linthresh=1e-10) self.collection = plt.pcolor(field, axes=self.fieldaxes, norm=norm) self.colorbar.set_clim(vmin=minvalue, vmax=field.max()) else: self.collection = plt.pcolor(field, axes=self.fieldaxes) self.colorbar.set_clim(vmin=field.min(), vmax=field.max()) self.colorbar.set_norm(norm=Normalize(vmin=field.min(), vmax=field.max())) self.fieldaxes.set_xlim(right=field.shape[1]) self.fieldaxes.set_ylim(top=field.shape[0]) self.colorbar.update_normal(self.collection) self.figure.canvas.draw_idle() def _getField(self): if self.plotHalo: return np.rot90(self.field[:, :, self.curklevel]) else: return np.rot90(self.field[self.istart : self.iend, self.jstart : self.jend, self.curklevel])
def plot_biplot(x, information, rhomin, rhomax): ## Formatting input if callable(information): n = 100 width = 20 # x = np.linspace(0, 1, n) y = information(x) else: n = len(information) width = int(n/5) # x = np.arange(n) y = information array = np.repeat(y, width).reshape((len(y), width)) # Font type for all plots rc('font',**{'family':'serif','serif':['Palation'], 'size':24, 'weight':'bold'}) rc('text', usetex=True) mpl.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"] # Creating plot # fig = plt.figure(figsize=(18,10)) fig1, ax1 = plt.subplots(figsize=(2, 10)) # gs = gridspec.GridSpec(8, 4) # ax1 = fig.add_subplot(gs[1:, 0]) # ax1 = fig.add_axes([0.05, 0.05, 0.05, 0.8]) # ax2 = fig.add_axes([0.2, 0.05, 0.75, 0.8]) #ax1 = fig.add_subplot(111) #ax2 = fig.add_subplot(111) colormap_chosen = mpl.colors.LinearSegmentedColormap.from_list('mycolors',['red','yellow']) #['yellow','yellow','#F7F8E0','yellow','red','#8A0808'] colormap_chosen = 'OrRd_r'#'gist_heat' #hot cm = plt.get_cmap(colormap_chosen) cNorm = colors.Normalize(vmin=rhomin, vmax=rhomax) ax1.imshow(array, cmap=cm, norm=cNorm) # ax2 = fig.add_subplot(gs[1:, 1:5]) fig1 = plt.gcf() fig2, ax2 = plt.subplots(figsize=(10, 8)) x_0, rho_0 = evolver.rho_aprox(0.5) ax2.plot(rho_0, x_0, linewidth=5, linestyle="-", c="grey", zorder=1, alpha=0.3) ax2.plot(information, x, linewidth=8, linestyle="-", c="black", zorder=1) ax2.plot(information, x, linewidth=5, linestyle="-", c="white", zorder=1) ##### TRYING TO MAKE LINES WITH COLOR GRADIENTS # ax2.scatter(information,x,c=range(len(information)), marker='_', s=30) # path = mpath.Path(np.column_stack([y, x])) # verts = path.interpolated(steps=3).vertices # xcolor, ycolor = verts[:, 0], verts[:, 1] colorline(ax2, y, x, z=rhomin*(rhomin-y)/(rhomin-rhomax) + rhomax*(rhomax-y)/(rhomax-rhomin), cmap=cm, norm=cNorm, linewidth=5, alpha=1) # gs.update(wspace=0.5, hspace=0.5) # Stetic tuning of plot ax2.set_ylim([min(x),max(x)]) ax2.set_xlim([990,1093]) ax2.set_xlabel(r"\textbf{Density / g L}$\boldsymbol{^{-1}}$") ax2.set_ylabel(r"\textbf{Height / mm}",rotation=270, labelpad= 10)#fontsize=20, ax2.yaxis.set_label_position("right") ax2.tick_params(labeltop=True) ax2.grid(True) plt.setp(ax2.get_xticklabels(), fontsize=18) # Final Configuration plt.setp(ax1.get_xticklabels(), visible=False) plt.setp(ax1.get_yticklabels(), visible=False) plt.setp(ax2.get_yticklabels(), visible=False) ax1.get_xaxis().set_visible(False) for axis in ['top','bottom','left','right']: ax1.spines[axis].set_linewidth(5) # fig = plt.gcf() fig2 = plt.gcf() fig3, axslider = plt.subplots(figsize=(10, 1)) #axslider = plt.axes([0.3, 0.94, 0.5, 0.04], axisbg=None) # axslider = fig.add_subplot(gs[0,1:3]) samp = Slider(axslider, 'Time (min)', 0., maxtime, valinit=0,color='grey',alpha=0.3,valfmt='%i'.ljust(5)) samp.set_val(float(time)/60.) fig3 = plt.gcf() return fig1, fig2, fig3
class py3DSeedEditor: """ Viewer and seed editor for 2D and 3D data. py3DSeedEditor(img, ...) img: 2D or 3D grayscale data voxelsizemm: size of voxel, default is [1, 1, 1] initslice: 0 colorbar: True/False, default is True cmap: colormap zaxis: axis with slice numbers ed = py3DSeedEditor(img) ed.show() selected_seeds = ed.seeds """ def __init__(self, img, voxelsizemm=[1,1,1], initslice = 0 , colorbar = True, cmap = matplotlib.cm.Greys_r, seeds = None, contour = None, zaxis=0, mouse_button_map= {1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8}, windowW = [], windowC = [], range_per_slice = False ): self.fig = plt.figure() if len(img.shape) == 2: imgtmp = img img = np.zeros([1, imgtmp.shape[0], imgtmp.shape[1]]) img[-1,:,:] = imgtmp zaxis = 0 # Rotate data in depndecy on zaxispyplot img = self._rotate_start(img, zaxis) seeds = self._rotate_start(seeds, zaxis) contour = self._rotate_start(contour, zaxis) self.rotated_back = False self.zaxis = zaxis # if True, intensity range is calculated per slice = better visualisation for # higher number of labels self.range_per_slice = range_per_slice #self.ax = self.fig.add_subplot(111) self.imgshape = list(img.shape) self.img = img self.actual_slice = initslice self.colorbar = colorbar self.cmap = cmap if seeds == None: self.seeds = np.zeros(self.imgshape, np.int8) else: self.seeds = seeds if not (windowW and windowC): self.imgmax = np.max(img) self.imgmin = np.min(img) else: self.imgmax = windowC + (windowW / 2) self.imgmin = windowC - (windowW / 2) """ Mapping mouse button to class number. Default is normal order""" self.button_map = mouse_button_map self.contour = contour self.press = None self.press2 = None # language self.texts = {'btn_delete':'Delete', 'btn_close': 'Close'} #iself.fig.subplots_adjust(left=0.25, bottom=0.25) self.ax = self.fig.add_axes([0.2, 0.3, 0.7,0.6]) self.draw_slice() if self.colorbar: self.fig.colorbar(self.imsh) # user interface look axcolor = 'lightgoldenrodyellow' ax_actual_slice = self.fig.add_axes([0.2, 0.2, 0.6, 0.03], axisbg=axcolor) self.actual_slice_slider = Slider(ax_actual_slice, 'Slice', 0, self.imgshape[2], valinit=initslice) # conenction to wheel events self.fig.canvas.mpl_connect('scroll_event', self.on_scroll) self.actual_slice_slider.on_changed(self.sliceslider_update) # draw self.fig.canvas.mpl_connect('button_press_event', self.on_press) self.fig.canvas.mpl_connect('button_release_event', self.on_release) self.fig.canvas.mpl_connect('motion_notify_event', self.on_motion) # delete seeds self.ax_delete_seeds = self.fig.add_axes([0.2,0.1,0.1,0.075]) self.btn_delete = Button(self.ax_delete_seeds, self.texts['btn_delete']) self.btn_delete.on_clicked(self.callback_delete) # close button self.ax_delete_seeds = self.fig.add_axes([0.7,0.1,0.1,0.075]) self.btn_delete = Button(self.ax_delete_seeds, self.texts['btn_close']) self.btn_delete.on_clicked(self.callback_close) self.draw_slice() def _rotate_start(self, data, zaxis): if data != None: if zaxis == 0: data = np.transpose(data,(1,2,0)) elif zaxis == 2: pass else: print "problem with zaxis in _rotate_start()" return data def _rotate_end(self, data, zaxis): if data != None: if self.rotated_back == False: if zaxis == 0: data = np.transpose(data,(2,0,1)) elif zaxis == 2: pass else: print "problem with zaxis in _rotate_start()" else: print "There is a danger in calling show() twice" return data def update_slice(self): #TODO tohle je tu kvuli contour, neumim ji odstranit jinak self.ax.cla() self.draw_slice() def draw_slice(self): self.actual_slice = np.int(self.actual_slice) sliceimg = self.img[:, :, self.actual_slice] if self.range_per_slice: self.imsh = self.ax.imshow(sliceimg, self.cmap, vmin=sliceimg.min(), vmax=sliceimg.max(), interpolation='nearest') else: self.imsh = self.ax.imshow(sliceimg, self.cmap, vmin=self.imgmin, vmax=self.imgmax, interpolation='nearest') self.ax.imshow(self.prepare_overlay(self.seeds[:, :, self.actual_slice]), interpolation='nearest', vmin=self.imgmin, vmax=self.imgmax) if self.contour != None: try: # exception catch problem with none object in image ctr = self.ax.contour(self.contour[:, :, self.actual_slice], 1, linewidths=2) except: pass self.fig.canvas.draw() def next_slice(self): self.actual_slice = self.actual_slice + 1 if self.actual_slice >= self.imgshape[2]: self.actual_slice = 0 def prev_slice(self): self.actual_slice = self.actual_slice - 1 if self.actual_slice < 0: self.actual_slice = self.imgshape[2] - 1 def sliceslider_update(self, val): # zaokrouhlení #self.actual_slice_slider.set_val(round(self.actual_slice_slider.val)) self.actual_slice = round(val) self.update_slice() def prepare_overlay(self,seeds): sh = list(seeds.shape) if len(sh) == 2: sh.append(4) else: sh[2] = 4 overlay = np.zeros(sh) overlay[:,:,0] = (seeds == 1) overlay[:,:,1] = (seeds == 2) overlay[:,:,2] = (seeds == 3) overlay[:,:,3] = (seeds > 0) return overlay def show(self): """ Function run viewer window. """ plt.show() # Rotate data in depndecy on zaxis self.img = self._rotate_end(self.img, self.zaxis) self.seeds = self._rotate_end(self.seeds, self.zaxis) self.contour = self._rotate_end(self.contour, self.zaxis) self.rotated_back = True return self.seeds def on_scroll(self, event): ''' mouse wheel is used for setting slider value''' if event.button == 'up': self.next_slice() if event.button == 'down': self.prev_slice() self.actual_slice_slider.set_val (self.actual_slice) ## malování ------------------- def on_press(self, event): 'on but-ton press we will see if the mouse is over us and store some data' if event.inaxes != self.ax: return #contains, attrd = self.rect.contains(event) #if not contains: return #print 'event contains', self.rect.xy #x0, y0 = self.rect.xy self.press = [event.xdata], [event.ydata], event.button #self.press1 = True def on_motion(self, event): 'on motion we will move the rect if the mouse is over us' if self.press is None: return if event.inaxes != self.ax: return #print event.inaxes x0, y0, btn = self.press x0.append(event.xdata) y0.append(event.ydata) def on_release(self, event): 'on release we reset the press data' if self.press is None: return #print self.press x0, y0, btn = self.press if btn == 1: color = 'r' elif btn == 2: color = 'b' #button Mapping btn = self.button_map[btn] self.set_seeds(y0, x0, self.actual_slice, btn ) self.press = None self.update_slice() def callback_delete(self, event): self.seeds[:,:,self.actual_slice] = 0 self.update_slice() def callback_close(self, event): matplotlib.pyplot.clf() matplotlib.pyplot.close() def set_seeds(self, px, py, pz, value = 1, voxelsizemm = [1,1,1], cursorsizemm = [1,1,1]): assert len(px) == len(py) , 'px and py describes a point, their size must be same' for i, item in enumerate(px): self.seeds[item, py[i], pz] = value def get_seed_sub(self, label): """ Return list of all seeds with specific label """ sx, sy, sz = np.nonzero(self.seeds == label) return sx, sy, sz def get_seed_val(self, label): """ Return data values for specific seed label""" return self.img[self.seeds==label]
class Parameters(object): """ An interactive matplotlib window to test and set paramters for eye tracking on a subset of frames. """ def __init__(self, notes): self.wintitle = notes.wintitle self.stack = notes.stack # cropped frame stack self.frame_dims = notes.stack.shape[-2:] # cropped frame dimensions self.nframes = notes.stack.shape[0] # number of frames _, _, _, mmperpx = geom.parse_axis(notes.axis) # mm to pixel conversion factor self.diffs = notes.diffs # mean pixel-wise differences from reference frame # status variable, set to True on 'c' keypress to exit parameter-setting # loop and perform eye tracking on all files: self.done = False # define a parameter dictionary with some initial values # first specify maximum/ initial values for some of the parameters method = 'convolve' # eye tracking method ('threshold' or 'convolve') shape = 'lse_ellipse' # shape fitting methos ('lse_ellipse or 'min_enclosing') # parameters for image pre-processing: eq_sp_maxval = 1 # bihist equalization separation point eq_rp_maxval = 1 # bihist equalization range point c_sig_maxval = 255 / 2 # intensity space sigma for bilateral filter s_sig_maxval = max(self.frame_dims) / 8 # spatial sigma for bilateral filter k_size_maxval = max(self.frame_dims) / 8 # kernel size for median filter k_size_inival = np.int(np.round(k_size_maxval/2)) if k_size_inival % 2 == 0: # must be an odd integer k_size_inival += 1 # parameters for threshold method: dark_thr_inival = 255 * 0.25 # threshold for dark areas light_thr_inival = 255 * 0.75 # threshold for light areas area_thr_maxval = np.pi * (0.5 / mmperpx)**2 # area threshold # parameters for convolve method: conv_size_maxval = 0.5 / mmperpx # kernel size for convolution conv_size_inival = np.int(np.round(conv_size_maxval/4)) if conv_size_inival % 2 == 0: # must be an odd integer conv_size_inival += 1 rad_maxval = np.round(2/mmperpx).astype('int') # max radius for edges # parameters for blink detection b_thr_inival = np.median(self.diffs)*3 # build parameter dictionary self.params = {'eq':False, 'eq_sp':0, 'eq_rp':0, 'k_size':k_size_inival, 'c_sig':c_sig_maxval/2, 's_sig':s_sig_maxval/2, 'dark_thr':20, 'light_thr':235, 'area_thr':0, 'conv_size':conv_size_inival, 'max_rad':rad_maxval/2, 'shape':shape, 'ht_fit':False, 'method':method, 'blink_thr':b_thr_inival,} # set initial frames to display self.frame_ind = 0 # current frame index self.frame = self.stack[self.frame_ind,:,:] # original frame # processed (filtered and equalized) frame: self.p_frame = img.pre_process(self.frame, self.params) # inverted binarized frame (dark contours): self.b_frame, _ = img.binarize(self.p_frame, self.params) # p_frame convolved with a black square (center = argmin[c_frame]) self.c_frame, self.center = img.square_convolve(self.p_frame, self.params['k_size']) self.g_frame = img.gradient(self.p_frame) # gradient magnitude of p_frame self.edge_pts = geom.starburst(self.g_frame, self.center, self.params['max_rad'], 100) # set up figure with plots and images self.fig, _ = plt.subplots() self.fig.canvas.set_window_title(self.wintitle) grid = gs.GridSpec(12, 12) # original frame display axis self.ax_frame = plt.subplot(grid[:6, :4]) self.ax_frame.axis('off') # processed frame display axis self.ax_pframe = plt.subplot(grid[6:9, :2]) self.ax_pframe.axis('off') self.ax_pframe.set_title('Processed Frame') # dark binary display axis self.ax_bframe = plt.subplot(grid[6:9, 2:4]) self.ax_bframe.axis('off') self.ax_bframe.set_title('Dark Contours') # convolution display axis self.ax_cframe = plt.subplot(grid[9:12, :2]) self.ax_cframe.axis('off') self.ax_cframe.set_title('Convolution') self.c_dot = self.ax_cframe.scatter(self.center[0], self.center[1], s=8) # gradient display axis self.ax_gframe = plt.subplot(grid[9:12, 2:4]) self.ax_gframe.axis('off') self.ax_gframe.set_title('Gradient') self.e_dots = self.ax_gframe.scatter(self.edge_pts[:,0], self.edge_pts[:,1], s=1, color='C0') # intensity histogram self.ax_hist = plt.subplot(grid[:4, 4:7]) self.hist = self.ax_hist.hist(self.p_frame.ravel(), bins=np.arange(256), color='C0', alpha=0.5, density=True) self.dark_thr = self.ax_hist.axvline(self.params['dark_thr'], color=[0,0,0], marker=' ', label='Dark Threshold') self.light_thr = self.ax_hist.axvline(self.params['light_thr'], color=[.75,.75,.75], marker=' ', label='Light Threshold') self.ax_hist.set_title('Intensity Histogram') self.ax_hist.legend() # timecourse of mean differences between each frame and the reference self.ax_diffs = plt.subplot(grid[4:8, 4:7]) self.ax_diffs.scatter(np.linspace(0,len(self.diffs),num=len(self.diffs)), self.diffs, s=3, color='C0') self.ax_diffs.set_xlim([0, self.nframes]) self.f_dot = self.ax_diffs.scatter(0, self.diffs[0], s=4, color='C1') self.blink_thr1 = self.ax_diffs.axhline(self.params['blink_thr'], marker=' ', color=[0,0,0], label='Blink Threshold') self.ax_diffs.legend() self.ax_diffs.set_ylabel('Diff. from reference') # slider for frame scrolling self.ax_fslider = plt.subplot(grid[8, 4:7]) self.f_slider = Slider(self.ax_fslider, 'Frame', 0, self.nframes-1, valinit=0, valfmt='%d') # distribution of the differences self.ax_dist = plt.subplot(grid[9:12, 4:7]) self.ax_dist.hist(self.diffs.ravel(), bins=100, color='C0', density=True) self.blink_thr2 = self.ax_dist.axvline(self.params['blink_thr'], marker=' ', color=[0,0,0]) self.ax_dist.set_xlabel('Diff. from reference') # slider for eq_sp equalization parameter self.ax_spslider = plt.subplot(grid[0, 8:11]) self.sp_slider = Slider(self.ax_spslider, 'EQ Sep. Point', 0, 1, valinit=self.params['eq_sp'], valfmt='%02f') # slider for eq_rp equalization parameter self.ax_rpslider = plt.subplot(grid[1, 8:11]) self.rp_slider = Slider(self.ax_rpslider, 'EQ Range Point', 0, 1, valinit=self.params['eq_rp'], valfmt='%02f') # slider for c_sig filter parameter self.ax_cslider = plt.subplot(grid[3, 8:11]) self.c_slider = Slider(self.ax_cslider, 'Color Sigma', 0, c_sig_maxval, valinit=self.params['c_sig'], valfmt='%d') # slider for s_sig filter parameter self.ax_sslider = plt.subplot(grid[4, 8:11]) self.s_slider = Slider(self.ax_sslider, 'Spatial Sigma', 0, s_sig_maxval, valinit=self.params['s_sig'], valfmt='%d') # slider for k_size filter parameter self.ax_kslider = plt.subplot(grid[5, 8:11]) self.k_slider = Slider(self.ax_kslider, 'Kernel Size', 1, k_size_maxval, valinit=k_size_inival, valfmt='%d') # slider for area_thr parameter self.ax_aslider = plt.subplot(grid[7, 8:11]) self.a_slider = Slider(self.ax_aslider, 'Area Thresh', 0, area_thr_maxval, valinit=0, valfmt='%d') # slider for conv_size convolution parameter self.ax_ckslider = plt.subplot(grid[9, 8:11]) self.ck_slider = Slider(self.ax_ckslider, 'Conv. Kernel Size', 0, conv_size_maxval, valinit=self.params['conv_size'], valfmt='%d') # slider for max_rad starburst parameter self.ax_rslider = plt.subplot(grid[10, 8:11]) self.r_slider = Slider(self.ax_rslider, 'Max. Radius', 0, rad_maxval, valinit=self.params['max_rad'], valfmt='%d') # disconnect default matplotlib key bindings manager, canvas = self.fig.canvas.manager, self.fig.canvas canvas.mpl_disconnect(manager.key_press_handler_id) # maximize display window manager.window.showMaximized() # connect callback functions self.cidpress = self.fig.canvas.mpl_connect('button_press_event', self.on_click) self.cidkey = self.fig.canvas.mpl_connect('key_press_event', self.on_key) self.f_slider.on_changed(self.update_frame) self.sp_slider.on_changed(self.sp_update) self.rp_slider.on_changed(self.rp_update) self.c_slider.on_changed(self.c_update) self.s_slider.on_changed(self.s_update) self.k_slider.on_changed(self.k_update) self.a_slider.on_changed(self.a_update) self.ck_slider.on_changed(self.ck_update) self.r_slider.on_changed(self.r_update) # get pupil and reflection patches self.get_eye_patches() # display self.update_display() # display user prompts print("Set Parameters") print(" - click on the slider or use the arrow keys to scroll between video frames") print("") print(" Pre-Processing:") print(" - use 'e' to toggle histogram equalization") print(" - set parameters for equalizationg and smoothing using the sliders") print("") print(" - use 'm' to change eye tracking method") print(" Threshold Method:") print(" - set dark and light binarization thresholds by clicking on the") print(" intensity histogram (right & left clicks respectively)") print(" - set an area threshold using the slider") print(" Convolution Method:") print(" - set a kernel size and a maximum radius using the sliders") print("") print(" Ellipse Fitting:") print(" - use 'h' to toggle ellipse fitting via re-sampling (Hough Transform)") print("") print(" - set the blink-detection threshold by clicking on the difference") print(" scatter plot or distribution plot") print("") print(" - use 'r' to re-set paramters and return to the previous window,") print(" 'c' to confirm the current parameters and commence tracking,") print(" or 'esc' to quit (parameters will not be saved)") def on_click(self, event): """sets thresholds based on click""" # set binarization thresholds if event.inaxes == self.ax_hist: if event.button == 1: # left click self.dark_thr.remove() self.params['dark_thr'] = event.xdata self.dark_thr = self.ax_hist.axvline(self.params['dark_thr'], color=[0,0,0], marker=' ', label='Dark Threshold') elif event.button == 3: # right click self.light_thr.remove() self.params['light_thr'] = event.xdata self.light_thr = self.ax_hist.axvline(self.params['light_thr'], color=[0.75,0.75,0.75], marker=' ', label='Light Threshold') # update binarized frame self.b_frame, light = img.binarize(self.p_frame, self.params) # update eye patches self.get_eye_patches() # set blink threshold in diffs timecourse axis elif event.inaxes == self.ax_diffs: if event.button == 1: self.blink_thr1.remove() self.blink_thr2.remove() self.params['blink_thr'] = event.ydata self.blink_thr1 = self.ax_diffs.axhline(self.params['blink_thr'], marker=' ', color=[0,0,0], label='Blink Threshold') self.blink_thr2 = self.ax_dist.axvline(self.params['blink_thr'], marker=' ', color=[0,0,0]) # set blink threshold in diffs distribution axis elif event.inaxes == self.ax_dist: if event.button == 1: self.blink_thr1.remove() self.blink_thr2.remove() self.params['blink_thr'] = event.xdata self.blink_thr1 = self.ax_diffs.axhline(self.params['blink_thr'], marker=' ', color=[0,0,0], label='Blink Threshold') self.blink_thr2 = self.ax_dist.axvline(self.params['blink_thr'], marker=' ', color=[0,0,0]) # update display self.update_display() def on_key(self, event): """specifies functions of pressed keys in matplotlib figure""" # toggle histogram equalization if event.key == 'e': self.params['eq'] = not self.params['eq'] print(" EQ: " + str(self.params['eq'])) self.update_pframes() self.update_hist() self.get_eye_patches() self.update_display() # toggle eye tracking method elif event.key == 'm': if self.params['method'] == 'threshold': self.params['method'] = 'convolve' print(" Method: convolve") elif self.params['method'] == 'convolve': self.params['method'] = 'threshold' print(" Method: threshold") self.get_eye_patches() # update eye patches self.update_display() # toggle ellipse fitting method elif event.key == 'f': if self.params['shape'] == 'lse_ellipse': self.params['shape'] = 'min_enclosing' elif self.params['shape'] == 'min_enclosing': self.params['shape'] = 'lse_ellipse' print(" Shape fit: " + self.params['shape']) self.get_eye_patches() self.update_display() # toggle Hough transform ellipse fitting elif event.key == 'h': self.params['ht_fit'] = not self.params['ht_fit'] print(" Hough transform: " + str(self.params['ht_fit'])) self.get_eye_patches() # update eye patches self.update_display() # change frame elif event.key == 'right': # move forward one frame if np.round(self.f_slider.val) < self.nframes-1: self.f_slider.set_val(self.f_slider.val+1) elif event.key == 'left': # move back one frame if np.round(self.f_slider.val) > 0: self.f_slider.set_val(self.f_slider.val-1) # continue, close, or re-set elif event.key == 'c': # confirm parameters and close figure print("Eye tracking parameters confirmed") self.done = True plt.close('all') elif event.key == 'escape': # quit eye tracking print("Eye tracking aborted") plt.close('all') sys.exit() elif event.key == 'r': # re-set & return to annotation print("Parameters re-set") plt.close('all') def update_frame(self, val): """updates frame and plots based on frame slider value""" # update frame index self.frame_ind = np.int(np.round(self.f_slider.val)) # update displays self.frame = self.stack[self.frame_ind,:,:] # raw frame self.update_pframes() self.update_hist() # diffs plot self.f_dot.remove() self.f_dot = self.ax_diffs.scatter(self.frame_ind, self.diffs[self.frame_ind], s=4, color=[1,0,1]) self.get_eye_patches() self.update_display() def sp_update(self, val): """updates eq_sp equalization parameter and re-applies filters""" self.params['eq_sp'] = self.sp_slider.val self.update_pframes() self.update_hist() self.get_eye_patches() self.update_display() def rp_update(self, val): """updates eq_rp equalization parameter and re-applies filters""" self.params['eq_rp'] = self.rp_slider.val self.update_pframes() self.update_hist() self.get_eye_patches() self.update_display() def c_update(self, val): """updates c_sig filter parameter and re-applies filters""" self.params['c_sig'] = self.c_slider.val self.update_pframes() self.get_eye_patches() self.update_display() def s_update(self, val): """updates s_sig filter parameter and re-applies filters""" self.params['s_sig'] = self.s_slider.val self.update_pframes() self.get_eye_patches() self.update_display() def k_update(self, val): """updates k_size filter parameter and re-applies filters""" self.params['k_size'] = np.int(np.round(self.k_slider.val)) if self.params['k_size'] % 2 == 0: # kernel size must be odd self.params['k_size'] += 1 self.update_pframes() self.get_eye_patches() self.update_display() def a_update(self, val): """updates dark contour area threshold and re-applies eye-tracking""" self.params['area_thr'] = self.a_slider.val self.get_eye_patches() self.update_display() def ck_update(self, val): """updates conv_size convolution parameter and re-applies eye-tracking""" self.params['conv_size'] = np.int(np.round(self.ck_slider.val)) if self.params['conv_size'] % 2 == 0: # kernel size must be odd self.params['conv_size'] += 1 self.update_pframes() self.get_eye_patches() self.update_display() def r_update(self, val): """updates max_rad and re-applies eye-tracking""" self.params['max_rad'] = self.r_slider.val self.update_pframes() self.get_eye_patches() self.update_display() def update_pframes(self): """updates the filtered frame""" self.p_frame = img.pre_process(self.frame, self.params) self.b_frame, light = img.binarize(self.p_frame, self.params) self.c_frame, self.center = img.square_convolve(self.p_frame, self.params['conv_size']) self.c_dot.remove() self.c_dot = self.ax_cframe.scatter(self.center[0], self.center[1], s=8, c='C0') self.g_frame = img.gradient(self.p_frame) self.edge_pts = geom.starburst(self.g_frame, self.center, self.params['max_rad'], 100) self.e_dots.remove() self.e_dots = self.ax_gframe.scatter(self.edge_pts[:,0], self.edge_pts[:,1], s=1, color='C0') def update_hist(self): """updates the intensity histogram""" # remove old histogram patches _ = [i.remove() for i in self.hist[2]] # update histogram self.hist = self.ax_hist.hist(self.p_frame.ravel(), bins=np.arange(256), color='C0', alpha=0.5, density=True) # set ylim to match current histogram self.ax_hist.set_ylim(top=self.hist[0].max()) def update_display(self): """updates the display""" self.ax_frame.imshow(self.frame, cmap='gray') self.ax_pframe.imshow(self.p_frame, cmap='gray') self.ax_bframe.imshow(self.b_frame, cmap='gray') self.ax_cframe.imshow(self.c_frame, cmap='gray') self.ax_gframe.imshow(self.g_frame, cmap='gray') self.fig.canvas.draw() def get_eye_patches(self): """performs eye tracking on the current frame and updates pupil and reflection patches""" if self.params['method'] == 'threshold': pupil, cr = main.itrack_threshold(self.frame, self.params) elif self.params['method'] == 'convolve': pupil, cr = main.itrack_convolve(self.frame, self.params) # clear any old patches self.ax_frame.patches.clear() self.ax_pframe.patches.clear() if not np.isnan(pupil).all(): # if fit was successful # parse ellipse parameters xy = (pupil[0], pupil[1]) width, height, theta = pupil[2] * 2, pupil[3] * 2, pupil[4] # create patches self.pupil = Ellipse(xy, width, height, angle=theta, lw=2, ec='C0', fill=False) self.pupil2 = Ellipse(xy, width, height, angle=theta, lw=2, ec='C0', fill=False) # add patches to appropriate axes self.ax_frame.add_patch(self.pupil) self.ax_pframe.add_patch(self.pupil2) # check that tracking was successful if not np.isnan(cr).all(): # if fit was successful # parse ellipse parameters xy = (cr[0], cr[1]) width, height, theta = cr[2] * 2, cr[3] * 3, cr[4] # create patches self.cr = Ellipse(xy, width, height, angle=theta, lw=2, ec='C1', fill=False) self.cr2 = Ellipse(xy, width, height, angle=theta, lw=2, ec='C1', fill=False) # add patches to appropriate axes self.ax_frame.add_patch(self.cr) self.ax_pframe.add_patch(self.cr2)
class CompareAnimation: """ Launch two parallel animations of runs, so the user can easily compare the structures. Example: R1 = ReadRun("fake/path/run_1/") R2 = ReadRun("fake/path/run_2/") new_anim=CompareAnimation(R1.S,R2.S) new_anim.launch() plt.show() """ def __init__(self,snaplist1,snaplist2, symbol="bo", dt = None, markersize=2, **kwargs): if snaplist1[0].t < snaplist2[0].t: S = deepcopy(snaplist1[0]) S.t = snaplist2[0].t snaplist1.reverse() snaplist1.append(S) snaplist1.reverse() if snaplist2[0].t < snaplist1[0].t: S = deepcopy(snaplist2[0]) S.t = snaplist1[0].t snaplist2.reverse() snaplist2.append(S) snaplist2.reverse() self.snaplists = [ snaplist1, snaplist2 ] self.times = [ [s.t for s in snapl] for snapl in self.snaplists ] #self.nsnap=len(snaplist) self.n = [ 0, 0 ] self.t=0. self.tmax = max(self.snaplists[0][-1].t, self.snaplists[1][-1].t) self.dt = self.tmax/200. if dt is None else dt self.delay=1 self.symbol= symbol self.markersize= markersize self.pause_switch=True self.BackgroundColor='white' self.kwargs=kwargs def create_frame(self): self.fig = plt.figure(figsize=(15,10)) print "fig created" self.ax = [] for i in range(2): ax = self.fig.add_subplot(121+i, projection='3d', adjustable='box', axisbg=self.BackgroundColor) ax.set_aspect('equal') plt.tight_layout() ax.set_xlabel("x (pc)") ax.set_ylabel ("y (pc)") ax.set_zlabel ("z (pc)") self.ax.append(ax) X,Y,Z = [],[],[] for snapl in self.snaplists: X.append( snapl[0].x ) Y.append( snapl[0].y ) Z.append( snapl[0].z ) max_range = np.array([X[0].max()-X[0].min(), Y[0].max()-Y[0].min(), Z[0].max()-Z[0].min()]).max() / 3.0 mean_x = X[0].mean(); mean_y = Y[0].mean(); mean_z = Z[0].mean() self.fig.subplots_adjust(bottom=0.06)#, left=0.1) self.line, self.canvas = [],[] for (x,y,z,ax) in zip(X,Y,Z,self.ax): self.line.append(ax.plot(x, y, z, self.symbol, markersize=self.markersize, **self.kwargs )[0]) self.canvas.append(ax.figure.canvas) ax.set_xlim(mean_x - max_range, mean_x + max_range) ax.set_ylim(mean_y - max_range, mean_y + max_range) ax.set_zlim(mean_z - max_range, mean_z + max_range) ax_pauseB=plt.axes([0.04, 0.02, 0.06, 0.025]) self.pauseB=Button(ax_pauseB,'Play') self.pauseB.on_clicked(self.Pause_button) slider_ax = plt.axes([0.18, 0.02, 0.73, 0.025]) self.slider_time = Slider(slider_ax, "Time", self.snaplists[0][0].t, self.snaplists[0][-1].t, valinit = self.snaplists[0][0].t, color = '#AAAAAA') self.slider_time.on_changed(self.slider_time_update) def update_lines(self): nsnaps = [0,0] for j,snaplist,time in zip(range(2),self.snaplists,self.times): ind = np.nonzero(time < self.t)[0] nsnaps[j] = max(ind) if len(ind)!=0 else 0 for (line,snaplist,n) in zip(self.line, self.snaplists, nsnaps): line.set_data(snaplist[n].x, snaplist[n].y) line.set_3d_properties(snaplist[n].z) for canv in self.canvas: canv.draw() def timer_update(self,lines): if not self.pause_switch: self.t = self.t+self.dt if self.t > self.tmax: self.t = self.t - self.tmax self.update_lines() self.slider_time.set_val(self.t) def Pause_button(self,event): self.pause_switch = not self.pause_switch def slider_time_update(self,val): self.t = val self.update_lines() def launch(self): self.create_frame() self.timer=self.fig.canvas.new_timer(interval=self.delay) args=[self.line] # We tell the timer to call the update function every 100ms self.timer.add_callback(self.timer_update,*args) self.timer.start()
class UI: def __init__(self, ss_plotdata, min_limit, max_limit, coloring, ppf): self.ss_plotdata = ss_plotdata self.n = self.ss_plotdata.n self.ppf = ppf self.min_limit = min_limit self.max_limit = max_limit self.azimuth = -65 self.elevation = 23 self.fig = plt.figure() self.fig.set_size_inches(17, 9) gs = gridspec.GridSpec(2, 3, width_ratios=[1, 1, 0.1], height_ratios=[1, 0.05], left=0.05) self.fig.suptitle(ss_plotdata.title) self.entry_ax = plt.subplot(gs[0], projection='3d') self.entry_ax.view_init(elev=self.elevation, azim=self.azimuth) self.exit_ax = plt.subplot(gs[1], projection='3d') self.exit_ax.view_init(elev=self.elevation, azim=self.azimuth) # setup colorbar self.cb_ax = plt.subplot(gs[2]) if coloring == Coloring.DISCRETE_MONTHS or coloring == Coloring.DISCRETE_MONTHS_SPLIT_MARKERS or coloring == Coloring.DISCRETE_MONTHS_POLYGONS: cbar = self.fig.colorbar(ss_plotdata.mapping, cax=self.cb_ax, label=ss_plotdata.colorbar_label, ticks=np.arange(1.5, 13.5, 1)) cbar.ax.set_yticklabels(calendar.month_abbr[1:]) else: self.fig.colorbar(ss_plotdata.mapping, cax=self.cb_ax, label=ss_plotdata.colorbar_label) self.slider_ax = plt.subplot(gs[1, :]) self.time_slider = Slider(self.slider_ax, "Time", 0, self.n, valinit=ss_plotdata.n, valstep=1, valfmt=("%0.0f " + ss_plotdata.lag_units)) self.timestep = 0 self.time_slider.on_changed(self.update_time) def draw(self, offset): self.ss_plotdata.update_data(self.entry_ax, self.exit_ax, offset) self.set_limit() def update_time(self, val): self.timestep = int(self.time_slider.val) self.draw(self.timestep) self.fig.canvas.draw_idle() def set_limit(self): set_all_limits(self.exit_ax, self.min_limit, self.max_limit) set_all_limits(self.entry_ax, self.min_limit, self.max_limit) def animate(self, k): self.timestep += self.ppf self.timestep %= self.n print('\r{0:.2f}%'.format(100 * self.timestep / self.n), end='') self.time_slider.set_val(self.timestep) return [] def render_animation_to_file(self, outfile): print("Rendering....") self.ani = animation.FuncAnimation( self.fig, self.animate, self.ppf * np.arange(0, (self.n + 1) // self.ppf), interval=20, repeat=False, blit=True) self.ani.save(outfile, writer="ffmpeg") def render_image_to_file(self, outfile): self.draw(self.n) plt.draw() self.fig.savefig(outfile) def show_animation(self): self.ani = animation.FuncAnimation( self.fig, self.animate, self.ppf * np.arange(0, (self.n + 1) // self.ppf), interval=20, repeat=True, blit=True) plt.show() def render_image(self, offset): pass
class PlotFrame(wx.Frame): """ PlotFrame is a custom wxPython frame to hold the panel with a Figure and WxAgg backend canvas for matplotlib plots or other figures. In this frame: self is an instance of a wxFrame; axes is an instance of MPL Axes; fig is an instance of MPL Figure; panel is an instance of wxPanel, used for the main panel, to hold canvas, an instance of MPL FigureCanvasWxAgg. """ # Main function to set everything up when the frame is created def __init__(self, title, pos, size): """ This will be executed when an instance of PlotFrame is created. It is the place to define any globals as "self.<name>". """ wx.Frame.__init__(self, None, wx.ID_ANY, title, pos, size) if len(sys.argv) < 2: self.filename = "" else: self.filename = sys.argv[1] # set some Boolean flags self.STOP = False self.data_loaded = False self.reverse_play = False self.step = 1 # Make the main Matplotlib panel for plots self.create_main_panel() # creates canvas and contents # Then add wxPython widgets below the MPL canvas # Layout with box sizers self.sizer = wx.BoxSizer(wx.VERTICAL) self.sizer.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.EXPAND) self.sizer.AddSpacer(10) self.sizer.Add(self.toolbar, 0, wx.EXPAND) self.sizer.AddSpacer(10) # Make the control panel with a row of buttons self.create_button_bar() self.sizer.Add(self.button_bar_sizer, 0, flag=wx.ALIGN_CENTER | wx.TOP) # Make a Status Bar self.statusbar = self.CreateStatusBar() self.sizer.Add(self.statusbar, 0, wx.EXPAND) self.SetStatusText("Frame created ...") # ------------------------------------------------------- # set up the Menu Bar # ------------------------------------------------------- menuBar = wx.MenuBar() menuFile = wx.Menu() # File menu menuFile.Append(1, "&Open", "Filename(s) or wildcard list to plot") menuFile.Append(3, "Save", "Save plot as a PNG image") menuFile.AppendSeparator() menuFile.Append(10, "E&xit") menuBar.Append(menuFile, "&File") menuHelp = wx.Menu() # Help menu menuHelp.Append(11, "&About Netview") menuHelp.Append(12, "&Usage and Help") menuHelp.Append(13, "Program &Info") menuBar.Append(menuHelp, "&Help") self.SetMenuBar(menuBar) self.panel.SetSizer(self.sizer) self.sizer.Fit(self) # ------------------------------------------------------- # Bind the menu items to functions # ------------------------------------------------------- self.Bind(wx.EVT_MENU, self.OnOpen, id=1) self.Bind(wx.EVT_MENU, self.OnSave, id=3) self.Bind(wx.EVT_MENU, self.OnQuit, id=10) self.Bind(wx.EVT_MENU, self.OnAbout, id=11) self.Bind(wx.EVT_MENU, self.OnUsage, id=12) self.Bind(wx.EVT_MENU, self.OnInfo, id=13) # methods defined below to get and plot the data # Normally do the plot on request, and not here # self.get_data_params() # self.init_plot() # self.get_xyt_data() # plot_data() # ---------- end of __init__ ---------------------------- # ------------------------------------------------------- # Function to make the main Matplotlib panel for plots # ------------------------------------------------------- def create_main_panel(self): """ create_main_panel creates the main mpl panel with instances of: * mpl Canvas * mpl Figure * mpl Figure * mpl Axes with subplot * mpl Widget class Sliders and Button * mpl navigation toolbar self.axes is the instance of MPL Axes, and is where it all happens """ self.panel = wx.Panel(self) # Create the mpl Figure and FigCanvas objects. # 3.5 x 5 inches, 100 dots-per-inch # self.dpi = 100 self.fig = Figure((3.5, 5.0), dpi=self.dpi) self.canvas = FigCanvas(self.panel, wx.ID_ANY, self.fig) # Since we have only one plot, we could use add_axes # instead of add_subplot, but then the subplot # configuration tool in the navigation toolbar wouldn't work. self.axes = self.fig.add_subplot(111) # (111) == (1,1,1) --> row 1, col 1, Figure 1) # self.axes.set_title("View from: "+self.filename) # Now create some sliders below the plot after making room self.fig.subplots_adjust(left=0.1, bottom=0.20) self.axtmin = self.fig.add_axes([0.2, 0.10, 0.5, 0.03]) self.axtmax = self.fig.add_axes([0.2, 0.05, 0.5, 0.03]) self.stmin = Slider(self.axtmin, 't_min:', 0.0, 1.0, valinit=0.0) self.stmax = Slider(self.axtmax, 't_max:', 0.0, 1.0, valinit=1.0) self.stmin.on_changed(self.update_trange) self.stmax.on_changed(self.update_trange) self.axbutton = self.fig.add_axes([0.8, 0.07, 0.1, 0.07]) self.reset_button = Button(self.axbutton, 'Reset') self.reset_button.color = 'skyblue' self.reset_button.hovercolor = 'lightblue' self.reset_button.on_clicked(self.reset_trange) # Create the navigation toolbar, tied to the canvas self.toolbar = NavigationToolbar(self.canvas) def update_trange(self, event): self.t_min = self.stmin.val self.t_max = self.stmax.val # print(self.t_min, self.t_max) def reset_trange(self, event): self.stmin.reset() self.stmax.reset() def create_button_bar(self): """ create_button_bar makes a control panel bar with buttons and toggles for New Data - Play - STOP - Single Step - Forward/Back - Normal/Fast It does not create a Panel container, but simply creates Button objects with bindings, and adds them to a horizontal BoxSizer self.button_bar_sizer. This is added to the PlotFrame vertical BoxSizer, after the MPL canvas, during initialization of the frame. """ rewind_button = wx.Button(self.panel, -1, "New Data") self.Bind(wx.EVT_BUTTON, self.OnRewind, rewind_button) replot_button = wx.Button(self.panel, -1, "Play") self.Bind(wx.EVT_BUTTON, self.OnReplot, replot_button) sstep_button = wx.Button(self.panel, -1, "Single Step") self.Bind(wx.EVT_BUTTON, self.OnSstep, sstep_button) stop_button = wx.Button(self.panel, -1, "STOP") self.Bind(wx.EVT_BUTTON, self.OnStop, stop_button) # The toggle buttons need to be globally accessible self.forward_toggle = wx.ToggleButton(self.panel, -1, "Forward") self.forward_toggle.SetValue(True) self.forward_toggle.SetLabel("Forward") self.Bind(wx.EVT_TOGGLEBUTTON, self.OnForward, self.forward_toggle) self.fast_toggle = wx.ToggleButton(self.panel, -1, " Normal ") self.fast_toggle.SetValue(True) self.fast_toggle.SetLabel(" Normal ") self.Bind(wx.EVT_TOGGLEBUTTON, self.OnFast, self.fast_toggle) # Set button colors to some simple colors that are likely # to be independent on X11 color definitions. Some nice # bit maps (from a media player skin?) should be used # or the buttons and toggle state colors in OnFast() below rewind_button.SetBackgroundColour('skyblue') replot_button.SetBackgroundColour('skyblue') sstep_button.SetBackgroundColour('skyblue') stop_button.SetBackgroundColour('skyblue') self.forward_toggle.SetForegroundColour('black') self.forward_toggle.SetBackgroundColour('yellow') self.fast_toggle.SetForegroundColour('black') self.fast_toggle.SetBackgroundColour('yellow') self.button_bar_sizer = wx.BoxSizer(wx.HORIZONTAL) flags = wx.ALIGN_CENTER | wx.ALL self.button_bar_sizer.Add(rewind_button, 0, border=3, flag=flags) self.button_bar_sizer.Add(replot_button, 0, border=3, flag=flags) self.button_bar_sizer.Add(sstep_button, 0, border=3, flag=flags) self.button_bar_sizer.Add(stop_button, 0, border=3, flag=flags) self.button_bar_sizer.Add(self.forward_toggle, 0, border=3, flag=flags) self.button_bar_sizer.Add(self.fast_toggle, 0, border=3, flag=flags) # ------------------------------------------------------- # Functions to generate or read (x,y) data and plot it # ------------------------------------------------------- def get_data_params(self): # These parameters would normally be provided in a file header, # past as arguments in a function, or from other file information # Next version will bring up a dialog for dt NX NY if no file header # Here check to see if a filename should be entered from File/Open # self.filename = 'Ex_net_Vm_0001.txt' if len(self.filename) == 0: # fake a button press of File/Open self.OnOpen(wx.EVT_BUTTON) # should check here if file exists as specified [path]/filename # assume it is a bzip2 compressed file try: fp = bz2.BZ2File(self.filename) line = fp.readline() except IOError: # then assume plain text fp = open(self.filename) line = fp.readline() fp.close() # check if first line is a header line starting with '#' header = line.split() if header[0][0] == "#": self.Ntimes = int(header[1]) self.t_min = float(header[2]) self.dt = float(header[3]) self.NX = int(header[4]) self.NY = int(header[5]) else: pdentry = self.ParamEntryDialog() if pdentry.ShowModal() == wx.ID_OK: self.Ntimes = int(pdentry.Ntimes_dialog.entry.GetValue()) self.t_min = float(pdentry.tmin_dialog.entry.GetValue()) self.dt = float(pdentry.dt_dialog.entry.GetValue()) self.NX = int(pdentry.NX_dialog.entry.GetValue()) self.NY = int(pdentry.NY_dialog.entry.GetValue()) print 'Ntimes = ', self.Ntimes, ' t_min = ', self.t_min print 'NX = ', self.NX, ' NY = ', self.NY pdentry.Destroy() self.t_max = (self.Ntimes - 1) * self.dt # reset slider max and min self.stmin.valmax = self.t_max self.stmin.valinit = self.t_min self.stmax.valmax = self.t_max self.stmax.valinit = self.t_max self.stmax.set_val(self.t_max) self.stmin.reset() self.stmax.reset() fp.close() def init_plot(self): ''' init_plot creates the initial plot display. A normal MPL plot would be created here with a command "self.axes.plot(x, y)" in order to create a plot of points in the x and y arrays on the Axes subplot. Here, we create an AxesImage instance with imshow(), instead. The initial image is a blank one of the proper dimensions, filled with zeroes. ''' self.t_max = (self.Ntimes - 1) * self.dt self.axes.set_title("View of " + self.filename) # Note that NumPy array (row, col) = image (y, x) data0 = np.zeros((self.NY, self.NX)) # Define a 'cold' to 'hot' color scale based in GENESIS 2 'hot' hotcolors = [ '#000032', '#00003c', '#000046', '#000050', '#00005a', '#000064', '#00006e', '#000078', '#000082', '#00008c', '#000096', '#0000a0', '#0000aa', '#0000b4', '#0000be', '#0000c8', '#0000d2', '#0000dc', '#0000e6', '#0000f0', '#0000fa', '#0000ff', '#000af6', '#0014ec', '#001ee2', '#0028d8', '#0032ce', '#003cc4', '#0046ba', '#0050b0', '#005aa6', '#00649c', '#006e92', '#007888', '#00827e', '#008c74', '#00966a', '#00a060', '#00aa56', '#00b44c', '#00be42', '#00c838', '#00d22e', '#00dc24', '#00e61a', '#00f010', '#00fa06', '#00ff00', '#0af600', '#14ec00', '#1ee200', '#28d800', '#32ce00', '#3cc400', '#46ba00', '#50b000', '#5aa600', '#649c00', '#6e9200', '#788800', '#827e00', '#8c7400', '#966a00', '#a06000', '#aa5600', '#b44c00', '#be4200', '#c83800', '#d22e00', '#dc2400', '#e61a00', '#f01000', '#fa0600', '#ff0000', '#ff0a00', '#ff1400', '#ff1e00', '#ff2800', '#ff3200', '#ff3c00', '#ff4600', '#ff5000', '#ff5a00', '#ff6400', '#ff6e00', '#ff7800', '#ff8200', '#ff8c00', '#ff9600', '#ffa000', '#ffaa00', '#ffb400', '#ffbe00', '#ffc800', '#ffd200', '#ffdc00', '#ffe600', '#fff000', '#fffa00', '#ffff00', '#ffff0a', '#ffff14', '#ffff1e', '#ffff28', '#ffff32', '#ffff3c', '#ffff46', '#ffff50', '#ffff5a', '#ffff64', '#ffff6e', '#ffff78', '#ffff82', '#ffff8c', '#ffff96', '#ffffa0', '#ffffaa', '#ffffb4', '#ffffbe', '#ffffc8', '#ffffd2', '#ffffdc', '#ffffe6', '#fffff0' ] cmap = matplotlib.colors.ListedColormap(hotcolors) self.im = self.axes.imshow(data0, cmap=cmap, origin='lower') # http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html # shows examples to use as a 'cold' to 'hot' mapping of value to color # cm.jet, cm.gnuplot and cm.afmhot are good choices, but are unlike G2 'hot' self.im.cmap = cmap # Not sure how to properly add a colorbar # self.cb = self.fig.colorbar(self.im, orientation='vertical') # refresh the canvas self.canvas.draw() def get_xyt_data(self): # Create scaled (0-1) luminance(x,y) array from ascii G-2 disk_out file # get the data to plot from the specified filename # Note that NumPy loadtxt transparently deals with bz2 compression self.SetStatusText('Data loading - please wait ....') rawdata = np.loadtxt(self.filename) # Note the difference between NumPy [row, col] order and network # x-y grid (x, y) = (col, row). We want a NumPy NY x NX, not # NX x NY, array to be used by the AxesImage object. xydata = np.resize(rawdata, (self.Ntimes, self.NY, self.NX)) # imshow expects the data to be scaled to range 0-1. Vmin = xydata.min() Vmax = xydata.max() self.ldata = (xydata - Vmin) / (Vmax - Vmin) self.data_loaded = True self.SetStatusText('Data has been loaded - click Play') def plot_data(self): ''' plot_data() shows successive frames of the data that was loaded into the ldata array. Creating a new self.im AxesImage instance for each frame is extremely slow, so the set_data method of AxesImage is used to load new data into the existing self.im for each frame. Normally 'self.canvas.draw()' would be used to display a frame, but redrawing the entire canvas, redraws the axes, labels, sliders, buttons, or anything else on the canvas. This uses a method taken from an example in Ch 7, p. 192 Matplotlib for Python developers, with draw_artist() and blit() redraw only the part that was changed. ''' if self.data_loaded == False: # bring up a warning dialog msg = """ Data for plotting has not been loaded! Please enter the file to plot with File/Open, unless it was already specified, and then click on 'New Data' to load the data to play back, before clicking 'Play'. """ wx.MessageBox(msg, "Plot Warning", wx.OK | wx.ICON_ERROR, self) return # set color limits self.im.set_clim(0.0, 1.0) self.im.set_interpolation('nearest') # 'None' is is slightly faster, but not implemented for MPL ver < 1.1 # self.im.set_interpolation('None') # do an initial draw, then save the empty figure axes self.canvas.draw() # self.bg = self.canvas.copy_from_bbox(self.axes.bbox) # However the save and restore is only needed if I change # axes legends, etc. The draw_artist(artist), and blit # are much faster than canvas.draw() and are sufficient. print 'system time (seconds) = ', time.time() # round frame_min down and frame_max up for the time window frame_min = int(self.t_min / self.dt) frame_max = min(int(self.t_max / self.dt) + 1, self.Ntimes) frame_step = self.step # Displaying simulation time to the status bar is much faster # than updating a slider progress bar, but location isn't optimum. # The check for the STOP button doesn't work because the button # click is not registered until this function exits. # check to see if self.reverse_play == True # then interchange frame_min, frame_max, and use negative step if self.reverse_play == True: frame_min = min(int(self.t_max / self.dt) + 1, self.Ntimes) - 1 frame_max = int(self.t_min / self.dt) - 1 frame_step = -self.step for frame_num in range(frame_min, frame_max, frame_step): self.SetStatusText('time: ' + str(frame_num * self.dt)) if self.STOP == True: self.t_min = frame_num * self.dt # set t_min slider ? self.STOP = False break self.im.set_data(self.ldata[frame_num]) self.axes.draw_artist(self.im) self.canvas.blit(self.axes.bbox) print 'system time (seconds) = ', time.time() # ------------------------------------------------------------------ # Define the classes and functions for getting parameter values # -------------------------------------------------------------- class ParamEntryDialog(wx.Dialog): def __init__(self): wx.Dialog.__init__(self, None, wx.ID_ANY) self.SetSize((250, 200)) self.SetTitle('Enter Data File Parameters') vbox = wx.BoxSizer(wx.VERTICAL) self.Ntimes_dialog = XDialog(self) self.Ntimes_dialog.entry_label.SetLabel('Number of entries') self.Ntimes_dialog.entry.ChangeValue(str(2501)) self.tmin_dialog = XDialog(self) self.tmin_dialog.entry_label.SetLabel('Start time (sec)') self.tmin_dialog.entry.ChangeValue(str(0.0)) self.dt_dialog = XDialog(self) self.dt_dialog.entry_label.SetLabel('Output time step (sec)') self.dt_dialog.entry.ChangeValue(str(0.0002)) self.NX_dialog = XDialog(self) self.NX_dialog.entry_label.SetLabel('Number of cells on x-axis') self.NX_dialog.entry.ChangeValue(str(32)) self.NY_dialog = XDialog(self) self.NY_dialog.entry_label.SetLabel('Number of cells on y-axis') self.NY_dialog.entry.ChangeValue(str(32)) vbox.Add(self.Ntimes_dialog, 0, wx.EXPAND | wx.ALL, border=5) vbox.Add(self.tmin_dialog, 0, wx.EXPAND | wx.ALL, border=5) vbox.Add(self.dt_dialog, 0, wx.EXPAND | wx.ALL, border=5) vbox.Add(self.NX_dialog, 0, wx.EXPAND | wx.ALL, border=5) vbox.Add(self.NY_dialog, 0, wx.EXPAND | wx.ALL, border=5) okButton = wx.Button(self, wx.ID_OK, 'Ok') # vbox.Add(okButton,flag=wx.ALIGN_CENTER|wx.TOP|wx.BOTTOM, border=10) vbox.Add(okButton, flag=wx.ALIGN_CENTER, border=10) self.SetSizer(vbox) self.SetSizerAndFit(vbox) # ------------------------------------------------------------------ # Define the functions executed on menu choices # --------------------------------------------------------------- def OnQuit(self, event): self.Close() def OnSave(self, event): file_choices = "PNG (*.png)|*.png" dlg = wx.FileDialog(self, message="Save plot as...", defaultDir=os.getcwd(), defaultFile="plot.png", wildcard=file_choices, style=wx.SAVE) if dlg.ShowModal() == wx.ID_OK: path = dlg.GetPath() self.canvas.print_figure(path, dpi=self.dpi) # self.flash_status_message("Saved to %s" % path) def OnAbout(self, event): msg = """ G-3 Netview ver. 1.7 Netview is a stand-alone Python application for viewing the output of GENESIS 2 and 3 network simulations. It is intended to replace GENESIS 2 SLI scripts that use the XODUS 'xview' widget. The design and operation is based on the G3Plot application for creating 2D plots of y(t) or y(x) from data files. Unlike G3Plot, the image created with Netview is an animated representation of a rectangular network with colored squares used to indicate the value of some variable at that position and time. Typically, this would be the membrane potenial of a cell soma, or a synaptic current in a dendrite segment. Help/Usage gives HTML help for using Netview. This is the main Help page. Help/Program Info provides some information about the objects and functions, and the wxPython and matplotlib classes used here. Dave Beeman, August 2012 """ dlg = wx.MessageDialog(self, msg, "About G-3 Netview", wx.OK | wx.ICON_QUESTION) dlg.ShowModal() dlg.Destroy() def OnOpen(self, event): dlg = wx.TextEntryDialog(self, "File with x,y data to plot", "File Open", self.filename, style=wx.OK | wx.CANCEL) if dlg.ShowModal() == wx.ID_OK: self.filename = dlg.GetValue() # A new filename has been entered, but the data has not been read self.data_loaded = False # print "You entered: %s" % self.filename dlg.Destroy() # This starts with the long string of HTML to display class UsageFrame(wx.Frame): text = """ <HTML> <HEAD></HEAD> <BODY BGCOLOR="#D6E7F7"> <CENTER><H1>Using G-3 Netview</H1></CENTER> <H2>Introduction and Quick Start</H2> <p>Netview is a stand-alone Python application for viewing the output of GENESIS 2 and 3 network simulations. It is intended to replace GENESIS 2 SLI scripts that use the XODUS 'xview' widget.</p> <p>The design and operation is based on the G3Plot application for creating 2D plots of y(t) or y(x) from data files. As with G3Plot, the main class PlotFrame uses a basic wxPython frame to embed a matplotlib figure for plotting. It defines some basic menu items and a control panel of buttons and toggles, each with bindings to a function to execute on a mouse click.</p> <p>Unlike G3Plot, the image created with Netview is an animated representation of a rectangular network with colored squares used to indicate the value of some variable at that position and time. Typically, this would be the membrane potenial of a cell soma, or a synaptic current in a dendrite segment.</p> <h2>Usage</h2> <p>The Menu Bar has <em>File/Open</em>, <em>File/Save</em>, and <em>File/Exit</em> choices. The Help Menu choices <em>About</em> and <em>Usage</em> give further information. The <em>Program Info</em> selection shows code documentation that is contained in some of the main function <em>docstrings</em>.</p> <p>After starting the <em>netview</em> program, enter a data file name in the dialog for File/Open, unless the filename was given as a command line argument. Then click on <strong>New Data</strong> to load the new data and initialize the plot. When the plot is cleared to black, press <strong>Play</strong>.</p> <p>The file types recognized are plain text or text files compressed with bzip2. The expected data format is one line for each output time step, with each line having the membrane potential value of each cell in the net. No time value should be given on the line. In order to properly display the data, netview needs some additional information about the network and the data. This can optionally be contained in a header line that precedes the data. If a header is not detected, a dialog will appear asking for the needed parameters.</p> <p>It is assumed that the cells are arranged on a NX x NY grid, numbered from 0 (bottom left corner) to NX*NY - 1 (upper right corner). In order to provide this information to netview, the data file should begin with a header line of the form:</p> <pre> #optional_RUNID_string Ntimes start_time dt NX NY SEP_X SEP_Y x0 y0 z0 </pre> <p>The line must start with "#" and can optionally be followed immediately by any string. Typically this is some identification string generated by the simulation run. The following parameters, separated by blanks or any whitespace, are:</p> <ul> <li>Ntimes - the number of lines in the file, exclusive of the header</li> <li>start_time - the simulation time for the first data line (default 0.0)</li> <li>dt - the time step used for output</li> <li>NX, NY - the integer dimensions of the network</li> <li>SEP_X, SEP_Y - the x,y distances between cells (optional)</li> <li>x0, y0, z0 - the location of the compartment (data source) relative to the cell origin</li> </ul> <p>The RUNID string and the last five parameters are not read or used by netview. These are available for other data analysis tools that need a RUNID and the location of each source.</p> <p>The slider bars can be used to set a time window for display, and the <strong>Reset</strong> button can set t_min and t_max back to the defaults. Use the <strong>Forward/Back</strong> toggle to reverse direction of <strong>Play</strong>, and the <strong>Normal/Fast</strong> toggle to show every tenth frame.</p> <p>The <strong>Single Step</strong> button can be used to advance a single step at a time (or 10, if in 'Fast' mode).</p> <p>The <strong>STOP</strong> button is currently not implemented</p> <p>To plot different data, enter a new filename with <strong>File/Open</strong> and repeat with <strong>New Data</strong> and <strong>Play</strong>.</p> <HR> </BODY> </HTML> """ def __init__(self, parent): wx.Frame.__init__(self, parent, -1, "Usage and Help", size=(640, 600), pos=(400, 100)) html = wx.html.HtmlWindow(self) html.SetPage(self.text) panel = wx.Panel(self, -1) button = wx.Button(panel, wx.ID_OK, "Close") self.Bind(wx.EVT_BUTTON, self.OnCloseMe, button) sizer = wx.BoxSizer(wx.VERTICAL) sizer.Add(html, 1, wx.EXPAND | wx.ALL, 5) sizer.Add(panel, 0, wx.ALIGN_CENTER | wx.ALL, 5) self.SetSizer(sizer) self.Layout() def OnCloseMe(self, event): self.Close(True) # ----------- end of class UsageFrame --------------- def OnUsage(self, event): usagewin = self.UsageFrame(self) usagewin.Show(True) def OnInfo(self, event): msg = "Program information for PlotFrame obtained from docstrings:" msg += "\n" + self.__doc__ + "\n" + self.create_main_panel.__doc__ msg += self.create_button_bar.__doc__ msg += self.init_plot.__doc__ msg += self.plot_data.__doc__ dlg = wx.lib.dialogs.ScrolledMessageDialog(self, msg, "PlotFrame Documentation") dlg.ShowModal() # --------------------------------------------------------------- # Define the functions executed on control button click # --------------------------------------------------------------- def OnRewind(self, event): self.get_data_params() self.init_plot() self.get_xyt_data() def OnReplot(self, event): self.plot_data() self.canvas.draw() def OnSstep(self, event): if self.data_loaded == False: # bring up a warning dialog msg = """ Data for plotting has not been loaded! Please enter the file to plot with File/Open, unless it was already specified, and then click on 'New Data' to load the data to play back, before clicking 'Play'. """ wx.MessageBox(msg, "Plot Warning", wx.OK | wx.ICON_ERROR, self) return self.t_max = min(self.t_max + self.dt, (self.Ntimes - 1) * self.dt) self.stmax.set_val(self.t_max) frame_num = int(self.t_max / self.dt) self.SetStatusText('time: ' + str(frame_num * self.dt)) self.im.set_data(self.ldata[frame_num]) self.axes.draw_artist(self.im) self.canvas.blit(self.axes.bbox) def OnStop(self, event): self.STOP = 'True' def OnForward(self, event): state = self.forward_toggle.GetValue() if state: self.reverse_play = False self.forward_toggle.SetLabel("Forward ") self.forward_toggle.SetForegroundColour('black') self.forward_toggle.SetBackgroundColour('yellow') else: self.reverse_play = True self.forward_toggle.SetLabel(" Back ") self.forward_toggle.SetForegroundColour('red') self.forward_toggle.SetBackgroundColour('green') def OnFast(self, event): state = self.fast_toggle.GetValue() if state: # print state self.fast_toggle.SetLabel(" Normal ") self.fast_toggle.SetForegroundColour('black') self.fast_toggle.SetBackgroundColour('yellow') self.step = 1 else: # print state self.fast_toggle.SetLabel(" Fast ") self.fast_toggle.SetForegroundColour('red') self.fast_toggle.SetBackgroundColour('green') self.step = 10
class ParametricEQSelector: def __init__(self): self._fs = 48000 self._nbands = 7 self._params = [] for i in range(0, self._nbands): self._params.append({'center': 0, 'resonance': 1.0/math.sqrt(2.0), 'dbgain': 0}) self._selected_band = 0 self._blocksize = 512 self._nfft = int(self._blocksize / 2) self._impulse_response = [0] * self._blocksize self._freq_response_real = [0] * self._blocksize self._freq_response_imag = [0] * self._blocksize self._response = [0] * self._blocksize self._plot_db = True self._eq = yodel.filter.ParametricEQ(self._fs, self._nbands) self._create_plot() self._create_plot_controls() self.select_band('Band ' + str(self._selected_band+1)) def _create_plot(self): self._fig, self._ax = plt.subplots() self._ax.set_title('Parametric Equalizer Design') self._ax.grid() plt.subplots_adjust(bottom=0.3) self._update_filter_response() self._x_axis = [i*(self._fs/2/self._nfft) for i in range(0, self._nfft)] self._y_axis = self._response[0:self._nfft] self._l_center, = self._ax.plot(self._x_axis, [0] * self._nfft, 'k') self._l_fr, = self._ax.plot(self._x_axis, self._y_axis, 'b') self._rescale_plot() def _create_plot_controls(self): self._dbrax = plt.axes([0.12, 0.05, 0.13, 0.10]) self._dbradio = RadioButtons(self._dbrax, ('Amplitude', 'Phase')) self._dbradio.on_clicked(self.set_plot_style) self._rax = plt.axes([0.27, 0.03, 0.15, 0.20]) bands_list = [] for i in range(1, self._nbands+1): bands_list.append('Band ' + str(i)) self._radio = RadioButtons(self._rax, tuple(bands_list)) self._radio.on_clicked(self.select_band) self._sfax = plt.axes([0.6, 0.19, 0.2, 0.03]) self._sqax = plt.axes([0.6, 0.12, 0.2, 0.03]) self._sdbax = plt.axes([0.6, 0.05, 0.2, 0.03]) self._fcslider = Slider(self._sfax, 'Cut-off frequency', 0, self._fs/2, valinit = self._params[self._selected_band]['center']) self._qslider = Slider(self._sqax, 'Q factor', 0.01, 10.0, valinit = self._params[self._selected_band]['resonance']) self._dbslider = Slider(self._sdbax, 'dB gain', -20.0, 20.0, valinit = self._params[self._selected_band]['dbgain']) self._fcslider.on_changed(self.set_center_frequency) self._qslider.on_changed(self.set_resonance) self._dbslider.on_changed(self.set_dbgain) def _rescale_plot(self): if self._plot_db: self._ax.set_ylim(-30, 30) else: self._ax.set_ylim(- 200, 200) plt.draw() def _plot_frequency_response(self, redraw=True): self._update_filter_response() self._y_axis = self._response[0:self._nfft] self._l_fr.set_ydata(self._y_axis) if redraw: plt.draw() def _plot_range_limits(self, redraw=True): self._l_center.set_ydata([0] * self._nfft) if redraw: plt.draw() def set_plot_style(self, style): if style == 'Phase': self._plot_db = False elif style == 'Amplitude': self._plot_db = True self._plot_range_limits(False) self._plot_frequency_response(False) self._rescale_plot() def select_band(self, band): idx = band.split(' ') self._selected_band = int(idx[1]) - 1 self._fcslider.set_val(self._params[self._selected_band]['center']) self._qslider.set_val(self._params[self._selected_band]['resonance']) self._dbslider.set_val(self._params[self._selected_band]['dbgain']) def set_center_frequency(self, val): self._params[self._selected_band]['center'] = val self._set_band(self._selected_band) self._plot_frequency_response() def set_resonance(self, val): self._params[self._selected_band]['resonance'] = val self._set_band(self._selected_band) self._plot_frequency_response() def set_dbgain(self, val): self._params[self._selected_band]['dbgain'] = val self._set_band(self._selected_band) self._plot_frequency_response() def _set_band(self, band): center = self._params[band]['center'] resonance = self._params[band]['resonance'] dbgain = self._params[band]['dbgain'] self._eq.set_band(band, center, resonance, dbgain) def _update_filter_response(self): self._impulse_response = impulse_response(self._eq, self._blocksize) self._freq_response_real, self._freq_response_imag = frequency_response(self._impulse_response) if self._plot_db: self._response = amplitude_response(self._freq_response_real, self._freq_response_imag) else: self._response = phase_response(self._freq_response_real, self._freq_response_imag)
class EnergyPlusModel(metaclass=ABCMeta): def __init__(self, model_file, log_dir=None, verbose=False): self.log_dir = log_dir self.model_basename = os.path.splitext(os.path.basename(model_file))[0] self.setup_spaces() self.action = 0.5 * (self.action_space.low + self.action_space.high) self.action_prev = self.action self.raw_state = None self.verbose = verbose self.timestamp_csv = None self.sl_episode = None # Progress data self.num_episodes = 0 self.num_episodes_last = 0 self.reward = None self.reward_mean = None def reset(self): pass # Parse date/time format from EnergyPlus and return datetime object with correction for 24:00 case def _parse_datetime(self, dstr): # ' MM/DD HH:MM:SS' or 'MM/DD HH:MM:SS' # Dirty hack if dstr[0] != ' ': dstr = ' ' + dstr # year = 2017 year = 2013 # for CHICAGO_IL_USA TMY2-94846 month = int(dstr[1:3]) day = int(dstr[4:6]) hour = int(dstr[8:10]) minute = int(dstr[11:13]) sec = 0 msec = 0 if hour == 24: hour = 0 dt = datetime(year, month, day, hour, minute, sec, msec) + timedelta(days=1) else: dt = datetime(year, month, day, hour, minute, sec, msec) return dt # Convert list of date/time string to list of datetime objects def _convert_datetime24(self, dates): # ' MM/DD HH:MM:SS' dates_new = [] for d in dates: # year = 2017 # month = int(d[1:3]) # day = int(d[4:6]) # hour = int(d[8:10]) # minute = int(d[11:13]) # sec = 0 # msec = 0 # if hour == 24: # hour = 0 # d_new = datetime(year, month, day, hour, minute, sec, msec) + dt.timedelta(days=1) # else: # d_new = datetime(year, month, day, hour, minute, sec, msec) # dates_new.append(d_new) dates_new.append(self._parse_datetime(d)) return dates_new # Generate x_pos and x_labels def generate_x_pos_x_labels(self, dates): time_delta = self._parse_datetime(dates[1]) - self._parse_datetime( dates[0]) x_pos = [] x_labels = [] for i, d in enumerate(dates): dt = self._parse_datetime(d) - time_delta if dt.hour == 0 and dt.minute == 0: x_pos.append(i) x_labels.append(dt.strftime('%m/%d')) return x_pos, x_labels def set_action(self, action): # In TPRO/POP1/POP2 in baseline, action seems to be normalized to [-1.0, 1.0]. # So it must be scaled back into action_space by the environment. assert action.shape == self.action_space.low.shape, 'Invalid action {}'.format( action) self.action_prev = self.action self.action = action self.action = np.clip(self.action, self.action_space.low, self.action_space.high) # self.action_prev = self.action # self.action = self.action_space.low + (normalized_action + 1.) * 0.5 * ( # self.action_space.high - self.action_space.low) # self.action = np.clip(self.action, self.action_space.low, self.action_space.high) @abstractmethod def setup_spaces(self): pass # Need to handle the case that raw_state is None @abstractmethod def set_raw_state(self, raw_state): pass def get_state(self): return self.format_state(self.raw_state) @abstractmethod def compute_reward(self): pass @abstractmethod def format_state(self, raw_state): pass # -------------------------------------------------- # Plotting staffs follow # -------------------------------------------------- def plot(self, log_dir='', csv_file='', **kwargs): if log_dir is not '': if not os.path.isdir(log_dir): print('energyplus_model.plot: {} is not a directory'.format( log_dir)) return print('energyplus_plot.plot log={}'.format(log_dir)) self.log_dir = log_dir self.show_progress() else: if not os.path.isfile(csv_file): print( 'energyplus_model.plot: {} is not a file'.format(csv_file)) return print('energyplus_model.plot csv={}'.format(csv_file)) self.read_episode(csv_file) plt.rcdefaults() plt.rcParams['font.size'] = 6 plt.rcParams['lines.linewidth'] = 1.0 plt.rcParams['legend.loc'] = 'lower right' self.fig = plt.figure(1, figsize=(16, 10)) self.plot_episode(csv_file) plt.show() # Show convergence def show_progress(self): self.monitor_file = self.log_dir + '/monitor.csv' # Read progress file if not self.read_monitor_file(): print('Progress data is missing') sys.exit(1) # Initialize graph plt.rcdefaults() plt.rcParams['font.size'] = 6 plt.rcParams['lines.linewidth'] = 1.0 plt.rcParams['legend.loc'] = 'lower right' self.fig = plt.figure(1, figsize=(16, 10)) # Show widgets axcolor = 'lightgoldenrodyellow' self.axprogress = self.fig.add_axes([0.15, 0.10, 0.70, 0.15], facecolor=axcolor) self.axslider = self.fig.add_axes([0.15, 0.04, 0.70, 0.02], facecolor=axcolor) axfirst = self.fig.add_axes([0.15, 0.01, 0.03, 0.02]) axlast = self.fig.add_axes([0.82, 0.01, 0.03, 0.02]) axprev = self.fig.add_axes([0.46, 0.01, 0.03, 0.02]) axnext = self.fig.add_axes([0.51, 0.01, 0.03, 0.02]) # Slider is drawn in plot_progress() # First/Last button self.button_first = Button(axfirst, 'First', color=axcolor, hovercolor='0.975') self.button_first.on_clicked(self.first_episode_num) self.button_last = Button(axlast, 'Last', color=axcolor, hovercolor='0.975') self.button_last.on_clicked(self.last_episode_num) # Next/Prev button self.button_prev = Button(axprev, 'Prev', color=axcolor, hovercolor='0.975') self.button_prev.on_clicked(self.prev_episode_num) self.button_next = Button(axnext, 'Next', color=axcolor, hovercolor='0.975') self.button_next.on_clicked(self.next_episode_num) # Timer self.timer = self.fig.canvas.new_timer(interval=1000) self.timer.add_callback(self.check_update) self.timer.start() # Progress data self.axprogress.set_xmargin(0) self.axprogress.set_xlabel('Episodes') self.axprogress.set_ylabel('Reward') self.axprogress.grid(True) self.plot_progress() # Plot latest episode self.update_episode(self.num_episodes - 1) plt.show() def check_update(self): if self.read_monitor_file(): self.plot_progress() def plot_progress(self): # Redraw all lines self.axprogress.lines = [] self.axprogress.plot(self.reward, color='#1f77b4', label='Reward') # self.axprogress.plot(self.reward_mean, color='#ff7f0e', label='Reward (average)') self.axprogress.legend() # Redraw slider if self.sl_episode is None or int(round( self.sl_episode.val)) == self.num_episodes - 2: cur_ep = self.num_episodes - 1 else: cur_ep = int(round(self.sl_episode.val)) self.axslider.clear() # self.sl_episode = Slider(self.axslider, 'Episode (0..{})'.format(self.num_episodes - 1), 0, self.num_episodes - 1, valinit=self.num_episodes - 1, valfmt='%6.0f') self.sl_episode = Slider(self.axslider, 'Episode (0..{})'.format(self.num_episodes - 1), 0, self.num_episodes - 1, valinit=cur_ep, valfmt='%6.0f') self.sl_episode.on_changed(self.set_episode_num) def read_monitor_file(self): # For the very first call, Wait until monitor.csv is created if self.timestamp_csv is None: while not os.path.isfile(self.monitor_file): time.sleep(1) self.timestamp_csv = os.stat( self.monitor_file ).st_mtime - 1 # '-1' is a hack to prevent losing the first set of data num_ep = 0 ts = os.stat(self.monitor_file).st_mtime if ts > self.timestamp_csv: # Monitor file is updated. self.timestamp_csv = ts f = open(self.monitor_file) firstline = f.readline() assert firstline.startswith('#') metadata = json.loads(firstline[1:]) assert metadata['env_id'] == "EnergyPlus-v0" assert set(metadata.keys()) == { 'env_id', 't_start' }, "Incorrect keys in monitor metadata" df = pd.read_csv(f, index_col=None) assert set(df.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline" f.close() self.reward = [] self.reward_mean = [] self.episode_dirs = [] self.num_episodes = 0 for rew, len, time_ in zip(df['r'], df['l'], df['t']): self.reward.append(rew / len) self.reward_mean.append(rew / len) self.episode_dirs.append( self.log_dir + '/output/episode-{:08d}'.format(self.num_episodes)) self.num_episodes += 1 if self.num_episodes > self.num_episodes_last: self.num_episodes_last = self.num_episodes return True else: return False def update_episode(self, ep): self.plot_episode(ep) def set_episode_num(self, val): ep = int(round(self.sl_episode.val)) self.update_episode(ep) def first_episode_num(self, val): self.sl_episode.set_val(0) def last_episode_num(self, val): self.sl_episode.set_val(self.num_episodes - 1) def prev_episode_num(self, val): ep = int(round(self.sl_episode.val)) if ep > 0: ep -= 1 self.sl_episode.set_val(ep) def next_episode_num(self, val): ep = int(round(self.sl_episode.val)) if ep < self.num_episodes - 1: ep += 1 self.sl_episode.set_val(ep) def show_statistics(self, title, series): print('{:25} ave={:5,.2f}, min={:5,.2f}, max={:5,.2f}, std={:5,.2f}'. format(title, np.average(series), np.min(series), np.max(series), np.std(series))) def get_statistics(self, series): return np.average(series), np.min(series), np.max(series), np.std( series) def show_distrib(self, title, series): dist = [0 for i in range(1000)] for v in series: idx = int(math.floor(v * 10)) if idx >= 1000: idx = 999 if idx < 0: idx = 0 dist[idx] += 1 print(title) print( ' degree 0.0-0.9 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9' ) print( ' -------------------------------------------------------------------------' ) for t in range(170, 280, 10): print(' {:4.1f}C {:5.1%} '.format( t / 10.0, sum(dist[t:(t + 10)]) / len(series)), end='') for tt in range(t, t + 10): print(' {:5.1%}'.format(dist[tt] / len(series)), end='') print('') def get_episode_list(self, log_dir='', csv_file=''): if (log_dir is not '' and csv_file is not '') or (log_dir is '' and csv_file is ''): print('Either one of log_dir or csv_file must be specified') quit() if log_dir is not '': if not os.path.isdir(log_dir): print('energyplus_model.dump: {} is not a directory'.format( log_dir)) return print('energyplus_plot.dump: log={}'.format(log_dir)) # self.log_dir = log_dir # Make a list of all episodes # Note: Somethimes csv file is missing in the episode directories # We accept gziped csv file also. csv_list = glob(log_dir + '/output/episode-????????/eplusout.csv') \ + glob(log_dir + '/output/episode-????????/eplusout.csv.gz') self.episode_dirs = list( set([os.path.dirname(i) for i in csv_list])) self.episode_dirs.sort() self.num_episodes = len(self.episode_dirs) else: # csv_file != '' self.episode_dirs = [os.path.dirname(csv_file)] self.num_episodes = len(self.episode_dirs) # Model dependent methods @abstractmethod def read_episode(self, ep): pass @abstractmethod def plot_episode(self, ep): pass @abstractmethod def dump_timesteps(self, log_dir='', csv_file='', **kwargs): pass @abstractmethod def dump_episodes(self, log_dir='', csv_file='', **kwargs): pass
class Player(FuncAnimation): """ The class makes a player with play, stop, and next buttons and a frame slider. """ def __init__(self, fig, func, init_func=None, fargs=None, save_count=None, button_color='yellow', bg_color='red', dis_start=0, dis_stop=100, pos=(0.125, 0.05), **kwargs): """ initialization :param fig: matplotlib fifure object :param func: user-defined function which takes a integer (frame number) as an input :param init_func: user-defined initial function used by the FuncAnimation class :param fargs: arguments of func, used by FuncAnimation class :param save_count: save count arg used by FuncAnimation class :param button_color: string, color of the buttons of the player :param bg_color: string, hovercolor of the buttons and slider :param dis_start: int, start frame number :param dis_stop: int, stop frame number :param pos: length 2 tuple, position of the buttons :param kwargs: kwargs for FuncAnimation class """ # setting up the index self.start_ind = dis_start self.stop_ind = dis_stop self.dis_length = self.stop_ind - self.start_ind self.ind = self.start_ind self.runs = True self.forwards = True self.fig = fig self.fig.set_facecolor('k') self.button_color = button_color self.bg_color = bg_color self.func = func self.setup(pos) FuncAnimation.__init__(self, self.fig, self.func, frames=self.play(), init_func=init_func, fargs=fargs, save_count=save_count, **kwargs) @property def ind(self): return self._ind @ind.setter def ind(self, val): self._ind = val self._ind -= self.start_ind self._ind %= (self.dis_length) self._ind += self.start_ind def play(self): """ play function """ while self.runs: self.ind = self.ind + self.forwards - (not self.forwards) self._update() yield self.ind def start(self): self.runs = True self._update() self.event_source.start() def stop(self, event=None): self.runs = False self._update() self.event_source.stop() def forward(self, event=None): self.forwards = True self.start() def backward(self, event=None): self.forwards = False self.start() def oneforward(self, event=None): self.forwards = True self.onestep() def onebackward(self, event=None): self.forwards = False self.onestep() def onestep(self): if self.forwards: self.ind += 1 else: self.ind -= 1 self.func(self.ind) self._update() self.fig.canvas.draw_idle() def _update(self): self.slider.set_val(self.ind) def __set_slider(self, val): val = int(val) self.ind = val #self.func(self.ind) def setup(self, pos): """ Setting up the buttons and the slider :param pos: length 2 tuple, position of the axes for buttons and tuples :return: """ playerax = self.fig.add_axes([pos[0], pos[1], 0.22, 0.04]) divider = mpl_toolkits.axes_grid1.make_axes_locatable(playerax) bax = divider.append_axes("right", size="80%", pad=0.05) sax = divider.append_axes("right", size="80%", pad=0.05) fax = divider.append_axes("right", size="80%", pad=0.05) ofax = divider.append_axes("right", size="100%", pad=0.05) sliderax = self.fig.add_axes( (pos[0], pos[1] - 0.045, 0.5, 0.04), facecolor=self.bg_color) # 'lemonchiffon') self.button_oneback = matplotlib.widgets.Button( playerax, color=self.button_color, hovercolor=self.bg_color, label='$\u29CF$') # , label=r'$\u29CF$') self.button_back = matplotlib.widgets.Button( bax, color=self.button_color, hovercolor=self.bg_color, label='$\u25C0$') # r'$\u25C0$') self.button_stop = matplotlib.widgets.Button( sax, color=self.button_color, hovercolor=self.bg_color, label='$\u25A0$') # , label=r'$\u25A0$') self.button_forward = matplotlib.widgets.Button( fax, color=self.button_color, hovercolor=self.bg_color, label='$\u25B6$') # , label=r'$\u25B6$') self.button_oneforward = matplotlib.widgets.Button( ofax, color=self.button_color, hovercolor=self.bg_color, label='$\u29D0$') # , label=r'$\u29D0$') self.button_oneback.on_clicked(self.onebackward) self.slider = Slider(sliderax, label='', valfmt='%0.0f', valmin=0, valmax=self.stop_ind - 1, valinit=self.ind, color='black', fc=self.button_color) # , snap='True') self.slider.label.set_color(self.button_color) # self.slider.valtext.set_color(self.button_color) self.slider.valtext.set_position((0.5, 0.5)) self.slider.set_val(self.ind) self.button_back.on_clicked(self.backward) self.button_stop.on_clicked(self.stop) self.button_forward.on_clicked(self.forward) self.button_oneforward.on_clicked(self.oneforward) self.slider.on_changed(self.__set_slider)
class CubeDisplayBase(ImageDisplay): #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def __init__(self, ax, data, coords=None, **kwargs): ''' Image display for 3D data. Implements frame slider and image scroll. Optionally also displays apertures if coordinates provided. subclasses must implement set_frame, get_frame methods Parameters ---------- ax : Axes object Axes on which to display data : array-like initial display data coords : optional, np.ndarray coordinates of apertures to display. This must be an np.ndarray with shape (k, N, 2) where k is the number of apertures per frame, and N is the number of frames kwargs are passed directly to ImageDisplay. ''' #setup image display self.autoscale = kwargs.pop('autoscale', 'percentile') #TODO: move up?? ImageDisplay.__init__(self, ax, data, **kwargs) #self.coords = coords #setup frame slider self._frame = 0 self.fsax = self.divider.append_axes('bottom', size=0.2, pad=0.25) #TODO: elliminated this SHIT Slider class!!! self.frame_slider = Slider(self.fsax, 'frame', 0, len(self), valfmt='%d') self.frame_slider.on_changed(self.set_frame) if self.use_blit: self.frame_slider.drawon = False #save background for blitting fig = ax.figure self.background = fig.canvas.copy_from_bbox(ax.bbox) #enable frame scroll fig.canvas.mpl_connect('scroll_event', self._scroll) #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #@property #def has_coords(self): #return self.coords is not None #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _needs_drawing(self): #NOTE: this method is a temp hack to return the artists that need to be #drawn when the frame is changed (for blitting). This is in place while #the base class is being refined. #TODO: proper observers as modelled on draggables.machinery needs_drawing = [self.imgplt] if self.has_hist: needs_drawing.extend(self.patches) #TODO: PatchCollection... if self.autoscale: needs_drawing.extend(self.sliders.sliders) ##[#self.imgplt.colorbar, #self.sliders.centre_knob]) return needs_drawing #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def get_data(self, i): return self.data[i] #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def get_frame(self): return self._frame #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #@expose.args() def set_frame(self, i, draw=False): '''Set frame data. draw if requested ''' i %= len(self) #wrap around! (eg. scroll past end ==> go to beginning) i = int(round(i, 0)) #make sure we have an int self._frame = i data = self.get_data(i) #ImageDisplay.draw_blit?? #set the slider axis limits dmin, dmax = data.min(), data.max() self.sliders.ax.set_ylim(dmin, dmax) self.sliders.valmin, self.sliders.valmax = dmin, dmax #needs_drawing.append()??? #set the image data self.imgplt.set_data(data) #needs_drawing = [self.imgplt] if self.autoscale: #set the slider positiions / color limits vmin, vmax = self.get_autoscale_limits(data, autoscale=self.autoscale) self.imgplt.set_clim(vmin, vmax) self.sliders.set_positions((vmin, vmax)) #TODO: update hisogram values etc... #ImageDisplay.draw_blit?? if draw: needs_drawing = self._needs_drawing() self.draw_blit(needs_drawing) frame = property(get_frame, set_frame) #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _scroll(self, event): self.frame += [-1, +1][event.button == 'up'] self.frame_slider.set_val(self.frame) #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #@expose.args() def draw_blit(self, artists): #print('draw_blit') fig = self.ax.figure fig.canvas.restore_region(self.background) for art in artists: try: self.ax.draw_artist(art) except Exception as err: print('drawing FAILED', art) traceback.print_exc() fig.canvas.blit(fig.bbox) #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def cooDisplayFormatter(self, x, y): s = ImageDisplay.cooDisplayFormatter(self, x,y) return 'frame %d: %s'%(self.frame, s)
class SpecPlotter: def __init__(self, q, freqs, x_max, hist_time, time_res, maxscale): self.HORIZONTAL_PIXELS = len(freqs) + 1 self.q = q self.freqs = freqs self.x_max = x_max self.hist_time = hist_time self.time_res = time_res self.maxscale = maxscale self.hist_segments = floor(self.hist_time / self.time_res) plt.ion() self.figure, (self.ax, self.ax2) = plt.subplots(2) self.figure.canvas.mpl_connect("close_event", self.exit_evt) self.Z = np.zeros((self.hist_segments, self.HORIZONTAL_PIXELS)) self.highest_peak = 0.0 if maxscale is None else maxscale self.scale = self.highest_peak self.scale_slider_touched = False plt.subplots_adjust(bottom=0.1) self.ax_scale_slider = plt.axes([0.15, 0.03, 0.7, 0.03]) self.scale_slider = Slider(self.ax_scale_slider, "Scale", 0, self.highest_peak, self.highest_peak) self.scale_slider.on_changed(self.scale_slider_changed) self.X = np.linspace(self.freqs[0], self.freqs[-1], self.HORIZONTAL_PIXELS) self.row = np.zeros(len(self.freqs)) def update(self): while True: while not self.q.empty(): self.row = self.q.get() self.Z = np.vstack((self.Z, interp1d(self.freqs, self.row, kind="linear")(self.X))) if self.Z.shape[0] > self.hist_segments: self.Z = np.delete(self.Z, 0, axis=0) peak_height = np.abs(self.row).max() if self.highest_peak < peak_height: self.highest_peak = peak_height if self.maxscale is None: if self.scale == self.scale_slider.valmax: self.scale = self.highest_peak self.ax_scale_slider.set_xlim(0, self.highest_peak) self.scale_slider.valmax = self.highest_peak self.scale_slider.set_val(self.scale) self.ax.clear() self.ax.imshow( self.Z, extent=(0, self.x_max, 0, self.hist_time), cmap="Greys", vmin=0, vmax=self.scale, interpolation="none", aspect="auto", ) self.ax.set_title("Spectrogram") self.ax.set_xlabel("Hz") self.ax2.clear() self.ax2.plot(self.freqs, self.row) self.ax2.axis([0, self.x_max, 0, self.scale]) # Draw canvas plt.draw() plt.pause(0.001) def exit_evt(self, evt): os._exit(0) def scale_slider_changed(self, v): self.scale = v
class viscm_editor(object): def __init__(self, uniform_space="CAM02-UCS", min_Jp=15, max_Jp=95, xp=None, yp=None): from .bezierbuilder import BezierModel, BezierBuilder self._uniform_space = uniform_space self.figure = plt.figure() axes = _viscm_editor_axes(self.figure) ax_btn_wireframe = plt.axes([0.7, 0.15, 0.1, 0.025]) self.btn_wireframe = Button(ax_btn_wireframe, "Show 3D gamut") self.btn_wireframe.on_clicked(self.plot_3d_gamut) ax_btn_wireframe = plt.axes([0.81, 0.15, 0.1, 0.025]) self.btn_save = Button(ax_btn_wireframe, "Save colormap") self.btn_save.on_clicked(self.save_colormap) ax_btn_props = plt.axes([0.81, 0.1, 0.1, 0.025]) self.btn_props = Button(ax_btn_props, "Properties") self.btn_props.on_clicked(self.show_viscm) self.prop_windows = [] axcolor = "None" ax_jp_min = plt.axes([0.1, 0.1, 0.5, 0.03], axisbg=axcolor) ax_jp_min.imshow(np.linspace(0, 100, 101).reshape(1, -1), cmap="gray") ax_jp_min.set_xlim(0, 100) ax_jp_max = plt.axes([0.1, 0.15, 0.5, 0.03], axisbg=axcolor) ax_jp_max.imshow(np.linspace(0, 100, 101).reshape(1, -1), cmap="gray") self.jp_min_slider = Slider(ax_jp_min, r"$J'_\mathrm{min}$", 0, 100, valinit=min_Jp) self.jp_max_slider = Slider(ax_jp_max, r"$J'_\mathrm{max}$", 0, 100, valinit=max_Jp) self.jp_min_slider.on_changed(self._jp_update) self.jp_max_slider.on_changed(self._jp_update) if xp is None: xp = [-2.0591553836234482, 59.377014829142524, 43.552546744036135, 4.7670857511283202, -9.5059638942617539] if yp is None: yp = [-25.664893617021221, -21.941489361702082, 38.874113475177353, 20.567375886524871, 32.047872340425585] self.bezier_model = BezierModel(xp, yp) self.cmap_model = BezierCMapModel( self.bezier_model, self.jp_min_slider.val, self.jp_max_slider.val, uniform_space ) self.highlight_point_model = HighlightPointModel(self.cmap_model, 0.5) self.bezier_builder = BezierBuilder(axes["bezier"], self.bezier_model) self.bezier_gamut_viewer = GamutViewer2D(axes["bezier"], self.highlight_point_model, uniform_space) tmp = HighlightPoint2DView(axes["bezier"], self.highlight_point_model) self.bezier_highlight_point_view = tmp # draw_pure_hue_angles(axes['bezier']) axes["bezier"].set_xlim(-100, 100) axes["bezier"].set_ylim(-100, 100) self.cmap_view = CMapView(axes["cm"], self.cmap_model) self.cmap_highlighter = HighlightPointBuilder(axes["cm"], self.highlight_point_model) print("Click sliders at bottom to change min/max lightness") print("Click on colorbar to adjust gamut view") print("Click-drag to move control points, ") print(" shift-click to add, control-click to delete") def plot_3d_gamut(self, event): fig, ax = plt.subplots(subplot_kw=dict(projection="3d")) self.wireframe_view = WireframeView(ax, self.cmap_model, self.highlight_point_model, self._uniform_space) plt.show() def save_colormap(self, event): import textwrap template = textwrap.dedent( """ from matplotlib.colors import ListedColormap from numpy import nan, inf # Used to reconstruct the colormap in viscm parameters = {{'xp': {xp}, 'yp': {yp}, 'min_Jp': {min_Jp}, 'max_Jp': {max_Jp}}} cm_data = {array_list} test_cm = ListedColormap(cm_data, name=__file__) if __name__ == "__main__": import matplotlib.pyplot as plt import numpy as np try: from viscm import viscm viscm(test_cm) except ImportError: print("viscm not found, falling back on simple display") plt.imshow(np.linspace(0, 100, 256)[None, :], aspect='auto', cmap=test_cm) plt.show() """ ) rgb, _ = self.cmap_model.get_sRGB(num=256) with open("/tmp/new_cm.py", "w") as f: array_list = np.array2string(rgb, max_line_width=78, prefix="cm_data = ", separator=",") xp, yp = self.cmap_model.bezier_model.get_control_points() data = dict( array_list=array_list, xp=xp, yp=yp, min_Jp=self.cmap_model.min_Jp, max_Jp=self.cmap_model.max_Jp ) f.write(template.format(**data)) print("*" * 50) print("Saved colormap to /tmp/new_cm.py") print("*" * 50) def show_viscm(self, event): cm = LinearSegmentedColormap.from_list("test_cm", self.cmap_model.get_sRGB(num=256)[0]) self.prop_windows.append(viscm(cm, name="test_cm")) plt.show() def _jp_update(self, val): jp_min = self.jp_min_slider.val jp_max = self.jp_max_slider.val smallest, largest = min(jp_min, jp_max), max(jp_min, jp_max) if (jp_min > smallest) or (jp_max < largest): self.jp_min_slider.set_val(smallest) self.jp_max_slider.set_val(largest) self.cmap_model.set_Jp_minmax(smallest, largest)
class FigureContainer(): """instantiates figure and axes, creates all other objects. Holds the animation""" def __init__(self, file): #because I'm using keys that matplotlib also uses, remove some bindings plt.rcParams['keymap.fullscreen'] = '{' plt.rcParams['keymap.yscale'] = '}' self.fig = plt.figure(facecolor='w') self.fig.set_size_inches(16, 9) g = gspec.GridSpec(5, 2, height_ratios=[2, 6, 4, 1, 6], width_ratios=[10, 3]) g.update(left=0.05, right=0.95, wspace=0.02, hspace=.04, bottom=.02, top=.98) #add the axes in the subplot areas desired self.spatial_axis = self.fig.add_subplot(g[1, 0]) self.vid_axis = self.fig.add_subplot(g[4, 0]) self.temporal_axis = self.fig.add_subplot(g[2, 0]) #self.select_axis = self.fig.add_subplot(g[:,1]) self.sel_panel = SelectPanel(self.fig, g) #create an axes to hold space for the buttons #self.fig.add_subplot(g[0,0]).set_axis_off() #add a slider slid_ax = self.fig.add_subplot(g[3, 0]) #TODO: does this actually force it to an int self.slide = Slider(slid_ax, 'Frame', 0, 100, valinit=0, valfmt='%0.0f') self.print_manager = PrintManager(self.temporal_axis, self.spatial_axis, self.sel_panel, self.vid_axis, self.slide, file) self.sel_panel.set_print_manager(self.print_manager) self.set_slider_range() self.prev_frame = self.slide.val #activate pick events self.fig.canvas.mpl_connect('pick_event', self.print_manager.on_pick) self.fig.canvas.mpl_connect('key_press_event', self.print_manager.on_key_press) g2 = gspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=g[0, 0], wspace=0.1, hspace=0.1) #g2.update(left=0.05, right=0.95, wspace=0.02, hspace=.04, # bottom = .2, top = .8) #ax=self.fig.add_subplot(g2[0,0]) #Button(ax, 'Load File').on_clicked( # self.print_manager.initiate_split_window) ax = self.fig.add_subplot(g2[0, 0]) self.sav_but = Button(ax, 'Save Changes') self.sav_but.on_clicked(self.print_manager.save) #ax=self.fig.add_subplot(g2[0,2]) #Button(ax, 'Pause Video') #ax=self.fig.add_subplot(g2[0,3]) #Button(ax, 'View Deleted') #and start the animation self.anim = animation.FuncAnimation(self.fig, self.update_func, fargs=(), interval=15, repeat=True) plt.show() def update_func(self, j): """based on the value of the slider, update the video """ i = int(self.slide.val) + 1 if i > self.slide.valmax: i = self.slide.valmin self.print_manager.change_frame(self.prev_frame, i) #self.slide.set_val(i%len(self.left_panel.frames)) self.slide.set_val(i) self.prev_frame = i return j def set_slider_range(self): """set the slider to range between combo_prints first frame and last frame""" self.slide.set_val(self.print_manager.combo_prints.first_frame.min()) self.slide.valmin = self.print_manager.combo_prints.first_frame.min() self.slide.valmax = self.print_manager.combo_prints.last_frame.max() self.slide.ax.set_xlim(self.slide.valmin, self.slide.valmax)
class PVSlicer(object): def __init__(self, filename, backend="Qt4Agg", clim=None): self.filename = filename try: from spectral_cube import SpectralCube cube = SpectralCube.read(filename, format='fits') self.array = cube._data except: warnings.warn("spectral_cube package is not available - using astropy.io.fits directly") from astropy.io import fits self.array = fits.getdata(filename) if self.array.ndim != 3: raise ValueError("dataset does not have 3 dimensions (install the spectral_cube package to avoid this error)") self.backend = backend import matplotlib as mpl mpl.use(self.backend) import matplotlib.pyplot as plt self.fig = plt.figure(figsize=(14, 8)) self.ax1 = self.fig.add_axes([0.1, 0.1, 0.4, 0.7]) if clim is None: warnings.warn("clim not defined and will be determined from the data") # To work with large arrays, sub-sample the data # (but don't do it for small arrays) n1 = max(self.array.shape[0] / 10, 1) n2 = max(self.array.shape[1] / 10, 1) n3 = max(self.array.shape[2] / 10, 1) sub_array = self.array[::n1,::n2,::n3] cmin = np.min(sub_array[~np.isnan(sub_array) & ~np.isinf(sub_array)]) cmax = np.max(sub_array[~np.isnan(sub_array) & ~np.isinf(sub_array)]) crange = cmax - cmin self._clim = (cmin - crange, cmax + crange) else: self._clim = clim self.slice = int(round(self.array.shape[0] / 2.)) from matplotlib.widgets import Slider self.slice_slider_ax = self.fig.add_axes([0.1, 0.95, 0.4, 0.03]) self.slice_slider_ax.set_xticklabels("") self.slice_slider_ax.set_yticklabels("") self.slice_slider = Slider(self.slice_slider_ax, "3-d slice", 0, self.array.shape[0], valinit=self.slice, valfmt="%i") self.slice_slider.on_changed(self.update_slice) self.slice_slider.drawon = False self.image = self.ax1.imshow(self.array[self.slice, :,:], origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray) self.vmin_slider_ax = self.fig.add_axes([0.1, 0.90, 0.4, 0.03]) self.vmin_slider_ax.set_xticklabels("") self.vmin_slider_ax.set_yticklabels("") self.vmin_slider = Slider(self.vmin_slider_ax, "vmin", self._clim[0], self._clim[1], valinit=self._clim[0]) self.vmin_slider.on_changed(self.update_vmin) self.vmin_slider.drawon = False self.vmax_slider_ax = self.fig.add_axes([0.1, 0.85, 0.4, 0.03]) self.vmax_slider_ax.set_xticklabels("") self.vmax_slider_ax.set_yticklabels("") self.vmax_slider = Slider(self.vmax_slider_ax, "vmax", self._clim[0], self._clim[1], valinit=self._clim[1]) self.vmax_slider.on_changed(self.update_vmax) self.vmax_slider.drawon = False self.grid1 = None self.grid2 = None self.grid3 = None self.ax2 = self.fig.add_axes([0.55, 0.1, 0.4, 0.7]) # Add slicing box self.box = SliceCurve(colors=(0.8, 0.0, 0.0)) self.ax1.add_collection(self.box) self.movable = MovableSliceBox(self.box, callback=self.update_pv_slice) self.movable.connect() # Add save button from matplotlib.widgets import Button self.save_button_ax = self.fig.add_axes([0.65, 0.90, 0.20, 0.05]) self.save_button = Button(self.save_button_ax, 'Save slice to FITS') self.save_button.on_clicked(self.save_fits) self.file_status_text = self.fig.text(0.75, 0.875, "", ha='center', va='center') self.set_file_status(None) self.set_file_status(None) self.pv_slice = None self.cidpress = self.fig.canvas.mpl_connect('button_press_event', self.click) def set_file_status(self, status, filename=None): if status == 'instructions': self.file_status_text.set_text('Please enter filename in terminal') self.file_status_text.set_color('red') elif status == 'saved': self.file_status_text.set_text('File successfully saved to {0}'.format(filename)) self.file_status_text.set_color('green') else: self.file_status_text.set_text('') self.file_status_text.set_color('black') self.fig.canvas.draw() def click(self, event): if event.inaxes != self.ax2: return self.slice_slider.set_val(event.ydata) def save_fits(self, *args, **kwargs): self.set_file_status('instructions') print("Enter filename: ", end='') try: plot_name = raw_input() except NameError: plot_name = input() if self.pv_slice is None: return try: self.pv_slice.writeto(plot_name, overwrite=True) except TypeError: self.pv_slice.writeto(plot_name, clobber=True) print("Saved file to: ", plot_name) self.set_file_status('saved', filename=plot_name) def update_pv_slice(self, box): path = Path(zip(box.x, box.y)) path.width = box.width self.pv_slice = extract_pv_slice(self.array, path) self.ax2.cla() self.ax2.imshow(self.pv_slice.data, origin='lower', aspect='auto', interpolation='nearest') self.fig.canvas.draw() def show(self, block=True): import matplotlib.pyplot as plt plt.show(block=block) def update_slice(self, pos=None): if self.array.ndim == 2: self.image.set_array(self.array) else: self.slice = int(round(pos)) self.image.set_array(self.array[self.slice, :, :]) self.fig.canvas.draw() def update_vmin(self, vmin): if vmin > self._clim[1]: self._clim = (self._clim[1], self._clim[1]) else: self._clim = (vmin, self._clim[1]) self.image.set_clim(*self._clim) self.fig.canvas.draw() def update_vmax(self, vmax): if vmax < self._clim[0]: self._clim = (self._clim[0], self._clim[0]) else: self._clim = (self._clim[0], vmax) self.image.set_clim(*self._clim) self.fig.canvas.draw()
def view_components(estimates, img, idx): """ View spatial and temporal components interactively Args: estimates: dict estimates dictionary contain results of VolPy img: 2-D array summary images for detection idx: list index of selected neurons """ n = len(idx) fig = plt.figure(figsize=(10, 10)) axcomp = plt.axes([0.05, 0.05, 0.9, 0.03]) ax1 = plt.axes([0.05, 0.55, 0.4, 0.4]) ax3 = plt.axes([0.55, 0.55, 0.4, 0.4]) ax2 = plt.axes([0.05, 0.1, 0.9, 0.4]) s_comp = Slider(axcomp, 'Component', 0, n, valinit=0) vmax = np.percentile(img, 98) def arrow_key_image_control(event): if event.key == 'left': new_val = np.round(s_comp.val - 1) if new_val < 0: new_val = 0 s_comp.set_val(new_val) elif event.key == 'right': new_val = np.round(s_comp.val + 1) if new_val > n: new_val = n s_comp.set_val(new_val) def update(val): i = np.int(np.round(s_comp.val)) print(f'Component:{i}') if i < n: ax1.cla() imgtmp = estimates['weights'][idx][i] ax1.imshow(imgtmp, interpolation='None', cmap=plt.cm.gray, vmax=np.max(imgtmp) * 0.5, vmin=0) ax1.set_title(f'Spatial component {i+1}') ax1.axis('off') ax2.cla() ax2.plot(estimates['t'][idx][i], alpha=0.8) ax2.plot(estimates['t_sub'][idx][i]) ax2.plot(estimates['t_rec'][idx][i], alpha=0.4, color='red') ax2.plot(estimates['spikes'][idx][i], 1.05 * np.max(estimates['t'][idx][i]) * np.ones(estimates['spikes'][idx][i].shape), color='r', marker='.', fillstyle='none', linestyle='none') ax2.set_title(f'Signal and spike times {i+1}') ax2.legend(labels=['t', 't_sub', 't_rec', 'spikes']) ax2.text(0.1, 0.1, f'snr:{round(estimates["snr"][idx][i],2)}', horizontalalignment='center', verticalalignment='center', transform=ax2.transAxes) ax2.text(0.1, 0.07, f'num_spikes: {len(estimates["spikes"][idx][i])}', horizontalalignment='center', verticalalignment='center', transform=ax2.transAxes) ax2.text(0.1, 0.04, f'locality_test: {estimates["locality"][idx][i]}', horizontalalignment='center', verticalalignment='center', transform=ax2.transAxes) ax3.cla() ax3.imshow(img, interpolation='None', cmap=plt.cm.gray, vmax=vmax) imgtmp2 = imgtmp.copy() imgtmp2[imgtmp2 == 0] = np.nan ax3.imshow(imgtmp2, interpolation='None', alpha=0.5, cmap=plt.cm.hot) ax3.axis('off') s_comp.on_changed(update) s_comp.set_val(0) fig.canvas.mpl_connect('key_release_event', arrow_key_image_control) plt.show()
class ytViewer(object): def __init__(self, filenames, fold=19277, nmax=100,NORM=True,dtype=int8,DEB=0,shear=0.): self.UPDATE = False self.color = False self.CMAP = ['jet','seismic','Greys','plasma'] self.COLORS = ['b','g','r'] self.filenames = filenames self.NORM = NORM self.NMAX = nmax self.fold = fold self.increment = 5 self.index = 0 self.shear = shear self.vmin = -125 self.vmax = 125 self.HORIZ_VAL = 0.05 for i in range(len(self.filenames)): exec("self.remove_len1%d = 0" %i) exec("self.remove_len2%d = 1" %i) self.Y0 = 0 ### To set the data arrays ### for i in range(len(self.filenames)): exec("self.data%d = fromfile(filenames[i],dtype=dtype)" %i) try: pass except: print '\nYou must provide existing filenames\n' sys.exit() self.max_index = int(len(self.data0)/self.fold) if not(self.NMAX): self.NMAX = self.max_index for i in range(len(self.filenames)): exec("self.folded_data_orig2%d = self.data%d[:self.max_index*self.fold].reshape(self.max_index,self.fold)" %(i,i)) exec("self.folded_data_orig%d = array(self.folded_data_orig2%d)" %(i,i)) exec("self.folded_data_orig3%d = array(self.folded_data_orig%d)" %(i,i)) exec("self.folded_data%d = self.folded_data_orig3%d[:self.NMAX]" %(i,i)) ################################################################## ################## Start creating the figure ##################### self.fig = figure(figsize=(16,7)) if len(self.filenames)==1: self.declare_axis_1channel() elif len(self.filenames)==2: self.declare_axis_2channel() elif len(self.filenames)==3: self.declare_axis_3channel() for i in range(len(self.filenames)): if not self.NORM: exec("self.im%d = self.ax%d.imshow(self.folded_data%d, interpolation='nearest', aspect='auto',origin='lower', vmin=self.vmin, vmax=self.vmax)" %(i,i,i)) else: exec("self.im%d = self.ax%d.imshow(self.folded_data%d, interpolation='nearest', aspect='auto',origin='lower', vmin=self.folded_data%d.min(), vmax=self.folded_data%d.max())" %(i,i,i,i,i)) for i in range(len(self.filenames)): exec("self.cursor%d = Cursor(self.ax%d, useblit=True, color='red', linewidth=2)" %(i,i)) self.axhh = axes([0.02,0.25,0.12,0.62]) for i in range(len(self.filenames)): exec("self.hline%d, = self.axh%d.plot(self.folded_data%d[self.Y0,:])" %(i,i,i)) exec("self.axh%d.set_xlim(0,len(self.folded_data%d[0,:]))" %(i,i)) exec("self.hhline%d, = self.axhh.plot(self.folded_data%d.mean(1),arange(self.NMAX),self.COLORS[i])" %(i,i)) self.axhh.set_ylim(0,self.NMAX-1) if not self.NORM: for i in range(len(self.filenames)): exec("self.axh%d.set_ylim(self.vmin, self.vmax)" %i) self.axhh.set_xlim(self.vmin, self.vmax) else: for i in range(len(self.filenames)): exec("self.axh%d.set_ylim(self.folded_data%d.min(), self.folded_data%d.max())" %(i,i,i)) if len(self.filenames)==1: LIM_MIN = self.folded_data0.mean(1).min() LIM_MAX = self.folded_data0.mean(1).max() elif len(self.filenames)==2: LIM_MIN = min(self.folded_data0.mean(1).min(),self.folded_data1.mean(1).min()) LIM_MAX = max(self.folded_data0.mean(1).max(),self.folded_data1.mean(1).max()) elif len(self.filenames)==3: LIM_MIN = min(self.folded_data0.mean(1).min(),self.folded_data1.mean(1).min(),self.folded_data2.mean(1).min()) LIM_MAX = max(self.folded_data0.mean(1).max(),self.folded_data1.mean(1).max(),self.folded_data2.mean(1).max()) self.axhh.set_xlim(LIM_MIN-1,LIM_MAX+1) # create 'remove_len1' slider for i in range(len(self.filenames)): exec("self.remove_len1%d_slider = Slider(self.remove_len1%d_sliderax,'beg',0.,self.fold,self.remove_len1%d,'%s')" %(i,i,i,'%d')) exec("self.remove_len1%d_slider.on_changed(self.update_tab)" %i) # create 'remove_len2' slider for i in range(len(self.filenames)): exec("self.remove_len2%d_slider = Slider(self.remove_len2%d_sliderax,'end',1.,self.fold,self.remove_len2%d,'%s')" %(i,i,i,'%d')) exec("self.remove_len2%d_slider.on_changed(self.update_tab)" %i) # create 'index' slider self.index_sliderax = axes([0.175,0.975,0.775,0.02]) self.index_slider = Slider(self.index_sliderax,'index',0,self.max_index-self.increment,0,'%d') self.index_slider.on_changed(self.update_tab) # create 'nmax' slider self.nmax_sliderax = axes([0.175,0.955,0.775,0.02]) self.nmax_slider = Slider(self.nmax_sliderax,'nmax',0,self.max_index,self.NMAX,'%d') self.nmax_slider.on_changed(self.update_tab) # create 'shear' slider self.shear_sliderax = axes([0.175,0.935,0.775,0.02]) self.shear_slider = Slider(self.shear_sliderax,'Shear',-0.5,0.5,self.shear,'%1.2f') self.shear_slider.on_changed(self.update_shear) cid = self.fig.canvas.mpl_connect('motion_notify_event', self.mousemove) cid2 = self.fig.canvas.mpl_connect('key_press_event', self.keypress) VERT_VAL = -4 font0 = FontProperties() font1 = font0.copy() font1.set_weight('bold') mpl.pyplot.text(-0.72,-32+VERT_VAL,'Useful keys:',fontsize=18,fontproperties=font1) mpl.pyplot.text(-0.72,-41+VERT_VAL,'"c" to change colormap\n "v" to change vertical\n /colorscale\n " " to pause\n "w"/"x" set Ch1 REMOVE\n sliders values to Ch2/3\n "t" Retrigger mode\n (NOT TOO MUCH POINTS) \n "q" to exit',fontsize=18) self.axe_toggledisplay = self.fig.add_axes([0.,0.,1.0,0.02]) if self.UPDATE: self.plot_circle(0,0,2,fc='#00FF7F') else: self.plot_circle(0,0,2,fc='#FF4500') mpl.pyplot.axis('off') gobject.idle_add(self.update_plot) show() ### BEGIN main loop ### def update_plot(self): while self.UPDATE: ### Compute the array to plot ### for i in range(len(self.filenames)): exec("self.folded_data%d = self.folded_data_orig3%d[self.index:(self.index+self.NMAX)]" %(i,i)) ### Update picture ### for i in range(len(self.filenames)): exec("self.im%d.set_data(self.folded_data%d)" %(i,i)) exec("self.hline%d.set_ydata(self.folded_data%d[self.Y0,:])" %(i,i)) exec("self.hhline%d.set_xdata(self.folded_data%d.mean(1))" %(i,i)) self.index = self.index + self.increment self.index_slider.set_val(self.index) self.fig.canvas.draw() return True return False ### END main loop ### def update_tab(self,val): for i in range(len(self.filenames)): exec("self.remove_len1%d = int(self.remove_len1%d_slider.val)" %(i,i)) exec("self.remove_len2%d = int(self.remove_len2%d_slider.val)" %(i,i)) self.index = int(round(self.index_slider.val,0)) self.NMAX = int(round(self.nmax_slider.val,0)) self.Y0 = 0 self.update_tabs() self.norm_fig() self.fig.canvas.draw() def update_tabs(self): for i in range(len(self.filenames)): exec("self.folded_data_orig3%d = array(self.folded_data_orig%d[:,self.remove_len1%d:-self.remove_len2%d])" %(i,i,i,i)) exec("self.process_data(self.shear,%d)" %i) exec("self.folded_data%d = self.folded_data_orig3%d[self.index:(self.index+self.NMAX)]" %(i,i)) def process_data(self,val,i): """ Redress data in the space/time diagram """ exec("dd = self.folded_data_orig2%d.copy()" %i) for k in range(0,dd.shape[0]): exec("dd[k,:] = roll(self.folded_data_orig2%d[k,:], int(k*val))" %i) exec("self.folded_data_orig%d = dd" %i) def norm_fig(self): if len(self.filenames)==1: LIM_MIN = self.folded_data0.mean(1).min() LIM_MAX = self.folded_data0.mean(1).max() elif len(self.filenames)==2: LIM_MIN = min(self.folded_data0.mean(1).min(),self.folded_data1.mean(1).min()) LIM_MAX = max(self.folded_data0.mean(1).max(),self.folded_data1.mean(1).max()) elif len(self.filenames)==3: LIM_MIN = min(self.folded_data0.mean(1).min(),self.folded_data1.mean(1).min(),self.folded_data2.mean(1).min()) LIM_MAX = max(self.folded_data0.mean(1).max(),self.folded_data1.mean(1).max(),self.folded_data2.mean(1).max()) self.axhh.clear() if not self.NORM: for i in range(len(self.filenames)): exec("self.ax%d.clear()" %i) exec("self.im%d = self.ax%d.imshow(self.folded_data%d,cmap=self.CMAP[0], interpolation='nearest', aspect='auto',origin='lower', vmin=self.vmin, vmax=self.vmax)" %(i,i,i)) exec("self.axh%d.clear()" %i) exec("self.hline%d, = self.axh%d.plot(self.folded_data%d[self.Y0,:])" %(i,i,i)) exec("self.axh%d.set_ylim(self.vmin, self.vmax)" %i) exec("self.axh%d.set_xlim(0, len(self.folded_data%d[0]))" %(i,i)) exec("self.hhline%d, = self.axhh.plot(self.folded_data%d.mean(1),arange(len(self.folded_data%d.mean(1))),self.COLORS[i])" %(i,i,i)) self.axhh.set_ylim(0,self.max_index-self.index-1) self.axhh.set_xlim(self.vmin, self.vmax) else: for i in range(len(self.filenames)): exec("self.ax%d.clear()" %i) exec("self.im%d = self.ax%d.imshow(self.folded_data%d,cmap=self.CMAP[0], interpolation='nearest', aspect='auto',origin='lower', vmin=self.folded_data%d.min(), vmax=self.folded_data%d.max())" %(i,i,i,i,i)) exec("self.axh%d.clear()" %i) exec("self.hline%d, = self.axh%d.plot(self.folded_data%d[self.Y0,:])" %(i,i,i)) exec("self.axh%d.set_ylim(self.folded_data%d.min(), self.folded_data%d.max())" %(i,i,i)) exec("self.axh%d.set_xlim(0, len(self.folded_data%d[0]))" %(i,i)) exec("self.hhline%d, = self.axhh.plot(self.folded_data%d.mean(1),arange(len(self.folded_data%d.mean(1))),self.COLORS[i])" %(i,i,i)) self.axhh.set_ylim(0,len(self.folded_data0.mean(1))) self.axhh.set_xlim(LIM_MIN-1,LIM_MAX+1) self.fig.canvas.draw() ### BEGIN Slider actions ### def update_shear(self,val): self.shear = round(self.shear_slider.val,2) self.update_tabs() self.update_tab(0) self.norm_fig() self.fig.canvas.draw() def update_cut(self): for i in range(len(self.filenames)): exec("self.hline%d.set_ydata(self.folded_data%d[self.Y0,:])" %(i,i)) self.fig.canvas.draw() ### END Slider actions ### ### BEGIN actions to the window ### def toggle_update(self): self.UPDATE = not(self.UPDATE) if self.UPDATE: gobject.idle_add(self.update_plot) self.color = not(self.color) if not(self.color): self.patch.remove() self.axe_toggledisplay = self.fig.add_axes([0.,0.,1.0,0.02]) self.axe_toggledisplay.clear() self.plot_circle(0,0,2,fc='#FF4500') mpl.pyplot.axis('off') self.fig.canvas.draw() else: self.patch.remove() self.axe_toggledisplay = self.fig.add_axes([0.,0.,1.0,0.02]) self.axe_toggledisplay.clear() self.plot_circle(0,0,2,fc='#00FF7F') mpl.pyplot.axis('off') self.fig.canvas.draw() def keypress(self, event): if event.key == 'q': # eXit del event sys.exit() elif event.key=='c': del event self.CMAP = roll(self.CMAP,-1) self.norm_fig() self.fig.canvas.draw() elif event.key=='v': del event self.NORM = not(self.NORM) self.norm_fig() elif event.key == ' ': # play/pause self.toggle_update() elif event.key == 'w': if len(self.filenames)>=2: print 'Set REMOVE values of channel 1 to channel 2' self.remove_len11_slider.set_val(self.remove_len10) self.remove_len21_slider.set_val(self.remove_len20) elif event.key == 'x': if len(self.filenames)>=3: print 'Set REMOVE values of channel 1 to channel 3' self.remove_len12_slider.set_val(self.remove_len10) self.remove_len22_slider.set_val(self.remove_len20) elif event.key == 't': print 'Trying to smooth from index',self.index self.smooth_array() print 'Done MF' else: print 'Key '+str(event.key)+' not known' def mousemove(self, event): # called on each mouse motion to get mouse position if len(self.filenames)==1: if event.inaxes!=self.ax0: return elif len(self.filenames)==2: if event.inaxes!=self.ax0 and event.inaxes!=self.ax1: return elif len(self.filenames)==3: if event.inaxes!=self.ax0 and event.inaxes!=self.ax1 and event.inaxes!=self.ax2: return self.X0 = int(round(event.xdata,0)) self.Y0 = int(round(event.ydata,0)) self.update_cut() ### END actions to the window ### ### Divers useful functions ### def plot_circle(self,x,y,r,fc='r'): """Plot a circle of radius r at position x,y""" cir = mpl.patches.Circle((x,y), radius=r, fc=fc) self.patch = mpl.pyplot.gca().add_patch(cir) def smooth_array(self): if len(self.filenames)>=2 and self.UPDATE==False: self.fig2 = figure(5,figsize=(16,7)) clf() for i in range(len(self.filenames)): exec("self.folded_data%d = self.folded_data_orig3%d[self.index:(self.index+self.NMAX)]" %(i,i)) if len(self.filenames)==2: self.folded_data_retriggered1,self.folded_data_retriggered0 = self.trig_by_interpolation_pola(self.folded_data1,self.folded_data0) self.fig2ax0 = self.fig2.add_subplot(111) self.fig2ax0.clear() self.fig2ax0.imshow(self.folded_data_retriggered0,cmap=self.CMAP[0], interpolation='nearest', aspect='auto',origin='lower', vmin=self.folded_data_retriggered0.min(), vmax=self.folded_data_retriggered0.max()) elif len(self.filenames)==3: self.folded_data_retriggered2,self.folded_data_retriggered0,self.folded_data_retriggered1 = self.trig_by_interpolation_pola(self.folded_data2,self.folded_data0,self.folded_data1) self.fig2ax1 = self.fig2.add_subplot(121) self.fig2ax1.clear() self.fig2ax1.imshow(self.folded_data_retriggered0,cmap=self.CMAP[0], interpolation='nearest', aspect='auto',origin='lower', vmin=self.folded_data_retriggered0.min(), vmax=self.folded_data_retriggered0.max()) self.fig2ax2 = self.fig2.add_subplot(122) self.fig2ax2.clear() self.fig2ax2.imshow(self.folded_data_retriggered1,cmap=self.CMAP[0], interpolation='nearest', aspect='auto',origin='lower', vmin=self.folded_data_retriggered1.min(), vmax=self.folded_data_retriggered1.max()) show(False) self.fig2.canvas.draw() else: print 'This function REQUIRES a trigger' def trig_by_interpolation_pola(self,data,pola1,pola2=None,thr=30,FACT=25,num_trig=0,DOWN=False): if DOWN: pos = self.find_down(data,thr) else: pos = self.find_up(data,thr) l = data # for a first trigg -> #array([data[pos[i]-100:pos[i]+100] for i in xrange(len(pos)-1)]) lp1 = pola1 if pola2 is not None: lp2 = pola2 ll = [] trace = [] trace_p1 = [] if pola2 is not None: trace_p2 = [] for i in range(len(l)): b = interpolate.interp1d(linspace(0,len(l[i]),len(l[i])),l[i]) bp1 = interpolate.interp1d(linspace(0,len(lp1[i]),len(lp1[i])),lp1[i]) if pola2 is not None: bp2 = interpolate.interp1d(linspace(0,len(lp2[i]),len(lp2[i])),lp2[i]) xnew2 = linspace(0, len(l[i]),FACT*len(l[i])) xnew = linspace(0, len(lp1[i]),FACT*len(lp1[i])) xnew3 = linspace(0, len(lp2[i]),FACT*len(lp2[i])) try: temp2 = b(xnew2) temp2_p1 = bp1(xnew) if pola2 is not None: temp2_p2 = bp2(xnew3) if DOWN: temp = self.find_down(temp2,thr) else: temp = self.find_up(temp2,thr) ll.append(temp[num_trig]) # If several downward event found for the trigger trace.append(temp2) trace_p1.append(temp2_p1) if pola2 is not None: trace_p2.append(temp2_p2) except: print '%d WARNING: Error repering => skiping a line'%i lll = [] lllp1 = [] if pola2 is not None: lllp2 = [] ll = array(ll)-array(ll).min() for i in range(len(ll)): lll.append(roll(trace[i],-ll[i])) lllp1.append(roll(trace_p1[i],-ll[i])) if pola2 is not None: lllp2.append(roll(trace_p2[i],-ll[i])) return (array(lll),array(lllp1),array(lllp2)) if pola2 is not None else (array(lll),array(lllp1)) def find_down(self,d, threshold): digitized = zeros(shape=d.shape, dtype=uint8) digitized[where(d < threshold)] = 255 derivative = digitized[1:]-digitized[0:-1] indices = where(derivative == 255)[0] return indices def find_up(self,d, threshold): digitized = zeros(shape=d.shape, dtype=uint8) digitized[where(d > threshold)] = 255 derivative = digitized[1:]-digitized[0:-1] indices = where(derivative == 255)[0] return indices def declare_axis_1channel(self): self.ax0 = axes([0.125+self.HORIZ_VAL,0.25,0.81,0.62]) self.axh0 = axes([0.125+self.HORIZ_VAL,0.05,0.81,0.15]) self.remove_len10_sliderax = axes([0.125+self.HORIZ_VAL,0.91,0.78,0.02]) self.remove_len20_sliderax = axes([0.125+self.HORIZ_VAL,0.88,0.78,0.02]) def declare_axis_2channel(self): self.ax0 = axes([0.125+self.HORIZ_VAL,0.25,0.395,0.62]) self.ax1 = axes([0.54+self.HORIZ_VAL,0.25,0.395,0.62]) self.axh0 = axes([0.125+self.HORIZ_VAL,0.05,0.395,0.15]) self.axh1 = axes([0.54+self.HORIZ_VAL,0.05,0.395,0.15]) self.remove_len10_sliderax = axes([0.125+self.HORIZ_VAL,0.91,0.37,0.02]) self.remove_len11_sliderax = axes([0.54+self.HORIZ_VAL,0.91,0.37,0.02]) self.remove_len20_sliderax = axes([0.125+self.HORIZ_VAL,0.88,0.37,0.02]) self.remove_len21_sliderax = axes([0.54+self.HORIZ_VAL,0.88,0.37,0.02]) def declare_axis_3channel(self): self.ax0 = axes([0.125+self.HORIZ_VAL,0.25,0.25,0.62]) self.ax1 = axes([0.405+self.HORIZ_VAL,0.25,0.25,0.62]) self.ax2 = axes([0.685+self.HORIZ_VAL,0.25,0.25,0.62]) self.axh0 = axes([0.125+self.HORIZ_VAL,0.05,0.25,0.15]) self.axh1 = axes([0.405+self.HORIZ_VAL,0.05,0.25,0.15]) self.axh2 = axes([0.685+self.HORIZ_VAL,0.05,0.25,0.15]) self.remove_len10_sliderax = axes([0.125+self.HORIZ_VAL,0.91,0.25,0.02]) self.remove_len11_sliderax = axes([0.405+self.HORIZ_VAL,0.91,0.25,0.02]) self.remove_len12_sliderax = axes([0.685+self.HORIZ_VAL,0.91,0.25,0.02]) self.remove_len20_sliderax = axes([0.125+self.HORIZ_VAL,0.88,0.25,0.02]) self.remove_len21_sliderax = axes([0.405+self.HORIZ_VAL,0.88,0.25,0.02]) self.remove_len22_sliderax = axes([0.685+self.HORIZ_VAL,0.88,0.25,0.02])
class timeseriesViewer(): """Class for tsview.py Example: cmd = 'tsview.py timeseries_ERA5_ramp_demErr.h5' obj = timeseriesViewer(cmd) obj.configure() obj.plot() """ def __init__(self, cmd=None, iargs=None): if cmd: iargs = cmd.split()[1:] self.cmd = cmd self.iargs = iargs # print command line cmd = '{} '.format(os.path.basename(__file__)) cmd += ' '.join(iargs) print(cmd) # figure variables self.figname_img = 'Cumulative Displacement Map' self.figsize_img = None self.fig_img = None self.ax_img = None self.cbar_img = None self.img = None self.ax_tslider = None self.tslider = None self.figname_pts = 'Point Displacement Time-series' self.figsize_pts = None self.fig_pts = None self.ax_pts = None return def configure(self): inps = cmd_line_parse(self.iargs) inps, self.atr = read_init_info(inps) # copy inps to self object for key, value in inps.__dict__.items(): setattr(self, key, value) # input figsize for the point time-series plot self.figsize_pts = self.fig_size self.pts_marker = 'r^' self.pts_marker_size = 6. return def plot(self): # read 3D time-series self.ts_data, self.mask = read_timeseries_data(self)[0:2] # Figure 1 - Cumulative Displacement Map self.fig_img = plt.figure(self.figname_img, figsize=self.figsize_img) # Figure 1 - Axes 1 - Displacement Map self.ax_img = self.fig_img.add_axes([0.125, 0.25, 0.75, 0.65]) img_data = np.array( self.ts_data[0][self.idx, :, :]) #################### img_data[self.mask == 0] = np.nan self.plot_init_image(img_data) # Figure 1 - Axes 2 - Time Slider self.ax_tslider = self.fig_img.add_axes([0.2, 0.1, 0.6, 0.07]) self.plot_init_time_slider(init_idx=self.idx, ref_idx=self.ref_idx) self.tslider.on_changed(self.update_time_slider) # Figure 2 - Time Series Displacement - Point self.fig_pts, self.ax_pts = plt.subplots(num=self.figname_pts, figsize=self.figsize_pts) if self.yx: d_ts = self.plot_point_timeseries(self.yx) # Output if self.save_fig: save_ts_plot(self.yx, self.fig_img, self.fig_pts, d_ts, self) # Final linking of the canvas to the plots. self.fig_img.canvas.mpl_connect('button_press_event', self.update_plot_timeseries) self.fig_img.canvas.mpl_connect('key_press_event', self.on_key_event) if self.disp_fig: vprint('showing ...') msg = '\n------------------------------------------------------------------------' msg += '\nTo scroll through the image sequence:' msg += '\n1) Move the slider, OR' msg += '\n2) Press left or right arrow key (if not responding, click the image and try again).' msg += '\n------------------------------------------------------------------------' vprint(msg) plt.show() return def plot_init_image(self, img_data): # prepare data if self.wrap: if self.disp_unit_img == 'radian': img_data *= self.range2phase img_data = ut.wrap(img_data, wrap_range=self.wrap_range) # Title and Axis Label disp_date = self.dates[self.idx].strftime('%Y-%m-%d') self.fig_title = 'N = {}, Time = {}'.format(self.idx, disp_date) # Initial Pixel of interest self.pts_yx = None self.pts_lalo = None if self.yx and self.yx != self.ref_yx: self.pts_yx = np.array(self.yx).reshape(-1, 2) if self.lalo: self.pts_lalo = np.array(self.lalo).reshape(-1, 2) # call view.py to plot self.img, self.cbar_img = view.plot_slice(self.ax_img, img_data, self.atr, self)[2:4] return self.img, self.cbar_img def plot_init_time_slider(self, init_idx=-1, ref_idx=0): val_step = np.min(np.diff(self.yearList)) val_min = self.yearList[0] val_max = self.yearList[-1] self.tslider = Slider(self.ax_tslider, label='Years', valinit=self.yearList[init_idx], valmin=val_min, valmax=val_max, valstep=val_step) bar_width = val_step / 4. datex = np.array(self.yearList) - bar_width / 2. self.tslider.ax.bar(datex, np.ones(len(datex)), bar_width, facecolor='black', ecolor=None) self.tslider.ax.bar(datex[ref_idx], 1., bar_width * 3, facecolor='crimson', ecolor=None) # xaxis tick format if np.floor(val_max) == np.floor(val_min): digit = 10. else: digit = 1. self.tslider.ax.set_xticks( np.round(np.linspace(val_min, val_max, num=5) * digit) / digit) self.tslider.ax.xaxis.set_minor_locator(MultipleLocator(1. / 12.)) self.tslider.ax.set_xlim([val_min, val_max]) self.tslider.ax.set_yticks([]) return self.tslider def update_time_slider(self, val): """Update Displacement Map using Slider""" idx = np.argmin(np.abs(np.array(self.yearList) - self.tslider.val)) # update title disp_date = self.dates[idx].strftime('%Y-%m-%d') self.ax_img.set_title('N = {n}, Time = {t}'.format(n=idx, t=disp_date), fontsize=self.font_size) # read data data_img = np.array(self.ts_data[0][idx, :, :]) data_img[self.mask == 0] = np.nan if self.wrap: if self.disp_unit_img == 'radian': data_img *= self.range2phase data_img = ut.wrap(data_img, wrap_range=self.wrap_range) # update data self.img.set_data(data_img) self.idx = idx self.fig_img.canvas.draw() return def plot_point_timeseries(self, yx): """Plot point displacement time-series at pixel [y, x] Parameters: yx : list of 2 int Returns: d_ts : 2D np.array in size of (num_date, num_file) """ self.ax_pts.cla() # plot scatter in different size for different files num_file = len(self.ts_data) if num_file <= 2: ms_step = 4 elif num_file == 3: ms_step = 3 elif num_file == 4: ms_step = 2 elif num_file >= 5: ms_step = 1 d_ts = [] y = yx[0] - self.pix_box[1] x = yx[1] - self.pix_box[0] for i in range(num_file - 1, -1, -1): # get displacement data d_tsi = self.ts_data[i][:, y, x] if self.zero_first: d_tsi -= d_tsi[self.zero_idx] d_ts.append(d_tsi) # get plot parameter - namespace ppar ppar = argparse.Namespace() ppar.label = self.file_label[i] ppar.ms = self.marker_size - ms_step * (num_file - 1 - i) ppar.mfc = pp.mplColors[num_file - 1 - i] if self.mask[y, x] == 0: ppar.mfc = 'gray' if self.offset: d_tsi += self.offset * (num_file - 1 - i) # plot if not np.all(np.isnan(d_tsi)): self.ax_pts = self.ts_plot_func(self.ax_pts, d_tsi, self, ppar) # axis format self.ax_pts = _adjust_ts_axis(self.ax_pts, self) title_ts = _get_ts_title(yx[0], yx[1], self.coord) if self.mask[y, x] == 0: title_ts += ' (masked out)' if self.disp_title: self.ax_pts.set_title(title_ts, fontsize=self.font_size) if self.tick_right: self.ax_pts.yaxis.tick_right() self.ax_pts.yaxis.set_label_position("right") # legend if len(self.ts_data) > 1: self.ax_pts.legend() # Print to terminal vprint('\n---------------------------------------') vprint(title_ts) float_formatter = lambda x: [float('{:.2f}'.format(i)) for i in x] vprint(float_formatter(d_ts[0])) if not np.all(np.isnan(d_ts[0])): # stat info vprint('displacement range: [{:.2f}, {:.2f}] {}'.format( np.nanmin(d_ts[0]), np.nanmax(d_ts[0]), self.disp_unit)) # estimate (print) slope estimate_slope(d_ts[0], self.yearList, ex_flag=self.ex_flag, disp_unit=self.disp_unit) # update figure self.fig_pts.canvas.draw() return d_ts def update_plot_timeseries(self, event): """Event function to get y/x from button press""" if event.inaxes == self.ax_img: # get row/col number if self.fig_coord == 'geo': y, x = self.coord.geo2radar(event.ydata, event.xdata, print_msg=False)[0:2] else: y, x = int(event.ydata + 0.5), int(event.xdata + 0.5) # plot time-series displacement self.plot_point_timeseries((y, x)) return def on_key_event(self, event): """Slide images with left/right key on keyboard""" if event.inaxes and event.inaxes.figure == self.fig_img: idx = None if event.key == 'left': idx = max(self.idx - 1, 0) elif event.key == 'right': idx = min(self.idx + 1, self.num_date - 1) if idx is not None and idx != self.idx: # update title disp_date = self.dates[idx].strftime('%Y-%m-%d') self.ax_img.set_title('N = {n}, Time = {t}'.format( n=idx, t=disp_date), fontsize=self.font_size) # read data data_img = np.array(self.ts_data[0][idx, :, :]) data_img[self.mask == 0] = np.nan if self.wrap: if self.disp_unit_img == 'radian': data_img *= self.range2phase data_img = ut.wrap(data_img, wrap_range=self.wrap_range) # update self.img.set_data(data_img) # update image self.tslider.set_val(self.yearList[idx]) # update slider self.idx = idx self.fig_img.canvas.draw() return
def profileview(model): plt.rc('font', size=8) fig, axs = plt.subplots(3, sharex=True) axs[0].set_title('protonic potential (V)') axs[1].set_title('oxygen molar fraction') axs[2].set_title('current density (A/m2)') axs[2].set_xlabel('distance from membrane (um)') fig.subplots_adjust(right=0.45) # the mask is assumed to account for boundary values mask = ~(model.gdl | model.membrane) # we denote the distance from the membrane wall x x = model.distance_from_membrane[mask] * 1E6 # the polarization curve is permanent and unique pcax = fig.add_axes([.5, .2, .4, .7]) pcax.set_title('polarization curve') pcax.yaxis.tick_right() pcax.yaxis.set_label_position("right") pcax.set_xlabel('geometric current density (A/cm**2)') pcax.set_ylabel('voltage at gdl (V)') polcurve, = pcax.plot([], [], 'ko-') def update(V): # try networks generate l==1 model.resolve(V, 263, flood=False) i = model.current_history[-1][mask] I = i.sum() / model.face_area / 1000 reading = (I, V) insert_point(reading, polcurve) pcax.relim() pcax.autoscale_view() for ax, yh in zip(axs, [ model.proton_history, model.oxygen_history, model.current_history, ]): y = yh[-1] update_ax(ax, x, y) fig.canvas.draw() slider = Slider(label='V', ax=fig.add_axes([.5, .05, .4, .03]), valmin=0, valmax=1.2, ) slider.on_changed(update) # generate an undersampled polcurve for V in np.linspace(0.05, 1.0, 10): slider.set_val(V) slider.set_val(0.65) # some extra interactivity fig.canvas.mpl_connect('button_press_event', lambda event: slider.set_val(event.ydata) if pcax is event.inaxes else None) plt.show()
class Hist4D(object): def __init__(self, save_only=False): self.fig=None self.cubes_info=None self.slow=None self.shigh=None self.colormap=None self.save_only = save_only def draw_cubes(self,_axes, vals, edges): ''' ax=Axes3D handle edges=matrix L+1xM+1xN+1 result of histogramdd vals=matrix LxMxN result of histogramdd colormap=color map to be matched with nonzero vals ''' edx, edy, edz = np.meshgrid(edges[0], edges[1], edges[2]) edx_rolled = np.roll(edx, -1, axis=1) edy_rolled = np.roll(edy, -1, axis=0) edz_rolled = np.roll(edz, -1, axis=2) edx_rolled = edx_rolled[:-1, :-1, :-1].ravel() edy_rolled = edy_rolled[:-1, :-1, :-1].ravel() edz_rolled = edz_rolled[:-1, :-1, :-1].ravel() edx = edx[:-1, :-1, :-1].ravel() edy = edy[:-1, :-1, :-1].ravel() edz = edz[:-1, :-1, :-1].ravel() vals = vals.ravel() vdraw_cube = np.vectorize(self.draw_cube, excluded='_axes') cubes_handles = vdraw_cube(_axes, edx[vals>0], edx_rolled[vals>0], edy[vals>0], edy_rolled[vals>0], edz[vals > 0], edz_rolled[vals > 0], vals[vals>0]/float(np.max(vals))) cubes_data = [a for a in zip(vals[vals>0],cubes_handles)] self.cubes_info=dict() for k, v in cubes_data: self.cubes_info[k] = self.cubes_info.get(k, ()) + tuple(v) #+(v,) def set_sliders(self,splot1,splot2): maxlim=max(self.cubes_info.keys()) axcolor = 'lightgoldenrodyellow' #low_vis = self.fig.add_axes([0.25, 0.1, 0.65, 0.03], axisbg=axcolor) #high_vis = self.fig.add_axes([0.25, 0.15, 0.65, 0.03], axisbg=axcolor) self.slow = Slider(splot1,'low', 0.0, maxlim, valfmt='%0.0f') self.shigh = Slider(splot2, 'high', 0.0 , maxlim, valfmt='%0.0f') self.slow.on_changed(self.update) self.shigh.on_changed(self.update) self.slow.set_val(0) self.shigh.set_val(maxlim) def update(self,val): visible = [(k,v) for k, v in self.cubes_info.items() if k > self.slow.val and k<= self.shigh.val] invisible = [v for k, v in self.cubes_info.items() if k <= self.slow.val or k> self.shigh.val] for (k,sublist) in visible: for item in sublist: print item.set_alpha item.set_alpha(k) for item in [item for sublist in invisible for item in sublist]: item.set_alpha(0) total=[v for k,v in self.cubes_info.items()] self.fig.canvas.draw_idle() def draw_cube(self,_axes, x1_coord, x2_coord, y1_coord, y2_coord, z1_coord, z2_coord, color_ind): ''' draw a cube given cube limits and color ''' _x_coord, _y_coord, _z_coord = np.meshgrid([x1_coord, x2_coord], [y1_coord, y2_coord], [z1_coord, z2_coord]) tmp1 = np.concatenate((_x_coord.ravel()[None, :], _y_coord.ravel()[ None, :], _z_coord.ravel()[None, :]), axis=0) tmp2 = tmp1.copy() tmp2[:, [0, 1]], tmp2[:, [6, 7]] = tmp2[ :, [6, 7]].copy(), tmp2[:, [0, 1]].copy() tmp3 = tmp2.copy() tmp3[:, [0, 2]], tmp3[:, [5, 7]] = tmp3[ :, [5, 7]].copy(), tmp3[:, [0, 2]].copy() points = np.concatenate((tmp1, tmp2, tmp3), axis=1) points = points.T.reshape(6, 4, 3) ''' collection = Poly3DCollection(points, facecolors=self.colormap(float(color_ind)), linewidths=0 ) _axes.add_collection3d(collection) return collection ''' surf = [] for count in range(6): surf.append(_axes.plot_surface(points[count, :, 0].reshape(2, 2), points[count, :, 1].reshape(2, 2), points[count, :, 2].reshape(2, 2), color=self.colormap(float(color_ind)), linewidth=0, antialiased=True, shade=False)) return surf def array2cmap(self,X): N = X.shape[0] r = np.linspace(0., 1., N+1) r = np.sort(np.concatenate((r, r)))[1:-1] rd = np.concatenate([[X[i, 0], X[i, 0]] for i in xrange(N)]) gr = np.concatenate([[X[i, 1], X[i, 1]] for i in xrange(N)]) bl = np.concatenate([[X[i, 2], X[i, 2]] for i in xrange(N)]) al = np.concatenate([[X[i, 3], X[i, 3]] for i in xrange(N)]) rd = tuple([(r[i], rd[i], rd[i]) for i in xrange(2 * N)]) gr = tuple([(r[i], gr[i], gr[i]) for i in xrange(2 * N)]) bl = tuple([(r[i], bl[i], bl[i]) for i in xrange(2 * N)]) al = tuple([(r[i], al[i], al[i]) for i in xrange(2 * N)]) cdict = {'red': rd, 'green': gr, 'blue': bl, 'alpha': al} return colors.LinearSegmentedColormap('my_colormap', cdict) def draw_colorbar(self,_axes,unique_vals=None,cax=None): if unique_vals is None: unique_vals = np.linspace(0, 1, 1000) xmin, xmax = _axes.get_xlim() ymin, ymax = _axes.get_ylim() zmin, zmax = _axes.get_zlim() invis=_axes.scatter(unique_vals, unique_vals, c=np.arange(len(unique_vals)), cmap=self.colormap) _axes.set_xlim([xmin,xmax]) _axes.set_ylim([ymin,ymax]) _axes.set_zlim([zmin,zmax]) cbar=self.fig.colorbar(invis,ax=_axes,cax=cax,drawedges=False) cbar.set_ticks(np.linspace(0,np.size(unique_vals),5)) if unique_vals is not None: cbar.set_ticklabels(np.around(np.linspace(0,np.max(unique_vals),5),2)) invis.set_alpha(0) def create_opacity_colormap(self,principal_rgb_color, scale_size=256): ''' Create opacity colormap based on one principal RGB color ''' if np.any(principal_rgb_color > 1): raise Exception('principal_rgb_color values should be in range [0,1]') opac_colormap = np.concatenate((np.tile(principal_rgb_color[None, :], (scale_size, 1))[:], np.linspace(0, 1, scale_size)[:, None]), axis=1) self.colormap=self.array2cmap(opac_colormap) def create_brightness_colormap(self,principal_rgb_color, scale_size): ''' Create brightness colormap based on one principal RGB color ''' if np.any(principal_rgb_color > 1): raise Exception('principal_rgb_color values should be in range [0,1]') hsv_color = colors.rgb_to_hsv(principal_rgb_color) hsv_colormap = np.concatenate((np.tile(hsv_color[:-1][None, :], (scale_size, 1))[:], np.linspace(0, 1, scale_size)[:, None]), axis=1) self.colormap=self.array2cmap(colors.hsv_to_rgb(hsv_colormap)) def draw(self,hist,edges, fig=None,gs=None,subplot=None, color=np.array([1,0,0]),all_axes=None): ''' fig=figure handle gs= contiguous slice (or whole) of gridspec to host plot hist,edges=histogramdd output ''' if fig is not None: self.fig=fig else: self.fig=plt.figure() if gs is None: gs = gridspec.GridSpec(50, 50) if all_axes is None: _axes = self.fig.add_subplot(gs[:-5,:45],projection='3d') cax=self.fig.add_subplot(gs[:-5,45:]) ax1=self.fig.add_subplot(gs[-4:-2,:]) ax2=self.fig.add_subplot(gs[-2:,:]) else: _axes,cax,ax1,ax2=all_axes _axes.clear() cax.clear() ax1.clear() ax2.clear() #unique_hist=np.unique(hist) self.create_opacity_colormap(color) self.draw_cubes(_axes, hist, edges) self.draw_colorbar(_axes,cax=cax) _axes.set_xlim((edges[0].min(),edges[0].max())) _axes.set_ylim((edges[1].min(),edges[1].max())) _axes.set_zlim((edges[2].min(),edges[2].max())) self.fig.patch.set_facecolor('white') _axes.w_xaxis.set_pane_color((0.8, 0.8, 0.8, 1.0)) _axes.w_yaxis.set_pane_color((0.8, 0.8, 0.8, 1.0)) _axes.w_zaxis.set_pane_color((0.8, 0.8, 0.8, 1.0)) if not self.save_only: self.set_sliders(ax1, ax2) return _axes,ax1,ax2